├── configs ├── render_config.gin ├── blender_refnerf.gin ├── blender_mipnerf.gin └── llff_refnerf.gin ├── requirements.txt ├── .gitignore ├── condor ├── interactive.job ├── eval.job ├── train.job ├── render.job ├── eval.sh ├── batch_jobs.sh ├── render.sh ├── train_debug.sh └── train.sh ├── tests ├── utils_test.py ├── camera_utils_test.py ├── ref_utils_test.py ├── datasets_test.py ├── image_test.py ├── math_test.py ├── geopoly_test.py └── coord_test.py ├── scripts ├── run_all_unit_tests.sh └── local_colmap_and_resize.sh ├── CONTRIBUTING.md ├── internal ├── geopoly.py ├── math.py ├── coord.py ├── ref_utils.py ├── utils.py ├── configs.py ├── image.py ├── render.py ├── vis.py ├── train_utils.py └── stepfun.py ├── render.py ├── LICENSE ├── eval.py ├── train.py └── README.md /configs/render_config.gin: -------------------------------------------------------------------------------- 1 | Config.render_path = True 2 | Config.render_path_frames = 480 3 | Config.render_video_fps = 60 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | Pillow 4 | tensorboard 5 | gin-config 6 | #dm_pix 7 | rawpy 8 | mediapy 9 | flatdict 10 | scipy -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | internal/pycolmap 2 | __pycache__/ 3 | interal/__pycache__/ 4 | tests/__pycache__/ 5 | .DS_Store 6 | .vscode/ 7 | .idea/ 8 | __MACOSX/ 9 | logs/ 10 | *.code-workspace 11 | data 12 | *.pkl 13 | notebooks/ 14 | -------------------------------------------------------------------------------- /condor/interactive.job: -------------------------------------------------------------------------------- 1 | Universe = vanilla 2 | 3 | RequestCpus = 4 4 | Request_GPUs = 1 5 | RequestMemory = 8Gb 6 | +RequestWalltime = 14000 7 | 8 | Requirements = (CUDACapability >= 5.0)&&(CUDAGlobalMemoryMb >= 8000.0)&&(machine != "andromeda.esat.kuleuven.be") 9 | 10 | Initialdir = /users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/condor 11 | Executable = /users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/condor/train.sh 12 | Arguments = ref_shiny/car $(Cluster) configs/blender_refnerf.gin 13 | 14 | Log = ../logs/ref_shiny/car/inter.$(Cluster).log 15 | Output = ../logs/ref_shiny/car/inter.$(Cluster).out 16 | Error = ../logs/ref_shiny/car/inter.$(Cluster).err 17 | 18 | Notification = Complete 19 | Queue 1 20 | -------------------------------------------------------------------------------- /condor/eval.job: -------------------------------------------------------------------------------- 1 | Universe = vanilla 2 | 3 | RequestCpus = 4 4 | Request_GPUs = 1 5 | RequestMemory = 10Gb 6 | +RequestWalltime = 259000 7 | 8 | Requirements = (CUDACapability >= 7.0)&&(CUDAGlobalMemoryMb >= 15000.0)&&(machine != "andromeda.esat.kuleuven.be") 9 | 10 | Initialdir = /users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/condor 11 | Executable = /users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/condor/eval.sh 12 | Arguments = ref_shiny/car 13555 configs/blender_refnerf.gin 13 | 14 | NiceUser = true 15 | 16 | Log = ../logs/ref_shiny/car/eval.$(Cluster).log 17 | Output = ../logs/ref_shiny/car/eval.$(Cluster).out 18 | Error = ../logs/ref_shiny/car/eval.$(Cluster).err 19 | 20 | Notification = Complete 21 | Queue 1 22 | -------------------------------------------------------------------------------- /condor/train.job: -------------------------------------------------------------------------------- 1 | Universe = vanilla 2 | 3 | RequestCpus = 4 4 | Request_GPUs = 1 5 | RequestMemory = 10Gb 6 | +RequestWalltime = 259000 7 | 8 | Requirements = (CUDACapability >= 7.0)&&(CUDAGlobalMemoryMb >= 15000.0)&&(machine != "andromeda.esat.kuleuven.be") 9 | 10 | Initialdir = /users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/condor 11 | Executable = /users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/condor/train.sh 12 | Arguments = ref_shiny/car $(Cluster) configs/blender_refnerf.gin 13 | 14 | NiceUser = true 15 | 16 | Log = ../logs/ref_shiny/car/exp.$(Cluster).log 17 | Output = ../logs/ref_shiny/car/exp.$(Cluster).out 18 | Error = ../logs/ref_shiny/car/exp.$(Cluster).err 19 | 20 | Notification = Complete 21 | Queue 1 22 | -------------------------------------------------------------------------------- /condor/render.job: -------------------------------------------------------------------------------- 1 | Universe = vanilla 2 | 3 | RequestCpus = 4 4 | Request_GPUs = 1 5 | RequestMemory = 10Gb 6 | +RequestWalltime = 86400 7 | 8 | Requirements = (CUDACapability >= 7.0)&&(CUDAGlobalMemoryMb >= 11000.0)&&(machine != "andromeda.esat.kuleuven.be") 9 | 10 | Initialdir = /users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/condor 11 | Executable = /users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/condor/render.sh 12 | Arguments = ref_shiny/car 13790 configs/blender_refnerf.gin 13 | 14 | NiceUser = true 15 | 16 | Log = ../logs/ref_shiny/car/render.$(Cluster).log 17 | Output = ../logs/ref_shiny/car/render.$(Cluster).out 18 | Error = ../logs/ref_shiny/car/render.$(Cluster).err 19 | 20 | Notification = Complete 21 | Queue 1 22 | -------------------------------------------------------------------------------- /condor/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | EXP=$2 5 | CONFIG=$3 6 | DATA_DIR=/esat/topaz/gkouros/datasets/nerf/$1 7 | 8 | source ~/miniconda3/etc/profile.d/conda.sh 9 | conda activate refnerf 10 | 11 | export PATH="/usr/local/cuda-12/bin:/usr/local/cuda/bin:$PATH" 12 | export LD_LIBRARY_PATH="/usr/local/cuda-12/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH" 13 | export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES/CUDA/} 14 | 15 | 16 | DIR=/users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/ 17 | cd ${DIR} 18 | 19 | DEG_VIEW=5 20 | RENDER_CHUNK_SIZE=8192 21 | 22 | python3 eval.py \ 23 | --gin_configs="${DIR}/logs/$NAME/$EXP/config.gin" \ 24 | --gin_bindings="Config.data_dir = '${DIR}/data/$NAME'" \ 25 | --gin_bindings="Config.checkpoint_dir = '${DIR}/logs/$NAME/$EXP'" \ 26 | --gin_bindings="Config.render_chunk_size = $RENDER_CHUNK_SIZE" \ 27 | --gin_bindings="NerfMLP.deg_view = $DEG_VIEW" \ 28 | --logtostderr 29 | 30 | conda deactivate 31 | -------------------------------------------------------------------------------- /condor/batch_jobs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd to the dir of the script so that you can execute it from anywhere 4 | DIR=$( realpath -e -- $( dirname -- ${BASH_SOURCE[0]})) 5 | cd $DIR 6 | echo $DIR 7 | 8 | # ARGS="Model.ray_shape=\'line\'" 9 | # condor_send --jobname "train_ref_pytorch_shiny_car_$ARGS" --conda-env 'refnerf' --gpumem 65 --mem 32 --gpus 1 --cpus 2 --timeout 2.99 --nice 1 \ 10 | # -c "bash train.sh ref_shiny/car EXP configs/blender_refnerf.gin $ARGS" 11 | 12 | # EXP="14256" 13 | # condor_send --jobname "train_ref_pytorch_shiny_car_$ARGS" --conda-env 'refnerf' --gpumem 12 --mem 32 --gpus 1 --cpus 2 --timeout 1.00 --nice 1 \ 14 | # -c "bash eval.sh ref_shiny/car $EXP configs/blender_refnerf.gin $ARGS" 15 | 16 | # EXP="14260" 17 | # condor_send --jobname "train_ref_pytorch_shiny_car_$ARGS" --conda-env 'refnerf' --gpumem 16 --mem 32 --gpus 1 --cpus 2 --timeout 1.00 --nice 1 \ 18 | # -c "bash eval.sh ref_shiny/car $EXP configs/blender_refnerf.gin $ARGS" 19 | -------------------------------------------------------------------------------- /tests/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utils.""" 16 | 17 | from absl.testing import absltest 18 | 19 | from internal import utils 20 | 21 | 22 | class UtilsTest(absltest.TestCase): 23 | 24 | def test_dummy_rays(self): 25 | """Ensures that the dummy Rays object is correctly initialized.""" 26 | rays = utils.dummy_rays() 27 | self.assertEqual(rays.origins.shape[-1], 3) 28 | 29 | 30 | if __name__ == '__main__': 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /scripts/run_all_unit_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | python -m unittest tests.camera_utils_test 18 | python -m unittest tests.geopoly_test 19 | python -m unittest tests.stepfun_test 20 | python -m unittest tests.coord_test 21 | python -m unittest tests.image_test 22 | python -m unittest tests.ref_utils_test 23 | python -m unittest tests.utils_test 24 | python -m unittest tests.datasets_test 25 | python -m unittest tests.math_test 26 | python -m unittest tests.render_test 27 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /condor/render.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | EXP=$2 5 | CONFIG=$3 6 | DATA_DIR=/esat/topaz/gkouros/datasets/nerf/$NAME 7 | 8 | export PATH="/usr/local/cuda-12/bin:/usr/local/cuda/bin:$PATH" 9 | export LD_LIBRARY_PATH="/usr/local/cuda-12/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH" 10 | export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES/CUDA/} 11 | 12 | source ~/miniconda3/etc/profile.d/conda.sh 13 | conda activate refnerf 14 | 15 | DIR=/users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/ 16 | cd ${DIR} 17 | 18 | DEG_VIEW=5 19 | RENDER_CHUNK_SIZE=8192 20 | 21 | if [[ "$CONFIG" == *"llff"* ]]; then 22 | RENDER_PATH=True 23 | else 24 | RENDER_PATH=False 25 | fi 26 | 27 | python3 render.py \ 28 | --gin_configs="${DIR}/logs/$NAME/$EXP/config.gin" \ 29 | --gin_bindings="Config.data_dir = '${DIR}/data/$NAME'" \ 30 | --gin_bindings="Config.checkpoint_dir = '${DIR}/logs/$NAME/$EXP'" \ 31 | --gin_bindings="Config.render_dir = '${DIR}/logs/$NAME/$EXP/render/'" \ 32 | --gin_bindings="Config.render_path = $RENDER_PATH" \ 33 | --gin_bindings="Config.render_path_frames = 480" \ 34 | --gin_bindings="Config.render_video_fps = 60" \ 35 | --gin_bindings="Config.render_chunk_size = $RENDER_CHUNK_SIZE" \ 36 | --gin_bindings="NerfMLP.deg_view = $DEG_VIEW" \ 37 | --logtostderr 38 | 39 | conda deactivate -------------------------------------------------------------------------------- /configs/blender_refnerf.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'blender' 2 | Config.batching = 'single_image' 3 | Config.near = 2 4 | Config.far = 6 5 | Config.eval_render_interval = 5 6 | Config.compute_normal_metrics = False 7 | Config.data_loss_type = 'mse' 8 | Config.distortion_loss_mult = 0.0 9 | Config.orientation_loss_mult = 0.1 10 | Config.orientation_loss_target = 'normals_pred' 11 | Config.predicted_normal_loss_mult = 3e-4 12 | Config.orientation_coarse_loss_mult = 0.01 13 | Config.predicted_normal_coarse_loss_mult = 3e-5 14 | Config.interlevel_loss_mult = 0.0 15 | Config.data_coarse_loss_mult = 0.1 16 | Config.data_loss_mult = 1.0 17 | Config.adam_beta1 = 0.9 18 | Config.adam_beta2 = 0.999 19 | Config.adam_eps = 1e-6 20 | Config.batch_size = 16384 21 | Config.render_chunk_size = 16384 22 | Config.max_steps = 250000 23 | 24 | Model.num_levels = 2 25 | Model.single_mlp = True 26 | Model.num_prop_samples = 128 # This needs to be set despite single_mlp = True. 27 | Model.num_nerf_samples = 128 28 | Model.anneal_slope = 0. 29 | Model.dilation_multiplier = 0. 30 | Model.dilation_bias = 0. 31 | Model.single_jitter = False 32 | Model.resample_padding = 0.01 33 | 34 | NerfMLP.net_depth = 8 35 | NerfMLP.net_width = 256 36 | NerfMLP.net_depth_viewdirs = 8 37 | NerfMLP.net_width_viewdirs = 256 38 | NerfMLP.basis_shape = 'octahedron' 39 | NerfMLP.basis_subdivisions = 1 40 | NerfMLP.disable_density_normals = False 41 | NerfMLP.enable_pred_normals = True 42 | NerfMLP.use_directional_enc = True 43 | NerfMLP.use_reflections = True 44 | NerfMLP.deg_view = 5 45 | NerfMLP.enable_pred_roughness = True 46 | NerfMLP.use_diffuse_color = True 47 | NerfMLP.use_specular_tint = True 48 | NerfMLP.use_n_dot_v = True 49 | NerfMLP.bottleneck_width = 128 50 | NerfMLP.bottleneck_noise = 0.0 51 | NerfMLP.density_bias = 0.5 52 | NerfMLP.max_deg_point = 16 53 | -------------------------------------------------------------------------------- /configs/blender_mipnerf.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'blender' 2 | Config.batching = 'single_image' 3 | Config.near = 2 4 | Config.far = 6 5 | Config.eval_render_interval = 5 6 | Config.compute_normal_metrics = False 7 | Config.data_loss_type = 'mse' 8 | Config.distortion_loss_mult = 0.0 9 | Config.orientation_loss_mult = 0.0 10 | Config.orientation_loss_target = 'normals_pred' 11 | Config.predicted_normal_loss_mult = 0 12 | Config.orientation_coarse_loss_mult = 0.0 13 | Config.predicted_normal_coarse_loss_mult = 0 14 | Config.interlevel_loss_mult = 0 15 | Config.data_coarse_loss_mult = 0.1 16 | Config.data_loss_mult = 1.0 17 | Config.adam_beta1 = 0.9 18 | Config.adam_beta2 = 0.999 19 | Config.adam_eps = 1e-6 20 | Config.batch_size = 16384 21 | Config.render_chunk_size = 16384 22 | Config.max_steps = 250000 23 | Config.dataset_debug_mode = False 24 | 25 | Model.num_levels = 1 26 | Model.single_mlp = True 27 | Model.num_prop_samples = 128 # This needs to be set despite single_mlp = True. 28 | Model.num_nerf_samples = 128 29 | Model.dilation_multiplier = 0. 30 | Model.dilation_bias = 0. 31 | Model.single_jitter = False 32 | Model.resample_padding = 0.01 33 | 34 | NerfMLP.net_depth = 8 35 | NerfMLP.net_width = 256 36 | NerfMLP.net_depth_viewdirs = 8 37 | NerfMLP.net_width_viewdirs = 256 38 | NerfMLP.basis_shape = 'octahedron' 39 | NerfMLP.basis_subdivisions = 1 40 | NerfMLP.disable_density_normals = True 41 | NerfMLP.enable_pred_normals = False 42 | NerfMLP.use_directional_enc = False 43 | NerfMLP.use_reflections = False 44 | NerfMLP.deg_view = 5 45 | NerfMLP.enable_pred_roughness = False 46 | NerfMLP.use_diffuse_color = False 47 | NerfMLP.use_specular_tint = False 48 | NerfMLP.use_n_dot_v = False 49 | NerfMLP.bottleneck_width = 128 50 | NerfMLP.bottleneck_noise = 0.0 51 | NerfMLP.density_bias = 0.5 52 | NerfMLP.max_deg_point = 16 53 | -------------------------------------------------------------------------------- /configs/llff_refnerf.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.batching = 'all_images' 3 | Config.llff_white_background = True 4 | Config.near = 0.2 5 | Config.far = 1e6 6 | Config.factor = 0 7 | Config.forward_facing = False 8 | Config.eval_render_interval = 5 9 | Config.compute_normal_metrics = False 10 | Config.data_loss_type = 'mse' 11 | Config.distortion_loss_mult = 0.0 12 | Config.orientation_loss_mult = 0.1 13 | Config.orientation_loss_target = 'normals_pred' 14 | Config.predicted_normal_loss_mult = 1e-3 15 | Config.orientation_coarse_loss_mult = 0.01 16 | Config.predicted_normal_coarse_loss_mult = 1e-4 17 | Config.interlevel_loss_mult = 0.0 18 | Config.data_coarse_loss_mult = 0.1 19 | Config.data_loss_mult = 1.0 20 | Config.adam_eps = 1e-6 21 | Config.batch_size = 16384 22 | Config.render_chunk_size = 16384 23 | Config.max_steps = 250000 24 | 25 | 26 | Model.num_levels = 2 27 | Model.single_mlp = True 28 | Model.num_prop_samples = 128 # This needs to be set despite single_mlp = True. 29 | Model.num_nerf_samples = 128 30 | Model.anneal_slope = 0. 31 | Model.dilation_multiplier = 0. 32 | Model.dilation_bias = 0. 33 | Model.single_jitter = False 34 | Model.resample_padding = 0.01 35 | Model.opaque_background = False 36 | 37 | NerfMLP.net_depth = 8 38 | NerfMLP.net_width = 256 39 | NerfMLP.net_depth_viewdirs = 8 40 | NerfMLP.net_width_viewdirs = 256 41 | NerfMLP.basis_shape = 'octahedron' 42 | NerfMLP.basis_subdivisions = 1 43 | NerfMLP.disable_density_normals = False 44 | NerfMLP.enable_pred_normals = True 45 | NerfMLP.use_directional_enc = True 46 | NerfMLP.use_reflections = True 47 | NerfMLP.deg_view = 5 48 | NerfMLP.enable_pred_roughness = True 49 | NerfMLP.use_diffuse_color = True 50 | NerfMLP.use_specular_tint = True 51 | NerfMLP.use_n_dot_v = True 52 | NerfMLP.bottleneck_width = 128 53 | NerfMLP.bottleneck_noise = 0. 54 | NerfMLP.density_bias = 0.5 55 | NerfMLP.max_deg_point = 16 56 | -------------------------------------------------------------------------------- /condor/train_debug.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | EXP=$2 5 | CONFIG=$3 6 | DATA_DIR=/esat/topaz/gkouros/datasets/nerf/$1 7 | 8 | source ~/miniconda3/etc/profile.d/conda.sh 9 | conda activate refnerf 10 | 11 | export PATH="/usr/local/cuda-11/bin:/usr/local/cuda/bin:$PATH" 12 | export LD_LIBRARY_PATH="/usr/local/cuda-11/lib64:/usr/local/cuda/lib64:$CONDA_PREFIX/lib/:$LD_LIBRARY_PATH" 13 | 14 | DIR=/users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/ 15 | cd ${DIR} 16 | 17 | DEG_VIEW=5 18 | BATCH_SIZE=1024 19 | RENDER_CHUNK_SIZE=1024 20 | MAX_STEPS=1000 21 | 22 | if [[ "$CONFIG" == *"llff"* ]]; then 23 | RENDER_PATH=True 24 | else 25 | RENDER_PATH=False 26 | fi 27 | 28 | CONFIG_PATH="$CONFIG" 29 | 30 | #TODO: DELETEME 31 | rm -rf ${DIR}/logs/$NAME/$EXP 32 | 33 | python3 train.py \ 34 | --gin_configs="$CONFIG_PATH" \ 35 | --gin_bindings="Config.max_steps = $MAX_STEPS" \ 36 | --gin_bindings="Config.data_dir = '${DIR}/data/$NAME'" \ 37 | --gin_bindings="Config.checkpoint_dir = '${DIR}/logs/$NAME/$EXP'" \ 38 | --gin_bindings="Config.batch_size = $BATCH_SIZE" \ 39 | --gin_bindings="Config.render_chunk_size = $RENDER_CHUNK_SIZE" \ 40 | --gin_bindings="NerfMLP.deg_view = $DEG_VIEW" 41 | # && \ 42 | # python3 render.py \ 43 | # --gin_configs="${DIR}/logs/$NAME/$EXP/config.gin" \ 44 | # --gin_bindings="Config.data_dir = '${DIR}/data/$NAME'" \ 45 | # --gin_bindings="Config.checkpoint_dir = '${DIR}/logs/$NAME/$EXP'" \ 46 | # --gin_bindings="Config.render_dir = '${DIR}/logs/$NAME/$EXP/render/'" \ 47 | # --gin_bindings="Config.render_path = $RENDER_PATH" \ 48 | # --gin_bindings="Config.render_path_frames = 480" \ 49 | # --gin_bindings="Config.render_video_fps = 60" \ 50 | # --gin_bindings="Config.batch_size = $BATCH_SIZE" \ 51 | # --gin_bindings="Config.render_chunk_size = $RENDER_CHUNK_SIZE" \ 52 | # --gin_bindings="NerfMLP.deg_view = $DEG_VIEW" \ 53 | # && \ 54 | # python3 eval.py \ 55 | # --gin_configs="${DIR}/logs/$NAME/$EXP/config.gin" \ 56 | # --gin_bindings="Config.data_dir = '${DIR}/data/$NAME'" \ 57 | # --gin_bindings="Config.checkpoint_dir = '${DIR}/logs/$NAME/$EXP'" \ 58 | # --gin_bindings="Config.batch_size = $BATCH_SIZE" \ 59 | # --gin_bindings="Config.render_chunk_size = $RENDER_CHUNK_SIZE" \ 60 | # --gin_bindings="NerfMLP.deg_view = $DEG_VIEW" 61 | 62 | conda deactivate 63 | -------------------------------------------------------------------------------- /condor/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | EXP=$2 5 | CONFIG=$3 6 | # format the arguments to gin bindings 7 | ARGS=${@:4} # all subsequent args are assumed args for the python script 8 | ARGS=($ARGS) # split comma separated string 9 | ARGS_STR='' 10 | 11 | for (( i=0; i<${#ARGS[@]}; ++i )); 12 | do 13 | ARGS_STR="$ARGS_STR --gin_bindings=${ARGS[$i]}" 14 | done 15 | echo ARGS="$ARGS_STR" 16 | 17 | source ~/miniconda3/etc/profile.d/conda.sh 18 | conda activate refnerf 19 | 20 | export PATH="/usr/local/cuda-12/bin:/usr/local/cuda/bin:$PATH" 21 | export LD_LIBRARY_PATH="/usr/local/cuda-12/lib64:/usr/local/cuda/lib64:$CONDA_PREFIX/lib/:$LD_LIBRARY_PATH" 22 | export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES/CUDA/} 23 | 24 | DATA_DIR=/esat/topaz/gkouros/datasets/nerf/$NAME 25 | DIR=/users/visics/gkouros/projects/nerf-repos/refnerf-pytorch/ 26 | cd ${DIR} 27 | 28 | DEG_VIEW=5 29 | BATCH_SIZE=1024 30 | RENDER_CHUNK_SIZE=8192 31 | MAX_STEPS=250000 32 | 33 | if [[ "$CONFIG" == *"llff"* ]]; then 34 | RENDER_PATH=True 35 | else 36 | RENDER_PATH=False 37 | fi 38 | 39 | # If job gets evicted reload generated config file not original that might have been modified 40 | if [ -f "${DIR}/logs/$NAME/$EXP/config.gin" ]; then 41 | CONFIG_PATH="${DIR}/logs/$NAME/$EXP/config.gin" 42 | else 43 | CONFIG_PATH="$CONFIG" 44 | fi 45 | 46 | python3 train.py \ 47 | --gin_configs="$CONFIG_PATH" \ 48 | --gin_bindings="Config.max_steps = $MAX_STEPS" \ 49 | --gin_bindings="Config.data_dir = '${DIR}/data/$NAME'" \ 50 | --gin_bindings="Config.checkpoint_dir = '${DIR}/logs/$NAME/$EXP'" \ 51 | --gin_bindings="Config.batch_size = $BATCH_SIZE" \ 52 | --gin_bindings="Config.render_chunk_size = $RENDER_CHUNK_SIZE" \ 53 | --gin_bindings="NerfMLP.deg_view = $DEG_VIEW" \ 54 | $ARGS_STR \ 55 | && \ 56 | python3 render.py \ 57 | --gin_configs="${DIR}/logs/$NAME/$EXP/config.gin" \ 58 | --gin_bindings="Config.data_dir = '${DIR}/data/$NAME'" \ 59 | --gin_bindings="Config.checkpoint_dir = '${DIR}/logs/$NAME/$EXP'" \ 60 | --gin_bindings="Config.render_dir = '${DIR}/logs/$NAME/$EXP/render/'" \ 61 | --gin_bindings="Config.render_path = $RENDER_PATH" \ 62 | --gin_bindings="Config.render_path_frames = 480" \ 63 | --gin_bindings="Config.render_video_fps = 60" \ 64 | --gin_bindings="Config.render_chunk_size = $RENDER_CHUNK_SIZE" \ 65 | --gin_bindings="NerfMLP.deg_view = $DEG_VIEW" \ 66 | && \ 67 | python3 eval.py \ 68 | --gin_configs="${DIR}/logs/$NAME/$EXP/config.gin" \ 69 | --gin_bindings="Config.data_dir = '${DIR}/data/$NAME'" \ 70 | --gin_bindings="Config.checkpoint_dir = '${DIR}/logs/$NAME/$EXP'" \ 71 | --gin_bindings="Config.render_chunk_size = $RENDER_CHUNK_SIZE" \ 72 | --gin_bindings="NerfMLP.deg_view = $DEG_VIEW" 73 | 74 | conda deactivate 75 | -------------------------------------------------------------------------------- /tests/camera_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for camera_utils.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from internal import camera_utils 20 | from jax import random 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | class CameraUtilsTest(parameterized.TestCase): 26 | 27 | def test_convert_to_ndc(self): 28 | rng = random.PRNGKey(0) 29 | for _ in range(10): 30 | # Random pinhole camera intrinsics. 31 | key, rng = random.split(rng) 32 | focal, width, height = random.uniform(key, (3,), minval=100., maxval=200.) 33 | camtopix = camera_utils.intrinsic_matrix(focal, focal, width / 2., 34 | height / 2.) 35 | pixtocam = np.linalg.inv(camtopix) 36 | near = 1. 37 | 38 | # Random rays, pointing forward (negative z direction). 39 | num_rays = 1000 40 | key, rng = random.split(rng) 41 | origins = jnp.array([0., 0., 1.]) 42 | origins += random.uniform(key, (num_rays, 3), minval=-1., maxval=1.) 43 | directions = jnp.array([0., 0., -1.]) 44 | directions += random.uniform(key, (num_rays, 3), minval=-.5, maxval=.5) 45 | 46 | # Project world-space points along each ray into NDC space. 47 | t = jnp.linspace(0., 1., 10) 48 | pts_world = origins + t[:, None, None] * directions 49 | pts_ndc = jnp.stack([ 50 | -focal / (.5 * width) * pts_world[..., 0] / pts_world[..., 2], 51 | -focal / (.5 * height) * pts_world[..., 1] / pts_world[..., 2], 52 | 1. + 2. * near / pts_world[..., 2], 53 | ], 54 | axis=-1) 55 | 56 | # Get NDC space rays. 57 | origins_ndc, directions_ndc = camera_utils.convert_to_ndc( 58 | origins, directions, pixtocam, near) 59 | 60 | # Ensure that the NDC space points lie on the calculated rays. 61 | directions_ndc_norm = jnp.linalg.norm( 62 | directions_ndc, axis=-1, keepdims=True) 63 | directions_ndc_unit = directions_ndc / directions_ndc_norm 64 | projection = ((pts_ndc - origins_ndc) * directions_ndc_unit).sum(axis=-1) 65 | pts_ndc_proj = origins_ndc + directions_ndc_unit * projection[..., None] 66 | 67 | # pts_ndc should be close to their projections pts_ndc_proj onto the rays. 68 | np.testing.assert_allclose(pts_ndc, pts_ndc_proj, atol=1e-5, rtol=1e-5) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /scripts/local_colmap_and_resize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Set to 0 if you do not have a GPU. 18 | USE_GPU=1 19 | # Path to a directory `base/` with images in `base/images/`. 20 | DATASET_PATH=$1 21 | # Recommended CAMERA values: OPENCV for perspective, OPENCV_FISHEYE for fisheye. 22 | CAMERA=${2:-OPENCV} 23 | 24 | 25 | # Run COLMAP. 26 | 27 | ### Feature extraction 28 | 29 | colmap feature_extractor \ 30 | --database_path "$DATASET_PATH"/database.db \ 31 | --image_path "$DATASET_PATH"/images \ 32 | --ImageReader.single_camera 1 \ 33 | --ImageReader.camera_model "$CAMERA" \ 34 | --SiftExtraction.use_gpu "$USE_GPU" 35 | 36 | 37 | ### Feature matching 38 | 39 | colmap exhaustive_matcher \ 40 | --database_path "$DATASET_PATH"/database.db \ 41 | --SiftMatching.use_gpu "$USE_GPU" 42 | 43 | ## Use if your scene has > 500 images 44 | ## Replace this path with your own local copy of the file. 45 | ## Download from: https://demuc.de/colmap/#download 46 | # VOCABTREE_PATH=/usr/local/google/home/bmild/vocab_tree_flickr100K_words32K.bin 47 | # colmap vocab_tree_matcher \ 48 | # --database_path "$DATASET_PATH"/database.db \ 49 | # --VocabTreeMatching.vocab_tree_path $VOCABTREE_PATH \ 50 | # --SiftMatching.use_gpu "$USE_GPU" 51 | 52 | 53 | ### Bundle adjustment 54 | 55 | # The default Mapper tolerance is unnecessarily large, 56 | # decreasing it speeds up bundle adjustment steps. 57 | mkdir -p "$DATASET_PATH"/sparse 58 | colmap mapper \ 59 | --database_path "$DATASET_PATH"/database.db \ 60 | --image_path "$DATASET_PATH"/images \ 61 | --output_path "$DATASET_PATH"/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001 63 | 64 | 65 | ### Image undistortion 66 | 67 | ## Use this if you want to undistort your images into ideal pinhole intrinsics. 68 | # mkdir -p "$DATASET_PATH"/dense 69 | # colmap image_undistorter \ 70 | # --image_path "$DATASET_PATH"/images \ 71 | # --input_path "$DATASET_PATH"/sparse/0 \ 72 | # --output_path "$DATASET_PATH"/dense \ 73 | # --output_type COLMAP 74 | 75 | # Resize images. 76 | 77 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_2 78 | 79 | pushd "$DATASET_PATH"/images_2 80 | ls | xargs -P 8 -I {} mogrify -resize 50% {} 81 | popd 82 | 83 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_4 84 | 85 | pushd "$DATASET_PATH"/images_4 86 | ls | xargs -P 8 -I {} mogrify -resize 25% {} 87 | popd 88 | 89 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_8 90 | 91 | pushd "$DATASET_PATH"/images_8 92 | ls | xargs -P 8 -I {} mogrify -resize 12.5% {} 93 | popd 94 | -------------------------------------------------------------------------------- /tests/ref_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ref_utils.""" 16 | 17 | from absl.testing import absltest 18 | from internal import ref_utils 19 | from jax import random 20 | import jax.numpy as jnp 21 | import numpy as np 22 | import scipy 23 | 24 | 25 | def generate_dir_enc_fn_scipy(deg_view): 26 | """Return spherical harmonics using scipy.special.sph_harm.""" 27 | ml_array = ref_utils.get_ml_array(deg_view) 28 | 29 | def dir_enc_fn(theta, phi): 30 | de = [scipy.special.sph_harm(m, l, phi, theta) for m, l in ml_array.T] 31 | de = np.stack(de, axis=-1) 32 | # Split into real and imaginary parts. 33 | return np.concatenate([np.real(de), np.imag(de)], axis=-1) 34 | 35 | return dir_enc_fn 36 | 37 | 38 | class RefUtilsTest(absltest.TestCase): 39 | 40 | def test_reflection(self): 41 | """Make sure reflected vectors have the same angle from normals as input.""" 42 | rng = random.PRNGKey(0) 43 | for shape in [(45, 3), (4, 7, 3)]: 44 | key, rng = random.split(rng) 45 | normals = random.normal(key, shape) 46 | key, rng = random.split(rng) 47 | directions = random.normal(key, shape) 48 | 49 | # Normalize normal vectors. 50 | normals = normals / ( 51 | jnp.linalg.norm(normals, axis=-1, keepdims=True) + 1e-10) 52 | 53 | reflected_directions = ref_utils.reflect(directions, normals) 54 | 55 | cos_angle_original = jnp.sum(directions * normals, axis=-1) 56 | cos_angle_reflected = jnp.sum(reflected_directions * normals, axis=-1) 57 | 58 | np.testing.assert_allclose( 59 | cos_angle_original, cos_angle_reflected, atol=1E-5, rtol=1E-5) 60 | 61 | def test_spherical_harmonics(self): 62 | """Make sure the fast spherical harmonics are accurate.""" 63 | shape = (12, 11, 13) 64 | 65 | # Generate random points on sphere. 66 | rng = random.PRNGKey(0) 67 | key1, key2 = random.split(rng) 68 | theta = random.uniform(key1, shape, minval=0.0, maxval=jnp.pi) 69 | phi = random.uniform(key2, shape, minval=0.0, maxval=2.0*jnp.pi) 70 | 71 | # Convert to Cartesian coordinates. 72 | x = jnp.sin(theta) * jnp.cos(phi) 73 | y = jnp.sin(theta) * jnp.sin(phi) 74 | z = jnp.cos(theta) 75 | xyz = jnp.stack([x, y, z], axis=-1) 76 | 77 | deg_view = 5 78 | de = ref_utils.generate_dir_enc_fn(deg_view)(xyz) 79 | de_scipy = generate_dir_enc_fn_scipy(deg_view)(theta, phi) 80 | 81 | np.testing.assert_allclose( 82 | de, de_scipy, atol=0.02, rtol=1e6) # Only use atol. 83 | self.assertFalse(jnp.any(jnp.isnan(de))) 84 | 85 | 86 | if __name__ == '__main__': 87 | absltest.main() 88 | -------------------------------------------------------------------------------- /tests/datasets_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for datasets.""" 16 | 17 | from absl.testing import absltest 18 | from internal import camera_utils 19 | from internal import configs 20 | from internal import datasets 21 | from jax import random 22 | import numpy as np 23 | 24 | 25 | class DummyDataset(datasets.Dataset): 26 | 27 | def _load_renderings(self, config): 28 | """Generates dummy image and pose data.""" 29 | self._n_examples = 2 30 | self.height = 3 31 | self.width = 4 32 | self._resolution = self.height * self.width 33 | self.focal = 5. 34 | self.pixtocams = np.linalg.inv( 35 | camera_utils.intrinsic_matrix(self.focal, self.focal, self.width * 0.5, 36 | self.height * 0.5)) 37 | 38 | rng = random.PRNGKey(0) 39 | 40 | key, rng = random.split(rng) 41 | images_shape = (self._n_examples, self.height, self.width, 3) 42 | self.images = random.uniform(key, images_shape) 43 | 44 | key, rng = random.split(rng) 45 | self.camtoworlds = np.stack([ 46 | camera_utils.viewmatrix(*random.normal(k, (3, 3))) 47 | for k in random.split(key, self._n_examples) 48 | ], 49 | axis=0) 50 | 51 | 52 | class DatasetsTest(absltest.TestCase): 53 | 54 | def test_dataset_batch_creation(self): 55 | np.random.seed(0) 56 | config = configs.Config(batch_size=8) 57 | 58 | # Check shapes are consistent across all ray attributes. 59 | for split in ['train', 'test']: 60 | dummy_dataset = DummyDataset(split, '', config) 61 | rays = dummy_dataset.peek().rays 62 | sh_gt = rays.origins.shape[:-1] 63 | for z in rays.__dict__.values(): 64 | if z is not None: 65 | self.assertEqual(z.shape[:-1], sh_gt) 66 | 67 | # Check test batch generation matches golden data. 68 | dummy_dataset = DummyDataset('test', '', config) 69 | batch = dummy_dataset.peek() 70 | 71 | rgb = batch.rgb.ravel() 72 | rgb_gt = np.array([ 73 | 0.5289556, 0.28869557, 0.24527192, 0.12083626, 0.8904066, 0.6259936, 74 | 0.57573485, 0.09355974, 0.8017353, 0.538651, 0.4998169, 0.42061496, 75 | 0.5591258, 0.00577283, 0.6804651, 0.9139203, 0.00444758, 0.96962905, 76 | 0.52956843, 0.38282406, 0.28777933, 0.6640035, 0.39736128, 0.99495006, 77 | 0.13100398, 0.7597165, 0.8532667, 0.67468107, 0.6804743, 0.26873016, 78 | 0.60699487, 0.5722265, 0.44482303, 0.6511061, 0.54807067, 0.09894073 79 | ]) 80 | np.testing.assert_allclose(rgb, rgb_gt, atol=1e-4, rtol=1e-4) 81 | 82 | ray_origins = batch.rays.origins.ravel() 83 | ray_origins_gt = np.array([ 84 | -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, 85 | -0.8818224, -0.20050469, -0.6451472, -0.8818224, -0.20050469, 86 | -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224, 87 | -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, 88 | -0.8818224, -0.20050469, -0.6451472, -0.8818224, -0.20050469, 89 | -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224, 90 | -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224 91 | ]) 92 | np.testing.assert_allclose( 93 | ray_origins, ray_origins_gt, atol=1e-4, rtol=1e-4) 94 | 95 | ray_dirs = batch.rays.directions.ravel() 96 | ray_dirs_gt = np.array([ 97 | 0.24370372, 0.89296186, -0.5227117, 0.05601424, 0.8468699, -0.57417226, 98 | -0.13167524, 0.8007779, -0.62563276, -0.31936473, 0.75468594, 99 | -0.67709327, 0.17780769, 0.96766925, -0.34928587, -0.0098818, 0.9215773, 100 | -0.4007464, -0.19757128, 0.87548524, -0.4522069, -0.38526076, 101 | 0.82939327, -0.5036674, 0.11191163, 1.0423766, -0.17586003, -0.07577785, 102 | 0.9962846, -0.22732055, -0.26346734, 0.95019263, -0.2787811, 103 | -0.45115682, 0.90410066, -0.3302416 104 | ]) 105 | np.testing.assert_allclose(ray_dirs, ray_dirs_gt, atol=1e-4, rtol=1e-4) 106 | 107 | 108 | if __name__ == '__main__': 109 | absltest.main() 110 | -------------------------------------------------------------------------------- /internal/geopoly.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for constructing geodesic polyhedron, which are used as a basis.""" 16 | 17 | import itertools 18 | import numpy as np 19 | 20 | 21 | def compute_sq_dist(mat0, mat1=None): 22 | """Compute the squared Euclidean distance between all pairs of columns.""" 23 | if mat1 is None: 24 | mat1 = mat0 25 | # Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y. 26 | sq_norm0 = np.sum(mat0**2, 0) 27 | sq_norm1 = np.sum(mat1**2, 0) 28 | sq_dist = sq_norm0[:, None] + sq_norm1[None, :] - 2 * mat0.T @ mat1 29 | sq_dist = np.maximum(0, sq_dist) # Negative values must be numerical errors. 30 | return sq_dist 31 | 32 | 33 | def compute_tesselation_weights(v): 34 | """Tesselate the vertices of a triangle by a factor of `v`.""" 35 | if v < 1: 36 | raise ValueError(f'v {v} must be >= 1') 37 | int_weights = [] 38 | for i in range(v + 1): 39 | for j in range(v + 1 - i): 40 | int_weights.append((i, j, v - (i + j))) 41 | int_weights = np.array(int_weights) 42 | weights = int_weights / v # Barycentric weights. 43 | return weights 44 | 45 | 46 | def tesselate_geodesic(base_verts, base_faces, v, eps=1e-4): 47 | """Tesselate the vertices of a geodesic polyhedron. 48 | 49 | Args: 50 | base_verts: tensor of floats, the vertex coordinates of the geodesic. 51 | base_faces: tensor of ints, the indices of the vertices of base_verts that 52 | constitute eachface of the polyhedra. 53 | v: int, the factor of the tesselation (v==1 is a no-op). 54 | eps: float, a small value used to determine if two vertices are the same. 55 | 56 | Returns: 57 | verts: a tensor of floats, the coordinates of the tesselated vertices. 58 | """ 59 | if not isinstance(v, int): 60 | raise ValueError(f'v {v} must an integer') 61 | tri_weights = compute_tesselation_weights(v) 62 | 63 | verts = [] 64 | for base_face in base_faces: 65 | new_verts = np.matmul(tri_weights, base_verts[base_face, :]) 66 | new_verts /= np.sqrt(np.sum(new_verts**2, 1, keepdims=True)) 67 | verts.append(new_verts) 68 | verts = np.concatenate(verts, 0) 69 | 70 | sq_dist = compute_sq_dist(verts.T) 71 | assignment = np.array([np.min(np.argwhere(d <= eps)) for d in sq_dist]) 72 | unique = np.unique(assignment) 73 | verts = verts[unique, :] 74 | 75 | return verts 76 | 77 | 78 | def generate_basis(base_shape, 79 | angular_tesselation, 80 | remove_symmetries=True, 81 | eps=1e-4): 82 | """Generates a 3D basis by tesselating a geometric polyhedron. 83 | 84 | Args: 85 | base_shape: string, the name of the starting polyhedron, must be either 86 | 'icosahedron' or 'octahedron'. 87 | angular_tesselation: int, the number of times to tesselate the polyhedron, 88 | must be >= 1 (a value of 1 is a no-op to the polyhedron). 89 | remove_symmetries: bool, if True then remove the symmetric basis columns, 90 | which is usually a good idea because otherwise projections onto the basis 91 | will have redundant negative copies of each other. 92 | eps: float, a small number used to determine symmetries. 93 | 94 | Returns: 95 | basis: a matrix with shape [3, n]. 96 | """ 97 | if base_shape == 'icosahedron': 98 | a = (np.sqrt(5) + 1) / 2 99 | verts = np.array([(-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1), 100 | (0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0), 101 | (-a, 1, 0), (a, -1, 0), (-a, -1, 0)]) / np.sqrt(a + 2) 102 | faces = np.array([(0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1), 103 | (8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3), 104 | (7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6), 105 | (6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5), 106 | (7, 2, 11)]) 107 | verts = tesselate_geodesic(verts, faces, angular_tesselation) 108 | elif base_shape == 'octahedron': 109 | verts = np.array([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), 110 | (1, 0, 0)]) 111 | corners = np.array(list(itertools.product([-1, 1], repeat=3))) 112 | pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2) 113 | faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1) 114 | verts = tesselate_geodesic(verts, faces, angular_tesselation) 115 | else: 116 | raise ValueError(f'base_shape {base_shape} not supported') 117 | 118 | if remove_symmetries: 119 | # Remove elements of `verts` that are reflections of each other. 120 | match = compute_sq_dist(verts.T, -verts.T) < eps 121 | verts = verts[np.any(np.triu(match), 1), :] 122 | basis = verts[:, ::-1].astype(np.float32) 123 | return basis 124 | -------------------------------------------------------------------------------- /internal/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mathy utility functions.""" 16 | 17 | import torch 18 | import functorch 19 | import numpy as np 20 | 21 | 22 | def safe_trig_helper(x, fn, t=100 * torch.pi): 23 | """Helper function used by safe_cos/safe_sin: mods x before sin()/cos().""" 24 | return fn(torch.where(torch.abs(x) < t, x, x % t)) 25 | 26 | 27 | def safe_cos(x): 28 | """torch.cos() on a TPU may NaN out for large values.""" 29 | return safe_trig_helper(x, torch.cos) 30 | 31 | 32 | def safe_sin(x): 33 | """torch.sin() on a TPU may NaN out for large values.""" 34 | return safe_trig_helper(x, torch.sin) 35 | 36 | 37 | def log_lerp(t, v0, v1): 38 | """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" 39 | if v0 <= 0 or v1 <= 0: 40 | raise ValueError(f'Interpolants {v0} and {v1} must be positive.') 41 | lv0 = np.log(v0) 42 | lv1 = np.log(v1) 43 | return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0) 44 | 45 | 46 | def learning_rate_decay(step, 47 | lr_init, 48 | lr_final, 49 | max_steps, 50 | lr_delay_steps=0, 51 | lr_delay_mult=1): 52 | """Continuous learning rate decay function. 53 | 54 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 55 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 56 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 57 | function of lr_delay_mult, such that the initial learning rate is 58 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 59 | to the normal learning rate when steps>lr_delay_steps. 60 | 61 | Args: 62 | step: int, the current optimization step. 63 | lr_init: float, the initial learning rate. 64 | lr_final: float, the final learning rate. 65 | max_steps: int, the number of steps during optimization. 66 | lr_delay_steps: int, the number of steps to delay the full learning rate. 67 | lr_delay_mult: float, the multiplier on the rate when delaying it. 68 | 69 | Returns: 70 | lr: the learning for current step 'step'. 71 | """ 72 | if lr_delay_steps > 0: 73 | # A kind of reverse cosine decay. 74 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 75 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) 76 | else: 77 | delay_rate = 1. 78 | return delay_rate * log_lerp(step / max_steps, lr_init, lr_final) / lr_init 79 | 80 | 81 | # def interp(*args): 82 | # """A gather-based (GPU-friendly) vectorized replacement for torch.interp().""" 83 | # args_flat = [x.reshape([-1, x.shape[-1]]) for x in args] 84 | # ret = functorch.vmap(torch.interp)(*args_flat).reshape(args[0].shape) 85 | # return ret 86 | 87 | 88 | def sorted_interp(x, xp, fp): 89 | """A TPU-friendly version of interp(), where xp and fp must be sorted.""" 90 | 91 | # Identify the location in `xp` that corresponds to each `x`. 92 | # The final `True` index in `mask` is the start of the matching interval. 93 | mask = x[..., None, :] >= xp[..., :, None] 94 | 95 | def find_interval(x): 96 | # Grab the value where `mask` switches from True to False, and vice versa. 97 | # This approach takes advantage of the fact that `x` is sorted. 98 | # print(torch.where(mask, x[..., None], x[..., :1, None])) 99 | # print(x[..., None].shape, x[..., :1, None].shape, torch.where(mask, x[..., None], x[..., :1, None]).shape) 100 | x0 = torch.max(torch.where( 101 | mask, x[..., None], x[..., :1, None]), dim=-2).values 102 | x1 = torch.min(torch.where( 103 | ~mask, x[..., None], x[..., -1:, None]), dim=-2).values 104 | return x0, x1 105 | 106 | fp0, fp1 = find_interval(fp) 107 | xp0, xp1 = find_interval(xp) 108 | 109 | offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) 110 | ret = fp0 + offset * (fp1 - fp0) 111 | return ret 112 | 113 | 114 | def interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor: 115 | """One-dimensional linear interpolation for monotonically increasing sample 116 | points. Similar to np.interp. 117 | 118 | Returns the one-dimensional piecewise linear interpolant to a function with 119 | given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. 120 | 121 | Args: 122 | x: the :math:`x`-coordinates at which to evaluate the interpolated 123 | values. 124 | xp: the :math:`x`-coordinates of the data points, must be increasing. 125 | fp: the :math:`y`-coordinates of the data points, same length as `xp`. 126 | 127 | Returns: 128 | the interpolated values, same size as `x`. 129 | 130 | Details: 131 | Taken from issue at https://github.com/pytorch/pytorch/issues/50334 132 | """ 133 | x = x.double() 134 | xp = xp.double() 135 | fp = fp.double() 136 | m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) 137 | b = fp[:-1] - (m * xp[:-1]) 138 | 139 | indices = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1 140 | indices = torch.clamp(indices, 0, len(m) - 1) 141 | 142 | return m[indices] * x + b[indices] -------------------------------------------------------------------------------- /internal/coord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tools for manipulating coordinate spaces and distances along rays.""" 15 | 16 | import torch 17 | from internal import math 18 | 19 | 20 | def contract(x): 21 | """Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077).""" 22 | eps = torch.finfo(torch.float32).eps 23 | # Clamping to eps prevents non-finite gradients when x == 0. 24 | x_mag_sq = torch.max(eps, torch.sum(x**2, dim=-1, keepdims=True)) 25 | z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) 26 | return z 27 | 28 | 29 | def inv_contract(z): 30 | """The inverse of contract().""" 31 | eps = torch.finfo(torch.float32).eps 32 | # Clamping to eps prevents non-finite gradients when z == 0. 33 | z_mag_sq = torch.max(eps, torch.sum(z**2, dim=-1, keepdims=True)) 34 | x = torch.where(z_mag_sq <= 1, z, z / (2 * torch.sqrt(z_mag_sq) - z_mag_sq)) 35 | return x 36 | 37 | 38 | # def track_linearize(fn, mean, cov): 39 | # """Apply function `fn` to a set of means and covariances, ala a Kalman filter. 40 | 41 | # We can analytically transform a Gaussian parameterized by `mean` and `cov` 42 | # with a function `fn` by linearizing `fn` around `mean`, and taking advantage 43 | # of the fact that Covar[Ax + y] = A(Covar[x])A^T (see 44 | # https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details). 45 | 46 | # Args: 47 | # fn: the function applied to the Gaussians parameterized by (mean, cov). 48 | # mean: a tensor of means, where the last axis is the dimension. 49 | # cov: a tensor of covariances, where the last two axes are the dimensions. 50 | 51 | # Returns: 52 | # fn_mean: the transformed means. 53 | # fn_cov: the transformed covariances. 54 | # """ 55 | # if (len(mean.shape) + 1) != len(cov.shape): 56 | # raise ValueError('cov must be non-diagonal') 57 | # fn_mean, lin_fn = jax.linearize(fn, mean) #TODO: HOW THE FUCK DO I FIX THIS? 58 | 59 | # fn_cov = torch.vmap(lin_fn, -1, -2)(torch.vmap(lin_fn, -1, -2)(cov)) 60 | # return fn_mean, fn_cov 61 | 62 | 63 | def construct_ray_warps(fn, t_near, t_far): 64 | """Construct a bijection between metric distances and normalized distances. 65 | 66 | See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a 67 | detailed explanation. 68 | 69 | Args: 70 | fn: the function to ray distances. 71 | t_near: a tensor of near-plane distances. 72 | t_far: a tensor of far-plane distances. 73 | 74 | Returns: 75 | t_to_s: a function that maps distances to normalized distances in [0, 1]. 76 | s_to_t: the inverse of t_to_s. 77 | """ 78 | if fn is None: 79 | fn_fwd = lambda x: x 80 | fn_inv = lambda x: x 81 | elif fn == 'piecewise': 82 | # Piecewise spacing combining identity and 1/x functions to allow t_near=0. 83 | fn_fwd = lambda x: torch.where(x < 1, .5 * x, 1 - .5 / x) 84 | fn_inv = lambda x: torch.where(x < .5, 2 * x, .5 / (1 - x)) 85 | else: 86 | inv_mapping = { 87 | 'reciprocal': torch.reciprocal, 88 | 'log': torch.exp, 89 | 'exp': torch.log, 90 | 'sqrt': torch.square, 91 | 'square': torch.sqrt 92 | } 93 | fn_fwd = fn 94 | fn_inv = inv_mapping[fn.__name__] 95 | 96 | s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)] 97 | t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near) 98 | s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near) 99 | return t_to_s, s_to_t 100 | 101 | 102 | def expected_sin(mean, var): 103 | """Compute the mean of sin(x), x ~ N(mean, var).""" 104 | return torch.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value. 105 | 106 | 107 | def integrated_pos_enc(mean, var, min_deg, max_deg): 108 | """Encode `x` with sinusoids scaled by 2^[min_deg, max_deg). 109 | 110 | Args: 111 | mean: tensor, the mean coordinates to be encoded 112 | var: tensor, the variance of the coordinates to be encoded. 113 | min_deg: int, the min degree of the encoding. 114 | max_deg: int, the max degree of the encoding. 115 | 116 | Returns: 117 | encoded: torch.ndarray, encoded variables. 118 | """ 119 | scales = 2**torch.arange(min_deg, max_deg) 120 | shape = mean.shape[:-1] + (-1,) 121 | scaled_mean = torch.reshape(mean[..., None, :] * scales[:, None], shape) 122 | scaled_var = torch.reshape(var[..., None, :] * scales[:, None]**2, shape) 123 | 124 | return expected_sin( 125 | torch.cat([scaled_mean, scaled_mean + 0.5 * torch.pi], dim=-1), 126 | torch.cat([scaled_var] * 2, dim=-1)) 127 | 128 | 129 | def lift_and_diagonalize(mean, cov, basis): 130 | """Project `mean` and `cov` onto basis and diagonalize the projected cov.""" 131 | fn_mean = torch.matmul(mean, basis) 132 | fn_cov_diag = torch.sum(basis * torch.matmul(cov, basis), dim=-2) 133 | return fn_mean, fn_cov_diag 134 | 135 | 136 | def pos_enc(x, min_deg, max_deg, append_identity=True): 137 | """The positional encoding used by the original NeRF paper.""" 138 | scales = 2**torch.arange(min_deg, max_deg) 139 | shape = x.shape[:-1] + (-1,) 140 | scaled_x = torch.reshape((x[..., None, :] * scales[:, None]), shape) 141 | # Note that we're not using safe_sin, unlike IPE. 142 | four_feat = torch.sin( 143 | torch.cat([scaled_x, scaled_x + 0.5 * torch.pi], dim=-1)) 144 | if append_identity: 145 | return torch.cat([x] + [four_feat], dim=-1) 146 | else: 147 | return four_feat 148 | -------------------------------------------------------------------------------- /tests/image_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for image.""" 16 | 17 | from absl.testing import absltest 18 | from internal import image 19 | import jax 20 | from jax import random 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | def matmul(a, b): 26 | """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" 27 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 28 | 29 | 30 | class ImageTest(absltest.TestCase): 31 | 32 | def test_color_correction(self): 33 | """Test that color correction can undo a CCM + quadratic warp + shift.""" 34 | im_shape = (128, 128, 3) 35 | rng = random.PRNGKey(0) 36 | for _ in range(10): 37 | # Construct a random image. 38 | key, rng = random.split(rng) 39 | im0 = random.uniform(key, shape=im_shape, minval=0.1, maxval=0.9) 40 | 41 | # Construct a random linear + quadratic color transformation. 42 | key, rng = random.split(rng) 43 | ccm_scale = random.normal(key) / 10 44 | key, rng = random.split(rng) 45 | shift = random.normal(key) / 10 46 | key, rng = random.split(rng) 47 | sq_mult = random.normal(key) / 10 48 | key, rng = random.split(rng) 49 | ccm = jnp.eye(3) + random.normal(key, shape=(3, 3)) * ccm_scale 50 | 51 | # Apply that random transformation to the image. 52 | im1 = jnp.clip( 53 | (matmul(jnp.reshape(im0, [-1, 3]), ccm)).reshape(im0.shape) + 54 | sq_mult * im0**2 + shift, 0, 1) 55 | 56 | # Check that color correction recovers the randomly transformed image. 57 | im0_cc = image.color_correct(im0, im1) 58 | np.testing.assert_allclose(im0_cc, im1, atol=1E-5, rtol=1E-5) 59 | 60 | def test_psnr_mse_round_trip(self): 61 | """PSNR -> MSE -> PSNR is a no-op.""" 62 | for psnr in [10., 20., 30.]: 63 | np.testing.assert_allclose( 64 | image.mse_to_psnr(image.psnr_to_mse(psnr)), 65 | psnr, 66 | atol=1E-5, 67 | rtol=1E-5) 68 | 69 | def test_ssim_dssim_round_trip(self): 70 | """SSIM -> DSSIM -> SSIM is a no-op.""" 71 | for ssim in [-0.9, 0, 0.9]: 72 | np.testing.assert_allclose( 73 | image.dssim_to_ssim(image.ssim_to_dssim(ssim)), 74 | ssim, 75 | atol=1E-5, 76 | rtol=1E-5) 77 | 78 | def test_srgb_linearize(self): 79 | x = jnp.linspace(-1, 3, 10000) # Nobody should call this <0 but it works. 80 | # Check that the round-trip transformation is a no-op. 81 | np.testing.assert_allclose( 82 | image.linear_to_srgb(image.srgb_to_linear(x)), x, atol=1E-5, rtol=1E-5) 83 | np.testing.assert_allclose( 84 | image.srgb_to_linear(image.linear_to_srgb(x)), x, atol=1E-5, rtol=1E-5) 85 | # Check that gradients are finite. 86 | self.assertTrue( 87 | jnp.all(jnp.isfinite(jax.vmap(jax.grad(image.linear_to_srgb))(x)))) 88 | self.assertTrue( 89 | jnp.all(jnp.isfinite(jax.vmap(jax.grad(image.srgb_to_linear))(x)))) 90 | 91 | def test_srgb_to_linear_golden(self): 92 | """A lazy golden test for srgb_to_linear.""" 93 | srgb = jnp.linspace(0, 1, 64) 94 | linear = image.srgb_to_linear(srgb) 95 | linear_gt = jnp.array([ 96 | 0.00000000, 0.00122856, 0.00245712, 0.00372513, 0.00526076, 0.00711347, 97 | 0.00929964, 0.01183453, 0.01473243, 0.01800687, 0.02167065, 0.02573599, 98 | 0.03021459, 0.03511761, 0.04045585, 0.04623971, 0.05247922, 0.05918410, 99 | 0.06636375, 0.07402734, 0.08218378, 0.09084171, 0.10000957, 0.10969563, 100 | 0.11990791, 0.13065430, 0.14194246, 0.15377994, 0.16617411, 0.17913227, 101 | 0.19266140, 0.20676863, 0.22146071, 0.23674440, 0.25262633, 0.26911288, 102 | 0.28621066, 0.30392596, 0.32226467, 0.34123330, 0.36083785, 0.38108405, 103 | 0.40197787, 0.42352500, 0.44573134, 0.46860245, 0.49214387, 0.51636110, 104 | 0.54125960, 0.56684470, 0.59312177, 0.62009590, 0.64777250, 0.67615650, 105 | 0.70525320, 0.73506740, 0.76560410, 0.79686830, 0.82886493, 0.86159873, 106 | 0.89507430, 0.92929670, 0.96427040, 1.00000000 107 | ]) 108 | np.testing.assert_allclose(linear, linear_gt, atol=1E-5, rtol=1E-5) 109 | 110 | def test_mse_to_psnr_golden(self): 111 | """A lazy golden test for mse_to_psnr.""" 112 | mse = jnp.exp(jnp.linspace(-10, 0, 64)) 113 | psnr = image.mse_to_psnr(mse) 114 | psnr_gt = jnp.array([ 115 | 43.429447, 42.740090, 42.050735, 41.361378, 40.6720240, 39.982666, 116 | 39.293310, 38.603954, 37.914597, 37.225240, 36.5358850, 35.846527, 117 | 35.157170, 34.467810, 33.778458, 33.089100, 32.3997460, 31.710388, 118 | 31.021034, 30.331675, 29.642320, 28.952961, 28.2636070, 27.574250, 119 | 26.884893, 26.195538, 25.506180, 24.816826, 24.1274700, 23.438112, 120 | 22.748756, 22.059400, 21.370045, 20.680689, 19.9913310, 19.301975, 121 | 18.612620, 17.923262, 17.233906, 16.544550, 15.8551940, 15.165837, 122 | 14.4764805, 13.787125, 13.097769, 12.408413, 11.719056, 11.029700, 123 | 10.3403420, 9.6509850, 8.9616290, 8.2722720, 7.5829163, 6.8935600, 124 | 6.2042036, 5.5148473, 4.825491, 4.136135, 3.4467785, 2.7574227, 125 | 2.0680661, 1.37871, 0.68935364, 0. 126 | ]) 127 | np.testing.assert_allclose(psnr, psnr_gt, atol=1E-5, rtol=1E-5) 128 | 129 | 130 | if __name__ == '__main__': 131 | absltest.main() 132 | -------------------------------------------------------------------------------- /internal/ref_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for reflection directions and directional encodings.""" 16 | 17 | from internal import math 18 | import torch 19 | import numpy as np 20 | 21 | 22 | def reflect(viewdirs, normals): 23 | """Reflect view directions about normals. 24 | 25 | The reflection of a vector v about a unit vector n is a vector u such that 26 | dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two 27 | equations is u = 2 dot(n, v) n - v. 28 | 29 | Args: 30 | viewdirs: [..., 3] array of view directions. 31 | normals: [..., 3] array of normal directions (assumed to be unit vectors). 32 | 33 | Returns: 34 | [..., 3] array of reflection directions. 35 | """ 36 | return 2.0 * torch.sum( 37 | normals * viewdirs, dim=-1, keepdims=True) * normals - viewdirs 38 | 39 | 40 | def l2_normalize(x, eps=torch.tensor(torch.finfo(torch.float32).eps)): 41 | """Normalize x to unit length along last axis.""" 42 | return x / torch.sqrt(torch.maximum(torch.sum(x**2, dim=-1, keepdims=True), eps)) 43 | 44 | 45 | def compute_weighted_mae(weights, normals, normals_gt): 46 | """Compute weighted mean angular error, assuming normals are unit length.""" 47 | one_eps = torch.tensor(1 - torch.finfo(torch.float32).eps) 48 | return (weights * torch.arccos( 49 | torch.clip((normals * normals_gt).sum(-1), -one_eps, 50 | one_eps))).sum() / weights.sum() * 180.0 / torch.pi 51 | 52 | 53 | def generalized_binomial_coeff(a, k): 54 | """Compute generalized binomial coefficients.""" 55 | return np.prod(a - np.arange(k)) / np.math.factorial(k) 56 | 57 | 58 | def assoc_legendre_coeff(l, m, k): 59 | """Compute associated Legendre polynomial coefficients. 60 | 61 | Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the 62 | (l, m)th associated Legendre polynomial, P_l^m(cos(theta)). 63 | 64 | Args: 65 | l: associated Legendre polynomial degree. 66 | m: associated Legendre polynomial order. 67 | k: power of cos(theta). 68 | 69 | Returns: 70 | A float, the coefficient of the term corresponding to the inputs. 71 | """ 72 | return ((-1)**m * 2**l * np.math.factorial(l) / np.math.factorial(k) / 73 | np.math.factorial(l - k - m) * 74 | generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l)) 75 | 76 | 77 | def sph_harm_coeff(l, m, k): 78 | """Compute spherical harmonic coefficients.""" 79 | return (np.sqrt( 80 | (2.0 * l + 1.0) * np.math.factorial(l - m) / 81 | (4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k)) 82 | 83 | 84 | def get_ml_array(deg_view): 85 | """Create a list with all pairs of (l, m) values to use in the encoding.""" 86 | ml_list = [] 87 | for i in range(deg_view): 88 | l = 2**i 89 | # Only use nonnegative m values, later splitting real and imaginary parts. 90 | for m in range(l + 1): 91 | ml_list.append((m, l)) 92 | 93 | # Convert list into a numpy array. 94 | ml_array = np.array(ml_list).T 95 | return ml_array 96 | 97 | 98 | def generate_ide_fn(deg_view): 99 | """Generate integrated directional encoding (IDE) function. 100 | 101 | This function returns a function that computes the integrated directional 102 | encoding from Equations 6-8 of arxiv.org/abs/2112.03907. 103 | 104 | Args: 105 | deg_view: number of spherical harmonics degrees to use. 106 | 107 | Returns: 108 | A function for evaluating integrated directional encoding. 109 | 110 | Raises: 111 | ValueError: if deg_view is larger than 5. 112 | """ 113 | if deg_view > 5: 114 | print('WARNING: Only deg_view of at most 5 is numerically stable.') 115 | # raise ValueError('Only deg_view of at most 5 is numerically stable.') 116 | 117 | ml_array = get_ml_array(deg_view) 118 | l_max = 2**(deg_view - 1) 119 | 120 | # Create a matrix corresponding to ml_array holding all coefficients, which, 121 | # when multiplied (from the right) by the z coordinate Vandermonde matrix, 122 | # results in the z component of the encoding. 123 | mat = torch.zeros((l_max + 1, ml_array.shape[1])) 124 | for i, (m, l) in enumerate(ml_array.T): 125 | for k in range(l - m + 1): 126 | mat[k, i] = sph_harm_coeff(l, m, k) 127 | 128 | def integrated_dir_enc_fn(xyz, kappa_inv): 129 | """Function returning integrated directional encoding (IDE). 130 | 131 | Args: 132 | xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at. 133 | kappa_inv: [..., 1] reciprocal of the concentration parameter of the von 134 | Mises-Fisher distribution. 135 | 136 | Returns: 137 | An array with the resulting IDE. 138 | """ 139 | x = xyz[..., 0:1] 140 | y = xyz[..., 1:2] 141 | z = xyz[..., 2:3] 142 | 143 | # Compute z Vandermonde matrix. 144 | vmz = torch.cat([z**i for i in range(mat.shape[0])], dim=-1) 145 | 146 | # Compute x+iy Vandermonde matrix. 147 | vmxy = torch.cat( 148 | [(x + 1j * y)**m for m in ml_array[0, :]], dim=-1) 149 | 150 | # Get spherical harmonics. 151 | sph_harms = vmxy * torch.matmul(vmz, mat) 152 | 153 | # Apply attenuation function using the von Mises-Fisher distribution 154 | # concentration parameter, kappa. 155 | sigma = torch.tensor(0.5 * ml_array[1, :] * (ml_array[1, :] + 1), dtype=torch.float32) 156 | ide = sph_harms * torch.exp(-sigma * kappa_inv) 157 | 158 | # Split into real and imaginary parts and return 159 | return torch.cat([torch.real(ide), torch.imag(ide)], dim=-1) 160 | 161 | return integrated_dir_enc_fn 162 | 163 | 164 | def generate_dir_enc_fn(deg_view): 165 | """Generate directional encoding (DE) function. 166 | 167 | Args: 168 | deg_view: number of spherical harmonics degrees to use. 169 | 170 | Returns: 171 | A function for evaluating directional encoding. 172 | """ 173 | integrated_dir_enc_fn = generate_ide_fn(deg_view) 174 | 175 | def dir_enc_fn(xyz): 176 | """Function returning directional encoding (DE).""" 177 | return integrated_dir_enc_fn(xyz, torch.zeros_like(xyz[..., :1])) 178 | 179 | return dir_enc_fn 180 | -------------------------------------------------------------------------------- /tests/math_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for math.""" 16 | 17 | import functools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from internal import math 22 | import jax 23 | from jax import random 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | def safe_trig_harness(fn, max_exp): 29 | x = 10**np.linspace(-30, max_exp, 10000) 30 | x = np.concatenate([-x[::-1], np.array([0]), x]) 31 | y_true = getattr(np, fn)(x) 32 | y = getattr(math, 'safe_' + fn)(x) 33 | return y_true, y 34 | 35 | 36 | class MathTest(parameterized.TestCase): 37 | 38 | def test_sin(self): 39 | """In [-1e10, 1e10] safe_sin and safe_cos are accurate.""" 40 | for fn in ['sin', 'cos']: 41 | y_true, y = safe_trig_harness(fn, 10) 42 | self.assertLess(jnp.max(jnp.abs(y - y_true)), 1e-4) 43 | self.assertFalse(jnp.any(jnp.isnan(y))) 44 | # Beyond that range it's less accurate but we just don't want it to be NaN. 45 | for fn in ['sin', 'cos']: 46 | y_true, y = safe_trig_harness(fn, 60) 47 | self.assertFalse(jnp.any(jnp.isnan(y))) 48 | 49 | def test_safe_exp_correct(self): 50 | """math.safe_exp() should match np.exp() for not-huge values.""" 51 | x = jnp.linspace(-80, 80, 10001) 52 | y = math.safe_exp(x) 53 | g = jax.vmap(jax.grad(math.safe_exp))(x) 54 | yg_true = jnp.exp(x) 55 | np.testing.assert_allclose(y, yg_true) 56 | np.testing.assert_allclose(g, yg_true) 57 | 58 | def test_safe_exp_finite(self): 59 | """math.safe_exp() behaves reasonably for huge values.""" 60 | x = jnp.linspace(-100000, 100000, 10001) 61 | y = math.safe_exp(x) 62 | g = jax.vmap(jax.grad(math.safe_exp))(x) 63 | # `y` and `g` should both always be finite. 64 | self.assertTrue(jnp.all(jnp.isfinite(y))) 65 | self.assertTrue(jnp.all(jnp.isfinite(g))) 66 | # The derivative of exp() should be exp(). 67 | np.testing.assert_allclose(y, g) 68 | # safe_exp()'s output and gradient should be monotonic. 69 | self.assertTrue(jnp.all(y[1:] >= y[:-1])) 70 | self.assertTrue(jnp.all(g[1:] >= g[:-1])) 71 | 72 | def test_learning_rate_decay(self): 73 | rng = random.PRNGKey(0) 74 | for _ in range(10): 75 | key, rng = random.split(rng) 76 | lr_init = jnp.exp(random.normal(key) - 3) 77 | key, rng = random.split(rng) 78 | lr_final = lr_init * jnp.exp(random.normal(key) - 5) 79 | key, rng = random.split(rng) 80 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key)))) 81 | 82 | lr_fn = functools.partial( 83 | math.learning_rate_decay, 84 | lr_init=lr_init, 85 | lr_final=lr_final, 86 | max_steps=max_steps) 87 | 88 | # Test that the rate at the beginning is the initial rate. 89 | np.testing.assert_allclose(lr_fn(0), lr_init, atol=1E-5, rtol=1E-5) 90 | 91 | # Test that the rate at the end is the final rate. 92 | np.testing.assert_allclose( 93 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5) 94 | 95 | # Test that the rate at the middle is the geometric mean of the two rates. 96 | np.testing.assert_allclose( 97 | lr_fn(max_steps / 2), 98 | jnp.sqrt(lr_init * lr_final), 99 | atol=1E-5, 100 | rtol=1E-5) 101 | 102 | # Test that the rate past the end is the final rate 103 | np.testing.assert_allclose( 104 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5) 105 | 106 | def test_delayed_learning_rate_decay(self): 107 | rng = random.PRNGKey(0) 108 | for _ in range(10): 109 | key, rng = random.split(rng) 110 | lr_init = jnp.exp(random.normal(key) - 3) 111 | key, rng = random.split(rng) 112 | lr_final = lr_init * jnp.exp(random.normal(key) - 5) 113 | key, rng = random.split(rng) 114 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key)))) 115 | key, rng = random.split(rng) 116 | lr_delay_steps = int( 117 | random.uniform(key, minval=0.1, maxval=0.4) * max_steps) 118 | key, rng = random.split(rng) 119 | lr_delay_mult = jnp.exp(random.normal(key) - 3) 120 | 121 | lr_fn = functools.partial( 122 | math.learning_rate_decay, 123 | lr_init=lr_init, 124 | lr_final=lr_final, 125 | max_steps=max_steps, 126 | lr_delay_steps=lr_delay_steps, 127 | lr_delay_mult=lr_delay_mult) 128 | 129 | # Test that the rate at the beginning is the delayed initial rate. 130 | np.testing.assert_allclose( 131 | lr_fn(0), lr_delay_mult * lr_init, atol=1E-5, rtol=1E-5) 132 | 133 | # Test that the rate at the end is the final rate. 134 | np.testing.assert_allclose( 135 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5) 136 | 137 | # Test that the rate at after the delay is over is the usual rate. 138 | np.testing.assert_allclose( 139 | lr_fn(lr_delay_steps), 140 | math.learning_rate_decay(lr_delay_steps, lr_init, lr_final, 141 | max_steps), 142 | atol=1E-5, 143 | rtol=1E-5) 144 | 145 | # Test that the rate at the middle is the geometric mean of the two rates. 146 | np.testing.assert_allclose( 147 | lr_fn(max_steps / 2), 148 | jnp.sqrt(lr_init * lr_final), 149 | atol=1E-5, 150 | rtol=1E-5) 151 | 152 | # Test that the rate past the end is the final rate 153 | np.testing.assert_allclose( 154 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5) 155 | 156 | @parameterized.named_parameters(('', False), ('sort', True)) 157 | def test_interp(self, sort): 158 | n, d0, d1 = 100, 10, 20 159 | rng = random.PRNGKey(0) 160 | 161 | key, rng = random.split(rng) 162 | x = random.normal(key, [n, d0]) 163 | 164 | key, rng = random.split(rng) 165 | xp = random.normal(key, [n, d1]) 166 | 167 | key, rng = random.split(rng) 168 | fp = random.normal(key, [n, d1]) 169 | 170 | if sort: 171 | xp = jnp.sort(xp, axis=-1) 172 | fp = jnp.sort(fp, axis=-1) 173 | z = math.sorted_interp(x, xp, fp) 174 | else: 175 | z = math.interp(x, xp, fp) 176 | 177 | z_true = jnp.stack([jnp.interp(x[i], xp[i], fp[i]) for i in range(n)]) 178 | np.testing.assert_allclose(z, z_true, atol=1e-5, rtol=1e-5) 179 | 180 | 181 | if __name__ == '__main__': 182 | absltest.main() 183 | -------------------------------------------------------------------------------- /tests/geopoly_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for geopoly.""" 16 | import itertools 17 | 18 | from absl.testing import absltest 19 | from internal import geopoly 20 | import jax 21 | from jax import random 22 | import numpy as np 23 | 24 | 25 | def is_same_basis(x, y, tol=1e-10): 26 | """Check if `x` and `y` describe the same linear basis.""" 27 | match = np.minimum( 28 | geopoly.compute_sq_dist(x, y), geopoly.compute_sq_dist(x, -y)) <= tol 29 | return (np.all(np.array(x.shape) == np.array(y.shape)) and 30 | np.all(np.sum(match, axis=0) == 1) and 31 | np.all(np.sum(match, axis=1) == 1)) 32 | 33 | 34 | class GeopolyTest(absltest.TestCase): 35 | 36 | def test_compute_sq_dist_reference(self): 37 | """Test against a simple reimplementation of compute_sq_dist.""" 38 | num_points = 100 39 | num_dims = 10 40 | rng = random.PRNGKey(0) 41 | key, rng = random.split(rng) 42 | mat0 = jax.random.normal(key, [num_dims, num_points]) 43 | key, rng = random.split(rng) 44 | mat1 = jax.random.normal(key, [num_dims, num_points]) 45 | 46 | sq_dist = geopoly.compute_sq_dist(mat0, mat1) 47 | 48 | sq_dist_ref = np.zeros([num_points, num_points]) 49 | for i in range(num_points): 50 | for j in range(num_points): 51 | sq_dist_ref[i, j] = np.sum((mat0[:, i] - mat1[:, j])**2) 52 | 53 | np.testing.assert_allclose(sq_dist, sq_dist_ref, atol=1e-4, rtol=1e-4) 54 | 55 | def test_compute_sq_dist_single_input(self): 56 | """Test that compute_sq_dist with a single input works correctly.""" 57 | rng = random.PRNGKey(0) 58 | num_points = 100 59 | num_dims = 10 60 | key, rng = random.split(rng) 61 | mat0 = jax.random.normal(key, [num_dims, num_points]) 62 | 63 | sq_dist = geopoly.compute_sq_dist(mat0) 64 | sq_dist_ref = geopoly.compute_sq_dist(mat0, mat0) 65 | np.testing.assert_allclose(sq_dist, sq_dist_ref) 66 | 67 | def test_compute_tesselation_weights_reference(self): 68 | """A reference implementation for triangle tesselation.""" 69 | for v in range(1, 10): 70 | w = geopoly.compute_tesselation_weights(v) 71 | perm = np.array(list(itertools.product(range(v + 1), repeat=3))) 72 | w_ref = perm[np.sum(perm, axis=-1) == v, :] / v 73 | # Check that all rows of x are close to some row in x_ref. 74 | self.assertTrue(is_same_basis(w.T, w_ref.T)) 75 | 76 | def test_generate_basis_golden(self): 77 | """A mediocre golden test against two arbitrary basis choices.""" 78 | basis = geopoly.generate_basis('icosahedron', 2) 79 | basis_golden = np.array([[0.85065081, 0.00000000, 0.52573111], 80 | [0.80901699, 0.50000000, 0.30901699], 81 | [0.52573111, 0.85065081, 0.00000000], 82 | [1.00000000, 0.00000000, 0.00000000], 83 | [0.80901699, 0.50000000, -0.30901699], 84 | [0.85065081, 0.00000000, -0.52573111], 85 | [0.30901699, 0.80901699, -0.50000000], 86 | [0.00000000, 0.52573111, -0.85065081], 87 | [0.50000000, 0.30901699, -0.80901699], 88 | [0.00000000, 1.00000000, 0.00000000], 89 | [-0.52573111, 0.85065081, 0.00000000], 90 | [-0.30901699, 0.80901699, -0.50000000], 91 | [0.00000000, 0.52573111, 0.85065081], 92 | [-0.30901699, 0.80901699, 0.50000000], 93 | [0.30901699, 0.80901699, 0.50000000], 94 | [0.50000000, 0.30901699, 0.80901699], 95 | [0.50000000, -0.30901699, 0.80901699], 96 | [0.00000000, 0.00000000, 1.00000000], 97 | [-0.50000000, 0.30901699, 0.80901699], 98 | [-0.80901699, 0.50000000, 0.30901699], 99 | [-0.80901699, 0.50000000, -0.30901699]]) 100 | self.assertTrue(is_same_basis(basis.T, basis_golden.T)) 101 | 102 | basis = geopoly.generate_basis('octahedron', 4) 103 | basis_golden = np.array([[0.00000000, 0.00000000, -1.00000000], 104 | [0.00000000, -0.31622777, -0.94868330], 105 | [0.00000000, -0.70710678, -0.70710678], 106 | [0.00000000, -0.94868330, -0.31622777], 107 | [0.00000000, -1.00000000, 0.00000000], 108 | [-0.31622777, 0.00000000, -0.94868330], 109 | [-0.40824829, -0.40824829, -0.81649658], 110 | [-0.40824829, -0.81649658, -0.40824829], 111 | [-0.31622777, -0.94868330, 0.00000000], 112 | [-0.70710678, 0.00000000, -0.70710678], 113 | [-0.81649658, -0.40824829, -0.40824829], 114 | [-0.70710678, -0.70710678, 0.00000000], 115 | [-0.94868330, 0.00000000, -0.31622777], 116 | [-0.94868330, -0.31622777, 0.00000000], 117 | [-1.00000000, 0.00000000, 0.00000000], 118 | [0.00000000, -0.31622777, 0.94868330], 119 | [0.00000000, -0.70710678, 0.70710678], 120 | [0.00000000, -0.94868330, 0.31622777], 121 | [0.40824829, -0.40824829, 0.81649658], 122 | [0.40824829, -0.81649658, 0.40824829], 123 | [0.31622777, -0.94868330, 0.00000000], 124 | [0.81649658, -0.40824829, 0.40824829], 125 | [0.70710678, -0.70710678, 0.00000000], 126 | [0.94868330, -0.31622777, 0.00000000], 127 | [0.31622777, 0.00000000, -0.94868330], 128 | [0.40824829, 0.40824829, -0.81649658], 129 | [0.40824829, 0.81649658, -0.40824829], 130 | [0.70710678, 0.00000000, -0.70710678], 131 | [0.81649658, 0.40824829, -0.40824829], 132 | [0.94868330, 0.00000000, -0.31622777], 133 | [0.40824829, -0.40824829, -0.81649658], 134 | [0.40824829, -0.81649658, -0.40824829], 135 | [0.81649658, -0.40824829, -0.40824829]]) 136 | self.assertTrue(is_same_basis(basis.T, basis_golden.T)) 137 | 138 | 139 | if __name__ == '__main__': 140 | absltest.main() 141 | -------------------------------------------------------------------------------- /internal/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions.""" 16 | 17 | import enum 18 | import os 19 | from typing import Any, Dict, Optional, Union 20 | 21 | import numpy as np 22 | from PIL import Image 23 | import torch 24 | from dataclasses import dataclass, fields 25 | 26 | 27 | _Array = Union[np.ndarray, torch.Tensor] 28 | 29 | 30 | @dataclass 31 | class Pixels: 32 | """All tensors must have the same num_dims and first n-1 dims must match.""" 33 | pix_x_int: _Array 34 | pix_y_int: _Array 35 | lossmult: _Array 36 | near: _Array 37 | far: _Array 38 | cam_idx: _Array 39 | 40 | def __getitem__(self, s): 41 | if isinstance(s, int): 42 | return Rays(*([getattr(self, dim.name)[s]] 43 | for dim in fields(self))) 44 | elif isinstance(s, slice): 45 | return Rays(*(getattr(self, dim.name)[s] 46 | for dim in fields(self))) 47 | else: 48 | raise ValueError('Argument to __getitem__ must be int or slice') 49 | 50 | 51 | @dataclass 52 | class Rays: 53 | """All tensors must have the same num_dims and first n-1 dims must match.""" 54 | origins: _Array 55 | directions: _Array 56 | viewdirs: _Array 57 | radii: _Array 58 | imageplane: _Array 59 | lossmult: _Array 60 | near: _Array 61 | far: _Array 62 | cam_idx: _Array 63 | 64 | def __getitem__(self, s): 65 | if isinstance(s, int): 66 | return Rays(*[[getattr(self, dim.name)[s]] 67 | for dim in fields(self)]) 68 | elif isinstance(s, slice): 69 | return Rays(*[getattr(self, dim.name)[s] 70 | for dim in fields(self)]) 71 | else: 72 | raise ValueError('Argument to __getitem__ must be int or slice') 73 | 74 | def to(self, device): 75 | for dim in fields(self): 76 | if isinstance(getattr(self, dim.name), np.ndarray): 77 | # convert to tensor and send to device 78 | setattr(self, dim.name, torch.tensor(getattr(self, dim.name), 79 | dtype=torch.float32, device=device)) 80 | elif isinstance(getattr(self, dim.name), torch.Tensor): 81 | # send to device if not already there 82 | if getattr(self, dim.name).device != device: 83 | getattr(self, dim.name).to(device) 84 | else: 85 | raise ValueError('Rays members must be either np.ndarray or torch.Tensor') 86 | 87 | def reshape(self, *dims): 88 | return Rays(*[getattr(self, dim.name).reshape(*dims) 89 | for dim in fields(self)]) 90 | @property 91 | def shape(self): 92 | return self.origins.shape 93 | 94 | 95 | # Dummy Rays object that can be used to initialize NeRF model. 96 | def dummy_rays() -> Rays: 97 | def data_fn(n): return torch.zeros((1, n)) 98 | return Rays( 99 | origins=data_fn(3), 100 | directions=data_fn(3), 101 | viewdirs=data_fn(3), 102 | radii=data_fn(1), 103 | imageplane=data_fn(2), 104 | lossmult=data_fn(1), 105 | near=data_fn(1), 106 | far=data_fn(1), 107 | cam_idx=data_fn(1).type(torch.int32)) 108 | 109 | 110 | @dataclass 111 | class Batch: 112 | """Data batch for NeRF training or testing.""" 113 | rays: Union[Pixels, Rays] 114 | rgb: Optional[_Array] = None 115 | disps: Optional[_Array] = None 116 | normals: Optional[_Array] = None 117 | alphas: Optional[_Array] = None 118 | 119 | 120 | class DataSplit(enum.Enum): 121 | """Dataset split.""" 122 | TRAIN = 'train' 123 | TEST = 'test' 124 | 125 | 126 | class BatchingMethod(enum.Enum): 127 | """Draw rays randomly from a single image or all images, in each batch.""" 128 | ALL_IMAGES = 'all_images' 129 | SINGLE_IMAGE = 'single_image' 130 | 131 | 132 | def open_file(pth, mode='r'): 133 | return open(pth, mode=mode) 134 | 135 | 136 | def file_exists(pth): 137 | return os.path.exists(pth) 138 | 139 | 140 | def listdir(pth): 141 | return os.listdir(pth) 142 | 143 | 144 | def isdir(pth): 145 | return os.path.isdir(pth) 146 | 147 | 148 | def makedirs(pth): 149 | if not file_exists(pth): 150 | os.makedirs(pth) 151 | 152 | 153 | def unshard(x, padding=0): 154 | """Collect the sharded tensor to the shape before sharding.""" 155 | y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:])) 156 | if padding > 0: 157 | y = y[:-padding] 158 | return y 159 | 160 | 161 | def load_img(pth: str) -> np.ndarray: 162 | """Load an image and cast to float32.""" 163 | with open_file(pth, 'rb') as f: 164 | image = np.array(Image.open(f), dtype=np.float32) 165 | return image 166 | 167 | 168 | def save_img_u8(img, pth, mask=None): 169 | """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" 170 | with open_file(pth, 'wb') as f: 171 | img_np = (np.clip(np.nan_to_num(img), 0., 1.) 172 | * 255).astype(np.uint8).squeeze() 173 | if mask is not None: 174 | mask_np = (np.nan_to_num(mask)).astype(np.float32).squeeze() 175 | mask_np = 255 * (mask_np - mask_np.min()) / \ 176 | (mask_np.max() - mask_np.min()) 177 | img_np = (255 - mask_np) + img_np 178 | img_np = np.array((255 * (img_np - img_np.min()) / 179 | (img_np.max() - img_np.min())), dtype=np.uint8) 180 | 181 | Image.fromarray(img_np).save(f, 'PNG') 182 | 183 | 184 | def save_img_f32(depthmap, pth): 185 | """Save an image (probably a depthmap) to disk as a float32 TIFF.""" 186 | with open_file(pth, 'wb') as f: 187 | Image.fromarray(np.nan_to_num(depthmap).astype( 188 | np.float32)).save(f, 'TIFF') 189 | 190 | 191 | def merge_chunks(chunks): 192 | merged_chunks = {} 193 | for key in chunks[0]: 194 | if isinstance(chunks[0][key], list): 195 | merged_chunks[key] = [ 196 | torch.cat([chunk[key][idx] for chunk in chunks]) 197 | for idx in range(len(chunks[0][key])) 198 | ] 199 | elif isinstance(chunks[0][key], torch.Tensor): 200 | merged_chunks[key] = torch.cat([tdict[key] for tdict in chunks]) 201 | else: 202 | raise ValueError('Contents should be either list or tensor') 203 | return merged_chunks 204 | 205 | 206 | def recursive_detach(v: [list, torch.Tensor]): 207 | if isinstance(v, torch.Tensor): 208 | return v.detach() 209 | elif isinstance(v, list): 210 | return [recursive_detach(vk) for vk in v] 211 | elif isinstance(v, dict): 212 | return {k: recursive_detach(vk) for k, vk in v.items()} 213 | else: 214 | raise ValueError('Invalid input. v must be torch.Tensor or list') 215 | 216 | 217 | def recursive_device_switch( 218 | v: [list, torch.Tensor], device: torch.device): 219 | if isinstance(v, torch.Tensor): 220 | return v.to(device) 221 | elif isinstance(v, list): 222 | return [recursive_device_switch(vk, device) for vk in v] 223 | elif isinstance(v, dict): 224 | return {k: recursive_device_switch(vk, device) for k, vk in v.items()} 225 | else: 226 | raise ValueError('Invalid input. v must be torch.Tensor or list') 227 | -------------------------------------------------------------------------------- /internal/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions for handling configurations.""" 16 | 17 | import dataclasses 18 | from typing import Any, Callable, Optional, Tuple 19 | 20 | from absl import flags 21 | import gin 22 | from internal import utils 23 | import numpy as np 24 | 25 | gin.add_config_file_search_path('experimental/users/barron/mipnerf360/') 26 | 27 | @gin.configurable() 28 | @dataclasses.dataclass 29 | class Config: 30 | """Configuration flags for everything.""" 31 | dataset_loader: str = 'llff' # The type of dataset loader to use. 32 | dataset_debug_mode: bool = False # If True, always loads specific batch 33 | batching: str = 'all_images' # Batch composition, [single_image, all_images]. 34 | batch_size: int = 16384 # The number of rays/pixels in each batch. 35 | patch_size: int = 1 # Resolution of patches sampled for training batches. 36 | factor: int = 0 # The downsample factor of images, 0 for no downsampling. 37 | load_alphabetical: bool = True # Load images in COLMAP vs alphabetical 38 | # ordering (affects heldout test set). 39 | forward_facing: bool = False # Set to True for forward-facing LLFF captures. 40 | render_path: bool = False # If True, render a path. Used only by LLFF. 41 | llffhold: int = 8 # Use every Nth image for the test set. Used only by LLFF. 42 | # If true, use all input images for training. 43 | llff_use_all_images_for_training: bool = False 44 | llff_white_background: bool = False # If True, remove bkgd with masks 45 | use_tiffs: bool = False # If True, use 32-bit TIFFs. Used only by Blender. 46 | compute_eval_metrics: bool = False # If True, compute SSIM and PSNR 47 | compute_disp_metrics: bool = False # If True, load and compute disparity MSE. 48 | compute_normal_metrics: bool = False # If True, load and compute normal MAE. 49 | gc_every: int = 10000 # The number of steps between garbage collections. 50 | disable_multiscale_loss: bool = False # If True, disable multiscale loss. 51 | randomized: bool = True # Use randomized stratified sampling. 52 | near: float = 2. # Near plane distance. 53 | far: float = 6. # Far plane distance. 54 | checkpoint_dir: Optional[str] = None # Where to log checkpoints. 55 | render_dir: Optional[str] = None # Output rendering directory. 56 | data_dir: Optional[str] = None # Input data directory. 57 | vocab_tree_path: Optional[str] = None # Path to vocab tree for COLMAP. 58 | render_chunk_size: int = 16384 # Chunk size for whole-image renderings. 59 | num_showcase_images: int = 5 # The number of test-set images to showcase. 60 | deterministic_showcase: bool = True # If True, showcase the same images. 61 | vis_num_rays: int = 16 # The number of rays to visualize. 62 | # Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage. 63 | vis_decimate: int = 0 64 | 65 | # Only used by train.py: 66 | max_steps: int = 250000 # The number of optimization steps. 67 | early_exit_steps: Optional[int] = None # Early stopping, for debugging. 68 | checkpoint_every: int = 25000 # The number of steps to save a checkpoint. 69 | print_every: int = 100 # The number of steps between reports to tensorboard. 70 | train_render_every: int = 5000 # Steps between test set renders when training 71 | cast_rays_in_train_step: bool = False # If True, compute rays in train step. 72 | data_loss_type: str = 'charb' # What kind of loss to use ('mse' or 'charb'). 73 | charb_padding: float = 0.001 # The padding used for Charbonnier loss. 74 | data_loss_mult: float = 1.0 # Mult for the finest data term in the loss. 75 | data_coarse_loss_mult: float = 0. # Multiplier for the coarser data terms. 76 | interlevel_loss_mult: float = 1.0 # Mult. for the loss on the proposal MLP. 77 | orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss. 78 | orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights. 79 | # What that loss is imposed on, options are 'normals' or 'normals_pred'. 80 | orientation_loss_target: str = 'normals_pred' 81 | predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss. 82 | # Mult. on the coarser predicted normal loss. 83 | predicted_normal_coarse_loss_mult: float = 0.0 84 | 85 | lr_init: float = 0.002 # The initial learning rate. 86 | lr_final: float = 0.00002 # The final learning rate. 87 | lr_delay_steps: int = 512 # The number of "warmup" learning steps. 88 | lr_delay_mult: float = 0.01 # How much severe the "warmup" should be. 89 | adam_beta1: float = 0.9 # Adam's beta2 hyperparameter. 90 | adam_beta2: float = 0.999 # Adam's beta2 hyperparameter. 91 | adam_eps: float = 1e-6 # Adam's epsilon hyperparameter. 92 | grad_max_norm: float = 0.001 # Gradient clipping magnitude, disabled if == 0. 93 | grad_max_val: float = 0. # Gradient clipping value, disabled if == 0. 94 | distortion_loss_mult: float = 0.01 # Multiplier on the distortion loss. 95 | 96 | # Only used by eval.py: 97 | eval_only_once: bool = True # If True evaluate the model only once, ow loop. 98 | eval_save_output: bool = True # If True save predicted images to disk. 99 | eval_save_ray_data: bool = False # If True save individual ray traces. 100 | eval_render_interval: int = 1 # The interval between images saved to disk. 101 | eval_dataset_limit: int = np.iinfo(np.int32).max # Num test images to eval. 102 | eval_quantize_metrics: bool = True # If True, run metrics on 8-bit images. 103 | eval_crop_borders: int = 0 # Ignore c border pixels in eval (x[c:-c, c:-c]). 104 | 105 | # Only used by render.py 106 | render_video_fps: int = 60 # Framerate in frames-per-second. 107 | render_video_crf: int = 18 # Constant rate factor for ffmpeg video quality. 108 | render_path_frames: int = 120 # Number of frames in render path. 109 | z_variation: float = 0. # How much height variation in render path. 110 | z_phase: float = 0. # Phase offset for height variation in render path. 111 | render_dist_percentile: float = 0.5 # How much to trim from near/far planes. 112 | render_dist_curve_fn: Callable[..., Any] = np.log # How depth is curved. 113 | render_path_file: Optional[str] = None # Numpy render pose file to load. 114 | render_job_id: int = 0 # Render job id. 115 | render_num_jobs: int = 1 # Total number of render jobs. 116 | render_resolution: Optional[Tuple[int, int]] = None # Render resolution, as 117 | # (width, height). 118 | render_focal: Optional[float] = None # Render focal length. 119 | render_camtype: Optional[str] = None # 'perspective', 'fisheye', or 'pano'. 120 | render_spherical: bool = False # Render spherical 360 panoramas. 121 | render_save_async: bool = True # Save to CNS using a separate thread. 122 | 123 | render_spline_keyframes: Optional[str] = None # Text file containing names of 124 | # images to be used as spline 125 | # keyframes, OR directory 126 | # containing those images. 127 | render_spline_n_interp: int = 30 # Num. frames to interpolate per keyframe. 128 | render_spline_degree: int = 5 # Polynomial degree of B-spline interpolation. 129 | render_spline_smoothness: float = .03 # B-spline smoothing factor, 0 for 130 | 131 | def define_common_flags(): 132 | # Define the flags used by both train.py and eval.py 133 | flags.DEFINE_string('mode', None, 'Required by GINXM, not used.') 134 | flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.') 135 | flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') 136 | flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.') 137 | 138 | 139 | def load_config(save_config=True): 140 | """Load the config, and optionally checkpoint it.""" 141 | gin.parse_config_files_and_bindings( 142 | flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True) 143 | config = Config() 144 | if save_config: 145 | utils.makedirs(config.checkpoint_dir) 146 | with utils.open_file(config.checkpoint_dir + '/config.gin', 'w') as f: 147 | f.write(gin.config_str()) 148 | return config 149 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Render script.""" 16 | 17 | import concurrent.futures 18 | import functools 19 | import glob 20 | import os 21 | import time 22 | import gc 23 | 24 | from absl import app 25 | import torch 26 | import gin 27 | from internal import configs 28 | from internal import datasets 29 | from internal import models 30 | from internal import train_utils 31 | from internal import utils 32 | 33 | from matplotlib import cm 34 | import mediapy as media 35 | import numpy as np 36 | 37 | configs.define_common_flags() 38 | 39 | 40 | def create_videos(config, base_dir, out_dir, out_name, num_frames): 41 | """Creates videos out of the images saved to disk.""" 42 | names = [n for n in config.checkpoint_dir.split('/') if n] 43 | # Last two parts of checkpoint path are experiment name and scene name. 44 | exp_name, scene_name = names[-2:] 45 | video_prefix = f'{scene_name}_{exp_name}_{out_name}' 46 | 47 | zpad = max(3, len(str(num_frames - 1))) 48 | 49 | def idx_to_str(idx): 50 | return str(idx).zfill(zpad) 51 | 52 | utils.makedirs(base_dir) 53 | 54 | # Load one example frame to get image shape and depth range. 55 | depth_file = os.path.join(out_dir, f'distance_mean_{idx_to_str(0)}.tiff') 56 | depth_frame = utils.load_img(depth_file) 57 | shape = depth_frame.shape 58 | p = config.render_dist_percentile 59 | distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p]) 60 | lo, hi = [config.render_dist_curve_fn(x) for x in distance_limits] 61 | print(f'Video shape is {shape[:2]}') 62 | 63 | video_kwargs = { 64 | 'shape': shape[:2], 65 | 'codec': 'h264', 66 | 'fps': config.render_video_fps, 67 | 'crf': config.render_video_crf, 68 | } 69 | 70 | for k in ['color', 'normals', 'acc', 'distance_mean', 'distance_median']: 71 | video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') 72 | input_format = 'gray' if k == 'acc' else 'rgb' 73 | file_ext = 'png' if k in ['color', 'normals'] else 'tiff' 74 | idx = 0 75 | file0 = os.path.join(out_dir, f'{k}_{idx_to_str(0)}.{file_ext}') 76 | if not utils.file_exists(file0): 77 | print(f'Images missing for tag {k}') 78 | continue 79 | print(f'Making video {video_file}...') 80 | with media.VideoWriter( 81 | video_file, **video_kwargs, input_format=input_format) as writer: 82 | for idx in range(num_frames): 83 | img_file = os.path.join( 84 | out_dir, f'{k}_{idx_to_str(idx)}.{file_ext}') 85 | if not utils.file_exists(img_file): 86 | ValueError(f'Image file {img_file} does not exist.') 87 | img = utils.load_img(img_file) 88 | if k in ['color', 'normals']: 89 | img = img / 255. 90 | elif k.startswith('distance'): 91 | img = config.render_dist_curve_fn(img) 92 | img = np.clip((img - np.minimum(lo, hi)) / 93 | np.abs(hi - lo), 0, 1) 94 | img = cm.get_cmap('turbo')(img)[..., :3] 95 | 96 | frame = (np.clip(np.nan_to_num(img), 0., 1.) 97 | * 255.).astype(np.uint8) 98 | writer.add_image(frame) 99 | idx += 1 100 | 101 | 102 | def main(unused_argv): 103 | 104 | config = configs.load_config(save_config=False) 105 | 106 | # Setup device. 107 | if torch.cuda.is_available(): 108 | device = torch.device('cuda') 109 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 110 | else: 111 | device = torch.device('cpu') 112 | torch.set_default_tensor_type('torch.FloatTensor') 113 | 114 | # Create test dataset. 115 | dataset = datasets.load_dataset('test', config.data_dir, config) 116 | 117 | # Set random number generator seeds. 118 | torch.manual_seed(20221019) 119 | np.random.seed(20221019) 120 | 121 | # Create model. 122 | setup = train_utils.setup_model(config, dataset=dataset) 123 | model, _, _, render_eval_fn, _ = setup 124 | state = dict(step=0, model=model.state_dict()) 125 | 126 | # Load states from checkpoint. 127 | if utils.isdir(config.checkpoint_dir): 128 | files = sorted([f for f in os.listdir(config.checkpoint_dir) 129 | if f.startswith('checkpoint')], 130 | key=lambda x: int(x.split('_')[-1])) 131 | # if there are checkpoints in the dir, load the latest checkpoint 132 | if files: 133 | checkpoint_name = files[-1] 134 | state = torch.load(os.path.join( 135 | config.checkpoint_dir, checkpoint_name)) 136 | model.load_state_dict(state['model']) 137 | model.eval() 138 | model.to(device) 139 | else: 140 | utils.makedirs(config.checkpoint_dir) 141 | 142 | step = int(state['step']) 143 | print(f'Rendering checkpoint at step {step}.') 144 | 145 | out_name = 'path_renders' if config.render_path else 'test_preds' 146 | out_name = f'{out_name}_step_{step}' 147 | base_dir = config.render_dir 148 | if base_dir is None: 149 | base_dir = os.path.join(config.checkpoint_dir, 'render') 150 | out_dir = os.path.join(base_dir, out_name) 151 | if not utils.isdir(out_dir): 152 | utils.makedirs(out_dir) 153 | 154 | def path_fn(x): 155 | return os.path.join(out_dir, x) 156 | 157 | # Ensure sufficient zero-padding of image indices in output filenames. 158 | zpad = max(3, len(str(dataset.size - 1))) 159 | 160 | def idx_to_str(idx): 161 | return str(idx).zfill(zpad) 162 | 163 | if config.render_save_async: 164 | async_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) 165 | async_futures = [] 166 | 167 | def save_fn(fn, *args, **kwargs): 168 | async_futures.append(async_executor.submit(fn, *args, **kwargs)) 169 | else: 170 | def save_fn(fn, *args, **kwargs): 171 | fn(*args, **kwargs) 172 | 173 | for idx in range(dataset.size): 174 | if idx % config.render_num_jobs != config.render_job_id: 175 | continue 176 | # If current image and next image both already exist, skip ahead. 177 | idx_str = idx_to_str(idx) 178 | curr_file = path_fn(f'color_{idx_str}.png') 179 | next_idx_str = idx_to_str(idx + config.render_num_jobs) 180 | next_file = path_fn(f'color_{next_idx_str}.png') 181 | if utils.file_exists(curr_file) and utils.file_exists(next_file): 182 | print(f'Image {idx}/{dataset.size} already exists, skipping') 183 | continue 184 | print(f'Evaluating image {idx+1}/{dataset.size}') 185 | eval_start_time = time.time() 186 | batch = dataset.generate_ray_batch(idx) 187 | train_frac = 1. 188 | 189 | with torch.no_grad(): 190 | rendering = models.render_image( 191 | functools.partial(render_eval_fn, train_frac), 192 | batch.rays, config) 193 | 194 | print(f'Rendered in {(time.time() - eval_start_time):0.3f}s') 195 | 196 | save_fn( 197 | utils.save_img_u8, rendering['rgb'], path_fn(f'color_{idx_str}.png')) 198 | if 'normals' in rendering: 199 | save_fn( 200 | utils.save_img_u8, rendering['normals'] / 2. + 0.5, 201 | path_fn(f'normals_{idx_str}.png')) 202 | save_fn( 203 | utils.save_img_f32, rendering['distance_mean'], 204 | path_fn(f'distance_mean_{idx_str}.tiff')) 205 | save_fn( 206 | utils.save_img_f32, rendering['distance_median'], 207 | path_fn(f'distance_median_{idx_str}.tiff')) 208 | save_fn( 209 | utils.save_img_f32, rendering['acc'], path_fn(f'acc_{idx_str}.tiff')) 210 | save_fn( 211 | utils.save_img_u8, rendering['roughness'], path_fn( 212 | f'rho_{idx_str}.png'), 213 | mask=rendering['acc']) 214 | 215 | num_files = len(glob.glob(path_fn('acc_*.tiff'))) 216 | if num_files == dataset.size: 217 | print( 218 | f'All files found, creating videos (job {config.render_job_id}).') 219 | create_videos(config, base_dir, out_dir, out_name, dataset.size) 220 | 221 | 222 | if __name__ == '__main__': 223 | with gin.config_scope('eval'): # Use the same scope as eval.py 224 | app.run(main) 225 | -------------------------------------------------------------------------------- /internal/image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for processing images.""" 16 | 17 | import types 18 | from typing import Optional, Union 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | # import dm_pix 23 | import numpy as np 24 | from numpy import array as tensor 25 | 26 | _Array = Union[np.ndarray, torch.tensor] 27 | np.tensor = np.array 28 | 29 | 30 | def mse_to_psnr(mse): 31 | """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" 32 | return -10. / torch.log(torch.tensor(10.)) * torch.log(mse) 33 | 34 | 35 | def psnr_to_mse(psnr): 36 | """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" 37 | return torch.exp(-0.1 * torch.log(torch.tensor(10.)) * psnr) 38 | 39 | 40 | def ssim_to_dssim(ssim): 41 | """Compute DSSIM given an SSIM.""" 42 | return (1 - ssim) / 2 43 | 44 | 45 | def dssim_to_ssim(dssim): 46 | """Compute DSSIM given an SSIM.""" 47 | return 1 - 2 * dssim 48 | 49 | 50 | def linear_to_srgb(linear: _Array, 51 | eps: Optional[float] = None, 52 | xnp: types.ModuleType = torch) -> _Array: 53 | """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" 54 | if eps is None: 55 | eps = xnp.tensor(xnp.finfo(xnp.float32).eps) 56 | srgb0 = 323 / 25 * linear 57 | srgb1 = (211 * xnp.maximum(eps, linear) ** (5 / 12) - 11) / 200 58 | return xnp.where(linear <= 0.0031308, srgb0, srgb1) 59 | 60 | 61 | def srgb_to_linear(srgb: _Array, 62 | eps: Optional[float] = None, 63 | xnp: types.ModuleType = torch) -> _Array: 64 | """Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" 65 | if eps is None: 66 | eps = xnp.tensor(xnp.finfo(xnp.float32).eps) 67 | linear0 = 25 / 323 * srgb 68 | linear1 = xnp.maximum(eps, ((200 * srgb + 11) / (211))) ** (12 / 5) 69 | return xnp.where(srgb <= 0.04045, linear0, linear1) 70 | 71 | 72 | def downsample(img, factor): 73 | """Area downsample img (factor must evenly divide img height and width).""" 74 | sh = img.shape 75 | if not (sh[0] % factor == 0 and sh[1] % factor == 0): 76 | raise ValueError(f'Downsampling factor {factor} does not ' 77 | f'evenly divide image shape {sh[:2]}') 78 | img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:]) 79 | img = img.mean((1, 3)) 80 | return img 81 | 82 | 83 | def color_correct(img, ref, num_iters=5, eps=0.5 / 255): 84 | """Warp `img` to match the colors in `ref_img`.""" 85 | if img.shape[-1] != ref.shape[-1]: 86 | raise ValueError( 87 | f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match' 88 | ) 89 | num_channels = img.shape[-1] 90 | img_mat = img.reshape([-1, num_channels]) 91 | ref_mat = ref.reshape([-1, num_channels]) 92 | is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps)) # z \in [eps, 1-eps]. 93 | mask0 = is_unclipped(img_mat) 94 | # Because the set of saturated pixels may change after solving for a 95 | # transformation, we repeatedly solve a system `num_iters` times and update 96 | # our estimate of which pixels are saturated. 97 | for _ in range(num_iters): 98 | # Construct the left hand side of a linear system that contains a quadratic 99 | # expansion of each pixel of `img`. 100 | a_mat = [] 101 | for c in range(num_channels): 102 | a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) # Quadratic term. 103 | a_mat.append(img_mat) # Linear term. 104 | a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term. 105 | a_mat = torch.cat(a_mat, dim=-1) 106 | warp = [] 107 | for c in range(num_channels): 108 | # Construct the right hand side of a linear system containing each color 109 | # of `ref`. 110 | b = ref_mat[:, c] 111 | # Ignore rows of the linear system that were saturated in the input or are 112 | # saturated in the current corrected color estimate. 113 | mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) 114 | ma_mat = torch.where(mask[:, None], a_mat, 0) 115 | mb = torch.where(mask, b, 0) 116 | # Solve the linear system. We're using the np.lstsq instead of torch because 117 | # it's significantly more stable in this case, for some reason. 118 | w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0] 119 | assert torch.all(torch.isfinite(w)) 120 | warp.append(w) 121 | warp = torch.stack(warp, dim=-1) 122 | # Apply the warp to update img_mat. 123 | img_mat = torch.clip( 124 | torch.matmul(a_mat, warp), 0, 1) 125 | corrected_img = torch.reshape(img_mat, img.shape) 126 | return corrected_img 127 | 128 | 129 | class MetricHarness: 130 | """A helper class for evaluating several error metrics.""" 131 | 132 | def __init__(self): 133 | pass 134 | 135 | def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s): 136 | """Evaluate the error between a predicted rgb image and the true image.""" 137 | psnr = float(mse_to_psnr(((rgb_pred - rgb_gt) ** 2).mean())) 138 | ssim = float(self.compute_ssim(rgb_pred, rgb_gt)) 139 | 140 | return { 141 | name_fn('psnr'): psnr, 142 | name_fn('ssim'): ssim, 143 | } 144 | 145 | @staticmethod 146 | def compute_ssim(img0, img1, max_val=1, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, return_map=False): 147 | """Computes SSIM from two images. 148 | This function was modeled after tf.image.ssim, and should produce comparable 149 | output. 150 | Args: 151 | img0: torch.tensor. An image of size [..., width, height, num_channels]. 152 | img1: torch.tensor. An image of size [..., width, height, num_channels]. 153 | max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. 154 | filter_size: int >= 1. Window size. 155 | filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. 156 | k1: float > 0. One of the SSIM dampening parameters. 157 | k2: float > 0. One of the SSIM dampening parameters. 158 | return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned 159 | Returns: 160 | Each image's mean SSIM, or a tensor of individual values if `return_map`. 161 | """ 162 | device = img0.device 163 | ori_shape = img0.size() 164 | width, height, num_channels = ori_shape[-3:] 165 | img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 166 | img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 167 | batch_size = img0.shape[0] 168 | dtype = img0.dtype 169 | 170 | # Construct a 1D Gaussian blur filter. 171 | hw = filter_size // 2 172 | shift = (2 * hw - filter_size + 1) / 2 173 | f_i = ((torch.arange(filter_size, device=device, dtype=dtype) - hw + shift) / filter_sigma) ** 2 174 | filt = torch.exp(-0.5 * f_i) 175 | filt /= torch.sum(filt) 176 | 177 | # Blur in x and y (faster than the 2D convolution). 178 | # z is a tensor of size [B, H, W, C] 179 | filt_fn1 = lambda z: F.conv2d( 180 | z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1), 181 | padding=[hw, 0], groups=num_channels) 182 | filt_fn2 = lambda z: F.conv2d( 183 | z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1), 184 | padding=[0, hw], groups=num_channels) 185 | 186 | # Vmap the blurs to the tensor size, and then compose them. 187 | filt_fn = lambda z: filt_fn1(filt_fn2(z)) 188 | mu0 = filt_fn(img0) 189 | mu1 = filt_fn(img1) 190 | mu00 = mu0 * mu0 191 | mu11 = mu1 * mu1 192 | mu01 = mu0 * mu1 193 | sigma00 = filt_fn(img0 ** 2) - mu00 194 | sigma11 = filt_fn(img1 ** 2) - mu11 195 | sigma01 = filt_fn(img0 * img1) - mu01 196 | 197 | # Clip the variances and covariances to valid values. 198 | # Variance must be non-negative: 199 | sigma00 = torch.clamp(sigma00, min=0.0) 200 | sigma11 = torch.clamp(sigma11, min=0.0) 201 | sigma01 = torch.sign(sigma01) * torch.min( 202 | torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) 203 | ) 204 | 205 | c1 = (k1 * max_val) ** 2 206 | c2 = (k2 * max_val) ** 2 207 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 208 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 209 | ssim_map = numer / denom 210 | ssim = torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1) 211 | return ssim_map if return_map else ssim 212 | -------------------------------------------------------------------------------- /internal/render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for shooting and rendering rays.""" 16 | 17 | import torch 18 | from internal import stepfun 19 | 20 | 21 | def lift_gaussian(d, t_mean, t_var, r_var, diag): 22 | """Lift a Gaussian defined along a ray to 3D coordinates.""" 23 | mean = d[..., None, :] * t_mean[..., None] 24 | 25 | eps = torch.tensor(1e-10) 26 | d_mag_sq = torch.maximum(eps, torch.sum(d**2, dim=-1, keepdims=True)) 27 | 28 | if diag: 29 | d_outer_diag = d**2 30 | null_outer_diag = 1 - d_outer_diag / d_mag_sq 31 | t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :] 32 | xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :] 33 | cov_diag = t_cov_diag + xy_cov_diag 34 | return mean, cov_diag 35 | else: 36 | d_outer = d[..., :, None] * d[..., None, :] 37 | eye = torch.eye(d.shape[-1]) 38 | null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :] 39 | t_cov = t_var[..., None, None] * d_outer[..., None, :, :] 40 | xy_cov = r_var[..., None, None] * null_outer[..., None, :, :] 41 | cov = t_cov + xy_cov 42 | return mean, cov 43 | 44 | 45 | def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True): 46 | """Approximate a conical frustum as a Gaussian distribution (mean+cov). 47 | 48 | Assumes the ray is originating from the origin, and base_radius is the 49 | radius at dist=1. Doesn't assume `d` is normalized. 50 | 51 | Args: 52 | d: torch.float32 3-vector, the axis of the cone 53 | t0: float, the starting distance of the frustum. 54 | t1: float, the ending distance of the frustum. 55 | base_radius: float, the scale of the radius as a function of distance. 56 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance. 57 | stable: boolean, whether or not to use the stable computation described in 58 | the paper (setting this to False will cause catastrophic failure). 59 | 60 | Returns: 61 | a Gaussian (mean and covariance). 62 | """ 63 | if stable: 64 | # Equation 7 in the paper (https://arxiv.org/abs/2103.13415). 65 | mu = (t0 + t1) / 2 # The average of the two `t` values. 66 | hw = (t1 - t0) / 2 # The half-width of the two `t` values. 67 | eps = torch.tensor(torch.finfo(torch.float32).eps) 68 | t_mean = mu + (2 * mu * hw**2) / torch.maximum(eps, 3 * mu**2 + hw**2) 69 | denom = torch.maximum(eps, 3 * mu**2 + hw**2) 70 | t_var = (hw**2) / 3 - (4 / 15) * hw**4 * (12 * mu**2 - hw**2) / denom**2 71 | r_var = (mu**2) / 4 + (5 / 12) * hw**2 - (4 / 15) * (hw**4) / denom 72 | else: 73 | # Equations 37-39 in the paper. 74 | t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3)) 75 | r_var = 3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3) 76 | t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3) 77 | t_var = t_mosq - t_mean**2 78 | r_var *= base_radius**2 79 | return lift_gaussian(d, t_mean, t_var, r_var, diag) 80 | 81 | 82 | def cylinder_to_gaussian(d, t0, t1, radius, diag): 83 | """Approximate a cylinder as a Gaussian distribution (mean+cov). 84 | 85 | Assumes the ray is originating from the origin, and radius is the 86 | radius. Does not renormalize `d`. 87 | 88 | Args: 89 | d: torch.float32 3-vector, the axis of the cylinder 90 | t0: float, the starting distance of the cylinder. 91 | t1: float, the ending distance of the cylinder. 92 | radius: float, the radius of the cylinder 93 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance. 94 | 95 | Returns: 96 | a Gaussian (mean and covariance). 97 | """ 98 | t_mean = (t0 + t1) / 2 99 | r_var = radius**2 / 4 100 | t_var = (t1 - t0)**2 / 12 101 | return lift_gaussian(d, t_mean, t_var, r_var, diag) 102 | 103 | 104 | def cast_rays(tdist, origins, directions, radii, ray_shape, diag=True): 105 | """Cast rays (cone- or cylinder-shaped) and featurize sections of it. 106 | 107 | Args: 108 | tdist: float array, the "fencepost" distances along the ray. 109 | origins: float array, the ray origin coordinates. 110 | directions: float array, the ray direction vectors. 111 | radii: float array, the radii (base radii for cones) of the rays. 112 | ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'. 113 | diag: boolean, whether or not the covariance matrices should be diagonal. 114 | 115 | Returns: 116 | a tuple of arrays of means and covariances. 117 | """ 118 | t0 = tdist[..., :-1] 119 | t1 = tdist[..., 1:] 120 | if ray_shape == 'cone': 121 | gaussian_fn = conical_frustum_to_gaussian 122 | elif ray_shape == 'cylinder': 123 | gaussian_fn = cylinder_to_gaussian 124 | else: 125 | raise ValueError('ray_shape must be \'cone\' or \'cylinder\'') 126 | means, covs = gaussian_fn(directions, t0, t1, radii, diag) 127 | means = means + origins[..., None, :] 128 | return means, covs 129 | 130 | 131 | def compute_alpha_weights(density, tdist, dirs, opaque_background=False): 132 | """Helper function for computing alpha compositing weights.""" 133 | t_delta = tdist[..., 1:] - tdist[..., :-1] 134 | delta = t_delta * torch.linalg.norm(dirs[..., None, :], dim=-1) 135 | density_delta = density * delta 136 | 137 | if opaque_background: 138 | # Equivalent to making the final t-interval infinitely wide. 139 | density_delta = torch.cat([ 140 | density_delta[..., :-1], 141 | torch.full_like(density_delta[..., -1:], torch.inf)], dim=-1) 142 | 143 | alpha = 1 - torch.exp(-density_delta) 144 | trans = torch.exp(-torch.cat([ 145 | torch.zeros_like(density_delta[..., :1]), 146 | torch.cumsum(density_delta[..., :-1], dim=-1)], dim=-1)) 147 | weights = alpha * trans 148 | return weights, alpha, trans 149 | 150 | 151 | def volumetric_rendering(rgbs, 152 | weights, 153 | tdist, 154 | bg_rgbs, 155 | t_far, 156 | compute_extras, 157 | extras=None): 158 | """Volumetric Rendering Function. 159 | 160 | Args: 161 | rgbs: torch.ndarray(float32), color, [batch_size, num_samples, 3] 162 | weights: torch.ndarray(float32), weights, [batch_size, num_samples]. 163 | tdist: torch.ndarray(float32), [batch_size, num_samples]. 164 | bg_rgbs: torch.ndarray(float32), the color(s) to use for the background. 165 | t_far: torch.ndarray(float32), [batch_size, 1], the distance of the far plane. 166 | compute_extras: bool, if True, compute extra quantities besides color. 167 | extras: dict, a set of values along rays to render by alpha compositing. 168 | 169 | Returns: 170 | rendering: a dict containing an rgb image of size [batch_size, 3], and other 171 | visualizations if compute_extras=True. 172 | """ 173 | eps = torch.tensor(torch.finfo(torch.float32).eps) 174 | rendering = {} 175 | 176 | acc = weights.sum(dim=-1) 177 | # The weight of the background. 178 | bg_w = torch.maximum(torch.tensor(0), 1 - acc[..., None]) 179 | rgb = (weights[..., None] * rgbs).sum(dim=-2) + bg_w * bg_rgbs 180 | rendering['rgb'] = rgb 181 | 182 | if compute_extras: 183 | rendering['acc'] = acc 184 | 185 | if extras is not None: 186 | for k, v in extras.items(): 187 | if v is not None: 188 | rendering[k] = (weights[..., None] * v).sum(dim=-2) 189 | 190 | def expectation(x): 191 | return (weights * x).sum(dim=-1) / torch.max(eps, acc) 192 | 193 | t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:]) 194 | # For numerical stability this expectation is computing using log-distance. 195 | rendering['distance_mean'] = torch.clip( 196 | torch.nan_to_num( 197 | torch.exp(expectation(torch.log(t_mids))), torch.inf), 198 | tdist[..., 0], tdist[..., -1]) 199 | 200 | # Add an extra fencepost with the far distance at the end of each ray, with 201 | # whatever weight is needed to make the new weight vector sum to exactly 1 202 | # (`weights` is only guaranteed to sum to <= 1, not == 1). 203 | t_aug = torch.cat([tdist, t_far], dim=-1) 204 | weights_aug = torch.cat([weights, bg_w], dim=-1) 205 | 206 | ps = [5, 50, 95] 207 | distance_percentiles = stepfun.weighted_percentile( 208 | t_aug, weights_aug, ps) 209 | 210 | for i, p in enumerate(ps): 211 | s = 'median' if p == 50 else 'percentile_' + str(p) 212 | rendering['distance_' + s] = distance_percentiles[..., i] 213 | 214 | return rendering 215 | -------------------------------------------------------------------------------- /tests/coord_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for coord.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from internal import coord 20 | from internal import math 21 | import jax 22 | from jax import random 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | 27 | def sample_covariance(rng, batch_size, num_dims): 28 | """Sample a random covariance matrix.""" 29 | half_cov = jax.random.normal(rng, [batch_size] + [num_dims] * 2) 30 | cov = math.matmul(half_cov, jnp.moveaxis(half_cov, -1, -2)) 31 | return cov 32 | 33 | 34 | def stable_pos_enc(x, n): 35 | """A stable pos_enc for very high degrees, courtesy of Sameer Agarwal.""" 36 | sin_x = np.sin(x) 37 | cos_x = np.cos(x) 38 | output = [] 39 | rotmat = np.array([[cos_x, -sin_x], [sin_x, cos_x]], dtype='double') 40 | for _ in range(n): 41 | output.append(rotmat[::-1, 0, :]) 42 | rotmat = np.einsum('ijn,jkn->ikn', rotmat, rotmat) 43 | return np.reshape(np.transpose(np.stack(output, 0), [2, 1, 0]), [-1, 2 * n]) 44 | 45 | 46 | class CoordTest(parameterized.TestCase): 47 | 48 | def test_stable_pos_enc(self): 49 | """Test that the stable posenc implementation works on multiples of pi/2.""" 50 | n = 10 51 | x = np.linspace(-np.pi, np.pi, 5) 52 | z = stable_pos_enc(x, n).reshape([-1, 2, n]) 53 | z0_true = np.zeros_like(z[:, 0, :]) 54 | z1_true = np.ones_like(z[:, 1, :]) 55 | z0_true[:, 0] = [0, -1, 0, 1, 0] 56 | z1_true[:, 0] = [-1, 0, 1, 0, -1] 57 | z1_true[:, 1] = [1, -1, 1, -1, 1] 58 | z_true = np.stack([z0_true, z1_true], axis=1) 59 | np.testing.assert_allclose(z, z_true, atol=1e-10) 60 | 61 | def test_contract_matches_special_case(self): 62 | """Test the math for Figure 2 of https://arxiv.org/abs/2111.12077.""" 63 | n = 10 64 | _, s_to_t = coord.construct_ray_warps(jnp.reciprocal, 1, jnp.inf) 65 | s = jnp.linspace(0, 1 - jnp.finfo(jnp.float32).eps, n + 1) 66 | tc = coord.contract(s_to_t(s)[:, None])[:, 0] 67 | delta_tc = tc[1:] - tc[:-1] 68 | np.testing.assert_allclose( 69 | delta_tc, np.full_like(delta_tc, 1 / n), atol=1E-5, rtol=1E-5) 70 | 71 | def test_contract_is_bounded(self): 72 | n, d = 10000, 3 73 | rng = random.PRNGKey(0) 74 | key0, key1, rng = random.split(rng, 3) 75 | x = jnp.where(random.bernoulli(key0, shape=[n, d]), 1, -1) * jnp.exp( 76 | random.uniform(key1, [n, d], minval=-10, maxval=10)) 77 | y = coord.contract(x) 78 | self.assertLessEqual(jnp.max(y), 2) 79 | 80 | def test_contract_is_noop_when_norm_is_leq_one(self): 81 | n, d = 10000, 3 82 | rng = random.PRNGKey(0) 83 | key, rng = random.split(rng) 84 | x = random.normal(key, shape=[n, d]) 85 | xc = x / jnp.maximum(1, jnp.linalg.norm(x, axis=-1, keepdims=True)) 86 | 87 | # Sanity check on the test itself. 88 | assert jnp.abs(jnp.max(jnp.linalg.norm(xc, axis=-1)) - 1) < 1e-6 89 | 90 | yc = coord.contract(xc) 91 | np.testing.assert_allclose(xc, yc, atol=1E-5, rtol=1E-5) 92 | 93 | def test_contract_gradients_are_finite(self): 94 | # Construct x such that we probe x == 0, where things are unstable. 95 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1) 96 | grad = jax.grad(lambda x: jnp.sum(coord.contract(x)))(x) 97 | self.assertTrue(jnp.all(jnp.isfinite(grad))) 98 | 99 | def test_inv_contract_gradients_are_finite(self): 100 | z = jnp.stack(jnp.meshgrid(*[jnp.linspace(-2, 2, 21)] * 2), axis=-1) 101 | z = z.reshape([-1, 2]) 102 | z = z[jnp.sum(z**2, axis=-1) < 2, :] 103 | grad = jax.grad(lambda z: jnp.sum(coord.inv_contract(z)))(z) 104 | self.assertTrue(jnp.all(jnp.isfinite(grad))) 105 | 106 | def test_inv_contract_inverts_contract(self): 107 | """Do a round-trip from metric space to contracted space and back.""" 108 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1) 109 | x_recon = coord.inv_contract(coord.contract(x)) 110 | np.testing.assert_allclose(x, x_recon, atol=1E-5, rtol=1E-5) 111 | 112 | @parameterized.named_parameters( 113 | ('05_1e-5', 5, 1e-5), 114 | ('10_1e-4', 10, 1e-4), 115 | ('15_0.005', 15, 0.005), 116 | ('20_0.2', 20, 0.2), # At high degrees, our implementation is unstable. 117 | ('25_2', 25, 2), # 2 is the maximum possible error. 118 | ('30_2', 30, 2), 119 | ) 120 | def test_pos_enc(self, n, tol): 121 | """test pos_enc against a stable recursive implementation.""" 122 | x = np.linspace(-np.pi, np.pi, 10001) 123 | z = coord.pos_enc(x[:, None], 0, n, append_identity=False) 124 | z_stable = stable_pos_enc(x, n) 125 | max_err = np.max(np.abs(z - z_stable)) 126 | print(f'PE of degree {n} has a maximum error of {max_err}') 127 | self.assertLess(max_err, tol) 128 | 129 | def test_pos_enc_matches_integrated(self): 130 | """Integrated positional encoding with a variance of zero must be pos_enc.""" 131 | min_deg = 0 132 | max_deg = 10 133 | np.linspace(-jnp.pi, jnp.pi, 10) 134 | x = jnp.stack( 135 | jnp.meshgrid(*[np.linspace(-jnp.pi, jnp.pi, 10)] * 2), axis=-1) 136 | x = np.linspace(-jnp.pi, jnp.pi, 10000) 137 | z_ipe = coord.integrated_pos_enc(x, jnp.zeros_like(x), min_deg, max_deg) 138 | z_pe = coord.pos_enc(x, min_deg, max_deg, append_identity=False) 139 | # We're using a pretty wide tolerance because IPE uses safe_sin(). 140 | np.testing.assert_allclose(z_pe, z_ipe, atol=1e-4) 141 | 142 | def test_track_linearize(self): 143 | rng = random.PRNGKey(0) 144 | batch_size = 20 145 | for _ in range(30): 146 | # Construct some random Gaussians with dimensionalities in [1, 10]. 147 | key, rng = random.split(rng) 148 | in_dims = random.randint(key, (), 1, 10) 149 | key, rng = random.split(rng) 150 | mean = jax.random.normal(key, [batch_size, in_dims]) 151 | key, rng = random.split(rng) 152 | cov = sample_covariance(key, batch_size, in_dims) 153 | key, rng = random.split(rng) 154 | out_dims = random.randint(key, (), 1, 10) 155 | 156 | # Construct a random affine transformation. 157 | key, rng = random.split(rng) 158 | a_mat = jax.random.normal(key, [out_dims, in_dims]) 159 | key, rng = random.split(rng) 160 | b = jax.random.normal(key, [out_dims]) 161 | 162 | def fn(x): 163 | x_vec = x.reshape([-1, x.shape[-1]]) 164 | y_vec = jax.vmap(lambda z: math.matmul(a_mat, z))(x_vec) + b # pylint:disable=cell-var-from-loop 165 | y = y_vec.reshape(list(x.shape[:-1]) + [y_vec.shape[-1]]) 166 | return y 167 | 168 | # Apply the affine function to the Gaussians. 169 | fn_mean_true = fn(mean) 170 | fn_cov_true = math.matmul(math.matmul(a_mat, cov), a_mat.T) 171 | 172 | # Tracking the Gaussians through a linearized function of a linear 173 | # operator should be the same. 174 | fn_mean, fn_cov = coord.track_linearize(fn, mean, cov) 175 | np.testing.assert_allclose(fn_mean, fn_mean_true, atol=1E-5, rtol=1E-5) 176 | np.testing.assert_allclose(fn_cov, fn_cov_true, atol=1e-5, rtol=1e-5) 177 | 178 | @parameterized.named_parameters(('reciprocal', jnp.reciprocal), 179 | ('log', jnp.log), ('sqrt', jnp.sqrt)) 180 | def test_construct_ray_warps_extents(self, fn): 181 | n = 100 182 | rng = random.PRNGKey(0) 183 | key, rng = random.split(rng) 184 | t_near = jnp.exp(jax.random.normal(key, [n])) 185 | key, rng = random.split(rng) 186 | t_far = t_near + jnp.exp(jax.random.normal(key, [n])) 187 | 188 | t_to_s, s_to_t = coord.construct_ray_warps(fn, t_near, t_far) 189 | 190 | np.testing.assert_allclose( 191 | t_to_s(t_near), jnp.zeros_like(t_near), atol=1E-5, rtol=1E-5) 192 | np.testing.assert_allclose( 193 | t_to_s(t_far), jnp.ones_like(t_far), atol=1E-5, rtol=1E-5) 194 | np.testing.assert_allclose( 195 | s_to_t(jnp.zeros_like(t_near)), t_near, atol=1E-5, rtol=1E-5) 196 | np.testing.assert_allclose( 197 | s_to_t(jnp.ones_like(t_near)), t_far, atol=1E-5, rtol=1E-5) 198 | 199 | def test_construct_ray_warps_special_reciprocal(self): 200 | """Test fn=1/x against its closed form.""" 201 | n = 100 202 | rng = random.PRNGKey(0) 203 | key, rng = random.split(rng) 204 | t_near = jnp.exp(jax.random.normal(key, [n])) 205 | key, rng = random.split(rng) 206 | t_far = t_near + jnp.exp(jax.random.normal(key, [n])) 207 | 208 | key, rng = random.split(rng) 209 | u = jax.random.uniform(key, [n]) 210 | t = t_near * (1 - u) + t_far * u 211 | key, rng = random.split(rng) 212 | s = jax.random.uniform(key, [n]) 213 | 214 | t_to_s, s_to_t = coord.construct_ray_warps(jnp.reciprocal, t_near, t_far) 215 | 216 | # Special cases for fn=reciprocal. 217 | s_to_t_ref = lambda s: 1 / (s / t_far + (1 - s) / t_near) 218 | t_to_s_ref = lambda t: (t_far * (t - t_near)) / (t * (t_far - t_near)) 219 | 220 | np.testing.assert_allclose(t_to_s(t), t_to_s_ref(t), atol=1E-5, rtol=1E-5) 221 | np.testing.assert_allclose(s_to_t(s), s_to_t_ref(s), atol=1E-5, rtol=1E-5) 222 | 223 | def test_expected_sin(self): 224 | normal_samples = random.normal(random.PRNGKey(0), (10000,)) 225 | for mu, var in [(0, 1), (1, 3), (-2, .2), (10, 10)]: 226 | sin_mu = coord.expected_sin(mu, var) 227 | x = jnp.sin(jnp.sqrt(var) * normal_samples + mu) 228 | np.testing.assert_allclose(sin_mu, jnp.mean(x), atol=1e-2) 229 | 230 | def test_integrated_pos_enc(self): 231 | num_dims = 2 # The number of input dimensions. 232 | min_deg = 0 # Must be 0 for this test to work. 233 | max_deg = 4 234 | num_samples = 100000 235 | rng = random.PRNGKey(0) 236 | for _ in range(5): 237 | # Generate a coordinate's mean and covariance matrix. 238 | key, rng = random.split(rng) 239 | mean = random.normal(key, (2,)) 240 | key, rng = random.split(rng) 241 | half_cov = jax.random.normal(key, [num_dims] * 2) 242 | cov = half_cov @ half_cov.T 243 | var = jnp.diag(cov) 244 | # Generate an IPE. 245 | enc = coord.integrated_pos_enc( 246 | mean, 247 | var, 248 | min_deg, 249 | max_deg, 250 | ) 251 | 252 | # Draw samples, encode them, and take their mean. 253 | key, rng = random.split(rng) 254 | samples = random.multivariate_normal(key, mean, cov, [num_samples]) 255 | assert min_deg == 0 256 | enc_samples = np.concatenate( 257 | [stable_pos_enc(x, max_deg) for x in tuple(samples.T)], axis=-1) 258 | # Correct for a different dimension ordering in stable_pos_enc. 259 | enc_gt = jnp.mean(enc_samples, 0) 260 | enc_gt = enc_gt.reshape([num_dims, max_deg * 2]).T.reshape([-1]) 261 | np.testing.assert_allclose(enc, enc_gt, rtol=1e-2, atol=1e-2) 262 | 263 | 264 | if __name__ == '__main__': 265 | absltest.main() 266 | -------------------------------------------------------------------------------- /internal/vis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for visualizing things.""" 16 | 17 | import torch 18 | from internal import stepfun 19 | from matplotlib import cm 20 | from internal import math 21 | 22 | 23 | def weighted_percentile(x, weight, ps, assume_sorted=False): 24 | """Compute the weighted percentile(s) of a single vector.""" 25 | x = x.reshape([-1]) 26 | weight = weight.reshape([-1]) 27 | if not assume_sorted: 28 | sortidx = torch.argsort(x) 29 | x, weight = x[sortidx], weight[torch.remainder(sortidx, len(weight))] 30 | 31 | acc_w = torch.cumsum(weight, dim=0) 32 | ps = torch.tensor(ps, device=x.device) 33 | return math.interp(ps * acc_w[-1] / 100, acc_w, x) 34 | 35 | 36 | def sinebow(h): 37 | """A cyclic and uniform colormap, see http://basecase.org/env/on-rainbows.""" 38 | def f(x): return torch.sin(torch.pi * x)**2 39 | return torch.stack([f(3 / 6 - h), f(5 / 6 - h), f(7 / 6 - h)], -1) 40 | 41 | 42 | def matte(vis, acc, dark=0.8, light=1.0, width=8): 43 | """Set non-accumulated pixels to a Photoshop-esque checker pattern.""" 44 | bg_mask = torch.logical_xor( 45 | (torch.arange(acc.shape[0]) % (2 * width) // width)[:, None], 46 | (torch.arange(acc.shape[1]) % (2 * width) // width)[None, :]) 47 | bg = torch.where(bg_mask, light, dark) 48 | return vis * acc[:, :, None] + (bg * (1 - acc))[:, :, None] 49 | 50 | 51 | def visualize_cmap(value, 52 | weight, 53 | colormap, 54 | lo=None, 55 | hi=None, 56 | percentile=99., 57 | curve_fn=lambda x: x, 58 | modulus=None, 59 | matte_background=True): 60 | """Visualize a 1D image and a 1D weighting according to some colormap. 61 | 62 | Args: 63 | value: A 1D image. 64 | weight: A weight map, in [0, 1]. 65 | colormap: A colormap function. 66 | lo: The lower bound to use when rendering, if None then use a percentile. 67 | hi: The upper bound to use when rendering, if None then use a percentile. 68 | percentile: What percentile of the value map to crop to when automatically 69 | generating `lo` and `hi`. Depends on `weight` as well as `value'. 70 | curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` 71 | before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). 72 | modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If 73 | `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. 74 | matte_background: If True, matte the image over a checkerboard. 75 | 76 | Returns: 77 | A colormap rendering. 78 | """ 79 | 80 | 81 | # Identify the values that bound the middle of `value' according to `weight`. 82 | lo_auto, hi_auto = weighted_percentile( 83 | value, weight, [50 - percentile / 2, 50 + percentile / 2]) 84 | 85 | # If `lo` or `hi` are None, use the automatically-computed bounds above. 86 | eps = torch.tensor(torch.finfo(torch.float32).eps) 87 | lo = lo or (lo_auto - eps) 88 | hi = hi or (hi_auto + eps) 89 | 90 | # Curve all values. 91 | value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] 92 | 93 | # Wrap the values around if requested. 94 | if modulus: 95 | value = torch.mod(value, modulus) / modulus 96 | else: 97 | # Otherwise, just scale to [0, 1]. 98 | value = torch.nan_to_num( 99 | torch.clip((value - torch.min(lo, hi)) / torch.abs(hi - lo), 0, 1)) 100 | 101 | if colormap: 102 | colorized = torch.tensor(colormap(value.cpu())[:, :, :3], dtype=torch.float32) 103 | else: 104 | if len(value.shape) != 3: 105 | raise ValueError( 106 | f'value must have 3 dims but has {len(value.shape)}') 107 | if value.shape[-1] != 3: 108 | raise ValueError( 109 | f'value must have 3 channels but has {len(value.shape[-1])}') 110 | colorized = value 111 | 112 | return matte(colorized, weight) if matte_background else colorized 113 | 114 | 115 | def visualize_coord_mod(coords, acc): 116 | """Visualize the coordinate of each point within its "cell".""" 117 | return matte(((coords + 1) % 2) / 2, acc) 118 | 119 | 120 | def visualize_rays(dist, 121 | dist_range, 122 | weights, 123 | rgbs, 124 | accumulate=False, 125 | renormalize=False, 126 | resolution=2048, 127 | bg_color=0.8): 128 | """Visualize a bundle of rays.""" 129 | dist_vis = torch.linspace(*dist_range, resolution + 1) 130 | vis_rgb, vis_alpha = [], [] 131 | for ds, ws, rs in zip(dist, weights, rgbs): 132 | vis_rs, vis_ws = [], [] 133 | for d, w, r in zip(ds, ws, rs): 134 | if accumulate: 135 | # Produce the accumulated color and weight at each point along the ray. 136 | w_csum = torch.cumsum(w, dim=0) 137 | rw_csum = torch.cumsum((r * w[:, None]), dim=0) 138 | eps = torch.finfo(torch.float32).eps 139 | r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum 140 | vis_rs.append(stepfun.resample(dist_vis, d, r.T, use_avg=True).T) 141 | vis_ws.append(stepfun.resample(dist_vis, d, w.T, use_avg=True).T) 142 | vis_rgb.append(torch.stack(vis_rs)) 143 | vis_alpha.append(torch.stack(vis_ws)) 144 | vis_rgb = torch.stack(vis_rgb, dim=1) 145 | vis_alpha = torch.stack(vis_alpha, dim=1) 146 | 147 | if renormalize: 148 | # Scale the alphas so that the largest value is 1, for visualization. 149 | vis_alpha /= torch.max(torch.finfo(torch.float32).eps, 150 | torch.max(vis_alpha)) 151 | 152 | if resolution > vis_rgb.shape[0]: 153 | rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1) 154 | stride = rep * vis_rgb.shape[1] 155 | 156 | vis_rgb = torch.tile(vis_rgb, (1, 1, rep, 1)).reshape( 157 | (-1,) + vis_rgb.shape[2:]) 158 | vis_alpha = torch.tile(vis_alpha, (1, 1, rep)).reshape( 159 | (-1,) + vis_alpha.shape[2:]) 160 | 161 | # Add a strip of background pixels after each set of levels of rays. 162 | vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:]) 163 | vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:]) 164 | vis_rgb = torch.cat([vis_rgb, torch.zeros_like(vis_rgb[:, :1])], 165 | dim=1).reshape((-1,) + vis_rgb.shape[2:]) 166 | vis_alpha = torch.cat( 167 | [vis_alpha, torch.zeros_like(vis_alpha[:, :1])], 168 | dim=1).reshape((-1,) + vis_alpha.shape[2:]) 169 | 170 | # Matte the RGB image over the background. 171 | vis = vis_rgb * vis_alpha[..., None] + \ 172 | (bg_color * (1 - vis_alpha))[..., None] 173 | 174 | # Remove the final row of background pixels. 175 | vis = vis[:-1] 176 | vis_alpha = vis_alpha[:-1] 177 | return vis, vis_alpha 178 | 179 | 180 | def visualize_suite(rendering, rays): 181 | """A wrapper around other visualizations for easy integration.""" 182 | 183 | def depth_curve_fn(x): 184 | return -torch.log(x + torch.tensor(torch.finfo(torch.float32).eps)) 185 | 186 | rgb = rendering['rgb'] 187 | acc = rendering['acc'] 188 | 189 | distance_mean = rendering['distance_mean'] 190 | distance_median = rendering['distance_median'] 191 | distance_p5 = rendering['distance_percentile_5'] 192 | distance_p95 = rendering['distance_percentile_95'] 193 | acc = torch.where(torch.isnan(distance_mean), torch.zeros_like(acc), acc) 194 | 195 | # The xyz coordinates where rays terminate. 196 | coords = rays.origins + rays.directions * distance_mean[:, :, None] 197 | 198 | vis_depth_mean, vis_depth_median = [ 199 | visualize_cmap(x, acc, cm.get_cmap('turbo'), curve_fn=depth_curve_fn) 200 | for x in [distance_mean, distance_median] 201 | ] 202 | 203 | # Render three depth percentiles directly to RGB channels, where the spacing 204 | # determines the color. delta == big change, epsilon = small change. 205 | # Gray: A strong discontinuitiy, [x-epsilon, x, x+epsilon] 206 | # Purple: A thin but even density, [x-delta, x, x+delta] 207 | # Red: A thin density, then a thick density, [x-delta, x, x+epsilon] 208 | # Blue: A thick density, then a thin density, [x-epsilon, x, x+delta] 209 | depth_triplet = torch.stack([2 * distance_median - distance_p5, 210 | distance_median, distance_p95], dim=-1) 211 | vis_depth_triplet = visualize_cmap( 212 | depth_triplet, acc, None, 213 | curve_fn=lambda x: torch.log(x + torch.tensor(torch.finfo(torch.float32).eps))) 214 | 215 | dist = rendering['ray_sdist'] 216 | dist_range = (0, 1) 217 | weights = rendering['ray_weights'] 218 | rgbs = [torch.clip(r, 0, 1) for r in rendering['ray_rgbs']] 219 | 220 | vis_ray_colors, _ = visualize_rays(dist, dist_range, weights, rgbs) 221 | 222 | sqrt_weights = [torch.sqrt(w) for w in weights] 223 | sqrt_ray_weights, ray_alpha = visualize_rays( 224 | dist, 225 | dist_range, 226 | [torch.ones_like(lw) for lw in sqrt_weights], 227 | [lw[..., None] for lw in sqrt_weights], 228 | bg_color=0, 229 | ) 230 | sqrt_ray_weights = sqrt_ray_weights[..., 0] 231 | # print(len(sqrt_weights), sqrt_weights[0].shape, len(sqrt_ray_weights), sqrt_ray_weights[0].shape) 232 | 233 | null_color = torch.tensor([1., 0., 0.]) 234 | vis_ray_weights_cmap = visualize_cmap( 235 | sqrt_ray_weights, 236 | torch.ones_like(sqrt_ray_weights), 237 | cm.get_cmap('gray'), 238 | lo=torch.tensor(0), 239 | hi=torch.tensor(1), 240 | matte_background=False, 241 | ) 242 | 243 | vis_ray_weights = torch.where( 244 | ray_alpha[:, :, None] == 0, 245 | null_color[None, None], 246 | vis_ray_weights_cmap 247 | ) 248 | 249 | vis = { 250 | 'color': rgb, 251 | 'acc': acc, 252 | 'color_matte': matte(rgb, acc), 253 | 'depth_mean': vis_depth_mean, 254 | 'depth_median': vis_depth_median, 255 | 'depth_triplet': vis_depth_triplet, 256 | 'coords_mod': visualize_coord_mod(coords, acc), 257 | 'ray_colors': vis_ray_colors, 258 | 'ray_weights': vis_ray_weights, 259 | } 260 | 261 | if 'rgb_cc' in rendering: 262 | vis['color_corrected'] = rendering['rgb_cc'] 263 | 264 | # Render every item named "normals*". 265 | for key, val in rendering.items(): 266 | if key.startswith('normals'): 267 | vis[key] = matte(val / 2. + 0.5, acc) 268 | 269 | if 'roughness' in rendering: 270 | vis['roughness'] = matte(torch.tanh(rendering['roughness']), acc) 271 | if 'diffuse' in rendering: 272 | vis['diffuse'] = rendering['diffuse'] 273 | if 'specular' in rendering: 274 | vis['specular'] = rendering['specular'] 275 | if 'tint' in rendering: 276 | vis['tint'] = rendering['tint'] 277 | 278 | return vis 279 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2022 Georgios Kouros 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /internal/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training step and model creation functions.""" 16 | 17 | import collections 18 | import functools 19 | from typing import Any, Callable, Dict, MutableMapping, Optional, Text, Tuple 20 | import torch 21 | 22 | from internal import camera_utils 23 | from internal import configs 24 | from internal import datasets 25 | from internal import image 26 | from internal import math 27 | from internal import models 28 | from internal import ref_utils 29 | from internal import stepfun 30 | from internal import utils 31 | 32 | 33 | def compute_data_loss(batch, renderings, rays, config): 34 | """Computes data loss terms for RGB, normal, and depth outputs.""" 35 | data_losses = [] 36 | stats = collections.defaultdict(lambda: []) 37 | 38 | # lossmult can be used to apply a weight to each ray in the batch. 39 | # For example: masking out rays, applying the Bayer mosaic mask, upweighting 40 | # rays from lower resolution images and so on. 41 | lossmult = rays.lossmult 42 | lossmult = torch.broadcast_to(lossmult, batch.rgb[..., :3].shape) 43 | if config.disable_multiscale_loss: 44 | lossmult = torch.ones_like(lossmult) 45 | for rendering in renderings: 46 | resid_sq = (rendering['rgb'] - torch.tensor(batch.rgb[..., :3]))**2 47 | denom = lossmult.sum() 48 | stats['mses'].append((lossmult * resid_sq).sum() / denom) 49 | 50 | if config.data_loss_type == 'mse': 51 | # Mean-squared error (L2) loss. 52 | data_loss = resid_sq 53 | elif config.data_loss_type == 'charb': 54 | # Charbonnier loss. 55 | data_loss = torch.sqrt(resid_sq + config.charb_padding**2) 56 | else: 57 | assert False 58 | data_losses.append((lossmult * data_loss).sum() / denom) 59 | 60 | if config.compute_disp_metrics: 61 | # Using mean to compute disparity, but other distance statistics can 62 | # be used instead. 63 | disp = 1 / (1 + rendering['distance_mean']) 64 | stats['disparity_mses'].append(((disp - batch.disps)**2).mean()) 65 | 66 | if config.compute_normal_metrics: 67 | if 'normals' in rendering: 68 | weights = rendering['acc'] * batch.alphas 69 | normalized_normals_gt = ref_utils.l2_normalize(batch.normals) 70 | normalized_normals = ref_utils.l2_normalize( 71 | rendering['normals']) 72 | normal_mae = ref_utils.compute_weighted_mae(weights, normalized_normals, 73 | normalized_normals_gt) 74 | else: 75 | # If normals are not computed, set MAE to NaN. 76 | normal_mae = torch.nan 77 | 78 | stats['normal_maes'].append(normal_mae) 79 | 80 | data_losses = torch.stack(data_losses) 81 | loss = \ 82 | config.data_coarse_loss_mult * torch.sum(data_losses[:-1]) + \ 83 | config.data_loss_mult * data_losses[-1] 84 | stats = {k: torch.tensor(stats[k]) for k in stats} 85 | return loss, stats 86 | 87 | 88 | def interlevel_loss(ray_history, config): 89 | """Computes the interlevel loss defined in mip-NeRF 360.""" 90 | # Stop the gradient from the interlevel loss onto the NeRF MLP. 91 | last_ray_results = ray_history[-1] 92 | c = last_ray_results['sdist'].detach() 93 | w = last_ray_results['weights'].detach() 94 | loss_interlevel = 0. 95 | for ray_results in ray_history[:-1]: 96 | cp = ray_results['sdist'] 97 | wp = ray_results['weights'] 98 | loss_interlevel += torch.mean(stepfun.lossfun_outer(c, w, cp, wp)) 99 | return config.interlevel_loss_mult * loss_interlevel 100 | 101 | 102 | def orientation_loss(rays, model, ray_history, config): 103 | """Computes the orientation loss regularizer defined in ref-NeRF.""" 104 | total_loss = 0. 105 | zero = torch.tensor(0.0, dtype=torch.float32) 106 | for i, ray_results in enumerate(ray_history): 107 | w = ray_results['weights'] 108 | n = ray_results[config.orientation_loss_target] 109 | if n is None: 110 | raise ValueError( 111 | 'Normals cannot be None if orientation loss is on.') 112 | # Negate viewdirs to represent normalized vectors from point to camera. 113 | v = -rays.viewdirs 114 | n_dot_v = (n * v[..., None, :]).sum(dim=-1) 115 | loss = torch.mean((w * torch.minimum(zero, n_dot_v)**2).sum(dim=-1)) 116 | if i < model.num_levels - 1: 117 | total_loss += config.orientation_coarse_loss_mult * loss 118 | else: 119 | total_loss += config.orientation_loss_mult * loss 120 | return total_loss 121 | 122 | 123 | def predicted_normal_loss(model, ray_history, config): 124 | """Computes the predicted normal supervision loss defined in ref-NeRF.""" 125 | total_loss = 0. 126 | for i, ray_results in enumerate(ray_history): 127 | w = ray_results['weights'] 128 | n = ray_results['normals'] 129 | n_pred = ray_results['normals_pred'] 130 | if n is None or n_pred is None: 131 | raise ValueError( 132 | 'Predicted normals and gradient normals cannot be None if ' 133 | 'predicted normal loss is on.') 134 | loss = torch.mean( 135 | (w * (1.0 - torch.sum(n * n_pred, dim=-1))).sum(dim=-1)) 136 | if i < model.num_levels - 1: 137 | total_loss += config.predicted_normal_coarse_loss_mult * loss 138 | else: 139 | total_loss += config.predicted_normal_loss_mult * loss 140 | return total_loss 141 | 142 | 143 | def create_train_step(model: models.Model, 144 | config: configs.Config, 145 | dataset: Optional[datasets.Dataset] = None): 146 | """Creates the pmap'ed Nerf training function. 147 | 148 | Args: 149 | model: The linen model. 150 | config: The configuration. 151 | dataset: Training dataset. 152 | 153 | Returns: 154 | training function. 155 | """ 156 | if dataset is None: 157 | camtype = camera_utils.ProjectionType.PERSPECTIVE 158 | else: 159 | camtype = dataset.camtype 160 | 161 | def train_step( 162 | model, 163 | optimizer, 164 | lr_scheduler, 165 | batch, 166 | cameras, 167 | train_frac, 168 | ): 169 | """One optimization step. 170 | 171 | Args: 172 | state: TrainState, state of the model/optimizer. 173 | batch: dict, a mini-batch of data for training. 174 | cameras: module containing camera poses. 175 | train_frac: float, the fraction of training that is complete. 176 | 177 | Returns: 178 | A tuple (new_state, stats) with 179 | new_state: TrainState, new training state. 180 | stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)]. 181 | """ 182 | rays = batch.rays 183 | if config.cast_rays_in_train_step: 184 | rays = camera_utils.cast_ray_batch( 185 | cameras, rays, camtype, xnp=torch).to(device) 186 | else: 187 | rays.to(model.device) 188 | 189 | # clear gradients 190 | optimizer.zero_grad() 191 | 192 | renderings, ray_history = model( 193 | rays, 194 | train_frac=train_frac, 195 | compute_extras=\ 196 | config.compute_disp_metrics or config.compute_normal_metrics) 197 | 198 | losses = {} 199 | 200 | # calculate photometric error 201 | data_loss, stats = compute_data_loss(batch, renderings, rays, config) 202 | losses['data'] = data_loss 203 | 204 | # calculate interlevel loss 205 | if config.interlevel_loss_mult > 0: 206 | losses['interlevel'] = interlevel_loss(ray_history, config) 207 | 208 | # calculate normals orientation loss 209 | if (config.orientation_coarse_loss_mult > 0 or 210 | config.orientation_loss_mult > 0): 211 | losses['orientation'] = orientation_loss( 212 | rays, model, ray_history, config) 213 | 214 | # calculate predicted normal loss 215 | if (config.predicted_normal_coarse_loss_mult > 0 or 216 | config.predicted_normal_loss_mult > 0): 217 | losses['predicted_normals'] = predicted_normal_loss( 218 | model, ray_history, config) 219 | 220 | params = dict(model.named_parameters()) 221 | stats['weights_l2s'] = {k.replace('.', '/') : params[k].detach().norm() ** 2 for k in params} 222 | 223 | # calculate total loss 224 | loss = torch.sum(torch.stack(list(losses.values()))) 225 | stats['loss'] = loss.detach().cpu() 226 | stats['losses'] = {key: losses[key].detach().cpu() for key in losses} 227 | 228 | # backprop 229 | loss.backward() 230 | 231 | # calculate average grad and stats 232 | stats['grad_norms'] = {k.replace('.', '/') : params[k].grad.detach().cpu().norm() for k in params} 233 | stats['grad_maxes'] = {k.replace('.', '/') : params[k].grad.detach().cpu().abs().max() for k in params} 234 | 235 | # Clip gradients 236 | if config.grad_max_val > 0: 237 | torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=config.grad_max_val) 238 | if config.grad_max_norm > 0: 239 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.grad_max_norm) 240 | 241 | #TODO: set nan grads to 0 242 | 243 | # update the model weights 244 | optimizer.step() 245 | 246 | # update learning rate 247 | lr_scheduler.step() 248 | 249 | #TODO: difference between previous and current state - Redundant? 250 | # stats['opt_update_norms'] = summarize_tree(opt_delta, tree_norm) 251 | # stats['opt_update_maxes'] = summarize_tree(opt_delta, tree_abs_max) 252 | 253 | # Calculate PSNR metric 254 | stats['psnrs'] = image.mse_to_psnr(stats['mses']) 255 | stats['psnr'] = stats['psnrs'][-1] 256 | 257 | # return new state and statistics 258 | return stats 259 | 260 | return train_step 261 | 262 | 263 | def create_optimizer( 264 | config: configs.Config, 265 | params: Dict) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]: 266 | """Creates optimizer for model training.""" 267 | adam_kwargs = { 268 | 'lr': config.lr_init, 269 | 'betas': (config.adam_beta1, config.adam_beta2), 270 | 'eps': config.adam_eps, 271 | } 272 | lr_kwargs = { 273 | 'lr_init': config.lr_init, 274 | 'lr_final': config.lr_final, 275 | 'max_steps': config.max_steps, 276 | 'lr_delay_steps': config.lr_delay_steps, 277 | 'lr_delay_mult': config.lr_delay_mult, 278 | } 279 | optimizer = torch.optim.Adam(params=params, **adam_kwargs) 280 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 281 | optimizer, functools.partial(math.learning_rate_decay, **lr_kwargs)) 282 | return optimizer, lr_scheduler 283 | 284 | 285 | def create_render_fn(model: models.Model): 286 | """Creates a function for full image rendering.""" 287 | def render_eval_fn(train_frac, rays): 288 | return model( 289 | rays, 290 | train_frac=train_frac, 291 | compute_extras=True) 292 | return render_eval_fn 293 | 294 | 295 | def setup_model( 296 | config: configs.Config, 297 | dataset: Optional[datasets.Dataset] = None, 298 | ): 299 | """Creates NeRF model, optimizer, and pmap-ed train/render functions.""" 300 | 301 | dummy_rays = utils.dummy_rays() 302 | model = models.construct_model(dummy_rays, config) 303 | 304 | optimizer, lr_scheduler = create_optimizer(config, model.parameters()) 305 | render_eval_fn = create_render_fn(model) 306 | train_step = create_train_step(model, config, dataset=dataset) 307 | 308 | return model, optimizer, lr_scheduler, render_eval_fn, train_step 309 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluation script.""" 16 | 17 | import functools 18 | import os 19 | from os import path 20 | import sys 21 | import time 22 | from absl import app, flags 23 | import torch 24 | from torch.utils.tensorboard import SummaryWriter 25 | import gin 26 | from internal import configs 27 | from internal import datasets 28 | from internal import image 29 | from internal import models 30 | from internal import ref_utils 31 | from internal import train_utils 32 | from internal import utils 33 | from internal import vis 34 | import numpy as np 35 | 36 | 37 | configs.define_common_flags() 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | def main(unused_argv): 42 | config = configs.load_config(save_config=False) 43 | 44 | # setup device 45 | if torch.cuda.is_available(): 46 | device = torch.device('cuda') 47 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 48 | else: 49 | device = torch.device('cpu') 50 | torch.set_default_tensor_type('torch.FloatTensor') 51 | 52 | dataset = datasets.load_dataset('test', config.data_dir, config) 53 | setup = train_utils.setup_model(config, dataset=dataset) 54 | model, _, _, render_eval_fn, _ = setup 55 | model.eval() 56 | state = dict(step=0, model=model.state_dict()) 57 | 58 | cc_fun = image.color_correct 59 | 60 | metric_harness = image.MetricHarness() 61 | 62 | last_step = 0 63 | out_dir = os.path.join(config.checkpoint_dir, 64 | 'path_renders' if config.render_path else 'test_preds') 65 | 66 | def path_fn(x): return os.path.join(out_dir, x) 67 | 68 | if not config.eval_only_once: 69 | summary_writer = SummaryWriter( 70 | os.path.join(config.checkpoint_dir, 'eval')) 71 | 72 | while True: 73 | 74 | # load checkpoint from file if it exists 75 | files = sorted([f for f in os.listdir(config.checkpoint_dir) 76 | if f.startswith('checkpoint')], key=lambda x: int(x.split('_')[-1])) 77 | # if there are checkpoints in the dir, load the latest checkpoint 78 | if not files: 79 | print(f'No checkpoints yet. Sleeping.') 80 | time.sleep(10) 81 | continue 82 | 83 | # reload state 84 | checkpoint_name = files[-1] 85 | state = torch.load(os.path.join( 86 | config.checkpoint_dir, checkpoint_name)) 87 | model.load_state_dict(state['model']) 88 | 89 | step = int(state['step']) 90 | 91 | if step <= last_step: 92 | print( 93 | f'Checkpoint step {step} <= last step {last_step}, sleeping.') 94 | time.sleep(10) 95 | continue 96 | print(f'Evaluating checkpoint at step {step}.') 97 | if config.eval_save_output and (not utils.isdir(out_dir)): 98 | utils.makedirs(out_dir) 99 | 100 | num_eval = min(dataset.size, config.eval_dataset_limit) 101 | perm = torch.randperm(num_eval) 102 | showcase_indices = torch.sort(perm[:config.num_showcase_images]) 103 | 104 | metrics = [] 105 | metrics_cc = [] 106 | showcases = [] 107 | render_times = [] 108 | 109 | # render and evaluate all test images 110 | for idx in range(dataset.size): 111 | eval_start_time = time.time() 112 | batch = next(dataset) 113 | if idx >= num_eval: 114 | print(f'Skipping image {idx+1}/{dataset.size}') 115 | continue 116 | print(f'Evaluating image {idx+1}/{dataset.size}') 117 | rays = batch.rays 118 | train_frac = state['step'] / config.max_steps 119 | with torch.no_grad(): 120 | rendering = models.render_image( 121 | functools.partial(render_eval_fn, train_frac), rays, config) 122 | 123 | render_times.append((time.time() - eval_start_time)) 124 | print(f'Rendered in {render_times[-1]:0.3f}s') 125 | 126 | # Cast to 64-bit to ensure high precision for color correction function. 127 | gt_rgb = torch.tensor( 128 | batch.rgb, dtype=torch.float64, device=torch.device('cpu')) 129 | 130 | # move renderings to cpu to allow for metrics calculations 131 | rendering = {k: v.cpu().double() for k, v in rendering.items() if not k.startswith('ray_')} 132 | 133 | cc_start_time = time.time() 134 | rendering['rgb_cc'] = cc_fun(rendering['rgb'], gt_rgb) 135 | print(f'Color corrected in {(time.time() - cc_start_time):0.3f}s') 136 | 137 | if not config.eval_only_once and idx in showcase_indices: 138 | showcase_idx = idx if config.deterministic_showcase else len( 139 | showcases) 140 | showcases.append((showcase_idx, rendering, batch)) 141 | if not config.render_path: 142 | rgb = rendering['rgb'] 143 | rgb_cc = rendering['rgb_cc'] 144 | rgb_gt = gt_rgb 145 | 146 | if config.eval_quantize_metrics: 147 | # Ensures that the images written to disk reproduce the metrics. 148 | rgb = np.round(rgb * 255) / 255 149 | rgb_cc = np.round(rgb_cc * 255) / 255 150 | 151 | if config.eval_crop_borders > 0: 152 | def crop_fn( 153 | x, c=config.eval_crop_borders): return x[c:-c, c:-c] 154 | rgb = crop_fn(rgb) 155 | rgb_cc = crop_fn(rgb_cc) 156 | rgb_gt = crop_fn(rgb_gt) 157 | 158 | # calculate PSNR and SSIM metrics between rendering and gt 159 | metric = metric_harness(rgb, rgb_gt) 160 | metric_cc = metric_harness(rgb_cc, rgb_gt) 161 | 162 | if config.compute_disp_metrics: 163 | for tag in ['mean', 'median']: 164 | key = f'distance_{tag}' 165 | if key in rendering: 166 | disparity = 1 / (1 + rendering[key]) 167 | metric[f'disparity_{tag}_mse'] = float( 168 | ((disparity - batch.disps)**2).mean()) 169 | 170 | if config.compute_normal_metrics: 171 | weights = rendering['acc'] * batch.alphas 172 | normalized_normals_gt = ref_utils.l2_normalize( 173 | batch.normals) 174 | for key, val in rendering.items(): 175 | if key.startswith('normals') and val is not None: 176 | normalized_normals = ref_utils.l2_normalize(val) 177 | metric[key + '_mae'] = ref_utils.compute_weighted_mae( 178 | weights, normalized_normals, normalized_normals_gt) 179 | 180 | for m, v in metric.items(): 181 | print(f'{m:30s} = {v:.4f}') 182 | 183 | metrics.append(metric) 184 | metrics_cc.append(metric_cc) 185 | 186 | if config.eval_save_output and (config.eval_render_interval > 0): 187 | if (idx % config.eval_render_interval) == 0: 188 | utils.save_img_u8(rendering['rgb'], 189 | path_fn(f'color_{idx:03d}.png')) 190 | utils.save_img_u8(rendering['rgb_cc'], 191 | path_fn(f'color_cc_{idx:03d}.png')) 192 | 193 | for key in ['distance_mean', 'distance_median']: 194 | if key in rendering: 195 | utils.save_img_f32(rendering[key], 196 | path_fn(f'{key}_{idx:03d}.tiff')) 197 | 198 | for key in ['normals']: 199 | if key in rendering: 200 | utils.save_img_u8(rendering[key] / 2. + 0.5, 201 | path_fn(f'{key}_{idx:03d}.png')) 202 | 203 | utils.save_img_f32( 204 | rendering['acc'], path_fn(f'acc_{idx:03d}.tiff')) 205 | 206 | if not config.eval_only_once: 207 | summary_writer.add_scalar( 208 | 'eval_median_render_time', np.median(render_times), step) 209 | for name in metrics[0]: 210 | scores = [m[name] for m in metrics] 211 | summary_writer.add_scalar( 212 | 'eval_metrics/' + name, np.mean(scores), step) 213 | summary_writer.add_histogram( 214 | 'eval_metrics/' + 'perimage_' + name, scores, step) 215 | for name in metrics_cc[0]: 216 | scores = [m[name] for m in metrics_cc] 217 | summary_writer.add_scalar( 218 | 'eval_metrics_cc/' + name, np.mean(scores), step) 219 | summary_writer.add_histogram( 220 | 'eval_metrics_cc/' + 'perimage_' + name, scores, step) 221 | 222 | for i, r, b in showcases: 223 | if config.vis_decimate > 1: 224 | d = config.vis_decimate 225 | 226 | def decimate_fn(x, d=d): 227 | return None if x is None else x[::d, ::d] 228 | else: 229 | def decimate_fn(x): return x 230 | r = decimate_fn(r) 231 | b = decimate_fn(b) 232 | visualizations = vis.visualize_suite(r, b.rays) 233 | for k, v in visualizations.items(): 234 | summary_writer.image(f'output_{k}_{i}', v, step) 235 | if not config.render_path: 236 | target = b.rgb 237 | summary_writer.image(f'true_color_{i}', target, step) 238 | pred = visualizations['color'] 239 | residual = np.clip(pred - target + 0.5, 0, 1) 240 | summary_writer.image(f'true_residual_{i}', residual, step) 241 | if config.compute_normal_metrics: 242 | summary_writer.image(f'true_normals_{i}', b.normals / 2. + 0.5, 243 | step) 244 | 245 | if (config.eval_save_output and not config.render_path): 246 | with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f: 247 | f.write(' '.join([str(r) for r in render_times])) 248 | for name in metrics[0]: 249 | with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f: 250 | f.write(' '.join([str(m[name]) for m in metrics])) 251 | for name in metrics_cc[0]: 252 | with utils.open_file(path_fn(f'metric_cc_{name}_{step}.txt'), 'w') as f: 253 | f.write(' '.join([str(m[name]) for m in metrics_cc])) 254 | if config.eval_save_ray_data: 255 | for i, r, b in showcases: 256 | rays = {k: v for k, v in r.items() if 'ray_' in k} 257 | np.set_printoptions(threshold=sys.maxsize) 258 | with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f: 259 | f.write(repr(rays)) 260 | # import pdb; pdb.set_trace() 261 | with utils.open_file(path_fn(f'avg_metrics_{step}.txt'), 'w') as f: 262 | f.write(f'render_time: {np.mean(render_times)}\n') 263 | for name in metrics[0]: 264 | f.write(f'{name}: {np.mean([m[name] for m in metrics])}\n') 265 | for name in metrics_cc[0]: 266 | f.write( 267 | f'cc_{name}: {np.mean([m[name] for m in metrics_cc])}\n') 268 | 269 | if config.eval_only_once: 270 | break 271 | if config.early_exit_steps is not None: 272 | num_steps = config.early_exit_steps 273 | else: 274 | num_steps = config.max_steps 275 | if int(step) >= num_steps: 276 | break 277 | last_step = step 278 | 279 | 280 | if __name__ == '__main__': 281 | with gin.config_scope('eval'): 282 | app.run(main) 283 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training script.""" 16 | 17 | import functools 18 | import os 19 | import sys 20 | import gc 21 | import time 22 | import numpy as np 23 | import random 24 | import torch 25 | import flatdict 26 | import logging.config 27 | from absl import flags 28 | import absl 29 | 30 | from torch.utils.tensorboard import SummaryWriter 31 | from absl import app 32 | import gin.torch 33 | from internal import configs 34 | from internal import datasets 35 | from internal import image 36 | from internal import models 37 | from internal import train_utils 38 | from internal import utils 39 | from internal import vis 40 | 41 | configs.define_common_flags() 42 | FLAGS = flags.FLAGS 43 | TIME_PRECISION = 1000 # Internally represent integer times in milliseconds. 44 | 45 | 46 | def main(unused_argv): 47 | # load config file and save params to checkpoint folder 48 | config = configs.load_config() 49 | 50 | # setup device 51 | if torch.cuda.is_available(): 52 | device = torch.device('cuda') 53 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 54 | else: 55 | device = torch.device('cpu') 56 | torch.set_default_tensor_type('torch.FloatTensor') 57 | 58 | # set random seeds for reproducibility 59 | torch.manual_seed(20221216) 60 | np.random.seed(20221216) 61 | 62 | # load training and test sets 63 | dataset = datasets.load_dataset('train', config.data_dir, config) 64 | test_dataset = datasets.load_dataset('test', config.data_dir, config) 65 | 66 | # create model, state, rendering evaluation function, training step, and lr scheduler 67 | setup = train_utils.setup_model(config, dataset=dataset) 68 | model, optimizer, lr_scheduler, render_eval_fn, train_step = setup 69 | state = dict( 70 | step=0, 71 | model=model.state_dict(), 72 | optim=optimizer.state_dict(), 73 | lr_scheduler=lr_scheduler.state_dict(), 74 | ) 75 | 76 | # create object for calculating metrics 77 | metric_harness = image.MetricHarness() 78 | 79 | # load saved checkpoint or create checkpoint dir if not there 80 | if utils.isdir(config.checkpoint_dir): 81 | files = sorted([f for f in os.listdir(config.checkpoint_dir) 82 | if f.startswith('checkpoint')], key=lambda x: int(x.split('_')[-1])) 83 | # if there are checkpoints in the dir, load the latest checkpoint 84 | if files: 85 | checkpoint_name = files[-1] 86 | state = torch.load(os.path.join(config.checkpoint_dir, checkpoint_name)) 87 | model.load_state_dict(state['model']) 88 | optimizer.load_state_dict(state['optimizer']) 89 | lr_scheduler.load_state_dict(state['lr_scheduler']) 90 | else: 91 | utils.makedirs(config.checkpoint_dir) 92 | 93 | # setup logging to file 94 | logfile = os.path.join(config.checkpoint_dir, 'output.log') 95 | logging.getLogger().handlers = [] 96 | logging.basicConfig( 97 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", 98 | handlers=[logging.FileHandler(logfile), logging.StreamHandler(sys.stdout)]) 99 | 100 | # print the number of parameters of the model 101 | num_params = sum(p.numel() for p in model.parameters()) 102 | logging.info(f'Number of parameters being optimized: {num_params}') 103 | 104 | # Resume training at the step of the last checkpoint. 105 | init_step = state['step'] + 1 106 | 107 | # setup tensorboard for logging 108 | summary_writer = SummaryWriter(config.checkpoint_dir) 109 | 110 | # Prefetch_buffer_size = 3 x batch_size. 111 | # gc.disable() # Disable automatic garbage collection for efficiency. 112 | total_time = 0 113 | total_steps = 0 114 | reset_stats = True 115 | if config.early_exit_steps is not None: 116 | num_steps = config.early_exit_steps 117 | else: 118 | num_steps = config.max_steps 119 | 120 | # set model to training mode and send to device 121 | model.to(device) 122 | 123 | # start training loop 124 | for step, batch in zip(range(init_step, num_steps + 1), dataset): 125 | 126 | model.train() 127 | 128 | # clear stats for this iteration 129 | if reset_stats: 130 | stats_buffer = [] 131 | train_start_time = time.time() 132 | reset_stats = False 133 | 134 | # update fraction of completed training 135 | train_frac = np.clip((step - 1) / (config.max_steps - 1), 0, 1) 136 | 137 | # perform training step 138 | stats = train_step( 139 | model, 140 | optimizer, 141 | lr_scheduler, 142 | batch, 143 | dataset.cameras, 144 | train_frac, 145 | ) 146 | 147 | # if step % config.gc_every == 0: 148 | # Disable automatic garbage collection for efficiency. 149 | # gc.collect() 150 | 151 | #TODO: Redundant? 152 | del batch 153 | gc.collect() 154 | torch.cuda.empty_cache() 155 | 156 | # set model to inference mode 157 | model.eval() 158 | 159 | with torch.no_grad(): 160 | 161 | # Log training summaries 162 | stats_buffer.append(stats) 163 | 164 | if step == init_step or step % config.print_every == 0: 165 | 166 | elapsed_time = time.time() - train_start_time 167 | steps_per_sec = config.print_every / elapsed_time 168 | rays_per_sec = config.batch_size * steps_per_sec 169 | 170 | # A robust approximation of total training time, in case of pre-emption. 171 | total_time += int(round(TIME_PRECISION * elapsed_time)) 172 | total_steps += config.print_every 173 | approx_total_time = int(round(step * total_time / total_steps)) 174 | 175 | # Stack stats_buffer along axis 0. 176 | fs = [dict(flatdict.FlatDict(s, delimiter='/')) for s in stats_buffer] 177 | stats_stacked = {k: torch.stack([f[k] for f in fs]) for k in fs[0].keys()} 178 | 179 | # Split every statistic that isn't a vector into a set of statistics. 180 | stats_split = {} 181 | for k, v in stats_stacked.items(): 182 | if v.ndim not in [1, 2] and v.shape[0] != len(stats_buffer): 183 | raise ValueError('statistics must be of size [n], or [n, k].') 184 | if v.ndim == 1: 185 | stats_split[k] = v 186 | elif v.ndim == 2: 187 | for i, vi in enumerate(tuple(v.T)): 188 | stats_split[f'{k}/{i}'] = vi 189 | 190 | # Summarize the entire histogram of each statistic. 191 | for k, v in stats_split.items(): 192 | summary_writer.add_histogram('train_' + k, v, step) 193 | 194 | # Take the mean and max of each statistic since the last summary. 195 | avg_stats = {k: torch.mean(v) for k, v in stats_split.items()} 196 | max_stats = {k: torch.max(v) for k, v in stats_split.items()} 197 | 198 | # Summarize the mean and max of each statistic. 199 | for k, v in avg_stats.items(): 200 | summary_writer.add_scalar(f'train_avg_{k}', v, step) 201 | for k, v in max_stats.items(): 202 | summary_writer.add_scalar(f'train_max_{k}', v, step) 203 | summary_writer.add_scalar('train_num_params', num_params, step) 204 | summary_writer.add_scalar('train_learning_rate', *lr_scheduler.get_last_lr(), step) 205 | 206 | summary_writer.add_scalar('train_steps_per_sec', steps_per_sec, step) 207 | summary_writer.add_scalar('train_rays_per_sec', rays_per_sec, step) 208 | summary_writer.add_scalar('train_avg_psnr_timed', avg_stats['psnr'], 209 | total_time // TIME_PRECISION) 210 | summary_writer.add_scalar('train_avg_psnr_timed_approx', avg_stats['psnr'], 211 | approx_total_time // TIME_PRECISION) 212 | precision = int(np.ceil(np.log10(config.max_steps))) + 1 213 | avg_loss = avg_stats['loss'] 214 | avg_psnr = avg_stats['psnr'] 215 | str_losses = { # Grab each "losses_{x}" field and print it as "x[:4]". 216 | k[7:11]: (f'{v:0.5f}' if v >= 1e-4 and v < 10 else f'{v:0.1e}') 217 | for k, v in avg_stats.items() 218 | if k.startswith('losses/') 219 | } 220 | logging.info(f'{step:{precision}d}' + f'/{config.max_steps:d}: ' + 221 | f'loss={avg_loss:0.5f}, ' + f'psnr={avg_psnr:6.3f}, ' + 222 | f'lr={lr_scheduler.get_last_lr()[0]:0.2e} | ' + 223 | ', '.join([f'{k}={s}' for k, s in str_losses.items()]) + 224 | f', {rays_per_sec:0.0f} r/s') 225 | 226 | # Reset everything we are tracking between summarizations. 227 | reset_stats = True 228 | 229 | # Save a checkpoint on the first epoch and every Nth epoch. 230 | if step == 1 or step % config.checkpoint_every == 0: 231 | # Save checkpoint. 232 | state = dict( 233 | step=step, 234 | model=model.state_dict(), 235 | optimizer=optimizer.state_dict(), 236 | lr_scheduler=lr_scheduler.state_dict()) 237 | torch.save(state, os.path.join(config.checkpoint_dir, f'checkpoint_{step}')) 238 | 239 | # Test-set evaluation. 240 | if config.train_render_every > 0 and step % config.train_render_every == 0: 241 | # We reuse the same random number generator from the optimization step 242 | # here on purpose so that the visualization matches what happened in 243 | # training. 244 | eval_start_time = time.time() 245 | test_case = next(test_dataset) 246 | test_case.rays.to(device) 247 | 248 | # Render test image. 249 | rendering = models.render_image( 250 | functools.partial(render_eval_fn, train_frac), 251 | test_case.rays, config) 252 | 253 | # Log eval summaries. 254 | eval_time = time.time() - eval_start_time 255 | num_rays = np.prod(np.array(test_case.rays.directions.shape[:-1])) 256 | rays_per_sec = num_rays / eval_time 257 | summary_writer.add_scalar('test_rays_per_sec', rays_per_sec, step) 258 | logging.info(f'Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec') 259 | 260 | # Compute metrics. 261 | if config.compute_eval_metrics: 262 | metric_start_time = time.time() 263 | metric = metric_harness(rendering['rgb'], test_case.rgb) 264 | logging.info(f'Metrics computed in {(time.time() - metric_start_time):0.3f}s') 265 | for name, val in metric.items(): 266 | if not np.isnan(val): 267 | logging.info(f'{name} = {val:.4f}') 268 | summary_writer.add_scalar( 269 | 'train_metrics/' + name, val, step) 270 | 271 | # Log images to tensorboard. 272 | vis_start_time = time.time() 273 | vis_suite = vis.visualize_suite(rendering, test_case.rays) 274 | logging.info(f'Visualized in {(time.time() - vis_start_time):0.3f}s') 275 | summary_writer.add_image( 276 | 'test_true_color', test_case.rgb, step, dataformats='HWC') 277 | if config.compute_normal_metrics: 278 | summary_writer.add_image( 279 | 'test_true_normals', test_case.normals / 2. + 0.5, step, 280 | dataformats='HWC') 281 | for k, v in vis_suite.items(): 282 | summary_writer.add_image( 283 | 'test_output_' + k, v, step, 284 | dataformats='HWC' if len(v.shape) == 3 else 'HW') 285 | 286 | # save last checkpoint if it wasn't already saved. 287 | if config.max_steps % config.checkpoint_every != 0: 288 | state = dict( 289 | step=config.max_steps, 290 | model=model.state_dict(), 291 | optimizer=optimizer.state_dict(), 292 | lr_scheduler=lr_scheduler.state_dict()) 293 | torch.save(state, os.path.join( 294 | config.checkpoint_dir, f'checkpoint_{config.max_steps}')) 295 | 296 | 297 | if __name__ == '__main__': 298 | with gin.config_scope('train'): 299 | FLAGS(sys.argv) 300 | main(sys.argv) 301 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # refnerf-pytorch: An unofficial port of the JAX-based Ref-NeRF code release to PyTorch 2 | 3 | You can find the original jax version released by Google at [https://github.com/google-research/multinerf](https://github.com/google-research/multinerf) 4 | 5 | *This is not an officially supported Google product.* 6 | 7 | This repository contains the code release for the CVPR2022 Ref-NeRF paper: 8 | [Ref-NeRF](https://dorverbin.github.io/refnerf/) 9 | This codebase was adapted from the [multinerf](https://github.com/google/multinerf) 10 | code release that combines the mip-NeRF-360, Raw-NeRF, Ref-NeRF papers from CVPR 2022. 11 | The original release for Ref-NeRF differs from the Google's internal codebase that was 12 | used for the paper and since it's a port from [JAX](https://github.com/google/jax) 13 | to [PyTorch](https://github.com/pytorch/pytorch) the results might be different. 14 | 15 | This is research code, and should be treated accordingly. 16 | 17 | ## Setup 18 | 19 | ``` 20 | # Clone the repo. 21 | git clone https://github.com/gkouros/refnerf-pytorch.git 22 | cd refnerf-pytorch 23 | 24 | # Make a conda environment. 25 | conda create --name refnerf python=3.9 26 | conda activate refnerf 27 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 28 | 29 | # Prepare pip. 30 | conda install pip 31 | pip install --upgrade pip 32 | 33 | # Install requirements. 34 | pip install -r requirements.txt 35 | 36 | # Manually install rmbrualla's `pycolmap` (don't use pip's! It's different). 37 | git clone https://github.com/rmbrualla/pycolmap.git ./internal/pycolmap 38 | 39 | # Confirm that all the unit tests pass. Scripts have not been ported to PyTorch yet. 40 | ./scripts/run_all_unit_tests.sh 41 | ``` 42 | 43 | ## Running 44 | 45 | Example scripts for training, evaluating, and rendering can be found in 46 | `scripts/`. You'll need to change the paths to point to wherever the datasets 47 | are located. [Gin](https://github.com/google/gin-config) configuration files 48 | for our model and some ablations can be found in `configs/`. 49 | After evaluating on the test set of each scene in one of the datasets, you can 50 | use `scripts/generate_tables.ipynb` to produce error metrics across all scenes 51 | in the same format as was used in tables in the paper. 52 | 53 | ### OOM errors 54 | 55 | You may need to reduce the batch size (`Config.batch_size`) to avoid out of memory 56 | errors. If you do this, but want to preserve quality, be sure to increase the number 57 | of training iterations and decrease the learning rate by whatever scale factor you 58 | decrease batch size by. 59 | 60 | ## Using your own data 61 | 62 | Summary: first, calculate poses. Second, train Ref-NeRF. Third, render a result video from the trained NeRF model. 63 | 64 | 1. Calculating poses (using COLMAP): 65 | ``` 66 | DATA_DIR=my_dataset_dir 67 | bash scripts/local_colmap_and_resize.sh ${DATA_DIR} 68 | ``` 69 | 2. Training Ref-NeRF: 70 | ``` 71 | python -m train \ 72 | --gin_configs=configs/llff_refnerf.gin \ 73 | --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ 74 | --gin_bindings="Config.checkpoint_dir = '${DATA_DIR}/checkpoints'" \ 75 | --logtostderr 76 | ``` 77 | 3. Rendering Ref-NeRF: 78 | ``` 79 | python -m render \ 80 | --gin_configs=configs/render_config.gin \ 81 | --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ 82 | --gin_bindings="Config.checkpoint_dir = '${DATA_DIR}/checkpoints'" \ 83 | --gin_bindings="Config.render_dir = '${DATA_DIR}/render'" \ 84 | --gin_bindings="Config.render_path = True" \ 85 | --gin_bindings="Config.render_path_frames = 480" \ 86 | --gin_bindings="Config.render_video_fps = 60" \ 87 | --logtostderr 88 | ``` 89 | Your output video should now exist in the directory `my_dataset_dir/render/`. 90 | 91 | See below for more detailed instructions on either using COLMAP to calculate poses or writing your own dataset loader (if you already have pose data from another source, like SLAM or RealityCapture). 92 | 93 | ### Running COLMAP to get camera poses 94 | 95 | In order to run Ref-NeRF on your own captured images of a scene, you must first run [COLMAP](https://colmap.github.io/install.html) to calculate camera poses. You can do this using our provided script `scripts/local_colmap_and_resize.sh`. Just make a directory `my_dataset_dir/` and copy your input images into a folder `my_dataset_dir/images/`, then run: 96 | ``` 97 | bash scripts/local_colmap_and_resize.sh my_dataset_dir 98 | ``` 99 | This will run COLMAP and create 2x, 4x, and 8x downsampled versions of your images. These lower resolution images can be used in NeRF by setting, e.g., the `Config.factor = 4` gin flag. 100 | 101 | By default, `local_colmap_and_resize.sh` uses the OPENCV camera model, which is a perspective pinhole camera with k1, k2 radial and t1, t2 tangential distortion coefficients. To switch to another COLMAP camera model, for example OPENCV_FISHEYE, you can run 102 | ``` 103 | bash scripts/local_colmap_and_resize.sh my_dataset_dir OPENCV_FISHEYE 104 | ``` 105 | 106 | If you have a very large capture of more than around 500 images, we recommend switching from the exhaustive matcher to the vocabulary tree matcher in COLMAP (see the script for a commented-out example). 107 | 108 | Our script is simply a thin wrapper for COLMAP--if you have run COLMAP yourself, all you need to do to load your scene in NeRF is ensure it has the following format: 109 | ``` 110 | my_dataset_dir/images/ <--- all input images 111 | my_dataset_dir/sparse/0/ <--- COLMAP sparse reconstruction files (cameras, images, points) 112 | ``` 113 | 114 | ### Writing a custom dataloader 115 | 116 | If you already have poses for your own data, you may prefer to write your own custom dataloader. 117 | 118 | Ref-NeRF includes a variety of dataloaders, all of which inherit from the 119 | base 120 | [Dataset class](https://github.com/gkouros/refnerf-pytorch/blob/main/internal/datasets.py#L152). 121 | 122 | The job of this class is to load all image and pose information from disk, then 123 | create batches of ray and color data for training or rendering a NeRF model. 124 | 125 | Any inherited subclass is responsible for loading images and camera poses from 126 | disk by implementing the `_load_renderings` method (which is marked as 127 | abstract by the decorator `@abc.abstractmethod`). This data is then used to 128 | generate train and test batches of ray + color data for feeding through the NeRF 129 | model. The ray parameters are calculated in `_make_ray_batch`. 130 | 131 | #### Existing data loaders 132 | 133 | To work from an example, you can see how this function is overloaded for the 134 | different dataloaders we have already implemented: 135 | 136 | - [Blender](https://github.com/gkouros/refnerf-pytorch/blob/main/internal/datasets.py#L470) 137 | - [DTU dataset](https://github.com/google-research/gkouros/refnerf-pytorch/blob/main/internal/datasets.py#L793) 138 | - [Tanks and Temples](https://github.com/google-research/gkouros/refnerf-pytorch/blob/main/internal/datasets.py#L680), 139 | as processed by the NeRF++ paper 140 | - [Tanks and Temples](https://github.com/google-research/gkouros/refnerf-pytorch/blob/main/internal/datasets.py#L728), 141 | as processed by the Free View Synthesis paper 142 | 143 | The main data loader we rely on is 144 | [LLFF](https://github.com/google-research/gkouros/refnerf-pytorch/blob/main/internal/datasets.py#L526) 145 | (named for historical reasons), which is the loader for a dataset that has been 146 | posed by COLMAP. 147 | 148 | #### Making your own loader by implementing `_load_renderings` 149 | 150 | To make a new dataset, make a class inheriting from `Dataset` and overload the 151 | `_load_renderings` method: 152 | 153 | ``` 154 | class MyNewDataset(Dataset): 155 | def _load_renderings(self, config): 156 | ... 157 | ``` 158 | 159 | In this function, you **must** set the following public attributes: 160 | 161 | - images 162 | - camtoworlds 163 | - pixtocams 164 | - height, width 165 | 166 | Many of our dataset loaders also set other useful attributes, but these are the 167 | critical ones for generating rays. You can see how they are used (along with a batch of pixel coordinates) to create rays in [`camera_utils.pixels_to_rays`](https://github.com/gkouros/refnerf-pytorch/blob/main/internal/camera_utils.py#L520). 168 | 169 | **Images** 170 | 171 | `images` = [N, height, width, 3] numpy array of RGB images. Currently we 172 | require all images to have the same resolution. 173 | 174 | **Extrinsic camera poses** 175 | 176 | `camtoworlds` = [N, 3, 4] numpy array of extrinsic pose matrices. 177 | `camtoworlds[i]` should be in **camera-to-world** format, such that we can run 178 | 179 | ``` 180 | pose = camtoworlds[i] 181 | x_world = pose[:3, :3] @ x_camera + pose[:3, 3:4] 182 | ``` 183 | 184 | to convert a 3D camera space point `x_camera` into a world space point `x_world`. 185 | 186 | These matrices must be stored in the **OpenGL** coordinate system convention for camera rotation: 187 | x-axis to the right, y-axis upward, and z-axis backward along the camera's focal 188 | axis. 189 | 190 | The most common conventions are 191 | 192 | - `[right, up, backwards]`: OpenGL, NeRF, most graphics code. 193 | - `[right, down, forwards]`: OpenCV, COLMAP, most computer vision code. 194 | 195 | Fortunately switching from OpenCV/COLMAP to NeRF is 196 | [simple](https://github.com/gkouros/refnerf-pytorch/blob/main/internal/datasets.py#L108): 197 | you just need to right-multiply the OpenCV pose matrices by `np.diag([1, -1, -1, 1])`, 198 | which will flip the sign of the y-axis (from down to up) and z-axis (from 199 | forwards to backwards): 200 | ``` 201 | camtoworlds_opengl = camtoworlds_opencv @ np.diag([1, -1, -1, 1]) 202 | ``` 203 | 204 | You may also want to **scale** your camera pose translations such that they all 205 | lie within the `[-1, 1]^3` cube for best performance with the default mipnerf360 206 | config files. 207 | 208 | We provide a useful helper function [`camera_utils.transform_poses_pca`](https://github.com/gkouros/refnerf-pytorch/blob/main/internal/camera_utils.py#L191) that computes a translation/rotation/scaling transform for the input poses that aligns the world space x-y plane with the ground (based on PCA) and scales the scene so that all input pose positions lie within `[-1, 1]^3`. (This function is applied by default when loading mip-NeRF 360 scenes with the LLFF data loader.) For a scene where this transformation has been applied, [`camera_utils.generate_ellipse_path`](https://github.com/gkouros/refnerf-pytorch/blob/main/internal/camera_utils.py#L230) can be used to generate a nice elliptical camera path for rendering videos. 209 | 210 | **Intrinsic camera poses** 211 | 212 | `pixtocams`= [N, 3, 4] numpy array of inverse intrinsic matrices, OR [3, 4] 213 | numpy array of a single shared inverse intrinsic matrix. These should be in 214 | **OpenCV** format, e.g. 215 | 216 | ``` 217 | camtopix = np.array([ 218 | [focal, 0, width/2], 219 | [ 0, focal, height/2], 220 | [ 0, 0, 1], 221 | ]) 222 | pixtocam = np.linalg.inv(camtopix) 223 | ``` 224 | 225 | Given a focal length and image size (and assuming a centered principal point, 226 | this matrix can be created using 227 | [`camera_utils.get_pixtocam`](https://github.com/gkouros/refnerf-pytorch/blob/main/internal/camera_utils.py#L411). 228 | 229 | Alternatively, it can be created by using 230 | [`camera_utils.intrinsic_matrix`](https://github.com/gkouros/refnerf-pytorch/blob/main/internal/camera_utils.py#L398) 231 | and inverting the resulting matrix. 232 | 233 | **Resolution** 234 | 235 | `height` = int, height of images. 236 | 237 | `width` = int, width of images. 238 | 239 | **Distortion parameters (optional)** 240 | 241 | `distortion_params` = dict, camera lens distortion model parameters. This 242 | dictionary must map from strings -> floats, and the allowed keys are `['k1', 243 | 'k2', 'k3', 'k4', 'p1', 'p2']` (up to four radial coefficients and up to two 244 | tangential coefficients). By default, this is set to the empty dictionary `{}`, 245 | in which case undistortion is not run. 246 | 247 | ### Details of the inner workings of Dataset 248 | 249 | The public interface mimics the behavior of a standard machine learning pipeline 250 | dataset provider that can provide infinite batches of data to the 251 | training/testing pipelines without exposing any details of how the batches are 252 | loaded/created or how this is parallelized. Therefore, the initializer runs all 253 | setup, including data loading from disk using `_load_renderings`, and begins 254 | the thread using its parent start() method. After the initializer returns, the 255 | caller can request batches of data straight away. 256 | 257 | The internal `self._queue` is initialized as `queue.Queue(3)`, so the infinite 258 | loop in `run()` will block on the call `self._queue.put(self._next_fn())` once 259 | there are 3 elements. The main thread training job runs in a loop that pops 1 260 | element at a time off the front of the queue. The Dataset thread's `run()` loop 261 | will populate the queue with 3 elements, then wait until a batch has been 262 | removed and push one more onto the end. 263 | 264 | This repeats indefinitely until the main thread's training loop completes 265 | (typically hundreds of thousands of iterations), then the main thread will exit 266 | and the Dataset thread will automatically be killed since it is a daemon. 267 | 268 | 269 | ## Citation 270 | If you use this software package or build on top of it, please use the following citation: 271 | ``` 272 | @misc{refnerf-pytorch, 273 | title={refnerf-pytorch: A port of Ref-NeRF from jax to pytorch}, 274 | author={Georgios Kouros}, 275 | year={2022}, 276 | url={https://github.com/gkouros/refnerf-pytorch}, 277 | } 278 | ``` 279 | ## References 280 | ``` 281 | @article{verbin2022refnerf, 282 | title={{Ref-NeRF}: Structured View-Dependent Appearance for 283 | Neural Radiance Fields}, 284 | author={Dor Verbin and Peter Hedman and Ben Mildenhall and 285 | Todd Zickler and Jonathan T. Barron and Pratul P. Srinivasan}, 286 | journal={CVPR}, 287 | year={2022} 288 | } 289 | ``` 290 | ``` 291 | @misc{multinerf2022, 292 | title={{MultiNeRF}: {A} {Code} {Release} for {Mip-NeRF} 360, {Ref-NeRF}, and {RawNeRF}}, 293 | author={Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman and Ricardo Martin-Brualla and Jonathan T. Barron}, 294 | year={2022}, 295 | url={https://github.com/google-research/multinerf}, 296 | } 297 | ``` 298 | -------------------------------------------------------------------------------- /internal/stepfun.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for manipulating step functions (piecewise-constant 1D functions). 16 | 17 | We have a shared naming and dimension convention for these functions. 18 | All input/output step functions are assumed to be aligned along the last axis. 19 | `t` always indicates the x coordinates of the *endpoints* of a step function. 20 | `y` indicates unconstrained values for the *bins* of a step function 21 | `w` indicates bin weights that sum to <= 1. `p` indicates non-negative bin 22 | values that *integrate* to <= 1. 23 | """ 24 | 25 | import torch 26 | import functorch 27 | from internal import math 28 | import numpy as np 29 | 30 | 31 | def searchsorted(a, v): 32 | """Find indices where v should be inserted into a to maintain order. 33 | 34 | This behaves like torch.searchsorted (its second output is the same as 35 | torch.searchsorted's output if all elements of v are in [a[0], a[-1]]) but is 36 | faster because it wastes memory to save some compute. 37 | 38 | Args: 39 | a: tensor, the sorted reference points that we are scanning to see where v 40 | should lie. 41 | v: tensor, the query points that we are pretending to insert into a. Does 42 | not need to be sorted. All but the last dimensions should match or expand 43 | to those of a, the last dimension can differ. 44 | 45 | Returns: 46 | (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the 47 | range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or 48 | last index of a. 49 | """ 50 | i = torch.arange(a.shape[-1]) 51 | v_ge_a = v[..., None, :] >= a[..., :, None] 52 | idx_lo = torch.max(torch.where( 53 | v_ge_a, i[..., :, None], i[..., :1, None]), dim=-2).values 54 | idx_hi = torch.min(torch.where( 55 | ~v_ge_a, i[..., :, None], i[..., -1:, None]), dim=-2).values 56 | return idx_lo, idx_hi 57 | 58 | 59 | def query(tq, t, y, outside_value=0): 60 | """Look up the values of the step function (t, y) at locations tq.""" 61 | idx_lo, idx_hi = searchsorted(t, tq) 62 | yq = torch.where(idx_lo == idx_hi, outside_value, 63 | torch.take_along_dim(y, idx_lo, dim=-1)) 64 | return yq 65 | 66 | 67 | def inner_outer(t0, t1, y1): 68 | """Construct inner and outer measures on (t1, y1) for t0.""" 69 | cy1 = torch.cat([torch.zeros_like(y1[..., :1]), 70 | torch.cumsum(y1, dim=-1)], 71 | dim=-1) 72 | idx_lo, idx_hi = searchsorted(t1, t0) 73 | 74 | cy1_lo = torch.take_along_dim(cy1, idx_lo, dim=-1) 75 | cy1_hi = torch.take_along_dim(cy1, idx_hi, dim=-1) 76 | 77 | y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1] 78 | y0_inner = torch.where(idx_hi[..., :-1] <= idx_lo[..., 1:], 79 | cy1_lo[..., 1:] - cy1_hi[..., :-1], 0) 80 | return y0_inner, y0_outer 81 | 82 | 83 | def lossfun_outer(t, w, t_env, w_env, eps=torch.finfo(torch.float32).eps): 84 | """The proposal weight should be an upper envelope on the nerf weight.""" 85 | _, w_outer = inner_outer(t, t_env, w_env) 86 | # We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's 87 | # more effective to pull w_outer up than it is to push w_inner down. 88 | # Scaled half-quadratic loss that gives a constant gradient at w_outer = 0. 89 | return torch.maximum(torch.tensor(.0), w - w_outer)**2 / (w + eps) 90 | 91 | 92 | def weight_to_pdf(t, w, eps=torch.tensor(torch.finfo(torch.float32).eps**2)): 93 | """Turn a vector of weights that sums to 1 into a PDF that integrates to 1.""" 94 | return w / torch.maximum(eps, (t[..., 1:] - t[..., :-1])) 95 | 96 | 97 | def pdf_to_weight(t, p): 98 | """Turn a PDF that integrates to 1 into a vector of weights that sums to 1.""" 99 | return p * (t[..., 1:] - t[..., :-1]) 100 | 101 | 102 | def max_dilate(t, w, dilation, domain=(-torch.tensor(float('inf')), 103 | torch.tensor(float('inf')))): 104 | """Dilate (via max-pooling) a non-negative step function.""" 105 | t0 = t[..., :-1] - dilation 106 | t1 = t[..., 1:] + dilation 107 | t_dilate = torch.sort(torch.cat([t, t0, t1], dim=-1), dim=-1).values 108 | t_dilate = torch.clip(t_dilate, *domain) 109 | w_dilate = torch.max( 110 | torch.where( 111 | (t0[..., None, :] <= t_dilate[..., None]) 112 | & (t1[..., None, :] > t_dilate[..., None]), 113 | w[..., None, :], 0), dim=-1).values[..., :-1] 114 | return t_dilate, w_dilate 115 | 116 | 117 | def max_dilate_weights( 118 | t, 119 | w, 120 | dilation, 121 | domain=(-torch.tensor(float('inf')), torch.tensor(float('inf'))), 122 | renormalize=False, 123 | eps=torch.tensor(torch.finfo(torch.float32).eps**2)): 124 | """Dilate (via max-pooling) a set of weights.""" 125 | p = weight_to_pdf(t, w) 126 | t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain) 127 | w_dilate = pdf_to_weight(t_dilate, p_dilate) 128 | if renormalize: 129 | w_dilate /= torch.maximum(eps, 130 | torch.sum(w_dilate, dim=-1, keepdims=True)) 131 | return t_dilate, w_dilate 132 | 133 | 134 | def integrate_weights(w): 135 | """Compute the cumulative sum of w, assuming all weight vectors sum to 1. 136 | 137 | The output's size on the last dimension is one greater than that of the input, 138 | because we're computing the integral corresponding to the endpoints of a step 139 | function, not the integral of the interior/bin values. 140 | 141 | Args: 142 | w: Tensor, which will be integrated along the last axis. This is assumed to 143 | sum to 1 along the last axis, and this function will (silently) break if 144 | that is not the case. 145 | 146 | Returns: 147 | cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 148 | """ 149 | cw = torch.minimum(torch.tensor(1), torch.cumsum(w[..., :-1], dim=-1)) 150 | shape = cw.shape[:-1] + (1,) 151 | # Ensure that the CDF starts with exactly 0 and ends with exactly 1. 152 | cw0 = torch.cat( 153 | [torch.zeros(shape), cw, torch.ones(shape)], dim=-1) 154 | return cw0 155 | 156 | 157 | def invert_cdf(u, t, w_logits, use_gpu_resampling=False): 158 | """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" 159 | # Compute the PDF and CDF for each weight vector. 160 | w = torch.nn.functional.softmax(w_logits, dim=-1) 161 | cw = integrate_weights(w) 162 | # Interpolate into the inverse CDF. 163 | interp_fn = math.interp if use_gpu_resampling else math.sorted_interp 164 | t_new = interp_fn(u, cw, t) 165 | return t_new 166 | 167 | 168 | def sample( 169 | t, 170 | w_logits, 171 | num_samples, 172 | single_jitter=False, 173 | deterministic_center=False, 174 | use_gpu_resampling=False 175 | ): 176 | """Piecewise-Constant PDF sampling from a step function. 177 | 178 | Args: 179 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) 180 | w_logits: [..., num_bins], logits corresponding to bin weights 181 | num_samples: int, the number of samples. 182 | single_jitter: bool, if True, jitter every sample along each ray by the same 183 | amount in the inverse CDF. Otherwise, jitter each sample independently. 184 | deterministic_center: bool, if False, when `rng` is None return samples that 185 | linspace the entire PDF. If True, skip the front and back of the linspace 186 | so that the centers of each PDF interval are returned. 187 | use_gpu_resampling: bool, If True this resamples the rays based on a 188 | "gather" instruction, which is fast on GPUs but slow on TPUs. If False, 189 | this resamples the rays based on brute-force searches, which is fast on 190 | TPUs, but slow on GPUs. 191 | 192 | Returns: 193 | t_samples: torch.ndarray(float32), [batch_size, num_samples]. 194 | """ 195 | eps = torch.tensor(torch.finfo(torch.float32).eps) 196 | 197 | # Draw uniform samples. 198 | # Match the behavior of jax.random.uniform() by spanning [0, 1-eps]. 199 | if deterministic_center: 200 | pad = 1 / (2 * num_samples) 201 | u = torch.linspace(pad, 1. - pad - eps, num_samples) 202 | else: 203 | u = torch.linspace(0, 1. - eps, num_samples) 204 | u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,)) 205 | 206 | return invert_cdf(u, t, w_logits, use_gpu_resampling=use_gpu_resampling) 207 | 208 | 209 | def sample_intervals( 210 | t, 211 | w_logits, 212 | num_samples, 213 | single_jitter=False, 214 | domain=(-torch.tensor(float('inf')), torch.tensor(float('inf'))), 215 | use_gpu_resampling=False 216 | ): 217 | """Sample *intervals* (rather than points) from a step function. 218 | 219 | Args: 220 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) 221 | w_logits: [..., num_bins], logits corresponding to bin weights 222 | num_samples: int, the number of intervals to sample. 223 | single_jitter: bool, if True, jitter every sample along each ray by the same 224 | amount in the inverse CDF. Otherwise, jitter each sample independently. 225 | domain: (minval, maxval), the range of valid values for `t`. 226 | use_gpu_resampling: bool, If True this resamples the rays based on a 227 | "gather" instruction, which is fast on GPUs but slow on TPUs. If False, 228 | this resamples the rays based on brute-force searches, which is fast on 229 | TPUs, but slow on GPUs. 230 | 231 | Returns: 232 | t_samples: torch.ndarray(float32), [batch_size, num_samples]. 233 | """ 234 | if num_samples <= 1: 235 | raise ValueError(f'num_samples must be > 1, is {num_samples}.') 236 | 237 | # Sample a set of points from the step function. 238 | centers = sample( 239 | t, 240 | w_logits, 241 | num_samples, 242 | single_jitter, 243 | deterministic_center=True, 244 | use_gpu_resampling=use_gpu_resampling) 245 | 246 | # The intervals we return will span the midpoints of each adjacent sample. 247 | mid = (centers[..., 1:] + centers[..., :-1]) / 2 248 | 249 | # Each first/last fencepost is the reflection of the first/last midpoint 250 | # around the first/last sampled center. We clamp to the limits of the input 251 | # domain, provided by the caller. 252 | minval, maxval = domain 253 | first = torch.maximum(torch.tensor(minval), 2 * 254 | centers[..., :1] - mid[..., :1]) 255 | last = torch.minimum(torch.tensor(maxval), 2 * 256 | centers[..., -1:] - mid[..., -1:]) 257 | t_samples = torch.cat([first, mid, last], dim=-1) 258 | return t_samples 259 | 260 | 261 | def lossfun_distortion(t, w): 262 | """Compute iint w[i] w[j] |t[i] - t[j]| di dj.""" 263 | # The loss incurred between all pairs of intervals. 264 | ut = (t[..., 1:] + t[..., :-1]) / 2 265 | dut = torch.abs(ut[..., :, None] - ut[..., None, :]) 266 | loss_inter = torch.sum( 267 | w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) 268 | 269 | # The loss incurred within each individual interval with itself. 270 | loss_intra = torch.sum(w**2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 271 | 272 | return loss_inter + loss_intra 273 | 274 | 275 | def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi): 276 | """Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi]).""" 277 | # Distortion when the intervals do not overlap. 278 | d_disjoint = torch.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2) 279 | 280 | # Distortion when the intervals overlap. 281 | d_overlap = ( 282 | 2 * (torch.minimum(t0_hi, t1_hi)**3 - torch.maximum(t0_lo, t1_lo)**3) + 283 | 3 * (t1_hi * t0_hi * torch.abs(t1_hi - t0_hi) + 284 | t1_lo * t0_lo * torch.abs(t1_lo - t0_lo) + t1_hi * t0_lo * 285 | (t0_lo - t1_hi) + t1_lo * t0_hi * 286 | (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo)) 287 | 288 | # Are the two intervals not overlapping? 289 | are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi) 290 | 291 | return torch.where(are_disjoint, d_disjoint, d_overlap) 292 | 293 | 294 | def weighted_percentile(t, w, ps): 295 | """Compute the weighted percentiles of a step function. w's must sum to 1.""" 296 | cw = integrate_weights(w) 297 | # We want to interpolate into the integrated weights according to `ps`. 298 | 299 | def fn(cw_i, t_i): 300 | return math.interp(torch.tensor(ps) / 100, cw_i, t_i) 301 | 302 | # Vmap fn to an arbitrary number of leading dimensions. 303 | cw_mat = cw.reshape([-1, cw.shape[-1]]) 304 | t_mat = t.reshape([-1, t.shape[-1]]) 305 | wprctile_mat = (functorch.vmap(fn, 0)(cw_mat, t_mat)) 306 | wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),)) 307 | return wprctile 308 | 309 | 310 | def resample(t, tp, vp, use_avg=False, 311 | eps=torch.tensor(torch.finfo(torch.float32).eps)): 312 | """Resample a step function defined by (tp, vp) into intervals t. 313 | 314 | Notation roughly matches jnp.interp. Resamples by summation by default. 315 | 316 | Args: 317 | t: tensor with shape (..., n+1), the endpoints to resample into. 318 | tp: tensor with shape (..., m+1), the endpoints of the step function being 319 | resampled. 320 | vp: tensor with shape (..., m), the values of the step function being 321 | resampled. 322 | use_avg: bool, if False, return the sum of the step function for each 323 | interval in `t`. If True, return the average, weighted by the width of 324 | each interval in `t`. 325 | eps: float, a small value to prevent division by zero when use_avg=True. 326 | 327 | Returns: 328 | v: tensor with shape (..., n), the values of the resampled step function. 329 | """ 330 | if use_avg: 331 | wp = torch.diff(tp, dim=-1) 332 | v_numer = resample(t, tp, vp * wp, use_avg=False) 333 | v_denom = resample(t, tp, wp, use_avg=False) 334 | v = v_numer / torch.maximum(eps, v_denom) 335 | return v 336 | 337 | acc = torch.cumsum(vp, dim=-1) 338 | acc0 = torch.cat([torch.zeros(acc.shape[:-1] + (1,)), acc], dim=-1) 339 | 340 | if len(acc0.shape) == 2: 341 | acc0_resampled = torch.stack([ 342 | math.interp(t, tp, acc0[dim]) for dim in range(len(acc0))], dim=0) 343 | else: 344 | acc0_resampled = math.interp(t, tp, acc0) 345 | 346 | v = torch.diff(acc0_resampled, dim=-1) 347 | 348 | return v 349 | --------------------------------------------------------------------------------