├── internal ├── pycolmap │ ├── .gitignore │ ├── pycolmap │ │ ├── __init__.py │ │ ├── image.py │ │ ├── camera.py │ │ ├── database.py │ │ └── rotation.py │ ├── README.md │ ├── tools │ │ ├── delete_images.py │ │ ├── transform_model.py │ │ ├── write_camera_track_to_bundler.py │ │ ├── colmap_to_nvm.py │ │ ├── save_cameras_as_ply.py │ │ ├── write_depthmap_to_ply.py │ │ └── impute_missing_cameras.py │ └── LICENSE.txt ├── math.py ├── geopoly.py ├── image.py ├── utils.py ├── coord.py ├── ref_utils.py ├── vis.py ├── configs.py └── stepfun.py ├── configs ├── render_config.gin ├── tat.gin ├── llff_256_cluster_uw.gin ├── llff_256.gin ├── llff_fog_sim.gin ├── llff_fog_sim_no_extra_samp.gin └── llff_256_uw.gin ├── scripts ├── render_all_uw.sh ├── train_llff.sh ├── run_all_unit_tests.sh ├── train_uw_cluster.sh ├── train_llff_uw.sh ├── eval_llff_uw.sh ├── render_llff_uw.sh ├── eval_all_uw.sh └── local_colmap_and_resize.sh ├── tests ├── utils_test.py ├── camera_utils_test.py ├── ref_utils_test.py ├── datasets_test.py ├── image_test.py ├── math_test.py ├── geopoly_test.py └── coord_test.py ├── install.sh ├── requirements.txt ├── README.md ├── render.py └── LICENSE /internal/pycolmap/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.sw* 3 | -------------------------------------------------------------------------------- /configs/render_config.gin: -------------------------------------------------------------------------------- 1 | Config.render_path = True 2 | Config.render_path_frames = 480 3 | Config.render_video_fps = 60 4 | -------------------------------------------------------------------------------- /configs/tat.gin: -------------------------------------------------------------------------------- 1 | # This config is meant to be run while overriding a 360*.gin config. 2 | 3 | Config.dataset_loader = 'tat_nerfpp' 4 | Config.near = 0.1 5 | Config.far = 1e6 6 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/__init__.py: -------------------------------------------------------------------------------- 1 | from camera import Camera 2 | from database import COLMAPDatabase 3 | from image import Image 4 | from scene_manager import SceneManager 5 | from rotation import Quaternion, DualQuaternion 6 | -------------------------------------------------------------------------------- /internal/pycolmap/README.md: -------------------------------------------------------------------------------- 1 | # pycolmap 2 | Python interface for COLMAP reconstructions, plus some convenient scripts for loading/modifying/converting reconstructions. 3 | 4 | This code does not, however, run reconstruction -- it only provides a convenient interface for handling COLMAP's output. 5 | -------------------------------------------------------------------------------- /scripts/render_all_uw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #CKPT_DIR="ckpt/uw/ablation_v1" 4 | CKPT_DIR="ckpt/uw/cvpr2023" 5 | #export CUDA_VISIBLE_DEVICES=1 6 | 7 | EXPERIMENT=uw 8 | DATA_DIR=data/"$EXPERIMENT" 9 | 10 | 11 | 12 | for EXP_DIR in $CKPT_DIR/*; 13 | do 14 | EXP_NAME=${EXP_DIR##*/} 15 | echo $EXP_NAME 16 | arrSCENE=(${EXP_NAME//_/ }) 17 | SCENE="${arrSCENE[0]}" 18 | echo $SCENE 19 | 20 | python -m render \ 21 | --gin_configs=${EXP_DIR}/config.gin \ 22 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 23 | --gin_bindings="Config.checkpoint_dir = '${EXP_DIR}'" \ 24 | --gin_bindings="Config.render_path = True" \ 25 | --gin_bindings="Config.render_path_frames = 240" \ 26 | --gin_bindings="Config.render_dir = '${EXP_DIR}/render/'" \ 27 | --gin_bindings="Config.render_video_fps = 5" \ 28 | --logtostderr 29 | 30 | 31 | done 32 | -------------------------------------------------------------------------------- /tests/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utils.""" 16 | 17 | from absl.testing import absltest 18 | 19 | from internal import utils 20 | 21 | 22 | class UtilsTest(absltest.TestCase): 23 | 24 | def test_dummy_rays(self): 25 | """Ensures that the dummy Rays object is correctly initialized.""" 26 | rays = utils.dummy_rays() 27 | self.assertEqual(rays.origins.shape[-1], 3) 28 | 29 | 30 | if __name__ == '__main__': 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /scripts/train_llff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=fern 19 | EXPERIMENT=llff 20 | DATA_DIR=data/nerf_llff_data 21 | CHECKPOINT_DIR=ckpt/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | rm "$CHECKPOINT_DIR"/* 24 | python -m train \ 25 | --gin_configs=configs/llff_256.gin \ 26 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 27 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /scripts/run_all_unit_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | python -m unittest tests.camera_utils_test 18 | python -m unittest tests.geopoly_test 19 | python -m unittest tests.stepfun_test 20 | python -m unittest tests.coord_test 21 | python -m unittest tests.image_test 22 | python -m unittest tests.ref_utils_test 23 | python -m unittest tests.utils_test 24 | python -m unittest tests.datasets_test 25 | python -m unittest tests.math_test 26 | python -m unittest tests.render_test 27 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/image.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | #------------------------------------------------------------------------------- 6 | # 7 | # Image 8 | # 9 | #------------------------------------------------------------------------------- 10 | 11 | class Image: 12 | def __init__(self, name_, camera_id_, q_, tvec_): 13 | self.name = name_ 14 | self.camera_id = camera_id_ 15 | self.q = q_ 16 | self.tvec = tvec_ 17 | 18 | self.points2D = np.empty((0, 2), dtype=np.float64) 19 | self.point3D_ids = np.empty((0,), dtype=np.uint64) 20 | 21 | #--------------------------------------------------------------------------- 22 | 23 | def R(self): 24 | return self.q.ToR() 25 | 26 | #--------------------------------------------------------------------------- 27 | 28 | def C(self): 29 | return -self.R().T.dot(self.tvec) 30 | 31 | #--------------------------------------------------------------------------- 32 | 33 | @property 34 | def t(self): 35 | return self.tvec 36 | -------------------------------------------------------------------------------- /internal/pycolmap/tools/delete_images.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | 6 | from pycolmap import DualQuaternion, Image, SceneManager 7 | 8 | 9 | #------------------------------------------------------------------------------- 10 | 11 | def main(args): 12 | scene_manager = SceneManager(args.input_folder) 13 | scene_manager.load() 14 | 15 | image_ids = map(scene_manager.get_image_from_name, 16 | iter(lambda: sys.stdin.readline().strip(), "")) 17 | scene_manager.delete_images(image_ids) 18 | 19 | scene_manager.save(args.output_folder) 20 | 21 | 22 | #------------------------------------------------------------------------------- 23 | 24 | if __name__ == "__main__": 25 | import argparse 26 | 27 | parser = argparse.ArgumentParser( 28 | description="Deletes images (filenames read from stdin) from a model.", 29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | 31 | parser.add_argument("input_folder") 32 | parser.add_argument("output_folder") 33 | 34 | args = parser.parse_args() 35 | 36 | main(args) 37 | -------------------------------------------------------------------------------- /scripts/train_uw_cluster.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This file was modified by Deborah Levy 3 | # Copyright 2022 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #export CUDA_VISIBLE_DEVICES=0 18 | 19 | SCENE=matanWb_strong 20 | EXPERIMENT=uw 21 | DATA_DIR=/root/Deborah/nerfren/load 22 | CHECKPOINT_DIR=ckpt/nerf_results/"$EXPERIMENT"/"$SCENE" 23 | 24 | rm "$CHECKPOINT_DIR"/* 25 | python -m train \ 26 | --gin_configs=configs/llff_256_cluster.gin \ 27 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 28 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 29 | --logtostderr 30 | -------------------------------------------------------------------------------- /scripts/train_llff_uw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## This file was modified by Deborah Levy 3 | # Copyright 2022 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #export CUDA_VISIBLE_DEVICES=1 18 | 19 | SCENE=Panama 20 | EXPERIMENT=uw 21 | EXPERIMENT_NAME=debug_config_uwmlp_xyz_atten 22 | DATA_DIR=data/ 23 | CHECKPOINT_DIR=ckpt/"$EXPERIMENT"/"$SCENE"_"$EXPERIMENT_NAME" 24 | 25 | rm "$CHECKPOINT_DIR"/* 26 | python -m train \ 27 | --gin_configs=configs/llff_256_uw.gin \ 28 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 29 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 30 | --logtostderr 31 | -------------------------------------------------------------------------------- /internal/pycolmap/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 True Price, UNC Chapel Hill 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/llff_256_cluster_uw.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.near = 0. 3 | Config.far = 1. 4 | Config.factor = 1 5 | Config.forward_facing = True 6 | Config.adam_eps = 1e-8 7 | Config.batch_size = 16384 8 | Config.print_every = 1000 # TODO Deborah, added to reduce calls to wandb.log 9 | Config.eval_on_train = False # TODO Naama 10 | Config.eval_only_once = True 11 | Config.use_uw_mlp = True 12 | Config.use_uw_bs_initial_loss = False 13 | Config.use_uw_sig_obj_loss = False 14 | Config.data_loss_type = 'rawnerf' 15 | UWMLP.uw_old_model = True 16 | UWMLP.uw_fog_model = True 17 | 18 | 19 | Model.ray_shape = 'cylinder' 20 | Model.opaque_background = True 21 | Model.num_levels = 2 22 | Model.num_prop_samples = 128 23 | Model.num_nerf_samples = 32 24 | 25 | PropMLP.net_depth = 4 26 | PropMLP.net_width = 256 27 | PropMLP.basis_shape = 'octahedron' 28 | PropMLP.basis_subdivisions = 1 29 | PropMLP.disable_density_normals = True 30 | PropMLP.disable_rgb = True 31 | 32 | NerfMLP.net_depth = 8 33 | NerfMLP.net_width = 256 34 | NerfMLP.basis_shape = 'octahedron' 35 | NerfMLP.basis_subdivisions = 1 36 | NerfMLP.disable_density_normals = True 37 | 38 | NerfMLP.max_deg_point = 16 39 | PropMLP.max_deg_point = 16 40 | -------------------------------------------------------------------------------- /scripts/eval_llff_uw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | #export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=Panama 19 | EXPERIMENT=uw 20 | EXPERIMENT_NAME=exp1 21 | DATA_DIR=data/ 22 | CHECKPOINT_DIR=ckpt/"$EXPERIMENT"/"$SCENE"_"$EXPERIMENT_NAME" 23 | 24 | python -m eval \ 25 | --gin_configs=${CHECKPOINT_DIR}/config.gin \ 26 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 27 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 28 | --gin_bindings="Config.eval_on_train = True" \ 29 | --gin_bindings="Config.eval_only_once = True" \ 30 | --logtostderr 31 | 32 | # change Config.eval_on_train to False to get the rendering of the test set 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /configs/llff_256.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.near = 0. 3 | Config.far = 1. 4 | Config.factor = 4 5 | Config.forward_facing = True 6 | Config.adam_eps = 1e-8 7 | Config.batch_size = 2048 8 | Config.eval_on_train = False # TODO Naama 9 | Config.eval_only_once = True 10 | Config.eval_save_ray_data = True # TODO Naame 11 | Config.lr_init = 0.00025 12 | Config.lr_final = 0.000025 13 | Config.max_steps = 2000000 14 | Model.ray_shape = 'cylinder' 15 | Model.opaque_background = True 16 | Model.num_levels = 2 17 | Model.num_prop_samples = 128 18 | Model.num_nerf_samples = 32 19 | 20 | PropMLP.net_depth = 4 21 | PropMLP.net_width = 256 22 | PropMLP.basis_shape = 'octahedron' 23 | PropMLP.basis_subdivisions = 1 24 | PropMLP.disable_density_normals = True 25 | PropMLP.disable_rgb = True 26 | 27 | NerfMLP.net_depth = 8 28 | NerfMLP.net_width = 256 29 | NerfMLP.basis_shape = 'octahedron' 30 | NerfMLP.basis_subdivisions = 1 31 | NerfMLP.disable_density_normals = True 32 | 33 | NerfMLP.max_deg_point = 16 34 | PropMLP.max_deg_point = 16 35 | 36 | # RGB activation we use for linear color outputs is exp(x - 5). 37 | # NerfMLP.rgb_padding = 0. 38 | # NerfMLP.rgb_activation = @math.safe_exp 39 | # NerfMLP.rgb_bias = -5. 40 | # PropMLP.rgb_padding = 0. 41 | # PropMLP.rgb_activation = @math.safe_exp 42 | # PropMLP.rgb_bias = -5. -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | cd /root/Deborah/uw-multinerf/ 2 | apt-get update && apt-get -y install sudo 3 | apt-get install vim 4 | apt-get install git 5 | apt-get install -y wget 6 | apt-get install ffmpeg libsm6 libxext6 -y 7 | apt-get install python3-pip 8 | # apt install python3.9 9 | # sudo ln -s /usr/bin/python3.9 /usr/bin/python 10 | 11 | wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && /bin/bash ~/miniconda.sh -b -p /opt/conda 12 | 13 | mkdir -p ~/miniconda3 14 | wget https://repo.anaconda.com/miniconda/Miniconda3-py39_22.11.1-1-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 15 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 16 | rm -rf ~/miniconda3/miniconda.sh 17 | 18 | 19 | 20 | 21 | # Prepare pip. 22 | pip install --upgrade pip 23 | pip install https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.1+cuda11.cudnn82-cp39-cp39-manylinux2014_x86_64.whl 24 | # pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 25 | pip install -r requirements.txt 26 | # Manually install rmbrualla's `pycolmap` (don't use pip's! It's different). 27 | git clone https://github.com/rmbrualla/pycolmap.git ./internal/pycolmap 28 | pip install pypfm 29 | pip install wandb 30 | wandb login 31 | # Confirm that all the unit tests pass. 32 | ./scripts/run_all_unit_tests.sh 33 | : 34 | -------------------------------------------------------------------------------- /scripts/render_llff_uw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This file was modified by Deborah Levy 3 | # Copyright 2022 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #export CUDA_VISIBLE_DEVICES=0 18 | 19 | SCENE=Panama 20 | EXPERIMENT=uw 21 | EXPERIMENT_NAME=exp1 22 | DATA_DIR=data/ 23 | CHECKPOINT_DIR=ckpt/"$EXPERIMENT"/"$SCENE"_"$EXPERIMENT_NAME" 24 | 25 | python -m render \ 26 | --gin_configs=${CHECKPOINT_DIR}/config.gin \ 27 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 28 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 29 | --gin_bindings="Config.render_path = True" \ 30 | --gin_bindings="Config.render_path_frames = 240" \ 31 | --gin_bindings="Config.render_dir = '${CHECKPOINT_DIR}/render/'" \ 32 | --gin_bindings="Config.render_video_fps = 5" \ 33 | --logtostderr 34 | -------------------------------------------------------------------------------- /scripts/eval_all_uw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #CKPT_DIR="ckpt/uw/ablation_v1" 4 | CKPT_DIR="ckpt/results/uw/cvpr2023/" 5 | #export CUDA_VISIBLE_DEVICES=0 6 | 7 | EXPERIMENT=uw 8 | DATA_DIR=data/"$EXPERIMENT" 9 | 10 | 11 | 12 | for EXP_DIR in $CKPT_DIR/*; 13 | do 14 | EXP_NAME=${EXP_DIR##*/} 15 | echo $EXP_NAME 16 | arrSCENE=(${EXP_NAME//_/ }) 17 | SCENE="${arrSCENE[0]}" 18 | echo $SCENE 19 | for mode in train test 20 | do 21 | PREDS_DIR="${EXP_DIR}/${mode}_preds" 22 | if test -d "$PREDS_DIR"; then 23 | echo "----------------------------------" 24 | echo " Skipping $PREDS_DIR..."; 25 | echo "----------------------------------" 26 | else 27 | echo "----------------------------------" 28 | echo " Processing $PREDS_DIR..."; 29 | echo "----------------------------------" 30 | if [[ $mode == "train" ]]; then 31 | python -m eval \ 32 | --gin_configs=${EXP_DIR}/config.gin \ 33 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 34 | --gin_bindings="Config.checkpoint_dir = '${EXP_DIR}'" \ 35 | --gin_bindings="Config.eval_on_train = True" \ 36 | --gin_bindings="Config.eval_only_once = True" \ 37 | --logtostderr 38 | else 39 | python -m eval \ 40 | --gin_configs=${EXP_DIR}/config.gin \ 41 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 42 | --gin_bindings="Config.checkpoint_dir = '${EXP_DIR}'" \ 43 | --gin_bindings="Config.eval_only_once = True" \ 44 | --logtostderr 45 | fi 46 | fi 47 | done 48 | done 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /internal/pycolmap/tools/transform_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | 6 | from pycolmap import Quaternion, SceneManager 7 | 8 | 9 | #------------------------------------------------------------------------------- 10 | 11 | def main(args): 12 | scene_manager = SceneManager(args.input_folder) 13 | scene_manager.load() 14 | 15 | # expect each line of input corresponds to one row 16 | P = np.array([ 17 | map(float, sys.stdin.readline().strip().split()) for _ in xrange(3)]) 18 | 19 | scene_manager.points3D[:] = scene_manager.points3D.dot(P[:,:3].T) + P[:,3] 20 | 21 | # get rotation without any global scaling (assuming isotropic scaling) 22 | scale = np.cbrt(np.linalg.det(P[:,:3])) 23 | q_old_from_new = ~Quaternion.FromR(P[:,:3] / scale) 24 | 25 | for image in scene_manager.images.itervalues(): 26 | image.q *= q_old_from_new 27 | image.tvec = scale * image.tvec - image.R().dot(P[:,3]) 28 | 29 | scene_manager.save(args.output_folder) 30 | 31 | 32 | #------------------------------------------------------------------------------- 33 | 34 | if __name__ == "__main__": 35 | import argparse 36 | 37 | parser = argparse.ArgumentParser( 38 | description="Apply a 3x4 transformation matrix to a COLMAP model and " 39 | "save the result as a new model. Row-major input can be piped in from " 40 | "a file or entered via the command line.", 41 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 42 | 43 | parser.add_argument("input_folder") 44 | parser.add_argument("output_folder") 45 | 46 | args = parser.parse_args() 47 | 48 | main(args) 49 | -------------------------------------------------------------------------------- /internal/pycolmap/tools/write_camera_track_to_bundler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | 6 | from pycolmap import SceneManager 7 | 8 | 9 | #------------------------------------------------------------------------------- 10 | 11 | def main(args): 12 | scene_manager = SceneManager(args.input_folder) 13 | scene_manager.load_cameras() 14 | scene_manager.load_images() 15 | 16 | if args.sort: 17 | images = sorted( 18 | scene_manager.images.itervalues(), key=lambda im: im.name) 19 | else: 20 | images = scene_manager.images.values() 21 | 22 | fid = open(args.output_file, "w") 23 | fid_filenames = open(args.output_file + ".list.txt", "w") 24 | 25 | print>>fid, "# Bundle file v0.3" 26 | print>>fid, len(images), 0 27 | 28 | for image in images: 29 | print>>fid_filenames, image.name 30 | camera = scene_manager.cameras[image.camera_id] 31 | print>>fid, 0.5 * (camera.fx + camera.fy), 0, 0 32 | R, t = image.R(), image.t 33 | print>>fid, R[0, 0], R[0, 1], R[0, 2] 34 | print>>fid, -R[1, 0], -R[1, 1], -R[1, 2] 35 | print>>fid, -R[2, 0], -R[2, 1], -R[2, 2] 36 | print>>fid, t[0], -t[1], -t[2] 37 | 38 | fid.close() 39 | fid_filenames.close() 40 | 41 | 42 | #------------------------------------------------------------------------------- 43 | 44 | if __name__ == "__main__": 45 | import argparse 46 | 47 | parser = argparse.ArgumentParser( 48 | description="Saves the camera positions in the Bundler format. Note " 49 | "that 3D points are not saved.", 50 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | 52 | parser.add_argument("input_folder") 53 | parser.add_argument("output_file") 54 | 55 | parser.add_argument("--sort", default=False, action="store_true", 56 | help="sort the images by their filename") 57 | 58 | args = parser.parse_args() 59 | 60 | main(args) 61 | -------------------------------------------------------------------------------- /configs/llff_fog_sim.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.near = 0. 3 | Config.far = 1. 4 | Config.factor = 4 5 | Config.forward_facing = True 6 | Config.adam_eps = 1e-8 7 | Config.batch_size = 2048 8 | Config.print_every = 1000 9 | Config.eval_on_train = False 10 | Config.eval_only_once = True 11 | Config.train_render_every = 5000 12 | Config.use_uw_mlp = True 13 | Config.use_uw_acc_weights_loss = False 14 | Config.use_uw_acc_trans_loss= True 15 | Config.max_steps = 250000 16 | Config.data_loss_type = 'rawnerf' 17 | Config.distortion_loss_mult = 0. 18 | Config.interlevel_loss_mult = 1 19 | Config.data_loss_mult = 1 20 | Config.uw_initial_acc_weights_loss_mult = 0.0001 21 | Config.uw_final_acc_weights_loss_mult = 0.0001 22 | Config.uw_initial_acc_trans_loss_mult = 0.00001 23 | Config.uw_final_acc_trans_loss_mult = 0.00001 24 | Config.uw_acc_loss_factor = 1 25 | Config.extra_samples = True 26 | Config.gen_eq = False #need to update the values twice, I know it's annoying, I will solve it soon and update the code 27 | UWMLP.gen_eq = False 28 | Config.uw_fog_model = True 29 | UWMLP.uw_fog_model = True 30 | Config.uw_atten_xyz = False 31 | UWMLP.uw_atten_xyz = False 32 | Config.uw_old_model = False 33 | UWMLP.uw_old_model= False #need to update the values twice, I know it's annoying, I will solve it soon and update the code 34 | UWMLP.uw_rgb_dir = False 35 | 36 | Config.lr_init = 0.002 37 | Config.lr_final = 0.00002 38 | Model.ray_shape = 'cylinder' 39 | Model.opaque_background = False 40 | Model.num_levels = 2 41 | Model.num_prop_samples = 128 42 | Model.num_nerf_samples = 32 43 | PropMLP.density_noise = 0 44 | PropMLP.net_depth = 4 45 | PropMLP.net_width = 256 46 | PropMLP.basis_shape = 'octahedron' 47 | PropMLP.basis_subdivisions = 1 48 | PropMLP.disable_density_normals = True 49 | PropMLP.disable_rgb = True 50 | 51 | PropMLP.density_bias = -1 52 | 53 | UWMLP.density_noise = 0 54 | UWMLP.net_depth = 8 55 | UWMLP.net_width = 256 56 | UWMLP.basis_shape = 'octahedron' 57 | UWMLP.basis_subdivisions = 1 58 | UWMLP.disable_density_normals = True 59 | UWMLP.water_bias = 0 60 | UWMLP.density_bias = 0 61 | UWMLP.max_deg_point = 16 62 | PropMLP.max_deg_point = 16 63 | -------------------------------------------------------------------------------- /configs/llff_fog_sim_no_extra_samp.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.near = 0. 3 | Config.far = 1. 4 | Config.factor = 4 5 | Config.forward_facing = True 6 | Config.adam_eps = 1e-8 7 | Config.batch_size = 2048 8 | Config.print_every = 1000 9 | Config.eval_on_train = False 10 | Config.eval_only_once = True 11 | Config.train_render_every = 5000 12 | Config.use_uw_mlp = True 13 | Config.use_uw_acc_weights_loss = False 14 | Config.use_uw_acc_trans_loss= True 15 | Config.max_steps = 250000 16 | Config.data_loss_type = 'rawnerf' 17 | Config.distortion_loss_mult = 0. 18 | Config.interlevel_loss_mult = 1 19 | Config.data_loss_mult = 1 20 | Config.uw_initial_acc_weights_loss_mult = 0.0001 21 | Config.uw_final_acc_weights_loss_mult = 0.0001 22 | Config.uw_initial_acc_trans_loss_mult = 0.00001 23 | Config.uw_final_acc_trans_loss_mult = 0.00001 24 | Config.uw_acc_loss_factor = 1 25 | Config.extra_samples = False 26 | Config.gen_eq = False #need to update the values twice, I know it's annoying, I will solve it soon and update the code 27 | UWMLP.gen_eq = False 28 | Config.uw_fog_model = True 29 | UWMLP.uw_fog_model = True 30 | Config.uw_atten_xyz = False 31 | UWMLP.uw_atten_xyz = False 32 | Config.uw_old_model = False 33 | UWMLP.uw_old_model= False #need to update the values twice, I know it's annoying, I will solve it soon and update the code 34 | UWMLP.uw_rgb_dir = False 35 | 36 | Config.lr_init = 0.003 37 | Config.lr_final = 0.00001 38 | Model.ray_shape = 'cylinder' 39 | Model.opaque_background = False 40 | Model.num_levels = 2 41 | Model.num_prop_samples = 128 42 | Model.num_nerf_samples = 32 43 | PropMLP.density_noise = 0 44 | PropMLP.net_depth = 4 45 | PropMLP.net_width = 256 46 | PropMLP.basis_shape = 'octahedron' 47 | PropMLP.basis_subdivisions = 1 48 | PropMLP.disable_density_normals = True 49 | PropMLP.disable_rgb = True 50 | 51 | PropMLP.density_bias = -1 52 | 53 | UWMLP.density_noise = 0 54 | UWMLP.net_depth = 8 55 | UWMLP.net_width = 256 56 | UWMLP.basis_shape = 'octahedron' 57 | UWMLP.basis_subdivisions = 1 58 | UWMLP.disable_density_normals = True 59 | UWMLP.water_bias = 0 60 | UWMLP.density_bias = 0 61 | UWMLP.max_deg_point = 16 62 | PropMLP.max_deg_point = 16 63 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | asttokens==2.1.0 3 | astunparse==1.6.3 4 | backcall==0.2.0 5 | cachetools==5.2.0 6 | certifi @ file:///croot/certifi_1665076670883/work/certifi 7 | charset-normalizer==2.1.1 8 | chex==0.1.5 9 | click==8.1.3 10 | commonmark==0.9.1 11 | contourpy==1.0.6 12 | cycler==0.11.0 13 | decorator==5.1.1 14 | dm-pix==0.3.4 15 | dm-tree==0.1.7 16 | docker-pycreds==0.4.0 17 | etils==0.9.0 18 | executing==1.2.0 19 | flatbuffers==22.10.26 20 | flax==0.6.1 21 | fonttools==4.38.0 22 | gast==0.4.0 23 | gin-config==0.5.0 24 | gitdb==4.0.9 25 | GitPython==3.1.29 26 | google-auth==2.14.0 27 | google-auth-oauthlib==0.4.6 28 | google-pasta==0.2.0 29 | grpcio==1.50.0 30 | h5py==3.7.0 31 | idna==3.4 32 | importlib-metadata==5.0.0 33 | importlib-resources==5.10.0 34 | ipython==8.6.0 35 | jax==0.3.23 36 | jaxlib==0.3.22+cuda11.cudnn82 37 | jedi==0.18.1 38 | keras==2.10.0 39 | Keras-Preprocessing==1.1.2 40 | kiwisolver==1.4.4 41 | libclang==14.0.6 42 | Markdown==3.4.1 43 | MarkupSafe==2.1.1 44 | matplotlib==3.6.2 45 | matplotlib-inline==0.1.6 46 | mediapy==1.1.2 47 | msgpack==1.0.4 48 | numpy==1.23.4 49 | oauthlib==3.2.2 50 | opencv-python==4.6.0.66 51 | opt-einsum==3.3.0 52 | optax==0.1.3 53 | packaging==21.3 54 | parso==0.8.3 55 | pathtools==0.1.2 56 | pexpect==4.8.0 57 | pickleshare==0.7.5 58 | pillow>=10.0.1 59 | promise==2.3 60 | prompt-toolkit==3.0.32 61 | protobuf==3.19.6 62 | psutil==5.9.4 63 | ptyprocess==0.7.0 64 | pure-eval==0.2.2 65 | pyasn1==0.4.8 66 | pyasn1-modules==0.2.8 67 | Pygments==2.13.0 68 | pyparsing==3.0.9 69 | python-dateutil==2.8.2 70 | PyYAML==6.0 71 | rawpy==0.17.3 72 | requests==2.28.1 73 | requests-oauthlib==1.3.1 74 | rich==12.6.0 75 | rsa==4.9 76 | scipy==1.9.3 77 | sentry-sdk==1.10.1 78 | setproctitle==1.3.2 79 | shortuuid==1.0.11 80 | six==1.16.0 81 | smmap==5.0.0 82 | stack-data==0.6.0 83 | tensorboard==2.10.1 84 | tensorboard-data-server==0.6.1 85 | tensorboard-plugin-wit==1.8.1 86 | tensorflow==2.10.0 87 | tensorflow-estimator==2.10.0 88 | tensorflow-io-gcs-filesystem==0.27.0 89 | termcolor==2.1.0 90 | toolz==0.12.0 91 | traitlets==5.5.0 92 | typing_extensions==4.4.0 93 | urllib3==1.26.12 94 | wandb==0.13.5 95 | wcwidth==0.2.5 96 | Werkzeug==2.2.2 97 | wrapt==1.14.1 98 | zipp==3.10.0 99 | -------------------------------------------------------------------------------- /configs/llff_256_uw.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.near = 0. 3 | Config.far = 1. 4 | Config.factor = 1 5 | Config.forward_facing = True 6 | Config.adam_eps = 1e-8 7 | Config.batch_size = 16384 8 | Config.print_every = 1000 9 | Config.eval_on_train = False 10 | Config.eval_only_once = True 11 | Config.train_render_every = 5000 12 | Config.use_uw_mlp = True 13 | Config.use_uw_acc_weights_loss = False 14 | Config.use_uw_acc_trans_loss= True 15 | Config.max_steps = 250000 16 | Config.data_loss_type = 'rawnerf' 17 | Config.distortion_loss_mult = 0. 18 | Config.interlevel_loss_mult = 1 19 | Config.data_loss_mult = 1 20 | Config.uw_initial_acc_weights_loss_mult = 0.001 21 | Config.uw_final_acc_weights_loss_mult = 0.001 22 | Config.uw_initial_acc_trans_loss_mult = 0.0001 23 | Config.uw_final_acc_trans_loss_mult = 0.0001 24 | Config.use_uw_sig_med_loss = False 25 | Config.uw_sig_med_mult = 0.0001 26 | Config.gen_eq = False #need to update the values twice, I know it's annoying, I will solve it soon and update the code 27 | UWMLP.gen_eq = False 28 | Config.uw_fog_model = False 29 | UWMLP.uw_fog_model = False 30 | Config.uw_atten_xyz = False 31 | UWMLP.uw_atten_xyz = False 32 | Config.uw_old_model = False 33 | UWMLP.uw_old_model= False #need to update the values twice, I know it's annoying, I will solve it soon and update the code 34 | 35 | UWMLP.uw_rgb_dir = False 36 | 37 | Config.lr_init = 0.002 38 | Config.lr_final = 0.00002 39 | #Model.num_glo_features = 9 40 | #Model.num_glo_embeddings = 25 41 | Model.ray_shape = 'cylinder' 42 | Model.opaque_background = False 43 | Model.num_levels = 2 44 | Model.num_prop_samples = 128 45 | Model.num_nerf_samples = 32 46 | PropMLP.density_noise = 0 47 | PropMLP.net_depth = 4 48 | PropMLP.net_width = 256 49 | PropMLP.basis_shape = 'octahedron' 50 | PropMLP.basis_subdivisions = 1 51 | PropMLP.disable_density_normals = True 52 | PropMLP.disable_rgb = True 53 | 54 | PropMLP.density_bias = 0 55 | 56 | UWMLP.density_noise = 0 57 | UWMLP.net_depth = 8 58 | UWMLP.net_width = 256 59 | UWMLP.basis_shape = 'octahedron' 60 | UWMLP.basis_subdivisions = 1 61 | UWMLP.disable_density_normals = True 62 | UWMLP.water_bias = 0 63 | UWMLP.density_bias = 0 64 | 65 | UWMLP.max_deg_point = 16 66 | PropMLP.max_deg_point = 16 67 | 68 | -------------------------------------------------------------------------------- /internal/pycolmap/tools/colmap_to_nvm.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import sys 3 | sys.path.append("..") 4 | 5 | import numpy as np 6 | 7 | from pycolmap import Quaternion, SceneManager 8 | 9 | 10 | #------------------------------------------------------------------------------- 11 | 12 | def main(args): 13 | scene_manager = SceneManager(args.input_folder) 14 | scene_manager.load() 15 | 16 | with open(args.output_file, "w") as fid: 17 | fid.write("NVM_V3\n \n{:d}\n".format(len(scene_manager.images))) 18 | 19 | image_fmt_str = " {:.3f} " + 7 * "{:.7f} " 20 | for image_id, image in scene_manager.images.iteritems(): 21 | camera = scene_manager.cameras[image.camera_id] 22 | f = 0.5 * (camera.fx + camera.fy) 23 | fid.write(args.image_name_prefix + image.name) 24 | fid.write(image_fmt_str.format( 25 | *((f,) + tuple(image.q.q) + tuple(image.C())))) 26 | if camera.distortion_func is None: 27 | fid.write("0 0\n") 28 | else: 29 | fid.write("{:.7f} 0\n".format(-camera.k1)) 30 | 31 | image_id_to_idx = dict( 32 | (image_id, i) for i, image_id in enumerate(scene_manager.images)) 33 | 34 | fid.write("{:d}\n".format(len(scene_manager.points3D))) 35 | for i, point3D_id in enumerate(scene_manager.point3D_ids): 36 | fid.write( 37 | "{:.7f} {:.7f} {:.7f} ".format(*scene_manager.points3D[i])) 38 | fid.write( 39 | "{:d} {:d} {:d} ".format(*scene_manager.point3D_colors[i])) 40 | keypoints = [ 41 | (image_id_to_idx[image_id], kp_idx) + 42 | tuple(scene_manager.images[image_id].points2D[kp_idx]) 43 | for image_id, kp_idx in 44 | scene_manager.point3D_id_to_images[point3D_id]] 45 | fid.write("{:d}".format(len(keypoints))) 46 | fid.write( 47 | (len(keypoints) * " {:d} {:d} {:.3f} {:.3f}" + "\n").format( 48 | *itertools.chain(*keypoints))) 49 | 50 | 51 | #------------------------------------------------------------------------------- 52 | 53 | if __name__ == "__main__": 54 | import argparse 55 | 56 | parser = argparse.ArgumentParser( 57 | description="Save a COLMAP reconstruction in the NVM format " 58 | "(http://ccwu.me/vsfm/doc.html#nvm).", 59 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 60 | 61 | parser.add_argument("input_folder") 62 | parser.add_argument("output_file") 63 | 64 | parser.add_argument("--image_name_prefix", type=str, default="", 65 | help="prefix image names with this string (e.g., 'images/')") 66 | 67 | args = parser.parse_args() 68 | 69 | main(args) 70 | -------------------------------------------------------------------------------- /tests/camera_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for camera_utils.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from internal import camera_utils 20 | from jax import random 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | class CameraUtilsTest(parameterized.TestCase): 26 | 27 | def test_convert_to_ndc(self): 28 | rng = random.PRNGKey(0) 29 | for _ in range(10): 30 | # Random pinhole camera intrinsics. 31 | key, rng = random.split(rng) 32 | focal, width, height = random.uniform(key, (3,), minval=100., maxval=200.) 33 | camtopix = camera_utils.intrinsic_matrix(focal, focal, width / 2., 34 | height / 2.) 35 | pixtocam = np.linalg.inv(camtopix) 36 | near = 1. 37 | 38 | # Random rays, pointing forward (negative z direction). 39 | num_rays = 1000 40 | key, rng = random.split(rng) 41 | origins = jnp.array([0., 0., 1.]) 42 | origins += random.uniform(key, (num_rays, 3), minval=-1., maxval=1.) 43 | directions = jnp.array([0., 0., -1.]) 44 | directions += random.uniform(key, (num_rays, 3), minval=-.5, maxval=.5) 45 | 46 | # Project world-space points along each ray into NDC space. 47 | t = jnp.linspace(0., 1., 10) 48 | pts_world = origins + t[:, None, None] * directions 49 | pts_ndc = jnp.stack([ 50 | -focal / (.5 * width) * pts_world[..., 0] / pts_world[..., 2], 51 | -focal / (.5 * height) * pts_world[..., 1] / pts_world[..., 2], 52 | 1. + 2. * near / pts_world[..., 2], 53 | ], 54 | axis=-1) 55 | 56 | # Get NDC space rays. 57 | origins_ndc, directions_ndc = camera_utils.convert_to_ndc( 58 | origins, directions, pixtocam, near) 59 | 60 | # Ensure that the NDC space points lie on the calculated rays. 61 | directions_ndc_norm = jnp.linalg.norm( 62 | directions_ndc, axis=-1, keepdims=True) 63 | directions_ndc_unit = directions_ndc / directions_ndc_norm 64 | projection = ((pts_ndc - origins_ndc) * directions_ndc_unit).sum(axis=-1) 65 | pts_ndc_proj = origins_ndc + directions_ndc_unit * projection[..., None] 66 | 67 | # pts_ndc should be close to their projections pts_ndc_proj onto the rays. 68 | np.testing.assert_allclose(pts_ndc, pts_ndc_proj, atol=1e-5, rtol=1e-5) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /scripts/local_colmap_and_resize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Set to 0 if you do not have a GPU. 18 | USE_GPU=1 19 | # Path to a directory `base/` with images in `base/images/`. 20 | DATASET_PATH=$1 21 | # Recommended CAMERA values: OPENCV for perspective, OPENCV_FISHEYE for fisheye. 22 | CAMERA=${2:-OPENCV} 23 | 24 | 25 | # Run COLMAP. 26 | 27 | ### Feature extraction 28 | 29 | colmap feature_extractor \ 30 | --database_path "$DATASET_PATH"/database.db \ 31 | --image_path "$DATASET_PATH"/images \ 32 | --ImageReader.single_camera 1 \ 33 | --ImageReader.camera_model "$CAMERA" \ 34 | --SiftExtraction.use_gpu "$USE_GPU" 35 | 36 | 37 | ### Feature matching 38 | 39 | colmap exhaustive_matcher \ 40 | --database_path "$DATASET_PATH"/database.db \ 41 | --SiftMatching.use_gpu "$USE_GPU" 42 | 43 | ## Use if your scene has > 500 images 44 | ## Replace this path with your own local copy of the file. 45 | ## Download from: https://demuc.de/colmap/#download 46 | # VOCABTREE_PATH=/usr/local/google/home/bmild/vocab_tree_flickr100K_words32K.bin 47 | # colmap vocab_tree_matcher \ 48 | # --database_path "$DATASET_PATH"/database.db \ 49 | # --VocabTreeMatching.vocab_tree_path $VOCABTREE_PATH \ 50 | # --SiftMatching.use_gpu "$USE_GPU" 51 | 52 | 53 | ### Bundle adjustment 54 | 55 | # The default Mapper tolerance is unnecessarily large, 56 | # decreasing it speeds up bundle adjustment steps. 57 | mkdir -p "$DATASET_PATH"/sparse 58 | colmap mapper \ 59 | --database_path "$DATASET_PATH"/database.db \ 60 | --image_path "$DATASET_PATH"/images \ 61 | --output_path "$DATASET_PATH"/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001 63 | 64 | 65 | ### Image undistortion 66 | 67 | ## Use this if you want to undistort your images into ideal pinhole intrinsics. 68 | # mkdir -p "$DATASET_PATH"/dense 69 | # colmap image_undistorter \ 70 | # --image_path "$DATASET_PATH"/images \ 71 | # --input_path "$DATASET_PATH"/sparse/0 \ 72 | # --output_path "$DATASET_PATH"/dense \ 73 | # --output_type COLMAP 74 | 75 | # Resize images. 76 | 77 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_2 78 | 79 | pushd "$DATASET_PATH"/images_2 80 | ls | xargs -P 8 -I {} mogrify -resize 50% {} 81 | popd 82 | 83 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_4 84 | 85 | pushd "$DATASET_PATH"/images_4 86 | ls | xargs -P 8 -I {} mogrify -resize 25% {} 87 | popd 88 | 89 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_8 90 | 91 | pushd "$DATASET_PATH"/images_8 92 | ls | xargs -P 8 -I {} mogrify -resize 12.5% {} 93 | popd 94 | -------------------------------------------------------------------------------- /tests/ref_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ref_utils.""" 16 | 17 | from absl.testing import absltest 18 | from internal import ref_utils 19 | from jax import random 20 | import jax.numpy as jnp 21 | import numpy as np 22 | import scipy 23 | 24 | 25 | def generate_dir_enc_fn_scipy(deg_view): 26 | """Return spherical harmonics using scipy.special.sph_harm.""" 27 | ml_array = ref_utils.get_ml_array(deg_view) 28 | 29 | def dir_enc_fn(theta, phi): 30 | de = [scipy.special.sph_harm(m, l, phi, theta) for m, l in ml_array.T] 31 | de = np.stack(de, axis=-1) 32 | # Split into real and imaginary parts. 33 | return np.concatenate([np.real(de), np.imag(de)], axis=-1) 34 | 35 | return dir_enc_fn 36 | 37 | 38 | class RefUtilsTest(absltest.TestCase): 39 | 40 | def test_reflection(self): 41 | """Make sure reflected vectors have the same angle from normals as input.""" 42 | rng = random.PRNGKey(0) 43 | for shape in [(45, 3), (4, 7, 3)]: 44 | key, rng = random.split(rng) 45 | normals = random.normal(key, shape) 46 | key, rng = random.split(rng) 47 | directions = random.normal(key, shape) 48 | 49 | # Normalize normal vectors. 50 | normals = normals / ( 51 | jnp.linalg.norm(normals, axis=-1, keepdims=True) + 1e-10) 52 | 53 | reflected_directions = ref_utils.reflect(directions, normals) 54 | 55 | cos_angle_original = jnp.sum(directions * normals, axis=-1) 56 | cos_angle_reflected = jnp.sum(reflected_directions * normals, axis=-1) 57 | 58 | np.testing.assert_allclose( 59 | cos_angle_original, cos_angle_reflected, atol=1E-5, rtol=1E-5) 60 | 61 | def test_spherical_harmonics(self): 62 | """Make sure the fast spherical harmonics are accurate.""" 63 | shape = (12, 11, 13) 64 | 65 | # Generate random points on sphere. 66 | rng = random.PRNGKey(0) 67 | key1, key2 = random.split(rng) 68 | theta = random.uniform(key1, shape, minval=0.0, maxval=jnp.pi) 69 | phi = random.uniform(key2, shape, minval=0.0, maxval=2.0*jnp.pi) 70 | 71 | # Convert to Cartesian coordinates. 72 | x = jnp.sin(theta) * jnp.cos(phi) 73 | y = jnp.sin(theta) * jnp.sin(phi) 74 | z = jnp.cos(theta) 75 | xyz = jnp.stack([x, y, z], axis=-1) 76 | 77 | deg_view = 5 78 | de = ref_utils.generate_dir_enc_fn(deg_view)(xyz) 79 | de_scipy = generate_dir_enc_fn_scipy(deg_view)(theta, phi) 80 | 81 | np.testing.assert_allclose( 82 | de, de_scipy, atol=0.02, rtol=1e6) # Only use atol. 83 | self.assertFalse(jnp.any(jnp.isnan(de))) 84 | 85 | 86 | if __name__ == '__main__': 87 | absltest.main() 88 | -------------------------------------------------------------------------------- /internal/pycolmap/tools/save_cameras_as_ply.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | import os 6 | 7 | from pycolmap import SceneManager 8 | 9 | 10 | #------------------------------------------------------------------------------- 11 | 12 | # Saves the cameras as a mesh 13 | # 14 | # inputs: 15 | # - ply_file: output file 16 | # - images: ordered array of pycolmap Image objects 17 | # - color: color string for the camera 18 | # - scale: amount to shrink/grow the camera model 19 | def save_camera_ply(ply_file, images, scale): 20 | points3D = scale * np.array(( 21 | (0., 0., 0.), 22 | (-1., -1., 1.), 23 | (-1., 1., 1.), 24 | (1., -1., 1.), 25 | (1., 1., 1.))) 26 | 27 | faces = np.array(((0, 2, 1), 28 | (0, 4, 2), 29 | (0, 3, 4), 30 | (0, 1, 3), 31 | (1, 2, 4), 32 | (1, 4, 3))) 33 | 34 | r = np.linspace(0, 255, len(images), dtype=np.uint8) 35 | g = 255 - r 36 | b = r - np.linspace(0, 128, len(images), dtype=np.uint8) 37 | color = np.column_stack((r, g, b)) 38 | 39 | with open(ply_file, "w") as fid: 40 | print>>fid, "ply" 41 | print>>fid, "format ascii 1.0" 42 | print>>fid, "element vertex", len(points3D) * len(images) 43 | print>>fid, "property float x" 44 | print>>fid, "property float y" 45 | print>>fid, "property float z" 46 | print>>fid, "property uchar red" 47 | print>>fid, "property uchar green" 48 | print>>fid, "property uchar blue" 49 | print>>fid, "element face", len(faces) * len(images) 50 | print>>fid, "property list uchar int vertex_index" 51 | print>>fid, "end_header" 52 | 53 | for image, c in zip(images, color): 54 | for p3D in (points3D.dot(image.R()) + image.C()): 55 | print>>fid, p3D[0], p3D[1], p3D[2], c[0], c[1], c[2] 56 | 57 | for i in xrange(len(images)): 58 | for f in (faces + len(points3D) * i): 59 | print>>fid, "3 {} {} {}".format(*f) 60 | 61 | 62 | #------------------------------------------------------------------------------- 63 | 64 | def main(args): 65 | scene_manager = SceneManager(args.input_folder) 66 | scene_manager.load_images() 67 | 68 | images = sorted(scene_manager.images.itervalues(), 69 | key=lambda image: image.name) 70 | 71 | save_camera_ply(args.output_file, images, args.scale) 72 | 73 | 74 | #------------------------------------------------------------------------------- 75 | 76 | if __name__ == "__main__": 77 | import argparse 78 | 79 | parser = argparse.ArgumentParser( 80 | description="Saves camera positions to a PLY for easy viewing outside " 81 | "of COLMAP. Currently, camera FoV is not reflected in the output.", 82 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 83 | 84 | parser.add_argument("input_folder") 85 | parser.add_argument("output_file") 86 | 87 | parser.add_argument("--scale", type=float, default=1., 88 | help="Scaling factor for the camera mesh.") 89 | 90 | args = parser.parse_args() 91 | 92 | main(args) 93 | -------------------------------------------------------------------------------- /tests/datasets_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for datasets.""" 16 | 17 | from absl.testing import absltest 18 | from internal import camera_utils 19 | from internal import configs 20 | from internal import datasets 21 | from jax import random 22 | import numpy as np 23 | 24 | 25 | class DummyDataset(datasets.Dataset): 26 | 27 | def _load_renderings(self, config): 28 | """Generates dummy image and pose data.""" 29 | self._n_examples = 2 30 | self.height = 3 31 | self.width = 4 32 | self._resolution = self.height * self.width 33 | self.focal = 5. 34 | self.pixtocams = np.linalg.inv( 35 | camera_utils.intrinsic_matrix(self.focal, self.focal, self.width * 0.5, 36 | self.height * 0.5)) 37 | 38 | rng = random.PRNGKey(0) 39 | 40 | key, rng = random.split(rng) 41 | images_shape = (self._n_examples, self.height, self.width, 3) 42 | self.images = random.uniform(key, images_shape) 43 | 44 | key, rng = random.split(rng) 45 | self.camtoworlds = np.stack([ 46 | camera_utils.viewmatrix(*random.normal(k, (3, 3))) 47 | for k in random.split(key, self._n_examples) 48 | ], 49 | axis=0) 50 | 51 | 52 | class DatasetsTest(absltest.TestCase): 53 | 54 | def test_dataset_batch_creation(self): 55 | np.random.seed(0) 56 | config = configs.Config(batch_size=8) 57 | 58 | # Check shapes are consistent across all ray attributes. 59 | for split in ['train', 'test']: 60 | dummy_dataset = DummyDataset(split, '', config) 61 | rays = dummy_dataset.peek().rays 62 | sh_gt = rays.origins.shape[:-1] 63 | for z in rays.__dict__.values(): 64 | if z is not None: 65 | self.assertEqual(z.shape[:-1], sh_gt) 66 | 67 | # Check test batch generation matches golden data. 68 | dummy_dataset = DummyDataset('test', '', config) 69 | batch = dummy_dataset.peek() 70 | 71 | rgb = batch.rgb.ravel() 72 | rgb_gt = np.array([ 73 | 0.5289556, 0.28869557, 0.24527192, 0.12083626, 0.8904066, 0.6259936, 74 | 0.57573485, 0.09355974, 0.8017353, 0.538651, 0.4998169, 0.42061496, 75 | 0.5591258, 0.00577283, 0.6804651, 0.9139203, 0.00444758, 0.96962905, 76 | 0.52956843, 0.38282406, 0.28777933, 0.6640035, 0.39736128, 0.99495006, 77 | 0.13100398, 0.7597165, 0.8532667, 0.67468107, 0.6804743, 0.26873016, 78 | 0.60699487, 0.5722265, 0.44482303, 0.6511061, 0.54807067, 0.09894073 79 | ]) 80 | np.testing.assert_allclose(rgb, rgb_gt, atol=1e-4, rtol=1e-4) 81 | 82 | ray_origins = batch.rays.origins.ravel() 83 | ray_origins_gt = np.array([ 84 | -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, 85 | -0.8818224, -0.20050469, -0.6451472, -0.8818224, -0.20050469, 86 | -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224, 87 | -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, 88 | -0.8818224, -0.20050469, -0.6451472, -0.8818224, -0.20050469, 89 | -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224, 90 | -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224 91 | ]) 92 | np.testing.assert_allclose( 93 | ray_origins, ray_origins_gt, atol=1e-4, rtol=1e-4) 94 | 95 | ray_dirs = batch.rays.directions.ravel() 96 | ray_dirs_gt = np.array([ 97 | 0.24370372, 0.89296186, -0.5227117, 0.05601424, 0.8468699, -0.57417226, 98 | -0.13167524, 0.8007779, -0.62563276, -0.31936473, 0.75468594, 99 | -0.67709327, 0.17780769, 0.96766925, -0.34928587, -0.0098818, 0.9215773, 100 | -0.4007464, -0.19757128, 0.87548524, -0.4522069, -0.38526076, 101 | 0.82939327, -0.5036674, 0.11191163, 1.0423766, -0.17586003, -0.07577785, 102 | 0.9962846, -0.22732055, -0.26346734, 0.95019263, -0.2787811, 103 | -0.45115682, 0.90410066, -0.3302416 104 | ]) 105 | np.testing.assert_allclose(ray_dirs, ray_dirs_gt, atol=1e-4, rtol=1e-4) 106 | 107 | 108 | if __name__ == '__main__': 109 | absltest.main() 110 | -------------------------------------------------------------------------------- /internal/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mathy utility functions.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | 21 | def matmul(a, b): 22 | """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" 23 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 24 | 25 | 26 | def safe_trig_helper(x, fn, t=100 * jnp.pi): 27 | """Helper function used by safe_cos/safe_sin: mods x before sin()/cos().""" 28 | return fn(jnp.where(jnp.abs(x) < t, x, x % t)) 29 | 30 | 31 | def safe_cos(x): 32 | """jnp.cos() on a TPU may NaN out for large values.""" 33 | return safe_trig_helper(x, jnp.cos) 34 | 35 | 36 | def safe_sin(x): 37 | """jnp.sin() on a TPU may NaN out for large values.""" 38 | return safe_trig_helper(x, jnp.sin) 39 | 40 | 41 | @jax.custom_jvp 42 | def safe_exp(x): 43 | """jnp.exp() but with finite output and gradients for large inputs.""" 44 | return jnp.exp(jnp.minimum(x, 88.))# jnp.exp(89) is infinity. 45 | 46 | 47 | # @jax.custom_jvp 48 | # def safe_sig(x,wb): 49 | # """jnp.exp() but with finite output and gradients for large inputs.""" 50 | # return 1/((1+jnp.exp(-x) )*wb) 51 | 52 | 53 | @safe_exp.defjvp 54 | def safe_exp_jvp(primals, tangents): 55 | """Override safe_exp()'s gradient so that it's large when inputs are large.""" 56 | x, = primals 57 | x_dot, = tangents 58 | exp_x = safe_exp(x) 59 | exp_x_dot = exp_x * x_dot 60 | return exp_x, exp_x_dot 61 | 62 | 63 | def log_lerp(t, v0, v1): 64 | """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" 65 | if v0 <= 0 or v1 <= 0: 66 | raise ValueError(f'Interpolants {v0} and {v1} must be positive.') 67 | lv0 = jnp.log(v0) 68 | lv1 = jnp.log(v1) 69 | return jnp.exp(jnp.clip(t, 0, 1) * (lv1 - lv0) + lv0) 70 | 71 | 72 | def learning_rate_decay(step, 73 | lr_init, 74 | lr_final, 75 | max_steps, 76 | lr_delay_steps=0, 77 | lr_delay_mult=1): 78 | """Continuous learning rate decay function. 79 | 80 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 81 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 82 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 83 | function of lr_delay_mult, such that the initial learning rate is 84 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 85 | to the normal learning rate when steps>lr_delay_steps. 86 | 87 | Args: 88 | step: int, the current optimization step. 89 | lr_init: float, the initial learning rate. 90 | lr_final: float, the final learning rate. 91 | max_steps: int, the number of steps during optimization. 92 | lr_delay_steps: int, the number of steps to delay the full learning rate. 93 | lr_delay_mult: float, the multiplier on the rate when delaying it. 94 | 95 | Returns: 96 | lr: the learning for current step 'step'. 97 | """ 98 | if lr_delay_steps > 0: 99 | # A kind of reverse cosine decay. 100 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * jnp.sin( 101 | 0.5 * jnp.pi * jnp.clip(step / lr_delay_steps, 0, 1)) 102 | else: 103 | delay_rate = 1. 104 | return delay_rate * log_lerp(step / max_steps, lr_init, lr_final) 105 | 106 | 107 | def interp(*args): 108 | """A gather-based (GPU-friendly) vectorized replacement for jnp.interp().""" 109 | args_flat = [x.reshape([-1, x.shape[-1]]) for x in args] 110 | ret = jax.vmap(jnp.interp)(*args_flat).reshape(args[0].shape) 111 | return ret 112 | 113 | 114 | def sorted_interp(x, xp, fp): 115 | """A TPU-friendly version of interp(), where xp and fp must be sorted.""" 116 | 117 | # Identify the location in `xp` that corresponds to each `x`. 118 | # The final `True` index in `mask` is the start of the matching interval. 119 | mask = x[..., None, :] >= xp[..., :, None] 120 | 121 | def find_interval(x): 122 | # Grab the value where `mask` switches from True to False, and vice versa. 123 | # This approach takes advantage of the fact that `x` is sorted. 124 | x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2) 125 | x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2) 126 | return x0, x1 127 | 128 | fp0, fp1 = find_interval(fp) 129 | xp0, xp1 = find_interval(xp) 130 | 131 | offset = jnp.clip(jnp.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) 132 | ret = fp0 + offset * (fp1 - fp0) 133 | return ret 134 | -------------------------------------------------------------------------------- /internal/geopoly.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for constructing geodesic polyhedron, which are used as a basis.""" 16 | 17 | import itertools 18 | import numpy as np 19 | 20 | 21 | def compute_sq_dist(mat0, mat1=None): 22 | """Compute the squared Euclidean distance between all pairs of columns.""" 23 | if mat1 is None: 24 | mat1 = mat0 25 | # Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y. 26 | sq_norm0 = np.sum(mat0**2, 0) 27 | sq_norm1 = np.sum(mat1**2, 0) 28 | sq_dist = sq_norm0[:, None] + sq_norm1[None, :] - 2 * mat0.T @ mat1 29 | sq_dist = np.maximum(0, sq_dist) # Negative values must be numerical errors. 30 | return sq_dist 31 | 32 | 33 | def compute_tesselation_weights(v): 34 | """Tesselate the vertices of a triangle by a factor of `v`.""" 35 | if v < 1: 36 | raise ValueError(f'v {v} must be >= 1') 37 | int_weights = [] 38 | for i in range(v + 1): 39 | for j in range(v + 1 - i): 40 | int_weights.append((i, j, v - (i + j))) 41 | int_weights = np.array(int_weights) 42 | weights = int_weights / v # Barycentric weights. 43 | return weights 44 | 45 | 46 | def tesselate_geodesic(base_verts, base_faces, v, eps=1e-4): 47 | """Tesselate the vertices of a geodesic polyhedron. 48 | 49 | Args: 50 | base_verts: tensor of floats, the vertex coordinates of the geodesic. 51 | base_faces: tensor of ints, the indices of the vertices of base_verts that 52 | constitute eachface of the polyhedra. 53 | v: int, the factor of the tesselation (v==1 is a no-op). 54 | eps: float, a small value used to determine if two vertices are the same. 55 | 56 | Returns: 57 | verts: a tensor of floats, the coordinates of the tesselated vertices. 58 | """ 59 | if not isinstance(v, int): 60 | raise ValueError(f'v {v} must an integer') 61 | tri_weights = compute_tesselation_weights(v) 62 | 63 | verts = [] 64 | for base_face in base_faces: 65 | new_verts = np.matmul(tri_weights, base_verts[base_face, :]) 66 | new_verts /= np.sqrt(np.sum(new_verts**2, 1, keepdims=True)) 67 | verts.append(new_verts) 68 | verts = np.concatenate(verts, 0) 69 | 70 | sq_dist = compute_sq_dist(verts.T) 71 | assignment = np.array([np.min(np.argwhere(d <= eps)) for d in sq_dist]) 72 | unique = np.unique(assignment) 73 | verts = verts[unique, :] 74 | 75 | return verts 76 | 77 | 78 | def generate_basis(base_shape, 79 | angular_tesselation, 80 | remove_symmetries=True, 81 | eps=1e-4): 82 | """Generates a 3D basis by tesselating a geometric polyhedron. 83 | 84 | Args: 85 | base_shape: string, the name of the starting polyhedron, must be either 86 | 'icosahedron' or 'octahedron'. 87 | angular_tesselation: int, the number of times to tesselate the polyhedron, 88 | must be >= 1 (a value of 1 is a no-op to the polyhedron). 89 | remove_symmetries: bool, if True then remove the symmetric basis columns, 90 | which is usually a good idea because otherwise projections onto the basis 91 | will have redundant negative copies of each other. 92 | eps: float, a small number used to determine symmetries. 93 | 94 | Returns: 95 | basis: a matrix with shape [3, n]. 96 | """ 97 | if base_shape == 'icosahedron': 98 | a = (np.sqrt(5) + 1) / 2 99 | verts = np.array([(-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1), 100 | (0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0), 101 | (-a, 1, 0), (a, -1, 0), (-a, -1, 0)]) / np.sqrt(a + 2) 102 | faces = np.array([(0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1), 103 | (8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3), 104 | (7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6), 105 | (6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5), 106 | (7, 2, 11)]) 107 | verts = tesselate_geodesic(verts, faces, angular_tesselation) 108 | elif base_shape == 'octahedron': 109 | verts = np.array([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), 110 | (1, 0, 0)]) 111 | corners = np.array(list(itertools.product([-1, 1], repeat=3))) 112 | pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2) 113 | faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1) 114 | verts = tesselate_geodesic(verts, faces, angular_tesselation) 115 | else: 116 | raise ValueError(f'base_shape {base_shape} not supported') 117 | 118 | if remove_symmetries: 119 | # Remove elements of `verts` that are reflections of each other. 120 | match = compute_sq_dist(verts.T, -verts.T) < eps 121 | verts = verts[np.any(np.triu(match), 1), :] 122 | 123 | basis = verts[:, ::-1] 124 | return basis 125 | -------------------------------------------------------------------------------- /internal/pycolmap/tools/write_depthmap_to_ply.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import imageio 5 | import numpy as np 6 | import os 7 | 8 | from plyfile import PlyData, PlyElement 9 | from pycolmap import SceneManager 10 | from scipy.ndimage.interpolation import zoom 11 | 12 | 13 | #------------------------------------------------------------------------------- 14 | 15 | def main(args): 16 | suffix = ".photometric.bin" if args.photometric else ".geometric.bin" 17 | 18 | image_file = os.path.join(args.dense_folder, "images", args.image_filename) 19 | depth_file = os.path.join( 20 | args.dense_folder, args.stereo_folder, "depth_maps", 21 | args.image_filename + suffix) 22 | if args.save_normals: 23 | normals_file = os.path.join( 24 | args.dense_folder, args.stereo_folder, "normal_maps", 25 | args.image_filename + suffix) 26 | 27 | # load camera intrinsics from the COLMAP reconstruction 28 | scene_manager = SceneManager(os.path.join(args.dense_folder, "sparse")) 29 | scene_manager.load_cameras() 30 | scene_manager.load_images() 31 | 32 | image_id, image = scene_manager.get_image_from_name(args.image_filename) 33 | camera = scene_manager.cameras[image.camera_id] 34 | rotation_camera_from_world = image.R() 35 | camera_center = image.C() 36 | 37 | # load image, depth map, and normal map 38 | image = imageio.imread(image_file) 39 | 40 | with open(depth_file, "rb") as fid: 41 | w = int("".join(iter(lambda: fid.read(1), "&"))) 42 | h = int("".join(iter(lambda: fid.read(1), "&"))) 43 | c = int("".join(iter(lambda: fid.read(1), "&"))) 44 | depth_map = np.fromfile(fid, np.float32).reshape(h, w) 45 | if (h, w) != image.shape[:2]: 46 | depth_map = zoom( 47 | depth_map, 48 | (float(image.shape[0]) / h, float(image.shape[1]) / w), 49 | order=0) 50 | 51 | if args.save_normals: 52 | with open(normals_file, "rb") as fid: 53 | w = int("".join(iter(lambda: fid.read(1), "&"))) 54 | h = int("".join(iter(lambda: fid.read(1), "&"))) 55 | c = int("".join(iter(lambda: fid.read(1), "&"))) 56 | normals = np.fromfile( 57 | fid, np.float32).reshape(c, h, w).transpose([1, 2, 0]) 58 | if (h, w) != image.shape[:2]: 59 | normals = zoom( 60 | normals, 61 | (float(image.shape[0]) / h, float(image.shape[1]) / w, 1.), 62 | order=0) 63 | 64 | if args.min_depth is not None: 65 | depth_map[depth_map < args.min_depth] = 0. 66 | if args.max_depth is not None: 67 | depth_map[depth_map > args.max_depth] = 0. 68 | 69 | # create 3D points 70 | #depth_map = np.minimum(depth_map, 100.) 71 | points3D = np.dstack(camera.get_image_grid() + [depth_map]) 72 | points3D[:,:,:2] *= depth_map[:,:,np.newaxis] 73 | 74 | # save 75 | points3D = points3D.astype(np.float32).reshape(-1, 3) 76 | if args.save_normals: 77 | normals = normals.astype(np.float32).reshape(-1, 3) 78 | image = image.reshape(-1, 3) 79 | if image.dtype != np.uint8: 80 | if image.max() <= 1: 81 | image = (image * 255.).astype(np.uint8) 82 | else: 83 | image = image.astype(np.uint8) 84 | 85 | if args.world_space: 86 | points3D = points3D.dot(rotation_camera_from_world) + camera_center 87 | if args.save_normals: 88 | normals = normals.dot(rotation_camera_from_world) 89 | 90 | if args.save_normals: 91 | vertices = np.rec.fromarrays( 92 | tuple(points3D.T) + tuple(normals.T) + tuple(image.T), 93 | names="x,y,z,nx,ny,nz,red,green,blue") 94 | else: 95 | vertices = np.rec.fromarrays( 96 | tuple(points3D.T) + tuple(image.T), names="x,y,z,red,green,blue") 97 | vertices = PlyElement.describe(vertices, "vertex") 98 | PlyData([vertices]).write(args.output_filename) 99 | 100 | 101 | #------------------------------------------------------------------------------- 102 | 103 | if __name__ == "__main__": 104 | import argparse 105 | 106 | parser = argparse.ArgumentParser( 107 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 108 | 109 | parser.add_argument("dense_folder", type=str) 110 | parser.add_argument("image_filename", type=str) 111 | parser.add_argument("output_filename", type=str) 112 | 113 | parser.add_argument( 114 | "--photometric", default=False, action="store_true", 115 | help="use photometric depthmap instead of geometric") 116 | 117 | parser.add_argument( 118 | "--world_space", default=False, action="store_true", 119 | help="apply the camera->world extrinsic transformation to the result") 120 | 121 | parser.add_argument( 122 | "--save_normals", default=False, action="store_true", 123 | help="load the estimated normal map and save as part of the PLY") 124 | 125 | parser.add_argument( 126 | "--stereo_folder", type=str, default="stereo", 127 | help="folder in the dense workspace containing depth and normal maps") 128 | 129 | parser.add_argument( 130 | "--min_depth", type=float, default=None, 131 | help="set pixels with depth less than this value to zero depth") 132 | 133 | parser.add_argument( 134 | "--max_depth", type=float, default=None, 135 | help="set pixels with depth greater than this value to zero depth") 136 | 137 | args = parser.parse_args() 138 | 139 | main(args) 140 | -------------------------------------------------------------------------------- /internal/image.py: -------------------------------------------------------------------------------- 1 | # This file was modified by Deborah Levy 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Functions for processing images.""" 17 | 18 | import types 19 | from typing import Optional, Union 20 | 21 | import dm_pix 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | _Array = Union[np.ndarray, jnp.ndarray] 27 | 28 | 29 | def mse_to_psnr(mse): 30 | """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" 31 | return -10. / jnp.log(10.) * jnp.log(mse) 32 | 33 | 34 | def psnr_to_mse(psnr): 35 | """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" 36 | return jnp.exp(-0.1 * jnp.log(10.) * psnr) 37 | 38 | 39 | def ssim_to_dssim(ssim): 40 | """Compute DSSIM given an SSIM.""" 41 | return (1 - ssim) / 2 42 | 43 | 44 | def dssim_to_ssim(dssim): 45 | """Compute DSSIM given an SSIM.""" 46 | return 1 - 2 * dssim 47 | 48 | 49 | def linear_to_srgb(linear: _Array, 50 | eps: Optional[float] = None, 51 | xnp: types.ModuleType = jnp) -> _Array: 52 | """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" 53 | if eps is None: 54 | eps = xnp.finfo(xnp.float32).eps 55 | srgb0 = 323 / 25 * linear 56 | srgb1 = (211 * xnp.maximum(eps, linear)**(5 / 12) - 11) / 200 57 | return xnp.where(linear <= 0.0031308, srgb0, srgb1) 58 | 59 | 60 | def srgb_to_linear(srgb: _Array, 61 | eps: Optional[float] = None, 62 | xnp: types.ModuleType = jnp) -> _Array: 63 | """Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" 64 | if eps is None: 65 | eps = xnp.finfo(xnp.float32).eps 66 | linear0 = 25 / 323 * srgb 67 | linear1 = xnp.maximum(eps, ((200 * srgb + 11) / (211)))**(12 / 5) 68 | return xnp.where(srgb <= 0.04045, linear0, linear1) 69 | 70 | 71 | def downsample(img, factor): 72 | """Area downsample img (factor must evenly divide img height and width).""" 73 | sh = img.shape 74 | if not (sh[0] % factor == 0 and sh[1] % factor == 0): 75 | raise ValueError(f'Downsampling factor {factor} does not ' 76 | f'evenly divide image shape {sh[:2]}') 77 | img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:]) 78 | img = img.mean((1, 3)) 79 | return img 80 | 81 | 82 | def color_correct(img, ref, num_iters=5, eps=0.5 / 255): 83 | """Warp `img` to match the colors in `ref_img`.""" 84 | if img.shape[-1] != ref.shape[-1]: 85 | raise ValueError( 86 | f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match' 87 | ) 88 | num_channels = img.shape[-1] 89 | img_mat = img.reshape([-1, num_channels]) 90 | ref_mat = ref.reshape([-1, num_channels]) 91 | is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps)) # z \in [eps, 1-eps]. 92 | mask0 = is_unclipped(img_mat) 93 | # Because the set of saturated pixels may change after solving for a 94 | # transformation, we repeatedly solve a system `num_iters` times and update 95 | # our estimate of which pixels are saturated. 96 | for _ in range(num_iters): 97 | # Construct the left hand side of a linear system that contains a quadratic 98 | # expansion of each pixel of `img`. 99 | a_mat = [] 100 | for c in range(num_channels): 101 | a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) # Quadratic term. 102 | a_mat.append(img_mat) # Linear term. 103 | a_mat.append(jnp.ones_like(img_mat[:, :1])) # Bias term. 104 | a_mat = jnp.concatenate(a_mat, axis=-1) 105 | warp = [] 106 | for c in range(num_channels): 107 | # Construct the right hand side of a linear system containing each color 108 | # of `ref`. 109 | b = ref_mat[:, c] 110 | # Ignore rows of the linear system that were saturated in the input or are 111 | # saturated in the current corrected color estimate. 112 | mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) 113 | ma_mat = jnp.where(mask[:, None], a_mat, 0) 114 | mb = jnp.where(mask, b, 0) 115 | # Solve the linear system. We're using the np.lstsq instead of jnp because 116 | # it's significantly more stable in this case, for some reason. 117 | w = np.linalg.lstsq(ma_mat, mb, rcond=-1)[0] 118 | assert jnp.all(jnp.isfinite(w)) 119 | warp.append(w) 120 | warp = jnp.stack(warp, axis=-1) 121 | # Apply the warp to update img_mat. 122 | img_mat = jnp.clip( 123 | jnp.matmul(a_mat, warp, precision=jax.lax.Precision.HIGHEST), 0, 1) 124 | corrected_img = jnp.reshape(img_mat, img.shape) 125 | return corrected_img 126 | 127 | 128 | class MetricHarness: 129 | """A helper class for evaluating several error metrics.""" 130 | 131 | def __init__(self): 132 | self.ssim_fn = jax.jit(dm_pix.ssim) 133 | 134 | def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s): 135 | """Evaluate the error between a predicted rgb image and the true image.""" 136 | psnr = float(mse_to_psnr(((rgb_pred - rgb_gt)**2).mean())) 137 | ssim = float(self.ssim_fn(rgb_pred, rgb_gt)) 138 | 139 | return { 140 | name_fn('psnr'): psnr, 141 | name_fn('ssim'): ssim, 142 | } 143 | -------------------------------------------------------------------------------- /internal/utils.py: -------------------------------------------------------------------------------- 1 | # This file was modified by Deborah Levy 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions.""" 17 | 18 | import enum 19 | import os 20 | from typing import Any, Dict, Optional, Union 21 | 22 | import flax 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | from PIL import ExifTags 27 | from PIL import Image 28 | # from pypfm import PFMLoader 29 | 30 | # pfm_loader = PFMLoader(color=False, compress=False) 31 | _Array = Union[np.ndarray, jnp.ndarray] 32 | 33 | 34 | @flax.struct.dataclass 35 | class Pixels: 36 | """All tensors must have the same num_dims and first n-1 dims must match.""" 37 | pix_x_int: _Array 38 | pix_y_int: _Array 39 | lossmult: _Array 40 | near: _Array 41 | far: _Array 42 | cam_idx: _Array 43 | exposure_idx: Optional[_Array] = None 44 | exposure_values: Optional[_Array] = None 45 | 46 | 47 | @flax.struct.dataclass 48 | class Rays: 49 | """All tensors must have the same num_dims and first n-1 dims must match.""" 50 | origins: _Array 51 | directions: _Array 52 | viewdirs: _Array 53 | radii: _Array 54 | imageplane: _Array 55 | lossmult: _Array 56 | near: _Array 57 | far: _Array 58 | cam_idx: _Array 59 | exposure_idx: Optional[_Array] = None 60 | exposure_values: Optional[_Array] = None 61 | x_coord: Optional[_Array] = None 62 | y_coord: Optional[_Array] = None 63 | 64 | 65 | # Dummy Rays object that can be used to initialize NeRF model. 66 | def dummy_rays(include_exposure_idx: bool = False, 67 | include_exposure_values: bool = False) -> Rays: 68 | data_fn = lambda n: jnp.zeros((1, n)) 69 | exposure_kwargs = {} 70 | if include_exposure_idx: 71 | exposure_kwargs['exposure_idx'] = data_fn(1).astype(jnp.int32) 72 | if include_exposure_values: 73 | exposure_kwargs['exposure_values'] = data_fn(1) 74 | return Rays( 75 | origins=data_fn(3), 76 | directions=data_fn(3), 77 | viewdirs=data_fn(3), 78 | radii=data_fn(1), 79 | imageplane=data_fn(2), 80 | lossmult=data_fn(1), 81 | near=data_fn(1), 82 | far=data_fn(1), 83 | cam_idx=data_fn(1).astype(jnp.int32), 84 | **exposure_kwargs) 85 | 86 | 87 | @flax.struct.dataclass 88 | class Batch: 89 | """Data batch for NeRF training or testing.""" 90 | rays: Union[Pixels, Rays] 91 | rgb: Optional[_Array] = None 92 | rgb_path: Optional[str] = None 93 | disps: Optional[_Array] = None 94 | normals: Optional[_Array] = None 95 | alphas: Optional[_Array] = None 96 | 97 | 98 | class DataSplit(enum.Enum): 99 | """Dataset split.""" 100 | TRAIN = 'train' 101 | TEST = 'test' 102 | 103 | 104 | class BatchingMethod(enum.Enum): 105 | """Draw rays randomly from a single image or all images, in each batch.""" 106 | ALL_IMAGES = 'all_images' 107 | SINGLE_IMAGE = 'single_image' 108 | 109 | 110 | def open_file(pth, mode='r'): 111 | return open(pth, mode=mode) 112 | 113 | 114 | def file_exists(pth): 115 | return os.path.exists(pth) 116 | 117 | 118 | def listdir(pth): 119 | return os.listdir(pth) 120 | 121 | 122 | def isdir(pth): 123 | return os.path.isdir(pth) 124 | 125 | 126 | def makedirs(pth): 127 | if not file_exists(pth): 128 | os.makedirs(pth) 129 | 130 | 131 | def shard(xs): 132 | """Split data into shards for multiple devices along the first dimension.""" 133 | return jax.tree_util.tree_map( 134 | lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs) 135 | 136 | 137 | def unshard(x, padding=0): 138 | """Collect the sharded tensor to the shape before sharding.""" 139 | y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:])) 140 | if padding > 0: 141 | y = y[:-padding] 142 | return y 143 | 144 | 145 | def load_img(pth: str) -> np.ndarray: 146 | """Load an image and cast to float32.""" 147 | with open_file(pth, 'rb') as f: 148 | image = np.array(Image.open(f), dtype=np.float32) 149 | return image 150 | 151 | 152 | def load_exif(pth: str) -> Dict[str, Any]: 153 | """Load EXIF data for an image.""" 154 | with open_file(pth, 'rb') as f: 155 | image_pil = Image.open(f) 156 | exif_pil = image_pil._getexif() # pylint: disable=protected-access 157 | if exif_pil is not None: 158 | exif = { 159 | ExifTags.TAGS[k]: v for k, v in exif_pil.items() if k in ExifTags.TAGS 160 | } 161 | else: 162 | exif = {} 163 | return exif 164 | 165 | 166 | def save_img_u8(img, pth): 167 | """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" 168 | with open_file(pth, 'wb') as f: 169 | Image.fromarray( 170 | (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save( 171 | f, 'PNG') 172 | 173 | 174 | def save_img_f32(depthmap, pth): 175 | """Save an image (probably a depthmap) to disk as a float32 TIFF.""" 176 | with open_file(pth, 'wb') as f: 177 | Image.fromarray(np.nan_to_num(depthmap).astype(np.float32),mode="RGB").save(f, 'TIFF') 178 | 179 | def save_npy(npy, pth): 180 | with open_file(pth, 'wb') as f: 181 | print(np.max(), np.min()) 182 | np.save(npy, file=f) 183 | # 184 | # def save_pfm(pfm, pth): 185 | # print(pfm.max(), pfm.min()) 186 | # pfm_loader.save_pfm(pth, pfm) 187 | -------------------------------------------------------------------------------- /internal/coord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tools for manipulating coordinate spaces and distances along rays.""" 15 | 16 | from internal import math 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | 21 | def contract(x): 22 | """Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077).""" 23 | eps = jnp.finfo(jnp.float32).eps 24 | # Clamping to eps prevents non-finite gradients when x == 0. 25 | x_mag_sq = jnp.maximum(eps, jnp.sum(x**2, axis=-1, keepdims=True)) 26 | z = jnp.where(x_mag_sq <= 1, x, ((2 * jnp.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) 27 | return z 28 | 29 | 30 | def inv_contract(z): 31 | """The inverse of contract().""" 32 | eps = jnp.finfo(jnp.float32).eps 33 | # Clamping to eps prevents non-finite gradients when z == 0. 34 | z_mag_sq = jnp.maximum(eps, jnp.sum(z**2, axis=-1, keepdims=True)) 35 | x = jnp.where(z_mag_sq <= 1, z, z / (2 * jnp.sqrt(z_mag_sq) - z_mag_sq)) 36 | return x 37 | 38 | 39 | def track_linearize(fn, mean, cov): 40 | """Apply function `fn` to a set of means and covariances, ala a Kalman filter. 41 | 42 | We can analytically transform a Gaussian parameterized by `mean` and `cov` 43 | with a function `fn` by linearizing `fn` around `mean`, and taking advantage 44 | of the fact that Covar[Ax + y] = A(Covar[x])A^T (see 45 | https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details). 46 | 47 | Args: 48 | fn: the function applied to the Gaussians parameterized by (mean, cov). 49 | mean: a tensor of means, where the last axis is the dimension. 50 | cov: a tensor of covariances, where the last two axes are the dimensions. 51 | 52 | Returns: 53 | fn_mean: the transformed means. 54 | fn_cov: the transformed covariances. 55 | """ 56 | if (len(mean.shape) + 1) != len(cov.shape): 57 | raise ValueError('cov must be non-diagonal') 58 | fn_mean, lin_fn = jax.linearize(fn, mean) 59 | fn_cov = jax.vmap(lin_fn, -1, -2)(jax.vmap(lin_fn, -1, -2)(cov)) 60 | return fn_mean, fn_cov 61 | 62 | 63 | def construct_ray_warps(fn, t_near, t_far): 64 | """Construct a bijection between metric distances and normalized distances. 65 | 66 | See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a 67 | detailed explanation. 68 | 69 | Args: 70 | fn: the function to ray distances. 71 | t_near: a tensor of near-plane distances. 72 | t_far: a tensor of far-plane distances. 73 | 74 | Returns: 75 | t_to_s: a function that maps distances to normalized distances in [0, 1]. 76 | s_to_t: the inverse of t_to_s. 77 | """ 78 | if fn is None: 79 | fn_fwd = lambda x: x 80 | fn_inv = lambda x: x 81 | elif fn == 'piecewise': 82 | # Piecewise spacing combining identity and 1/x functions to allow t_near=0. 83 | fn_fwd = lambda x: jnp.where(x < 1, .5 * x, 1 - .5 / x) 84 | fn_inv = lambda x: jnp.where(x < .5, 2 * x, .5 / (1 - x)) 85 | else: 86 | inv_mapping = { 87 | 'reciprocal': jnp.reciprocal, 88 | 'log': jnp.exp, 89 | 'exp': jnp.log, 90 | 'sqrt': jnp.square, 91 | 'square': jnp.sqrt 92 | } 93 | fn_fwd = fn 94 | fn_inv = inv_mapping[fn.__name__] 95 | 96 | s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)] 97 | t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near) 98 | s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near) 99 | return t_to_s, s_to_t 100 | 101 | 102 | def expected_sin(mean, var): 103 | """Compute the mean of sin(x), x ~ N(mean, var).""" 104 | return jnp.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value. 105 | 106 | 107 | def integrated_pos_enc(mean, var, min_deg, max_deg): 108 | """Encode `x` with sinusoids scaled by 2^[min_deg, max_deg). 109 | 110 | Args: 111 | mean: tensor, the mean coordinates to be encoded 112 | var: tensor, the variance of the coordinates to be encoded. 113 | min_deg: int, the min degree of the encoding. 114 | max_deg: int, the max degree of the encoding. 115 | 116 | Returns: 117 | encoded: jnp.ndarray, encoded variables. 118 | """ 119 | scales = 2**jnp.arange(min_deg, max_deg) 120 | shape = mean.shape[:-1] + (-1,) 121 | scaled_mean = jnp.reshape(mean[..., None, :] * scales[:, None], shape) 122 | scaled_var = jnp.reshape(var[..., None, :] * scales[:, None]**2, shape) 123 | 124 | return expected_sin( 125 | jnp.concatenate([scaled_mean, scaled_mean + 0.5 * jnp.pi], axis=-1), 126 | jnp.concatenate([scaled_var] * 2, axis=-1)) 127 | 128 | 129 | def lift_and_diagonalize(mean, cov, basis): 130 | """Project `mean` and `cov` onto basis and diagonalize the projected cov.""" 131 | fn_mean = math.matmul(mean, basis) 132 | fn_cov_diag = jnp.sum(basis * math.matmul(cov, basis), axis=-2) 133 | return fn_mean, fn_cov_diag 134 | 135 | 136 | def pos_enc(x, min_deg, max_deg, append_identity=True): 137 | """The positional encoding used by the original NeRF paper.""" 138 | scales = 2**jnp.arange(min_deg, max_deg) 139 | shape = x.shape[:-1] + (-1,) 140 | scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape) 141 | # Note that we're not using safe_sin, unlike IPE. 142 | four_feat = jnp.sin( 143 | jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1)) 144 | if append_identity: 145 | return jnp.concatenate([x] + [four_feat], axis=-1) 146 | else: 147 | return four_feat 148 | -------------------------------------------------------------------------------- /tests/image_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for image.""" 16 | 17 | from absl.testing import absltest 18 | from internal import image 19 | import jax 20 | from jax import random 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | def matmul(a, b): 26 | """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" 27 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 28 | 29 | 30 | class ImageTest(absltest.TestCase): 31 | 32 | def test_color_correction(self): 33 | """Test that color correction can undo a CCM + quadratic warp + shift.""" 34 | im_shape = (128, 128, 3) 35 | rng = random.PRNGKey(0) 36 | for _ in range(10): 37 | # Construct a random image. 38 | key, rng = random.split(rng) 39 | im0 = random.uniform(key, shape=im_shape, minval=0.1, maxval=0.9) 40 | 41 | # Construct a random linear + quadratic color transformation. 42 | key, rng = random.split(rng) 43 | ccm_scale = random.normal(key) / 10 44 | key, rng = random.split(rng) 45 | shift = random.normal(key) / 10 46 | key, rng = random.split(rng) 47 | sq_mult = random.normal(key) / 10 48 | key, rng = random.split(rng) 49 | ccm = jnp.eye(3) + random.normal(key, shape=(3, 3)) * ccm_scale 50 | 51 | # Apply that random transformation to the image. 52 | im1 = jnp.clip( 53 | (matmul(jnp.reshape(im0, [-1, 3]), ccm)).reshape(im0.shape) + 54 | sq_mult * im0**2 + shift, 0, 1) 55 | 56 | # Check that color correction recovers the randomly transformed image. 57 | im0_cc = image.color_correct(im0, im1) 58 | np.testing.assert_allclose(im0_cc, im1, atol=1E-5, rtol=1E-5) 59 | 60 | def test_psnr_mse_round_trip(self): 61 | """PSNR -> MSE -> PSNR is a no-op.""" 62 | for psnr in [10., 20., 30.]: 63 | np.testing.assert_allclose( 64 | image.mse_to_psnr(image.psnr_to_mse(psnr)), 65 | psnr, 66 | atol=1E-5, 67 | rtol=1E-5) 68 | 69 | def test_ssim_dssim_round_trip(self): 70 | """SSIM -> DSSIM -> SSIM is a no-op.""" 71 | for ssim in [-0.9, 0, 0.9]: 72 | np.testing.assert_allclose( 73 | image.dssim_to_ssim(image.ssim_to_dssim(ssim)), 74 | ssim, 75 | atol=1E-5, 76 | rtol=1E-5) 77 | 78 | def test_srgb_linearize(self): 79 | x = jnp.linspace(-1, 3, 10000) # Nobody should call this <0 but it works. 80 | # Check that the round-trip transformation is a no-op. 81 | np.testing.assert_allclose( 82 | image.linear_to_srgb(image.srgb_to_linear(x)), x, atol=1E-5, rtol=1E-5) 83 | np.testing.assert_allclose( 84 | image.srgb_to_linear(image.linear_to_srgb(x)), x, atol=1E-5, rtol=1E-5) 85 | # Check that gradients are finite. 86 | self.assertTrue( 87 | jnp.all(jnp.isfinite(jax.vmap(jax.grad(image.linear_to_srgb))(x)))) 88 | self.assertTrue( 89 | jnp.all(jnp.isfinite(jax.vmap(jax.grad(image.srgb_to_linear))(x)))) 90 | 91 | def test_srgb_to_linear_golden(self): 92 | """A lazy golden test for srgb_to_linear.""" 93 | srgb = jnp.linspace(0, 1, 64) 94 | linear = image.srgb_to_linear(srgb) 95 | linear_gt = jnp.array([ 96 | 0.00000000, 0.00122856, 0.00245712, 0.00372513, 0.00526076, 0.00711347, 97 | 0.00929964, 0.01183453, 0.01473243, 0.01800687, 0.02167065, 0.02573599, 98 | 0.03021459, 0.03511761, 0.04045585, 0.04623971, 0.05247922, 0.05918410, 99 | 0.06636375, 0.07402734, 0.08218378, 0.09084171, 0.10000957, 0.10969563, 100 | 0.11990791, 0.13065430, 0.14194246, 0.15377994, 0.16617411, 0.17913227, 101 | 0.19266140, 0.20676863, 0.22146071, 0.23674440, 0.25262633, 0.26911288, 102 | 0.28621066, 0.30392596, 0.32226467, 0.34123330, 0.36083785, 0.38108405, 103 | 0.40197787, 0.42352500, 0.44573134, 0.46860245, 0.49214387, 0.51636110, 104 | 0.54125960, 0.56684470, 0.59312177, 0.62009590, 0.64777250, 0.67615650, 105 | 0.70525320, 0.73506740, 0.76560410, 0.79686830, 0.82886493, 0.86159873, 106 | 0.89507430, 0.92929670, 0.96427040, 1.00000000 107 | ]) 108 | np.testing.assert_allclose(linear, linear_gt, atol=1E-5, rtol=1E-5) 109 | 110 | def test_mse_to_psnr_golden(self): 111 | """A lazy golden test for mse_to_psnr.""" 112 | mse = jnp.exp(jnp.linspace(-10, 0, 64)) 113 | psnr = image.mse_to_psnr(mse) 114 | psnr_gt = jnp.array([ 115 | 43.429447, 42.740090, 42.050735, 41.361378, 40.6720240, 39.982666, 116 | 39.293310, 38.603954, 37.914597, 37.225240, 36.5358850, 35.846527, 117 | 35.157170, 34.467810, 33.778458, 33.089100, 32.3997460, 31.710388, 118 | 31.021034, 30.331675, 29.642320, 28.952961, 28.2636070, 27.574250, 119 | 26.884893, 26.195538, 25.506180, 24.816826, 24.1274700, 23.438112, 120 | 22.748756, 22.059400, 21.370045, 20.680689, 19.9913310, 19.301975, 121 | 18.612620, 17.923262, 17.233906, 16.544550, 15.8551940, 15.165837, 122 | 14.4764805, 13.787125, 13.097769, 12.408413, 11.719056, 11.029700, 123 | 10.3403420, 9.6509850, 8.9616290, 8.2722720, 7.5829163, 6.8935600, 124 | 6.2042036, 5.5148473, 4.825491, 4.136135, 3.4467785, 2.7574227, 125 | 2.0680661, 1.37871, 0.68935364, 0. 126 | ]) 127 | np.testing.assert_allclose(psnr, psnr_gt, atol=1E-5, rtol=1E-5) 128 | 129 | 130 | if __name__ == '__main__': 131 | absltest.main() 132 | -------------------------------------------------------------------------------- /internal/pycolmap/tools/impute_missing_cameras.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | 6 | from pycolmap import DualQuaternion, Image, SceneManager 7 | 8 | 9 | #------------------------------------------------------------------------------- 10 | 11 | image_to_idx = lambda im: int(im.name[:im.name.rfind(".")]) 12 | 13 | 14 | #------------------------------------------------------------------------------- 15 | 16 | def interpolate_linear(images, camera_id, file_format): 17 | if len(images) < 2: 18 | raise ValueError("Need at least two images for linear interpolation!") 19 | 20 | prev_image = images[0] 21 | prev_idx = image_to_idx(prev_image) 22 | prev_dq = DualQuaternion.FromQT(prev_image.q, prev_image.t) 23 | start = prev_idx 24 | 25 | new_images = [] 26 | 27 | for image in images[1:]: 28 | curr_idx = image_to_idx(image) 29 | curr_dq = DualQuaternion.FromQT(image.q, image.t) 30 | T = curr_idx - prev_idx 31 | Tinv = 1. / T 32 | 33 | # like quaternions, dq(x) = -dq(x), so we'll need to pick the one more 34 | # appropriate for interpolation by taking -dq if the dot product of the 35 | # two q-vectors is negative 36 | if prev_dq.q0.dot(curr_dq.q0) < 0: 37 | curr_dq = -curr_dq 38 | 39 | for i in xrange(1, T): 40 | t = i * Tinv 41 | dq = t * prev_dq + (1. - t) * curr_dq 42 | q, t = dq.ToQT() 43 | new_images.append( 44 | Image(file_format.format(prev_idx + i), args.camera_id, q, t)) 45 | 46 | prev_idx = curr_idx 47 | prev_dq = curr_dq 48 | 49 | return new_images 50 | 51 | 52 | #------------------------------------------------------------------------------- 53 | 54 | def interpolate_hermite(images, camera_id, file_format): 55 | if len(images) < 4: 56 | raise ValueError( 57 | "Need at least four images for Hermite spline interpolation!") 58 | 59 | new_images = [] 60 | 61 | # linear blending for the first frames 62 | T0 = image_to_idx(images[0]) 63 | dq0 = DualQuaternion.FromQT(images[0].q, images[0].t) 64 | T1 = image_to_idx(images[1]) 65 | dq1 = DualQuaternion.FromQT(images[1].q, images[1].t) 66 | 67 | if dq0.q0.dot(dq1.q0) < 0: 68 | dq1 = -dq1 69 | dT = 1. / float(T1 - T0) 70 | for j in xrange(1, T1 - T0): 71 | t = j * dT 72 | dq = ((1. - t) * dq0 + t * dq1).normalize() 73 | new_images.append( 74 | Image(file_format.format(T0 + j), camera_id, *dq.ToQT())) 75 | 76 | T2 = image_to_idx(images[2]) 77 | dq2 = DualQuaternion.FromQT(images[2].q, images[2].t) 78 | if dq1.q0.dot(dq2.q0) < 0: 79 | dq2 = -dq2 80 | 81 | # Hermite spline interpolation of dual quaternions 82 | # pdfs.semanticscholar.org/05b1/8ede7f46c29c2722fed3376d277a1d286c55.pdf 83 | for i in xrange(1, len(images) - 2): 84 | T3 = image_to_idx(images[i + 2]) 85 | dq3 = DualQuaternion.FromQT(images[i + 2].q, images[i + 2].t) 86 | if dq2.q0.dot(dq3.q0) < 0: 87 | dq3 = -dq3 88 | 89 | prev_duration = T1 - T0 90 | current_duration = T2 - T1 91 | next_duration = T3 - T2 92 | 93 | # approximate the derivatives at dq1 and dq2 using weighted central 94 | # differences 95 | dt1 = 1. / float(T2 - T0) 96 | dt2 = 1. / float(T3 - T1) 97 | 98 | m1 = (current_duration * dt1) * (dq2 - dq1) + \ 99 | (prev_duration * dt1) * (dq1 - dq0) 100 | m2 = (next_duration * dt2) * (dq3 - dq2) + \ 101 | (current_duration * dt2) * (dq2 - dq1) 102 | 103 | dT = 1. / float(current_duration) 104 | 105 | for j in xrange(1, current_duration): 106 | t = j * dT # 0 to 1 107 | t2 = t * t # t squared 108 | t3 = t2 * t # t cubed 109 | 110 | # coefficients of the Hermite spline (a=>dq and b=>m) 111 | a1 = 2. * t3 - 3. * t2 + 1. 112 | b1 = t3 - 2. * t2 + t 113 | a2 = -2. * t3 + 3. * t2 114 | b2 = t3 - t2 115 | 116 | dq = (a1 * dq1 + b1 * m1 + a2 * dq2 + b2 * m2).normalize() 117 | 118 | new_images.append( 119 | Image(file_format.format(T1 + j), camera_id, *dq.ToQT())) 120 | 121 | T0, T1, T2 = T1, T2, T3 122 | dq0, dq1, dq2 = dq1, dq2, dq3 123 | 124 | # linear blending for the last frames 125 | dT = 1. / float(T2 - T1) 126 | for j in xrange(1, T2 - T1): 127 | t = j * dT # 0 to 1 128 | dq = ((1. - t) * dq1 + t * dq2).normalize() 129 | new_images.append( 130 | Image(file_format.format(T1 + j), camera_id, *dq.ToQT())) 131 | 132 | return new_images 133 | 134 | 135 | #------------------------------------------------------------------------------- 136 | 137 | def main(args): 138 | scene_manager = SceneManager(args.input_folder) 139 | scene_manager.load() 140 | 141 | images = sorted(scene_manager.images.itervalues(), key=image_to_idx) 142 | 143 | if args.method.lower() == "linear": 144 | new_images = interpolate_linear(images, args.camera_id, args.format) 145 | else: 146 | new_images = interpolate_hermite(images, args.camera_id, args.format) 147 | 148 | map(scene_manager.add_image, new_images) 149 | 150 | scene_manager.save(args.output_folder) 151 | 152 | 153 | #------------------------------------------------------------------------------- 154 | 155 | if __name__ == "__main__": 156 | import argparse 157 | 158 | parser = argparse.ArgumentParser( 159 | description="Given a reconstruction with ordered images *with integer " 160 | "filenames* like '000100.png', fill in missing camera positions for " 161 | "intermediate frames.", 162 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 163 | 164 | parser.add_argument("input_folder") 165 | parser.add_argument("output_folder") 166 | 167 | parser.add_argument("--camera_id", type=int, default=1, 168 | help="camera id to use for the missing images") 169 | 170 | parser.add_argument("--format", type=str, default="{:06d}.png", 171 | help="filename format to use for added images") 172 | 173 | parser.add_argument( 174 | "--method", type=str.lower, choices=("linear", "hermite"), 175 | default="hermite", 176 | help="Pose imputation method") 177 | 178 | args = parser.parse_args() 179 | 180 | main(args) 181 | -------------------------------------------------------------------------------- /internal/ref_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for reflection directions and directional encodings.""" 16 | 17 | from internal import math 18 | import jax.numpy as jnp 19 | import numpy as np 20 | 21 | 22 | def reflect(viewdirs, normals): 23 | """Reflect view directions about normals. 24 | 25 | The reflection of a vector v about a unit vector n is a vector u such that 26 | dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two 27 | equations is u = 2 dot(n, v) n - v. 28 | 29 | Args: 30 | viewdirs: [..., 3] array of view directions. 31 | normals: [..., 3] array of normal directions (assumed to be unit vectors). 32 | 33 | Returns: 34 | [..., 3] array of reflection directions. 35 | """ 36 | return 2.0 * jnp.sum( 37 | normals * viewdirs, axis=-1, keepdims=True) * normals - viewdirs 38 | 39 | 40 | def l2_normalize(x, eps=jnp.finfo(jnp.float32).eps): 41 | """Normalize x to unit length along last axis.""" 42 | return x / jnp.sqrt(jnp.maximum(jnp.sum(x**2, axis=-1, keepdims=True), eps)) 43 | 44 | 45 | def compute_weighted_mae(weights, normals, normals_gt): 46 | """Compute weighted mean angular error, assuming normals are unit length.""" 47 | one_eps = 1 - jnp.finfo(jnp.float32).eps 48 | return (weights * jnp.arccos( 49 | jnp.clip((normals * normals_gt).sum(-1), -one_eps, 50 | one_eps))).sum() / weights.sum() * 180.0 / jnp.pi 51 | 52 | 53 | def generalized_binomial_coeff(a, k): 54 | """Compute generalized binomial coefficients.""" 55 | return np.prod(a - np.arange(k)) / np.math.factorial(k) 56 | 57 | 58 | def assoc_legendre_coeff(l, m, k): 59 | """Compute associated Legendre polynomial coefficients. 60 | 61 | Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the 62 | (l, m)th associated Legendre polynomial, P_l^m(cos(theta)). 63 | 64 | Args: 65 | l: associated Legendre polynomial degree. 66 | m: associated Legendre polynomial order. 67 | k: power of cos(theta). 68 | 69 | Returns: 70 | A float, the coefficient of the term corresponding to the inputs. 71 | """ 72 | return ((-1)**m * 2**l * np.math.factorial(l) / np.math.factorial(k) / 73 | np.math.factorial(l - k - m) * 74 | generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l)) 75 | 76 | 77 | def sph_harm_coeff(l, m, k): 78 | """Compute spherical harmonic coefficients.""" 79 | return (np.sqrt( 80 | (2.0 * l + 1.0) * np.math.factorial(l - m) / 81 | (4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k)) 82 | 83 | 84 | def get_ml_array(deg_view): 85 | """Create a list with all pairs of (l, m) values to use in the encoding.""" 86 | ml_list = [] 87 | for i in range(deg_view): 88 | l = 2**i 89 | # Only use nonnegative m values, later splitting real and imaginary parts. 90 | for m in range(l + 1): 91 | ml_list.append((m, l)) 92 | 93 | # Convert list into a numpy array. 94 | ml_array = np.array(ml_list).T 95 | return ml_array 96 | 97 | 98 | def generate_ide_fn(deg_view): 99 | """Generate integrated directional encoding (IDE) function. 100 | 101 | This function returns a function that computes the integrated directional 102 | encoding from Equations 6-8 of arxiv.org/abs/2112.03907. 103 | 104 | Args: 105 | deg_view: number of spherical harmonics degrees to use. 106 | 107 | Returns: 108 | A function for evaluating integrated directional encoding. 109 | 110 | Raises: 111 | ValueError: if deg_view is larger than 5. 112 | """ 113 | if deg_view > 5: 114 | raise ValueError('Only deg_view of at most 5 is numerically stable.') 115 | 116 | ml_array = get_ml_array(deg_view) 117 | l_max = 2**(deg_view - 1) 118 | 119 | # Create a matrix corresponding to ml_array holding all coefficients, which, 120 | # when multiplied (from the right) by the z coordinate Vandermonde matrix, 121 | # results in the z component of the encoding. 122 | mat = np.zeros((l_max + 1, ml_array.shape[1])) 123 | for i, (m, l) in enumerate(ml_array.T): 124 | for k in range(l - m + 1): 125 | mat[k, i] = sph_harm_coeff(l, m, k) 126 | 127 | def integrated_dir_enc_fn(xyz, kappa_inv): 128 | """Function returning integrated directional encoding (IDE). 129 | 130 | Args: 131 | xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at. 132 | kappa_inv: [..., 1] reciprocal of the concentration parameter of the von 133 | Mises-Fisher distribution. 134 | 135 | Returns: 136 | An array with the resulting IDE. 137 | """ 138 | x = xyz[..., 0:1] 139 | y = xyz[..., 1:2] 140 | z = xyz[..., 2:3] 141 | 142 | # Compute z Vandermonde matrix. 143 | vmz = jnp.concatenate([z**i for i in range(mat.shape[0])], axis=-1) 144 | 145 | # Compute x+iy Vandermonde matrix. 146 | vmxy = jnp.concatenate([(x + 1j * y)**m for m in ml_array[0, :]], axis=-1) 147 | 148 | # Get spherical harmonics. 149 | sph_harms = vmxy * math.matmul(vmz, mat) 150 | 151 | # Apply attenuation function using the von Mises-Fisher distribution 152 | # concentration parameter, kappa. 153 | sigma = 0.5 * ml_array[1, :] * (ml_array[1, :] + 1) 154 | ide = sph_harms * jnp.exp(-sigma * kappa_inv) 155 | 156 | # Split into real and imaginary parts and return 157 | return jnp.concatenate([jnp.real(ide), jnp.imag(ide)], axis=-1) 158 | 159 | return integrated_dir_enc_fn 160 | 161 | 162 | def generate_dir_enc_fn(deg_view): 163 | """Generate directional encoding (DE) function. 164 | 165 | Args: 166 | deg_view: number of spherical harmonics degrees to use. 167 | 168 | Returns: 169 | A function for evaluating directional encoding. 170 | """ 171 | integrated_dir_enc_fn = generate_ide_fn(deg_view) 172 | 173 | def dir_enc_fn(xyz): 174 | """Function returning directional encoding (DE).""" 175 | return integrated_dir_enc_fn(xyz, jnp.zeros_like(xyz[..., :1])) 176 | 177 | return dir_enc_fn 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeaThru-NeRF: Neural Radiance Fields In Scattering Media, CVPR 2023 2 | 3 | #### [project page](https://sea-thru-nerf.github.io/) | [paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Levy_SeaThru-NeRF_Neural_Radiance_Fields_in_Scattering_Media_CVPR_2023_paper.pdf) | [Nerfstudio](https://docs.nerf.studio/nerfology/methods/seathru_nerf.html) 4 | 5 | > SeaThru-NeRF: Neural Radiance Fields In Scattering Media 6 | > [Deborah Levy](mailto:dlrun14@gmail.com) | Amit Peleg | [Naama Pearl](https://naamapearl.github.io/) | Dan Rosenbaum | [Derya Akkaynak](https://www.deryaakkaynak.com/) | [Tali Treibitz](https://www.viseaon.haifa.ac.il/) | [Simon Korman](https://www.cs.haifa.ac.il/~skorman/) 7 | > CVPR 2023 8 | 9 | 10 | Our implementation is based on the paper "Mip-NeRF 360: Unbounded Anti-Aliased Neural Radiance Fields" (CVPR 2022) and their [github repository](https://github.com/google-research/multinerf). 11 | 12 | 13 | This implementation is written in [JAX](https://github.com/google/jax). 14 | 15 | ## Setup 16 | 17 | ``` 18 | # Clone the repo. 19 | git clone https://github.com/deborahLevy130/seathru_NeRF.git 20 | cd seathru_NeRF 21 | mkdir data 22 | 23 | # Make a conda environment. 24 | conda create --name seathruNeRF python=3.9 25 | conda activate seathruNeRF 26 | 27 | # Prepare pip. 28 | conda install pip 29 | pip install --upgrade pip 30 | 31 | # Install requirements. 32 | pip install -r requirements.txt 33 | 34 | # Manually install rmbrualla's `pycolmap` (don't use pip's! It's different). 35 | git clone https://github.com/rmbrualla/pycolmap.git ./internal/pycolmap 36 | 37 | ``` 38 | You'll probably also need to update your JAX installation to support GPUs or TPUs. 39 | 40 | ## Running 41 | 42 | Example scripts for training, evaluating, and rendering can be found in 43 | `scripts/`. You'll need to change the paths to point to wherever the datasets 44 | are located. [Gin](https://github.com/google/gin-config) configuration files 45 | for our model and some ablations can be found in `configs/`. 46 | 47 | 48 | ### OOM errors 49 | 50 | You may need to reduce the batch size (`Config.batch_size`) to avoid out of memory 51 | errors. If you do this, but want to preserve quality, be sure to increase the number 52 | of training iterations and decrease the learning rate by whatever scale factor you 53 | decrease batch size by. 54 | 55 | ## Using your own data 56 | 57 | Summary: first, calculate poses. Second, train SeaThru-NeRF. Third, render a result video from the trained NeRF model. 58 | 59 | 1. Calculating poses (using [COLMAP](https://colmap.github.io/install.html)): 60 | 61 | 2. Training SeaThru-NeRF: 62 | ``` 63 | ./scripts/train_llff_uw.sh 64 | ``` 65 | set `SCENE` to the image set you wish to use 66 | 67 | 3. Evaluating SeaThru-NeRF on existing images: 68 | 69 | ``` 70 | ./scripts/render_llff_uw.sh 71 | ``` 72 | set `SCENE` and `EXPERIMENT_NAME` to the corresponding experiment. 73 | 74 | 4. Rendering SeaThru-NeRF Novel Views: 75 | ``` 76 | ./scripts/render_llff_uw.sh 77 | ``` 78 | set `SCENE` and `EXPERIMENT_NAME` to the corresponding experiment. 79 | 80 | Your output video should now exist in the directory `ckpt/uw/${SCENE}_${EXPERIMENT_NAME}/render/`. 81 | You will find the underwater rendering, the restored images rendering (J) and the depth maps. 82 | ## Dataset - Photos credit - Matan Yuval 83 | 84 | [Here](https://drive.google.com/uc?export=download&id=1RzojBFvBWjUUhuJb95xJPSNP3nJwZWaT) you will find the underwater scenes from the paper. 85 | Extract the files into the data folder and train SeaThru-NeRF with those scenes: 86 | 87 | ``` 88 | ./scripts/train_llff_uw.sh 89 | ``` 90 | 91 | In ```'${SCENE}'``` put the name of the scene you wish to work with. 92 | 93 | For more datasets formats you can refer to [multinerf](https://github.com/google-research/multinerf) 94 | 95 | For now our NeRF works on looking forward scenes. 96 | 97 | ### Running SeaThru-NeRF on your own data 98 | 99 | In order to run SeaThru-NeRF on your own captured images of a scene, you must first run [COLMAP](https://colmap.github.io/install.html) to calculate camera poses. After you run COLMAP, you can run [this](https://github.com/Fyusion/LLFF/blob/master/imgs2poses.py) script from LLFF to get poses_bound.npy file. 100 | After you run COLMAP, all you need to do to load your scene in SeaThru-NeRF is ensure it has the following format: 101 | ``` 102 | my_dataset_dir/images_wb/ <--- all input images 103 | my_dataset_dir/sparse/0/ <--- COLMAP sparse reconstruction files (cameras, images, points) 104 | my_dataset_dir/poses_bounds.npy 105 | ``` 106 | ### How to implement SeaThru-NeRF in your own NeRF 107 | 108 | To incorporate our NeRF into an existing NeRF framework, follow these steps: 109 | 110 | 1. Incorporate the medium's module into the MLP by referring to the architecture provided in section 4.5 of the paper titled "Implementation and Optimization." You can also refer to the code available [here](https://github.com/deborahLevy130/seathru_NeRF/blob/master/internal/models.py#L866). 111 | Update(17.08.23): The water net depth is 1 112 | 113 | 3. Modify the rendering equations as outlined in the paper. 114 | 115 | 4. Integrate the accuracy loss described in the paper for the object's transmission. You can refer to our implementation available [here](https://github.com/deborahLevy130/seathru_NeRF/blob/master/internal/train_utils.py#L153). If you have an alternative loss function that encourages the weights of your rendering equations to be somehow Unimodal (or close to Dirac delta function), you may use it instead of the accuracy loss. Simply apply it to the weights of the objects. 116 | 117 | 118 | 119 | 120 | 121 | ## Citation 122 | If you use this software package, please cite whichever constituent paper(s) 123 | you build upon, or feel free to cite this entire codebase as: 124 | 125 | ``` 126 | @inproceedings{levy2023seathru, 127 | title={SeaThru-NeRF: Neural Radiance Fields in Scattering Media}, 128 | author={Levy, Deborah and Peleg, Amit and Pearl, Naama and Rosenbaum, Dan and Akkaynak, Derya and Korman, Simon and Treibitz, Tali}, 129 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 130 | pages={56--65}, 131 | year={2023} 132 | } 133 | 134 | @misc{multinerf2022, 135 | title={{MultiNeRF}: {A} {Code} {Release} for {Mip-NeRF} 360, {Ref-NeRF}, and {RawNeRF}}, 136 | author={Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman and Ricardo Martin-Brualla and Jonathan T. Barron}, 137 | year={2022}, 138 | url={https://github.com/google-research/multinerf}, 139 | } 140 | ``` 141 | -------------------------------------------------------------------------------- /tests/math_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for math.""" 16 | 17 | import functools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from internal import math 22 | import jax 23 | from jax import random 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | def safe_trig_harness(fn, max_exp): 29 | x = 10**np.linspace(-30, max_exp, 10000) 30 | x = np.concatenate([-x[::-1], np.array([0]), x]) 31 | y_true = getattr(np, fn)(x) 32 | y = getattr(math, 'safe_' + fn)(x) 33 | return y_true, y 34 | 35 | 36 | class MathTest(parameterized.TestCase): 37 | 38 | def test_sin(self): 39 | """In [-1e10, 1e10] safe_sin and safe_cos are accurate.""" 40 | for fn in ['sin', 'cos']: 41 | y_true, y = safe_trig_harness(fn, 10) 42 | self.assertLess(jnp.max(jnp.abs(y - y_true)), 1e-4) 43 | self.assertFalse(jnp.any(jnp.isnan(y))) 44 | # Beyond that range it's less accurate but we just don't want it to be NaN. 45 | for fn in ['sin', 'cos']: 46 | y_true, y = safe_trig_harness(fn, 60) 47 | self.assertFalse(jnp.any(jnp.isnan(y))) 48 | 49 | def test_safe_exp_correct(self): 50 | """math.safe_exp() should match np.exp() for not-huge values.""" 51 | x = jnp.linspace(-80, 80, 10001) 52 | y = math.safe_exp(x) 53 | g = jax.vmap(jax.grad(math.safe_exp))(x) 54 | yg_true = jnp.exp(x) 55 | np.testing.assert_allclose(y, yg_true) 56 | np.testing.assert_allclose(g, yg_true) 57 | 58 | def test_safe_exp_finite(self): 59 | """math.safe_exp() behaves reasonably for huge values.""" 60 | x = jnp.linspace(-100000, 100000, 10001) 61 | y = math.safe_exp(x) 62 | g = jax.vmap(jax.grad(math.safe_exp))(x) 63 | # `y` and `g` should both always be finite. 64 | self.assertTrue(jnp.all(jnp.isfinite(y))) 65 | self.assertTrue(jnp.all(jnp.isfinite(g))) 66 | # The derivative of exp() should be exp(). 67 | np.testing.assert_allclose(y, g) 68 | # safe_exp()'s output and gradient should be monotonic. 69 | self.assertTrue(jnp.all(y[1:] >= y[:-1])) 70 | self.assertTrue(jnp.all(g[1:] >= g[:-1])) 71 | 72 | def test_learning_rate_decay(self): 73 | rng = random.PRNGKey(0) 74 | for _ in range(10): 75 | key, rng = random.split(rng) 76 | lr_init = jnp.exp(random.normal(key) - 3) 77 | key, rng = random.split(rng) 78 | lr_final = lr_init * jnp.exp(random.normal(key) - 5) 79 | key, rng = random.split(rng) 80 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key)))) 81 | 82 | lr_fn = functools.partial( 83 | math.learning_rate_decay, 84 | lr_init=lr_init, 85 | lr_final=lr_final, 86 | max_steps=max_steps) 87 | 88 | # Test that the rate at the beginning is the initial rate. 89 | np.testing.assert_allclose(lr_fn(0), lr_init, atol=1E-5, rtol=1E-5) 90 | 91 | # Test that the rate at the end is the final rate. 92 | np.testing.assert_allclose( 93 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5) 94 | 95 | # Test that the rate at the middle is the geometric mean of the two rates. 96 | np.testing.assert_allclose( 97 | lr_fn(max_steps / 2), 98 | jnp.sqrt(lr_init * lr_final), 99 | atol=1E-5, 100 | rtol=1E-5) 101 | 102 | # Test that the rate past the end is the final rate 103 | np.testing.assert_allclose( 104 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5) 105 | 106 | def test_delayed_learning_rate_decay(self): 107 | rng = random.PRNGKey(0) 108 | for _ in range(10): 109 | key, rng = random.split(rng) 110 | lr_init = jnp.exp(random.normal(key) - 3) 111 | key, rng = random.split(rng) 112 | lr_final = lr_init * jnp.exp(random.normal(key) - 5) 113 | key, rng = random.split(rng) 114 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key)))) 115 | key, rng = random.split(rng) 116 | lr_delay_steps = int( 117 | random.uniform(key, minval=0.1, maxval=0.4) * max_steps) 118 | key, rng = random.split(rng) 119 | lr_delay_mult = jnp.exp(random.normal(key) - 3) 120 | 121 | lr_fn = functools.partial( 122 | math.learning_rate_decay, 123 | lr_init=lr_init, 124 | lr_final=lr_final, 125 | max_steps=max_steps, 126 | lr_delay_steps=lr_delay_steps, 127 | lr_delay_mult=lr_delay_mult) 128 | 129 | # Test that the rate at the beginning is the delayed initial rate. 130 | np.testing.assert_allclose( 131 | lr_fn(0), lr_delay_mult * lr_init, atol=1E-5, rtol=1E-5) 132 | 133 | # Test that the rate at the end is the final rate. 134 | np.testing.assert_allclose( 135 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5) 136 | 137 | # Test that the rate at after the delay is over is the usual rate. 138 | np.testing.assert_allclose( 139 | lr_fn(lr_delay_steps), 140 | math.learning_rate_decay(lr_delay_steps, lr_init, lr_final, 141 | max_steps), 142 | atol=1E-5, 143 | rtol=1E-5) 144 | 145 | # Test that the rate at the middle is the geometric mean of the two rates. 146 | np.testing.assert_allclose( 147 | lr_fn(max_steps / 2), 148 | jnp.sqrt(lr_init * lr_final), 149 | atol=1E-5, 150 | rtol=1E-5) 151 | 152 | # Test that the rate past the end is the final rate 153 | np.testing.assert_allclose( 154 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5) 155 | 156 | @parameterized.named_parameters(('', False), ('sort', True)) 157 | def test_interp(self, sort): 158 | n, d0, d1 = 100, 10, 20 159 | rng = random.PRNGKey(0) 160 | 161 | key, rng = random.split(rng) 162 | x = random.normal(key, [n, d0]) 163 | 164 | key, rng = random.split(rng) 165 | xp = random.normal(key, [n, d1]) 166 | 167 | key, rng = random.split(rng) 168 | fp = random.normal(key, [n, d1]) 169 | 170 | if sort: 171 | xp = jnp.sort(xp, axis=-1) 172 | fp = jnp.sort(fp, axis=-1) 173 | z = math.sorted_interp(x, xp, fp) 174 | else: 175 | z = math.interp(x, xp, fp) 176 | 177 | z_true = jnp.stack([jnp.interp(x[i], xp[i], fp[i]) for i in range(n)]) 178 | np.testing.assert_allclose(z, z_true, atol=1e-5, rtol=1e-5) 179 | 180 | 181 | if __name__ == '__main__': 182 | absltest.main() 183 | -------------------------------------------------------------------------------- /tests/geopoly_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for geopoly.""" 16 | import itertools 17 | 18 | from absl.testing import absltest 19 | from internal import geopoly 20 | import jax 21 | from jax import random 22 | import numpy as np 23 | 24 | 25 | def is_same_basis(x, y, tol=1e-10): 26 | """Check if `x` and `y` describe the same linear basis.""" 27 | match = np.minimum( 28 | geopoly.compute_sq_dist(x, y), geopoly.compute_sq_dist(x, -y)) <= tol 29 | return (np.all(np.array(x.shape) == np.array(y.shape)) and 30 | np.all(np.sum(match, axis=0) == 1) and 31 | np.all(np.sum(match, axis=1) == 1)) 32 | 33 | 34 | class GeopolyTest(absltest.TestCase): 35 | 36 | def test_compute_sq_dist_reference(self): 37 | """Test against a simple reimplementation of compute_sq_dist.""" 38 | num_points = 100 39 | num_dims = 10 40 | rng = random.PRNGKey(0) 41 | key, rng = random.split(rng) 42 | mat0 = jax.random.normal(key, [num_dims, num_points]) 43 | key, rng = random.split(rng) 44 | mat1 = jax.random.normal(key, [num_dims, num_points]) 45 | 46 | sq_dist = geopoly.compute_sq_dist(mat0, mat1) 47 | 48 | sq_dist_ref = np.zeros([num_points, num_points]) 49 | for i in range(num_points): 50 | for j in range(num_points): 51 | sq_dist_ref[i, j] = np.sum((mat0[:, i] - mat1[:, j])**2) 52 | 53 | np.testing.assert_allclose(sq_dist, sq_dist_ref, atol=1e-4, rtol=1e-4) 54 | 55 | def test_compute_sq_dist_single_input(self): 56 | """Test that compute_sq_dist with a single input works correctly.""" 57 | rng = random.PRNGKey(0) 58 | num_points = 100 59 | num_dims = 10 60 | key, rng = random.split(rng) 61 | mat0 = jax.random.normal(key, [num_dims, num_points]) 62 | 63 | sq_dist = geopoly.compute_sq_dist(mat0) 64 | sq_dist_ref = geopoly.compute_sq_dist(mat0, mat0) 65 | np.testing.assert_allclose(sq_dist, sq_dist_ref) 66 | 67 | def test_compute_tesselation_weights_reference(self): 68 | """A reference implementation for triangle tesselation.""" 69 | for v in range(1, 10): 70 | w = geopoly.compute_tesselation_weights(v) 71 | perm = np.array(list(itertools.product(range(v + 1), repeat=3))) 72 | w_ref = perm[np.sum(perm, axis=-1) == v, :] / v 73 | # Check that all rows of x are close to some row in x_ref. 74 | self.assertTrue(is_same_basis(w.T, w_ref.T)) 75 | 76 | def test_generate_basis_golden(self): 77 | """A mediocre golden test against two arbitrary basis choices.""" 78 | basis = geopoly.generate_basis('icosahedron', 2) 79 | basis_golden = np.array([[0.85065081, 0.00000000, 0.52573111], 80 | [0.80901699, 0.50000000, 0.30901699], 81 | [0.52573111, 0.85065081, 0.00000000], 82 | [1.00000000, 0.00000000, 0.00000000], 83 | [0.80901699, 0.50000000, -0.30901699], 84 | [0.85065081, 0.00000000, -0.52573111], 85 | [0.30901699, 0.80901699, -0.50000000], 86 | [0.00000000, 0.52573111, -0.85065081], 87 | [0.50000000, 0.30901699, -0.80901699], 88 | [0.00000000, 1.00000000, 0.00000000], 89 | [-0.52573111, 0.85065081, 0.00000000], 90 | [-0.30901699, 0.80901699, -0.50000000], 91 | [0.00000000, 0.52573111, 0.85065081], 92 | [-0.30901699, 0.80901699, 0.50000000], 93 | [0.30901699, 0.80901699, 0.50000000], 94 | [0.50000000, 0.30901699, 0.80901699], 95 | [0.50000000, -0.30901699, 0.80901699], 96 | [0.00000000, 0.00000000, 1.00000000], 97 | [-0.50000000, 0.30901699, 0.80901699], 98 | [-0.80901699, 0.50000000, 0.30901699], 99 | [-0.80901699, 0.50000000, -0.30901699]]) 100 | self.assertTrue(is_same_basis(basis.T, basis_golden.T)) 101 | 102 | basis = geopoly.generate_basis('octahedron', 4) 103 | basis_golden = np.array([[0.00000000, 0.00000000, -1.00000000], 104 | [0.00000000, -0.31622777, -0.94868330], 105 | [0.00000000, -0.70710678, -0.70710678], 106 | [0.00000000, -0.94868330, -0.31622777], 107 | [0.00000000, -1.00000000, 0.00000000], 108 | [-0.31622777, 0.00000000, -0.94868330], 109 | [-0.40824829, -0.40824829, -0.81649658], 110 | [-0.40824829, -0.81649658, -0.40824829], 111 | [-0.31622777, -0.94868330, 0.00000000], 112 | [-0.70710678, 0.00000000, -0.70710678], 113 | [-0.81649658, -0.40824829, -0.40824829], 114 | [-0.70710678, -0.70710678, 0.00000000], 115 | [-0.94868330, 0.00000000, -0.31622777], 116 | [-0.94868330, -0.31622777, 0.00000000], 117 | [-1.00000000, 0.00000000, 0.00000000], 118 | [0.00000000, -0.31622777, 0.94868330], 119 | [0.00000000, -0.70710678, 0.70710678], 120 | [0.00000000, -0.94868330, 0.31622777], 121 | [0.40824829, -0.40824829, 0.81649658], 122 | [0.40824829, -0.81649658, 0.40824829], 123 | [0.31622777, -0.94868330, 0.00000000], 124 | [0.81649658, -0.40824829, 0.40824829], 125 | [0.70710678, -0.70710678, 0.00000000], 126 | [0.94868330, -0.31622777, 0.00000000], 127 | [0.31622777, 0.00000000, -0.94868330], 128 | [0.40824829, 0.40824829, -0.81649658], 129 | [0.40824829, 0.81649658, -0.40824829], 130 | [0.70710678, 0.00000000, -0.70710678], 131 | [0.81649658, 0.40824829, -0.40824829], 132 | [0.94868330, 0.00000000, -0.31622777], 133 | [0.40824829, -0.40824829, -0.81649658], 134 | [0.40824829, -0.81649658, -0.40824829], 135 | [0.81649658, -0.40824829, -0.40824829]]) 136 | self.assertTrue(is_same_basis(basis.T, basis_golden.T)) 137 | 138 | 139 | if __name__ == '__main__': 140 | absltest.main() 141 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # This file was modified by Deborah Levy 2 | # 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Render script.""" 17 | 18 | import concurrent.futures 19 | import functools 20 | import glob 21 | import os 22 | import time 23 | 24 | from absl import app 25 | from flax.training import checkpoints 26 | import gin 27 | from internal import configs 28 | from internal import datasets 29 | from internal import models 30 | from internal import train_utils 31 | from internal import utils 32 | import jax 33 | from jax import random 34 | from matplotlib import cm 35 | import mediapy as media 36 | import numpy as np 37 | import matplotlib.image 38 | 39 | configs.define_common_flags() 40 | jax.config.parse_flags_with_absl() 41 | 42 | 43 | def create_videos(config, base_dir, out_dir, out_name, num_frames): 44 | """Creates videos out of the images saved to disk.""" 45 | names = [n for n in config.checkpoint_dir.split('/') if n] 46 | # Last two parts of checkpoint path are experiment name and scene name. 47 | exp_name, scene_name = names[-2:] 48 | video_prefix = f'{scene_name}_{exp_name}_{out_name}' 49 | 50 | zpad = max(3, len(str(num_frames - 1))) 51 | idx_to_str = lambda idx: str(idx).zfill(zpad) 52 | 53 | utils.makedirs(base_dir) 54 | 55 | # Load one example frame to get image shape and depth range. 56 | depth_file = os.path.join(out_dir, f'distance_median_{idx_to_str(0)}.tiff') 57 | depth_frame = utils.load_img(depth_file) 58 | shape = depth_frame.shape 59 | p = config.render_dist_percentile 60 | distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p]) 61 | lo, hi = [config.render_dist_curve_fn(x) for x in distance_limits] 62 | print(f'Video shape is {shape[:2]}') 63 | 64 | video_kwargs = { 65 | 'shape': shape[:2], 66 | 'codec': 'h264', 67 | 'fps': config.render_video_fps, 68 | 'crf': config.render_video_crf, 69 | } 70 | 71 | for k in ['color', 'J']: 72 | video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') 73 | input_format = 'gray' if k == 'acc' else 'rgb' 74 | file_ext = 'png' if k in ['color', 'J', 'distance_median'] else 'tiff' 75 | idx = 0 76 | file0 = os.path.join(out_dir, f'{k}_{idx_to_str(0)}.{file_ext}') 77 | if not utils.file_exists(file0): 78 | print(f'Images missing for tag {k}') 79 | continue 80 | print(f'Making video {video_file}...') 81 | with media.VideoWriter( 82 | video_file, **video_kwargs, input_format=input_format) as writer: 83 | for idx in range(num_frames): 84 | img_file = os.path.join(out_dir, f'{k}_{idx_to_str(idx)}.{file_ext}') 85 | if not utils.file_exists(img_file): 86 | ValueError(f'Image file {img_file} does not exist.') 87 | img = utils.load_img(img_file) 88 | if k in ['color', 'J']: 89 | img = img / 255. 90 | elif k.startswith('distance'): 91 | img = config.render_dist_curve_fn(img) 92 | img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1) 93 | img = cm.get_cmap('turbo')(img)[..., :3] 94 | 95 | frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8) 96 | writer.add_image(frame) 97 | idx += 1 98 | 99 | 100 | def main(unused_argv): 101 | config = configs.load_config(save_config=False) 102 | 103 | dataset = datasets.load_dataset('test', config.data_dir, config) 104 | 105 | key = random.PRNGKey(20200823) 106 | _, state, render_eval_pfn, _, _ = train_utils.setup_model(config, key) 107 | 108 | if config.rawnerf_mode: 109 | postprocess_fn = dataset.metadata['postprocess_fn'] 110 | else: 111 | postprocess_fn = lambda z: z 112 | 113 | state = checkpoints.restore_checkpoint(config.checkpoint_dir, state) 114 | step = int(state.step) 115 | print(f'Rendering checkpoint at step {step}.') 116 | 117 | out_name = 'path_renders' if config.render_path else 'test_preds' 118 | out_name = f'{out_name}_step_{step}' 119 | base_dir = config.render_dir 120 | if base_dir is None: 121 | base_dir = os.path.join(config.checkpoint_dir, 'render') 122 | out_dir = os.path.join(base_dir, out_name) 123 | if not utils.isdir(out_dir): 124 | utils.makedirs(out_dir) 125 | 126 | path_fn = lambda x: os.path.join(out_dir, x) 127 | 128 | # Ensure sufficient zero-padding of image indices in output filenames. 129 | zpad = max(3, len(str(dataset.size - 1))) 130 | idx_to_str = lambda idx: str(idx).zfill(zpad) 131 | 132 | if config.render_save_async: 133 | async_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) 134 | async_futures = [] 135 | 136 | def save_fn(fn, *args, **kwargs): 137 | async_futures.append(async_executor.submit(fn, *args, **kwargs)) 138 | else: 139 | def save_fn(fn, *args, **kwargs): 140 | fn(*args, **kwargs) 141 | 142 | for idx in range(dataset.size): 143 | if idx % config.render_num_jobs != config.render_job_id: 144 | continue 145 | # If current image and next image both already exist, skip ahead. 146 | idx_str = idx_to_str(idx) 147 | curr_file = path_fn(f'color_{idx_str}.png') 148 | next_idx_str = idx_to_str(idx + config.render_num_jobs) 149 | next_file = path_fn(f'color_{next_idx_str}.png') 150 | if utils.file_exists(curr_file) and utils.file_exists(next_file): 151 | print(f'Image {idx}/{dataset.size} already exists, skipping') 152 | continue 153 | print(f'Evaluating image {idx + 1}/{dataset.size}') 154 | eval_start_time = time.time() 155 | rays = dataset.generate_ray_batch(idx).rays 156 | train_frac = 1. 157 | rendering = models.render_image( 158 | functools.partial(render_eval_pfn, state.params, train_frac), 159 | rays, None, config) 160 | print(f'Rendered in {(time.time() - eval_start_time):0.3f}s') 161 | 162 | if jax.host_id() != 0: # Only record via host 0. 163 | continue 164 | 165 | rendering['rgb'] = postprocess_fn(rendering['rgb']) 166 | 167 | save_fn( 168 | utils.save_img_u8, rendering['rgb'], path_fn(f'color_{idx_str}.png')) 169 | if 'J' in rendering: 170 | rendering['J'] = postprocess_fn(rendering['J']) 171 | save_fn(utils.save_img_u8, rendering['J'], path_fn(f'J_{idx_str}.png')) 172 | rendering['bs'] = postprocess_fn(rendering['bs']) 173 | save_fn(utils.save_img_u8, rendering['bs'], path_fn(f'bs_{idx_str}.png')) 174 | 175 | 176 | matplotlib.image.imsave(path_fn(f'distance_median_{idx_str}.png'), np.asarray(rendering['distance_median'])) 177 | 178 | if config.render_save_async: 179 | # Wait until all worker threads finish. 180 | async_executor.shutdown(wait=True) 181 | 182 | # This will ensure that exceptions in child threads are raised to the 183 | # main thread. 184 | for future in async_futures: 185 | future.result() 186 | 187 | time.sleep(1) 188 | num_files = len(glob.glob(path_fn('distance_median*.tiff'))) 189 | time.sleep(10) 190 | if jax.host_id() == 0 and num_files == dataset.size: 191 | print(f'All files found, creating videos (job {config.render_job_id}).') 192 | create_videos(config, base_dir, out_dir, out_name, dataset.size) 193 | 194 | # A hack that forces Jax to keep all TPUs alive until every TPU is finished. 195 | x = jax.numpy.ones([jax.local_device_count()]) 196 | x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)) 197 | print(x) 198 | 199 | 200 | if __name__ == '__main__': 201 | with gin.config_scope('eval'): # Use the same scope as eval.py 202 | app.run(main) 203 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/camera.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | from scipy.optimize import root 6 | 7 | 8 | #------------------------------------------------------------------------------- 9 | # 10 | # camera distortion functions for arrays of size (..., 2) 11 | # 12 | #------------------------------------------------------------------------------- 13 | 14 | def simple_radial_distortion(camera, x): 15 | return x * (1. + camera.k1 * np.square(x).sum(axis=-1, keepdims=True)) 16 | 17 | def radial_distortion(camera, x): 18 | r_sq = np.square(x).sum(axis=-1, keepdims=True) 19 | return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) 20 | 21 | def opencv_distortion(camera, x): 22 | x_sq = np.square(x) 23 | xy = np.prod(x, axis=-1, keepdims=True) 24 | r_sq = x_sq.sum(axis=-1, keepdims=True) 25 | 26 | return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + np.concatenate(( 27 | 2. * camera.p1 * xy + camera.p2 * (r_sq + 2. * x_sq), 28 | camera.p1 * (r_sq + 2. * y_sq) + 2. * camera.p2 * xy), 29 | axis=-1) 30 | 31 | 32 | #------------------------------------------------------------------------------- 33 | # 34 | # Camera 35 | # 36 | #------------------------------------------------------------------------------- 37 | 38 | class Camera: 39 | @staticmethod 40 | def GetNumParams(type_): 41 | if type_ == 0 or type_ == 'SIMPLE_PINHOLE': 42 | return 3 43 | if type_ == 1 or type_ == 'PINHOLE': 44 | return 4 45 | if type_ == 2 or type_ == 'SIMPLE_RADIAL': 46 | return 4 47 | if type_ == 3 or type_ == 'RADIAL': 48 | return 5 49 | if type_ == 4 or type_ == 'OPENCV': 50 | return 8 51 | #if type_ == 5 or type_ == 'OPENCV_FISHEYE': 52 | # return 8 53 | #if type_ == 6 or type_ == 'FULL_OPENCV': 54 | # return 12 55 | #if type_ == 7 or type_ == 'FOV': 56 | # return 5 57 | #if type_ == 8 or type_ == 'SIMPLE_RADIAL_FISHEYE': 58 | # return 4 59 | #if type_ == 9 or type_ == 'RADIAL_FISHEYE': 60 | # return 5 61 | #if type_ == 10 or type_ == 'THIN_PRISM_FISHEYE': 62 | # return 12 63 | 64 | # TODO: not supporting other camera types, currently 65 | raise Exception('Camera type not supported') 66 | 67 | 68 | #--------------------------------------------------------------------------- 69 | 70 | @staticmethod 71 | def GetNameFromType(type_): 72 | if type_ == 0: return 'SIMPLE_PINHOLE' 73 | if type_ == 1: return 'PINHOLE' 74 | if type_ == 2: return 'SIMPLE_RADIAL' 75 | if type_ == 3: return 'RADIAL' 76 | if type_ == 4: return 'OPENCV' 77 | #if type_ == 5: return 'OPENCV_FISHEYE' 78 | #if type_ == 6: return 'FULL_OPENCV' 79 | #if type_ == 7: return 'FOV' 80 | #if type_ == 8: return 'SIMPLE_RADIAL_FISHEYE' 81 | #if type_ == 9: return 'RADIAL_FISHEYE' 82 | #if type_ == 10: return 'THIN_PRISM_FISHEYE' 83 | 84 | raise Exception('Camera type not supported') 85 | 86 | 87 | #--------------------------------------------------------------------------- 88 | 89 | def __init__(self, type_, width_, height_, params): 90 | self.width = width_ 91 | self.height = height_ 92 | 93 | if type_ == 0 or type_ == 'SIMPLE_PINHOLE': 94 | self.fx, self.cx, self.cy = params 95 | self.fy = self.fx 96 | self.distortion_func = None 97 | self.camera_type = 0 98 | 99 | elif type_ == 1 or type_ == 'PINHOLE': 100 | self.fx, self.fy, self.cx, self.cy = params 101 | self.distortion_func = None 102 | self.camera_type = 1 103 | 104 | elif type_ == 2 or type_ == 'SIMPLE_RADIAL': 105 | self.fx, self.cx, self.cy, self.k1 = params 106 | self.fy = self.fx 107 | self.distortion_func = simple_radial_distortion 108 | self.camera_type = 2 109 | 110 | elif type_ == 3 or type_ == 'RADIAL': 111 | self.fx, self.cx, self.cy, self.k1, self.k2 = params 112 | self.fy = self.fx 113 | self.distortion_func = radial_distortion 114 | self.camera_type = 3 115 | 116 | elif type_ == 4 or type_ == 'OPENCV': 117 | self.fx, self.fy, self.cx, self.cy = params[:4] 118 | self.k1, self.k2, self.p1, self.p2 = params[4:] 119 | self.distortion_func = opencv_distortion 120 | self.camera_type = 4 121 | 122 | else: 123 | raise Exception('Camera type not supported') 124 | 125 | 126 | #--------------------------------------------------------------------------- 127 | 128 | def __str__(self): 129 | s = (self.GetNameFromType(self.camera_type) + 130 | ' {} {} {}'.format(self.width, self.height, self.fx)) 131 | 132 | if self.camera_type in (1, 4): # PINHOLE, OPENCV 133 | s += ' {}'.format(self.fy) 134 | 135 | s += ' {} {}'.format(self.cx, self.cy) 136 | 137 | if self.camera_type == 2: # SIMPLE_RADIAL 138 | s += ' {}'.format(self.k1) 139 | 140 | elif self.camera_type == 3: # RADIAL 141 | s += ' {} {}'.format(self.k1, self.k2) 142 | 143 | elif self.camera_type == 4: # OPENCV 144 | s += ' {} {} {} {}'.format(self.k1, self.k2, self.p1, self.p2) 145 | 146 | return s 147 | 148 | 149 | #--------------------------------------------------------------------------- 150 | 151 | # return the camera parameters in the same order as the colmap output format 152 | def get_params(self): 153 | if self.camera_type == 0: 154 | return np.array((self.fx, self.cx, self.cy)) 155 | if self.camera_type == 1: 156 | return np.array((self.fx, self.fy, self.cx, self.cy)) 157 | if self.camera_type == 2: 158 | return np.array((self.fx, self.cx, self.cy, self.k1)) 159 | if self.camera_type == 3: 160 | return np.array((self.fx, self.cx, self.cy, self.k1, self.k2)) 161 | if self.camera_type == 4: 162 | return np.array((self.fx, self.fy, self.cx, self.cy, self.k1, 163 | self.k2, self.p1, self.p2)) 164 | 165 | 166 | #--------------------------------------------------------------------------- 167 | 168 | def get_camera_matrix(self): 169 | return np.array( 170 | ((self.fx, 0, self.cx), (0, self.fy, self.cy), (0, 0, 1))) 171 | 172 | def get_inverse_camera_matrix(self): 173 | return np.array( 174 | ((1. / self.fx, 0, -self.cx / self.fx), 175 | (0, 1. / self.fy, -self.cy / self.fy), 176 | (0, 0, 1))) 177 | 178 | @property 179 | def K(self): 180 | return self.get_camera_matrix() 181 | 182 | @property 183 | def K_inv(self): 184 | return self.get_inverse_camera_matrix() 185 | 186 | #--------------------------------------------------------------------------- 187 | 188 | # return the inverse camera matrix 189 | def get_inv_camera_matrix(self): 190 | inv_fx, inv_fy = 1. / self.fx, 1. / self.fy 191 | return np.array(((inv_fx, 0, -inv_fx * self.cx), 192 | (0, inv_fy, -inv_fy * self.cy), 193 | (0, 0, 1))) 194 | 195 | 196 | #--------------------------------------------------------------------------- 197 | 198 | # return an (x, y) pixel coordinate grid for this camera 199 | def get_image_grid(self): 200 | xmin = (0.5 - self.cx) / self.fx 201 | xmax = (self.width - 0.5 - self.cx) / self.fx 202 | ymin = (0.5 - self.cy) / self.fy 203 | ymax = (self.height - 0.5 - self.cy) / self.fy 204 | return np.meshgrid(np.linspace(xmin, xmax, self.width), 205 | np.linspace(ymin, ymax, self.height)) 206 | 207 | 208 | #--------------------------------------------------------------------------- 209 | 210 | # x: array of shape (N,2) or (2,) 211 | # normalized: False if the input points are in pixel coordinates 212 | # denormalize: True if the points should be put back into pixel coordinates 213 | def distort_points(self, x, normalized=True, denormalize=True): 214 | x = np.atleast_2d(x) 215 | 216 | # put the points into normalized camera coordinates 217 | if not normalized: 218 | x -= np.array([[self.cx, self.cy]]) 219 | x /= np.array([[self.fx, self.fy]]) 220 | 221 | # distort, if necessary 222 | if self.distortion_func is not None: 223 | x = self.distortion_func(self, x) 224 | 225 | if denormalize: 226 | x *= np.array([[self.fx, self.fy]]) 227 | x += np.array([[self.cx, self.cy]]) 228 | 229 | return x 230 | 231 | 232 | #--------------------------------------------------------------------------- 233 | 234 | # x: array of shape (N1,N2,...,2), (N,2), or (2,) 235 | # normalized: False if the input points are in pixel coordinates 236 | # denormalize: True if the points should be put back into pixel coordinates 237 | def undistort_points(self, x, normalized=False, denormalize=True): 238 | x = np.atleast_2d(x) 239 | 240 | # put the points into normalized camera coordinates 241 | if not normalized: 242 | x = x - np.array([self.cx, self.cy]) # creates a copy 243 | x /= np.array([self.fx, self.fy]) 244 | 245 | # undistort, if necessary 246 | if self.distortion_func is not None: 247 | def objective(xu): 248 | return (x - self.distortion_func(self, xu.reshape(*x.shape)) 249 | ).ravel() 250 | 251 | xu = root(objective, x).x.reshape(*x.shape) 252 | else: 253 | xu = x 254 | 255 | if denormalize: 256 | xu *= np.array([[self.fx, self.fy]]) 257 | xu += np.array([[self.cx, self.cy]]) 258 | 259 | return xu 260 | -------------------------------------------------------------------------------- /internal/vis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for visualizing things.""" 16 | 17 | from internal import stepfun 18 | import jax.numpy as jnp 19 | from matplotlib import cm 20 | 21 | 22 | def weighted_percentile(x, w, ps, assume_sorted=False): 23 | """Compute the weighted percentile(s) of a single vector.""" 24 | x = x.reshape([-1]) 25 | w = w.reshape([-1]) 26 | if not assume_sorted: 27 | sortidx = jnp.argsort(x) 28 | x, w = x[sortidx], w[sortidx] 29 | acc_w = jnp.cumsum(w) 30 | return jnp.interp(jnp.array(ps) * (acc_w[-1] / 100), acc_w, x) 31 | 32 | 33 | def sinebow(h): 34 | """A cyclic and uniform colormap, see http://basecase.org/env/on-rainbows.""" 35 | f = lambda x: jnp.sin(jnp.pi * x)**2 36 | return jnp.stack([f(3 / 6 - h), f(5 / 6 - h), f(7 / 6 - h)], -1) 37 | 38 | 39 | def matte(vis, acc, dark=0.8, light=1.0, width=8): 40 | """Set non-accumulated pixels to a Photoshop-esque checker pattern.""" 41 | bg_mask = jnp.logical_xor( 42 | (jnp.arange(acc.shape[0]) % (2 * width) // width)[:, None], 43 | (jnp.arange(acc.shape[1]) % (2 * width) // width)[None, :]) 44 | bg = jnp.where(bg_mask, light, dark) 45 | return vis * acc[:, :, None] + (bg * (1 - acc))[:, :, None] 46 | 47 | 48 | def visualize_cmap(value, 49 | weight, 50 | colormap, 51 | lo=None, 52 | hi=None, 53 | percentile=99., 54 | curve_fn=lambda x: x, 55 | modulus=None, 56 | matte_background=True): 57 | """Visualize a 1D image and a 1D weighting according to some colormap. 58 | 59 | Args: 60 | value: A 1D image. 61 | weight: A weight map, in [0, 1]. 62 | colormap: A colormap function. 63 | lo: The lower bound to use when rendering, if None then use a percentile. 64 | hi: The upper bound to use when rendering, if None then use a percentile. 65 | percentile: What percentile of the value map to crop to when automatically 66 | generating `lo` and `hi`. Depends on `weight` as well as `value'. 67 | curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` 68 | before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). 69 | modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If 70 | `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. 71 | matte_background: If True, matte the image over a checkerboard. 72 | 73 | Returns: 74 | A colormap rendering. 75 | """ 76 | # Identify the values that bound the middle of `value' according to `weight`. 77 | lo_auto, hi_auto = weighted_percentile( 78 | value, weight, [50 - percentile / 2, 50 + percentile / 2]) 79 | 80 | # If `lo` or `hi` are None, use the automatically-computed bounds above. 81 | eps = jnp.finfo(jnp.float32).eps 82 | lo = lo or (lo_auto - eps) 83 | hi = hi or (hi_auto + eps) 84 | 85 | # Curve all values. 86 | value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] 87 | 88 | # Wrap the values around if requested. 89 | if modulus: 90 | value = jnp.mod(value, modulus) / modulus 91 | else: 92 | # Otherwise, just scale to [0, 1]. 93 | value = jnp.nan_to_num( 94 | jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1)) 95 | 96 | if colormap: 97 | colorized = colormap(value)[:, :, :3] 98 | else: 99 | if len(value.shape) != 3: 100 | raise ValueError(f'value must have 3 dims but has {len(value.shape)}') 101 | if value.shape[-1] != 3: 102 | raise ValueError( 103 | f'value must have 3 channels but has {len(value.shape[-1])}') 104 | colorized = value 105 | 106 | return matte(colorized, weight) if matte_background else colorized 107 | 108 | 109 | def visualize_coord_mod(coords, acc): 110 | """Visualize the coordinate of each point within its "cell".""" 111 | return matte(((coords + 1) % 2) / 2, acc) 112 | 113 | 114 | def visualize_rays(dist, 115 | dist_range, 116 | weights, 117 | rgbs, 118 | accumulate=False, 119 | renormalize=False, 120 | resolution=2048, 121 | bg_color=0.8): 122 | """Visualize a bundle of rays.""" 123 | dist_vis = jnp.linspace(*dist_range, resolution + 1) 124 | vis_rgb, vis_alpha = [], [] 125 | for ds, ws, rs in zip(dist, weights, rgbs): 126 | vis_rs, vis_ws = [], [] 127 | for d, w, r in zip(ds, ws, rs): 128 | if accumulate: 129 | # Produce the accumulated color and weight at each point along the ray. 130 | w_csum = jnp.cumsum(w, axis=0) 131 | rw_csum = jnp.cumsum((r * w[:, None]), axis=0) 132 | eps = jnp.finfo(jnp.float32).eps 133 | r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum 134 | vis_rs.append(stepfun.resample(dist_vis, d, r.T, use_avg=True).T) 135 | vis_ws.append(stepfun.resample(dist_vis, d, w.T, use_avg=True).T) 136 | vis_rgb.append(jnp.stack(vis_rs)) 137 | vis_alpha.append(jnp.stack(vis_ws)) 138 | vis_rgb = jnp.stack(vis_rgb, axis=1) 139 | vis_alpha = jnp.stack(vis_alpha, axis=1) 140 | 141 | if renormalize: 142 | # Scale the alphas so that the largest value is 1, for visualization. 143 | vis_alpha /= jnp.maximum(jnp.finfo(jnp.float32).eps, jnp.max(vis_alpha)) 144 | 145 | if resolution > vis_rgb.shape[0]: 146 | rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1) 147 | stride = rep * vis_rgb.shape[1] 148 | 149 | vis_rgb = jnp.tile(vis_rgb, (1, 1, rep, 1)).reshape((-1,) + vis_rgb.shape[2:]) 150 | vis_alpha = jnp.tile(vis_alpha, (1, 1, rep)).reshape((-1,) + vis_alpha.shape[2:]) 151 | 152 | # Add a strip of background pixels after each set of levels of rays. 153 | vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:]) 154 | vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:]) 155 | vis_rgb = jnp.concatenate([vis_rgb, jnp.zeros_like(vis_rgb[:, :1])], 156 | axis=1).reshape((-1,) + vis_rgb.shape[2:]) 157 | vis_alpha = jnp.concatenate( 158 | [vis_alpha, jnp.zeros_like(vis_alpha[:, :1])], 159 | axis=1).reshape((-1,) + vis_alpha.shape[2:]) 160 | 161 | # Matte the RGB image over the background. 162 | vis = vis_rgb * vis_alpha[..., None] + (bg_color * (1 - vis_alpha))[..., None] 163 | 164 | # Remove the final row of background pixels. 165 | vis = vis[:-1] 166 | vis_alpha = vis_alpha[:-1] 167 | return vis, vis_alpha 168 | 169 | 170 | def visualize_suite(rendering, rays): 171 | """A wrapper around other visualizations for easy integration.""" 172 | 173 | depth_curve_fn = lambda x: -jnp.log(x + jnp.finfo(jnp.float32).eps) 174 | 175 | rgb = rendering['rgb'] 176 | acc = rendering['acc'] 177 | 178 | distance_mean = rendering['distance_mean'] 179 | distance_median = rendering['distance_median'] 180 | # distance_check = rendering['distance_check'] 181 | distance_p5 = rendering['distance_percentile_5'] 182 | distance_p95 = rendering['distance_percentile_95'] 183 | acc = jnp.where(jnp.isnan(distance_mean), jnp.zeros_like(acc), acc) 184 | 185 | # The xyz coordinates where rays terminate. 186 | coords = rays.origins + rays.directions * distance_mean[:, :, None] 187 | 188 | vis_depth_mean, vis_depth_median = [ #TODO distance Check 189 | visualize_cmap(x, acc, cm.get_cmap('turbo'), curve_fn=depth_curve_fn) 190 | for x in [distance_mean, distance_median] # TODO distance Check 191 | ] 192 | 193 | # Render three depth percentiles directly to RGB channels, where the spacing 194 | # determines the color. delta == big change, epsilon = small change. 195 | # Gray: A strong discontinuitiy, [x-epsilon, x, x+epsilon] 196 | # Purple: A thin but even density, [x-delta, x, x+delta] 197 | # Red: A thin density, then a thick density, [x-delta, x, x+epsilon] 198 | # Blue: A thick density, then a thin density, [x-epsilon, x, x+delta] 199 | vis_depth_triplet = visualize_cmap( 200 | jnp.stack( 201 | [2 * distance_median - distance_p5, distance_median, distance_p95], 202 | axis=-1), 203 | acc, 204 | None, 205 | curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps)) 206 | 207 | dist = rendering['ray_sdist'] 208 | dist_range = (0, 1) 209 | weights = rendering['ray_weights'] 210 | rgbs = [jnp.clip(r, 0, 1) for r in rendering['ray_rgbs']] 211 | 212 | vis_ray_colors, _ = visualize_rays(dist, dist_range, weights, rgbs) 213 | 214 | sqrt_weights = [jnp.sqrt(w) for w in weights] 215 | sqrt_ray_weights, ray_alpha = visualize_rays( 216 | dist, 217 | dist_range, 218 | [jnp.ones_like(lw) for lw in sqrt_weights], 219 | [lw[..., None] for lw in sqrt_weights], 220 | bg_color=0, 221 | ) 222 | sqrt_ray_weights = sqrt_ray_weights[..., 0] 223 | 224 | null_color = jnp.array([1., 0., 0.]) 225 | vis_ray_weights = jnp.where( 226 | ray_alpha[:, :, None] == 0, 227 | null_color[None, None], 228 | visualize_cmap( 229 | sqrt_ray_weights, 230 | jnp.ones_like(sqrt_ray_weights), 231 | cm.get_cmap('gray'), 232 | lo=0, 233 | hi=1, 234 | matte_background=False, 235 | ), 236 | ) 237 | 238 | vis = { 239 | 'color': rgb, 240 | 'acc': acc, 241 | 'color_matte': matte(rgb, acc), 242 | 'depth_mean': vis_depth_mean, 243 | #'depth_check': vis_depth_check, TODO: Deborah depth 244 | 'depth_median': vis_depth_median, 245 | 'depth_triplet': vis_depth_triplet, 246 | 'coords_mod': visualize_coord_mod(coords, acc), 247 | 'ray_colors': vis_ray_colors, 248 | 'ray_weights': vis_ray_weights, 249 | } 250 | 251 | if 'rgb_cc' in rendering: 252 | vis['color_corrected'] = rendering['rgb_cc'] 253 | 254 | # Render every item named "normals*". 255 | for key, val in rendering.items(): 256 | if key.startswith('normals'): 257 | vis[key] = matte(val / 2. + 0.5, acc) 258 | 259 | if 'roughness' in rendering: 260 | vis['roughness'] = matte(jnp.tanh(rendering['roughness']), acc) 261 | 262 | return vis 263 | -------------------------------------------------------------------------------- /tests/coord_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for coord.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from internal import coord 20 | from internal import math 21 | import jax 22 | from jax import random 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | 27 | def sample_covariance(rng, batch_size, num_dims): 28 | """Sample a random covariance matrix.""" 29 | half_cov = jax.random.normal(rng, [batch_size] + [num_dims] * 2) 30 | cov = math.matmul(half_cov, jnp.moveaxis(half_cov, -1, -2)) 31 | return cov 32 | 33 | 34 | def stable_pos_enc(x, n): 35 | """A stable pos_enc for very high degrees, courtesy of Sameer Agarwal.""" 36 | sin_x = np.sin(x) 37 | cos_x = np.cos(x) 38 | output = [] 39 | rotmat = np.array([[cos_x, -sin_x], [sin_x, cos_x]], dtype='double') 40 | for _ in range(n): 41 | output.append(rotmat[::-1, 0, :]) 42 | rotmat = np.einsum('ijn,jkn->ikn', rotmat, rotmat) 43 | return np.reshape(np.transpose(np.stack(output, 0), [2, 1, 0]), [-1, 2 * n]) 44 | 45 | 46 | class CoordTest(parameterized.TestCase): 47 | 48 | def test_stable_pos_enc(self): 49 | """Test that the stable posenc implementation works on multiples of pi/2.""" 50 | n = 10 51 | x = np.linspace(-np.pi, np.pi, 5) 52 | z = stable_pos_enc(x, n).reshape([-1, 2, n]) 53 | z0_true = np.zeros_like(z[:, 0, :]) 54 | z1_true = np.ones_like(z[:, 1, :]) 55 | z0_true[:, 0] = [0, -1, 0, 1, 0] 56 | z1_true[:, 0] = [-1, 0, 1, 0, -1] 57 | z1_true[:, 1] = [1, -1, 1, -1, 1] 58 | z_true = np.stack([z0_true, z1_true], axis=1) 59 | np.testing.assert_allclose(z, z_true, atol=1e-10) 60 | 61 | def test_contract_matches_special_case(self): 62 | """Test the math for Figure 2 of https://arxiv.org/abs/2111.12077.""" 63 | n = 10 64 | _, s_to_t = coord.construct_ray_warps(jnp.reciprocal, 1, jnp.inf) 65 | s = jnp.linspace(0, 1 - jnp.finfo(jnp.float32).eps, n + 1) 66 | tc = coord.contract(s_to_t(s)[:, None])[:, 0] 67 | delta_tc = tc[1:] - tc[:-1] 68 | np.testing.assert_allclose( 69 | delta_tc, np.full_like(delta_tc, 1 / n), atol=1E-5, rtol=1E-5) 70 | 71 | def test_contract_is_bounded(self): 72 | n, d = 10000, 3 73 | rng = random.PRNGKey(0) 74 | key0, key1, rng = random.split(rng, 3) 75 | x = jnp.where(random.bernoulli(key0, shape=[n, d]), 1, -1) * jnp.exp( 76 | random.uniform(key1, [n, d], minval=-10, maxval=10)) 77 | y = coord.contract(x) 78 | self.assertLessEqual(jnp.max(y), 2) 79 | 80 | def test_contract_is_noop_when_norm_is_leq_one(self): 81 | n, d = 10000, 3 82 | rng = random.PRNGKey(0) 83 | key, rng = random.split(rng) 84 | x = random.normal(key, shape=[n, d]) 85 | xc = x / jnp.maximum(1, jnp.linalg.norm(x, axis=-1, keepdims=True)) 86 | 87 | # Sanity check on the test itself. 88 | assert jnp.abs(jnp.max(jnp.linalg.norm(xc, axis=-1)) - 1) < 1e-6 89 | 90 | yc = coord.contract(xc) 91 | np.testing.assert_allclose(xc, yc, atol=1E-5, rtol=1E-5) 92 | 93 | def test_contract_gradients_are_finite(self): 94 | # Construct x such that we probe x == 0, where things are unstable. 95 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1) 96 | grad = jax.grad(lambda x: jnp.sum(coord.contract(x)))(x) 97 | self.assertTrue(jnp.all(jnp.isfinite(grad))) 98 | 99 | def test_inv_contract_gradients_are_finite(self): 100 | z = jnp.stack(jnp.meshgrid(*[jnp.linspace(-2, 2, 21)] * 2), axis=-1) 101 | z = z.reshape([-1, 2]) 102 | z = z[jnp.sum(z**2, axis=-1) < 2, :] 103 | grad = jax.grad(lambda z: jnp.sum(coord.inv_contract(z)))(z) 104 | self.assertTrue(jnp.all(jnp.isfinite(grad))) 105 | 106 | def test_inv_contract_inverts_contract(self): 107 | """Do a round-trip from metric space to contracted space and back.""" 108 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1) 109 | x_recon = coord.inv_contract(coord.contract(x)) 110 | np.testing.assert_allclose(x, x_recon, atol=1E-5, rtol=1E-5) 111 | 112 | @parameterized.named_parameters( 113 | ('05_1e-5', 5, 1e-5), 114 | ('10_1e-4', 10, 1e-4), 115 | ('15_0.005', 15, 0.005), 116 | ('20_0.2', 20, 0.2), # At high degrees, our implementation is unstable. 117 | ('25_2', 25, 2), # 2 is the maximum possible error. 118 | ('30_2', 30, 2), 119 | ) 120 | def test_pos_enc(self, n, tol): 121 | """test pos_enc against a stable recursive implementation.""" 122 | x = np.linspace(-np.pi, np.pi, 10001) 123 | z = coord.pos_enc(x[:, None], 0, n, append_identity=False) 124 | z_stable = stable_pos_enc(x, n) 125 | max_err = np.max(np.abs(z - z_stable)) 126 | print(f'PE of degree {n} has a maximum error of {max_err}') 127 | self.assertLess(max_err, tol) 128 | 129 | def test_pos_enc_matches_integrated(self): 130 | """Integrated positional encoding with a variance of zero must be pos_enc.""" 131 | min_deg = 0 132 | max_deg = 10 133 | np.linspace(-jnp.pi, jnp.pi, 10) 134 | x = jnp.stack( 135 | jnp.meshgrid(*[np.linspace(-jnp.pi, jnp.pi, 10)] * 2), axis=-1) 136 | x = np.linspace(-jnp.pi, jnp.pi, 10000) 137 | z_ipe = coord.integrated_pos_enc(x, jnp.zeros_like(x), min_deg, max_deg) 138 | z_pe = coord.pos_enc(x, min_deg, max_deg, append_identity=False) 139 | # We're using a pretty wide tolerance because IPE uses safe_sin(). 140 | np.testing.assert_allclose(z_pe, z_ipe, atol=1e-4) 141 | 142 | def test_track_linearize(self): 143 | rng = random.PRNGKey(0) 144 | batch_size = 20 145 | for _ in range(30): 146 | # Construct some random Gaussians with dimensionalities in [1, 10]. 147 | key, rng = random.split(rng) 148 | in_dims = random.randint(key, (), 1, 10) 149 | key, rng = random.split(rng) 150 | mean = jax.random.normal(key, [batch_size, in_dims]) 151 | key, rng = random.split(rng) 152 | cov = sample_covariance(key, batch_size, in_dims) 153 | key, rng = random.split(rng) 154 | out_dims = random.randint(key, (), 1, 10) 155 | 156 | # Construct a random affine transformation. 157 | key, rng = random.split(rng) 158 | a_mat = jax.random.normal(key, [out_dims, in_dims]) 159 | key, rng = random.split(rng) 160 | b = jax.random.normal(key, [out_dims]) 161 | 162 | def fn(x): 163 | x_vec = x.reshape([-1, x.shape[-1]]) 164 | y_vec = jax.vmap(lambda z: math.matmul(a_mat, z))(x_vec) + b # pylint:disable=cell-var-from-loop 165 | y = y_vec.reshape(list(x.shape[:-1]) + [y_vec.shape[-1]]) 166 | return y 167 | 168 | # Apply the affine function to the Gaussians. 169 | fn_mean_true = fn(mean) 170 | fn_cov_true = math.matmul(math.matmul(a_mat, cov), a_mat.T) 171 | 172 | # Tracking the Gaussians through a linearized function of a linear 173 | # operator should be the same. 174 | fn_mean, fn_cov = coord.track_linearize(fn, mean, cov) 175 | np.testing.assert_allclose(fn_mean, fn_mean_true, atol=1E-5, rtol=1E-5) 176 | np.testing.assert_allclose(fn_cov, fn_cov_true, atol=1e-5, rtol=1e-5) 177 | 178 | @parameterized.named_parameters(('reciprocal', jnp.reciprocal), 179 | ('log', jnp.log), ('sqrt', jnp.sqrt)) 180 | def test_construct_ray_warps_extents(self, fn): 181 | n = 100 182 | rng = random.PRNGKey(0) 183 | key, rng = random.split(rng) 184 | t_near = jnp.exp(jax.random.normal(key, [n])) 185 | key, rng = random.split(rng) 186 | t_far = t_near + jnp.exp(jax.random.normal(key, [n])) 187 | 188 | t_to_s, s_to_t = coord.construct_ray_warps(fn, t_near, t_far) 189 | 190 | np.testing.assert_allclose( 191 | t_to_s(t_near), jnp.zeros_like(t_near), atol=1E-5, rtol=1E-5) 192 | np.testing.assert_allclose( 193 | t_to_s(t_far), jnp.ones_like(t_far), atol=1E-5, rtol=1E-5) 194 | np.testing.assert_allclose( 195 | s_to_t(jnp.zeros_like(t_near)), t_near, atol=1E-5, rtol=1E-5) 196 | np.testing.assert_allclose( 197 | s_to_t(jnp.ones_like(t_near)), t_far, atol=1E-5, rtol=1E-5) 198 | 199 | def test_construct_ray_warps_special_reciprocal(self): 200 | """Test fn=1/x against its closed form.""" 201 | n = 100 202 | rng = random.PRNGKey(0) 203 | key, rng = random.split(rng) 204 | t_near = jnp.exp(jax.random.normal(key, [n])) 205 | key, rng = random.split(rng) 206 | t_far = t_near + jnp.exp(jax.random.normal(key, [n])) 207 | 208 | key, rng = random.split(rng) 209 | u = jax.random.uniform(key, [n]) 210 | t = t_near * (1 - u) + t_far * u 211 | key, rng = random.split(rng) 212 | s = jax.random.uniform(key, [n]) 213 | 214 | t_to_s, s_to_t = coord.construct_ray_warps(jnp.reciprocal, t_near, t_far) 215 | 216 | # Special cases for fn=reciprocal. 217 | s_to_t_ref = lambda s: 1 / (s / t_far + (1 - s) / t_near) 218 | t_to_s_ref = lambda t: (t_far * (t - t_near)) / (t * (t_far - t_near)) 219 | 220 | np.testing.assert_allclose(t_to_s(t), t_to_s_ref(t), atol=1E-5, rtol=1E-5) 221 | np.testing.assert_allclose(s_to_t(s), s_to_t_ref(s), atol=1E-5, rtol=1E-5) 222 | 223 | def test_expected_sin(self): 224 | normal_samples = random.normal(random.PRNGKey(0), (10000,)) 225 | for mu, var in [(0, 1), (1, 3), (-2, .2), (10, 10)]: 226 | sin_mu = coord.expected_sin(mu, var) 227 | x = jnp.sin(jnp.sqrt(var) * normal_samples + mu) 228 | np.testing.assert_allclose(sin_mu, jnp.mean(x), atol=1e-2) 229 | 230 | def test_integrated_pos_enc(self): 231 | num_dims = 2 # The number of input dimensions. 232 | min_deg = 0 # Must be 0 for this test to work. 233 | max_deg = 4 234 | num_samples = 100000 235 | rng = random.PRNGKey(0) 236 | for _ in range(5): 237 | # Generate a coordinate's mean and covariance matrix. 238 | key, rng = random.split(rng) 239 | mean = random.normal(key, (2,)) 240 | key, rng = random.split(rng) 241 | half_cov = jax.random.normal(key, [num_dims] * 2) 242 | cov = half_cov @ half_cov.T 243 | var = jnp.diag(cov) 244 | # Generate an IPE. 245 | enc = coord.integrated_pos_enc( 246 | mean, 247 | var, 248 | min_deg, 249 | max_deg, 250 | ) 251 | 252 | # Draw samples, encode them, and take their mean. 253 | key, rng = random.split(rng) 254 | samples = random.multivariate_normal(key, mean, cov, [num_samples]) 255 | assert min_deg == 0 256 | enc_samples = np.concatenate( 257 | [stable_pos_enc(x, max_deg) for x in tuple(samples.T)], axis=-1) 258 | # Correct for a different dimension ordering in stable_pos_enc. 259 | enc_gt = jnp.mean(enc_samples, 0) 260 | enc_gt = enc_gt.reshape([num_dims, max_deg * 2]).T.reshape([-1]) 261 | np.testing.assert_allclose(enc, enc_gt, rtol=1e-2, atol=1e-2) 262 | 263 | 264 | if __name__ == '__main__': 265 | absltest.main() 266 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/database.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sqlite3 4 | 5 | 6 | #------------------------------------------------------------------------------- 7 | # convert SQLite BLOBs to/from numpy arrays 8 | 9 | def array_to_blob(arr): 10 | return np.getbuffer(arr) 11 | 12 | def blob_to_array(blob, dtype, shape=(-1,)): 13 | return np.frombuffer(blob, dtype).reshape(*shape) 14 | 15 | 16 | #------------------------------------------------------------------------------- 17 | # convert to/from image pair ids 18 | 19 | MAX_IMAGE_ID = 2**31 - 1 20 | 21 | def get_pair_id(image_id1, image_id2): 22 | if image_id1 > image_id2: 23 | image_id1, image_id2 = image_id2, image_id1 24 | return image_id1 * MAX_IMAGE_ID + image_id2 25 | 26 | 27 | def get_image_ids_from_pair_id(pair_id): 28 | image_id2 = pair_id % MAX_IMAGE_ID 29 | return (pair_id - image_id2) / MAX_IMAGE_ID, image_id2 30 | 31 | 32 | #------------------------------------------------------------------------------- 33 | # create table commands 34 | 35 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( 36 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 37 | model INTEGER NOT NULL, 38 | width INTEGER NOT NULL, 39 | height INTEGER NOT NULL, 40 | params BLOB, 41 | prior_focal_length INTEGER NOT NULL)""" 42 | 43 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( 44 | image_id INTEGER PRIMARY KEY NOT NULL, 45 | rows INTEGER NOT NULL, 46 | cols INTEGER NOT NULL, 47 | data BLOB, 48 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 49 | 50 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( 51 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 52 | name TEXT NOT NULL UNIQUE, 53 | camera_id INTEGER NOT NULL, 54 | prior_qw REAL, 55 | prior_qx REAL, 56 | prior_qy REAL, 57 | prior_qz REAL, 58 | prior_tx REAL, 59 | prior_ty REAL, 60 | prior_tz REAL, 61 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647), 62 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))""" 63 | 64 | CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries ( 65 | pair_id INTEGER PRIMARY KEY NOT NULL, 66 | rows INTEGER NOT NULL, 67 | cols INTEGER NOT NULL, 68 | data BLOB, 69 | config INTEGER NOT NULL, 70 | F BLOB, 71 | E BLOB, 72 | H BLOB)""" 73 | 74 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( 75 | image_id INTEGER PRIMARY KEY NOT NULL, 76 | rows INTEGER NOT NULL, 77 | cols INTEGER NOT NULL, 78 | data BLOB, 79 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 80 | 81 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( 82 | pair_id INTEGER PRIMARY KEY NOT NULL, 83 | rows INTEGER NOT NULL, 84 | cols INTEGER NOT NULL, 85 | data BLOB)""" 86 | 87 | CREATE_NAME_INDEX = \ 88 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" 89 | 90 | CREATE_ALL = "; ".join([CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE, 91 | CREATE_IMAGES_TABLE, CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE, 92 | CREATE_MATCHES_TABLE, CREATE_NAME_INDEX]) 93 | 94 | 95 | #------------------------------------------------------------------------------- 96 | # functional interface for adding objects 97 | 98 | def add_camera(db, model, width, height, params, prior_focal_length=False, 99 | camera_id=None): 100 | # TODO: Parameter count checks 101 | params = np.asarray(params, np.float64) 102 | db.execute("INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", 103 | (camera_id, model, width, height, array_to_blob(params), 104 | prior_focal_length)) 105 | 106 | 107 | def add_descriptors(db, image_id, descriptors): 108 | descriptors = np.ascontiguousarray(descriptors, np.uint8) 109 | db.execute("INSERT INTO descriptors VALUES (?, ?, ?, ?)", 110 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) 111 | 112 | 113 | def add_image(db, name, camera_id, prior_q=np.zeros(4), prior_t=np.zeros(3), 114 | image_id=None): 115 | db.execute("INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 116 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], 117 | prior_q[3], prior_t[0], prior_t[1], prior_t[2])) 118 | 119 | 120 | # config: defaults to fundamental matrix 121 | def add_inlier_matches(db, image_id1, image_id2, matches, config=2, F=None, 122 | E=None, H=None): 123 | assert(len(matches.shape) == 2) 124 | assert(matches.shape[1] == 2) 125 | 126 | if image_id1 > image_id2: 127 | matches = matches[:,::-1] 128 | 129 | if F is not None: 130 | F = np.asarray(F, np.float64) 131 | if E is not None: 132 | E = np.asarray(E, np.float64) 133 | if H is not None: 134 | H = np.asarray(H, np.float64) 135 | 136 | pair_id = get_pair_id(image_id1, image_id2) 137 | matches = np.asarray(matches, np.uint32) 138 | db.execute("INSERT INTO inlier_matches VALUES (?, ?, ?, ?, ?, ?, ?, ?)", 139 | (pair_id,) + matches.shape + (array_to_blob(matches), config, F, E, H)) 140 | 141 | 142 | def add_keypoints(db, image_id, keypoints): 143 | assert(len(keypoints.shape) == 2) 144 | assert(keypoints.shape[1] in [2, 4, 6]) 145 | 146 | keypoints = np.asarray(keypoints, np.float32) 147 | db.execute("INSERT INTO keypoints VALUES (?, ?, ?, ?)", 148 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) 149 | 150 | 151 | # config: defaults to fundamental matrix 152 | def add_matches(db, image_id1, image_id2, matches): 153 | assert(len(matches.shape) == 2) 154 | assert(matches.shape[1] == 2) 155 | 156 | if image_id1 > image_id2: 157 | matches = matches[:,::-1] 158 | 159 | pair_id = get_pair_id(image_id1, image_id2) 160 | matches = np.asarray(matches, np.uint32) 161 | db.execute("INSERT INTO matches VALUES (?, ?, ?, ?)", 162 | (pair_id,) + matches.shape + (array_to_blob(matches),)) 163 | 164 | 165 | #------------------------------------------------------------------------------- 166 | # simple functional interface 167 | 168 | class COLMAPDatabase(sqlite3.Connection): 169 | @staticmethod 170 | def connect(database_path): 171 | return sqlite3.connect(database_path, factory=COLMAPDatabase) 172 | 173 | 174 | def __init__(self, *args, **kwargs): 175 | super(COLMAPDatabase, self).__init__(*args, **kwargs) 176 | 177 | self.initialize_tables = lambda: self.executescript(CREATE_ALL) 178 | 179 | self.initialize_cameras = \ 180 | lambda: self.executescript(CREATE_CAMERAS_TABLE) 181 | self.initialize_descriptors = \ 182 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) 183 | self.initialize_images = \ 184 | lambda: self.executescript(CREATE_IMAGES_TABLE) 185 | self.initialize_inlier_matches = \ 186 | lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE) 187 | self.initialize_keypoints = \ 188 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE) 189 | self.initialize_matches = \ 190 | lambda: self.executescript(CREATE_MATCHES_TABLE) 191 | 192 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) 193 | 194 | 195 | add_camera = add_camera 196 | add_descriptors = add_descriptors 197 | add_image = add_image 198 | add_inlier_matches = add_inlier_matches 199 | add_keypoints = add_keypoints 200 | add_matches = add_matches 201 | 202 | 203 | #------------------------------------------------------------------------------- 204 | 205 | def main(args): 206 | import os 207 | 208 | if os.path.exists(args.database_path): 209 | print("Error: database path already exists -- will not modify it.") 210 | exit() 211 | 212 | db = COLMAPDatabase.connect(args.database_path) 213 | 214 | # 215 | # for convenience, try creating all the tables upfront 216 | # 217 | 218 | db.initialize_tables() 219 | 220 | 221 | # 222 | # create dummy cameras 223 | # 224 | 225 | model1, w1, h1, params1 = 0, 1024, 768, np.array((1024., 512., 384.)) 226 | model2, w2, h2, params2 = 2, 1024, 768, np.array((1024., 512., 384., 0.1)) 227 | 228 | db.add_camera(model1, w1, h1, params1) 229 | db.add_camera(model2, w2, h2, params2) 230 | 231 | 232 | # 233 | # create dummy images 234 | # 235 | 236 | db.add_image("image1.png", 0) 237 | db.add_image("image2.png", 0) 238 | db.add_image("image3.png", 2) 239 | db.add_image("image4.png", 2) 240 | 241 | 242 | # 243 | # create dummy keypoints; note that COLMAP supports 2D keypoints (x, y), 244 | # 4D keypoints (x, y, theta, scale), and 6D affine keypoints 245 | # (x, y, a_11, a_12, a_21, a_22) 246 | # 247 | 248 | N = 1000 249 | kp1 = np.random.rand(N, 2) * (1024., 768.) 250 | kp2 = np.random.rand(N, 2) * (1024., 768.) 251 | kp3 = np.random.rand(N, 2) * (1024., 768.) 252 | kp4 = np.random.rand(N, 2) * (1024., 768.) 253 | 254 | db.add_keypoints(1, kp1) 255 | db.add_keypoints(2, kp2) 256 | db.add_keypoints(3, kp3) 257 | db.add_keypoints(4, kp4) 258 | 259 | 260 | # 261 | # create dummy matches 262 | # 263 | 264 | M = 50 265 | m12 = np.random.randint(N, size=(M, 2)) 266 | m23 = np.random.randint(N, size=(M, 2)) 267 | m34 = np.random.randint(N, size=(M, 2)) 268 | 269 | db.add_matches(1, 2, m12) 270 | db.add_matches(2, 3, m23) 271 | db.add_matches(3, 4, m34) 272 | 273 | 274 | # 275 | # check cameras 276 | # 277 | 278 | rows = db.execute("SELECT * FROM cameras") 279 | 280 | camera_id, model, width, height, params, prior = next(rows) 281 | params = blob_to_array(params, np.float32) 282 | assert model == model1 and width == w1 and height == h1 283 | assert np.allclose(params, params1) 284 | 285 | camera_id, model, width, height, params, prior = next(rows) 286 | params = blob_to_array(params, np.float32) 287 | assert model == model2 and width == w2 and height == h2 288 | assert np.allclose(params, params2) 289 | 290 | 291 | # 292 | # check keypoints 293 | # 294 | 295 | kps = dict( 296 | (image_id, blob_to_array(data, np.float32, (-1, 2))) 297 | for image_id, data in db.execute( 298 | "SELECT image_id, data FROM keypoints")) 299 | 300 | assert np.allclose(kps[1], kp1) 301 | assert np.allclose(kps[2], kp2) 302 | assert np.allclose(kps[3], kp3) 303 | assert np.allclose(kps[4], kp4) 304 | 305 | 306 | # 307 | # check matches 308 | # 309 | 310 | pair_ids = [get_pair_id(*pair) for pair in [(1, 2), (2, 3), (3, 4)]] 311 | 312 | matches = dict( 313 | (get_image_ids_from_pair_id(pair_id), 314 | blob_to_array(data, np.uint32, (-1, 2))) 315 | for pair_id, data in db.execute("SELECT pair_id, data FROM matches")) 316 | 317 | assert np.all(matches[(1, 2)] == m12) 318 | assert np.all(matches[(2, 3)] == m23) 319 | assert np.all(matches[(3, 4)] == m34) 320 | 321 | # 322 | # clean up 323 | # 324 | 325 | db.close() 326 | os.remove(args.database_path) 327 | 328 | #------------------------------------------------------------------------------- 329 | 330 | if __name__ == "__main__": 331 | import argparse 332 | 333 | parser = argparse.ArgumentParser( 334 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 335 | 336 | parser.add_argument("--database_path", type=str, default="database.db") 337 | 338 | args = parser.parse_args() 339 | 340 | main(args) 341 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /internal/configs.py: -------------------------------------------------------------------------------- 1 | # This file was modified by Deborah Levy 2 | #Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for handling configurations.""" 17 | 18 | import dataclasses 19 | from typing import Any, Callable, Optional, Tuple, List 20 | 21 | from absl import flags 22 | from flax.core import FrozenDict 23 | import gin 24 | from internal import utils 25 | import jax 26 | import jax.numpy as jnp 27 | 28 | # gin.add_config_file_search_path('experimental/users/barron/mipnerf360/') 29 | 30 | configurables = { 31 | 'jnp': [jnp.reciprocal, jnp.log, jnp.log1p, jnp.exp, jnp.sqrt, jnp.square], 32 | 'jax.nn': [jax.nn.relu, jax.nn.softplus, jax.nn.silu], 33 | 'jax.nn.initializers.he_normal': [jax.nn.initializers.he_normal()], 34 | 'jax.nn.initializers.he_uniform': [jax.nn.initializers.he_uniform()], 35 | 'jax.nn.initializers.glorot_normal': [jax.nn.initializers.glorot_normal()], 36 | 'jax.nn.initializers.glorot_uniform': [ 37 | jax.nn.initializers.glorot_uniform() 38 | ], 39 | } 40 | 41 | for module, configurables in configurables.items(): 42 | for configurable in configurables: 43 | gin.config.external_configurable(configurable, module=module) 44 | 45 | 46 | @gin.configurable() 47 | @dataclasses.dataclass 48 | class Config: 49 | """Configuration flags for everything.""" 50 | 51 | dataset_loader: str = 'llff' # The type of dataset loader to use. 52 | batching: str = 'all_images' # Batch composition, [single_image, all_images]. 53 | batch_size: int = 16384 # The number of rays/pixels in each batch. 54 | patch_size: int = 1 # Resolution of patches sampled for training batches. 55 | factor: int = 0 # The downsample factor of images, 0 for no downsampling. 56 | load_alphabetical: bool = True # Load images in COLMAP vs alphabetical 57 | # ordering (affects heldout test set). 58 | forward_facing: bool = True # Set to True for forward-facing LLFF captures. 59 | render_path: bool = False # If True, render a path. Used only by LLFF. 60 | llffhold: int = 8 # Use every Nth image for the test set. Used only by LLFF. 61 | # If true, use all input images for training. 62 | llff_use_all_images_for_training: bool = False 63 | use_tiffs: bool = False # If True, use 32-bit TIFFs. Used only by Blender. 64 | compute_disp_metrics: bool = False 65 | compute_normal_metrics: bool = False 66 | gc_every: int = 10000 # The number of steps between garbage collections. 67 | disable_multiscale_loss: bool = False # If True, disable multiscale loss. 68 | randomized: bool = True # Use randomized stratified sampling. 69 | near: float = 2. # Near plane distance. 70 | far: float = 6. # Far plane distance. 71 | checkpoint_dir: Optional[str] = None # Where to log checkpoints. 72 | render_dir: Optional[str] = None # Output rendering directory. 73 | data_dir: Optional[str] = None # Input data directory. 74 | vocab_tree_path: Optional[str] = None # Path to vocab tree for COLMAP. 75 | render_chunk_size: int = 16384 # Chunk size for whole-image renderings. 76 | num_showcase_images: int = 5 # The number of test-set images to showcase. 77 | deterministic_showcase: bool = True # If True, showcase the same images. 78 | vis_num_rays: int = 8 # The number of rays to visualize. 79 | # Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage. 80 | vis_decimate: int = 0 81 | 82 | # Only used by train.py: 83 | max_steps: int = 250000 # The number of optimization steps. 84 | early_exit_steps: Optional[int] = None # Early stopping, for debugging. 85 | checkpoint_every: int = 25000 # The number of steps to save a checkpoint. 86 | print_every: int = 100 # The number of steps between reports to tensorboard. 87 | train_render_every: int = 5000 # Steps between test set renders when training 88 | cast_rays_in_train_step: bool = False # If True, compute rays in train step. 89 | data_loss_type: str = 'rawnerf' # What kind of loss to use ('mse' or 'charb'). 90 | charb_padding: float = 0.001 # The padding used for Charbonnier loss. 91 | data_loss_mult: float = 1.0 # Mult for the finest data term in the loss. 92 | data_coarse_loss_mult: float = 0. # Multiplier for the coarser data terms. 93 | interlevel_loss_mult: float = 1.0 # Mult. for the loss on the proposal MLP. 94 | 95 | weight_decay_mults: FrozenDict[str, Any] = FrozenDict({}) # Weight decays. 96 | # An example that regularizes the NeRF and the first layer of the prop MLP: 97 | # weight_decay_mults = { 98 | # 'NerfMLP_0': 0.00001, 99 | # 'PropMLP_0/Dense_0': 0.001, 100 | # } 101 | # Any model parameter that isn't specified gets a mult of 0. See the 102 | # train_weight_l2_* parameters in TensorBoard to know what can be regularized. 103 | 104 | lr_init: float = 0.002 # The initial learning rate. 105 | lr_final: float = 0.00002 # The final learning rate. 106 | lr_delay_steps: int = 512 # The number of "warmup" learning steps. 107 | lr_delay_mult: float = 0.01 # How much sever the "warmup" should be. 108 | adam_beta1: float = 0.9 # Adam's beta2 hyperparameter. 109 | adam_beta2: float = 0.999 # Adam's beta2 hyperparameter. 110 | adam_eps: float = 1e-6 # Adam's epsilon hyperparameter. 111 | grad_max_norm: float = 0.001 # Gradient clipping magnitude, disabled if == 0. 112 | grad_max_val: float = 0. # Gradient clipping value, disabled if == 0. 113 | distortion_loss_mult: float = 0.01 # Multiplier on the distortion loss. 114 | 115 | # Only used by eval.py: 116 | eval_only_once: bool = True # If True evaluate the model only once, ow loop. 117 | eval_save_output: bool = True # If True save predicted images to disk. 118 | eval_save_ray_data: bool = False # If True save individual ray traces. 119 | eval_render_interval: int = 1 # The interval between images saved to disk. 120 | eval_dataset_limit: int = jnp.iinfo(jnp.int32).max # Num test images to eval. 121 | eval_quantize_metrics: bool = True # If True, run metrics on 8-bit images. 122 | eval_crop_borders: int = 0 # Ignore c border pixels in eval (x[c:-c, c:-c]). 123 | eval_on_train: bool = False 124 | 125 | # Only used by render.py 126 | render_video_fps: int = 60 # Framerate in frames-per-second. 127 | render_video_crf: int = 18 # Constant rate factor for ffmpeg video quality. 128 | render_path_frames: int = 120 # Number of frames in render path. 129 | z_variation: float = 0. # How much height variation in render path. 130 | z_phase: float = 0. # Phase offset for height variation in render path. 131 | render_dist_percentile: float = 0.5 # How much to trim from near/far planes. 132 | render_dist_curve_fn: Callable[..., Any] = jnp.log # How depth is curved. 133 | render_path_file: Optional[str] = None # Numpy render pose file to load. 134 | render_job_id: int = 0 # Render job id. 135 | render_num_jobs: int = 1 # Total number of render jobs. 136 | render_resolution: Optional[Tuple[int, int]] = None # Render resolution, as 137 | # (width, height). 138 | render_focal: Optional[float] = None # Render focal length. 139 | render_camtype: Optional[str] = None # 'perspective', 'fisheye', or 'pano'. 140 | render_spherical: bool = False # Render spherical 360 panoramas. 141 | render_save_async: bool = True # Save to CNS using a separate thread. 142 | 143 | render_spline_keyframes: Optional[str] = None # Text file containing names of 144 | # images to be used as spline 145 | # keyframes, OR directory 146 | # containing those images. 147 | render_spline_n_interp: int = 30 # Num. frames to interpolate per keyframe. 148 | render_spline_degree: int = 5 # Polynomial degree of B-spline interpolation. 149 | render_spline_smoothness: float = .03 # B-spline smoothing factor, 0 for 150 | # exact interpolation of keyframes. 151 | # Interpolate per-frame exposure value from spline keyframes. 152 | render_spline_interpolate_exposure: bool = False 153 | 154 | # # Flags for raw datasets. 155 | rawnerf_mode: bool = False # Load raw images and train in raw color space. 156 | exposure_percentile: float = 97. # Image percentile to expose as white. 157 | num_border_pixels_to_mask: int = 0 # During training, discard N-pixel border 158 | # around each input image. 159 | apply_bayer_mask: bool = False # During training, apply Bayer mosaic mask. 160 | autoexpose_renders: bool = False # During rendering, autoexpose each image. 161 | # For raw test scenes, use affine raw-space color correction. 162 | eval_raw_affine_cc: bool = False 163 | 164 | # underwater flags 165 | use_uw_mlp: bool = False 166 | use_uw_acc_weights_loss: bool = False # use accuracy loss (equation (27) from the paper) on object's weights sum 167 | use_uw_acc_trans_loss: bool = False # use accuracy loss (equation (27) from the paper) on object's transmission like the paper 168 | use_uw_sig_med_loss: bool = False # use std loss on medium's densities to imply smoothness on the densities - not in the paper 169 | uw_initial_acc_weights_loss_mult: float = 0.01 # initial mult factor for acc_weights_loss can be changed after uw_decay_acc iterations 170 | uw_initial_acc_trans_loss_mult: float = 0.01 # initial mult factor for acc_trans_loss can be changed after uw_decay_acc iterations 171 | uw_final_acc_weights_loss_mult: float = 0.01 # final mult factor for acc_weights_loss after uw_decay_acc iterations 172 | uw_final_acc_trans_loss_mult: float = 0.01 # final mult factor for acc_trans_loss after uw_decay_acc iterations 173 | uw_acc_loss_factor: float = 6 # factor which encourages one of the terms in the accuracy loss to be more dominant - for the trans it encourages it to be 1 and for the weights zero. 174 | uw_sig_med_mult: float = 0.01 # mult factor for std loss - not in the paper!! 175 | # uw_usual_rendering : int = 600 176 | # uw_final_usual_rendering: int = 1500 177 | uw_old_model: bool = False # If True same sigmas for attenuation and backscatter. 178 | uw_decay_acc: int = 5000 # number of iterations to change acc_loss_mult value 179 | gen_eq: bool = False # use equations (11)-(14) from the paper instead of (22) 180 | uw_atten_xyz: bool = False # If True use rgb xyz coordinates also as input for sigma_atten prediction 181 | uw_fog_model: bool = False # If True same sigmas for attenuation and backscatter and for the same for all color channels. 182 | uw_rgb_dir: bool = False #If True use view_dir also as input for rgb_obj 183 | 184 | extra_samples: bool = False # If true add extra samples to the beginning of the bs component - (for fog sim) 185 | 186 | 187 | def define_common_flags(): 188 | # Define the flags used by both train.py and eval.py 189 | flags.DEFINE_string('mode', None, 'Required by GINXM, not used.') 190 | flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.') 191 | flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') 192 | flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.') 193 | 194 | 195 | def load_config(save_config=True): 196 | """Load the config, and optionally checkpoint it.""" 197 | gin.parse_config_files_and_bindings( 198 | flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True) 199 | config = Config() 200 | if save_config and jax.host_id() == 0: 201 | utils.makedirs(config.checkpoint_dir) 202 | with utils.open_file(config.checkpoint_dir + '/config.gin', 'w') as f: 203 | f.write(gin.config_str()) 204 | return config 205 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/rotation.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | #------------------------------------------------------------------------------- 6 | # 7 | # Axis-Angle Functions 8 | # 9 | #------------------------------------------------------------------------------- 10 | 11 | # returns the cross product matrix representation of a 3-vector v 12 | def cross_prod_matrix(v): 13 | return np.array(((0., -v[2], v[1]), (v[2], 0., -v[0]), (-v[1], v[0], 0.))) 14 | 15 | #------------------------------------------------------------------------------- 16 | 17 | # www.euclideanspace.com/maths/geometry/rotations/conversions/angleToMatrix/ 18 | # if angle is None, assume ||axis|| == angle, in radians 19 | # if angle is not None, assume that axis is a unit vector 20 | def axis_angle_to_rotation_matrix(axis, angle=None): 21 | if angle is None: 22 | angle = np.linalg.norm(axis) 23 | if np.abs(angle) > np.finfo('float').eps: 24 | axis = axis / angle 25 | 26 | cp_axis = cross_prod_matrix(axis) 27 | return np.eye(3) + ( 28 | np.sin(angle) * cp_axis + (1. - np.cos(angle)) * cp_axis.dot(cp_axis)) 29 | 30 | #------------------------------------------------------------------------------- 31 | 32 | # after some deliberation, I've decided the easiest way to do this is to use 33 | # quaternions as an intermediary 34 | def rotation_matrix_to_axis_angle(R): 35 | return Quaternion.FromR(R).ToAxisAngle() 36 | 37 | #------------------------------------------------------------------------------- 38 | # 39 | # Quaternion 40 | # 41 | #------------------------------------------------------------------------------- 42 | 43 | class Quaternion: 44 | # create a quaternion from an existing rotation matrix 45 | # euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/ 46 | @staticmethod 47 | def FromR(R): 48 | trace = np.trace(R) 49 | 50 | if trace > 0: 51 | qw = 0.5 * np.sqrt(1. + trace) 52 | qx = (R[2,1] - R[1,2]) * 0.25 / qw 53 | qy = (R[0,2] - R[2,0]) * 0.25 / qw 54 | qz = (R[1,0] - R[0,1]) * 0.25 / qw 55 | elif R[0,0] > R[1,1] and R[0,0] > R[2,2]: 56 | s = 2. * np.sqrt(1. + R[0,0] - R[1,1] - R[2,2]) 57 | qw = (R[2,1] - R[1,2]) / s 58 | qx = 0.25 * s 59 | qy = (R[0,1] + R[1,0]) / s 60 | qz = (R[0,2] + R[2,0]) / s 61 | elif R[1,1] > R[2,2]: 62 | s = 2. * np.sqrt(1. + R[1,1] - R[0,0] - R[2,2]) 63 | qw = (R[0,2] - R[2,0]) / s 64 | qx = (R[0,1] + R[1,0]) / s 65 | qy = 0.25 * s 66 | qz = (R[1,2] + R[2,1]) / s 67 | else: 68 | s = 2. * np.sqrt(1. + R[2,2] - R[0,0] - R[1,1]) 69 | qw = (R[1,0] - R[0,1]) / s 70 | qx = (R[0,2] + R[2,0]) / s 71 | qy = (R[1,2] + R[2,1]) / s 72 | qz = 0.25 * s 73 | 74 | return Quaternion(np.array((qw, qx, qy, qz))) 75 | 76 | # if angle is None, assume ||axis|| == angle, in radians 77 | # if angle is not None, assume that axis is a unit vector 78 | @staticmethod 79 | def FromAxisAngle(axis, angle=None): 80 | if angle is None: 81 | angle = np.linalg.norm(axis) 82 | if np.abs(angle) > np.finfo('float').eps: 83 | axis = axis / angle 84 | 85 | qw = np.cos(0.5 * angle) 86 | axis = axis * np.sin(0.5 * angle) 87 | 88 | return Quaternion(np.array((qw, axis[0], axis[1], axis[2]))) 89 | 90 | #--------------------------------------------------------------------------- 91 | 92 | def __init__(self, q=np.array((1., 0., 0., 0.))): 93 | if isinstance(q, Quaternion): 94 | self.q = q.q.copy() 95 | else: 96 | q = np.asarray(q) 97 | if q.size == 4: 98 | self.q = q.copy() 99 | elif q.size == 3: # convert from a 3-vector to a quaternion 100 | self.q = np.empty(4) 101 | self.q[0], self.q[1:] = 0., q.ravel() 102 | else: 103 | raise Exception('Input quaternion should be a 3- or 4-vector') 104 | 105 | def __add__(self, other): 106 | return Quaternion(self.q + other.q) 107 | 108 | def __iadd__(self, other): 109 | self.q += other.q 110 | return self 111 | 112 | # conjugation via the ~ operator 113 | def __invert__(self): 114 | return Quaternion( 115 | np.array((self.q[0], -self.q[1], -self.q[2], -self.q[3]))) 116 | 117 | # returns: self.q * other.q if other is a Quaternion; otherwise performs 118 | # scalar multiplication 119 | def __mul__(self, other): 120 | if isinstance(other, Quaternion): # quaternion multiplication 121 | return Quaternion(np.array(( 122 | self.q[0] * other.q[0] - self.q[1] * other.q[1] - 123 | self.q[2] * other.q[2] - self.q[3] * other.q[3], 124 | self.q[0] * other.q[1] + self.q[1] * other.q[0] + 125 | self.q[2] * other.q[3] - self.q[3] * other.q[2], 126 | self.q[0] * other.q[2] - self.q[1] * other.q[3] + 127 | self.q[2] * other.q[0] + self.q[3] * other.q[1], 128 | self.q[0] * other.q[3] + self.q[1] * other.q[2] - 129 | self.q[2] * other.q[1] + self.q[3] * other.q[0]))) 130 | else: # scalar multiplication (assumed) 131 | return Quaternion(other * self.q) 132 | 133 | def __rmul__(self, other): 134 | return self * other 135 | 136 | def __imul__(self, other): 137 | self.q[:] = (self * other).q 138 | return self 139 | 140 | def __irmul__(self, other): 141 | self.q[:] = (self * other).q 142 | return self 143 | 144 | def __neg__(self): 145 | return Quaternion(-self.q) 146 | 147 | def __sub__(self, other): 148 | return Quaternion(self.q - other.q) 149 | 150 | def __isub__(self, other): 151 | self.q -= other.q 152 | return self 153 | 154 | def __str__(self): 155 | return str(self.q) 156 | 157 | def copy(self): 158 | return Quaternion(self) 159 | 160 | def dot(self, other): 161 | return self.q.dot(other.q) 162 | 163 | # assume the quaternion is nonzero! 164 | def inverse(self): 165 | return Quaternion((~self).q / self.q.dot(self.q)) 166 | 167 | def norm(self): 168 | return np.linalg.norm(self.q) 169 | 170 | def normalize(self): 171 | self.q /= np.linalg.norm(self.q) 172 | return self 173 | 174 | # assume x is a Nx3 numpy array or a numpy 3-vector 175 | def rotate_points(self, x): 176 | x = np.atleast_2d(x) 177 | return x.dot(self.ToR().T) 178 | 179 | # convert to a rotation matrix 180 | def ToR(self): 181 | return np.eye(3) + 2 * np.array(( 182 | (-self.q[2] * self.q[2] - self.q[3] * self.q[3], 183 | self.q[1] * self.q[2] - self.q[3] * self.q[0], 184 | self.q[1] * self.q[3] + self.q[2] * self.q[0]), 185 | ( self.q[1] * self.q[2] + self.q[3] * self.q[0], 186 | -self.q[1] * self.q[1] - self.q[3] * self.q[3], 187 | self.q[2] * self.q[3] - self.q[1] * self.q[0]), 188 | ( self.q[1] * self.q[3] - self.q[2] * self.q[0], 189 | self.q[2] * self.q[3] + self.q[1] * self.q[0], 190 | -self.q[1] * self.q[1] - self.q[2] * self.q[2]))) 191 | 192 | # convert to axis-angle representation, with angle encoded by the length 193 | def ToAxisAngle(self): 194 | # recall that for axis-angle representation (a, angle), with "a" unit: 195 | # q = (cos(angle/2), a * sin(angle/2)) 196 | # below, for readability, "theta" actually means half of the angle 197 | 198 | sin_sq_theta = self.q[1:].dot(self.q[1:]) 199 | 200 | # if theta is non-zero, then we can compute a unique rotation 201 | if np.abs(sin_sq_theta) > np.finfo('float').eps: 202 | sin_theta = np.sqrt(sin_sq_theta) 203 | cos_theta = self.q[0] 204 | 205 | # atan2 is more stable, so we use it to compute theta 206 | # note that we multiply by 2 to get the actual angle 207 | angle = 2. * ( 208 | np.arctan2(-sin_theta, -cos_theta) if cos_theta < 0. else 209 | np.arctan2(sin_theta, cos_theta)) 210 | 211 | return self.q[1:] * (angle / sin_theta) 212 | 213 | # otherwise, the result is singular, and we avoid dividing by 214 | # sin(angle/2) = 0 215 | return np.zeros(3) 216 | 217 | # euclideanspace.com/maths/geometry/rotations/conversions/quaternionToEuler 218 | # this assumes the quaternion is non-zero 219 | # returns yaw, pitch, roll, with application in that order 220 | def ToEulerAngles(self): 221 | qsq = self.q**2 222 | k = 2. * (self.q[0] * self.q[3] + self.q[1] * self.q[2]) / qsq.sum() 223 | 224 | if (1. - k) < np.finfo('float').eps: # north pole singularity 225 | return 2. * np.arctan2(self.q[1], self.q[0]), 0.5 * np.pi, 0. 226 | if (1. + k) < np.finfo('float').eps: # south pole singularity 227 | return -2. * np.arctan2(self.q[1], self.q[0]), -0.5 * np.pi, 0. 228 | 229 | yaw = np.arctan2(2. * (self.q[0] * self.q[2] - self.q[1] * self.q[3]), 230 | qsq[0] + qsq[1] - qsq[2] - qsq[3]) 231 | pitch = np.arcsin(k) 232 | roll = np.arctan2(2. * (self.q[0] * self.q[1] - self.q[2] * self.q[3]), 233 | qsq[0] - qsq[1] + qsq[2] - qsq[3]) 234 | 235 | return yaw, pitch, roll 236 | 237 | #------------------------------------------------------------------------------- 238 | # 239 | # DualQuaternion 240 | # 241 | #------------------------------------------------------------------------------- 242 | 243 | class DualQuaternion: 244 | # DualQuaternion from an existing rotation + translation 245 | @staticmethod 246 | def FromQT(q, t): 247 | return DualQuaternion(qe=(0.5 * np.asarray(t))) * DualQuaternion(q) 248 | 249 | def __init__(self, q0=np.array((1., 0., 0., 0.)), qe=np.zeros(4)): 250 | self.q0, self.qe = Quaternion(q0), Quaternion(qe) 251 | 252 | def __add__(self, other): 253 | return DualQuaternion(self.q0 + other.q0, self.qe + other.qe) 254 | 255 | def __iadd__(self, other): 256 | self.q0 += other.q0 257 | self.qe += other.qe 258 | return self 259 | 260 | # conguation via the ~ operator 261 | def __invert__(self): 262 | return DualQuaternion(~self.q0, ~self.qe) 263 | 264 | def __mul__(self, other): 265 | if isinstance(other, DualQuaternion): 266 | return DualQuaternion( 267 | self.q0 * other.q0, 268 | self.q0 * other.qe + self.qe * other.q0) 269 | elif isinstance(other, complex): # multiplication by a dual number 270 | return DualQuaternion( 271 | self.q0 * other.real, 272 | self.q0 * other.imag + self.qe * other.real) 273 | else: # scalar multiplication (assumed) 274 | return DualQuaternion(other * self.q0, other * self.qe) 275 | 276 | def __rmul__(self, other): 277 | return self.__mul__(other) 278 | 279 | def __imul__(self, other): 280 | tmp = self * other 281 | self.q0, self.qe = tmp.q0, tmp.qe 282 | return self 283 | 284 | def __neg__(self): 285 | return DualQuaternion(-self.q0, -self.qe) 286 | 287 | def __sub__(self, other): 288 | return DualQuaternion(self.q0 - other.q0, self.qe - other.qe) 289 | 290 | def __isub__(self, other): 291 | self.q0 -= other.q0 292 | self.qe -= other.qe 293 | return self 294 | 295 | # q^-1 = q* / ||q||^2 296 | # assume that q0 is nonzero! 297 | def inverse(self): 298 | normsq = complex(q0.dot(q0), 2. * self.q0.q.dot(self.qe.q)) 299 | inv_len_real = 1. / normsq.real 300 | return ~self * complex( 301 | inv_len_real, -normsq.imag * inv_len_real * inv_len_real) 302 | 303 | # returns a complex representation of the real and imaginary parts of the norm 304 | # assume that q0 is nonzero! 305 | def norm(self): 306 | q0_norm = self.q0.norm() 307 | return complex(q0_norm, self.q0.dot(self.qe) / q0_norm) 308 | 309 | # assume that q0 is nonzero! 310 | def normalize(self): 311 | # current length is ||q0|| + eps * ( / ||q0||) 312 | # writing this as a + eps * b, the inverse is 313 | # 1/||q|| = 1/a - eps * b / a^2 314 | norm = self.norm() 315 | inv_len_real = 1. / norm.real 316 | self *= complex(inv_len_real, -norm.imag * inv_len_real * inv_len_real) 317 | return self 318 | 319 | # return the translation vector for this dual quaternion 320 | def getT(self): 321 | return 2 * (self.qe * ~self.q0).q[1:] 322 | 323 | def ToQT(self): 324 | return self.q0, self.getT() 325 | -------------------------------------------------------------------------------- /internal/stepfun.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for manipulating step functions (piecewise-constant 1D functions). 16 | 17 | We have a shared naming and dimension convention for these functions. 18 | All input/output step functions are assumed to be aligned along the last axis. 19 | `t` always indicates the x coordinates of the *endpoints* of a step function. 20 | `y` indicates unconstrained values for the *bins* of a step function 21 | `w` indicates bin weights that sum to <= 1. `p` indicates non-negative bin 22 | values that *integrate* to <= 1. 23 | """ 24 | 25 | from internal import math 26 | import jax 27 | import jax.numpy as jnp 28 | 29 | 30 | def searchsorted(a, v): 31 | """Find indices where v should be inserted into a to maintain order. 32 | 33 | This behaves like jnp.searchsorted (its second output is the same as 34 | jnp.searchsorted's output if all elements of v are in [a[0], a[-1]]) but is 35 | faster because it wastes memory to save some compute. 36 | 37 | Args: 38 | a: tensor, the sorted reference points that we are scanning to see where v 39 | should lie. 40 | v: tensor, the query points that we are pretending to insert into a. Does 41 | not need to be sorted. All but the last dimensions should match or expand 42 | to those of a, the last dimension can differ. 43 | 44 | Returns: 45 | (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the 46 | range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or 47 | last index of a. 48 | """ 49 | i = jnp.arange(a.shape[-1]) 50 | v_ge_a = v[..., None, :] >= a[..., :, None] 51 | idx_lo = jnp.max(jnp.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2) 52 | idx_hi = jnp.min(jnp.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2) 53 | return idx_lo, idx_hi 54 | 55 | 56 | def query(tq, t, y, outside_value=0): 57 | """Look up the values of the step function (t, y) at locations tq.""" 58 | idx_lo, idx_hi = searchsorted(t, tq) 59 | yq = jnp.where(idx_lo == idx_hi, outside_value, 60 | jnp.take_along_axis(y, idx_lo, axis=-1)) 61 | return yq 62 | 63 | 64 | def inner_outer(t0, t1, y1): 65 | """Construct inner and outer measures on (t1, y1) for t0.""" 66 | cy1 = jnp.concatenate([jnp.zeros_like(y1[..., :1]), 67 | jnp.cumsum(y1, axis=-1)], 68 | axis=-1) 69 | idx_lo, idx_hi = searchsorted(t1, t0) 70 | 71 | cy1_lo = jnp.take_along_axis(cy1, idx_lo, axis=-1) 72 | cy1_hi = jnp.take_along_axis(cy1, idx_hi, axis=-1) 73 | 74 | y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1] 75 | y0_inner = jnp.where(idx_hi[..., :-1] <= idx_lo[..., 1:], 76 | cy1_lo[..., 1:] - cy1_hi[..., :-1], 0) 77 | return y0_inner, y0_outer 78 | 79 | 80 | def lossfun_outer(t, w, t_env, w_env, eps=jnp.finfo(jnp.float32).eps): 81 | """The proposal weight should be an upper envelope on the nerf weight.""" 82 | _, w_outer = inner_outer(t, t_env, w_env) 83 | # We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's 84 | # more effective to pull w_outer up than it is to push w_inner down. 85 | # Scaled half-quadratic loss that gives a constant gradient at w_outer = 0. 86 | return jnp.maximum(0, w - w_outer)**2 / (w + eps) 87 | 88 | 89 | def weight_to_pdf(t, w, eps=jnp.finfo(jnp.float32).eps**2): 90 | """Turn a vector of weights that sums to 1 into a PDF that integrates to 1.""" 91 | return w / jnp.maximum(eps, (t[..., 1:] - t[..., :-1])) 92 | 93 | 94 | def pdf_to_weight(t, p): 95 | """Turn a PDF that integrates to 1 into a vector of weights that sums to 1.""" 96 | return p * (t[..., 1:] - t[..., :-1]) 97 | 98 | 99 | def max_dilate(t, w, dilation, domain=(-jnp.inf, jnp.inf)): 100 | """Dilate (via max-pooling) a non-negative step function.""" 101 | t0 = t[..., :-1] - dilation 102 | t1 = t[..., 1:] + dilation 103 | t_dilate = jnp.sort(jnp.concatenate([t, t0, t1], axis=-1), axis=-1) 104 | t_dilate = jnp.clip(t_dilate, *domain) 105 | w_dilate = jnp.max( 106 | jnp.where( 107 | (t0[..., None, :] <= t_dilate[..., None]) 108 | & (t1[..., None, :] > t_dilate[..., None]), 109 | w[..., None, :], 110 | 0, 111 | ), 112 | axis=-1)[..., :-1] 113 | return t_dilate, w_dilate 114 | 115 | 116 | def max_dilate_weights(t, 117 | w, 118 | dilation, 119 | domain=(-jnp.inf, jnp.inf), 120 | renormalize=False, 121 | eps=jnp.finfo(jnp.float32).eps**2): 122 | """Dilate (via max-pooling) a set of weights.""" 123 | p = weight_to_pdf(t, w) 124 | t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain) 125 | w_dilate = pdf_to_weight(t_dilate, p_dilate) 126 | if renormalize: 127 | w_dilate /= jnp.maximum(eps, jnp.sum(w_dilate, axis=-1, keepdims=True)) 128 | return t_dilate, w_dilate 129 | 130 | 131 | def integrate_weights(w): 132 | """Compute the cumulative sum of w, assuming all weight vectors sum to 1. 133 | 134 | The output's size on the last dimension is one greater than that of the input, 135 | because we're computing the integral corresponding to the endpoints of a step 136 | function, not the integral of the interior/bin values. 137 | 138 | Args: 139 | w: Tensor, which will be integrated along the last axis. This is assumed to 140 | sum to 1 along the last axis, and this function will (silently) break if 141 | that is not the case. 142 | 143 | Returns: 144 | cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 145 | """ 146 | cw = jnp.minimum(1, jnp.cumsum(w[..., :-1], axis=-1)) 147 | shape = cw.shape[:-1] + (1,) 148 | # Ensure that the CDF starts with exactly 0 and ends with exactly 1. 149 | cw0 = jnp.concatenate([jnp.zeros(shape), cw, jnp.ones(shape)], axis=-1) 150 | return cw0 151 | 152 | 153 | def invert_cdf(u, t, w_logits, use_gpu_resampling=False): 154 | """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" 155 | # Compute the PDF and CDF for each weight vector. 156 | w = jax.nn.softmax(w_logits, axis=-1) 157 | cw = integrate_weights(w) 158 | # Interpolate into the inverse CDF. 159 | interp_fn = math.interp if use_gpu_resampling else math.sorted_interp 160 | t_new = interp_fn(u, cw, t) 161 | return t_new 162 | 163 | 164 | def sample(rng, 165 | t, 166 | w_logits, 167 | num_samples, 168 | single_jitter=False, 169 | deterministic_center=False, 170 | use_gpu_resampling=False): 171 | """Piecewise-Constant PDF sampling from a step function. 172 | 173 | Args: 174 | rng: random number generator (or None for `linspace` sampling). 175 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) 176 | w_logits: [..., num_bins], logits corresponding to bin weights 177 | num_samples: int, the number of samples. 178 | single_jitter: bool, if True, jitter every sample along each ray by the same 179 | amount in the inverse CDF. Otherwise, jitter each sample independently. 180 | deterministic_center: bool, if False, when `rng` is None return samples that 181 | linspace the entire PDF. If True, skip the front and back of the linspace 182 | so that the centers of each PDF interval are returned. 183 | use_gpu_resampling: bool, If True this resamples the rays based on a 184 | "gather" instruction, which is fast on GPUs but slow on TPUs. If False, 185 | this resamples the rays based on brute-force searches, which is fast on 186 | TPUs, but slow on GPUs. 187 | 188 | Returns: 189 | t_samples: jnp.ndarray(float32), [batch_size, num_samples]. 190 | """ 191 | eps = jnp.finfo(jnp.float32).eps 192 | 193 | # Draw uniform samples. 194 | if rng is None: 195 | # Match the behavior of jax.random.uniform() by spanning [0, 1-eps]. 196 | if deterministic_center: 197 | pad = 1 / (2 * num_samples) 198 | u = jnp.linspace(pad, 1. - pad - eps, num_samples) 199 | else: 200 | u = jnp.linspace(0, 1. - eps, num_samples) 201 | u = jnp.broadcast_to(u, t.shape[:-1] + (num_samples,)) 202 | else: 203 | # `u` is in [0, 1) --- it can be zero, but it can never be 1. 204 | u_max = eps + (1 - eps) / num_samples 205 | max_jitter = (1 - u_max) / (num_samples - 1) - eps 206 | d = 1 if single_jitter else num_samples 207 | u = ( 208 | jnp.linspace(0, 1 - u_max, num_samples) + 209 | jax.random.uniform(rng, t.shape[:-1] + (d,), maxval=max_jitter)) 210 | 211 | return invert_cdf(u, t, w_logits, use_gpu_resampling=use_gpu_resampling) 212 | 213 | 214 | def sample_intervals(rng, 215 | t, 216 | w_logits, 217 | num_samples, 218 | single_jitter=False, 219 | domain=(-jnp.inf, jnp.inf), 220 | use_gpu_resampling=False): 221 | """Sample *intervals* (rather than points) from a step function. 222 | 223 | Args: 224 | rng: random number generator (or None for `linspace` sampling). 225 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) 226 | w_logits: [..., num_bins], logits corresponding to bin weights 227 | num_samples: int, the number of intervals to sample. 228 | single_jitter: bool, if True, jitter every sample along each ray by the same 229 | amount in the inverse CDF. Otherwise, jitter each sample independently. 230 | domain: (minval, maxval), the range of valid values for `t`. 231 | use_gpu_resampling: bool, If True this resamples the rays based on a 232 | "gather" instruction, which is fast on GPUs but slow on TPUs. If False, 233 | this resamples the rays based on brute-force searches, which is fast on 234 | TPUs, but slow on GPUs. 235 | 236 | Returns: 237 | t_samples: jnp.ndarray(float32), [batch_size, num_samples]. 238 | """ 239 | if num_samples <= 1: 240 | raise ValueError(f'num_samples must be > 1, is {num_samples}.') 241 | 242 | # Sample a set of points from the step function. 243 | centers = sample( 244 | rng, 245 | t, 246 | w_logits, 247 | num_samples, 248 | single_jitter, 249 | deterministic_center=True, 250 | use_gpu_resampling=use_gpu_resampling) 251 | 252 | # The intervals we return will span the midpoints of each adjacent sample. 253 | mid = (centers[..., 1:] + centers[..., :-1]) / 2 254 | 255 | # Each first/last fencepost is the reflection of the first/last midpoint 256 | # around the first/last sampled center. We clamp to the limits of the input 257 | # domain, provided by the caller. 258 | minval, maxval = domain 259 | first = jnp.maximum(minval, 2 * centers[..., :1] - mid[..., :1]) 260 | last = jnp.minimum(maxval, 2 * centers[..., -1:] - mid[..., -1:]) 261 | 262 | t_samples = jnp.concatenate([first, mid, last], axis=-1) 263 | return t_samples 264 | 265 | 266 | def lossfun_distortion(t, w): 267 | """Compute iint w[i] w[j] |t[i] - t[j]| di dj.""" 268 | # The loss incurred between all pairs of intervals. 269 | ut = (t[..., 1:] + t[..., :-1]) / 2 270 | dut = jnp.abs(ut[..., :, None] - ut[..., None, :]) 271 | loss_inter = jnp.sum(w * jnp.sum(w[..., None, :] * dut, axis=-1), axis=-1) 272 | 273 | # The loss incurred within each individual interval with itself. 274 | loss_intra = jnp.sum(w**2 * (t[..., 1:] - t[..., :-1]), axis=-1) / 3 275 | 276 | return loss_inter + loss_intra 277 | 278 | 279 | def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi): 280 | """Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi]).""" 281 | # Distortion when the intervals do not overlap. 282 | d_disjoint = jnp.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2) 283 | 284 | # Distortion when the intervals overlap. 285 | d_overlap = (2 * 286 | (jnp.minimum(t0_hi, t1_hi)**3 - jnp.maximum(t0_lo, t1_lo)**3) + 287 | 3 * (t1_hi * t0_hi * jnp.abs(t1_hi - t0_hi) + 288 | t1_lo * t0_lo * jnp.abs(t1_lo - t0_lo) + t1_hi * t0_lo * 289 | (t0_lo - t1_hi) + t1_lo * t0_hi * 290 | (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo)) 291 | 292 | # Are the two intervals not overlapping? 293 | are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi) 294 | 295 | return jnp.where(are_disjoint, d_disjoint, d_overlap) 296 | 297 | 298 | def weighted_percentile(t, w, ps): 299 | """Compute the weighted percentiles of a step function. w's must sum to 1.""" 300 | cw = integrate_weights(w) 301 | # We want to interpolate into the integrated weights according to `ps`. 302 | fn = lambda cw_i, t_i: jnp.interp(jnp.array(ps) / 100, cw_i, t_i) 303 | # Vmap fn to an arbitrary number of leading dimensions. 304 | cw_mat = cw.reshape([-1, cw.shape[-1]]) 305 | t_mat = t.reshape([-1, t.shape[-1]]) 306 | wprctile_mat = (jax.vmap(fn, 0)(cw_mat, t_mat)) 307 | wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),)) 308 | return wprctile 309 | 310 | 311 | def resample(t, tp, vp, use_avg=False, eps=jnp.finfo(jnp.float32).eps): 312 | """Resample a step function defined by (tp, vp) into intervals t. 313 | 314 | Notation roughly matches jnp.interp. Resamples by summation by default. 315 | 316 | Args: 317 | t: tensor with shape (..., n+1), the endpoints to resample into. 318 | tp: tensor with shape (..., m+1), the endpoints of the step function being 319 | resampled. 320 | vp: tensor with shape (..., m), the values of the step function being 321 | resampled. 322 | use_avg: bool, if False, return the sum of the step function for each 323 | interval in `t`. If True, return the average, weighted by the width of 324 | each interval in `t`. 325 | eps: float, a small value to prevent division by zero when use_avg=True. 326 | 327 | Returns: 328 | v: tensor with shape (..., n), the values of the resampled step function. 329 | """ 330 | if use_avg: 331 | wp = jnp.diff(tp, axis=-1) 332 | v_numer = resample(t, tp, vp * wp, use_avg=False) 333 | v_denom = resample(t, tp, wp, use_avg=False) 334 | v = v_numer / jnp.maximum(eps, v_denom) 335 | return v 336 | 337 | acc = jnp.cumsum(vp, axis=-1) 338 | acc0 = jnp.concatenate([jnp.zeros(acc.shape[:-1] + (1,)), acc], axis=-1) 339 | acc0_resampled = jnp.vectorize( 340 | jnp.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0) 341 | v = jnp.diff(acc0_resampled, axis=-1) 342 | return v 343 | --------------------------------------------------------------------------------