├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── LICENSES ├── LICENSE-DVGO ├── LICENSE-MipNeRF ├── LICENSE-MultiNeRF ├── LICENSE-NeRF ├── LICENSE-NeRFPP └── LICENSE-Plenoxels ├── NOTICE ├── README.md ├── cache └── .gitkeep ├── configs ├── dvgo │ ├── 360_v2.gin │ ├── blender.gin │ ├── lf.gin │ ├── llff.gin │ ├── shiny_blender.gin │ └── tnt.gin ├── mipnerf │ ├── 360_v2.gin │ ├── blender.gin │ ├── blender_ms.gin │ ├── lf.gin │ ├── llff.gin │ ├── refnerf_real.gin │ ├── shiny_blender.gin │ └── tnt.gin ├── mipnerf360 │ ├── 360_v2.gin │ ├── lf.gin │ ├── llff.gin │ └── tnt.gin ├── nerf │ ├── 360_v2.gin │ ├── blender.gin │ ├── blender_ms.gin │ ├── lf.gin │ ├── llff.gin │ ├── shiny_blender.gin │ └── tnt.gin ├── nerfpp │ ├── 360_v2.gin │ ├── lf.gin │ └── tnt.gin ├── plenoxel │ ├── 360_v2.gin │ ├── blender.gin │ ├── blender_ms.gin │ ├── lf.gin │ ├── llff.gin │ ├── shiny_blender.gin │ └── tnt.gin └── refnerf │ ├── blender.gin │ ├── refnerf_real.gin │ └── shiny_blender.gin ├── data └── .gitkeep ├── lib ├── dvgo │ ├── adam_upd.cpp │ ├── adam_upd_kernel.cu │ ├── cuda │ │ ├── adam_upd.cpp │ │ ├── adam_upd_kernel.cu │ │ ├── render_utils.cpp │ │ ├── render_utils_kernel.cu │ │ ├── total_variation.cpp │ │ ├── total_variation_kernel.cu │ │ ├── ub360_utils.cpp │ │ └── ub360_utils_kernel.cu │ ├── render_utils.cpp │ ├── render_utils_kernel.cu │ ├── total_variation.cpp │ ├── total_variation_kernel.cu │ ├── ub360_utils.cpp │ └── ub360_utils_kernel.cu └── plenoxel │ ├── CMakeLists.txt │ ├── include │ ├── cubemap_util.cuh │ ├── cuda_util.cuh │ ├── data_spec.hpp │ ├── data_spec_packed.cuh │ ├── random_util.cuh │ ├── render_util.cuh │ └── util.hpp │ ├── loss_kernel.cu │ ├── misc_kernel.cu │ ├── optim_kernel.cu │ ├── render_lerp_kernel_cuvol.cu │ ├── svox2.cpp │ ├── svox2_kernel.cu │ └── version.py ├── nerf_factory.yml ├── requirements.txt ├── run.py ├── sbatch.sh ├── scripts ├── collage.sh ├── download_data.sh ├── dvgo.sh ├── mipnerf.sh ├── mipnerf360.sh ├── nerf.sh ├── nerfpp.sh ├── plenoxel.sh ├── refnerf.sh └── test.sh ├── setup.py ├── src ├── data │ ├── data_util │ │ ├── blender.py │ │ ├── blender_ms.py │ │ ├── lf.py │ │ ├── llff.py │ │ ├── nerf_360_v2.py │ │ ├── refnerf_real.py │ │ ├── shiny_blender.py │ │ └── tnt.py │ ├── interface.py │ ├── litdata.py │ ├── pose_utils.py │ ├── ray_utils.py │ └── sampler.py └── model │ ├── dvgo │ ├── __global__.py │ ├── dcvgo.py │ ├── dmpigo.py │ ├── dvgo.py │ ├── grid.py │ ├── masked_adam.py │ ├── model.py │ └── utils.py │ ├── interface.py │ ├── mipnerf │ ├── helper.py │ └── model.py │ ├── mipnerf360 │ ├── helper.py │ ├── model.py │ └── test.py │ ├── nerf │ ├── helper.py │ └── model.py │ ├── nerfpp │ ├── helper.py │ └── model.py │ ├── plenoxel │ ├── __global__.py │ ├── autograd.py │ ├── dataclass.py │ ├── model.py │ ├── sparse_grid.py │ └── utils.py │ └── refnerf │ ├── helper.py │ ├── model.py │ └── ref_utils.py └── utils ├── check_mean_score.py ├── create_scripts.py ├── preprocess_shiny_blender.py ├── select_option.py └── store_image.py /.gitignore: -------------------------------------------------------------------------------- 1 | *logs/* 2 | logs 3 | wandb/* 4 | *vscode* 5 | *pycache* 6 | *swap-pane* 7 | *.idea* 8 | *build* 9 | *.egg-info* 10 | *.eggs* 11 | *_debug* 12 | 13 | cache/* 14 | !cache/.gitkeep 15 | data/* 16 | !data/.gitkeep 17 | dataloader/co3d_lists/* 18 | !dataloader/co3d_lists/.gitkeep -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'build|egg-info|dist' 2 | 3 | default_language_version: 4 | python: python3 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v1.2.3 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: check-added-large-files 12 | - id: end-of-file-fixer 13 | 14 | - repo: https://github.com/pycqa/isort 15 | rev: 5.6.3 16 | hooks: 17 | - id: isort 18 | name: isort (python) 19 | args: ["--profile", "black"] 20 | 21 | - repo: https://github.com/psf/black 22 | rev: 22.3.0 23 | hooks: 24 | - id: black 25 | language_version: python3 26 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | To contribute to our NeRF-Factory code, we strongly recommend to follow the protocols below. 2 | 3 | 1. Do not modify the original code since it can change the performance of the implemented code. 4 | 2. Use pull request to contribute to our code. We will process a code review and merge it after internal acceptance. 5 | 3. For clairty, please provide us the description of codes to add your model. 6 | - It would be the best to provide your public link of the article. 7 | 4. If you desire to add your custom dataset, please share the dataset publicly. 8 | - Add the automatic download script of your dataset on `scripts/download_data.sh` 9 | 5. [Recommended] Add detailed descriptions of the code using comments inside the code. 10 | 6. Do not use the util functions from different model. This could cause side-effects in future. 11 | -------------------------------------------------------------------------------- /LICENSES/LICENSE-NeRF: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 bmild 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSES/LICENSE-NeRFPP: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, the NeRF++ authors 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /LICENSES/LICENSE-Plenoxels: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, the Plenoxels authors 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | How to apply the Apache License to your work. 2 | 3 | To apply the Apache License to your work, attach the following 4 | boilerplate notice, with the fields enclosed by brackets "[]" 5 | replaced with your own identifying information. (Don't include 6 | the brackets!) The text should be enclosed in the appropriate 7 | comment syntax for the file format. We also recommend that a 8 | file or class name and description of purpose be included on the 9 | same "printed page" as the copyright notice for easier 10 | identification within third-party archives. 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deprecated Repository 2 | 3 | This repository is no longer maintained. Please note that all further development and updates have been moved to a new repository. 4 | 5 | You can find the new repository **[here](https://github.com/POSTECH-CVLab/NeRF-Factory)**. 6 | 7 | Thank you for your understanding and continued support. 8 | 9 | 10 | # NeRF-Factory: An awesome PyTorch NeRF collection 11 | 12 | ![logo](https://user-images.githubusercontent.com/33657821/191188990-d15744b5-c030-48ac-9669-2a0600bacdec.png) 13 | 14 | [Project Page](https://kakaobrain.github.io/NeRF-Factory/) | [Checkpoints](https://huggingface.co/nrtf/nerf_factory) 15 | 16 | Attention all NeRF researchers! We are here with a PyTorch-reimplemented large-scale NeRF library. Our library is easily extensible and usable. 17 | 18 |

19 | animated 20 | animated 21 |

22 | 23 | 24 | This contains PyTorch-implementation of 7 popular NeRF models. 25 | - NeRF: [[Project Page]](https://www.matthewtancik.com/nerf) [[Paper]](https://arxiv.org/abs/2003.08934) [[Code]](https://github.com/bmild/nerf) 26 | - NeRF++: [[Paper]](http://arxiv.org/abs/2010.07492) [[Code]](https://github.com/Kai-46/nerfplusplus) 27 | - DVGO: [[Project Page]](https://sunset1995.github.io/dvgo/) [[Paper-v1]](https://arxiv.org/abs/2111.11215) [[Paper-v2]](https://arxiv.org/abs/2206.05085) [[Code]](https://github.com/sunset1995/DirectVoxGO) 28 | - Plenoxels: [[Project Page]](https://alexyu.net/plenoxels/) [[Paper]](https://arxiv.org/abs/2112.05131) [[Code]](https://github.com/sxyu/svox2) 29 | - Mip-NeRF: [[Project Page]](https://jonbarron.info/mipnerf/) [[Paper]](https://arxiv.org/abs/2103.13415) [[Code]](https://github.com/google/mipnerf) 30 | - Mip-NeRF360: [[Project Page]](https://jonbarron.info/mipnerf360/) [[Paper]](https://arxiv.org/abs/2111.12077) [[Code]](https://github.com/google-research/multinerf) 31 | - Ref-NeRF: [[Project Page]](https://dorverbin.github.io/refnerf/) [[Paper]](https://arxiv.org/abs/2112.03907) [[Code]](https://github.com/google-research/multinerf) 32 | 33 | and also 7 popular NeRF datasets. 34 | - NeRF Blender: [link](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 35 | - NeRF LLFF: [link](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 36 | - Tanks and Temples: [link](https://drive.google.com/file/d/11KRfN91W1AxAW6lOFs4EeYDbeoQZCi87/view?usp=sharing) 37 | - LF: [link](https://drive.google.com/file/d/1gsjDjkbTh4GAR9fFqlIDZ__qR9NYTURQ/view?usp=sharing) 38 | - NeRF-360: [link](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 39 | - NeRF-360-v2: [link](https://jonbarron.info/mipnerf360/) 40 | - Shiny Blender: [link](https://dorverbin.github.io/refnerf/) 41 | 42 | You only need to do for running the code is: 43 | 44 | ```bash 45 | python3 -m run --ginc configs/[model]/[data].gin 46 | # ex) python3 -m run --ginc configs/nerf/blender.gin 47 | ``` 48 | 49 | We also provide convenient visualizers for NeRF researchers. 50 | 51 | 52 | ## Contributor 53 | This project is created and maintained by [Yoonwoo Jeong](https://github.com/jeongyw12382), [Seungjoo Shin](https://github.com/seungjooshin), and [Kibaek Park](https://github.com/parkkibaek). 54 | 55 | ## Requirements 56 | ``` 57 | conda create -n nerf_factory -c anaconda python=3.8 58 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch 59 | pip3 install -r requirements.txt 60 | 61 | ## Optional(Plenoxel) 62 | pip3 install . 63 | 64 | ## Or you could directly build from nerf_factory.yml 65 | conda env create --file nerf_factory.yml 66 | ``` 67 | 68 | ## Command 69 | 70 | ```bash 71 | python3 -m run --ginc configs/[model]/[data].gin 72 | # ex) python3 -m run --ginc configs/nerf/blender.gin 73 | ``` 74 | 75 | ## Preparing Dataset 76 | 77 | We provide an automatic download script for all datasets. 78 | 79 | ```bash 80 | # NeRF-blender dataset 81 | bash scripts/download_data.sh nerf_synthetic 82 | # NeRF-LLFF(NeRF-Real) dataset 83 | bash scripts/download_data.sh nerf_llff 84 | # NeRF-360 dataset 85 | bash scripts/download_data.sh nerf_real_360 86 | # Tanks and Temples dataset 87 | bash scripts/download_data.sh tanks_and_temples 88 | # LF dataset 89 | bash scripts/download_data.sh lf 90 | # NeRF-360-v2 dataset 91 | bash scripts/download_data.sh nerf_360_v2 92 | # Shiny-blender dataset 93 | bash scripts/download_data.sh shiny_blender 94 | ``` 95 | 96 | ## Run the Code! 97 | 98 | A very simple script to run the code. 99 | 100 | 101 | ### Training Code 102 | 103 | A script for running the training code. 104 | 105 | ```bash 106 | python3 run.py --ginc configs/[model]/[data].gin --scene [scene] 107 | 108 | ## ex) run training nerf on chair scene of blender dataset 109 | python3 run.py --ginc configs/nerf/blender.gin --scene chair 110 | ``` 111 | 112 | ### Evaluation Code 113 | 114 | A script for running the evaluation code only. 115 | 116 | ```bash 117 | python3 run.py --ginc configs/[model]/[data].gin --scene [scene] \ 118 | --ginb run.run_train=False 119 | 120 | ## ex) run evaluating nerf on chair scene of blender dataset 121 | python3 run.py --ginc configs/nerf/blender.gin --scene chair \ 122 | --ginb run.run_train=False 123 | ``` 124 | 125 | ## Custom 126 | 127 | How to add the custom dataset and the custom model in NeRF-Factory? 128 | 129 | ### Custom Dataset 130 | 131 | - Add files of the custom dataset on ```./data/[custom_dataset]```. 132 | - Implement a dataset loader code on ```./src/data/data_util/[custom_dataset].py```. 133 | - Implement a custom dataset class ```LitData[custom_dataset]``` on ```./src/data/litdata.py```. 134 | - Add option of selecting the custom dataset on the function ```def select_dataset()``` of ```./utils/select_option.py```. 135 | - Add gin config file for each model as ```./configs/[model]/[custom_dataset].gin```. 136 | 137 | ### Custom Model 138 | 139 | - Implement a custom model code on ```./src/model/[custom_model]/model.py```. 140 | - Implement a custom model's helper code on ```./src/model/[custom_model]/helper.py```. 141 | - [Optional] If you need more code files for the custom model, you can add them in ```./src/model/[custom_model]/```.- Add option of selecting the custom model on the function ```def select_model()``` of ```./utils/select_option.py```. 142 | - Add gin config file for each model as ```./configs/[custom_model]/[dataset].gin```. 143 | 144 | ### License 145 | 146 | Copyright (c) 2022 POSTECH, KAIST, and Kakao Brain Corp. All Rights Reserved. 147 | Licensed under the Apache License, Version 2.0 (see [LICENSE](https://github.com/kakaobrain/NeRF-Factory/tree/main/LICENSE) for details) 148 | -------------------------------------------------------------------------------- /cache/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/nerf-factory/b8f152ac279010106f32c2b9d7d4f009781ca5b2/cache/.gitkeep -------------------------------------------------------------------------------- /configs/dvgo/360_v2.gin: -------------------------------------------------------------------------------- 1 | ### TnT Specific Arguments 2 | 3 | run.dataset_name = "nerf_360_v2" 4 | run.datadir = "data/nerf_360_v2" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ### NeRF Standard Specific Arguments 9 | 10 | LitData.needs_train_info = True 11 | LitData.batch_size = 4096 12 | LitData.chunk = 4096 13 | LitData.use_pixel_centers = True 14 | LitData.epoch_size = 40000 15 | LitData.use_near_clip = True 16 | 17 | run.max_steps = 40000 18 | run.log_every_n_steps = 100 19 | run.progressbar_refresh_rate = 100 20 | run.model_name = "dvgo" 21 | 22 | LitDVGO.bbox_type="unbounded_inward" 23 | LitDVGO.model_type="dcvgo" 24 | LitDVGO.N_iters_coarse=0 25 | LitDVGO.N_rand_fine=4096 26 | LitDVGO.lrate_decay_fine=80 27 | LitDVGO.weight_nearclip_fine=1.0 28 | LitDVGO.weight_distortion_fine=0.01 29 | LitDVGO.pg_scale_fine=[2000, 4000, 6000, 8000, 10000, 12000, 14000, 16000] 30 | LitDVGO.tv_before_fine=20000 31 | LitDVGO.tv_dense_before_fine=20000 32 | LitDVGO.weight_tv_density_fine=1e-6 33 | LitDVGO.weight_tv_k0_fine=1e-7 34 | 35 | LitDVGO.num_voxels_fine = 32768000 36 | LitDVGO.num_voxels_base_fine = 32768000 37 | LitDVGO.alpha_init_fine = 1e-4 38 | LitDVGO.stepsize_fine = 0.5 39 | LitDVGO.fast_color_thres_fine = "outdoor_default" 40 | LitDVGO.world_bound_scale_fine = 1. 41 | LitDVGO.contracted_norm_fine = "l2" 42 | -------------------------------------------------------------------------------- /configs/dvgo/blender.gin: -------------------------------------------------------------------------------- 1 | ### Blender Specific Arguments 2 | 3 | run.dataset_name = "blender" 4 | run.datadir = "data/blender" 5 | 6 | LitData.batch_sampler = "dynamic_all_images" 7 | 8 | ### NeRF Standard Specific Arguments 9 | 10 | LitData.needs_train_info = True 11 | LitData.batch_size = 4096 12 | LitData.chunk = 32768 13 | LitData.use_pixel_centers = True 14 | LitData.epoch_size = 25000 15 | LitDataBlender.white_bkgd = True 16 | 17 | run.max_steps = 25000 18 | run.log_every_n_steps = 100 19 | run.progressbar_refresh_rate = 100 20 | run.model_name = "dvgo" 21 | run.save_last = False 22 | 23 | MultipleImageDynamicDDPSampler.N_coarse = 5000 24 | LitDVGO.model_type = "dvgo" 25 | LitDVGO.ray_masking = True -------------------------------------------------------------------------------- /configs/dvgo/lf.gin: -------------------------------------------------------------------------------- 1 | ### TnT Specific Arguments 2 | 3 | run.dataset_name = "lf" 4 | run.datadir = "data/lf_data" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ### NeRF Standard Specific Arguments 9 | 10 | LitData.needs_train_info = True 11 | LitData.batch_size = 4096 12 | LitData.chunk = 4096 13 | LitData.use_pixel_centers = True 14 | LitData.epoch_size = 25000 15 | LitData.use_near_clip = True 16 | 17 | run.max_steps = 25000 18 | run.log_every_n_steps = 100 19 | run.progressbar_refresh_rate = 100 20 | run.model_name = "dvgo" 21 | 22 | LitDVGO.bbox_type="unbounded_inward" 23 | LitDVGO.model_type="dcvgo" 24 | LitDVGO.N_iters_coarse=0 25 | LitDVGO.N_rand_fine=4096 26 | LitDVGO.weight_distortion_fine=0.01 27 | LitDVGO.pg_scale_fine=[1000,2000,3000,4000,5000,6000] 28 | LitDVGO.tv_before_fine=1e9 29 | LitDVGO.tv_dense_before_fine=10000 30 | LitDVGO.weight_tv_density_fine=1e-6 31 | LitDVGO.weight_tv_k0_fine=1e-7 32 | 33 | LitDVGO.num_voxels_fine = 16777216 34 | LitDVGO.num_voxels_base_fine = 16777216 35 | LitDVGO.alpha_init_fine = 1e-4 36 | LitDVGO.stepsize_fine = 0.5 37 | LitDVGO.fast_color_thres_fine = "outdoor_default" 38 | LitDVGO.world_bound_scale_fine = 1. 39 | LitDVGO.contracted_norm_fine = "l2" 40 | 41 | LitDataLF.test_skip = 16 42 | -------------------------------------------------------------------------------- /configs/dvgo/llff.gin: -------------------------------------------------------------------------------- 1 | ### LLFF Specific Arguments 2 | 3 | run.dataset_name = "llff" 4 | run.datadir = "data/llff" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitData.ndc_coord = True 8 | 9 | ### NeRF Standard Specific Arguments 10 | 11 | LitData.needs_train_info = True 12 | LitData.batch_size = 4096 13 | LitData.chunk = 4096 14 | LitData.use_pixel_centers = True 15 | LitData.epoch_size = 30000 16 | 17 | run.max_steps = 30000 18 | run.log_every_n_steps = 100 19 | run.progressbar_refresh_rate = 100 20 | run.model_name = "dvgo" 21 | 22 | LitDVGO.model_type = "dmpigo" 23 | LitDVGO.N_iters_coarse = 0 24 | LitDVGO.weight_distortion_fine = 0.01 25 | LitDVGO.pg_scale_fine = [2000, 4000, 6000, 8000] 26 | LitDVGO.decay_after_scale_fine = 0.1 27 | LitDVGO.tv_before_fine = 1e9 28 | LitDVGO.tv_dense_before_fine = 10000 29 | LitDVGO.weight_tv_density_fine = 1e-5 30 | LitDVGO.weight_tv_k0_fine = 1e-6 31 | 32 | LitDVGO.num_voxels_fine = 37748736 33 | LitDVGO.mpi_depth_fine = 256 34 | LitDVGO.stepsize_fine = 1.0 35 | LitDVGO.rgbnet_dim_fine = 9 36 | LitDVGO.rgbnet_width_fine = 64 37 | LitDVGO.world_bound_scale_fine =1. 38 | LitDVGO.fast_color_thres_fine = 0.00078125 39 | LitDVGO.rand_bkgd = True -------------------------------------------------------------------------------- /configs/dvgo/shiny_blender.gin: -------------------------------------------------------------------------------- 1 | ### Blender Specific Arguments 2 | 3 | run.dataset_name = "shiny_blender" 4 | run.datadir = "data/refnerf_shinyblender" 5 | 6 | LitData.batch_sampler = "dynamic_all_images" 7 | 8 | ### NeRF Standard Specific Arguments 9 | 10 | LitData.needs_train_info = True 11 | LitData.batch_size = 4096 12 | LitData.chunk = 32768 13 | LitData.use_pixel_centers = True 14 | LitData.epoch_size = 25000 15 | LitDataBlender.white_bkgd = True 16 | 17 | run.max_steps = 25000 18 | run.log_every_n_steps = 100 19 | run.progressbar_refresh_rate = 100 20 | run.model_name = "dvgo" 21 | run.save_last = False 22 | 23 | MultipleImageDynamicDDPSampler.N_coarse = 5000 24 | LitDVGO.model_type = "dvgo" 25 | LitDVGO.ray_masking = True -------------------------------------------------------------------------------- /configs/dvgo/tnt.gin: -------------------------------------------------------------------------------- 1 | ### TnT Specific Arguments 2 | 3 | run.dataset_name = "tanks_and_temples" 4 | run.datadir = "data/tanks_and_temples" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ### NeRF Standard Specific Arguments 9 | 10 | LitData.needs_train_info = True 11 | LitData.batch_size = 4096 12 | LitData.chunk = 4096 13 | LitData.use_pixel_centers = True 14 | LitData.epoch_size = 30000 15 | LitData.use_near_clip = True 16 | 17 | run.max_steps = 30000 18 | run.log_every_n_steps = 100 19 | run.progressbar_refresh_rate = 100 20 | run.model_name = "dvgo" 21 | 22 | LitDVGO.bbox_type="unbounded_inward" 23 | LitDVGO.model_type="dcvgo" 24 | LitDVGO.N_iters_coarse=0 25 | LitDVGO.N_rand_fine=4096 26 | LitDVGO.weight_distortion_fine=0.01 27 | LitDVGO.pg_scale_fine=[1000,2000,3000,4000,5000,6000,7000] 28 | LitDVGO.tv_before_fine=1e9 29 | LitDVGO.tv_dense_before_fine=10000 30 | LitDVGO.weight_tv_density_fine=1e-6 31 | LitDVGO.weight_tv_k0_fine=1e-7 32 | 33 | LitDVGO.num_voxels_fine = 32768000 34 | LitDVGO.num_voxels_base_fine = 32768000 35 | LitDVGO.alpha_init_fine = 1e-4 36 | LitDVGO.stepsize_fine = 0.5 37 | LitDVGO.fast_color_thres_fine = "outdoor_default" 38 | LitDVGO.world_bound_scale_fine = 1. 39 | LitDVGO.contracted_norm_fine = "l2" 40 | -------------------------------------------------------------------------------- /configs/mipnerf/360_v2.gin: -------------------------------------------------------------------------------- 1 | ### 360-v2 Specific Arguments 2 | 3 | run.dataset_name = "nerf_360_v2" 4 | run.datadir = "data/nerf_360_v2" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ## MipNeRF Standard Specific Arguments 9 | 10 | run.model_name = "mipnerf" 11 | run.max_steps = 1000000 12 | run.log_every_n_steps = 100 13 | 14 | LitData.load_radii = True 15 | LitData.batch_size = 4096 16 | LitData.chunk = 8192 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | MipNeRF.density_noise = 1. 21 | LitDataNeRF360V2.cam_scale_factor = 0.125 22 | -------------------------------------------------------------------------------- /configs/mipnerf/blender.gin: -------------------------------------------------------------------------------- 1 | ### Blender Specific Arguments 2 | 3 | run.dataset_name = "blender" 4 | run.datadir = "data/blender" 5 | 6 | LitData.batch_sampler = "single_image" 7 | LitDataBlender.white_bkgd = True 8 | 9 | ## MipNeRF Standard Specific Arguments 10 | 11 | 12 | run.model_name = "mipnerf" 13 | run.max_steps = 1000000 14 | run.log_every_n_steps = 100 15 | 16 | LitData.load_radii = True 17 | LitData.batch_size = 4096 18 | LitData.chunk = 8192 19 | LitData.use_pixel_centers = True 20 | LitData.epoch_size = 250000 21 | -------------------------------------------------------------------------------- /configs/mipnerf/blender_ms.gin: -------------------------------------------------------------------------------- 1 | ### Blender Specific Arguments 2 | 3 | run.dataset_name = "blender_multiscale" 4 | run.datadir = "data/multiscale" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitDataBlenderMultiScale.white_bkgd = True 8 | 9 | ## MipNeRF Standard Specific Arguments 10 | 11 | run.model_name = "mipnerf" 12 | run.max_steps = 1000000 13 | run.log_every_n_steps = 100 14 | 15 | LitData.load_radii = True 16 | LitData.batch_size = 4096 17 | LitData.chunk = 8192 18 | LitData.use_pixel_centers = True 19 | LitData.epoch_size = 250000 20 | 21 | LitMipNeRF.use_multiscale = True -------------------------------------------------------------------------------- /configs/mipnerf/lf.gin: -------------------------------------------------------------------------------- 1 | ### LLFF Specific Arguments 2 | 3 | run.dataset_name = "lf" 4 | run.datadir = "data/lf_data" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ## MipNeRF Standard Specific Arguments 9 | 10 | run.model_name = "mipnerf" 11 | run.max_steps = 1000000 12 | run.log_every_n_steps = 100 13 | 14 | LitData.load_radii = True 15 | LitData.batch_size = 4096 16 | LitData.chunk = 8192 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | LitDataLF.cam_scale_factor = 0.125 21 | MipNeRF.density_noise = 1. 22 | -------------------------------------------------------------------------------- /configs/mipnerf/llff.gin: -------------------------------------------------------------------------------- 1 | ### LLFF Specific Arguments 2 | 3 | run.dataset_name = "llff" 4 | run.datadir = "data/llff" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitData.ndc_coord = True 8 | 9 | ## MipNeRF Standard Specific Arguments 10 | 11 | run.model_name = "mipnerf" 12 | run.max_steps = 1000000 13 | run.log_every_n_steps = 100 14 | 15 | LitData.load_radii = True 16 | LitData.batch_size = 4096 17 | LitData.chunk = 8192 18 | LitData.use_pixel_centers = True 19 | LitData.epoch_size = 250000 20 | 21 | MipNeRF.ray_shape = "cylinder" 22 | MipNeRF.density_noise = 1. -------------------------------------------------------------------------------- /configs/mipnerf/refnerf_real.gin: -------------------------------------------------------------------------------- 1 | ### RefNeRF Real Specific Arguments 2 | 3 | run.dataset_name = "refnerf_real" 4 | run.datadir = "data/ref_real" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ### MipNeRF Standard Specific Arguments 9 | 10 | run.model_name = "mipnerf" 11 | run.max_steps = 1000000 12 | run.log_every_n_steps = 100 13 | 14 | LitData.load_radii = True 15 | LitData.batch_size = 4096 16 | LitData.chunk = 8192 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | MipNeRF.density_noise = 1. -------------------------------------------------------------------------------- /configs/mipnerf/shiny_blender.gin: -------------------------------------------------------------------------------- 1 | ### Blender Specific Arguments 2 | 3 | run.dataset_name = "shiny_blender" 4 | run.datadir = "data/refnerf" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitDataShinyBlender.white_bkgd = True 8 | 9 | ### MipNeRF Standard Specific Arguments 10 | 11 | run.model_name = "mipnerf" 12 | run.max_steps = 1000000 13 | run.log_every_n_steps = 100 14 | 15 | LitData.load_radii = True 16 | LitData.batch_size = 4096 17 | LitData.chunk = 8192 18 | LitData.use_pixel_centers = True 19 | LitData.epoch_size = 250000 -------------------------------------------------------------------------------- /configs/mipnerf/tnt.gin: -------------------------------------------------------------------------------- 1 | ### LLFF Specific Arguments 2 | 3 | run.dataset_name = "tanks_and_temples" 4 | run.datadir = "data/tanks_and_temples" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ## MipNeRF Standard Specific Arguments 9 | 10 | run.model_name = "mipnerf" 11 | run.max_steps = 1000000 12 | run.log_every_n_steps = 100 13 | 14 | LitData.load_radii = True 15 | LitData.batch_size = 4096 16 | LitData.chunk = 8192 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | LitDataTnT.cam_scale_factor = 0.125 21 | MipNeRF.density_noise = 1. -------------------------------------------------------------------------------- /configs/mipnerf360/360_v2.gin: -------------------------------------------------------------------------------- 1 | ### 360-v2 Specific Arguments 2 | 3 | run.dataset_name = "nerf_360_v2" 4 | run.datadir = "data/nerf_360_v2" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ## MipNeRF Standard Specific Arguments 9 | 10 | run.model_name = "mipnerf360" 11 | run.max_steps = 1000000 12 | run.log_every_n_steps = 100 13 | 14 | LitData.load_radii = True 15 | LitData.batch_size = 4096 16 | LitData.chunk = 4096 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | 21 | LitDataNeRF360V2.near = 0.1 22 | LitDataNeRF360V2.far = 1e6 23 | 24 | MipNeRF360.opaque_background = True 25 | 26 | run.grad_max_norm = 0.001 -------------------------------------------------------------------------------- /configs/mipnerf360/lf.gin: -------------------------------------------------------------------------------- 1 | ### LLFF Specific Arguments 2 | 3 | run.dataset_name = "lf" 4 | run.datadir = "data/lf_data" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ## MipNeRF360 Standard Specific Arguments 9 | 10 | run.model_name = "mipnerf360" 11 | run.max_steps = 1000000 12 | run.log_every_n_steps = 100 13 | 14 | LitData.load_radii = True 15 | LitData.batch_size = 4096 16 | LitData.chunk = 4096 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | LitDataLF.near = 0.1 21 | LitDataLF.far = 1e6 22 | 23 | MipNeRF360.opaque_background = True 24 | 25 | run.grad_max_norm = 0.001 -------------------------------------------------------------------------------- /configs/mipnerf360/llff.gin: -------------------------------------------------------------------------------- 1 | ### LLFF Specific Arguments 2 | 3 | run.dataset_name = "llff" 4 | run.datadir = "data/llff" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ## MipNeRF360 Standard Specific Arguments 9 | 10 | run.model_name = "mipnerf360" 11 | run.max_steps = 1000000 12 | run.log_every_n_steps = 100 13 | 14 | LitData.load_radii = True 15 | LitData.batch_size = 4096 16 | LitData.chunk = 4096 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | LitDataLLFF.near = 0.2 21 | LitDataLLFF.far = 1e6 22 | 23 | MipNeRF360.opaque_background = True 24 | 25 | run.grad_max_norm = 0.001 -------------------------------------------------------------------------------- /configs/mipnerf360/tnt.gin: -------------------------------------------------------------------------------- 1 | ### LLFF Specific Arguments 2 | 3 | run.dataset_name = "tanks_and_temples" 4 | run.datadir = "data/tanks_and_temples" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ## MipNeRF360 Standard Specific Arguments 9 | 10 | run.model_name = "mipnerf360" 11 | run.max_steps = 1000000 12 | run.log_every_n_steps = 100 13 | 14 | LitData.load_radii = True 15 | LitData.batch_size = 4096 16 | LitData.chunk = 4096 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | LitDataTnT.near = 0.1 21 | LitDataTnT.far = 1e6 22 | 23 | MipNeRF360.opaque_background = True 24 | 25 | run.grad_max_norm = 0.001 -------------------------------------------------------------------------------- /configs/nerf/360_v2.gin: -------------------------------------------------------------------------------- 1 | ### Blender Specific Arguments 2 | 3 | run.dataset_name = "nerf_360_v2" 4 | run.datadir = "data/nerf_360_v2" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ### NeRF Standard Specific Arguments 9 | 10 | NeRF.num_coarse_samples = 128 11 | NeRF.num_fine_samples = 256 12 | NeRF.noise_std = 1. 13 | 14 | LitData.batch_size = 4096 15 | LitData.chunk = 8192 16 | LitData.use_pixel_centers = True 17 | LitData.epoch_size = 250000 18 | LitDataBlender.white_bkgd = True 19 | 20 | run.max_steps = 1000000 21 | run.log_every_n_steps = 100 22 | run.model_name = "nerf" 23 | 24 | LitDataNeRF360V2.cam_scale_factor = 0.125 -------------------------------------------------------------------------------- /configs/nerf/blender.gin: -------------------------------------------------------------------------------- 1 | ### Blender Specific Arguments 2 | 3 | run.dataset_name = "blender" 4 | run.datadir = "data/blender" 5 | 6 | LitData.batch_sampler = "single_image" 7 | 8 | ### NeRF Standard Specific Arguments 9 | 10 | NeRF.num_coarse_samples = 64 11 | NeRF.num_fine_samples = 128 12 | 13 | LitData.batch_size = 4096 14 | LitData.chunk = 8192 15 | LitData.use_pixel_centers = True 16 | LitData.epoch_size = 250000 17 | LitDataBlender.white_bkgd = True 18 | 19 | run.max_steps = 1000000 20 | run.log_every_n_steps = 100 21 | run.model_name = "nerf" -------------------------------------------------------------------------------- /configs/nerf/blender_ms.gin: -------------------------------------------------------------------------------- 1 | ### Blender Specific Arguments 2 | 3 | run.dataset_name = "blender_multiscale" 4 | run.datadir = "data/multiscale" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ### NeRF Standard Specific Arguments 9 | 10 | NeRF.num_coarse_samples = 64 11 | NeRF.num_fine_samples = 128 12 | 13 | LitData.batch_size = 4096 14 | LitData.chunk = 8192 15 | LitData.use_pixel_centers = True 16 | LitData.epoch_size = 250000 17 | LitDataBlenderMultiScale.white_bkgd = True 18 | 19 | run.max_steps = 1000000 20 | run.log_every_n_steps = 100 21 | run.model_name = "nerf" -------------------------------------------------------------------------------- /configs/nerf/lf.gin: -------------------------------------------------------------------------------- 1 | ### LF Specific Arguments 2 | 3 | run.dataset_name = "lf" 4 | run.datadir = "data/lf_data" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitDataLF.cam_scale_factor = 0.125 8 | 9 | ### NeRF Standard Specific Arguments 10 | 11 | NeRF.num_coarse_samples = 128 12 | NeRF.num_fine_samples = 256 13 | NeRF.noise_std = 1. 14 | 15 | LitData.batch_size = 4096 16 | LitData.chunk = 8192 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | run.max_steps = 1000000 21 | run.log_every_n_steps = 100 22 | run.model_name = "nerf" 23 | -------------------------------------------------------------------------------- /configs/nerf/llff.gin: -------------------------------------------------------------------------------- 1 | ### LLFF Specific Arguments 2 | 3 | run.dataset_name = "llff" 4 | run.datadir = "data/llff" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitData.ndc_coord = True 8 | 9 | ### NeRF Standard Specific Arguments 10 | 11 | NeRF.num_coarse_samples = 64 12 | NeRF.num_fine_samples = 128 13 | NeRF.noise_std = 1. 14 | 15 | LitData.batch_size = 4096 16 | LitData.chunk = 16384 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | run.max_steps = 1000000 21 | run.log_every_n_steps = 100 22 | run.model_name = "nerf" -------------------------------------------------------------------------------- /configs/nerf/shiny_blender.gin: -------------------------------------------------------------------------------- 1 | ### Shiny Blender Specific Arguments 2 | 3 | run.dataset_name = "shiny_blender" 4 | run.datadir = "data/refnerf" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitDataShinyBlender.white_bkgd = True 8 | 9 | ### NeRF Standard Specific Arguments 10 | 11 | NeRF.num_coarse_samples = 64 12 | NeRF.num_fine_samples = 128 13 | 14 | LitData.batch_size = 4096 15 | LitData.chunk = 8192 16 | LitData.use_pixel_centers = True 17 | LitData.epoch_size = 250000 18 | LitDataBlender.white_bkgd = True 19 | 20 | run.max_steps = 1000000 21 | run.log_every_n_steps = 100 22 | run.model_name = "nerf" -------------------------------------------------------------------------------- /configs/nerf/tnt.gin: -------------------------------------------------------------------------------- 1 | ### TnT Specific Arguments 2 | 3 | run.dataset_name = "tanks_and_temples" 4 | run.datadir = "data/tanks_and_temples" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitDataTnT.cam_scale_factor = 0.125 8 | 9 | ### NeRF Standard Specific Arguments 10 | 11 | NeRF.num_coarse_samples = 128 12 | NeRF.num_fine_samples = 256 13 | NeRF.noise_std = 1. 14 | 15 | LitData.batch_size = 4096 16 | LitData.chunk = 8192 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | run.max_steps = 1000000 21 | run.log_every_n_steps = 100 22 | run.model_name = "nerf" 23 | -------------------------------------------------------------------------------- /configs/nerfpp/360_v2.gin: -------------------------------------------------------------------------------- 1 | ### 360-v2 Specific Arguments 2 | 3 | run.dataset_name = "nerf_360_v2" 4 | run.datadir = "data/nerf_360_v2" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ## NeRF++ Standard Specific Arguments 9 | 10 | run.model_name = "nerfpp" 11 | 12 | run.max_steps = 1000000 13 | run.log_every_n_steps = 100 14 | 15 | LitData.load_radii = True 16 | LitData.batch_size = 4096 17 | LitData.chunk = 8192 18 | LitData.use_pixel_centers = True 19 | LitData.epoch_size = 250000 20 | 21 | NeRFPP.density_noise = 1. 22 | LitDataNeRF360V2.strict_scaling = True 23 | 24 | -------------------------------------------------------------------------------- /configs/nerfpp/lf.gin: -------------------------------------------------------------------------------- 1 | ### LF Specific Arguments 2 | 3 | run.dataset_name = "lf" 4 | run.datadir = "data/lf_data" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitDataLF.cam_scale_factor = 0 8 | 9 | ### NeRF++ Standard Specific Arguments 10 | 11 | NeRFPP.num_coarse_samples = 64 12 | NeRFPP.num_fine_samples = 128 13 | NeRFPP.density_noise = 1. 14 | 15 | LitData.batch_size = 4096 16 | LitData.chunk = 8192 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | run.max_steps = 1000000 21 | run.log_every_n_steps = 100 22 | run.model_name = "nerfpp" 23 | -------------------------------------------------------------------------------- /configs/nerfpp/tnt.gin: -------------------------------------------------------------------------------- 1 | ### TnT Specific Arguments 2 | 3 | run.dataset_name = "tanks_and_temples" 4 | run.datadir = "data/tanks_and_temples" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitDataTnT.cam_scale_factor = 0 8 | 9 | ### NeRF++ Standard Specific Arguments 10 | 11 | NeRFPP.num_coarse_samples = 64 12 | NeRFPP.num_fine_samples = 128 13 | NeRFPP.density_noise = 1. 14 | 15 | LitData.batch_size = 4096 16 | LitData.chunk = 8192 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | run.max_steps = 1000000 21 | run.log_every_n_steps = 100 22 | run.model_name = "nerfpp" 23 | -------------------------------------------------------------------------------- /configs/plenoxel/360_v2.gin: -------------------------------------------------------------------------------- 1 | 2 | ### TNT Specific 3 | 4 | run.dataset_name = "nerf_360_v2" 5 | run.datadir = "data/nerf_360_v2" 6 | 7 | LitPlenoxel.reso = [[128, 128, 128], [256, 256, 256], [512, 512, 512]] 8 | LitPlenoxel.background_nlayers = 64 9 | LitPlenoxel.background_reso = 1024 10 | LitPlenoxel.lr_sigma = 3.0e+1 11 | LitPlenoxel.lr_sh = 1.0e-2 12 | LitPlenoxel.lr_sigma_delay_steps = 15000 13 | LitPlenoxel.weight_thresh = 1.28 14 | LitPlenoxel.thresh_type = "weight" 15 | LitPlenoxel.lambda_tv = 5.0e-3 16 | LitPlenoxel.lambda_tv_sh = 5.0e-3 17 | LitPlenoxel.lambda_tv_background_color = 1.0e-3 18 | LitPlenoxel.lambda_tv_background_color = 1.0e-3 19 | LitPlenoxel.lambda_beta = 1.0e-5 20 | LitPlenoxel.lambda_sparsity = 1.0e-11 21 | LitPlenoxel.background_brightness = 0.5 22 | LitPlenoxel.tv_early_only = 0 23 | LitPlenoxel.lr_fg_begin_step = 1000 24 | LitPlenoxel.near_clip = 0.35 25 | 26 | ResampleCallBack.upsamp_every = 25600 27 | run.max_steps = 76800 28 | 29 | ### Plenoxel Specific 30 | 31 | LitData.batch_sampler = "all_images" 32 | LitData.epoch_size = 25600 33 | LitData.batch_size = 5000 34 | LitData.chunk = 8192 35 | LitData.use_pixel_centers = True 36 | 37 | run.model_name = "plenoxel" -------------------------------------------------------------------------------- /configs/plenoxel/blender.gin: -------------------------------------------------------------------------------- 1 | 2 | ### Blender Specific 3 | 4 | LitData.batch_sampler = "all_images" 5 | 6 | LitPlenoxel.reso = [[256, 256, 256], [512, 512, 512]] 7 | LitPlenoxel.lr_sigma = 3.0e+1 8 | LitPlenoxel.lr_sh = 1.0e-2 9 | LitPlenoxel.lambda_tv = 1.0e-5 10 | LitPlenoxel.lambda_tv_sh = 1.0e-3 11 | 12 | ResampleCallBack.upsamp_every = 38400 13 | LitDataBlender.cam_scale_factor = 0.6666 14 | run.dataset_name = "blender" 15 | run.datadir = "data/blender" 16 | run.logbase = "logs" 17 | run.max_steps = 153600 18 | 19 | ### Plenoxel Specific 20 | 21 | LitData.batch_sampler = "all_images" 22 | LitData.epoch_size = 38400 23 | LitData.batch_size = 5000 24 | LitData.chunk = 8192 25 | LitData.use_pixel_centers = True 26 | run.model_name = "plenoxel" -------------------------------------------------------------------------------- /configs/plenoxel/blender_ms.gin: -------------------------------------------------------------------------------- 1 | 2 | ### Blender Specific 3 | 4 | LitData.batch_sampler = "all_images" 5 | 6 | LitPlenoxel.reso = [[256, 256, 256], [512, 512, 512]] 7 | LitPlenoxel.lr_sigma = 3.0e+1 8 | LitPlenoxel.lr_sh = 1.0e-2 9 | LitPlenoxel.lambda_tv = 1.0e-5 10 | LitPlenoxel.lambda_tv_sh = 1.0e-3 11 | 12 | ResampleCallBack.upsamp_every = 38400 13 | LitDataBlenderMultiScale.cam_scale_factor = 0.6666 14 | run.dataset_name = "blender_multiscale" 15 | run.datadir = "data/multiscale" 16 | run.max_steps = 153600 17 | run.log_every_n_steps = 100 18 | run.progressbar_refresh_rate = 100 19 | 20 | ### Plenoxel Specific 21 | 22 | LitData.batch_sampler = "all_images" 23 | LitData.epoch_size = 38400 24 | LitData.batch_size = 5000 25 | LitData.chunk = 8192 26 | LitData.use_pixel_centers = True 27 | run.model_name = "plenoxel" -------------------------------------------------------------------------------- /configs/plenoxel/lf.gin: -------------------------------------------------------------------------------- 1 | 2 | ### LF Specific 3 | 4 | run.dataset_name = "lf" 5 | run.datadir = "data/lf_data" 6 | 7 | LitPlenoxel.reso = [[128, 128, 128], [256, 256, 256], [512, 512, 512]] 8 | LitPlenoxel.background_nlayers = 64 9 | LitPlenoxel.background_reso = 1024 10 | LitPlenoxel.lr_sigma = 3.0e+1 11 | LitPlenoxel.lr_sh = 1.0e-2 12 | LitPlenoxel.lr_sigma_delay_steps = 15000 13 | LitPlenoxel.weight_thresh = 1.28 14 | LitPlenoxel.thresh_type = "weight" 15 | LitPlenoxel.lambda_tv = 5.0e-3 16 | LitPlenoxel.lambda_tv_sh = 5.0e-3 17 | LitPlenoxel.lambda_tv_background_color = 1.0e-3 18 | LitPlenoxel.lambda_tv_background_color = 1.0e-3 19 | LitPlenoxel.lambda_beta = 1.0e-5 20 | LitPlenoxel.lambda_sparsity = 1.0e-11 21 | LitPlenoxel.background_brightness = 0.5 22 | LitPlenoxel.tv_early_only = 0 23 | LitPlenoxel.lr_fg_begin_step = 1000 24 | LitPlenoxel.near_clip = 0.35 25 | 26 | ResampleCallBack.upsamp_every = 25600 27 | run.max_steps = 153600 28 | 29 | ### Plenoxel Specific 30 | 31 | LitData.batch_sampler = "all_images" 32 | LitData.epoch_size = 38400 33 | LitData.batch_size = 5000 34 | LitData.chunk = 8192 35 | LitData.use_pixel_centers = True 36 | 37 | run.model_name = "plenoxel" -------------------------------------------------------------------------------- /configs/plenoxel/llff.gin: -------------------------------------------------------------------------------- 1 | 2 | ### LLFF Specific 3 | 4 | run.dataset_name = "llff" 5 | run.datadir = "data/llff" 6 | run.logbase = "logs" 7 | 8 | LitData.ndc_coord = True 9 | 10 | LitPlenoxel.reso = [[256, 256, 256], [512, 512, 512], [1408, 1156, 128]] 11 | LitPlenoxel.lr_sigma = 3.0e+1 12 | LitPlenoxel.lr_sh = 1.0e-2 13 | LitPlenoxel.lambda_tv = 5.0e-4 14 | LitPlenoxel.density_thresh = 5.0 15 | LitPlenoxel.thresh_type = "sigma" 16 | LitPlenoxel.lambda_tv_sh = 5.0e-3 17 | LitPlenoxel.lambda_sparsity = 1.0e-12 18 | LitPlenoxel.background_brightness = 0.5 19 | LitPlenoxel.tv_early_only = 0 20 | LitPlenoxel.last_sample_opaque = False 21 | 22 | ResampleCallBack.upsamp_every = 38400 23 | run.max_steps = 153600 24 | 25 | ### Plenoxel Specific Arguments 26 | 27 | LitData.batch_sampler = "all_images" 28 | LitData.epoch_size = 38400 29 | LitData.batch_size = 5000 30 | LitData.chunk = 8192 31 | LitData.use_pixel_centers = True 32 | 33 | run.model_name = "plenoxel" -------------------------------------------------------------------------------- /configs/plenoxel/shiny_blender.gin: -------------------------------------------------------------------------------- 1 | 2 | ### Blender Specific 3 | 4 | LitData.batch_sampler = "all_images" 5 | 6 | LitPlenoxel.reso = [[256, 256, 256], [512, 512, 512]] 7 | LitPlenoxel.lr_sigma = 3.0e+1 8 | LitPlenoxel.lr_sh = 1.0e-2 9 | LitPlenoxel.lambda_tv = 1.0e-5 10 | LitPlenoxel.lambda_tv_sh = 1.0e-3 11 | 12 | ResampleCallBack.upsamp_every = 38400 13 | LitDataShinyBlender.cam_scale_factor = 0.6666 14 | run.dataset_name = "shiny_blender" 15 | run.datadir = "data/refnerf_shinyblender" 16 | run.logbase = "logs" 17 | run.max_steps = 153600 18 | 19 | ### Plenoxel Specific 20 | 21 | LitData.batch_sampler = "all_images" 22 | LitData.epoch_size = 38400 23 | LitData.batch_size = 5000 24 | LitData.chunk = 8192 25 | LitData.use_pixel_centers = True 26 | run.model_name = "plenoxel" -------------------------------------------------------------------------------- /configs/plenoxel/tnt.gin: -------------------------------------------------------------------------------- 1 | 2 | ### TNT Specific 3 | 4 | run.dataset_name = "tanks_and_temples" 5 | run.datadir = "data/tanks_and_temples" 6 | 7 | LitPlenoxel.reso = [[128, 128, 128], [256, 256, 256], [512, 512, 512], [640, 640, 640]] 8 | LitPlenoxel.background_nlayers = 64 9 | LitPlenoxel.background_reso = 1024 10 | LitPlenoxel.lr_sigma = 3.0e+1 11 | LitPlenoxel.lr_sh = 1.0e-2 12 | LitPlenoxel.lr_sigma_delay_steps = 15000 13 | LitPlenoxel.weight_thresh = 1.28 14 | LitPlenoxel.thresh_type = "weight" 15 | LitPlenoxel.lambda_tv = 5.0e-5 16 | LitPlenoxel.lambda_tv_sh = 5.0e-3 17 | LitPlenoxel.lambda_tv_background_color = 1.0e-3 18 | LitPlenoxel.lambda_tv_background_color = 1.0e-3 19 | LitPlenoxel.lambda_beta = 1.0e-5 20 | LitPlenoxel.lambda_sparsity = 1.0e-11 21 | LitPlenoxel.background_brightness = 1.0 22 | LitPlenoxel.tv_early_only = 0 23 | 24 | ResampleCallBack.upsamp_every = 38400 25 | run.max_steps = 153600 26 | 27 | ### Plenoxel Specific 28 | 29 | LitData.batch_sampler = "all_images" 30 | LitData.epoch_size = 38400 31 | LitData.batch_size = 5000 32 | LitData.chunk = 8192 33 | LitData.use_pixel_centers = True 34 | 35 | run.model_name = "plenoxel" -------------------------------------------------------------------------------- /configs/refnerf/blender.gin: -------------------------------------------------------------------------------- 1 | ### Blender Specific Arguments 2 | 3 | run.dataset_name = "blender" 4 | run.datadir = "data/blender" 5 | 6 | LitData.batch_sampler = "single_image" 7 | LitDataBlender.white_bkgd = True 8 | 9 | ### RefNeRF Standard Specific Arguments 10 | 11 | run.model_name = "refnerf" 12 | run.max_steps = 1000000 13 | run.log_every_n_steps = 100 14 | 15 | LitData.load_radii = True 16 | LitData.batch_size = 4096 17 | LitData.chunk = 8192 18 | LitData.use_pixel_centers = True 19 | LitData.epoch_size = 250000 20 | 21 | LitRefNeRF.compute_normal_metrics = False -------------------------------------------------------------------------------- /configs/refnerf/refnerf_real.gin: -------------------------------------------------------------------------------- 1 | ### RefNeRF Real Specific Arguments 2 | 3 | run.dataset_name = "refnerf_real" 4 | run.datadir = "data/ref_real" 5 | 6 | LitData.batch_sampler = "all_images" 7 | 8 | ### RefNeRF Standard Specific Arguments 9 | 10 | run.model_name = "refnerf" 11 | run.max_steps = 1000000 12 | run.log_every_n_steps = 100 13 | 14 | LitData.load_radii = True 15 | LitData.batch_size = 4096 16 | LitData.chunk = 8192 17 | LitData.use_pixel_centers = True 18 | LitData.epoch_size = 250000 19 | 20 | RefNeRFMLP.density_noise = 1. -------------------------------------------------------------------------------- /configs/refnerf/shiny_blender.gin: -------------------------------------------------------------------------------- 1 | ### Shiny Blender Specific Arguments 2 | 3 | run.dataset_name = "shiny_blender" 4 | run.datadir = "data/refnerf_shinyblender" 5 | 6 | LitData.batch_sampler = "all_images" 7 | LitDataShinyBlender.white_bkgd = True 8 | 9 | ### RefNeRF Standard Specific Arguments 10 | 11 | run.model_name = "refnerf" 12 | run.max_steps = 1000000 13 | run.log_every_n_steps = 100 14 | 15 | LitData.load_radii = True 16 | LitData.batch_size = 4096 17 | LitData.chunk = 8192 18 | LitData.use_pixel_centers = True 19 | LitData.epoch_size = 250000 20 | 21 | LitRefNeRF.compute_normal_metrics = True -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/nerf-factory/b8f152ac279010106f32c2b9d7d4f009781ca5b2/data/.gitkeep -------------------------------------------------------------------------------- /lib/dvgo/adam_upd.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | void adam_upd_cuda( 8 | torch::Tensor param, 9 | torch::Tensor grad, 10 | torch::Tensor exp_avg, 11 | torch::Tensor exp_avg_sq, 12 | int step, float beta1, float beta2, float lr, float eps); 13 | 14 | void masked_adam_upd_cuda( 15 | torch::Tensor param, 16 | torch::Tensor grad, 17 | torch::Tensor exp_avg, 18 | torch::Tensor exp_avg_sq, 19 | int step, float beta1, float beta2, float lr, float eps); 20 | 21 | void adam_upd_with_perlr_cuda( 22 | torch::Tensor param, 23 | torch::Tensor grad, 24 | torch::Tensor exp_avg, 25 | torch::Tensor exp_avg_sq, 26 | torch::Tensor perlr, 27 | int step, float beta1, float beta2, float lr, float eps); 28 | 29 | 30 | // C++ interface 31 | 32 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 33 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 34 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 35 | 36 | void adam_upd( 37 | torch::Tensor param, 38 | torch::Tensor grad, 39 | torch::Tensor exp_avg, 40 | torch::Tensor exp_avg_sq, 41 | int step, float beta1, float beta2, float lr, float eps) { 42 | CHECK_INPUT(param); 43 | CHECK_INPUT(grad); 44 | CHECK_INPUT(exp_avg); 45 | CHECK_INPUT(exp_avg_sq); 46 | adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, 47 | step, beta1, beta2, lr, eps); 48 | } 49 | 50 | void masked_adam_upd( 51 | torch::Tensor param, 52 | torch::Tensor grad, 53 | torch::Tensor exp_avg, 54 | torch::Tensor exp_avg_sq, 55 | int step, float beta1, float beta2, float lr, float eps) { 56 | CHECK_INPUT(param); 57 | CHECK_INPUT(grad); 58 | CHECK_INPUT(exp_avg); 59 | CHECK_INPUT(exp_avg_sq); 60 | masked_adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, 61 | step, beta1, beta2, lr, eps); 62 | } 63 | 64 | void adam_upd_with_perlr( 65 | torch::Tensor param, 66 | torch::Tensor grad, 67 | torch::Tensor exp_avg, 68 | torch::Tensor exp_avg_sq, 69 | torch::Tensor perlr, 70 | int step, float beta1, float beta2, float lr, float eps) { 71 | CHECK_INPUT(param); 72 | CHECK_INPUT(grad); 73 | CHECK_INPUT(exp_avg); 74 | CHECK_INPUT(exp_avg_sq); 75 | adam_upd_with_perlr_cuda(param, grad, exp_avg, exp_avg_sq, perlr, 76 | step, beta1, beta2, lr, eps); 77 | } 78 | 79 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 80 | m.def("adam_upd", &adam_upd, 81 | "Adam update"); 82 | m.def("masked_adam_upd", &masked_adam_upd, 83 | "Adam update ignoring zero grad"); 84 | m.def("adam_upd_with_perlr", &adam_upd_with_perlr, 85 | "Adam update ignoring zero grad with per-voxel lr"); 86 | } 87 | 88 | -------------------------------------------------------------------------------- /lib/dvgo/adam_upd_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | template 9 | __global__ void adam_upd_cuda_kernel( 10 | scalar_t* __restrict__ param, 11 | const scalar_t* __restrict__ grad, 12 | scalar_t* __restrict__ exp_avg, 13 | scalar_t* __restrict__ exp_avg_sq, 14 | const size_t N, 15 | const float step_size, const float beta1, const float beta2, const float eps) { 16 | 17 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 18 | if(index 26 | __global__ void masked_adam_upd_cuda_kernel( 27 | scalar_t* __restrict__ param, 28 | const scalar_t* __restrict__ grad, 29 | scalar_t* __restrict__ exp_avg, 30 | scalar_t* __restrict__ exp_avg_sq, 31 | const size_t N, 32 | const float step_size, const float beta1, const float beta2, const float eps) { 33 | 34 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 35 | if(index 43 | __global__ void adam_upd_with_perlr_cuda_kernel( 44 | scalar_t* __restrict__ param, 45 | const scalar_t* __restrict__ grad, 46 | scalar_t* __restrict__ exp_avg, 47 | scalar_t* __restrict__ exp_avg_sq, 48 | scalar_t* __restrict__ perlr, 49 | const size_t N, 50 | const float step_size, const float beta1, const float beta2, const float eps) { 51 | 52 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 53 | if(index<<>>( 76 | param.data(), 77 | grad.data(), 78 | exp_avg.data(), 79 | exp_avg_sq.data(), 80 | N, step_size, beta1, beta2, eps); 81 | })); 82 | } 83 | 84 | void masked_adam_upd_cuda( 85 | torch::Tensor param, 86 | torch::Tensor grad, 87 | torch::Tensor exp_avg, 88 | torch::Tensor exp_avg_sq, 89 | const int step, const float beta1, const float beta2, const float lr, const float eps) { 90 | 91 | const size_t N = param.numel(); 92 | 93 | const int threads = 256; 94 | const int blocks = (N + threads - 1) / threads; 95 | 96 | const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step)); 97 | 98 | AT_DISPATCH_FLOATING_TYPES(param.type(), "masked_adam_upd_cuda", ([&] { 99 | masked_adam_upd_cuda_kernel<<>>( 100 | param.data(), 101 | grad.data(), 102 | exp_avg.data(), 103 | exp_avg_sq.data(), 104 | N, step_size, beta1, beta2, eps); 105 | })); 106 | } 107 | 108 | void adam_upd_with_perlr_cuda( 109 | torch::Tensor param, 110 | torch::Tensor grad, 111 | torch::Tensor exp_avg, 112 | torch::Tensor exp_avg_sq, 113 | torch::Tensor perlr, 114 | const int step, const float beta1, const float beta2, const float lr, const float eps) { 115 | 116 | const size_t N = param.numel(); 117 | 118 | const int threads = 256; 119 | const int blocks = (N + threads - 1) / threads; 120 | 121 | const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step)); 122 | 123 | AT_DISPATCH_FLOATING_TYPES(param.type(), "adam_upd_with_perlr_cuda", ([&] { 124 | adam_upd_with_perlr_cuda_kernel<<>>( 125 | param.data(), 126 | grad.data(), 127 | exp_avg.data(), 128 | exp_avg_sq.data(), 129 | perlr.data(), 130 | N, step_size, beta1, beta2, eps); 131 | })); 132 | } 133 | 134 | -------------------------------------------------------------------------------- /lib/dvgo/cuda/adam_upd.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | void adam_upd_cuda( 8 | torch::Tensor param, 9 | torch::Tensor grad, 10 | torch::Tensor exp_avg, 11 | torch::Tensor exp_avg_sq, 12 | int step, float beta1, float beta2, float lr, float eps); 13 | 14 | void masked_adam_upd_cuda( 15 | torch::Tensor param, 16 | torch::Tensor grad, 17 | torch::Tensor exp_avg, 18 | torch::Tensor exp_avg_sq, 19 | int step, float beta1, float beta2, float lr, float eps); 20 | 21 | void adam_upd_with_perlr_cuda( 22 | torch::Tensor param, 23 | torch::Tensor grad, 24 | torch::Tensor exp_avg, 25 | torch::Tensor exp_avg_sq, 26 | torch::Tensor perlr, 27 | int step, float beta1, float beta2, float lr, float eps); 28 | 29 | 30 | // C++ interface 31 | 32 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 33 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 34 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 35 | 36 | void adam_upd( 37 | torch::Tensor param, 38 | torch::Tensor grad, 39 | torch::Tensor exp_avg, 40 | torch::Tensor exp_avg_sq, 41 | int step, float beta1, float beta2, float lr, float eps) { 42 | CHECK_INPUT(param); 43 | CHECK_INPUT(grad); 44 | CHECK_INPUT(exp_avg); 45 | CHECK_INPUT(exp_avg_sq); 46 | adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, 47 | step, beta1, beta2, lr, eps); 48 | } 49 | 50 | void masked_adam_upd( 51 | torch::Tensor param, 52 | torch::Tensor grad, 53 | torch::Tensor exp_avg, 54 | torch::Tensor exp_avg_sq, 55 | int step, float beta1, float beta2, float lr, float eps) { 56 | CHECK_INPUT(param); 57 | CHECK_INPUT(grad); 58 | CHECK_INPUT(exp_avg); 59 | CHECK_INPUT(exp_avg_sq); 60 | masked_adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, 61 | step, beta1, beta2, lr, eps); 62 | } 63 | 64 | void adam_upd_with_perlr( 65 | torch::Tensor param, 66 | torch::Tensor grad, 67 | torch::Tensor exp_avg, 68 | torch::Tensor exp_avg_sq, 69 | torch::Tensor perlr, 70 | int step, float beta1, float beta2, float lr, float eps) { 71 | CHECK_INPUT(param); 72 | CHECK_INPUT(grad); 73 | CHECK_INPUT(exp_avg); 74 | CHECK_INPUT(exp_avg_sq); 75 | adam_upd_with_perlr_cuda(param, grad, exp_avg, exp_avg_sq, perlr, 76 | step, beta1, beta2, lr, eps); 77 | } 78 | 79 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 80 | m.def("adam_upd", &adam_upd, 81 | "Adam update"); 82 | m.def("masked_adam_upd", &masked_adam_upd, 83 | "Adam update ignoring zero grad"); 84 | m.def("adam_upd_with_perlr", &adam_upd_with_perlr, 85 | "Adam update ignoring zero grad with per-voxel lr"); 86 | } 87 | 88 | -------------------------------------------------------------------------------- /lib/dvgo/cuda/adam_upd_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | template 9 | __global__ void adam_upd_cuda_kernel( 10 | scalar_t* __restrict__ param, 11 | const scalar_t* __restrict__ grad, 12 | scalar_t* __restrict__ exp_avg, 13 | scalar_t* __restrict__ exp_avg_sq, 14 | const size_t N, 15 | const float step_size, const float beta1, const float beta2, const float eps) { 16 | 17 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 18 | if(index 26 | __global__ void masked_adam_upd_cuda_kernel( 27 | scalar_t* __restrict__ param, 28 | const scalar_t* __restrict__ grad, 29 | scalar_t* __restrict__ exp_avg, 30 | scalar_t* __restrict__ exp_avg_sq, 31 | const size_t N, 32 | const float step_size, const float beta1, const float beta2, const float eps) { 33 | 34 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 35 | if(index 43 | __global__ void adam_upd_with_perlr_cuda_kernel( 44 | scalar_t* __restrict__ param, 45 | const scalar_t* __restrict__ grad, 46 | scalar_t* __restrict__ exp_avg, 47 | scalar_t* __restrict__ exp_avg_sq, 48 | scalar_t* __restrict__ perlr, 49 | const size_t N, 50 | const float step_size, const float beta1, const float beta2, const float eps) { 51 | 52 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 53 | if(index<<>>( 76 | param.data(), 77 | grad.data(), 78 | exp_avg.data(), 79 | exp_avg_sq.data(), 80 | N, step_size, beta1, beta2, eps); 81 | })); 82 | } 83 | 84 | void masked_adam_upd_cuda( 85 | torch::Tensor param, 86 | torch::Tensor grad, 87 | torch::Tensor exp_avg, 88 | torch::Tensor exp_avg_sq, 89 | const int step, const float beta1, const float beta2, const float lr, const float eps) { 90 | 91 | const size_t N = param.numel(); 92 | 93 | const int threads = 256; 94 | const int blocks = (N + threads - 1) / threads; 95 | 96 | const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step)); 97 | 98 | AT_DISPATCH_FLOATING_TYPES(param.type(), "masked_adam_upd_cuda", ([&] { 99 | masked_adam_upd_cuda_kernel<<>>( 100 | param.data(), 101 | grad.data(), 102 | exp_avg.data(), 103 | exp_avg_sq.data(), 104 | N, step_size, beta1, beta2, eps); 105 | })); 106 | } 107 | 108 | void adam_upd_with_perlr_cuda( 109 | torch::Tensor param, 110 | torch::Tensor grad, 111 | torch::Tensor exp_avg, 112 | torch::Tensor exp_avg_sq, 113 | torch::Tensor perlr, 114 | const int step, const float beta1, const float beta2, const float lr, const float eps) { 115 | 116 | const size_t N = param.numel(); 117 | 118 | const int threads = 256; 119 | const int blocks = (N + threads - 1) / threads; 120 | 121 | const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step)); 122 | 123 | AT_DISPATCH_FLOATING_TYPES(param.type(), "adam_upd_with_perlr_cuda", ([&] { 124 | adam_upd_with_perlr_cuda_kernel<<>>( 125 | param.data(), 126 | grad.data(), 127 | exp_avg.data(), 128 | exp_avg_sq.data(), 129 | perlr.data(), 130 | N, step_size, beta1, beta2, eps); 131 | })); 132 | } 133 | 134 | -------------------------------------------------------------------------------- /lib/dvgo/cuda/total_variation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | void total_variation_add_grad_cuda(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode); 8 | 9 | 10 | // C++ interface 11 | 12 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 13 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | 16 | void total_variation_add_grad(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode) { 17 | CHECK_INPUT(param); 18 | CHECK_INPUT(grad); 19 | total_variation_add_grad_cuda(param, grad, wx, wy, wz, dense_mode); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("total_variation_add_grad", &total_variation_add_grad, "Add total variation grad"); 24 | } 25 | 26 | -------------------------------------------------------------------------------- /lib/dvgo/cuda/total_variation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | template 9 | __device__ __forceinline__ scalar_t clamp(const scalar_t v, const bound_t lo, const bound_t hi) { 10 | return min(max(v, lo), hi); 11 | } 12 | 13 | template 14 | __global__ void total_variation_add_grad_cuda_kernel( 15 | const scalar_t* __restrict__ param, 16 | scalar_t* __restrict__ grad, 17 | float wx, float wy, float wz, 18 | const size_t sz_i, const size_t sz_j, const size_t sz_k, const size_t N) { 19 | 20 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 21 | if(index<<>>( 52 | param.data(), 53 | grad.data(), 54 | wx, wy, wz, 55 | sz_i, sz_j, sz_k, N); 56 | })); 57 | } 58 | else { 59 | AT_DISPATCH_FLOATING_TYPES(param.type(), "total_variation_add_grad_cuda", ([&] { 60 | total_variation_add_grad_cuda_kernel<<>>( 61 | param.data(), 62 | grad.data(), 63 | wx, wy, wz, 64 | sz_i, sz_j, sz_k, N); 65 | })); 66 | } 67 | } 68 | 69 | -------------------------------------------------------------------------------- /lib/dvgo/cuda/ub360_utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | torch::Tensor cumdist_thres_cuda(torch::Tensor dist, float thres); 8 | 9 | // C++ interface 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | 15 | torch::Tensor cumdist_thres(torch::Tensor dist, float thres) { 16 | CHECK_INPUT(dist); 17 | return cumdist_thres_cuda(dist, thres); 18 | } 19 | 20 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 21 | m.def("cumdist_thres", &cumdist_thres, "Generate mask for cumulative dist."); 22 | } 23 | 24 | -------------------------------------------------------------------------------- /lib/dvgo/cuda/ub360_utils_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | /* 9 | helper function to skip oversampled points, 10 | especially near the foreground scene bbox boundary 11 | */ 12 | template 13 | __global__ void cumdist_thres_cuda_kernel( 14 | scalar_t* __restrict__ dist, 15 | const float thres, 16 | const int n_rays, 17 | const int n_pts, 18 | bool* __restrict__ mask) { 19 | const int i_ray = blockIdx.x * blockDim.x + threadIdx.x; 20 | if(i_ray thres); 28 | cum_dist *= float(!over); 29 | mask[i] = over; 30 | } 31 | } 32 | } 33 | 34 | torch::Tensor cumdist_thres_cuda(torch::Tensor dist, float thres) { 35 | const int n_rays = dist.size(0); 36 | const int n_pts = dist.size(1); 37 | const int threads = 256; 38 | const int blocks = (n_rays + threads - 1) / threads; 39 | auto mask = torch::zeros({n_rays, n_pts}, torch::dtype(torch::kBool).device(torch::kCUDA)); 40 | AT_DISPATCH_FLOATING_TYPES(dist.type(), "cumdist_thres_cuda", ([&] { 41 | cumdist_thres_cuda_kernel<<>>( 42 | dist.data(), thres, 43 | n_rays, n_pts, 44 | mask.data()); 45 | })); 46 | return mask; 47 | } 48 | 49 | -------------------------------------------------------------------------------- /lib/dvgo/total_variation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | void total_variation_add_grad_cuda(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode); 8 | 9 | 10 | // C++ interface 11 | 12 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 13 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | 16 | void total_variation_add_grad(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode) { 17 | CHECK_INPUT(param); 18 | CHECK_INPUT(grad); 19 | total_variation_add_grad_cuda(param, grad, wx, wy, wz, dense_mode); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("total_variation_add_grad", &total_variation_add_grad, "Add total variation grad"); 24 | } 25 | 26 | -------------------------------------------------------------------------------- /lib/dvgo/total_variation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | template 9 | __device__ __forceinline__ scalar_t clamp(const scalar_t v, const bound_t lo, const bound_t hi) { 10 | return min(max(v, lo), hi); 11 | } 12 | 13 | template 14 | __global__ void total_variation_add_grad_cuda_kernel( 15 | const scalar_t* __restrict__ param, 16 | scalar_t* __restrict__ grad, 17 | float wx, float wy, float wz, 18 | const size_t sz_i, const size_t sz_j, const size_t sz_k, const size_t N) { 19 | 20 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 21 | if(index<<>>( 52 | param.data(), 53 | grad.data(), 54 | wx, wy, wz, 55 | sz_i, sz_j, sz_k, N); 56 | })); 57 | } 58 | else { 59 | AT_DISPATCH_FLOATING_TYPES(param.type(), "total_variation_add_grad_cuda", ([&] { 60 | total_variation_add_grad_cuda_kernel<<>>( 61 | param.data(), 62 | grad.data(), 63 | wx, wy, wz, 64 | sz_i, sz_j, sz_k, N); 65 | })); 66 | } 67 | } 68 | 69 | -------------------------------------------------------------------------------- /lib/dvgo/ub360_utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | torch::Tensor cumdist_thres_cuda(torch::Tensor dist, float thres); 8 | 9 | // C++ interface 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | 15 | torch::Tensor cumdist_thres(torch::Tensor dist, float thres) { 16 | CHECK_INPUT(dist); 17 | return cumdist_thres_cuda(dist, thres); 18 | } 19 | 20 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 21 | m.def("cumdist_thres", &cumdist_thres, "Generate mask for cumulative dist."); 22 | } 23 | 24 | -------------------------------------------------------------------------------- /lib/dvgo/ub360_utils_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | /* 9 | helper function to skip oversampled points, 10 | especially near the foreground scene bbox boundary 11 | */ 12 | template 13 | __global__ void cumdist_thres_cuda_kernel( 14 | scalar_t* __restrict__ dist, 15 | const float thres, 16 | const int n_rays, 17 | const int n_pts, 18 | bool* __restrict__ mask) { 19 | const int i_ray = blockIdx.x * blockDim.x + threadIdx.x; 20 | if(i_ray thres); 28 | cum_dist *= float(!over); 29 | mask[i] = over; 30 | } 31 | } 32 | } 33 | 34 | torch::Tensor cumdist_thres_cuda(torch::Tensor dist, float thres) { 35 | const int n_rays = dist.size(0); 36 | const int n_pts = dist.size(1); 37 | const int threads = 256; 38 | const int blocks = (n_rays + threads - 1) / threads; 39 | auto mask = torch::zeros({n_rays, n_pts}, torch::dtype(torch::kBool).device(torch::kCUDA)); 40 | AT_DISPATCH_FLOATING_TYPES(dist.type(), "cumdist_thres_cuda", ([&] { 41 | cumdist_thres_cuda_kernel<<>>( 42 | dist.data(), thres, 43 | n_rays, n_pts, 44 | mask.data()); 45 | })); 46 | return mask; 47 | } 48 | 49 | -------------------------------------------------------------------------------- /lib/plenoxel/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PlenOctree Authors. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions are met: 5 | # 6 | # 1. Redistributions of source code must retain the above copyright notice, 7 | # this list of conditions and the following disclaimer. 8 | # 9 | # 2. Redistributions in binary form must reproduce the above copyright notice, 10 | # this list of conditions and the following disclaimer in the documentation 11 | # and/or other materials provided with the distribution. 12 | # 13 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 14 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 15 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 16 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 17 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 18 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 19 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 20 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 21 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 22 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 23 | # POSSIBILITY OF SUCH DAMAGE. 24 | 25 | # NOTE: This CMakeLists is for development purposes only 26 | # (To check CUDA compile errors) 27 | # It is NOT necessary to use this for installation. Just use pip install . 28 | cmake_minimum_required( VERSION 3.3 ) 29 | 30 | if(NOT CMAKE_BUILD_TYPE) 31 | set(CMAKE_BUILD_TYPE Release) 32 | endif() 33 | if (POLICY CMP0048) 34 | cmake_policy(SET CMP0048 NEW) 35 | endif (POLICY CMP0048) 36 | if (POLICY CMP0069) 37 | cmake_policy(SET CMP0069 NEW) 38 | endif (POLICY CMP0069) 39 | if (POLICY CMP0072) 40 | cmake_policy(SET CMP0072 NEW) 41 | endif (POLICY CMP0072) 42 | 43 | project( svox2 ) 44 | 45 | set(CMAKE_CXX_STANDARD 14) 46 | enable_language(CUDA) 47 | message(STATUS "CUDA enabled") 48 | set( CMAKE_CUDA_STANDARD 14 ) 49 | set( CMAKE_CUDA_STANDARD_REQUIRED ON) 50 | set( CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g -Xcudafe \"--display_error_number --diag_suppress=3057 --diag_suppress=3058 --diag_suppress=3059 --diag_suppress=3060\" -lineinfo -arch=sm_75 ") 51 | # -Xptxas=\"-v\" 52 | 53 | set( INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include" ) 54 | 55 | if( MSVC ) 56 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd") 57 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT /GLT /Ox") 58 | set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler=\"/MT\"" ) 59 | endif() 60 | 61 | file(GLOB SOURCES 62 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp 63 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cu) 64 | 65 | find_package(pybind11 REQUIRED) 66 | find_package(Torch REQUIRED) 67 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 68 | 69 | include_directories (${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 70 | 71 | pybind11_add_module(svox2-test SHARED ${SOURCES}) 72 | target_link_libraries(svox2-test PRIVATE "${TORCH_LIBRARIES}") 73 | target_include_directories(svox2-test PRIVATE "${INCLUDE_DIR}") 74 | 75 | if (MSVC) 76 | file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") 77 | add_custom_command(TARGET svox2-test 78 | POST_BUILD 79 | COMMAND ${CMAKE_COMMAND} -E copy_if_different 80 | ${TORCH_DLLS} 81 | $) 82 | endif (MSVC) 83 | -------------------------------------------------------------------------------- /lib/plenoxel/include/cuda_util.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "util.hpp" 8 | 9 | 10 | #define DEVICE_GUARD(_ten) \ 11 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); 12 | 13 | #define CUDA_GET_THREAD_ID(tid, Q) const int tid = blockIdx.x * blockDim.x + threadIdx.x; \ 14 | if (tid >= Q) return 15 | #define CUDA_GET_THREAD_ID_U64(tid, Q) const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; \ 16 | if (tid >= Q) return 17 | #define CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS) ((Q - 1) / CUDA_N_THREADS + 1) 18 | #define CUDA_CHECK_ERRORS \ 19 | cudaError_t err = cudaGetLastError(); \ 20 | if (err != cudaSuccess) \ 21 | printf("Error in svox2.%s : %s\n", __FUNCTION__, cudaGetErrorString(err)) 22 | 23 | #define CUDA_MAX_THREADS at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock 24 | 25 | #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 26 | #else 27 | __device__ inline double atomicAdd(double* address, double val){ 28 | unsigned long long int* address_as_ull = (unsigned long long int*)address; 29 | unsigned long long int old = *address_as_ull, assumed; 30 | do { 31 | assumed = old; 32 | old = atomicCAS(address_as_ull, assumed, 33 | __double_as_longlong(val + __longlong_as_double(assumed))); 34 | } while (assumed != old); 35 | return __longlong_as_double(old); 36 | } 37 | #endif 38 | 39 | __device__ inline void atomicMax(float* result, float value){ 40 | unsigned* result_as_u = (unsigned*)result; 41 | unsigned old = *result_as_u, assumed; 42 | do { 43 | assumed = old; 44 | old = atomicCAS(result_as_u, assumed, 45 | __float_as_int(fmaxf(value, __int_as_float(assumed)))); 46 | } while (old != assumed); 47 | return; 48 | } 49 | 50 | __device__ inline void atomicMax(double* result, double value){ 51 | unsigned long long int* result_as_ull = (unsigned long long int*)result; 52 | unsigned long long int old = *result_as_ull, assumed; 53 | do { 54 | assumed = old; 55 | old = atomicCAS(result_as_ull, assumed, 56 | __double_as_longlong(fmaxf(value, __longlong_as_double(assumed)))); 57 | } while (old != assumed); 58 | return; 59 | } 60 | 61 | __device__ __inline__ void transform_coord(float* __restrict__ point, 62 | const float* __restrict__ scaling, 63 | const float* __restrict__ offset) { 64 | point[0] = fmaf(point[0], scaling[0], offset[0]); // a*b + c 65 | point[1] = fmaf(point[1], scaling[1], offset[1]); // a*b + c 66 | point[2] = fmaf(point[2], scaling[2], offset[2]); // a*b + c 67 | } 68 | 69 | // Linear interp 70 | // Subtract and fused multiply-add 71 | // (1-w) a + w b 72 | template 73 | __host__ __device__ __inline__ T lerp(T a, T b, T w) { 74 | return fmaf(w, b - a, a); 75 | } 76 | 77 | __device__ __inline__ static float _norm( 78 | const float* __restrict__ dir) { 79 | // return sqrtf(dir[0] * dir[0] + dir[1] * dir[1] + dir[2] * dir[2]); 80 | return norm3df(dir[0], dir[1], dir[2]); 81 | } 82 | 83 | __device__ __inline__ static float _rnorm( 84 | const float* __restrict__ dir) { 85 | // return 1.f / _norm(dir); 86 | return rnorm3df(dir[0], dir[1], dir[2]); 87 | } 88 | 89 | __host__ __device__ __inline__ static void xsuby3d( 90 | float* __restrict__ x, 91 | const float* __restrict__ y) { 92 | x[0] -= y[0]; 93 | x[1] -= y[1]; 94 | x[2] -= y[2]; 95 | } 96 | 97 | __host__ __device__ __inline__ static float _dot( 98 | const float* __restrict__ x, 99 | const float* __restrict__ y) { 100 | return x[0] * y[0] + x[1] * y[1] + x[2] * y[2]; 101 | } 102 | 103 | __host__ __device__ __inline__ static void _cross( 104 | const float* __restrict__ a, 105 | const float* __restrict__ b, 106 | float* __restrict__ out) { 107 | out[0] = a[1] * b[2] - a[2] * b[1]; 108 | out[1] = a[2] * b[0] - a[0] * b[2]; 109 | out[2] = a[0] * b[1] - a[1] * b[0]; 110 | } 111 | 112 | __device__ __inline__ static float _dist_ray_to_origin( 113 | const float* __restrict__ origin, 114 | const float* __restrict__ dir) { 115 | // dir must be unit vector 116 | float tmp[3]; 117 | _cross(origin, dir, tmp); 118 | return _norm(tmp); 119 | } 120 | 121 | #define int_div2_ceil(x) ((((x) - 1) >> 1) + 1) 122 | 123 | __host__ __inline__ cudaError_t cuda_assert( 124 | const cudaError_t code, const char* const file, 125 | const int line, const bool abort) { 126 | if (code != cudaSuccess) { 127 | fprintf(stderr, "cuda_assert: %s %s %s %d\n", cudaGetErrorName(code) ,cudaGetErrorString(code), 128 | file, line); 129 | 130 | if (abort) { 131 | cudaDeviceReset(); 132 | exit(code); 133 | } 134 | } 135 | 136 | return code; 137 | } 138 | 139 | #define cuda(...) cuda_assert((cuda##__VA_ARGS__), __FILE__, __LINE__, true); 140 | 141 | -------------------------------------------------------------------------------- /lib/plenoxel/include/data_spec.hpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include "util.hpp" 4 | #include 5 | 6 | using torch::Tensor; 7 | 8 | enum BasisType { 9 | BASIS_TYPE_SH = 1, 10 | BASIS_TYPE_3D_TEXTURE = 4, 11 | BASIS_TYPE_MLP = 255, 12 | }; 13 | 14 | struct SparseGridSpec { 15 | Tensor density_data; 16 | Tensor sh_data; 17 | Tensor links; 18 | Tensor _offset; 19 | Tensor _scaling; 20 | 21 | Tensor background_links; 22 | Tensor background_data; 23 | 24 | int basis_dim; 25 | uint8_t basis_type; 26 | Tensor basis_data; 27 | 28 | inline void check() { 29 | CHECK_INPUT(density_data); 30 | CHECK_INPUT(sh_data); 31 | CHECK_INPUT(links); 32 | if (background_links.defined()) { 33 | CHECK_INPUT(background_links); 34 | CHECK_INPUT(background_data); 35 | TORCH_CHECK(background_links.ndimension() == 36 | 2); // (H, W) -> [N] \cup {-1} 37 | TORCH_CHECK(background_data.ndimension() == 3); // (N, D, C) -> R 38 | } 39 | if (basis_data.defined()) { 40 | CHECK_INPUT(basis_data); 41 | } 42 | CHECK_CPU_INPUT(_offset); 43 | CHECK_CPU_INPUT(_scaling); 44 | TORCH_CHECK(density_data.ndimension() == 2); 45 | TORCH_CHECK(sh_data.ndimension() == 2); 46 | TORCH_CHECK(links.ndimension() == 3); 47 | } 48 | }; 49 | 50 | struct GridOutputGrads { 51 | torch::Tensor grad_density_out; 52 | torch::Tensor grad_sh_out; 53 | torch::Tensor grad_basis_out; 54 | torch::Tensor grad_background_out; 55 | 56 | torch::Tensor mask_out; 57 | torch::Tensor mask_background_out; 58 | inline void check() { 59 | if (grad_density_out.defined()) { 60 | CHECK_INPUT(grad_density_out); 61 | } 62 | if (grad_sh_out.defined()) { 63 | CHECK_INPUT(grad_sh_out); 64 | } 65 | if (grad_basis_out.defined()) { 66 | CHECK_INPUT(grad_basis_out); 67 | } 68 | if (grad_background_out.defined()) { 69 | CHECK_INPUT(grad_background_out); 70 | } 71 | if (mask_out.defined() && mask_out.size(0) > 0) { 72 | CHECK_INPUT(mask_out); 73 | } 74 | if (mask_background_out.defined() && mask_background_out.size(0) > 0) { 75 | CHECK_INPUT(mask_background_out); 76 | } 77 | } 78 | }; 79 | 80 | struct CameraSpec { 81 | torch::Tensor c2w; 82 | float fx; 83 | float fy; 84 | float cx; 85 | float cy; 86 | int width; 87 | int height; 88 | 89 | float ndc_coeffx; 90 | float ndc_coeffy; 91 | 92 | inline void check() { 93 | CHECK_INPUT(c2w); 94 | TORCH_CHECK(c2w.is_floating_point()); 95 | TORCH_CHECK(c2w.ndimension() == 2); 96 | TORCH_CHECK(c2w.size(1) == 4); 97 | } 98 | }; 99 | 100 | struct RaysSpec { 101 | Tensor origins; 102 | Tensor dirs; 103 | inline void check() { 104 | CHECK_INPUT(origins); 105 | CHECK_INPUT(dirs); 106 | TORCH_CHECK(origins.is_floating_point()); 107 | TORCH_CHECK(dirs.is_floating_point()); 108 | } 109 | }; 110 | 111 | struct RenderOptions { 112 | float background_brightness; 113 | float step_size; 114 | float sigma_thresh; 115 | float stop_thresh; 116 | 117 | float near_clip; 118 | bool use_spheric_clip; 119 | 120 | bool last_sample_opaque; 121 | }; 122 | -------------------------------------------------------------------------------- /lib/plenoxel/include/data_spec_packed.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include "data_spec.hpp" 5 | #include "cuda_util.cuh" 6 | #include "random_util.cuh" 7 | 8 | namespace { 9 | namespace device { 10 | 11 | struct PackedSparseGridSpec { 12 | PackedSparseGridSpec(SparseGridSpec& spec) 13 | : 14 | density_data(spec.density_data.data_ptr()), 15 | sh_data(spec.sh_data.data_ptr()), 16 | links(spec.links.data_ptr()), 17 | basis_type(spec.basis_type), 18 | basis_data(spec.basis_data.defined() ? spec.basis_data.data_ptr() : nullptr), 19 | background_links(spec.background_links.defined() ? 20 | spec.background_links.data_ptr() : 21 | nullptr), 22 | background_data(spec.background_data.defined() ? 23 | spec.background_data.data_ptr() : 24 | nullptr), 25 | size{(int)spec.links.size(0), 26 | (int)spec.links.size(1), 27 | (int)spec.links.size(2)}, 28 | stride_x{(int)spec.links.stride(0)}, 29 | background_reso{ 30 | spec.background_links.defined() ? (int)spec.background_links.size(1) : 0, 31 | }, 32 | background_nlayers{ 33 | spec.background_data.defined() ? (int)spec.background_data.size(1) : 0 34 | }, 35 | basis_dim(spec.basis_dim), 36 | sh_data_dim((int)spec.sh_data.size(1)), 37 | basis_reso(spec.basis_data.defined() ? spec.basis_data.size(0) : 0), 38 | _offset{spec._offset.data_ptr()[0], 39 | spec._offset.data_ptr()[1], 40 | spec._offset.data_ptr()[2]}, 41 | _scaling{spec._scaling.data_ptr()[0], 42 | spec._scaling.data_ptr()[1], 43 | spec._scaling.data_ptr()[2]} { 44 | } 45 | 46 | float* __restrict__ density_data; 47 | float* __restrict__ sh_data; 48 | const int32_t* __restrict__ links; 49 | 50 | const uint8_t basis_type; 51 | float* __restrict__ basis_data; 52 | 53 | const int32_t* __restrict__ background_links; 54 | float* __restrict__ background_data; 55 | 56 | const int size[3], stride_x; 57 | const int background_reso, background_nlayers; 58 | 59 | const int basis_dim, sh_data_dim, basis_reso; 60 | const float _offset[3]; 61 | const float _scaling[3]; 62 | }; 63 | 64 | struct PackedGridOutputGrads { 65 | PackedGridOutputGrads(GridOutputGrads& grads) : 66 | grad_density_out(grads.grad_density_out.defined() ? grads.grad_density_out.data_ptr() : nullptr), 67 | grad_sh_out(grads.grad_sh_out.defined() ? grads.grad_sh_out.data_ptr() : nullptr), 68 | grad_basis_out(grads.grad_basis_out.defined() ? grads.grad_basis_out.data_ptr() : nullptr), 69 | grad_background_out(grads.grad_background_out.defined() ? grads.grad_background_out.data_ptr() : nullptr), 70 | mask_out((grads.mask_out.defined() && grads.mask_out.size(0) > 0) ? grads.mask_out.data_ptr() : nullptr), 71 | mask_background_out((grads.mask_background_out.defined() && grads.mask_background_out.size(0) > 0) ? grads.mask_background_out.data_ptr() : nullptr) 72 | {} 73 | float* __restrict__ grad_density_out; 74 | float* __restrict__ grad_sh_out; 75 | float* __restrict__ grad_basis_out; 76 | float* __restrict__ grad_background_out; 77 | 78 | bool* __restrict__ mask_out; 79 | bool* __restrict__ mask_background_out; 80 | }; 81 | 82 | struct PackedCameraSpec { 83 | PackedCameraSpec(CameraSpec& cam) : 84 | c2w(cam.c2w.packed_accessor32()), 85 | fx(cam.fx), fy(cam.fy), 86 | cx(cam.cx), cy(cam.cy), 87 | width(cam.width), height(cam.height), 88 | ndc_coeffx(cam.ndc_coeffx), ndc_coeffy(cam.ndc_coeffy) {} 89 | const torch::PackedTensorAccessor32 90 | c2w; 91 | float fx; 92 | float fy; 93 | float cx; 94 | float cy; 95 | int width; 96 | int height; 97 | 98 | float ndc_coeffx; 99 | float ndc_coeffy; 100 | }; 101 | 102 | struct PackedRaysSpec { 103 | const torch::PackedTensorAccessor32 origins; 104 | const torch::PackedTensorAccessor32 dirs; 105 | PackedRaysSpec(RaysSpec& spec) : 106 | origins(spec.origins.packed_accessor32()), 107 | dirs(spec.dirs.packed_accessor32()) 108 | { } 109 | }; 110 | 111 | struct SingleRaySpec { 112 | SingleRaySpec() = default; 113 | __device__ SingleRaySpec(const float* __restrict__ origin, const float* __restrict__ dir) 114 | : origin{origin[0], origin[1], origin[2]}, 115 | dir{dir[0], dir[1], dir[2]} {} 116 | __device__ void set(const float* __restrict__ origin, const float* __restrict__ dir) { 117 | #pragma unroll 3 118 | for (int i = 0; i < 3; ++i) { 119 | this->origin[i] = origin[i]; 120 | this->dir[i] = dir[i]; 121 | } 122 | } 123 | 124 | float origin[3]; 125 | float dir[3]; 126 | float tmin, tmax, world_step; 127 | 128 | float pos[3]; 129 | int32_t l[3]; 130 | RandomEngine32 rng; 131 | }; 132 | 133 | } // namespace device 134 | } // namespace 135 | -------------------------------------------------------------------------------- /lib/plenoxel/include/random_util.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include 5 | 6 | // A custom xorshift random generator 7 | // Maybe replace with some CUDA internal stuff? 8 | struct RandomEngine32 { 9 | uint32_t x, y, z; 10 | 11 | // Inclusive both 12 | __host__ __device__ 13 | uint32_t randint(uint32_t lo, uint32_t hi) { 14 | if (hi <= lo) return lo; 15 | uint32_t z = (*this)(); 16 | return z % (hi - lo + 1) + lo; 17 | } 18 | 19 | __host__ __device__ 20 | void rand2(float* out1, float* out2) { 21 | const uint32_t z = (*this)(); 22 | const uint32_t fmax = (1 << 16); 23 | const uint32_t z1 = z >> 16; 24 | const uint32_t z2 = z & (fmax - 1); 25 | const float ifmax = 1.f / fmax; 26 | 27 | *out1 = z1 * ifmax; 28 | *out2 = z2 * ifmax; 29 | } 30 | 31 | __host__ __device__ 32 | float rand() { 33 | uint32_t z = (*this)(); 34 | return float(z) / (1LL << 32); 35 | } 36 | 37 | 38 | __host__ __device__ 39 | void randn2(float* out1, float* out2) { 40 | rand2(out1, out2); 41 | // Box-Muller transform 42 | const float srlog = sqrtf(-2 * logf(*out1 + 1e-32f)); 43 | *out2 *= 2 * M_PI; 44 | *out1 = srlog * cosf(*out2); 45 | *out2 = srlog * sinf(*out2); 46 | } 47 | 48 | __host__ __device__ 49 | float randn() { 50 | float x, y; 51 | rand2(&x, &y); 52 | // Box-Muller transform 53 | return sqrtf(-2 * logf(x + 1e-32f))* cosf(2 * M_PI * y); 54 | } 55 | 56 | __host__ __device__ 57 | uint32_t operator()() { 58 | uint32_t t; 59 | x ^= x << 16; 60 | x ^= x >> 5; 61 | x ^= x << 1; 62 | t = x; 63 | x = y; 64 | y = z; 65 | z = t ^ x ^ y; 66 | return z; 67 | } 68 | }; 69 | -------------------------------------------------------------------------------- /lib/plenoxel/include/util.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | // Changed from x.type().is_cuda() due to deprecation 3 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 4 | #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor") 5 | #define CHECK_CONTIGUOUS(x) \ 6 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) \ 8 | CHECK_CUDA(x); \ 9 | CHECK_CONTIGUOUS(x) 10 | #define CHECK_CPU_INPUT(x) \ 11 | CHECK_CPU(x); \ 12 | CHECK_CONTIGUOUS(x) 13 | 14 | #if defined(__CUDACC__) 15 | // #define _EXP(x) expf(x) // SLOW EXP 16 | #define _EXP(x) __expf(x) // FAST EXP 17 | #define _SIGMOID(x) (1 / (1 + _EXP(-(x)))) 18 | 19 | #else 20 | 21 | #define _EXP(x) expf(x) 22 | #define _SIGMOID(x) (1 / (1 + expf(-(x)))) 23 | #endif 24 | #define _SQR(x) ((x) * (x)) 25 | -------------------------------------------------------------------------------- /lib/plenoxel/svox2.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | 3 | // This file contains only Python bindings 4 | #include "data_spec.hpp" 5 | #include 6 | #include 7 | #include 8 | 9 | using torch::Tensor; 10 | 11 | std::tuple sample_grid(SparseGridSpec &, Tensor, 12 | bool); 13 | void sample_grid_backward(SparseGridSpec &, Tensor, Tensor, Tensor, Tensor, 14 | Tensor, bool); 15 | 16 | // ** NeRF rendering formula (trilerp) 17 | Tensor volume_render_cuvol(SparseGridSpec &, RaysSpec &, RenderOptions &); 18 | void volume_render_cuvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 19 | Tensor, Tensor, GridOutputGrads &); 20 | void volume_render_cuvol_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 21 | Tensor, float, float, Tensor, GridOutputGrads &); 22 | // Expected termination (depth) rendering 23 | torch::Tensor volume_render_expected_term(SparseGridSpec &, RaysSpec &, 24 | RenderOptions &); 25 | // Depth rendering based on sigma-threshold as in Dex-NeRF 26 | torch::Tensor volume_render_sigma_thresh(SparseGridSpec &, RaysSpec &, 27 | RenderOptions &, float); 28 | 29 | // Misc 30 | Tensor dilate(Tensor); 31 | void accel_dist_prop(Tensor); 32 | void grid_weight_render(Tensor, CameraSpec &, float, float, bool, Tensor, 33 | Tensor, Tensor); 34 | 35 | // Loss 36 | Tensor tv(Tensor, Tensor, int, int, bool, float, bool, float, float); 37 | void tv_grad(Tensor, Tensor, int, int, float, bool, float, bool, float, float, 38 | Tensor); 39 | void tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, int, int, float, bool, 40 | float, bool, bool, float, float, Tensor); 41 | void msi_tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, float, float, Tensor); 42 | void lumisphere_tv_grad_sparse(SparseGridSpec &, Tensor, Tensor, Tensor, float, 43 | float, float, float, GridOutputGrads &); 44 | 45 | // Optim 46 | void rmsprop_step(Tensor, Tensor, Tensor, Tensor, float, float, float, float, 47 | float); 48 | void sgd_step(Tensor, Tensor, Tensor, float, float); 49 | 50 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 51 | #define _REG_FUNC(funname) m.def(#funname, &funname) 52 | _REG_FUNC(sample_grid); 53 | _REG_FUNC(sample_grid_backward); 54 | _REG_FUNC(volume_render_cuvol); 55 | _REG_FUNC(volume_render_cuvol_backward); 56 | _REG_FUNC(volume_render_cuvol_fused); 57 | _REG_FUNC(volume_render_expected_term); 58 | _REG_FUNC(volume_render_sigma_thresh); 59 | 60 | // Loss 61 | _REG_FUNC(tv); 62 | _REG_FUNC(tv_grad); 63 | _REG_FUNC(tv_grad_sparse); 64 | _REG_FUNC(msi_tv_grad_sparse); 65 | _REG_FUNC(lumisphere_tv_grad_sparse); 66 | 67 | // Misc 68 | _REG_FUNC(dilate); 69 | _REG_FUNC(accel_dist_prop); 70 | _REG_FUNC(grid_weight_render); 71 | 72 | // Optimizer 73 | _REG_FUNC(rmsprop_step); 74 | _REG_FUNC(sgd_step); 75 | #undef _REG_FUNC 76 | 77 | py::class_(m, "SparseGridSpec") 78 | .def(py::init<>()) 79 | .def_readwrite("density_data", &SparseGridSpec::density_data) 80 | .def_readwrite("sh_data", &SparseGridSpec::sh_data) 81 | .def_readwrite("links", &SparseGridSpec::links) 82 | .def_readwrite("_offset", &SparseGridSpec::_offset) 83 | .def_readwrite("_scaling", &SparseGridSpec::_scaling) 84 | .def_readwrite("basis_dim", &SparseGridSpec::basis_dim) 85 | .def_readwrite("basis_type", &SparseGridSpec::basis_type) 86 | .def_readwrite("basis_data", &SparseGridSpec::basis_data) 87 | .def_readwrite("background_links", &SparseGridSpec::background_links) 88 | .def_readwrite("background_data", &SparseGridSpec::background_data); 89 | 90 | py::class_(m, "CameraSpec") 91 | .def(py::init<>()) 92 | .def_readwrite("c2w", &CameraSpec::c2w) 93 | .def_readwrite("fx", &CameraSpec::fx) 94 | .def_readwrite("fy", &CameraSpec::fy) 95 | .def_readwrite("cx", &CameraSpec::cx) 96 | .def_readwrite("cy", &CameraSpec::cy) 97 | .def_readwrite("width", &CameraSpec::width) 98 | .def_readwrite("height", &CameraSpec::height) 99 | .def_readwrite("ndc_coeffx", &CameraSpec::ndc_coeffx) 100 | .def_readwrite("ndc_coeffy", &CameraSpec::ndc_coeffy); 101 | 102 | py::class_(m, "RaysSpec") 103 | .def(py::init<>()) 104 | .def_readwrite("origins", &RaysSpec::origins) 105 | .def_readwrite("dirs", &RaysSpec::dirs); 106 | 107 | py::class_(m, "RenderOptions") 108 | .def(py::init<>()) 109 | .def_readwrite("background_brightness", 110 | &RenderOptions::background_brightness) 111 | .def_readwrite("step_size", &RenderOptions::step_size) 112 | .def_readwrite("sigma_thresh", &RenderOptions::sigma_thresh) 113 | .def_readwrite("stop_thresh", &RenderOptions::stop_thresh) 114 | .def_readwrite("near_clip", &RenderOptions::near_clip) 115 | .def_readwrite("use_spheric_clip", &RenderOptions::use_spheric_clip) 116 | .def_readwrite("last_sample_opaque", &RenderOptions::last_sample_opaque); 117 | 118 | py::class_(m, "GridOutputGrads") 119 | .def(py::init<>()) 120 | .def_readwrite("grad_density_out", &GridOutputGrads::grad_density_out) 121 | .def_readwrite("grad_sh_out", &GridOutputGrads::grad_sh_out) 122 | .def_readwrite("grad_basis_out", &GridOutputGrads::grad_basis_out) 123 | .def_readwrite("grad_background_out", 124 | &GridOutputGrads::grad_background_out) 125 | .def_readwrite("mask_out", &GridOutputGrads::mask_out) 126 | .def_readwrite("mask_background_out", 127 | &GridOutputGrads::mask_background_out); 128 | } 129 | -------------------------------------------------------------------------------- /lib/plenoxel/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio 2 | tqdm 3 | requests 4 | configargparse 5 | scikit-image 6 | imageio-ffmpeg 7 | piqa 8 | wandb 9 | torch-scatter 10 | pytorch_lightning 11 | opencv-python 12 | gin-config 13 | gdown 14 | ninja 15 | functorch 16 | torch==1.11.0 17 | torch_efficient_distloss 18 | -------------------------------------------------------------------------------- /sbatch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH -J icn # Job name 4 | #SBATCH -o sbatch_log/pytorch-1gpu.%j.out # Name of stdout output file (%j expands to %jobId) 5 | #SBATCH -p A100 # queue name or partiton name titanxp/titanrtx/2080ti 6 | #SBATCH -t 3-00:00:00 # Run time (hh:mm:ss) - 1.5 hours 7 | #SBATCH --gres=gpu:1 # number of gpus you want to use 8 | 9 | #SBATCH --nodes=1 10 | ##SBATCH --exclude=n13 11 | ##SBTACH --nodelist=n12 12 | 13 | ##SBTACH --ntasks=1 14 | ##SBATCH --tasks-per-node=1 15 | ##SBATCH --cpus-per-task=1 16 | 17 | cd $SLURM_SUBMIT_DIR 18 | 19 | echo "SLURM_SUBMIT_DIR=$SLURM_SUBMIT_DIR" 20 | echo "CUDA_HOME=$CUDA_HOME" 21 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 22 | echo "CUDA_VERSION=$CUDA_VERSION" 23 | 24 | srun -l /bin/hostname 25 | srun -l /bin/pwd 26 | srun -l /bin/date 27 | 28 | module purge 29 | 30 | echo "Start" 31 | export NCCL_NSOCKS_PERTHREAD=4 32 | export NCCL_SOCKET_NTHREADS=2 33 | export WANDB_SPAWN_METHOD=fork 34 | 35 | 36 | nvidia-smi 37 | date 38 | squeue --job $SLURM_JOBID 39 | 40 | echo "##### END #####" -------------------------------------------------------------------------------- /scripts/collage.sh: -------------------------------------------------------------------------------- 1 | python3 -m run --ginc configs/nerf/360_v2.gin --scene_name garden --ginb run.run_render=True --ginb run.run_eval=False --ginb run.run_train=False 2 | python3 -m run --ginc configs/mipnerf/360_v2.gin --scene_name garden --ginb run.run_render=True --ginb run.run_eval=False --ginb run.run_train=False 3 | python3 -m run --ginc configs/nerfpp/360_v2.gin --scene_name garden --ginb run.run_render=True --ginb run.run_eval=False --ginb run.run_train=False 4 | python3 -m run --ginc configs/mipnerf360/360_v2.gin --scene_name garden --ginb run.run_render=True --ginb run.run_eval=False --ginb run.run_train=False 5 | python3 -m run --ginc configs/plenoxel/360_v2.gin --scene_name garden --ginb run.run_render=True --ginb run.run_eval=False --ginb run.run_train=False 6 | python3 -m run --ginc configs/dvgo/360_v2.gin --scene_name garden --ginb run.run_render=True --ginb run.run_eval=False --ginb run.run_train=False 7 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | dname=$1 2 | 3 | case $dname in 4 | "nerf_synthetic") 5 | gdown https://drive.google.com/uc?id=18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG 6 | unzip nerf_synthetic.zip 7 | rm -rf __MACOSX 8 | mv nerf_synthetic data/blender 9 | rm nerf_synthetic.zip 10 | ;; 11 | "nerf_llff") 12 | gdown https://drive.google.com/uc?id=16VnMcF1KJYxN9QId6TClMsZRahHNMW5g 13 | unzip nerf_llff_data.zip 14 | rm -rf __MACOSX 15 | mv nerf_llff_data data/llff 16 | rm nerf_llff_data.zip 17 | ;; 18 | "nerf_real_360") 19 | gdown https://drive.google.com/uc?id=1jzggQ7IPaJJTKx9yLASWHrX8dXHnG5eB 20 | unzip nerf_real_360.zip 21 | rm -rf __MACOSX 22 | mkdir nerf_real_360 23 | mv vasedeck nerf_real_360 24 | mv pinecone nerf_real_360 25 | mv nerf_real_360 data/nerf_360 26 | rm nerf_real_360.zip 27 | ;; 28 | "tanks_and_temples") 29 | gdown 11KRfN91W1AxAW6lOFs4EeYDbeoQZCi87 30 | unzip tanks_and_temples.zip 31 | cd tanks_and_temples/tat_training_Truck 32 | cp -r "test" "validation" 33 | cd ../.. 34 | mv tanks_and_temples data 35 | rm tanks_and_temples.zip 36 | rm -rf __MACOSX 37 | ;; 38 | "lf") 39 | gdown 1gsjDjkbTh4GAR9fFqlIDZ__qR9NYTURQ 40 | unzip lf_data.zip 41 | mv lf_data data 42 | rm -rf __MACOSX 43 | rm lf_data.zip 44 | ;; 45 | "nerf_360_v2") 46 | wget http://storage.googleapis.com/gresearch/refraw360/360_v2.zip 47 | mkdir 360_v2 48 | unzip 360_v2.zip -d 360_v2 49 | mv 360_v2 data 50 | rm 360_v2.zip 51 | ;; 52 | "shiny_blender") 53 | wget https://storage.googleapis.com/gresearch/refraw360/ref.zip 54 | unzip ref.zip 55 | mv refnerf data 56 | rm ref.zip 57 | python utils/preprocess_shiny_blender.py 58 | ;; 59 | esac 60 | -------------------------------------------------------------------------------- /scripts/dvgo.sh: -------------------------------------------------------------------------------- 1 | python3 -m run --ginc configs/dvgo/blender.gin --scene_name chair 2 | python3 -m run --ginc configs/dvgo/blender.gin --scene_name drums 3 | python3 -m run --ginc configs/dvgo/blender.gin --scene_name ficus 4 | python3 -m run --ginc configs/dvgo/blender.gin --scene_name hotdog 5 | python3 -m run --ginc configs/dvgo/blender.gin --scene_name lego 6 | python3 -m run --ginc configs/dvgo/blender.gin --scene_name materials 7 | python3 -m run --ginc configs/dvgo/blender.gin --scene_name mic 8 | python3 -m run --ginc configs/dvgo/blender.gin --scene_name ship 9 | python3 -m run --ginc configs/dvgo/llff.gin --scene_name fern 10 | python3 -m run --ginc configs/dvgo/llff.gin --scene_name flower 11 | python3 -m run --ginc configs/dvgo/llff.gin --scene_name fortress 12 | python3 -m run --ginc configs/dvgo/llff.gin --scene_name horns 13 | python3 -m run --ginc configs/dvgo/llff.gin --scene_name orchids 14 | python3 -m run --ginc configs/dvgo/llff.gin --scene_name leaves 15 | python3 -m run --ginc configs/dvgo/llff.gin --scene_name room 16 | python3 -m run --ginc configs/dvgo/llff.gin --scene_name trex 17 | python3 -m run --ginc configs/dvgo/tnt.gin --scene_name tat_intermediate_M60 18 | python3 -m run --ginc configs/dvgo/tnt.gin --scene_name tat_intermediate_Playground 19 | python3 -m run --ginc configs/dvgo/tnt.gin --scene_name tat_intermediate_Train 20 | python3 -m run --ginc configs/dvgo/tnt.gin --scene_name tat_training_Truck 21 | python3 -m run --ginc configs/dvgo/lf.gin --scene_name africa 22 | python3 -m run --ginc configs/dvgo/lf.gin --scene_name basket 23 | python3 -m run --ginc configs/dvgo/lf.gin --scene_name ship 24 | python3 -m run --ginc configs/dvgo/lf.gin --scene_name statue 25 | python3 -m run --ginc configs/dvgo/lf.gin --scene_name torch 26 | python3 -m run --ginc configs/dvgo/360_v2.gin --scene_name bicycle 27 | python3 -m run --ginc configs/dvgo/360_v2.gin --scene_name bonsai 28 | python3 -m run --ginc configs/dvgo/360_v2.gin --scene_name counter 29 | python3 -m run --ginc configs/dvgo/360_v2.gin --scene_name garden 30 | python3 -m run --ginc configs/dvgo/360_v2.gin --scene_name kitchen 31 | python3 -m run --ginc configs/dvgo/360_v2.gin --scene_name room 32 | python3 -m run --ginc configs/dvgo/360_v2.gin --scene_name stump 33 | python3 -m run --ginc configs/dvgo/shiny_blender.gin --scene_name ball 34 | python3 -m run --ginc configs/dvgo/shiny_blender.gin --scene_name car 35 | python3 -m run --ginc configs/dvgo/shiny_blender.gin --scene_name coffee 36 | python3 -m run --ginc configs/dvgo/shiny_blender.gin --scene_name helmet 37 | python3 -m run --ginc configs/dvgo/shiny_blender.gin --scene_name teapot 38 | python3 -m run --ginc configs/dvgo/shiny_blender.gin --scene_name toaster -------------------------------------------------------------------------------- /scripts/mipnerf.sh: -------------------------------------------------------------------------------- 1 | python3 -m run --ginc configs/mipnerf/blender.gin --scene_name chair 2 | python3 -m run --ginc configs/mipnerf/blender.gin --scene_name drums 3 | python3 -m run --ginc configs/mipnerf/blender.gin --scene_name ficus 4 | python3 -m run --ginc configs/mipnerf/blender.gin --scene_name hotdog 5 | python3 -m run --ginc configs/mipnerf/blender.gin --scene_name lego 6 | python3 -m run --ginc configs/mipnerf/blender.gin --scene_name materials 7 | python3 -m run --ginc configs/mipnerf/blender.gin --scene_name mic 8 | python3 -m run --ginc configs/mipnerf/blender.gin --scene_name ship 9 | python3 -m run --ginc configs/mipnerf/llff.gin --scene_name fern 10 | python3 -m run --ginc configs/mipnerf/llff.gin --scene_name flower 11 | python3 -m run --ginc configs/mipnerf/llff.gin --scene_name fortress 12 | python3 -m run --ginc configs/mipnerf/llff.gin --scene_name horns 13 | python3 -m run --ginc configs/mipnerf/llff.gin --scene_name leaves 14 | python3 -m run --ginc configs/mipnerf/llff.gin --scene_name orchids 15 | python3 -m run --ginc configs/mipnerf/llff.gin --scene_name room 16 | python3 -m run --ginc configs/mipnerf/llff.gin --scene_name trex 17 | python3 -m run --ginc configs/mipnerf/tnt.gin --scene_name tat_intermediate_M60 18 | python3 -m run --ginc configs/mipnerf/tnt.gin --scene_name tat_intermediate_Playground 19 | python3 -m run --ginc configs/mipnerf/tnt.gin --scene_name tat_intermediate_Train 20 | python3 -m run --ginc configs/mipnerf/tnt.gin --scene_name tat_training_Truck 21 | python3 -m run --ginc configs/mipnerf/lf.gin --scene_name africa 22 | python3 -m run --ginc configs/mipnerf/lf.gin --scene_name basket 23 | python3 -m run --ginc configs/mipnerf/lf.gin --scene_name ship 24 | python3 -m run --ginc configs/mipnerf/lf.gin --scene_name statue 25 | python3 -m run --ginc configs/mipnerf/lf.gin --scene_name torch 26 | python3 -m run --ginc configs/mipnerf/blender_ms.gin --scene_name chair 27 | python3 -m run --ginc configs/mipnerf/blender_ms.gin --scene_name drums 28 | python3 -m run --ginc configs/mipnerf/blender_ms.gin --scene_name ficus 29 | python3 -m run --ginc configs/mipnerf/blender_ms.gin --scene_name hotdog 30 | python3 -m run --ginc configs/mipnerf/blender_ms.gin --scene_name lego 31 | python3 -m run --ginc configs/mipnerf/blender_ms.gin --scene_name materials 32 | python3 -m run --ginc configs/mipnerf/blender_ms.gin --scene_name mic 33 | python3 -m run --ginc configs/mipnerf/blender_ms.gin --scene_name ship 34 | python3 -m run --ginc configs/mipnerf/360_v2.gin --scene_name bicycle 35 | python3 -m run --ginc configs/mipnerf/360_v2.gin --scene_name bonsai 36 | python3 -m run --ginc configs/mipnerf/360_v2.gin --scene_name counter 37 | python3 -m run --ginc configs/mipnerf/360_v2.gin --scene_name garden 38 | python3 -m run --ginc configs/mipnerf/360_v2.gin --scene_name kitchen 39 | python3 -m run --ginc configs/mipnerf/360_v2.gin --scene_name room 40 | python3 -m run --ginc configs/mipnerf/360_v2.gin --scene_name stump -------------------------------------------------------------------------------- /scripts/mipnerf360.sh: -------------------------------------------------------------------------------- 1 | python3 -m run --ginc configs/mipnerf360/tnt.gin --scene_name tat_intermediate_M60 2 | python3 -m run --ginc configs/mipnerf360/tnt.gin --scene_name tat_intermediate_Playground 3 | python3 -m run --ginc configs/mipnerf360/tnt.gin --scene_name tat_intermediate_Train 4 | python3 -m run --ginc configs/mipnerf360/tnt.gin --scene_name tat_training_Truck 5 | python3 -m run --ginc configs/mipnerf360/lf.gin --scene_name africa 6 | python3 -m run --ginc configs/mipnerf360/lf.gin --scene_name basket 7 | python3 -m run --ginc configs/mipnerf360/lf.gin --scene_name ship 8 | python3 -m run --ginc configs/mipnerf360/lf.gin --scene_name statue 9 | python3 -m run --ginc configs/mipnerf360/lf.gin --scene_name torch 10 | python3 -m run --ginc configs/mipnerf360/360_v2.gin --scene_name bicycle 11 | python3 -m run --ginc configs/mipnerf360/360_v2.gin --scene_name bonsai 12 | python3 -m run --ginc configs/mipnerf360/360_v2.gin --scene_name counter 13 | python3 -m run --ginc configs/mipnerf360/360_v2.gin --scene_name garden 14 | python3 -m run --ginc configs/mipnerf360/360_v2.gin --scene_name kitchen 15 | python3 -m run --ginc configs/mipnerf360/360_v2.gin --scene_name room 16 | python3 -m run --ginc configs/mipnerf360/360_v2.gin --scene_name stump 17 | python3 -m run --ginc configs/mipnerf360/llff.gin --scene_name fern 18 | python3 -m run --ginc configs/mipnerf360/llff.gin --scene_name flower 19 | python3 -m run --ginc configs/mipnerf360/llff.gin --scene_name fortress 20 | python3 -m run --ginc configs/mipnerf360/llff.gin --scene_name horns 21 | python3 -m run --ginc configs/mipnerf360/llff.gin --scene_name leaves 22 | python3 -m run --ginc configs/mipnerf360/llff.gin --scene_name orchids 23 | python3 -m run --ginc configs/mipnerf360/llff.gin --scene_name room 24 | python3 -m run --ginc configs/mipnerf360/llff.gin --scene_name trex -------------------------------------------------------------------------------- /scripts/nerf.sh: -------------------------------------------------------------------------------- 1 | python3 -m run --ginc configs/nerf/blender.gin --scene_name chair 2 | python3 -m run --ginc configs/nerf/blender.gin --scene_name drums 3 | python3 -m run --ginc configs/nerf/blender.gin --scene_name ficus 4 | python3 -m run --ginc configs/nerf/blender.gin --scene_name hotdog 5 | python3 -m run --ginc configs/nerf/blender.gin --scene_name lego 6 | python3 -m run --ginc configs/nerf/blender.gin --scene_name materials 7 | python3 -m run --ginc configs/nerf/blender.gin --scene_name mic 8 | python3 -m run --ginc configs/nerf/blender.gin --scene_name ship 9 | python3 -m run --ginc configs/nerf/llff.gin --scene_name fern 10 | python3 -m run --ginc configs/nerf/llff.gin --scene_name flower 11 | python3 -m run --ginc configs/nerf/llff.gin --scene_name fortress 12 | python3 -m run --ginc configs/nerf/llff.gin --scene_name horns 13 | python3 -m run --ginc configs/nerf/llff.gin --scene_name orchids 14 | python3 -m run --ginc configs/nerf/llff.gin --scene_name leaves 15 | python3 -m run --ginc configs/nerf/llff.gin --scene_name room 16 | python3 -m run --ginc configs/nerf/llff.gin --scene_name trex 17 | python3 -m run --ginc configs/nerf/tnt.gin --scene_name tat_intermediate_M60 18 | python3 -m run --ginc configs/nerf/tnt.gin --scene_name tat_intermediate_Playground 19 | python3 -m run --ginc configs/nerf/tnt.gin --scene_name tat_intermediate_Train 20 | python3 -m run --ginc configs/nerf/tnt.gin --scene_name tat_training_Truck 21 | python3 -m run --ginc configs/nerf/lf.gin --scene_name africa 22 | python3 -m run --ginc configs/nerf/lf.gin --scene_name basket 23 | python3 -m run --ginc configs/nerf/lf.gin --scene_name ship 24 | python3 -m run --ginc configs/nerf/lf.gin --scene_name statue 25 | python3 -m run --ginc configs/nerf/lf.gin --scene_name torch 26 | python3 -m run --ginc configs/nerf/blender_ms.gin --scene_name chair 27 | python3 -m run --ginc configs/nerf/blender_ms.gin --scene_name drums 28 | python3 -m run --ginc configs/nerf/blender_ms.gin --scene_name ficus 29 | python3 -m run --ginc configs/nerf/blender_ms.gin --scene_name hotdog 30 | python3 -m run --ginc configs/nerf/blender_ms.gin --scene_name lego 31 | python3 -m run --ginc configs/nerf/blender_ms.gin --scene_name materials 32 | python3 -m run --ginc configs/nerf/blender_ms.gin --scene_name mic 33 | python3 -m run --ginc configs/nerf/blender_ms.gin --scene_name ship 34 | python3 -m run --ginc configs/nerf/360_v2.gin --scene_name bicycle 35 | python3 -m run --ginc configs/nerf/360_v2.gin --scene_name bonsai 36 | python3 -m run --ginc configs/nerf/360_v2.gin --scene_name counter 37 | python3 -m run --ginc configs/nerf/360_v2.gin --scene_name garden 38 | python3 -m run --ginc configs/nerf/360_v2.gin --scene_name kitchen 39 | python3 -m run --ginc configs/nerf/360_v2.gin --scene_name room 40 | python3 -m run --ginc configs/nerf/360_v2.gin --scene_name stump -------------------------------------------------------------------------------- /scripts/nerfpp.sh: -------------------------------------------------------------------------------- 1 | python3 -m run --ginc configs/nerfpp/tnt.gin --scene_name tat_intermediate_M60 2 | python3 -m run --ginc configs/nerfpp/tnt.gin --scene_name tat_intermediate_Playground 3 | python3 -m run --ginc configs/nerfpp/tnt.gin --scene_name tat_intermediate_Train 4 | python3 -m run --ginc configs/nerfpp/tnt.gin --scene_name tat_training_Truck 5 | python3 -m run --ginc configs/nerfpp/lf.gin --scene_name africa 6 | python3 -m run --ginc configs/nerfpp/lf.gin --scene_name basket 7 | python3 -m run --ginc configs/nerfpp/lf.gin --scene_name ship 8 | python3 -m run --ginc configs/nerfpp/lf.gin --scene_name statue 9 | python3 -m run --ginc configs/nerfpp/lf.gin --scene_name torch 10 | python3 -m run --ginc configs/nerfpp/360_v2.gin --scene_name bicycle 11 | python3 -m run --ginc configs/nerfpp/360_v2.gin --scene_name bonsai 12 | python3 -m run --ginc configs/nerfpp/360_v2.gin --scene_name counter 13 | python3 -m run --ginc configs/nerfpp/360_v2.gin --scene_name garden 14 | python3 -m run --ginc configs/nerfpp/360_v2.gin --scene_name kitchen 15 | python3 -m run --ginc configs/nerfpp/360_v2.gin --scene_name room 16 | python3 -m run --ginc configs/nerfpp/360_v2.gin --scene_name stump -------------------------------------------------------------------------------- /scripts/plenoxel.sh: -------------------------------------------------------------------------------- 1 | python3 -m run --ginc configs/plenoxel/blender.gin --scene_name chair 2 | python3 -m run --ginc configs/plenoxel/blender.gin --scene_name drums 3 | python3 -m run --ginc configs/plenoxel/blender.gin --scene_name ficus 4 | python3 -m run --ginc configs/plenoxel/blender.gin --scene_name hotdog 5 | python3 -m run --ginc configs/plenoxel/blender.gin --scene_name lego 6 | python3 -m run --ginc configs/plenoxel/blender.gin --scene_name materials 7 | python3 -m run --ginc configs/plenoxel/blender.gin --scene_name mic 8 | python3 -m run --ginc configs/plenoxel/blender.gin --scene_name ship 9 | python3 -m run --ginc configs/plenoxel/llff.gin --scene_name fern 10 | python3 -m run --ginc configs/plenoxel/llff.gin --scene_name flower 11 | python3 -m run --ginc configs/plenoxel/llff.gin --scene_name fortress 12 | python3 -m run --ginc configs/plenoxel/llff.gin --scene_name horns 13 | python3 -m run --ginc configs/plenoxel/llff.gin --scene_name orchids 14 | python3 -m run --ginc configs/plenoxel/llff.gin --scene_name leaves 15 | python3 -m run --ginc configs/plenoxel/llff.gin --scene_name room 16 | python3 -m run --ginc configs/plenoxel/llff.gin --scene_name trex 17 | python3 -m run --ginc configs/plenoxel/tnt.gin --scene_name tat_intermediate_M60 18 | python3 -m run --ginc configs/plenoxel/tnt.gin --scene_name tat_intermediate_Playground 19 | python3 -m run --ginc configs/plenoxel/tnt.gin --scene_name tat_intermediate_Train 20 | python3 -m run --ginc configs/plenoxel/tnt.gin --scene_name tat_training_Truck 21 | python3 -m run --ginc configs/plenoxel/lf.gin --scene_name africa 22 | python3 -m run --ginc configs/plenoxel/lf.gin --scene_name basket 23 | python3 -m run --ginc configs/plenoxel/lf.gin --scene_name ship 24 | python3 -m run --ginc configs/plenoxel/lf.gin --scene_name statue 25 | python3 -m run --ginc configs/plenoxel/lf.gin --scene_name torch 26 | python3 -m run --ginc configs/plenoxel/blender_ms.gin --scene_name chair 27 | python3 -m run --ginc configs/plenoxel/blender_ms.gin --scene_name drums 28 | python3 -m run --ginc configs/plenoxel/blender_ms.gin --scene_name ficus 29 | python3 -m run --ginc configs/plenoxel/blender_ms.gin --scene_name hotdog 30 | python3 -m run --ginc configs/plenoxel/blender_ms.gin --scene_name lego 31 | python3 -m run --ginc configs/plenoxel/blender_ms.gin --scene_name materials 32 | python3 -m run --ginc configs/plenoxel/blender_ms.gin --scene_name mic 33 | python3 -m run --ginc configs/plenoxel/blender_ms.gin --scene_name ship 34 | python3 -m run --ginc configs/plenoxel/360_v2.gin --scene_name bicycle 35 | python3 -m run --ginc configs/plenoxel/360_v2.gin --scene_name bonsai 36 | python3 -m run --ginc configs/plenoxel/360_v2.gin --scene_name counter 37 | python3 -m run --ginc configs/plenoxel/360_v2.gin --scene_name garden 38 | python3 -m run --ginc configs/plenoxel/360_v2.gin --scene_name kitchen 39 | python3 -m run --ginc configs/plenoxel/360_v2.gin --scene_name room 40 | python3 -m run --ginc configs/plenoxel/360_v2.gin --scene_name stump 41 | python3 -m run --ginc configs/plenoxel/shiny_blender.gin --scene_name ball 42 | python3 -m run --ginc configs/plenoxel/shiny_blender.gin --scene_name car 43 | python3 -m run --ginc configs/plenoxel/shiny_blender.gin --scene_name coffee 44 | python3 -m run --ginc configs/plenoxel/shiny_blender.gin --scene_name helmet 45 | python3 -m run --ginc configs/plenoxel/shiny_blender.gin --scene_name teapot 46 | python3 -m run --ginc configs/plenoxel/shiny_blender.gin --scene_name toaster -------------------------------------------------------------------------------- /scripts/refnerf.sh: -------------------------------------------------------------------------------- 1 | python3 -m run --ginc configs/refnerf/blender.gin --scene_name chair 2 | python3 -m run --ginc configs/refnerf/blender.gin --scene_name drums 3 | python3 -m run --ginc configs/refnerf/blender.gin --scene_name ficus 4 | python3 -m run --ginc configs/refnerf/blender.gin --scene_name hotdog 5 | python3 -m run --ginc configs/refnerf/blender.gin --scene_name lego 6 | python3 -m run --ginc configs/refnerf/blender.gin --scene_name materials 7 | python3 -m run --ginc configs/refnerf/blender.gin --scene_name mic 8 | python3 -m run --ginc configs/refnerf/blender.gin --scene_name ship 9 | python3 -m run --ginc configs/refnerf/shiny_blender.gin --scene_name ball 10 | python3 -m run --ginc configs/refnerf/shiny_blender.gin --scene_name car 11 | python3 -m run --ginc configs/refnerf/shiny_blender.gin --scene_name coffee 12 | python3 -m run --ginc configs/refnerf/shiny_blender.gin --scene_name helmet 13 | python3 -m run --ginc configs/refnerf/shiny_blender.gin --scene_name teapot 14 | python3 -m run --ginc configs/refnerf/shiny_blender.gin --scene_name toaster -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | python3 -m run --ginc configs/mipnerf/blender.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name lego 2 | python3 -m run --ginc configs/mipnerf/llff.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 3 | python3 -m run --ginc configs/mipnerf/tnt.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name tat_intermediate_Playground 4 | python3 -m run --ginc configs/mipnerf/lf.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name ship 5 | python3 -m run --ginc configs/mipnerf/blender_ms.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name lego 6 | python3 -m run --ginc configs/mipnerf/360_v2.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 7 | python3 -m run --ginc configs/mipnerf360/tnt.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name tat_intermediate_Playground 8 | python3 -m run --ginc configs/mipnerf360/lf.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name ship 9 | python3 -m run --ginc configs/mipnerf360/360_v2.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 10 | python3 -m run --ginc configs/mipnerf360/llff.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 11 | python3 -m run --ginc configs/nerf/blender.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name lego 12 | python3 -m run --ginc configs/nerf/llff.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 13 | python3 -m run --ginc configs/nerf/tnt.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name tat_intermediate_Playground 14 | python3 -m run --ginc configs/nerf/lf.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name ship 15 | python3 -m run --ginc configs/nerf/blender_ms.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name lego 16 | python3 -m run --ginc configs/nerf/360_v2.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 17 | python3 -m run --ginc configs/nerfpp/tnt.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name tat_intermediate_Playground 18 | python3 -m run --ginc configs/nerfpp/lf.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name ship 19 | python3 -m run --ginc configs/nerfpp/360_v2.gin --ginb LitData.epoch_size=2000 --ginb run.max_steps=2000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 20 | python3 -m run --ginc configs/dvgo/blender.gin --ginb LitData.epoch_size=10000 --ginb run.max_steps=10000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name lego 21 | python3 -m run --ginc configs/dvgo/llff.gin --ginb LitData.epoch_size=10000 --ginb run.max_steps=10000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 22 | python3 -m run --ginc configs/dvgo/tnt.gin --ginb LitData.epoch_size=10000 --ginb run.max_steps=10000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name tat_intermediate_Playground 23 | python3 -m run --ginc configs/dvgo/lf.gin --ginb LitData.epoch_size=10000 --ginb run.max_steps=10000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name ship 24 | python3 -m run --ginc configs/dvgo/360_v2.gin --ginb LitData.epoch_size=10000 --ginb run.max_steps=10000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 25 | python3 -m run --ginc configs/plenoxel/blender.gin --ginb LitData.epoch_size=100000 --ginb run.max_steps=100000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name lego 26 | python3 -m run --ginc configs/plenoxel/llff.gin --ginb LitData.epoch_size=100000 --ginb run.max_steps=100000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room 27 | python3 -m run --ginc configs/plenoxel/blender_ms.gin --ginb LitData.epoch_size=100000 --ginb run.max_steps=100000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name lego 28 | python3 -m run --ginc configs/plenoxel/tnt.gin --ginb LitData.epoch_size=100000 --ginb run.max_steps=100000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name tat_intermediate_Playground 29 | python3 -m run --ginc configs/plenoxel/lf.gin --ginb LitData.epoch_size=100000 --ginb run.max_steps=100000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name ship 30 | python3 -m run --ginc configs/plenoxel/360_v2.gin --ginb LitData.epoch_size=100000 --ginb run.max_steps=100000 --ginb run.postfix="'test'" --ginb LitData.batch_size=512 --scene_name room -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from Plenoxels (https://github.com/sxyu/svox2) 8 | # Copyright (c) 2022 the Plenoxel authors. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import os 12 | import os.path as osp 13 | import warnings 14 | 15 | from setuptools import setup 16 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 17 | 18 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 19 | 20 | __version__ = None 21 | exec(open("lib/plenoxel/version.py", "r").read()) 22 | 23 | CUDA_FLAGS = [] 24 | INSTALL_REQUIREMENTS = [] 25 | include_dirs = [osp.join(ROOT_DIR, "lib", "plenoxel", "include")] 26 | 27 | # From PyTorch3D 28 | cub_home = os.environ.get("CUB_HOME", None) 29 | if cub_home is None: 30 | prefix = os.environ.get("CONDA_PREFIX", None) 31 | if prefix is not None and os.path.isdir(prefix + "/include/cub"): 32 | cub_home = prefix + "/include" 33 | 34 | if cub_home is None: 35 | warnings.warn( 36 | "The environment variable `CUB_HOME` was not found." 37 | "Installation will fail if your system CUDA toolkit version is less than 11." 38 | "NVIDIA CUB can be downloaded " 39 | "from `https://github.com/NVIDIA/cub/releases`. You can unpack " 40 | "it to a location of your choice and set the environment variable " 41 | "`CUB_HOME` to the folder containing the `CMakeListst.txt` file." 42 | ) 43 | else: 44 | include_dirs.append(os.path.realpath(cub_home).replace("\\ ", " ")) 45 | 46 | try: 47 | ext_modules = [ 48 | CUDAExtension( 49 | "lib.plenoxel", 50 | [ 51 | "lib/plenoxel/svox2.cpp", 52 | "lib/plenoxel/svox2_kernel.cu", 53 | "lib/plenoxel/render_lerp_kernel_cuvol.cu", 54 | "lib/plenoxel/misc_kernel.cu", 55 | "lib/plenoxel/loss_kernel.cu", 56 | "lib/plenoxel/optim_kernel.cu", 57 | ], 58 | include_dirs=include_dirs, 59 | optional=False, 60 | ), 61 | ] 62 | except: 63 | import warnings 64 | 65 | warnings.warn("Failed to build CUDA extension") 66 | ext_modules = [] 67 | 68 | setup( 69 | name="plenoxel", 70 | version=__version__, 71 | author="Alex Yu", 72 | author_email="alexyu99126@gmail.com", 73 | description="PyTorch sparse voxel volume extension, including custom CUDA kernels", 74 | long_description="PyTorch sparse voxel volume extension, including custom CUDA kernels", 75 | ext_modules=ext_modules, 76 | setup_requires=["pybind11>=2.5.0"], 77 | packages=["lib.plenoxel"], 78 | cmdclass={"build_ext": BuildExtension}, 79 | zip_safe=False, 80 | ) 81 | -------------------------------------------------------------------------------- /src/data/data_util/blender.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from NeRF (https://github.com/bmild/nerf) 8 | # Copyright (c) 2020 Google LLC. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import json 12 | import os 13 | 14 | import gdown 15 | import imageio 16 | import numpy as np 17 | import torch 18 | 19 | trans_t = lambda t: torch.tensor( 20 | [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]] 21 | ).float() 22 | 23 | rot_phi = lambda phi: torch.tensor( 24 | [ 25 | [1, 0, 0, 0], 26 | [0, np.cos(phi), -np.sin(phi), 0], 27 | [0, np.sin(phi), np.cos(phi), 0], 28 | [0, 0, 0, 1], 29 | ] 30 | ).float() 31 | 32 | rot_theta = lambda th: torch.tensor( 33 | [ 34 | [np.cos(th), 0, -np.sin(th), 0], 35 | [0, 1, 0, 0], 36 | [np.sin(th), 0, np.cos(th), 0], 37 | [0, 0, 0, 1], 38 | ] 39 | ).float() 40 | 41 | 42 | def pose_spherical(theta, phi, radius): 43 | c2w = trans_t(radius) 44 | c2w = rot_phi(phi / 180.0 * np.pi) @ c2w 45 | c2w = rot_theta(theta / 180.0 * np.pi) @ c2w 46 | c2w = ( 47 | torch.tensor([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]).float() 48 | @ c2w 49 | ) 50 | return c2w 51 | 52 | 53 | def load_blender_data( 54 | datadir: str, 55 | scene_name: str, 56 | train_skip: int, 57 | val_skip: int, 58 | test_skip: int, 59 | cam_scale_factor: float, 60 | white_bkgd: bool, 61 | ): 62 | basedir = os.path.join(datadir, scene_name) 63 | cam_trans = np.diag(np.array([1, -1, -1, 1], dtype=np.float32)) 64 | splits = ["train", "val", "test"] 65 | metas = {} 66 | for s in splits: 67 | with open(os.path.join(basedir, "transforms_{}.json".format(s)), "r") as fp: 68 | metas[s] = json.load(fp) 69 | 70 | images = [] 71 | extrinsics = [] 72 | counts = [0] 73 | 74 | for s in splits: 75 | meta = metas[s] 76 | imgs = [] 77 | poses = [] 78 | 79 | if s == "train": 80 | skip = train_skip 81 | elif s == "val": 82 | skip = val_skip 83 | elif s == "test": 84 | skip = test_skip 85 | 86 | for frame in meta["frames"][::skip]: 87 | fname = os.path.join(basedir, frame["file_path"] + ".png") 88 | imgs.append(imageio.imread(fname)) 89 | poses.append(np.array(frame["transform_matrix"])) 90 | imgs = (np.array(imgs) / 255.0).astype(np.float32) # keep all 4 channels (RGBA) 91 | poses = np.array(poses).astype(np.float32) 92 | counts.append(counts[-1] + imgs.shape[0]) 93 | images.append(imgs) 94 | extrinsics.append(poses) 95 | 96 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] 97 | 98 | images = np.concatenate(images, 0) 99 | 100 | extrinsics = np.concatenate(extrinsics, 0) 101 | 102 | extrinsics[:, :3, 3] *= cam_scale_factor 103 | extrinsics = extrinsics @ cam_trans 104 | 105 | h, w = imgs[0].shape[:2] 106 | num_frame = len(extrinsics) 107 | i_split += [np.arange(num_frame)] 108 | 109 | camera_angle_x = float(meta["camera_angle_x"]) 110 | focal = 0.5 * w / np.tan(0.5 * camera_angle_x) 111 | intrinsics = np.array( 112 | [ 113 | [[focal, 0.0, 0.5 * w], [0.0, focal, 0.5 * h], [0.0, 0.0, 1.0]] 114 | for _ in range(num_frame) 115 | ] 116 | ) 117 | image_sizes = np.array([[h, w] for _ in range(num_frame)]) 118 | 119 | render_poses = torch.stack( 120 | [ 121 | pose_spherical(angle, -30.0, 4.0) @ cam_trans 122 | for angle in np.linspace(-180, 180, 40 + 1)[:-1] 123 | ], 124 | 0, 125 | ) 126 | render_poses[:, :3, 3] *= cam_scale_factor 127 | near = 2.0 128 | far = 6.0 129 | 130 | if white_bkgd: 131 | images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) 132 | else: 133 | images = images[..., :3] 134 | 135 | return ( 136 | images, 137 | intrinsics, 138 | extrinsics, 139 | image_sizes, 140 | near, 141 | far, 142 | (-1, -1), 143 | i_split, 144 | render_poses, 145 | ) 146 | -------------------------------------------------------------------------------- /src/data/data_util/blender_ms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from NeRF (https://github.com/bmild/nerf) 8 | # Copyright (c) 2020 Google LLC. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import json 12 | import os 13 | 14 | import imageio 15 | import numpy as np 16 | import torch 17 | 18 | trans_t = lambda t: torch.tensor( 19 | [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]] 20 | ).float() 21 | 22 | rot_phi = lambda phi: torch.tensor( 23 | [ 24 | [1, 0, 0, 0], 25 | [0, np.cos(phi), -np.sin(phi), 0], 26 | [0, np.sin(phi), np.cos(phi), 0], 27 | [0, 0, 0, 1], 28 | ] 29 | ).float() 30 | 31 | rot_theta = lambda th: torch.tensor( 32 | [ 33 | [np.cos(th), 0, -np.sin(th), 0], 34 | [0, 1, 0, 0], 35 | [np.sin(th), 0, np.cos(th), 0], 36 | [0, 0, 0, 1], 37 | ] 38 | ).float() 39 | 40 | 41 | def pose_spherical(theta, phi, radius): 42 | c2w = trans_t(radius) 43 | c2w = rot_phi(phi / 180.0 * np.pi) @ c2w 44 | c2w = rot_theta(theta / 180.0 * np.pi) @ c2w 45 | c2w = ( 46 | torch.tensor([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]).float() 47 | @ c2w 48 | ) 49 | return c2w 50 | 51 | 52 | def load_blender_ms_data( 53 | datadir: str, 54 | scene_name: str, 55 | train_skip: int, 56 | val_skip: int, 57 | test_skip: int, 58 | cam_scale_factor: float, 59 | white_bkgd: bool, 60 | ): 61 | basedir = os.path.join(datadir, scene_name) 62 | cam_trans = np.diag(np.array([1, -1, -1, 1], dtype=np.float32)) 63 | splits = ["train", "val", "test"] 64 | 65 | metadatapath = os.path.join(basedir, "metadata.json") 66 | with open(metadatapath) as fp: 67 | metadata = json.load(fp) 68 | 69 | images = [] 70 | extrinsics = [] 71 | counts = [0] 72 | focals = [] 73 | multlosses = [] 74 | 75 | for s in splits: 76 | meta = metadata[s] 77 | imgs = [] 78 | poses = [] 79 | fs = [] 80 | multloss = [] 81 | 82 | if s == "train": 83 | skip = train_skip 84 | elif s == "val": 85 | skip = val_skip 86 | elif s == "test": 87 | skip = test_skip 88 | 89 | for (filepath, pose, focal, mult) in zip( 90 | meta["file_path"][::skip], 91 | meta["cam2world"][::skip], 92 | meta["focal"][::skip], 93 | meta["lossmult"][::skip], 94 | ): 95 | fname = os.path.join(basedir, filepath) 96 | imgs.append(imageio.imread(fname)) 97 | poses.append(np.array(pose)) 98 | fs.append(focal) 99 | multloss.append(mult) 100 | 101 | imgs = [(img / 255.0).astype(np.float32) for img in imgs] 102 | poses = np.array(poses).astype(np.float32) 103 | counts.append(counts[-1] + len(imgs)) 104 | images += imgs 105 | focals += fs 106 | extrinsics.append(poses) 107 | multlosses.append(np.array(multloss)) 108 | 109 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] 110 | extrinsics = np.concatenate(extrinsics, 0) 111 | 112 | extrinsics[:, :3, 3] *= cam_scale_factor 113 | extrinsics = extrinsics @ cam_trans 114 | 115 | image_sizes = np.array([img.shape[:2] for img in images]) 116 | num_frame = len(extrinsics) 117 | i_split += [np.arange(num_frame)] 118 | 119 | intrinsics = np.array( 120 | [ 121 | [[focal, 0.0, 0.5 * w], [0.0, focal, 0.5 * h], [0.0, 0.0, 1.0]] 122 | for (focal, (h, w)) in zip(focals, image_sizes) 123 | ] 124 | ) 125 | 126 | render_poses = torch.stack( 127 | [ 128 | pose_spherical(angle, -30.0, 4.0) @ cam_trans 129 | for angle in np.linspace(-180, 180, 40 + 1)[:-1] 130 | ], 131 | 0, 132 | ) 133 | render_poses[:, :3, 3] *= cam_scale_factor 134 | near = 2.0 135 | far = 6.0 136 | 137 | if white_bkgd: 138 | images = [ 139 | image[..., :3] * image[..., -1:] + (1.0 - image[..., -1:]) 140 | for image in images 141 | ] 142 | else: 143 | images = [image[..., :3] for image in images] 144 | 145 | multlosses = np.concatenate(multlosses) 146 | 147 | return ( 148 | images, 149 | intrinsics, 150 | extrinsics, 151 | image_sizes, 152 | near, 153 | far, 154 | (-1, -1), 155 | i_split, 156 | render_poses, 157 | multlosses, # Train only 158 | ) 159 | -------------------------------------------------------------------------------- /src/data/data_util/lf.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from NeRF++ (https://github.com/Kai-46/nerfplusplus) 8 | # Copyright (c) 2020 the NeRF++ authors. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import glob 12 | import os 13 | from typing import * 14 | 15 | import imageio 16 | import numpy as np 17 | 18 | 19 | def find_files(dir, exts): 20 | if os.path.isdir(dir): 21 | files_grabbed = [] 22 | for ext in exts: 23 | files_grabbed.extend(glob.glob(os.path.join(dir, ext))) 24 | if len(files_grabbed) > 0: 25 | files_grabbed = sorted(files_grabbed) 26 | return files_grabbed 27 | else: 28 | return [] 29 | 30 | 31 | def similarity_from_cameras(c2w): 32 | """ 33 | Get a similarity transform to normalize dataset 34 | from c2w (OpenCV convention) cameras 35 | :param c2w: (N, 4) 36 | :return T (4,4) , scale (float) 37 | """ 38 | t = c2w[:, :3, 3] 39 | R = c2w[:, :3, :3] 40 | 41 | # (1) Rotate the world so that z+ is the up axis 42 | # we estimate the up axis by averaging the camera up axes 43 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) 44 | world_up = np.mean(ups, axis=0) 45 | world_up /= np.linalg.norm(world_up) 46 | 47 | up_camspace = np.array([0.0, -1.0, 0.0]) 48 | c = (up_camspace * world_up).sum() 49 | cross = np.cross(world_up, up_camspace) 50 | skew = np.array( 51 | [ 52 | [0.0, -cross[2], cross[1]], 53 | [cross[2], 0.0, -cross[0]], 54 | [-cross[1], cross[0], 0.0], 55 | ] 56 | ) 57 | if c > -1: 58 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) 59 | else: 60 | # In the unlikely case the original data has y+ up axis, 61 | # rotate 180-deg about x axis 62 | R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 63 | 64 | # R_align = np.eye(3) # DEBUG 65 | R = R_align @ R 66 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) 67 | t = (R_align @ t[..., None])[..., 0] 68 | 69 | # (2) Recenter the scene using camera center rays 70 | # find the closest point to the origin for each camera's center ray 71 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds 72 | 73 | # median for more robustness 74 | translate = -np.median(nearest, axis=0) 75 | 76 | # translate = -np.mean(t, axis=0) # DEBUG 77 | 78 | transform = np.eye(4) 79 | transform[:3, 3] = translate 80 | transform[:3, :3] = R_align 81 | 82 | # (3) Rescale the scene using camera distances 83 | scale = 1.0 / np.median(np.linalg.norm(t + translate, axis=-1)) 84 | return transform, scale 85 | 86 | 87 | def load_lf_data( 88 | datadir: str, 89 | scene_name: str, 90 | train_skip: int, 91 | val_skip: int, 92 | test_skip: int, 93 | cam_scale_factor: float, 94 | near: Optional[float], 95 | far: Optional[float], 96 | ): 97 | 98 | basedir = os.path.join(datadir, scene_name) 99 | 100 | def parse_txt(filename): 101 | assert os.path.isfile(filename) 102 | nums = open(filename).read().split() 103 | return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32) 104 | 105 | # camera parameters files 106 | intrinsics_files = find_files( 107 | "{}/train/intrinsics".format(basedir), exts=["*.txt"] 108 | )[::train_skip] 109 | intrinsics_files += find_files( 110 | "{}/validation/intrinsics".format(basedir), exts=["*.txt"] 111 | )[::val_skip] 112 | intrinsics_files += find_files( 113 | "{}/test/intrinsics".format(basedir), exts=["*.txt"] 114 | )[::test_skip] 115 | pose_files = find_files("{}/train/pose".format(basedir), exts=["*.txt"])[ 116 | ::train_skip 117 | ] 118 | pose_files += find_files("{}/validation/pose".format(basedir), exts=["*.txt"])[ 119 | ::val_skip 120 | ] 121 | pose_files += find_files("{}/test/pose".format(basedir), exts=["*.txt"])[ 122 | ::test_skip 123 | ] 124 | cam_cnt = len(pose_files) 125 | 126 | # img files 127 | img_files = find_files("{}/rgb".format(basedir), exts=["*.png", "*.jpg"]) 128 | if len(img_files) > 0: 129 | assert len(img_files) == cam_cnt 130 | else: 131 | img_files = [ 132 | None, 133 | ] * cam_cnt 134 | 135 | # assume all images have the same size as training image 136 | train_imgfile = find_files("{}/train/rgb".format(basedir), exts=["*.png", "*.jpg"])[ 137 | ::train_skip 138 | ] 139 | val_imgfile = find_files( 140 | "{}/validation/rgb".format(basedir), exts=["*.png", "*.jpg"] 141 | )[::val_skip] 142 | test_imgfile = find_files("{}/test/rgb".format(basedir), exts=["*.png", "*.jpg"])[ 143 | ::test_skip 144 | ] 145 | i_train = np.arange(len(train_imgfile)) 146 | i_val = np.arange(len(val_imgfile)) + len(train_imgfile) 147 | i_test = np.arange(len(test_imgfile)) + len(train_imgfile) + len(val_imgfile) 148 | i_all = np.arange(len(train_imgfile) + len(val_imgfile) + len(test_imgfile)) 149 | i_split = (i_train, i_val, i_test, i_all) 150 | 151 | images = ( 152 | np.stack( 153 | [ 154 | imageio.imread(imgfile) 155 | for imgfile in train_imgfile + val_imgfile + test_imgfile 156 | ] 157 | ) 158 | / 255.0 159 | ) 160 | h, w = images[0].shape[:2] 161 | 162 | intrinsics = np.stack( 163 | [parse_txt(intrinsics_file) for intrinsics_file in intrinsics_files] 164 | ) 165 | extrinsics = np.stack([parse_txt(pose_file) for pose_file in pose_files]) 166 | 167 | if cam_scale_factor > 0: 168 | T, sscale = similarity_from_cameras(extrinsics) 169 | extrinsics = np.einsum("nij, ki -> nkj", extrinsics, T) 170 | scene_scale = cam_scale_factor * sscale 171 | extrinsics[:, :3, 3] *= scene_scale 172 | 173 | num_frame = len(extrinsics) 174 | 175 | image_sizes = np.array([[h, w] for i in range(num_frame)]) 176 | 177 | near = 0.0 if near is None else near 178 | far = 1.0 if far is None else far 179 | 180 | render_poses = extrinsics 181 | 182 | return ( 183 | images, 184 | intrinsics, 185 | extrinsics, 186 | image_sizes, 187 | near, 188 | far, 189 | (-1, -1), 190 | i_split, 191 | render_poses, 192 | ) 193 | -------------------------------------------------------------------------------- /src/data/data_util/shiny_blender.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from Ref-NeRF (https://github.com/google-research/multinerf) 8 | # Copyright (c) 2022 Google LLC. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import json 12 | import os 13 | 14 | import gdown 15 | import imageio 16 | import numpy as np 17 | import torch 18 | 19 | trans_t = lambda t: torch.tensor( 20 | [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]] 21 | ).float() 22 | 23 | rot_phi = lambda phi: torch.tensor( 24 | [ 25 | [1, 0, 0, 0], 26 | [0, np.cos(phi), -np.sin(phi), 0], 27 | [0, np.sin(phi), np.cos(phi), 0], 28 | [0, 0, 0, 1], 29 | ] 30 | ).float() 31 | 32 | rot_theta = lambda th: torch.tensor( 33 | [ 34 | [np.cos(th), 0, -np.sin(th), 0], 35 | [0, 1, 0, 0], 36 | [np.sin(th), 0, np.cos(th), 0], 37 | [0, 0, 0, 1], 38 | ] 39 | ).float() 40 | 41 | 42 | def pose_spherical(theta, phi, radius): 43 | c2w = trans_t(radius) 44 | c2w = rot_phi(phi / 180.0 * np.pi) @ c2w 45 | c2w = rot_theta(theta / 180.0 * np.pi) @ c2w 46 | c2w = ( 47 | torch.tensor([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]).float() 48 | @ c2w 49 | ) 50 | return c2w 51 | 52 | 53 | def load_shiny_blender_data( 54 | datadir: str, 55 | scene_name: str, 56 | train_skip: int, 57 | val_skip: int, 58 | test_skip: int, 59 | cam_scale_factor: float, 60 | white_bkgd: bool, 61 | ): 62 | basedir = os.path.join(datadir, scene_name) 63 | cam_trans = np.diag(np.array([1, -1, -1, 1], dtype=np.float32)) 64 | splits = ["train", "val", "test"] 65 | metas = {} 66 | for s in splits: 67 | if s == "val": 68 | continue 69 | with open(os.path.join(basedir, "transforms_{}.json".format(s)), "r") as fp: 70 | metas[s] = json.load(fp) 71 | metas["val"] = metas["test"] 72 | 73 | images = [] 74 | normals = [] 75 | extrinsics = [] 76 | counts = [0] 77 | 78 | for s in splits: 79 | meta = metas[s] 80 | imgs = [] 81 | norms = [] 82 | alphas = [] 83 | poses = [] 84 | 85 | if s == "train": 86 | skip = train_skip 87 | elif s == "val": 88 | skip = val_skip 89 | elif s == "test": 90 | skip = test_skip 91 | 92 | for frame in meta["frames"][::skip]: 93 | img_fname = os.path.join(basedir, frame["file_path"] + ".png") 94 | norm_fname = os.path.join(basedir, frame["file_path"] + "_normal.png") 95 | imgs.append(imageio.imread(img_fname)) 96 | norms.append(imageio.imread(norm_fname)) 97 | poses.append(np.array(frame["transform_matrix"])) 98 | imgs = (np.array(imgs) / 255.0).astype(np.float32) 99 | alphas = imgs[..., -1:] 100 | norms = np.array(norms).astype(np.float32)[..., :3] * 2.0 / 255.0 - 1.0 101 | # Concatenate normals and alphas for computing MAE 102 | norms = np.concatenate([norms, alphas], axis=-1) 103 | poses = np.array(poses).astype(np.float32) 104 | counts.append(counts[-1] + imgs.shape[0]) 105 | images.append(imgs) 106 | normals.append(norms) 107 | extrinsics.append(poses) 108 | 109 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] 110 | 111 | images = np.concatenate(images, 0) 112 | normals = np.concatenate(normals, 0) 113 | 114 | extrinsics = np.concatenate(extrinsics, 0) 115 | 116 | extrinsics[:, :3, 3] *= cam_scale_factor 117 | extrinsics = extrinsics @ cam_trans 118 | 119 | h, w = imgs[0].shape[:2] 120 | num_frame = len(extrinsics) 121 | i_split += [np.arange(num_frame)] 122 | 123 | camera_angle_x = float(meta["camera_angle_x"]) 124 | focal = 0.5 * w / np.tan(0.5 * camera_angle_x) 125 | intrinsics = np.array( 126 | [ 127 | [[focal, 0.0, 0.5 * w], [0.0, focal, 0.5 * h], [0.0, 0.0, 1.0]] 128 | for _ in range(num_frame) 129 | ] 130 | ) 131 | image_sizes = np.array([[h, w] for _ in range(num_frame)]) 132 | 133 | render_poses = torch.stack( 134 | [ 135 | pose_spherical(angle, -30.0, 4.0) @ cam_trans 136 | for angle in np.linspace(-180, 180, 40 + 1)[:-1] 137 | ], 138 | 0, 139 | ) 140 | render_poses[:, :3, 3] *= cam_scale_factor 141 | near = 2.0 142 | far = 6.0 143 | 144 | if white_bkgd: 145 | images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) 146 | else: 147 | images = images[..., :3] 148 | 149 | return ( 150 | images, 151 | intrinsics, 152 | extrinsics, 153 | image_sizes, 154 | near, 155 | far, 156 | (-1, -1), 157 | i_split, 158 | render_poses, 159 | normals, 160 | ) 161 | -------------------------------------------------------------------------------- /src/data/data_util/tnt.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from NeRF++ (https://github.com/Kai-46/nerfplusplus) 8 | # Copyright (c) 2020 the NeRF++ authors. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import glob 12 | import os 13 | from typing import * 14 | 15 | import imageio 16 | import numpy as np 17 | 18 | 19 | def find_files(dir, exts): 20 | if os.path.isdir(dir): 21 | files_grabbed = [] 22 | for ext in exts: 23 | files_grabbed.extend(glob.glob(os.path.join(dir, ext))) 24 | if len(files_grabbed) > 0: 25 | files_grabbed = sorted(files_grabbed) 26 | return files_grabbed 27 | else: 28 | return [] 29 | 30 | 31 | def similarity_from_cameras(c2w): 32 | """ 33 | Get a similarity transform to normalize dataset 34 | from c2w (OpenCV convention) cameras 35 | :param c2w: (N, 4) 36 | :return T (4,4) , scale (float) 37 | """ 38 | t = c2w[:, :3, 3] 39 | R = c2w[:, :3, :3] 40 | 41 | # (1) Rotate the world so that z+ is the up axis 42 | # we estimate the up axis by averaging the camera up axes 43 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) 44 | world_up = np.mean(ups, axis=0) 45 | world_up /= np.linalg.norm(world_up) 46 | 47 | up_camspace = np.array([0.0, -1.0, 0.0]) 48 | c = (up_camspace * world_up).sum() 49 | cross = np.cross(world_up, up_camspace) 50 | skew = np.array( 51 | [ 52 | [0.0, -cross[2], cross[1]], 53 | [cross[2], 0.0, -cross[0]], 54 | [-cross[1], cross[0], 0.0], 55 | ] 56 | ) 57 | if c > -1: 58 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) 59 | else: 60 | # In the unlikely case the original data has y+ up axis, 61 | # rotate 180-deg about x axis 62 | R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 63 | 64 | # R_align = np.eye(3) # DEBUG 65 | R = R_align @ R 66 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) 67 | t = (R_align @ t[..., None])[..., 0] 68 | 69 | # (2) Recenter the scene using camera center rays 70 | # find the closest point to the origin for each camera's center ray 71 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds 72 | 73 | # median for more robustness 74 | translate = -np.median(nearest, axis=0) 75 | 76 | # translate = -np.mean(t, axis=0) # DEBUG 77 | 78 | transform = np.eye(4) 79 | transform[:3, 3] = translate 80 | transform[:3, :3] = R_align 81 | 82 | # (3) Rescale the scene using camera distances 83 | scale = 1.0 / np.median(np.linalg.norm(t + translate, axis=-1)) 84 | return transform, scale 85 | 86 | 87 | def load_tnt_data( 88 | datadir: str, 89 | scene_name: str, 90 | train_skip: int, 91 | val_skip: int, 92 | test_skip: int, 93 | cam_scale_factor: float, 94 | near: Optional[float], 95 | far: Optional[float], 96 | ): 97 | 98 | basedir = os.path.join(datadir, scene_name) 99 | 100 | def parse_txt(filename): 101 | assert os.path.isfile(filename) 102 | nums = open(filename).read().split() 103 | return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32) 104 | 105 | # camera parameters files 106 | intrinsics_files = find_files( 107 | "{}/train/intrinsics".format(basedir), exts=["*.txt"] 108 | )[::train_skip] 109 | intrinsics_files += find_files( 110 | "{}/validation/intrinsics".format(basedir), exts=["*.txt"] 111 | )[::val_skip] 112 | intrinsics_files += find_files( 113 | "{}/test/intrinsics".format(basedir), exts=["*.txt"] 114 | )[::test_skip] 115 | pose_files = find_files("{}/train/pose".format(basedir), exts=["*.txt"])[ 116 | ::train_skip 117 | ] 118 | pose_files += find_files("{}/validation/pose".format(basedir), exts=["*.txt"])[ 119 | ::val_skip 120 | ] 121 | pose_files += find_files("{}/test/pose".format(basedir), exts=["*.txt"])[ 122 | ::test_skip 123 | ] 124 | cam_cnt = len(pose_files) 125 | 126 | # img files 127 | img_files = find_files("{}/rgb".format(basedir), exts=["*.png", "*.jpg"]) 128 | if len(img_files) > 0: 129 | assert len(img_files) == cam_cnt 130 | else: 131 | img_files = [ 132 | None, 133 | ] * cam_cnt 134 | 135 | # assume all images have the same size as training image 136 | train_imgfile = find_files("{}/train/rgb".format(basedir), exts=["*.png", "*.jpg"])[ 137 | ::train_skip 138 | ] 139 | val_imgfile = find_files( 140 | "{}/validation/rgb".format(basedir), exts=["*.png", "*.jpg"] 141 | )[::val_skip] 142 | test_imgfile = find_files("{}/test/rgb".format(basedir), exts=["*.png", "*.jpg"])[ 143 | ::test_skip 144 | ] 145 | i_train = np.arange(len(train_imgfile)) 146 | i_val = np.arange(len(val_imgfile)) + len(train_imgfile) 147 | i_test = np.arange(len(test_imgfile)) + len(train_imgfile) + len(val_imgfile) 148 | i_all = np.arange(len(train_imgfile) + len(val_imgfile) + len(test_imgfile)) 149 | i_split = (i_train, i_val, i_test, i_all) 150 | 151 | images = ( 152 | np.stack( 153 | [ 154 | imageio.imread(imgfile) 155 | for imgfile in train_imgfile + val_imgfile + test_imgfile 156 | ] 157 | ) 158 | / 255.0 159 | ) 160 | h, w = images[0].shape[:2] 161 | 162 | intrinsics = np.stack( 163 | [parse_txt(intrinsics_file) for intrinsics_file in intrinsics_files] 164 | ) 165 | extrinsics = np.stack([parse_txt(pose_file) for pose_file in pose_files]) 166 | 167 | if cam_scale_factor > 0: 168 | T, sscale = similarity_from_cameras(extrinsics) 169 | extrinsics = np.einsum("nij, ki -> nkj", extrinsics, T) 170 | scene_scale = cam_scale_factor * sscale 171 | extrinsics[:, :3, 3] *= scene_scale 172 | 173 | num_frame = len(extrinsics) 174 | 175 | image_sizes = np.array([[h, w] for i in range(num_frame)]) 176 | 177 | near = 0.0 if near is None else near 178 | far = 1.0 if far is None else far 179 | 180 | render_poses = extrinsics 181 | 182 | return ( 183 | images, 184 | intrinsics, 185 | extrinsics, 186 | image_sizes, 187 | near, 188 | far, 189 | (-1, -1), 190 | i_split, 191 | render_poses, 192 | ) 193 | -------------------------------------------------------------------------------- /src/data/ray_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import numpy as np 8 | 9 | 10 | def convert_to_ndc(origins, directions, ndc_coeffs, near: float = 1.0): 11 | """Convert a set of rays to NDC coordinates.""" 12 | t = (near - origins[Ellipsis, 2]) / directions[Ellipsis, 2] 13 | origins = origins + t[Ellipsis, None] * directions 14 | 15 | dx, dy, dz = directions[:, 0], directions[:, 1], directions[:, 2] 16 | ox, oy, oz = origins[:, 0], origins[:, 1], origins[:, 2] 17 | o0 = ndc_coeffs[0] * (ox / oz) 18 | o1 = ndc_coeffs[1] * (oy / oz) 19 | o2 = 1 - 2 * near / oz 20 | d0 = ndc_coeffs[0] * (dx / dz - ox / oz) 21 | d1 = ndc_coeffs[1] * (dy / dz - oy / oz) 22 | d2 = 2 * near / oz 23 | 24 | origins = np.stack([o0, o1, o2], -1) 25 | directions = np.stack([d0, d1, d2], -1) 26 | 27 | return origins, directions 28 | 29 | 30 | def batchified_get_rays( 31 | intrinsics, 32 | extrinsics, 33 | image_sizes, 34 | use_pixel_centers, 35 | get_radii, 36 | ndc_coord, 37 | ndc_coeffs, 38 | multlosses, 39 | ): 40 | 41 | radii = None 42 | multloss_expand = None 43 | 44 | center = 0.5 if use_pixel_centers else 0.0 45 | mesh_grids = [ 46 | np.meshgrid( 47 | np.arange(w, dtype=np.float32) + center, 48 | np.arange(h, dtype=np.float32) + center, 49 | indexing="xy", 50 | ) 51 | for (h, w) in image_sizes 52 | ] 53 | 54 | i_coords = [mesh_grid[0] for mesh_grid in mesh_grids] 55 | j_coords = [mesh_grid[1] for mesh_grid in mesh_grids] 56 | 57 | dirs = [ 58 | np.stack( 59 | [ 60 | (i - intrinsic[0][2]) / intrinsic[0][0], 61 | (j - intrinsic[1][2]) / intrinsic[1][1], 62 | np.ones_like(i), 63 | ], 64 | -1, 65 | ) 66 | for (intrinsic, i, j) in zip(intrinsics, i_coords, j_coords) 67 | ] 68 | 69 | rays_o = np.concatenate( 70 | [ 71 | np.tile(extrinsic[np.newaxis, :3, 3], (1, h * w, 1)).reshape(-1, 3) 72 | for (extrinsic, (h, w)) in zip(extrinsics, image_sizes) 73 | ] 74 | ).astype(np.float32) 75 | 76 | rays_d = np.concatenate( 77 | [ 78 | np.einsum("hwc, rc -> hwr", dir, extrinsic[:3, :3]).reshape(-1, 3) 79 | for (dir, extrinsic) in zip(dirs, extrinsics) 80 | ] 81 | ).astype(np.float32) 82 | 83 | viewdirs = rays_d 84 | viewdirs /= np.linalg.norm(viewdirs, axis=-1, keepdims=True) 85 | 86 | if ndc_coord: 87 | rays_o, rays_d = convert_to_ndc(rays_o, rays_d, ndc_coeffs) 88 | 89 | if get_radii: 90 | 91 | if not ndc_coord: 92 | rays_d_orig = [ 93 | np.einsum("hwc, rc -> hwr", dir, extrinsic[:3, :3]) 94 | for (dir, extrinsic) in zip(dirs, extrinsics) 95 | ] 96 | dx = [ 97 | np.sqrt(np.sum((v[:-1, :, :] - v[1:, :, :]) ** 2, -1)) 98 | for v in rays_d_orig 99 | ] 100 | dx = [np.concatenate([v, v[-2:-1, :]], 0) for v in dx] 101 | _radii = [v[..., None] * 2 / np.sqrt(12) for v in dx] 102 | radii = np.concatenate([radii_each.reshape(-1) for radii_each in _radii])[ 103 | ..., None 104 | ] 105 | 106 | else: 107 | rays_o_orig, cnt = [], 0 108 | for (h, w) in image_sizes: 109 | rays_o_orig.append(rays_o[cnt : cnt + h * w].reshape(h, w, 3)) 110 | cnt += h * w 111 | 112 | dx = [ 113 | np.sqrt(np.sum((v[:-1, :, :] - v[1:, :, :]) ** 2, -1)) 114 | for v in rays_o_orig 115 | ] 116 | dx = [np.concatenate([v, v[-2:-1, :]], 0) for v in dx] 117 | dy = [ 118 | np.sqrt(np.sum((v[:, :-1, :] - v[:, 1:, :]) ** 2, -1)) 119 | for v in rays_o_orig 120 | ] 121 | dy = [np.concatenate([v, v[:, -2:-1]], 1) for v in dy] 122 | _radii = [(vx + vy)[..., None] / np.sqrt(12) for (vx, vy) in zip(dx, dy)] 123 | radii = np.concatenate([radii_each.reshape(-1) for radii_each in _radii])[ 124 | ..., None 125 | ] 126 | 127 | if multlosses is not None: 128 | multloss_expand = np.concatenate( 129 | [ 130 | np.array([scale] * (h * w)) 131 | for (scale, (h, w)) in zip(multlosses, image_sizes) 132 | ] 133 | )[..., None] 134 | 135 | return rays_o, rays_d, viewdirs, radii, multloss_expand 136 | -------------------------------------------------------------------------------- /src/data/sampler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import gin 8 | import numpy as np 9 | import torch 10 | import torch.distributed as dist 11 | from torch.utils.data.sampler import SequentialSampler 12 | 13 | 14 | class DDPSampler(SequentialSampler): 15 | def __init__(self, batch_size, num_replicas, rank, tpu): 16 | self.data_source = None 17 | self.batch_size = batch_size 18 | self.drop_last = False 19 | ngpus = torch.cuda.device_count() 20 | if ngpus == 1 and not tpu: 21 | rank, num_replicas = 0, 1 22 | else: 23 | if num_replicas is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available") 26 | num_replicas = dist.get_world_size() 27 | if rank is None: 28 | if not dist.is_available(): 29 | raise RuntimeError("Requires distributed package to be available") 30 | rank = dist.get_rank() 31 | self.rank = rank 32 | self.num_replicas = num_replicas 33 | 34 | 35 | class DDPSequnetialSampler(DDPSampler): 36 | def __init__(self, batch_size, num_replicas, rank, N_total, tpu): 37 | self.N_total = N_total 38 | super(DDPSequnetialSampler, self).__init__(batch_size, num_replicas, rank, tpu) 39 | 40 | def __iter__(self): 41 | idx_list = np.arange(self.N_total) 42 | return iter(idx_list[self.rank :: self.num_replicas]) 43 | 44 | def __len__(self): 45 | return int(np.ceil(self.N_total / self.num_replicas)) 46 | 47 | 48 | class SingleImageDDPSampler(DDPSampler): 49 | def __init__( 50 | self, 51 | batch_size, 52 | num_replicas, 53 | rank, 54 | N_img, 55 | N_pixels, 56 | epoch_size, 57 | tpu, 58 | precrop, 59 | precrop_steps, 60 | ): 61 | super(SingleImageDDPSampler, self).__init__(batch_size, num_replicas, rank, tpu) 62 | self.N_pixels = N_pixels 63 | self.N_img = N_img 64 | self.epoch_size = epoch_size 65 | self.precrop = precrop 66 | self.precrop_steps = precrop_steps 67 | 68 | def __iter__(self): 69 | image_choice = np.random.choice( 70 | np.arange(self.N_img), self.epoch_size, replace=True 71 | ) 72 | image_shape = self.N_pixels[image_choice] 73 | if not self.precrop: 74 | idx_choice = [ 75 | np.random.choice( 76 | np.arange(image_shape[i, 0] * image_shape[i, 1]), self.batch_size 77 | ) 78 | for i in range(self.epoch_size) 79 | ] 80 | else: 81 | idx_choice = [] 82 | h_pick = [ 83 | np.random.choice(np.arange(image_shape[i, 0] // 2), self.batch_size) 84 | + image_shape[i, 0] // 4 85 | for i in range(self.precrop_steps) 86 | ] 87 | w_pick = [ 88 | np.random.choice(np.arange(image_shape[i, 1] // 2), self.batch_size) 89 | + image_shape[i, 1] // 4 90 | for i in range(self.precrop_steps) 91 | ] 92 | idx_choice = [ 93 | h_pick[i] * image_shape[i, 1] + w_pick[i] 94 | for i in range(self.precrop_steps) 95 | ] 96 | 97 | idx_choice += [ 98 | np.random.choice( 99 | np.arange(image_shape[i, 0] * image_shape[i, 1]), self.batch_size 100 | ) 101 | for i in range(self.epoch_size - self.precrop_steps) 102 | ] 103 | self.precrop = False 104 | 105 | idx_choice = np.stack(idx_choice) 106 | idx_jump = np.concatenate( 107 | [ 108 | np.zeros_like(self.N_pixels[0]), 109 | np.cumsum(self.N_pixels[:-1, 0] * self.N_pixels[:-1, 1]), 110 | ] 111 | )[..., None] 112 | idx_shift = idx_jump[image_choice] + idx_choice 113 | idx_shift = idx_shift[:, self.rank :: self.num_replicas] 114 | 115 | return iter(idx_shift) 116 | 117 | def __len__(self): 118 | return self.epoch_size 119 | 120 | 121 | class MultipleImageDDPSampler(DDPSampler): 122 | def __init__(self, batch_size, num_replicas, rank, total_len, epoch_size, tpu): 123 | super(MultipleImageDDPSampler, self).__init__( 124 | batch_size, num_replicas, rank, tpu 125 | ) 126 | self.total_len = total_len 127 | self.epoch_size = epoch_size 128 | 129 | def __iter__(self): 130 | full_index = np.arange(self.total_len) 131 | indices = np.stack( 132 | [ 133 | np.random.choice(full_index, self.batch_size) 134 | for _ in range(self.epoch_size) 135 | ] 136 | ) 137 | for batch in indices: 138 | yield batch[self.rank :: self.num_replicas] 139 | 140 | def __len__(self): 141 | return self.epoch_size 142 | 143 | 144 | @gin.configurable() 145 | class MultipleImageDynamicDDPSampler(DDPSampler): 146 | def __init__( 147 | self, 148 | batch_size, 149 | num_replicas, 150 | rank, 151 | total_len, 152 | N_img, 153 | N_pixels, 154 | epoch_size, 155 | tpu, 156 | N_coarse=0, 157 | ): 158 | super(MultipleImageDynamicDDPSampler, self).__init__( 159 | batch_size, num_replicas, rank, tpu 160 | ) 161 | self.total_len = total_len 162 | self.epoch_size = epoch_size 163 | self.N_pixels = N_pixels 164 | self.N_img = N_img 165 | self.N_coarse = N_coarse 166 | 167 | def __iter__(self): 168 | indices = np.random.rand(self.epoch_size - self.N_coarse, self.batch_size) 169 | 170 | image_choice = np.random.choice( 171 | np.arange(self.N_img), self.N_coarse, replace=True 172 | ) 173 | image_shape = self.N_pixels[image_choice] 174 | idx_choice = [ 175 | np.random.choice( 176 | np.arange(image_shape[i, 0] * image_shape[i, 1]), self.batch_size 177 | ) 178 | for i in range(self.N_coarse) 179 | ] 180 | 181 | idx_choice = np.stack(idx_choice) 182 | idx_jump = np.concatenate( 183 | [ 184 | np.zeros_like(self.N_pixels[0]), 185 | np.cumsum(self.N_pixels[:-1, 0] * self.N_pixels[:-1, 1]), 186 | ] 187 | )[..., None] 188 | idx_shift = idx_jump[image_choice] + idx_choice 189 | idx_shift = idx_shift[:, self.rank :: self.num_replicas] 190 | for batch in idx_shift: 191 | yield batch 192 | 193 | for batch in indices: 194 | yield np.floor( 195 | (batch[self.rank :: self.num_replicas]) * self.total_len 196 | ).astype(np.uint) 197 | 198 | def __len__(self): 199 | return self.epoch_size 200 | -------------------------------------------------------------------------------- /src/model/dvgo/__global__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from DVGO (https://github.com/sunset1995/DirectVoxGO) 8 | # Copyright (c) 2022 Google LLC. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import os 12 | 13 | from torch.utils.cpp_extension import load 14 | 15 | root_dir = __file__.split(os.path.relpath(__file__))[0] 16 | 17 | render_utils_cuda = None 18 | total_variation_cuda = None 19 | ub360_utils_cuda = None 20 | adam_upd_cuda = None 21 | 22 | sources = ["lib/dvgo/cuda/adam_upd.cpp", "lib/dvgo/cuda/adam_upd_kernel.cu"] 23 | 24 | 25 | def init(): 26 | global render_utils_cuda 27 | render_utils_cuda = load( 28 | name="render_utils_cuda", 29 | sources=[ 30 | os.path.join(root_dir, path) 31 | for path in [ 32 | "lib/dvgo/cuda/render_utils.cpp", 33 | "lib/dvgo/cuda/render_utils_kernel.cu", 34 | ] 35 | ], 36 | verbose=True, 37 | ) 38 | 39 | global ub360_utils_cuda 40 | ub360_utils_cuda = load( 41 | name="ub360_utils_cuda", 42 | sources=[ 43 | os.path.join(root_dir, path) 44 | for path in [ 45 | "lib/dvgo/cuda/ub360_utils.cpp", 46 | "lib/dvgo/cuda/ub360_utils_kernel.cu", 47 | ] 48 | ], 49 | verbose=True, 50 | ) 51 | 52 | global total_variation_cuda 53 | if total_variation_cuda is None: 54 | total_variation_cuda = load( 55 | name="total_variation_cuda", 56 | sources=[ 57 | os.path.join(root_dir, path) 58 | for path in [ 59 | "lib/dvgo/cuda/total_variation.cpp", 60 | "lib/dvgo/cuda/total_variation_kernel.cu", 61 | ] 62 | ], 63 | verbose=True, 64 | ) 65 | 66 | global adam_upd_cuda 67 | adam_upd_cuda = load( 68 | name="adam_upd_cuda", 69 | sources=[os.path.join(root_dir, path) for path in sources], 70 | verbose=True, 71 | ) 72 | -------------------------------------------------------------------------------- /src/model/dvgo/masked_adam.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from DVGO (https://github.com/sunset1995/DirectVoxGO) 8 | # Copyright (c) 2022 Google LLC. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import torch 12 | 13 | from src.model.dvgo.__global__ import * 14 | 15 | """ Extend Adam optimizer 16 | 1. support per-voxel learning rate 17 | 2. masked update (ignore zero grad) which speeduping training 18 | """ 19 | 20 | 21 | class MaskedAdam(torch.optim.Optimizer): 22 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8): 23 | 24 | if not 0.0 <= lr: 25 | raise ValueError("Invalid learning rate: {}".format(lr)) 26 | if not 0.0 <= eps: 27 | raise ValueError("Invalid epsilon value: {}".format(eps)) 28 | if not 0.0 <= betas[0] < 1.0: 29 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 30 | if not 0.0 <= betas[1] < 1.0: 31 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 32 | defaults = dict(lr=lr, betas=betas, eps=eps) 33 | self.per_lr = None 34 | super(MaskedAdam, self).__init__(params, defaults) 35 | 36 | def __setstate__(self, state): 37 | super(MaskedAdam, self).__setstate__(state) 38 | 39 | def set_pervoxel_lr(self, count): 40 | assert self.param_groups[0]["params"][0].shape == count.shape 41 | self.per_lr = count.float() / count.max() 42 | 43 | @torch.no_grad() 44 | def step(self, closure=None): 45 | 46 | loss = None 47 | if closure is not None: 48 | with torch.enable_grad(): 49 | loss = closure() 50 | 51 | for group in self.param_groups: 52 | lr = group["lr"] 53 | beta1, beta2 = group["betas"] 54 | eps = group["eps"] 55 | skip_zero_grad = group["skip_zero_grad"] 56 | 57 | for param in group["params"]: 58 | if param.grad is not None: 59 | state = self.state[param] 60 | # Lazy state initialization 61 | if len(state) == 0: 62 | state["step"] = 0 63 | # Exponential moving average of gradient values 64 | state["exp_avg"] = torch.zeros_like( 65 | param, memory_format=torch.preserve_format 66 | ) 67 | # Exponential moving average of squared gradient values 68 | state["exp_avg_sq"] = torch.zeros_like( 69 | param, memory_format=torch.preserve_format 70 | ) 71 | 72 | state["step"] += 1 73 | 74 | from src.model.dvgo.__global__ import adam_upd_cuda 75 | 76 | if self.per_lr is not None and param.shape == self.per_lr.shape: 77 | adam_upd_cuda.adam_upd_with_perlr( 78 | param, 79 | param.grad, 80 | state["exp_avg"], 81 | state["exp_avg_sq"], 82 | self.per_lr, 83 | state["step"], 84 | beta1, 85 | beta2, 86 | lr, 87 | eps, 88 | ) 89 | elif skip_zero_grad: 90 | adam_upd_cuda.masked_adam_upd( 91 | param, 92 | param.grad, 93 | state["exp_avg"], 94 | state["exp_avg_sq"], 95 | state["step"], 96 | beta1, 97 | beta2, 98 | lr, 99 | eps, 100 | ) 101 | else: 102 | adam_upd_cuda.adam_upd( 103 | param, 104 | param.grad, 105 | state["exp_avg"], 106 | state["exp_avg_sq"], 107 | state["step"], 108 | beta1, 109 | beta2, 110 | lr, 111 | eps, 112 | ) 113 | 114 | return loss 115 | -------------------------------------------------------------------------------- /src/model/dvgo/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from DVGO (https://github.com/sunset1995/DirectVoxGO) 8 | # Copyright (c) 2022 Google LLC. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import numpy as np 12 | import scipy.signal 13 | import torch 14 | import torch.nn as nn 15 | 16 | from src.model.dvgo.__global__ import * 17 | from src.model.dvgo.masked_adam import MaskedAdam 18 | 19 | """ Misc 20 | """ 21 | mse2psnr = lambda x: -10.0 * torch.log10(x) 22 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 23 | 24 | 25 | def create_optimizer_or_freeze_model(model, cfg_train, global_step): 26 | 27 | decay_steps = cfg_train["lrate_decay"] * 1000 28 | decay_factor = 0.1 ** (global_step / decay_steps) 29 | 30 | param_group = [] 31 | for k in cfg_train.keys(): 32 | if not k.startswith("lrate_"): 33 | continue 34 | k = k[len("lrate_") :] 35 | 36 | if not hasattr(model, k): 37 | continue 38 | 39 | param = getattr(model, k) 40 | if param is None: 41 | print(f"create_optimizer_or_freeze_model: param {k} not exist") 42 | continue 43 | 44 | lr = cfg_train.get(f"lrate_{k}") * decay_factor 45 | 46 | if lr > 0: 47 | print(f"create_optimizer_or_freeze_model: param {k} lr {lr}") 48 | if isinstance(param, nn.Module): 49 | param = param.parameters() 50 | 51 | param_group.append( 52 | { 53 | "params": param, 54 | "lr": lr, 55 | "skip_zero_grad": (k in cfg_train["skip_zero_grad_fields"]), 56 | } 57 | ) 58 | else: 59 | print(f"create_optimizer_or_freeze_model: param {k} freeze") 60 | param.requires_grad = False 61 | 62 | return MaskedAdam(param_group) 63 | 64 | 65 | """ Checkpoint utils 66 | """ 67 | 68 | 69 | def load_checkpoint(model, optimizer, ckpt_path, no_reload_optimizer): 70 | ckpt = torch.load(ckpt_path) 71 | start = ckpt["global_step"] 72 | model.load_state_dict(ckpt["model_state_dict"]) 73 | if not no_reload_optimizer: 74 | optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 75 | return model, optimizer, start 76 | 77 | 78 | def load_model(model_class, ckpt_path): 79 | ckpt = torch.load(ckpt_path) 80 | model = model_class(**ckpt["model_kwargs"]) 81 | model.load_state_dict(ckpt["model_state_dict"]) 82 | return model 83 | 84 | 85 | """ Evaluation metrics (ssim, lpips) 86 | """ 87 | 88 | 89 | def rgb_ssim( 90 | img0, 91 | img1, 92 | max_val, 93 | filter_size=11, 94 | filter_sigma=1.5, 95 | k1=0.01, 96 | k2=0.03, 97 | return_map=False, 98 | ): 99 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 100 | assert len(img0.shape) == 3 101 | assert img0.shape[-1] == 3 102 | assert img0.shape == img1.shape 103 | 104 | # Construct a 1D Gaussian blur filter. 105 | hw = filter_size // 2 106 | shift = (2 * hw - filter_size + 1) / 2 107 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma) ** 2 108 | filt = np.exp(-0.5 * f_i) 109 | filt /= np.sum(filt) 110 | 111 | # Blur in x and y (faster than the 2D convolution). 112 | def convolve2d(z, f): 113 | return scipy.signal.convolve2d(z, f, mode="valid") 114 | 115 | filt_fn = lambda z: np.stack( 116 | [ 117 | convolve2d(convolve2d(z[..., i], filt[:, None]), filt[None, :]) 118 | for i in range(z.shape[-1]) 119 | ], 120 | -1, 121 | ) 122 | mu0 = filt_fn(img0) 123 | mu1 = filt_fn(img1) 124 | mu00 = mu0 * mu0 125 | mu11 = mu1 * mu1 126 | mu01 = mu0 * mu1 127 | sigma00 = filt_fn(img0**2) - mu00 128 | sigma11 = filt_fn(img1**2) - mu11 129 | sigma01 = filt_fn(img0 * img1) - mu01 130 | 131 | # Clip the variances and covariances to valid values. 132 | # Variance must be non-negative: 133 | sigma00 = np.maximum(0.0, sigma00) 134 | sigma11 = np.maximum(0.0, sigma11) 135 | sigma01 = np.sign(sigma01) * np.minimum(np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 136 | c1 = (k1 * max_val) ** 2 137 | c2 = (k2 * max_val) ** 2 138 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 139 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 140 | ssim_map = numer / denom 141 | ssim = np.mean(ssim_map) 142 | return ssim_map if return_map else ssim 143 | 144 | 145 | __LPIPS__ = {} 146 | 147 | 148 | def init_lpips(net_name, device): 149 | assert net_name in ["alex", "vgg"] 150 | import lpips 151 | 152 | print(f"init_lpips: lpips_{net_name}") 153 | return lpips.LPIPS(net=net_name, version="0.1").eval().to(device) 154 | 155 | 156 | def rgb_lpips(np_gt, np_im, net_name, device): 157 | if net_name not in __LPIPS__: 158 | __LPIPS__[net_name] = init_lpips(net_name, device) 159 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 160 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 161 | return __LPIPS__[net_name](gt, im, normalize=True).item() 162 | -------------------------------------------------------------------------------- /src/model/interface.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import json 8 | import os 9 | 10 | import numpy as np 11 | import pytorch_lightning as pl 12 | import torch 13 | from piqa.lpips import LPIPS 14 | from piqa.ssim import SSIM 15 | 16 | import utils.store_image as store_image 17 | 18 | reshape_2d = lambda x: x.reshape((x.shape[0], -1)) 19 | clip_0_1 = lambda x: torch.clip(x, 0, 1).detach() 20 | 21 | 22 | class LitModel(pl.LightningModule): 23 | 24 | # Utils to reorganize output values from evaluation steps, 25 | # i.e., validation and test step. 26 | def alter_gather_cat(self, outputs, key, image_sizes): 27 | each = torch.cat([output[key] for output in outputs]) 28 | all = self.all_gather(each).detach() 29 | if all.dim() == 3: 30 | all = all.permute((1, 0, 2)).flatten(0, 1) 31 | ret, curr = [], 0 32 | for (h, w) in image_sizes: 33 | ret.append(all[curr : curr + h * w].reshape(h, w, 3)) 34 | curr += h * w 35 | return ret 36 | 37 | @torch.no_grad() 38 | def psnr_each(self, preds, gts): 39 | psnr_list = [] 40 | for (pred, gt) in zip(preds, gts): 41 | pred = torch.clip(pred, 0, 1) 42 | gt = torch.clip(gt, 0, 1) 43 | mse = torch.mean((pred - gt) ** 2) 44 | psnr = -10.0 * torch.log(mse) / np.log(10) 45 | psnr_list.append(psnr) 46 | return torch.stack(psnr_list) 47 | 48 | @torch.no_grad() 49 | def ssim_each(self, preds, gts): 50 | ssim_model = SSIM().to(device=self.device) 51 | ssim_list = [] 52 | for (pred, gt) in zip(preds, gts): 53 | pred = torch.clip(pred.permute((2, 0, 1)).unsqueeze(0).float(), 0, 1) 54 | gt = torch.clip(gt.permute((2, 0, 1)).unsqueeze(0).float(), 0, 1) 55 | ssim = ssim_model(pred, gt) 56 | ssim_list.append(ssim) 57 | del ssim_model 58 | return torch.stack(ssim_list) 59 | 60 | @torch.no_grad() 61 | def lpips_each(self, preds, gts): 62 | lpips_model = LPIPS(network="vgg").to(device=self.device) 63 | lpips_list = [] 64 | for (pred, gt) in zip(preds, gts): 65 | pred = torch.clip(pred.permute((2, 0, 1)).unsqueeze(0).float(), 0, 1) 66 | gt = torch.clip(gt.permute((2, 0, 1)).unsqueeze(0).float(), 0, 1) 67 | lpips = lpips_model(pred, gt) 68 | lpips_list.append(lpips) 69 | del lpips_model 70 | return torch.stack(lpips_list) 71 | 72 | @torch.no_grad() 73 | def psnr(self, preds, gts, i_train, i_val, i_test): 74 | ret = {} 75 | ret["name"] = "PSNR" 76 | psnr_list = self.psnr_each(preds, gts) 77 | ret["mean"] = psnr_list.mean().item() 78 | if self.trainer.datamodule.eval_test_only: 79 | ret["test"] = psnr_list.mean().item() 80 | else: 81 | ret["train"] = psnr_list[i_train].mean().item() 82 | ret["val"] = psnr_list[i_val].mean().item() 83 | ret["test"] = psnr_list[i_test].mean().item() 84 | 85 | return ret 86 | 87 | @torch.no_grad() 88 | def ssim(self, preds, gts, i_train, i_val, i_test): 89 | ret = {} 90 | ret["name"] = "SSIM" 91 | ssim_list = self.ssim_each(preds, gts) 92 | ret["mean"] = ssim_list.mean().item() 93 | if self.trainer.datamodule.eval_test_only: 94 | ret["test"] = ssim_list.mean().item() 95 | else: 96 | ret["train"] = ssim_list[i_train].mean().item() 97 | ret["val"] = ssim_list[i_val].mean().item() 98 | ret["test"] = ssim_list[i_test].mean().item() 99 | 100 | return ret 101 | 102 | @torch.no_grad() 103 | def lpips(self, preds, gts, i_train, i_val, i_test): 104 | ret = {} 105 | ret["name"] = "LPIPS" 106 | lpips_list = self.lpips_each(preds, gts) 107 | ret["mean"] = lpips_list.mean().item() 108 | if self.trainer.datamodule.eval_test_only: 109 | ret["test"] = lpips_list.mean().item() 110 | else: 111 | ret["train"] = lpips_list[i_train].mean().item() 112 | ret["val"] = lpips_list[i_val].mean().item() 113 | ret["test"] = lpips_list[i_test].mean().item() 114 | 115 | return ret 116 | 117 | def write_stats(self, fpath, *stats): 118 | 119 | d = {} 120 | for stat in stats: 121 | d[stat["name"]] = { 122 | k: float(w) 123 | for (k, w) in stat.items() 124 | if k != "name" and k != "scene_wise" 125 | } 126 | 127 | with open(fpath, "w") as fp: 128 | json.dump(d, fp, indent=4, sort_keys=True) 129 | 130 | def predict_step(self, *args, **kwargs): 131 | return self.test_step(*args, **kwargs) 132 | 133 | def on_predict_epoch_end(self, outputs): 134 | dmodule = self.trainer.datamodule 135 | image_sizes = dmodule.image_sizes 136 | image_num = len(dmodule.render_poses) 137 | all_image_sizes = np.stack([image_sizes[0] for _ in range(image_num)]) 138 | rgbs = self.alter_gather_cat(outputs[0], "rgb", all_image_sizes) 139 | 140 | if self.trainer.is_global_zero: 141 | image_dir = os.path.join(self.logdir, "render_video") 142 | os.makedirs(image_dir, exist_ok=True) 143 | store_image.store_image(image_dir, rgbs) 144 | store_image.store_video(image_dir, rgbs, None) 145 | 146 | return None 147 | -------------------------------------------------------------------------------- /src/model/nerf/helper.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from NeRF (https://github.com/bmild/nerf) 8 | # Copyright (c) 2020 Google LLC. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | 17 | def img2mse(x, y): 18 | return torch.mean((x - y) ** 2) 19 | 20 | 21 | def mse2psnr(x): 22 | return -10.0 * torch.log(x) / np.log(10) 23 | 24 | 25 | def cast_rays(t_vals, origins, directions): 26 | return origins[..., None, :] + t_vals[..., None] * directions[..., None, :] 27 | 28 | 29 | def sample_along_rays( 30 | rays_o, 31 | rays_d, 32 | num_samples, 33 | near, 34 | far, 35 | randomized, 36 | lindisp, 37 | ): 38 | bsz = rays_o.shape[0] 39 | t_vals = torch.linspace(0.0, 1.0, num_samples + 1, device=rays_o.device) 40 | if lindisp: 41 | t_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals) 42 | else: 43 | t_vals = near * (1.0 - t_vals) + far * t_vals 44 | 45 | if randomized: 46 | mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1]) 47 | upper = torch.cat([mids, t_vals[..., -1:]], -1) 48 | lower = torch.cat([t_vals[..., :1], mids], -1) 49 | t_rand = torch.rand((bsz, num_samples + 1), device=rays_o.device) 50 | t_vals = lower + (upper - lower) * t_rand 51 | else: 52 | t_vals = torch.broadcast_to(t_vals, (bsz, num_samples + 1)) 53 | 54 | coords = cast_rays(t_vals, rays_o, rays_d) 55 | 56 | return t_vals, coords 57 | 58 | 59 | def pos_enc(x, min_deg, max_deg): 60 | scales = torch.tensor([2**i for i in range(min_deg, max_deg)]).type_as(x) 61 | xb = torch.reshape((x[..., None, :] * scales[:, None]), list(x.shape[:-1]) + [-1]) 62 | four_feat = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1)) 63 | return torch.cat([x] + [four_feat], dim=-1) 64 | 65 | 66 | def volumetric_rendering(rgb, density, t_vals, dirs, white_bkgd): 67 | 68 | eps = 1e-10 69 | 70 | dists = torch.cat( 71 | [ 72 | t_vals[..., 1:] - t_vals[..., :-1], 73 | torch.ones(t_vals[..., :1].shape, device=t_vals.device) * 1e10, 74 | ], 75 | dim=-1, 76 | ) 77 | dists = dists * torch.norm(dirs[..., None, :], dim=-1) 78 | alpha = 1.0 - torch.exp(-density[..., 0] * dists) 79 | accum_prod = torch.cat( 80 | [ 81 | torch.ones_like(alpha[..., :1]), 82 | torch.cumprod(1.0 - alpha[..., :-1] + eps, dim=-1), 83 | ], 84 | dim=-1, 85 | ) 86 | 87 | weights = alpha * accum_prod 88 | 89 | comp_rgb = (weights[..., None] * rgb).sum(dim=-2) 90 | depth = (weights * t_vals).sum(dim=-1) 91 | acc = weights.sum(dim=-1) 92 | inv_eps = 1 / eps 93 | 94 | if white_bkgd: 95 | comp_rgb = comp_rgb + (1.0 - acc[..., None]) 96 | 97 | return comp_rgb, acc, weights 98 | 99 | 100 | def sorted_piecewise_constant_pdf( 101 | bins, weights, num_samples, randomized, float_min_eps=2**-32 102 | ): 103 | 104 | eps = 1e-5 105 | weight_sum = weights.sum(dim=-1, keepdims=True) 106 | padding = torch.fmax(torch.zeros_like(weight_sum), eps - weight_sum) 107 | weights = weights + padding / weights.shape[-1] 108 | weight_sum = weight_sum + padding 109 | 110 | pdf = weights / weight_sum 111 | cdf = torch.fmin( 112 | torch.ones_like(pdf[..., :-1]), torch.cumsum(pdf[..., :-1], dim=-1) 113 | ) 114 | cdf = torch.cat( 115 | [ 116 | torch.zeros(list(cdf.shape[:-1]) + [1], device=weights.device), 117 | cdf, 118 | torch.ones(list(cdf.shape[:-1]) + [1], device=weights.device), 119 | ], 120 | dim=-1, 121 | ) 122 | 123 | s = 1 / num_samples 124 | if randomized: 125 | u = torch.rand(list(cdf.shape[:-1]) + [num_samples], device=cdf.device) 126 | else: 127 | u = torch.linspace(0.0, 1.0 - float_min_eps, num_samples, device=cdf.device) 128 | u = torch.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples]) 129 | 130 | mask = u[..., None, :] >= cdf[..., :, None] 131 | 132 | bin0 = (mask * bins[..., None] + ~mask * bins[..., :1, None]).max(dim=-2)[0] 133 | bin1 = (~mask * bins[..., None] + mask * bins[..., -1:, None]).min(dim=-2)[0] 134 | # Debug Here 135 | cdf0 = (mask * cdf[..., None] + ~mask * cdf[..., :1, None]).max(dim=-2)[0] 136 | cdf1 = (~mask * cdf[..., None] + mask * cdf[..., -1:, None]).min(dim=-2)[0] 137 | 138 | t = torch.clip(torch.nan_to_num((u - cdf0) / (cdf1 - cdf0), 0), 0, 1) 139 | samples = bin0 + t * (bin1 - bin0) 140 | 141 | return samples 142 | 143 | 144 | def sample_pdf(bins, weights, origins, directions, t_vals, num_samples, randomized): 145 | 146 | t_samples = sorted_piecewise_constant_pdf( 147 | bins, weights, num_samples, randomized 148 | ).detach() 149 | t_vals = torch.sort(torch.cat([t_vals, t_samples], dim=-1), dim=-1).values 150 | coords = cast_rays(t_vals, origins, directions) 151 | return t_vals, coords 152 | -------------------------------------------------------------------------------- /src/model/plenoxel/__global__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from Plenoxels (https://github.com/sxyu/svox2) 8 | # Copyright (c) 2022 the Plenoxel authors. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | BASIS_TYPE_SH = 1 12 | BASIS_TYPE_3D_TEXTURE = 4 13 | BASIS_TYPE_MLP = 255 14 | 15 | 16 | def _get_c_extension(): 17 | from warnings import warn 18 | 19 | try: 20 | import lib.plenoxel as _C 21 | 22 | if not hasattr(_C, "sample_grid"): 23 | _C = None 24 | except: 25 | _C = None 26 | 27 | return _C 28 | -------------------------------------------------------------------------------- /src/model/plenoxel/autograd.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from Plenoxels (https://github.com/sxyu/svox2) 8 | # Copyright (c) 2022 the Plenoxel authors. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | import torch 12 | import torch.autograd as autograd 13 | import src.model.plenoxel.utils as utils 14 | 15 | from typing import Tuple 16 | 17 | from src.model.plenoxel.__global__ import ( 18 | BASIS_TYPE_SH, 19 | _get_c_extension, 20 | BASIS_TYPE_3D_TEXTURE, 21 | BASIS_TYPE_MLP, 22 | ) 23 | 24 | _C = _get_c_extension() 25 | 26 | # BEGIN Differentiable CUDA functions with custom gradient 27 | class _SampleGridAutogradFunction(autograd.Function): 28 | @staticmethod 29 | def forward( 30 | ctx, 31 | data_density: torch.Tensor, 32 | data_sh: torch.Tensor, 33 | grid, 34 | points: torch.Tensor, 35 | want_colors: bool, 36 | ): 37 | assert not points.requires_grad, "Point gradient not supported" 38 | out_density, out_sh = _C.sample_grid(grid, points, want_colors) 39 | ctx.save_for_backward(points) 40 | ctx.grid = grid 41 | ctx.want_colors = want_colors 42 | return out_density, out_sh 43 | 44 | @staticmethod 45 | def backward(ctx, grad_out_density, grad_out_sh): 46 | (points,) = ctx.saved_tensors 47 | grad_density_grid = torch.zeros_like(ctx.grid.density_data.data) 48 | grad_sh_grid = torch.zeros_like(ctx.grid.sh_data.data) 49 | _C.sample_grid_backward( 50 | ctx.grid, 51 | points, 52 | grad_out_density.contiguous(), 53 | grad_out_sh.contiguous(), 54 | grad_density_grid, 55 | grad_sh_grid, 56 | ctx.want_colors, 57 | ) 58 | if not ctx.needs_input_grad[0]: 59 | grad_density_grid = None 60 | if not ctx.needs_input_grad[1]: 61 | grad_sh_grid = None 62 | 63 | return grad_density_grid, grad_sh_grid, None, None, None 64 | 65 | 66 | class _VolumeRenderFunction(autograd.Function): 67 | @staticmethod 68 | def forward( 69 | ctx, 70 | data_density: torch.Tensor, 71 | data_sh: torch.Tensor, 72 | data_basis: torch.Tensor, 73 | data_background: torch.Tensor, 74 | grid, 75 | rays, 76 | opt, 77 | backend: str, 78 | ): 79 | cu_fn = _C.__dict__[f"volume_render_{backend}"] 80 | color = cu_fn(grid, rays, opt) 81 | ctx.save_for_backward(color) 82 | ctx.grid = grid 83 | ctx.rays = rays 84 | ctx.opt = opt 85 | ctx.backend = backend 86 | ctx.basis_data = data_basis 87 | return color 88 | 89 | @staticmethod 90 | def backward(ctx, grad_out): 91 | (color_cache,) = ctx.saved_tensors 92 | cu_fn = _C.__dict__[f"volume_render_{ctx.backend}_backward"] 93 | grad_density_grid = torch.zeros_like(ctx.grid.density_data.data) 94 | grad_sh_grid = torch.zeros_like(ctx.grid.sh_data.data) 95 | if ctx.grid.basis_type == BASIS_TYPE_MLP: 96 | grad_basis = torch.zeros_like(ctx.basis_data) 97 | elif ctx.grid.basis_type == BASIS_TYPE_3D_TEXTURE: 98 | grad_basis = torch.zeros_like(ctx.grid.basis_data.data) 99 | if ctx.grid.background_data is not None: 100 | grad_background = torch.zeros_like(ctx.grid.background_data.data) 101 | grad_holder = _C.GridOutputGrads() 102 | grad_holder.grad_density_out = grad_density_grid 103 | grad_holder.grad_sh_out = grad_sh_grid 104 | if ctx.needs_input_grad[2]: 105 | grad_holder.grad_basis_out = grad_basis 106 | if ctx.grid.background_data is not None and ctx.needs_input_grad[3]: 107 | grad_holder.grad_background_out = grad_background 108 | cu_fn( 109 | ctx.grid, ctx.rays, ctx.opt, grad_out.contiguous(), color_cache, grad_holder 110 | ) 111 | ctx.grid = ctx.rays = ctx.opt = None 112 | if not ctx.needs_input_grad[0]: 113 | grad_density_grid = None 114 | if not ctx.needs_input_grad[1]: 115 | grad_sh_grid = None 116 | if not ctx.needs_input_grad[2]: 117 | grad_basis = None 118 | if not ctx.needs_input_grad[3]: 119 | grad_background = None 120 | ctx.basis_data = None 121 | 122 | return ( 123 | grad_density_grid, 124 | grad_sh_grid, 125 | grad_basis, 126 | grad_background, 127 | None, 128 | None, 129 | None, 130 | None, 131 | ) 132 | 133 | 134 | class _TotalVariationFunction(autograd.Function): 135 | @staticmethod 136 | def forward( 137 | ctx, 138 | data: torch.Tensor, 139 | links: torch.Tensor, 140 | start_dim: int, 141 | end_dim: int, 142 | use_logalpha: bool, 143 | logalpha_delta: float, 144 | ignore_edge: bool, 145 | ndc_coeffs: Tuple[float, float], 146 | ): 147 | tv = _C.tv( 148 | links, 149 | data, 150 | start_dim, 151 | end_dim, 152 | use_logalpha, 153 | logalpha_delta, 154 | ignore_edge, 155 | ndc_coeffs[0], 156 | ndc_coeffs[1], 157 | ) 158 | ctx.save_for_backward(links, data) 159 | ctx.start_dim = start_dim 160 | ctx.end_dim = end_dim 161 | ctx.use_logalpha = use_logalpha 162 | ctx.logalpha_delta = logalpha_delta 163 | ctx.ignore_edge = ignore_edge 164 | ctx.ndc_coeffs = ndc_coeffs 165 | return tv 166 | 167 | @staticmethod 168 | def backward(ctx, grad_out): 169 | links, data = ctx.saved_tensors 170 | grad_grid = torch.zeros_like(data) 171 | _C.tv_grad( 172 | links, 173 | data, 174 | ctx.start_dim, 175 | ctx.end_dim, 176 | 1.0, 177 | ctx.use_logalpha, 178 | ctx.logalpha_delta, 179 | ctx.ignore_edge, 180 | ctx.ndc_coeffs[0], 181 | ctx.ndc_coeffs[1], 182 | grad_grid, 183 | ) 184 | ctx.start_dim = ctx.end_dim = None 185 | if not ctx.needs_input_grad[0]: 186 | grad_grid = None 187 | return grad_grid, None, None, None, None, None, None, None 188 | -------------------------------------------------------------------------------- /src/model/plenoxel/dataclass.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from Plenoxels (https://github.com/sxyu/svox2) 8 | # Copyright (c) 2022 the Plenoxel authors. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | from dataclasses import dataclass 12 | from random import random 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import torch 16 | 17 | import src.model.plenoxel.utils as utils 18 | from src.model.plenoxel.__global__ import _get_c_extension 19 | 20 | _C = _get_c_extension() 21 | 22 | 23 | @dataclass 24 | class RenderOptions: 25 | """ 26 | Rendering options, see comments 27 | available: 28 | :param backend: str, renderer backend 29 | :param background_brightness: float 30 | :param step_size: float, step size for rendering 31 | :param sigma_thresh: float 32 | :param stop_thresh: float 33 | """ 34 | 35 | def __init__( 36 | self, 37 | backend: str = "cuvol", 38 | background_brightness: float = 1.0, 39 | step_size: float = 0.5, 40 | sigma_thresh: float = 1e-10, 41 | stop_thresh: float = 1e-7, 42 | last_sample_opaque: bool = False, 43 | near_clip: float = 0.0, 44 | use_spheric_clip: bool = False, 45 | ): 46 | self.backend = backend 47 | self.background_brightness = background_brightness 48 | self.step_size = step_size 49 | self.sigma_thresh = sigma_thresh 50 | self.stop_thresh = stop_thresh 51 | self.last_sample_opaque = last_sample_opaque 52 | self.near_clip = near_clip 53 | self.use_spheric_clip = use_spheric_clip 54 | 55 | def _to_cpp(self, randomize: bool = False): 56 | """ 57 | Generate object to pass to C++ 58 | """ 59 | opt = _C.RenderOptions() 60 | opt.background_brightness = self.background_brightness 61 | opt.step_size = self.step_size 62 | opt.sigma_thresh = self.sigma_thresh 63 | opt.stop_thresh = self.stop_thresh 64 | opt.near_clip = self.near_clip 65 | opt.use_spheric_clip = self.use_spheric_clip 66 | opt.last_sample_opaque = self.last_sample_opaque 67 | 68 | return opt 69 | 70 | 71 | @dataclass 72 | class Rays: 73 | origins: torch.Tensor 74 | dirs: torch.Tensor 75 | 76 | def _to_cpp(self): 77 | """ 78 | Generate object to pass to C++ 79 | """ 80 | spec = _C.RaysSpec() 81 | spec.origins = self.origins 82 | spec.dirs = self.dirs 83 | return spec 84 | 85 | def __getitem__(self, key): 86 | return Rays(self.origins[key], self.dirs[key]) 87 | 88 | @property 89 | def is_cuda(self) -> bool: 90 | return self.origins.is_cuda and self.dirs.is_cuda 91 | 92 | 93 | @dataclass 94 | class Camera: 95 | c2w: torch.Tensor # OpenCV 96 | fx: float = 1111.11 97 | fy: Optional[float] = None 98 | cx: Optional[float] = None 99 | cy: Optional[float] = None 100 | width: int = 800 101 | height: int = 800 102 | 103 | ndc_coeffs: Union[Tuple[float, float], List[float]] = (-1.0, -1.0) 104 | 105 | @property 106 | def fx_val(self): 107 | return self.fx 108 | 109 | @property 110 | def fy_val(self): 111 | return self.fx if self.fy is None else self.fy 112 | 113 | @property 114 | def cx_val(self): 115 | return self.width * 0.5 if self.cx is None else self.cx 116 | 117 | @property 118 | def cy_val(self): 119 | return self.height * 0.5 if self.cy is None else self.cy 120 | 121 | @property 122 | def using_ndc(self): 123 | return self.ndc_coeffs[0] > 0.0 124 | 125 | def _to_cpp(self): 126 | """ 127 | Generate object to pass to C++ 128 | """ 129 | spec = _C.CameraSpec() 130 | spec.c2w = self.c2w.float() 131 | spec.fx = float(self.fx_val) 132 | spec.fy = float(self.fy_val) 133 | spec.cx = float(self.cx_val) 134 | spec.cy = float(self.cy_val) 135 | spec.width = int(self.width) 136 | spec.height = int(self.height) 137 | spec.ndc_coeffx = float(self.ndc_coeffs[0]) 138 | spec.ndc_coeffy = float(self.ndc_coeffs[1]) 139 | return spec 140 | 141 | @property 142 | def is_cuda(self) -> bool: 143 | return self.c2w.is_cuda 144 | -------------------------------------------------------------------------------- /src/model/refnerf/ref_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # ------------------------------------------------------------------------------------ 7 | # Modified from Ref-NeRF (https://github.com/google-research/multinerf) 8 | # Copyright (c) 2022 Google LLC. All Rights Reserved. 9 | # ------------------------------------------------------------------------------------ 10 | 11 | 12 | import numpy as np 13 | import torch 14 | 15 | 16 | def reflect(viewdirs, normals): 17 | """Reflect view directions about normals. 18 | 19 | The reflection of a vector v about a unit vector n is a vector u such that 20 | dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two 21 | equations is u = 2 dot(n, v) n - v. 22 | 23 | Args: 24 | viewdirs: [..., 3] array of view directions. 25 | normals: [..., 3] array of normal directions (assumed to be unit vectors). 26 | 27 | Returns: 28 | [..., 3] array of reflection directions. 29 | """ 30 | return ( 31 | 2.0 * torch.sum(normals * viewdirs, dim=-1, keepdims=True) * normals - viewdirs 32 | ) 33 | 34 | 35 | def l2_normalize(x, eps=torch.finfo(torch.float32).eps): 36 | """Normalize x to unit length along last axis.""" 37 | 38 | return x / torch.sqrt( 39 | torch.fmax(torch.sum(x**2, dim=-1, keepdims=True), torch.full_like(x, eps)) 40 | ) 41 | 42 | 43 | def compute_weighted_mae(weights, normals, normals_gt): 44 | """Compute weighted mean angular error, assuming normals are unit length.""" 45 | one_eps = 1 - torch.finfo(torch.float32).eps 46 | return ( 47 | ( 48 | weights 49 | * torch.arccos( 50 | torch.clamp(torch.sum(normals * normals_gt, -1), -one_eps, one_eps) 51 | ) 52 | ).sum() 53 | / torch.sum(weights) 54 | * 180.0 55 | / np.pi 56 | ) 57 | 58 | 59 | def generalized_binomial_coeff(a, k): 60 | """Compute generalized binomial coefficients.""" 61 | # return np.prod(a - np.arange(k)) / np.math.factorial(k) 62 | return np.prod(a - np.arange(k)) / np.math.factorial(k) 63 | 64 | 65 | def assoc_legendre_coeff(l, m, k): 66 | """Compute associated Legendre polynomial coefficients. 67 | 68 | Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the 69 | (l, m)th associated Legendre polynomial, P_l^m(cos(theta)). 70 | 71 | Args: 72 | l: associated Legendre polynomial degree. 73 | m: associated Legendre polynomial order. 74 | k: power of cos(theta). 75 | 76 | Returns: 77 | A float, the coefficient of the term corresponding to the inputs. 78 | """ 79 | return ( 80 | (-1) ** m 81 | * 2**l 82 | * np.math.factorial(l) 83 | / np.math.factorial(k) 84 | / np.math.factorial(l - k - m) 85 | * generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l) 86 | ) 87 | 88 | 89 | def sph_harm_coeff(l, m, k): 90 | """Compute spherical harmonic coefficients.""" 91 | # return (np.sqrt( 92 | # (2.0 * l + 1.0) * np.math.factorial(l - m) / 93 | # (4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k)) 94 | return np.sqrt( 95 | (2.0 * l + 1.0) 96 | * np.math.factorial(l - m) 97 | / (4.0 * np.pi * np.math.factorial(l + m)) 98 | ) * assoc_legendre_coeff(l, m, k) 99 | 100 | 101 | def get_ml_array(deg_view): 102 | """Create a list with all pairs of (l, m) values to use in the encoding.""" 103 | ml_list = [] 104 | for i in range(deg_view): 105 | l = 2**i 106 | # Only use nonnegative m values, later splitting real and imaginary parts. 107 | for m in range(l + 1): 108 | ml_list.append((m, l)) 109 | 110 | ml_array = np.array(ml_list).T 111 | return ml_array 112 | 113 | 114 | def generate_ide_fn(deg_view): 115 | """Generate integrated directional encoding (IDE) function. 116 | 117 | This function returns a function that computes the integrated directional 118 | encoding from Equations 6-8 of arxiv.org/abs/2112.03907. 119 | 120 | Args: 121 | deg_view: number of spherical harmonics degrees to use. 122 | 123 | Returns: 124 | A function for evaluating integrated directional encoding. 125 | 126 | Raises: 127 | ValueError: if deg_view is larger than 5. 128 | """ 129 | if deg_view > 5: 130 | raise ValueError("Only deg_view of at most 5 is numerically stable.") 131 | 132 | ml_array = get_ml_array(deg_view) 133 | l_max = 2 ** (deg_view - 1) 134 | 135 | # Create a matrix corresponding to ml_array holding all coefficients, which, 136 | # when multiplied (from the right) by the z coordinate Vandermonde matrix, 137 | # results in the z component of the encoding. 138 | mat = np.zeros((l_max + 1, ml_array.shape[1])) 139 | for i, (m, l) in enumerate(ml_array.T): 140 | for k in range(l - m + 1): 141 | mat[k, i] = sph_harm_coeff(l, m, k) 142 | 143 | mat = torch.Tensor(mat) 144 | 145 | def integrated_dir_enc_fn(xyz, kappa_inv): 146 | """Function returning integrated directional encoding (IDE). 147 | 148 | Args: 149 | xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at. 150 | kappa_inv: [..., 1] reciprocal of the concentration parameter of the von 151 | Mises-Fisher distribution. 152 | 153 | Returns: 154 | An array with the resulting IDE. 155 | """ 156 | x = xyz[..., 0:1] 157 | y = xyz[..., 1:2] 158 | z = xyz[..., 2:3] 159 | 160 | vmz = torch.cat([z**i for i in range(mat.shape[0])], dim=-1) 161 | vmxy = torch.cat([(x + 1j * y) ** m for m in ml_array[0, :]], dim=-1) 162 | 163 | sph_harms = vmxy * torch.matmul(vmz, mat.to(xyz.device)) 164 | 165 | sigma = torch.Tensor(0.5 * ml_array[1, :] * (ml_array[1, :] + 1)).to( 166 | kappa_inv.device 167 | ) 168 | ide = sph_harms * torch.exp(-sigma * kappa_inv) 169 | 170 | return torch.cat([torch.real(ide), torch.imag(ide)], dim=-1) 171 | 172 | return integrated_dir_enc_fn 173 | 174 | 175 | def generate_dir_enc_fn(deg_view): 176 | """Generate directional encoding (DE) function. 177 | 178 | Args: 179 | deg_view: number of spherical harmonics degrees to use. 180 | 181 | Returns: 182 | A function for evaluating directional encoding. 183 | """ 184 | integrated_dir_enc_fn = generate_ide_fn(deg_view) 185 | 186 | def dir_enc_fn(xyz): 187 | """Function returning directional encoding (DE).""" 188 | return integrated_dir_enc_fn(xyz, torch.zeros_like(xyz[..., :1])) 189 | 190 | return dir_enc_fn 191 | -------------------------------------------------------------------------------- /utils/check_mean_score.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import argparse 8 | import json 9 | import os 10 | 11 | import numpy as np 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--postfix", default="", type=str, help="files name to filter") 16 | parser.add_argument( 17 | "--dirpath", default=".", type=str, help="path to the directory" 18 | ) 19 | args = parser.parse_args() 20 | 21 | json_list = [] 22 | for dirname in os.listdir(args.dirpath): 23 | json_path = os.path.join(args.dirpath, dirname, "results.json") 24 | if os.path.exists(json_path): 25 | with open(json_path, "r") as fp: 26 | json_list.append(json.load(fp)) 27 | 28 | print(len(json_list)) 29 | psnr_mean, psnr_train, psnr_val, psnr_test = [], [], [], [] 30 | ssim_mean, ssim_train, ssim_val, ssim_test = [], [], [], [] 31 | lpips_mean, lpips_train, lpips_val, lpips_test = [], [], [], [] 32 | 33 | for json_file in json_list: 34 | psnr_mean.append(json_file["PSNR"]["mean"]) 35 | psnr_train.append(json_file["PSNR"]["train_mean"]) 36 | psnr_val.append(json_file["PSNR"]["val_mean"]) 37 | psnr_test.append(json_file["PSNR"]["test_mean"]) 38 | ssim_mean.append(json_file["SSIM"]["mean"]) 39 | ssim_train.append(json_file["SSIM"]["train_mean"]) 40 | ssim_val.append(json_file["SSIM"]["val_mean"]) 41 | ssim_test.append(json_file["SSIM"]["test_mean"]) 42 | lpips_mean.append(json_file["LPIPS-VGG"]["mean"]) 43 | lpips_train.append(json_file["LPIPS-VGG"]["train_mean"]) 44 | lpips_val.append(json_file["LPIPS-VGG"]["val_mean"]) 45 | lpips_test.append(json_file["LPIPS-VGG"]["test_mean"]) 46 | 47 | score_name = ( 48 | "psnr_mean", 49 | "psnr_train", 50 | "psnr_val", 51 | "psnr_test", 52 | "ssim_mean", 53 | "ssim_train", 54 | "ssim_val", 55 | "ssim_test", 56 | "lpips_mean", 57 | "lpips_train", 58 | "lpips_val", 59 | "lpips_test", 60 | ) 61 | 62 | for name in score_name: 63 | print(f"{name} : {np.array(eval(name)).mean()}") 64 | -------------------------------------------------------------------------------- /utils/create_scripts.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import argparse 8 | import os 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--model", type=str, default=None, help="name of dataset") 13 | 14 | parser.add_argument( 15 | "--dataset", type=str, action="append", default=None, help="name of dataset" 16 | ) 17 | args = parser.parse_args() 18 | 19 | seeds = [111, 333, 555] 20 | 21 | blender_scene_list = [ 22 | "chair", 23 | "drums", 24 | "ficus", 25 | "hotdog", 26 | "lego", 27 | "materials", 28 | "mic", 29 | "ship", 30 | ] 31 | 32 | llff_scene_list = [ 33 | "fern", 34 | "flower", 35 | "fortress", 36 | "horns", 37 | "orchids", 38 | "leaves", 39 | "room", 40 | "trex", 41 | ] 42 | 43 | tnt_scene_list = [ 44 | "tat_intermediate_M60", 45 | "tat_intermediate_Playground", 46 | "tat_intermediate_Train", 47 | "tat_training_Truck", 48 | ] 49 | 50 | lf_scene_list = [ 51 | "africa", 52 | "basket", 53 | "ship", 54 | "statue", 55 | "torch", 56 | ] 57 | 58 | file_name = f"{args.model}.sh" 59 | file = open(os.path.join("../scripts", file_name), "w") 60 | 61 | scene_list = [] 62 | 63 | if "blender" in args.dataset: 64 | ginc = f"configs/{args.model}/blender.gin" 65 | for scene in blender_scene_list: 66 | for seed in seeds: 67 | file.write( 68 | f"python3 -m run --ginc {ginc} --scene_name {scene} --seed {seed}\n" 69 | ) 70 | 71 | if "llff" in args.dataset: 72 | ginc = f"configs/{args.model}/llff.gin" 73 | for scene in llff_scene_list: 74 | for seed in seeds: 75 | file.write( 76 | f"python3 -m run --ginc {ginc} --scene_name {scene} --seed {seed}\n" 77 | ) 78 | 79 | if "tanks_and_temples" in args.dataset or "tnt" in args.dataset: 80 | ginc = f"configs/{args.model}/tnt.gin" 81 | for scene in tnt_scene_list: 82 | for seed in seeds: 83 | file.write( 84 | f"python3 -m run --ginc {ginc} --scene_name {scene} --seed {seed}\n" 85 | ) 86 | 87 | if "lf" in args.dataset: 88 | ginc = f"configs/{args.model}/lf.gin" 89 | for scene in lf_scene_list: 90 | for seed in seeds: 91 | file.write( 92 | f"python3 -m run --ginc {ginc} --scene_name {scene} --seed {seed}\n" 93 | ) 94 | 95 | file.close() 96 | -------------------------------------------------------------------------------- /utils/preprocess_shiny_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import imageio 4 | import numpy as np 5 | from PIL import Image 6 | 7 | if __name__ == "__main__": 8 | img_dir = "./data/refnerf/ball" 9 | 10 | for i in range(100): 11 | alpha = imageio.imread(os.path.join(img_dir, "train", f"r_{i}_alpha.png")) 12 | img = imageio.imread(os.path.join(img_dir, "train", f"r_{i}.png")) 13 | img = np.concatenate([img, alpha[..., None]], axis=-1) 14 | img = Image.fromarray(img) 15 | img.save(os.path.join(img_dir, "train", f"r_{i}.png")) 16 | 17 | for i in range(200): 18 | alpha = imageio.imread(os.path.join(img_dir, "test", f"r_{i}_alpha.png")) 19 | img = imageio.imread(os.path.join(img_dir, "test", f"r_{i}.png")) 20 | img = np.concatenate([img, alpha[..., None]], axis=-1) 21 | img = Image.fromarray(img) 22 | img.save(os.path.join(img_dir, "test", f"r_{i}.png")) 23 | -------------------------------------------------------------------------------- /utils/select_option.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | from typing import * 8 | 9 | from src.data.litdata import ( 10 | LitDataBlender, 11 | LitDataBlenderMultiScale, 12 | LitDataLF, 13 | LitDataLLFF, 14 | LitDataNeRF360V2, 15 | LitDataRefNeRFReal, 16 | LitDataShinyBlender, 17 | LitDataTnT, 18 | ) 19 | from src.model.dvgo.model import LitDVGO 20 | from src.model.mipnerf360.model import LitMipNeRF360 21 | from src.model.mipnerf.model import LitMipNeRF 22 | from src.model.nerf.model import LitNeRF 23 | from src.model.nerfpp.model import LitNeRFPP 24 | from src.model.plenoxel.model import LitPlenoxel 25 | from src.model.refnerf.model import LitRefNeRF 26 | 27 | 28 | def select_model( 29 | model_name: str, 30 | ): 31 | 32 | if model_name == "nerf": 33 | return LitNeRF() 34 | elif model_name == "mipnerf": 35 | return LitMipNeRF() 36 | elif model_name == "plenoxel": 37 | return LitPlenoxel() 38 | elif model_name == "nerfpp": 39 | return LitNeRFPP() 40 | elif model_name == "dvgo": 41 | return LitDVGO() 42 | elif model_name == "refnerf": 43 | return LitRefNeRF() 44 | elif model_name == "mipnerf360": 45 | return LitMipNeRF360() 46 | 47 | else: 48 | raise f"Unknown model named {model_name}" 49 | 50 | 51 | def select_dataset( 52 | dataset_name: str, 53 | datadir: str, 54 | scene_name: str, 55 | ): 56 | if dataset_name == "blender": 57 | data_fun = LitDataBlender 58 | elif dataset_name == "blender_multiscale": 59 | data_fun = LitDataBlenderMultiScale 60 | elif dataset_name == "llff": 61 | data_fun = LitDataLLFF 62 | elif dataset_name == "tanks_and_temples": 63 | data_fun = LitDataTnT 64 | elif dataset_name == "lf": 65 | data_fun = LitDataLF 66 | elif dataset_name == "nerf_360_v2": 67 | data_fun = LitDataNeRF360V2 68 | elif dataset_name == "shiny_blender": 69 | data_fun = LitDataShinyBlender 70 | elif dataset_name == "refnerf_real": 71 | data_fun = LitDataRefNeRFReal 72 | 73 | return data_fun( 74 | datadir=datadir, 75 | scene_name=scene_name, 76 | ) 77 | 78 | 79 | def select_callback(model_name): 80 | 81 | callbacks = [] 82 | 83 | if model_name == "plenoxel": 84 | import src.model.plenoxel.model as model 85 | 86 | callbacks += [model.ResampleCallBack()] 87 | 88 | if model_name == "dvgo": 89 | import src.model.dvgo.model as model 90 | 91 | callbacks += [ 92 | model.Coarse2Fine(), 93 | model.ProgressiveScaling(), 94 | model.UpdateOccupancyMask(), 95 | ] 96 | 97 | return callbacks 98 | -------------------------------------------------------------------------------- /utils/store_image.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # NeRF-Factory 3 | # Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | 9 | import imageio 10 | import numpy as np 11 | from PIL import Image 12 | 13 | 14 | def to8b(x): 15 | return (255 * np.clip(x, 0, 1)).astype(np.uint8) 16 | 17 | 18 | def norm8b(x): 19 | x = (x - x.min()) / (x.max() - x.min()) 20 | return to8b(x) 21 | 22 | 23 | def store_image(dirpath, rgbs): 24 | for (i, rgb) in enumerate(rgbs): 25 | imgname = f"image{str(i).zfill(3)}.png" 26 | rgbimg = Image.fromarray(to8b(rgb.detach().cpu().numpy())) 27 | imgpath = os.path.join(dirpath, imgname) 28 | rgbimg.save(imgpath) 29 | 30 | 31 | def store_video(dirpath, rgbs, depths): 32 | rgbimgs = [to8b(rgb.cpu().detach().numpy()) for rgb in rgbs] 33 | video_dir = os.path.join(dirpath, "videos") 34 | os.makedirs(video_dir, exist_ok=True) 35 | imageio.mimwrite(os.path.join(video_dir, "images.mp4"), rgbimgs, fps=20, quality=8) 36 | --------------------------------------------------------------------------------