├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── defaults.gin ├── hypernerf_interp_ap_2d.gin ├── hypernerf_interp_ds_2d.gin ├── hypernerf_vrig_ap_2d.gin ├── hypernerf_vrig_ds_2d.gin └── test_local.gin ├── eval.py ├── hypernerf ├── __init__.py ├── camera.py ├── configs.py ├── datasets │ ├── __init__.py │ ├── core.py │ ├── interp.py │ └── nerfies.py ├── dual_quaternion.py ├── evaluation.py ├── gpath.py ├── image_utils.py ├── model_utils.py ├── models.py ├── modules.py ├── quaternion.py ├── rigid_body.py ├── schedules.py ├── testdata │ └── camera.json ├── tf_camera.py ├── training.py ├── types.py ├── utils.py ├── visualization.py └── warping.py ├── notebooks ├── HyperNeRF_Render_Video.ipynb ├── HyperNeRF_Training.ipynb └── figures │ ├── hypernerf_ap_ds_figure.ipynb │ ├── hypernerf_optim_latent.ipynb │ ├── level_set_visualization.ipynb │ ├── nerfies_2d_experiments.ipynb │ ├── nerfies_eval_skeleton.ipynb │ └── sdf_2d.ipynb ├── requirements.txt ├── setup.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .idea 3 | bower_components 4 | node_modules 5 | venv 6 | 7 | *.ts.map 8 | 9 | *~ 10 | *.so 11 | .DS_Store 12 | ._.DS_Store 13 | *.swp 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *,cover 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields 2 | 3 | This is the code for "HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields". 4 | 5 | * [Project Page](https://hypernerf.github.io) 6 | * [Paper](https://arxiv.org/abs/2106.13228) 7 | * [Video](https://www.youtube.com/watch?v=qzgdE_ghkaI) 8 | 9 | This codebase implements HyperNeRF using [JAX](https://github.com/google/jax), 10 | building on [JaxNeRF](https://github.com/google-research/google-research/tree/master/jaxnerf). 11 | 12 | 13 | ## Demo 14 | 15 | We provide an easy-to-get-started demo using Google Colab! 16 | 17 | These Colabs will allow you to train a basic version of our method using 18 | Cloud TPUs (or GPUs) on Google Colab. 19 | 20 | Note that due to limited compute resources available, these are not the fully 21 | featured models and will train quite slowly and the quality will likely not be that great. 22 | If you would like to train a fully featured model, please refer to the instructions below 23 | on how to train on your own machine. 24 | 25 | | Description | Link | 26 | | ----------- | ----------- | 27 | | Process a video into a dataset| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb)| 28 | | Train HyperNeRF| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/hypernerf/blob/main/notebooks/HyperNeRF_Training.ipynb)| 29 | | Render HyperNeRF Videos| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/hypernerf/blob/main/notebooks/HyperNeRF_Render_Video.ipynb)| 30 | 31 | 32 | ## Setup 33 | The code can be run under any environment with Python 3.8 and above. 34 | (It may run with lower versions, but we have not tested it). 35 | 36 | We recommend using [Miniconda](https://docs.conda.io/en/latest/miniconda.html) and setting up an environment: 37 | 38 | conda create --name hypernerf python=3.8 39 | 40 | Next, install the required packages: 41 | 42 | pip install -r requirements.txt 43 | 44 | Install the appropriate JAX distribution for your environment by [following the instructions here](https://github.com/google/jax#installation). For example: 45 | 46 | # For CUDA version 11.1 47 | pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html 48 | 49 | 50 | ## Training 51 | After preparing a dataset, you can train a Nerfie by running: 52 | 53 | export DATASET_PATH=/path/to/dataset 54 | export EXPERIMENT_PATH=/path/to/save/experiment/to 55 | python train.py \ 56 | --base_folder $EXPERIMENT_PATH \ 57 | --gin_bindings="data_dir='$DATASET_PATH'" \ 58 | --gin_configs configs/test_local.gin 59 | 60 | To plot telemetry to Tensorboard and render checkpoints on the fly, also 61 | launch an evaluation job by running: 62 | 63 | python eval.py \ 64 | --base_folder $EXPERIMENT_PATH \ 65 | --gin_bindings="data_dir='$DATASET_PATH'" \ 66 | --gin_configs configs/test_local.gin 67 | 68 | The two jobs should use a mutually exclusive set of GPUs. This division allows the 69 | training job to run without having to stop for evaluation. 70 | 71 | 72 | ## Configuration 73 | * We use [Gin](https://github.com/google/gin-config) for configuration. 74 | * We provide a couple preset configurations. 75 | * Please refer to `config.py` for documentation on what each configuration does. 76 | * Preset configs: 77 | - `hypernerf_vrig_ds_2d.gin`: The deformable surface configuration for the validation rig (novel-view synthesis) experiments. 78 | - `hypernerf_vrig_ap_2d.gin`: The axis-aligned plane configuration for the validation rig (novel-view synthesis) experiments. 79 | - `hypernerf_interp_ds_2d.gin`: The deformable surface configuration for the interpolation experiments. 80 | - `hypernerf_interp_ap_2d.gin`: The axis-aligned plane configuration for the interpolation experiments. 81 | 82 | 83 | ## Dataset 84 | The dataset uses the [same format as Nerfies](https://github.com/google/nerfies#datasets). 85 | 86 | 87 | ## Citing 88 | If you find our work useful, please consider citing: 89 | ```BibTeX 90 | @article{park2021hypernerf, 91 | author = {Park, Keunhong and Sinha, Utkarsh and Hedman, Peter and Barron, Jonathan T. and Bouaziz, Sofien and Goldman, Dan B and Martin-Brualla, Ricardo and Seitz, Steven M.}, 92 | title = {HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields}, 93 | journal = {ACM Trans. Graph.}, 94 | issue_date = {December 2021}, 95 | publisher = {ACM}, 96 | volume = {40}, 97 | number = {6}, 98 | month = {dec}, 99 | year = {2021}, 100 | articleno = {238}, 101 | } 102 | ``` 103 | -------------------------------------------------------------------------------- /configs/defaults.gin: -------------------------------------------------------------------------------- 1 | # The base configuration file. 2 | 3 | spatial_point_min_deg = 0 4 | spatial_point_max_deg = 8 5 | warp_min_deg = 0 6 | warp_max_deg = 8 7 | elastic_init_weight = 0.01 8 | lr_delay_steps = 2500 9 | lr_delay_mult = 0.01 10 | 11 | hyper_num_dims = 8 12 | hyper_point_max_deg = 1 13 | 14 | # Predefined warp alpha schedules. 15 | ANNEALED_WARP_ALPHA_SCHED = { 16 | 'type': 'linear', 17 | 'initial_value': %warp_min_deg, 18 | 'final_value': %warp_max_deg, 19 | 'num_steps': 80000, 20 | } 21 | CONSTANT_WARP_ALPHA_SCHED = { 22 | 'type': 'constant', 23 | 'value': %warp_max_deg, 24 | } 25 | 26 | # Predefined elastic loss schedules. 27 | CONSTANT_ELASTIC_LOSS_SCHED = { 28 | 'type': 'constant', 29 | 'value': %elastic_init_weight, 30 | } 31 | DECAYING_ELASTIC_LOSS_SCHED = { 32 | 'type': 'piecewise', 33 | 'schedules': [ 34 | (50000, ('constant', %elastic_init_weight)), 35 | (100000, ('cosine_easing', %elastic_init_weight, 1e-8, 100000)), 36 | ] 37 | } 38 | 39 | DEFAULT_LR_SCHEDULE = { 40 | 'type': 'exponential', 41 | 'initial_value': %init_lr, 42 | 'final_value': %final_lr, 43 | 'num_steps': %max_steps, 44 | } 45 | 46 | DELAYED_LR_SCHEDULE = { 47 | 'type': 'delayed', 48 | 'delay_steps': %lr_delay_steps, 49 | 'delay_mult': %lr_delay_mult, 50 | 'base_schedule': %DEFAULT_LR_SCHEDULE, 51 | } 52 | 53 | DELAYED_HYPER_ALPHA_SCHED = { 54 | 'type': 'piecewise', 55 | 'schedules': [ 56 | (1000, ('constant', 0.0)), 57 | (0, ('linear', 0.0, %hyper_point_max_deg, 10000)) 58 | ], 59 | } 60 | FAST_HYPER_ALPHA_SCHED = { 61 | 'type': 'piecewise', 62 | 'schedules': [ 63 | (0, ('linear', 0.0, %hyper_point_max_deg, 2500)) 64 | ], 65 | } 66 | CONSTANT_HYPER_ALPHA_SCHED = ('constant', %hyper_point_max_deg) 67 | 68 | 69 | # Experiment configs. 70 | ExperimentConfig.image_scale = %image_scale 71 | ExperimentConfig.random_seed = 0 72 | ExperimentConfig.datasource_cls = @NerfiesDataSource 73 | NerfiesDataSource.data_dir = %data_dir 74 | NerfiesDataSource.image_scale = %image_scale 75 | 76 | # Common configs. 77 | NerfModel.use_viewdirs = True 78 | NerfModel.use_stratified_sampling = True 79 | NerfModel.use_posenc_identity = False 80 | NerfModel.spatial_point_min_deg = %spatial_point_min_deg 81 | NerfModel.spatial_point_max_deg = %spatial_point_max_deg 82 | TrainConfig.nerf_alpha_schedule = ('constant', %spatial_point_max_deg) 83 | 84 | # NeRF Metadata 85 | NerfModel.nerf_embed_cls = @nerf/GLOEmbed 86 | nerf/GLOEmbed.num_dims = 8 87 | 88 | # Warp field configs. 89 | NerfModel.warp_embed_cls = @warp/GLOEmbed 90 | warp/GLOEmbed.num_dims = 8 91 | 92 | SE3Field.min_deg = %warp_min_deg 93 | SE3Field.max_deg = %warp_max_deg 94 | SE3Field.use_posenc_identity = False 95 | NerfModel.warp_field_cls = @SE3Field 96 | 97 | # Hyper point configs. 98 | NerfModel.hyper_embed_cls = @hyper/GLOEmbed 99 | hyper/GLOEmbed.num_dims = %hyper_num_dims 100 | 101 | # Use macros to make sure these are set somewhere. 102 | TrainConfig.batch_size = %batch_size 103 | TrainConfig.max_steps = %max_steps 104 | TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE 105 | TrainConfig.warp_alpha_schedule = %CONSTANT_WARP_ALPHA_SCHED 106 | 107 | # Elastic loss. 108 | TrainConfig.use_elastic_loss = False 109 | TrainConfig.elastic_loss_weight_schedule = %CONSTANT_ELASTIC_LOSS_SCHED 110 | 111 | # Background regularization loss. 112 | TrainConfig.use_background_loss = False 113 | TrainConfig.background_loss_weight = 1.0 114 | 115 | # Script interval configs. 116 | TrainConfig.print_every = 100 117 | TrainConfig.log_every = 500 118 | TrainConfig.save_every = 5000 119 | 120 | EvalConfig.eval_once = False 121 | EvalConfig.save_output = True 122 | EvalConfig.chunk = %eval_batch_size 123 | 124 | EvalConfig.num_val_eval = None 125 | -------------------------------------------------------------------------------- /configs/hypernerf_interp_ap_2d.gin: -------------------------------------------------------------------------------- 1 | include 'configs/defaults.gin' 2 | 3 | image_scale = 2 4 | batch_size = 6144 5 | eval_batch_size = 8096 6 | 7 | max_steps = 250000 8 | lr_decay_steps = 500000 9 | lr_delay_steps = 10000 10 | init_lr = 1e-3 11 | final_lr = 1e-4 12 | TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE 13 | 14 | # Dataset config. 15 | ExperimentConfig.datasource_cls = @InterpDataSource 16 | InterpDataSource.data_dir = %data_dir 17 | InterpDataSource.image_scale = %image_scale 18 | InterpDataSource.interval = 4 19 | 20 | # Basic model config. 21 | NerfModel.num_coarse_samples = 128 22 | NerfModel.num_fine_samples = 128 23 | NerfModel.use_viewdirs = True 24 | NerfModel.norm_type = 'none' 25 | NerfModel.activation = @jax.nn.relu 26 | 27 | NerfModel.use_posenc_identity = True 28 | SE3Field.use_posenc_identity = True 29 | 30 | # NeRF position encoding configs. 31 | spatial_point_min_deg = 0 32 | spatial_point_max_deg = 8 33 | NERF_EASE_ALPHA_SCHEDULE = { 34 | 'type': 'linear', 35 | 'initial_value': 6.0, 36 | 'final_value': %spatial_point_max_deg, 37 | 'num_steps': 80000, 38 | } 39 | 40 | # Hyper config. 41 | hyper_num_dims = 2 42 | hyper_point_min_deg = 0 43 | hyper_point_max_deg = 1 44 | NerfModel.hyper_point_min_deg = %hyper_point_min_deg 45 | NerfModel.hyper_point_max_deg = %hyper_point_max_deg 46 | TrainConfig.hyper_alpha_schedule = { 47 | 'type': 'piecewise', 48 | 'schedules': [ 49 | (1000, ('constant', 0.0)), 50 | (0, ('linear', 0.0, %hyper_point_max_deg, 10000)) 51 | ], 52 | } 53 | 54 | NerfModel.hyper_slice_method = 'axis_aligned_plane' 55 | NerfModel.hyper_sheet_mlp_cls = @HyperSheetMLP 56 | NerfModel.hyper_use_warp_embed = True 57 | 58 | hyper_sheet_min_deg = 0 59 | hyper_sheet_max_deg = 6 60 | HyperSheetMLP.min_deg = %hyper_sheet_min_deg 61 | HyperSheetMLP.max_deg = %hyper_sheet_max_deg 62 | HyperSheetMLP.output_channels = %hyper_num_dims 63 | TrainConfig.hyper_sheet_alpha_schedule = ('constant', %hyper_sheet_max_deg) 64 | 65 | # Warp config. 66 | NerfModel.use_warp = True 67 | warp_min_deg = 0 68 | warp_max_deg = 6 69 | warp_alpha_steps = 80000 70 | TrainConfig.warp_alpha_schedule = { 71 | 'type': 'linear', 72 | 'initial_value': %warp_min_deg, 73 | 'final_value': %warp_max_deg, 74 | 'num_steps': %warp_alpha_steps, 75 | } 76 | 77 | # Train configs. 78 | TrainConfig.use_weight_norm = False 79 | TrainConfig.use_elastic_loss = False 80 | TrainConfig.use_background_loss = False 81 | TrainConfig.background_loss_weight = 1.0 82 | TrainConfig.use_hyper_reg_loss = False 83 | TrainConfig.hyper_reg_loss_weight = 0.0001 84 | 85 | TrainConfig.print_every = 100 86 | TrainConfig.log_every = 1000 87 | TrainConfig.histogram_every = 5000 88 | TrainConfig.save_every = 10000 89 | 90 | EvalConfig.eval_once = False 91 | EvalConfig.save_output = True 92 | -------------------------------------------------------------------------------- /configs/hypernerf_interp_ds_2d.gin: -------------------------------------------------------------------------------- 1 | include 'configs/defaults.gin' 2 | 3 | image_scale = 2 4 | batch_size = 6144 5 | eval_batch_size = 8096 6 | 7 | max_steps = 250000 8 | lr_decay_steps = 500000 9 | lr_delay_steps = 10000 10 | init_lr = 1e-3 11 | final_lr = 1e-4 12 | TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE 13 | 14 | # Dataset config. 15 | ExperimentConfig.datasource_cls = @InterpDataSource 16 | InterpDataSource.data_dir = %data_dir 17 | InterpDataSource.image_scale = %image_scale 18 | InterpDataSource.interval = 4 19 | 20 | # Basic model config. 21 | NerfModel.num_coarse_samples = 128 22 | NerfModel.num_fine_samples = 128 23 | NerfModel.use_viewdirs = True 24 | NerfModel.norm_type = 'none' 25 | NerfModel.activation = @jax.nn.relu 26 | 27 | NerfModel.use_posenc_identity = True 28 | SE3Field.use_posenc_identity = True 29 | 30 | # NeRF position encoding configs. 31 | spatial_point_min_deg = 0 32 | spatial_point_max_deg = 8 33 | NERF_EASE_ALPHA_SCHEDULE = { 34 | 'type': 'linear', 35 | 'initial_value': 6.0, 36 | 'final_value': %spatial_point_max_deg, 37 | 'num_steps': 80000, 38 | } 39 | 40 | # Hyper config. 41 | hyper_num_dims = 2 42 | hyper_point_min_deg = 0 43 | hyper_point_max_deg = 1 44 | NerfModel.hyper_point_min_deg = %hyper_point_min_deg 45 | NerfModel.hyper_point_max_deg = %hyper_point_max_deg 46 | TrainConfig.hyper_alpha_schedule = { 47 | 'type': 'piecewise', 48 | 'schedules': [ 49 | (1000, ('constant', 0.0)), 50 | (0, ('linear', 0.0, %hyper_point_max_deg, 10000)) 51 | ], 52 | } 53 | 54 | NerfModel.hyper_slice_method = 'bendy_sheet' 55 | NerfModel.hyper_sheet_mlp_cls = @HyperSheetMLP 56 | NerfModel.hyper_use_warp_embed = True 57 | 58 | hyper_sheet_min_deg = 0 59 | hyper_sheet_max_deg = 6 60 | HyperSheetMLP.min_deg = %hyper_sheet_min_deg 61 | HyperSheetMLP.max_deg = %hyper_sheet_max_deg 62 | HyperSheetMLP.output_channels = %hyper_num_dims 63 | TrainConfig.hyper_sheet_alpha_schedule = ('constant', %hyper_sheet_max_deg) 64 | 65 | # Warp config. 66 | NerfModel.use_warp = True 67 | warp_min_deg = 0 68 | warp_max_deg = 6 69 | warp_alpha_steps = 80000 70 | TrainConfig.warp_alpha_schedule = { 71 | 'type': 'linear', 72 | 'initial_value': %warp_min_deg, 73 | 'final_value': %warp_max_deg, 74 | 'num_steps': %warp_alpha_steps, 75 | } 76 | 77 | # Train configs. 78 | TrainConfig.use_weight_norm = False 79 | TrainConfig.use_elastic_loss = False 80 | TrainConfig.use_background_loss = False 81 | TrainConfig.background_loss_weight = 1.0 82 | TrainConfig.use_hyper_reg_loss = False 83 | TrainConfig.hyper_reg_loss_weight = 0.0001 84 | 85 | TrainConfig.print_every = 100 86 | TrainConfig.log_every = 1000 87 | TrainConfig.histogram_every = 5000 88 | TrainConfig.save_every = 10000 89 | 90 | EvalConfig.eval_once = False 91 | EvalConfig.save_output = True 92 | -------------------------------------------------------------------------------- /configs/hypernerf_vrig_ap_2d.gin: -------------------------------------------------------------------------------- 1 | include 'configs/defaults.gin' 2 | 3 | max_steps = 250000 4 | lr_decay_steps = %max_steps 5 | 6 | image_scale = 4 7 | batch_size = 6144 8 | eval_batch_size = 8096 9 | init_lr = 0.001 10 | final_lr = 0.0001 11 | elastic_init_weight = 0.001 12 | TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE 13 | 14 | # Basic model config. 15 | NerfModel.nerf_trunk_width = 256 16 | NerfModel.nerf_trunk_depth = 8 17 | NerfModel.num_coarse_samples = 128 18 | NerfModel.num_fine_samples = 128 19 | NerfModel.use_viewdirs = True 20 | NerfModel.use_posenc_identity = True 21 | SE3Field.use_posenc_identity = True 22 | 23 | NerfModel.nerf_embed_key = 'camera' 24 | NerfModel.use_rgb_condition = True 25 | 26 | # NeRF position encoding configs. 27 | spatial_point_min_deg = 0 28 | spatial_point_max_deg = 8 29 | 30 | # Hyper configs. 31 | hyper_num_dims = 2 32 | hyper_point_min_deg = 0 33 | hyper_point_max_deg = 1 34 | NerfModel.hyper_point_min_deg = %hyper_point_min_deg 35 | NerfModel.hyper_point_max_deg = %hyper_point_max_deg 36 | 37 | TrainConfig.hyper_alpha_schedule = %DELAYED_HYPER_ALPHA_SCHED 38 | 39 | NerfModel.hyper_slice_method = 'axis_aligned_plane' 40 | NerfModel.hyper_sheet_mlp_cls = @HyperSheetMLP 41 | NerfModel.hyper_use_warp_embed = True 42 | 43 | hyper_sheet_min_deg = 0 44 | hyper_sheet_max_deg = 6 45 | HyperSheetMLP.min_deg = %hyper_sheet_min_deg 46 | HyperSheetMLP.max_deg = %hyper_sheet_max_deg 47 | HyperSheetMLP.output_channels = %hyper_num_dims 48 | TrainConfig.hyper_sheet_alpha_schedule = ('constant', %hyper_sheet_max_deg) 49 | 50 | # Warp config. 51 | NerfModel.use_warp = True 52 | warp_min_deg = 0 53 | warp_max_deg = 6 54 | warp_alpha_steps = 80000 55 | TrainConfig.warp_alpha_schedule = { 56 | 'type': 'linear', 57 | 'initial_value': %warp_min_deg, 58 | 'final_value': %warp_max_deg, 59 | 'num_steps': %warp_alpha_steps, 60 | } 61 | 62 | TrainConfig.use_weight_norm = False 63 | TrainConfig.use_elastic_loss = False 64 | TrainConfig.use_background_loss = True 65 | TrainConfig.background_loss_weight = 1.0 66 | 67 | TrainConfig.use_warp_reg_loss = False 68 | TrainConfig.warp_reg_loss_weight = 1e-2 69 | 70 | TrainConfig.elastic_reduce_method = 'weight' 71 | TrainConfig.elastic_loss_weight_schedule = { 72 | 'type': 'constant', 73 | 'value': %elastic_init_weight, 74 | } 75 | 76 | 77 | TrainConfig.print_every = 100 78 | TrainConfig.log_every = 100 79 | TrainConfig.histogram_every = 1000 80 | TrainConfig.save_every = 10000 81 | 82 | EvalConfig.num_val_eval = None 83 | EvalConfig.num_train_eval = None 84 | EvalConfig.eval_once = False 85 | EvalConfig.save_output = True 86 | -------------------------------------------------------------------------------- /configs/hypernerf_vrig_ds_2d.gin: -------------------------------------------------------------------------------- 1 | include 'configs/defaults.gin' 2 | 3 | max_steps = 250000 4 | lr_decay_steps = %max_steps 5 | 6 | image_scale = 4 7 | batch_size = 6144 8 | eval_batch_size = 8096 9 | init_lr = 0.001 10 | final_lr = 0.0001 11 | elastic_init_weight = 0.001 12 | TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE 13 | 14 | # Basic model config. 15 | NerfModel.nerf_trunk_width = 256 16 | NerfModel.nerf_trunk_depth = 8 17 | NerfModel.num_coarse_samples = 128 18 | NerfModel.num_fine_samples = 128 19 | NerfModel.use_viewdirs = True 20 | NerfModel.use_posenc_identity = True 21 | SE3Field.use_posenc_identity = True 22 | 23 | NerfModel.nerf_embed_key = 'camera' 24 | NerfModel.use_rgb_condition = True 25 | 26 | # NeRF position encoding configs. 27 | spatial_point_min_deg = 0 28 | spatial_point_max_deg = 8 29 | 30 | # Hyper configs. 31 | hyper_num_dims = 2 32 | hyper_point_min_deg = 0 33 | hyper_point_max_deg = 1 34 | NerfModel.hyper_point_min_deg = %hyper_point_min_deg 35 | NerfModel.hyper_point_max_deg = %hyper_point_max_deg 36 | 37 | TrainConfig.hyper_alpha_schedule = %DELAYED_HYPER_ALPHA_SCHED 38 | 39 | NerfModel.hyper_slice_method = 'bendy_sheet' 40 | NerfModel.hyper_sheet_mlp_cls = @HyperSheetMLP 41 | NerfModel.hyper_use_warp_embed = True 42 | 43 | hyper_sheet_min_deg = 0 44 | hyper_sheet_max_deg = 6 45 | HyperSheetMLP.min_deg = %hyper_sheet_min_deg 46 | HyperSheetMLP.max_deg = %hyper_sheet_max_deg 47 | HyperSheetMLP.output_channels = %hyper_num_dims 48 | TrainConfig.hyper_sheet_alpha_schedule = ('constant', %hyper_sheet_max_deg) 49 | 50 | # Warp config. 51 | NerfModel.use_warp = True 52 | warp_min_deg = 0 53 | warp_max_deg = 6 54 | warp_alpha_steps = 80000 55 | TrainConfig.warp_alpha_schedule = { 56 | 'type': 'linear', 57 | 'initial_value': %warp_min_deg, 58 | 'final_value': %warp_max_deg, 59 | 'num_steps': %warp_alpha_steps, 60 | } 61 | 62 | TrainConfig.use_weight_norm = False 63 | TrainConfig.use_elastic_loss = False 64 | TrainConfig.use_background_loss = True 65 | TrainConfig.background_loss_weight = 1.0 66 | 67 | TrainConfig.use_warp_reg_loss = False 68 | TrainConfig.warp_reg_loss_weight = 1e-2 69 | 70 | TrainConfig.elastic_reduce_method = 'weight' 71 | TrainConfig.elastic_loss_weight_schedule = { 72 | 'type': 'constant', 73 | 'value': %elastic_init_weight, 74 | } 75 | 76 | 77 | TrainConfig.print_every = 100 78 | TrainConfig.log_every = 100 79 | TrainConfig.histogram_every = 1000 80 | TrainConfig.save_every = 10000 81 | 82 | EvalConfig.num_val_eval = None 83 | EvalConfig.num_train_eval = None 84 | EvalConfig.eval_once = False 85 | EvalConfig.save_output = True 86 | -------------------------------------------------------------------------------- /configs/test_local.gin: -------------------------------------------------------------------------------- 1 | include 'configs/defaults.gin' 2 | 3 | image_scale = 4 4 | batch_size = 512 5 | eval_batch_size = 8192 6 | 7 | elastic_init_weight = 0.01 8 | max_steps = 250000 9 | lr_decay_steps = 500000 10 | init_lr = 1e-3 11 | final_lr = 1e-5 12 | 13 | NerfModel.num_coarse_samples = 64 14 | NerfModel.num_fine_samples = 64 15 | NerfModel.use_viewdirs = True 16 | NerfModel.use_stratified_sampling = True 17 | 18 | NerfModel.norm_type = 'none' 19 | NerfModel.activation = @jax.nn.relu 20 | 21 | spatial_point_min_deg = 0 22 | spatial_point_max_deg = 8 23 | 24 | # Hyper config. 25 | hyper_num_dims = 4 26 | hyper_point_min_deg = 0 27 | hyper_point_max_deg = 1 28 | NerfModel.hyper_point_min_deg = %hyper_point_min_deg 29 | NerfModel.hyper_point_max_deg = %hyper_point_max_deg 30 | TrainConfig.hyper_alpha_schedule = { 31 | 'type': 'piecewise', 32 | 'schedules': [ 33 | (1000, ('constant', 0.0)), 34 | (0, ('linear', 0.0, %hyper_point_max_deg, 10000)) 35 | ], 36 | } 37 | 38 | 39 | NerfModel.hyper_slice_method = 'bendy_sheet' 40 | NerfModel.hyper_sheet_mlp_cls = @HyperSheetMLP 41 | NerfModel.hyper_use_warp_embed = True 42 | 43 | hyper_sheet_min_deg = 0 44 | hyper_sheet_max_deg = 6 45 | HyperSheetMLP.min_deg = %hyper_sheet_min_deg 46 | HyperSheetMLP.max_deg = %hyper_sheet_max_deg 47 | HyperSheetMLP.output_channels = %hyper_num_dims 48 | TrainConfig.hyper_sheet_alpha_schedule = ('constant', %hyper_sheet_max_deg) 49 | 50 | NerfModel.use_warp = True 51 | warp_min_deg = 0 52 | warp_max_deg = 4 53 | TrainConfig.warp_alpha_schedule = { 54 | 'type': 'linear', 55 | 'initial_value': %warp_min_deg, 56 | 'final_value': %warp_max_deg, 57 | 'num_steps': 50000, 58 | } 59 | 60 | TrainConfig.use_weight_norm = False 61 | TrainConfig.use_elastic_loss = False 62 | TrainConfig.use_background_loss = False 63 | TrainConfig.background_loss_weight = 1.0 64 | 65 | TrainConfig.use_warp_reg_loss = True 66 | TrainConfig.warp_reg_loss_weight = 0.001 67 | TrainConfig.use_hyper_reg_loss = False 68 | TrainConfig.hyper_reg_loss_weight = 0.001 69 | 70 | TrainConfig.print_every = 10 71 | TrainConfig.log_every = 100 72 | TrainConfig.histogram_every = 100 73 | TrainConfig.save_every = 1000 74 | 75 | EvalConfig.eval_once = False 76 | EvalConfig.save_output = False 77 | EvalConfig.num_train_eval = 5 78 | EvalConfig.num_val_eval = 5 79 | -------------------------------------------------------------------------------- /hypernerf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /hypernerf/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration classes.""" 16 | import dataclasses 17 | from typing import Any, Callable, Optional 18 | 19 | from flax import linen as nn 20 | import gin 21 | import immutabledict 22 | import jax.numpy 23 | 24 | from hypernerf import datasets 25 | 26 | ScheduleDef = Any 27 | 28 | 29 | gin.config.external_configurable(nn.elu, module='flax.nn') 30 | gin.config.external_configurable(nn.relu, module='flax.nn') 31 | gin.config.external_configurable(nn.leaky_relu, module='flax.nn') 32 | gin.config.external_configurable(nn.tanh, module='flax.nn') 33 | gin.config.external_configurable(nn.sigmoid, module='flax.nn') 34 | gin.config.external_configurable(nn.softplus, module='flax.nn') 35 | gin.config.external_configurable(nn.gelu, module='flax.nn') 36 | 37 | gin.config.external_configurable(jax.numpy.sin, module='jax.numpy') 38 | gin.config.external_configurable(jax.nn.relu, module='jax.nn') 39 | gin.config.external_configurable(jax.nn.silu, module='jax.nn') 40 | gin.config.external_configurable(jax.nn.gelu, module='jax.nn') 41 | 42 | 43 | @gin.configurable() 44 | @dataclasses.dataclass 45 | class ExperimentConfig: 46 | """Experiment configuration.""" 47 | # A subname for the experiment e.g., for parameter sweeps. If this is set 48 | # experiment artifacts will be saves to a subdirectory with this name. 49 | subname: Optional[str] = None 50 | # The image scale to use for the dataset. Should be a power of 2. 51 | image_scale: int = 4 52 | # The random seed used to initialize the RNGs for the experiment. 53 | random_seed: int = 12345 54 | # The datasource class. 55 | datasource_cls: Callable[..., datasets.DataSource] = gin.REQUIRED 56 | 57 | 58 | @gin.configurable() 59 | @dataclasses.dataclass 60 | class TrainConfig: 61 | """Parameters for training.""" 62 | batch_size: int = gin.REQUIRED 63 | 64 | # The definition for the learning rate schedule. 65 | lr_schedule: ScheduleDef = immutabledict.immutabledict({ 66 | 'type': 'exponential', 67 | 'initial_value': 0.001, 68 | 'final_value': 0.0001, 69 | 'num_steps': 1000000, 70 | }) 71 | # The maximum number of training steps. 72 | max_steps: int = 1000000 73 | 74 | # Whether to use weight normalization. 75 | use_weight_norm: bool = False 76 | 77 | # The NeRF alpha schedule. 78 | nerf_alpha_schedule: Optional[ScheduleDef] = None 79 | # The warp alpha schedule. 80 | warp_alpha_schedule: Optional[ScheduleDef] = None 81 | # The schedule or the hyper sheet position encoding. 82 | hyper_alpha_schedule: Optional[ScheduleDef] = None 83 | # The schedule or the hyper sheet position encoding. 84 | hyper_sheet_alpha_schedule: Optional[ScheduleDef] = None 85 | 86 | # Whether to use the elastic regularization loss. 87 | use_elastic_loss: bool = False 88 | # The weight of the elastic regularization loss. 89 | elastic_loss_weight_schedule: Optional[ScheduleDef] = None 90 | # Which method to use to reduce the samples for the elastic loss. 91 | # 'weight' computes a weighted sum using the density weights, and 'median' 92 | # selects the sample at the median depth point. 93 | elastic_reduce_method: str = 'weight' 94 | # Which loss method to use for the elastic loss. 95 | elastic_loss_type: str = 'log_svals' 96 | # Whether to use background regularization. 97 | use_background_loss: bool = False 98 | # The weight for the background loss. 99 | background_loss_weight: float = 0.0 100 | # The batch size for background regularization loss. 101 | background_points_batch_size: int = 16384 102 | # Whether to use the warp reg loss. 103 | use_warp_reg_loss: bool = False 104 | # The weight for the warp reg loss. 105 | warp_reg_loss_weight: float = 0.0 106 | # The alpha for the warp reg loss. 107 | warp_reg_loss_alpha: float = -2.0 108 | # The scale for the warp reg loss. 109 | warp_reg_loss_scale: float = 0.001 110 | # Whether to regularize the hyper points. 111 | use_hyper_reg_loss: bool = False 112 | # The weight for the hyper reg loss. 113 | hyper_reg_loss_weight: float = 0.0 114 | 115 | # The size of the shuffle buffer size when shuffling the training dataset. 116 | # This needs to be sufficiently large to contain a diverse set of images in 117 | # each batch, especially when optimizing GLO embeddings. 118 | shuffle_buffer_size: int = 5000000 119 | # How often to save a checkpoint. 120 | save_every: int = 10000 121 | # How often to log to Tensorboard. 122 | log_every: int = 500 123 | # How often to log histograms to Tensorboard. 124 | histogram_every: int = 5000 125 | # How often to print to the console. 126 | print_every: int = 25 127 | 128 | # Unused, here for backwards compatibility. 129 | use_curvature_loss: bool = False 130 | curvature_loss_alpha: int = 0 131 | curvature_loss_scale: float = 0 132 | curvature_loss_spacing: float = 0 133 | curvature_loss_weight_schedule: Optional[Any] = None 134 | 135 | 136 | @gin.configurable() 137 | @dataclasses.dataclass 138 | class EvalConfig: 139 | """Parameters for evaluation.""" 140 | # If True only evaluate the model once, otherwise evaluate any new 141 | # checkpoints. 142 | eval_once: bool = False 143 | # If True save the predicted images to persistent storage. 144 | save_output: bool = True 145 | # The evaluation batch size. 146 | chunk: int = 8192 147 | # Max render checkpoints. The renders will rotate after this many. 148 | max_render_checkpoints: int = 3 149 | 150 | # The subname to append to 'renders' and 'summaries'. 151 | subname: str = '' 152 | 153 | # Unused args here for backwards compatibility. 154 | val_argmin: bool = False 155 | optim_metadata: bool = False 156 | optim_tile_size: int = 0 157 | optim_lr: float = 0.0 158 | 159 | # The number of validation examples to evaluate. (Default: all). 160 | num_val_eval: Optional[int] = 10 161 | # The number of training examples to evaluate. 162 | num_train_eval: Optional[int] = 10 163 | # The number of test examples to evaluate. 164 | num_test_eval: Optional[int] = 10 165 | -------------------------------------------------------------------------------- /hypernerf/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=unused-import 16 | """Dataset definition and utility package.""" 17 | from hypernerf.datasets.core import * 18 | from hypernerf.datasets.interp import InterpDataSource 19 | from hypernerf.datasets.nerfies import NerfiesDataSource 20 | -------------------------------------------------------------------------------- /hypernerf/datasets/interp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Casual Volumetric Capture datasets. 16 | 17 | Note: Please benchmark before submitted changes to this module. It's very easy 18 | to introduce data loading bottlenecks! 19 | """ 20 | import json 21 | from typing import List, Tuple 22 | 23 | from absl import logging 24 | import cv2 25 | import gin 26 | import numpy as np 27 | 28 | from hypernerf import gpath 29 | from hypernerf import types 30 | from hypernerf import utils 31 | from hypernerf.datasets import core 32 | 33 | 34 | def load_scene_info( 35 | data_dir: types.PathType) -> Tuple[np.ndarray, float, float, float]: 36 | """Loads the scene center, scale, near and far from scene.json. 37 | 38 | Args: 39 | data_dir: the path to the dataset. 40 | 41 | Returns: 42 | scene_center: the center of the scene (unscaled coordinates). 43 | scene_scale: the scale of the scene. 44 | near: the near plane of the scene (scaled coordinates). 45 | far: the far plane of the scene (scaled coordinates). 46 | """ 47 | scene_json_path = gpath.GPath(data_dir, 'scene.json') 48 | with scene_json_path.open('r') as f: 49 | scene_json = json.load(f) 50 | 51 | scene_center = np.array(scene_json['center']) 52 | scene_scale = scene_json['scale'] 53 | near = scene_json['near'] 54 | far = scene_json['far'] 55 | 56 | return scene_center, scene_scale, near, far 57 | 58 | 59 | def _load_image(path: types.PathType) -> np.ndarray: 60 | path = gpath.GPath(path) 61 | with path.open('rb') as f: 62 | raw_im = np.asarray(bytearray(f.read()), dtype=np.uint8) 63 | image = cv2.imdecode(raw_im, cv2.IMREAD_COLOR)[:, :, ::-1] # BGR -> RGB 64 | image = np.asarray(image).astype(np.float32) / 255.0 65 | return image 66 | 67 | 68 | def _load_dataset_ids(data_dir: types.PathType) -> Tuple[List[str], List[str]]: 69 | """Loads dataset IDs.""" 70 | dataset_json_path = gpath.GPath(data_dir, 'dataset.json') 71 | logging.info('*** Loading dataset IDs from %s', dataset_json_path) 72 | with dataset_json_path.open('r') as f: 73 | dataset_json = json.load(f) 74 | 75 | return dataset_json['ids'] 76 | 77 | 78 | @gin.configurable 79 | class InterpDataSource(core.DataSource): 80 | """Data loader for videos.""" 81 | 82 | def __init__( 83 | self, 84 | data_dir=gin.REQUIRED, 85 | image_scale: int = gin.REQUIRED, 86 | interval: int = gin.REQUIRED, 87 | shuffle_pixels=False, 88 | camera_type='json', 89 | test_camera_trajectory='orbit-mild', 90 | **kwargs): 91 | self.data_dir = gpath.GPath(data_dir) 92 | if interval < 2 or interval % 2 != 0: 93 | raise ValueError('interval must be a positive even number.') 94 | all_ids = _load_dataset_ids(self.data_dir) 95 | if interval > len(all_ids) - 1: 96 | raise ValueError('interval exceeds dataset size.') 97 | all_indices = np.arange(len(all_ids)) 98 | train_indices = all_indices[::interval] 99 | # Take the middle frames for validation. 100 | val_indices = (train_indices[:-1] + train_indices[1:]) // 2 101 | train_ids = [all_ids[i] for i in train_indices] 102 | val_ids = [all_ids[i] for i in val_indices] 103 | super().__init__(train_ids=train_ids, val_ids=val_ids, **kwargs) 104 | self.scene_center, self.scene_scale, self._near, self._far = ( 105 | load_scene_info(self.data_dir)) 106 | self.test_camera_trajectory = test_camera_trajectory 107 | 108 | self.image_scale = image_scale 109 | self.shuffle_pixels = shuffle_pixels 110 | 111 | self.rgb_dir = gpath.GPath(data_dir, 'rgb', f'{image_scale}x') 112 | self.depth_dir = gpath.GPath(data_dir, 'depth', f'{image_scale}x') 113 | self.camera_type = camera_type 114 | self.camera_dir = gpath.GPath(data_dir, 'camera') 115 | 116 | self.train_metadata_ids = {t_id: i for i, t_id in enumerate(train_ids)} 117 | # The pair of metadata ids for each val id. 118 | self.val_metadata_ids = { 119 | v: (self.train_metadata_ids[l], self.train_metadata_ids[r]) 120 | for v, l, r in zip(val_ids, train_ids[:-1], train_ids[1:]) 121 | } 122 | # The pair of train ids that correspond to each val id. 123 | self.val_pivot_ids = { 124 | v: (l, r) for v, l, r in zip(val_ids, train_ids[:-1], train_ids[1:]) 125 | } 126 | 127 | metadata_path = self.data_dir / 'metadata.json' 128 | with metadata_path.open('r') as f: 129 | self.metadata_dict = json.load(f) 130 | 131 | @property 132 | def near(self): 133 | return self._near 134 | 135 | @property 136 | def far(self): 137 | return self._far 138 | 139 | @property 140 | def camera_ext(self): 141 | if self.camera_type == 'json': 142 | return '.json' 143 | 144 | raise ValueError(f'Unknown camera_type {self.camera_type}') 145 | 146 | def get_rgb_path(self, item_id): 147 | return self.rgb_dir / f'{item_id}.png' 148 | 149 | def load_rgb(self, item_id): 150 | return _load_image(self.rgb_dir / f'{item_id}.png') 151 | 152 | def load_camera(self, item_id, scale_factor=1.0): 153 | if isinstance(item_id, gpath.GPath): 154 | camera_path = item_id 155 | else: 156 | if self.camera_type == 'proto': 157 | camera_path = self.camera_dir / f'{item_id}{self.camera_ext}' 158 | elif self.camera_type == 'json': 159 | camera_path = self.camera_dir / f'{item_id}{self.camera_ext}' 160 | else: 161 | raise ValueError(f'Unknown camera type {self.camera_type!r}.') 162 | 163 | return core.load_camera(camera_path, 164 | scale_factor=scale_factor / self.image_scale, 165 | scene_center=self.scene_center, 166 | scene_scale=self.scene_scale) 167 | 168 | def glob_cameras(self, path): 169 | path = gpath.GPath(path) 170 | return sorted(path.glob(f'*{self.camera_ext}')) 171 | 172 | def load_test_cameras(self, count=None): 173 | camera_dir = (self.data_dir / 'camera-paths' / self.test_camera_trajectory) 174 | if not camera_dir.exists(): 175 | logging.warning('test camera path does not exist: %s', str(camera_dir)) 176 | return [] 177 | camera_paths = sorted(camera_dir.glob(f'*{self.camera_ext}')) 178 | if count is not None: 179 | stride = max(1, len(camera_paths) // count) 180 | camera_paths = camera_paths[::stride] 181 | cameras = utils.parallel_map(self.load_camera, camera_paths) 182 | return cameras 183 | 184 | def load_points(self, shuffle=False): 185 | with (self.data_dir / 'points.npy').open('rb') as f: 186 | points = np.load(f) 187 | points = (points - self.scene_center) * self.scene_scale 188 | points = points.astype(np.float32) 189 | if shuffle: 190 | logging.info('Shuffling points.') 191 | shuffled_inds = self.rng.permutation(len(points)) 192 | points = points[shuffled_inds] 193 | logging.info('Loaded %d points.', len(points)) 194 | return points 195 | 196 | def _get_metadata_id(self, item_id): 197 | if item_id in self.train_metadata_ids: 198 | return self.train_metadata_ids[item_id] 199 | elif item_id in self.val_metadata_ids: 200 | # If the metadata ID is a pair then define a linear interpolation. 201 | left_id, right_id = self.val_pivot_ids[item_id] 202 | item_ts = float(self.metadata_dict[item_id]['time_id']) 203 | left_ts = float(self.metadata_dict[left_id]['time_id']) 204 | right_ts = float(self.metadata_dict[right_id]['time_id']) 205 | # Compute what interval the middle frame lies on between the left 206 | # and right frames. This should be between 0 and 1. 207 | progression = (item_ts - left_ts) / (right_ts - left_ts) 208 | assert 0.0 <= progression <= 1.0 209 | left_metadata = self.train_metadata_ids[left_id] 210 | right_metadata = self.train_metadata_ids[right_id] 211 | return left_metadata, right_metadata, progression 212 | else: 213 | raise RuntimeError(f'Metadata for item_id {item_id} not known') 214 | 215 | def get_appearance_id(self, item_id): 216 | return self._get_metadata_id(item_id) 217 | 218 | def get_camera_id(self, item_id): 219 | raise NotImplementedError() 220 | 221 | def get_warp_id(self, item_id): 222 | return self._get_metadata_id(item_id) 223 | 224 | def get_time_id(self, item_id): 225 | return self._get_metadata_id(item_id) 226 | -------------------------------------------------------------------------------- /hypernerf/datasets/nerfies.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Casual Volumetric Capture datasets. 16 | 17 | Note: Please benchmark before submitted changes to this module. It's very easy 18 | to introduce data loading bottlenecks! 19 | """ 20 | import json 21 | from typing import List, Tuple 22 | 23 | from absl import logging 24 | import cv2 25 | import gin 26 | import numpy as np 27 | 28 | from hypernerf import camera as cam 29 | from hypernerf import gpath 30 | from hypernerf import types 31 | from hypernerf import utils 32 | from hypernerf.datasets import core 33 | 34 | 35 | def load_scene_info( 36 | data_dir: types.PathType) -> Tuple[np.ndarray, float, float, float]: 37 | """Loads the scene center, scale, near and far from scene.json. 38 | 39 | Args: 40 | data_dir: the path to the dataset. 41 | 42 | Returns: 43 | scene_center: the center of the scene (unscaled coordinates). 44 | scene_scale: the scale of the scene. 45 | near: the near plane of the scene (scaled coordinates). 46 | far: the far plane of the scene (scaled coordinates). 47 | """ 48 | scene_json_path = gpath.GPath(data_dir, 'scene.json') 49 | with scene_json_path.open('r') as f: 50 | scene_json = json.load(f) 51 | 52 | scene_center = np.array(scene_json['center']) 53 | scene_scale = scene_json['scale'] 54 | near = scene_json['near'] 55 | far = scene_json['far'] 56 | 57 | return scene_center, scene_scale, near, far 58 | 59 | 60 | def _load_image(path: types.PathType) -> np.ndarray: 61 | path = gpath.GPath(path) 62 | with path.open('rb') as f: 63 | raw_im = np.asarray(bytearray(f.read()), dtype=np.uint8) 64 | image = cv2.imdecode(raw_im, cv2.IMREAD_COLOR)[:, :, ::-1] # BGR -> RGB 65 | image = np.asarray(image).astype(np.float32) / 255.0 66 | return image 67 | 68 | 69 | def _load_dataset_ids(data_dir: types.PathType) -> Tuple[List[str], List[str]]: 70 | """Loads dataset IDs.""" 71 | dataset_json_path = gpath.GPath(data_dir, 'dataset.json') 72 | logging.info('*** Loading dataset IDs from %s', dataset_json_path) 73 | with dataset_json_path.open('r') as f: 74 | dataset_json = json.load(f) 75 | train_ids = dataset_json['train_ids'] 76 | val_ids = dataset_json['val_ids'] 77 | 78 | train_ids = [str(i) for i in train_ids] 79 | val_ids = [str(i) for i in val_ids] 80 | 81 | return train_ids, val_ids 82 | 83 | 84 | @gin.configurable 85 | class NerfiesDataSource(core.DataSource): 86 | """Data loader for videos.""" 87 | 88 | def __init__(self, 89 | data_dir: str = gin.REQUIRED, 90 | image_scale: int = gin.REQUIRED, 91 | shuffle_pixels: bool = False, 92 | camera_type: str = 'json', 93 | test_camera_trajectory: str = 'orbit-mild', 94 | **kwargs): 95 | self.data_dir = gpath.GPath(data_dir) 96 | # Load IDs from JSON if it exists. This is useful since COLMAP fails on 97 | # some images so this gives us the ability to skip invalid images. 98 | train_ids, val_ids = _load_dataset_ids(self.data_dir) 99 | super().__init__(train_ids=train_ids, val_ids=val_ids, 100 | **kwargs) 101 | self.scene_center, self.scene_scale, self._near, self._far = ( 102 | load_scene_info(self.data_dir)) 103 | self.test_camera_trajectory = test_camera_trajectory 104 | 105 | self.image_scale = image_scale 106 | self.shuffle_pixels = shuffle_pixels 107 | 108 | self.rgb_dir = gpath.GPath(data_dir, 'rgb', f'{image_scale}x') 109 | self.depth_dir = gpath.GPath(data_dir, 'depth', f'{image_scale}x') 110 | if camera_type not in ['json']: 111 | raise ValueError('The camera type needs to be json.') 112 | self.camera_type = camera_type 113 | self.camera_dir = gpath.GPath(data_dir, 'camera') 114 | 115 | metadata_path = self.data_dir / 'metadata.json' 116 | if metadata_path.exists(): 117 | with metadata_path.open('r') as f: 118 | self.metadata_dict = json.load(f) 119 | 120 | @property 121 | def near(self) -> float: 122 | return self._near 123 | 124 | @property 125 | def far(self) -> float: 126 | return self._far 127 | 128 | @property 129 | def camera_ext(self) -> str: 130 | if self.camera_type == 'json': 131 | return '.json' 132 | 133 | raise ValueError(f'Unknown camera_type {self.camera_type}') 134 | 135 | def get_rgb_path(self, item_id: str) -> types.PathType: 136 | return self.rgb_dir / f'{item_id}.png' 137 | 138 | def load_rgb(self, item_id: str) -> np.ndarray: 139 | return _load_image(self.rgb_dir / f'{item_id}.png') 140 | 141 | def load_camera(self, 142 | item_id: types.PathType, 143 | scale_factor: float = 1.0) -> cam.Camera: 144 | if isinstance(item_id, gpath.GPath): 145 | camera_path = item_id 146 | else: 147 | if self.camera_type == 'proto': 148 | camera_path = self.camera_dir / f'{item_id}{self.camera_ext}' 149 | elif self.camera_type == 'json': 150 | camera_path = self.camera_dir / f'{item_id}{self.camera_ext}' 151 | else: 152 | raise ValueError(f'Unknown camera type {self.camera_type!r}.') 153 | 154 | return core.load_camera(camera_path, 155 | scale_factor=scale_factor / self.image_scale, 156 | scene_center=self.scene_center, 157 | scene_scale=self.scene_scale) 158 | 159 | def glob_cameras(self, path): 160 | path = gpath.GPath(path) 161 | return sorted(path.glob(f'*{self.camera_ext}')) 162 | 163 | def load_test_cameras(self, count=None): 164 | camera_dir = (self.data_dir / 'camera-paths' / self.test_camera_trajectory) 165 | if not camera_dir.exists(): 166 | logging.warning('test camera path does not exist: %s', str(camera_dir)) 167 | return [] 168 | camera_paths = sorted(camera_dir.glob(f'*{self.camera_ext}')) 169 | if count is not None: 170 | stride = max(1, len(camera_paths) // count) 171 | camera_paths = camera_paths[::stride] 172 | cameras = utils.parallel_map(self.load_camera, camera_paths) 173 | return cameras 174 | 175 | def load_points(self, shuffle=False): 176 | with (self.data_dir / 'points.npy').open('rb') as f: 177 | points = np.load(f) 178 | points = (points - self.scene_center) * self.scene_scale 179 | points = points.astype(np.float32) 180 | if shuffle: 181 | logging.info('Shuffling points.') 182 | shuffled_inds = self.rng.permutation(len(points)) 183 | points = points[shuffled_inds] 184 | logging.info('Loaded %d points.', len(points)) 185 | return points 186 | 187 | def get_appearance_id(self, item_id): 188 | return self.metadata_dict[item_id]['appearance_id'] 189 | 190 | def get_camera_id(self, item_id): 191 | return self.metadata_dict[item_id]['camera_id'] 192 | 193 | def get_warp_id(self, item_id): 194 | return self.metadata_dict[item_id]['warp_id'] 195 | 196 | def get_time_id(self, item_id): 197 | if 'time_id' in self.metadata_dict[item_id]: 198 | return self.metadata_dict[item_id]['time_id'] 199 | else: 200 | # Fallback for older datasets. 201 | return self.metadata_dict[item_id]['warp_id'] 202 | -------------------------------------------------------------------------------- /hypernerf/dual_quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dual quaternion math. 16 | 17 | We encode a dual quaternion as an 8-dimensional array: 18 | [rx, ry, rz, rw, dx, dy, dz, dw] 19 | which represent the dual quaternion: 20 | r + εd = (rx*i, ry*j, rz*k, rw) + ε(dx*i, dy*j, dz*k, dw) 21 | 22 | References: 23 | https://en.wikipedia.org/wiki/Dual_quaternion 24 | """ 25 | from jax import numpy as jnp 26 | from hypernerf import quaternion 27 | 28 | 29 | def real_part(dq): 30 | """Returns the real part of the dual quaternion.""" 31 | return dq[..., :4] 32 | 33 | 34 | def dual_part(dq): 35 | """Returns the dual part of the dual quaternion.""" 36 | return dq[..., 4:] 37 | 38 | 39 | def split_parts(dq): 40 | """Splits the dual quaternion into its real and dual parts.""" 41 | return real_part(dq), dual_part(dq) 42 | 43 | 44 | def join_parts(real, dual): 45 | """Creates a dual quaternion from its real and dual parts.""" 46 | return jnp.concatenate((real, dual), axis=-1) 47 | 48 | 49 | def identity(dtype=jnp.float32): 50 | """Returns the dual quaternion encoding an identity transform.""" 51 | return jnp.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], dtype=dtype) 52 | 53 | 54 | def add(dq1, dq2): 55 | """Adds two dual quaternions.""" 56 | return dq1 + dq2 57 | 58 | 59 | def multiply(dq1, dq2): 60 | """Dual quaternion multiplication. 61 | 62 | Args: 63 | dq1: a (*,8) dimensional array representing a dual quaternion. 64 | dq2: a (*,8) dimensional array representing a dual quaternion. 65 | 66 | Returns: 67 | A (*,8) dimensional array representing the output dual quaternion. 68 | """ 69 | a, b = split_parts(dq1) 70 | c, d = split_parts(dq2) 71 | 72 | real = quaternion.multiply(a, c) 73 | dual = quaternion.multiply(a, d) + quaternion.multiply(b, c) 74 | 75 | return join_parts(real, dual) 76 | 77 | 78 | def quaternion_conjugate(dq): 79 | """Returns the quaternion conjugate.""" 80 | real, dual = split_parts(dq) 81 | return join_parts(quaternion.conjugate(real), quaternion.conjugate(dual)) 82 | 83 | 84 | def dual_conjugate(dq): 85 | """Returns the dual number conjugate.""" 86 | real, dual = split_parts(dq) 87 | return join_parts(real, -dual) 88 | 89 | 90 | def quaternion_dual_conjugate(dq): 91 | """Returns the dual number and quaternion conjugate.""" 92 | real, dual = split_parts(dq) 93 | return join_parts(-quaternion.conjugate(real), -quaternion.conjugate(dual)) 94 | 95 | 96 | def normalize(dq): 97 | """Normalize a dual quaternion.""" 98 | real, dual = split_parts(dq) 99 | real_norm = quaternion.norm(real) 100 | return join_parts(real / real_norm, dual / real_norm) 101 | 102 | 103 | def get_rotation(dq): 104 | """Returns a rotation quaternion this dual quaternion encodes.""" 105 | return real_part(dq) 106 | 107 | 108 | def get_translation(dq): 109 | """Returns a translation vector this dual quaternion encodes.""" 110 | real, dual = split_parts(dq) 111 | return 2 * quaternion.im( 112 | quaternion.multiply(dual, quaternion.conjugate(real))) 113 | 114 | 115 | def from_rotation_translation(q, t): 116 | """Creates a dual quaternion from a rotation and translation. 117 | 118 | Args: 119 | q: a (*,4) array containing a rotation quaternion. 120 | t: a (*,3) array containing a translation vector. 121 | 122 | Returns: 123 | A (*,8) array containing a dual quaternion. 124 | """ 125 | # Pad t = [t; 0] 126 | t = jnp.concatenate((t, jnp.zeros_like(t[..., -1:])), axis=-1) 127 | dq_t = join_parts(quaternion.identity(), 0.5 * t) 128 | dq_r = join_parts(q, jnp.zeros_like(q)) 129 | return multiply(dq_t, dq_r) 130 | -------------------------------------------------------------------------------- /hypernerf/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for evaluating a trained NeRF.""" 16 | import math 17 | import time 18 | 19 | from absl import logging 20 | from flax import jax_utils 21 | import jax 22 | from jax import tree_util 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | from hypernerf import utils 27 | 28 | 29 | def encode_metadata(model, params, metadata): 30 | """Encodes metadata embeddings. 31 | 32 | Args: 33 | model: a NerfModel. 34 | params: the parameters of the model. 35 | metadata: the metadata dict. 36 | 37 | Returns: 38 | A new metadata dict with the encoded embeddings. 39 | """ 40 | encoded_metadata = {} 41 | if model.use_nerf_embed: 42 | encoded_metadata['encoded_nerf'] = model.apply( 43 | {'params': params}, metadata, method=model.encode_nerf_embed) 44 | if model.use_warp: 45 | encoded_metadata['encoded_warp'] = model.apply( 46 | {'params': params}, metadata, method=model.encode_warp_embed) 47 | if model.has_hyper_embed: 48 | encoded_metadata['encoded_hyper'] = model.apply( 49 | {'params': params}, metadata, method=model.encode_hyper_embed) 50 | return encoded_metadata 51 | 52 | 53 | def render_image( 54 | state, 55 | rays_dict, 56 | model_fn, 57 | device_count, 58 | rng, 59 | chunk=8192, 60 | default_ret_key=None): 61 | """Render all the pixels of an image (in test mode). 62 | 63 | Args: 64 | state: model_utils.TrainState. 65 | rays_dict: dict, test example. 66 | model_fn: function, jit-ed render function. 67 | device_count: The number of devices to shard batches over. 68 | rng: The random number generator. 69 | chunk: int, the size of chunks to render sequentially. 70 | default_ret_key: either 'fine' or 'coarse'. If None will default to highest. 71 | 72 | Returns: 73 | rgb: jnp.ndarray, rendered color image. 74 | depth: jnp.ndarray, rendered depth. 75 | acc: jnp.ndarray, rendered accumulated weights per pixel. 76 | """ 77 | batch_shape = rays_dict['origins'].shape[:-1] 78 | num_rays = np.prod(batch_shape) 79 | rays_dict = tree_util.tree_map(lambda x: x.reshape((num_rays, -1)), rays_dict) 80 | _, key_0, key_1 = jax.random.split(rng, 3) 81 | key_0 = jax.random.split(key_0, device_count) 82 | key_1 = jax.random.split(key_1, device_count) 83 | proc_id = jax.process_index() 84 | ret_maps = [] 85 | start_time = time.time() 86 | num_batches = int(math.ceil(num_rays / chunk)) 87 | logging.info('Rendering: num_batches = %d, num_rays = %d, chunk = %d', 88 | num_batches, num_rays, chunk) 89 | for batch_idx in range(num_batches): 90 | ray_idx = batch_idx * chunk 91 | logging.log_every_n_seconds( 92 | logging.INFO, 'Rendering batch %d/%d (%d/%d)', 2.0, 93 | batch_idx, num_batches, ray_idx, num_rays) 94 | # pylint: disable=cell-var-from-loop 95 | chunk_slice_fn = lambda x: x[ray_idx:ray_idx + chunk] 96 | chunk_rays_dict = tree_util.tree_map(chunk_slice_fn, rays_dict) 97 | num_chunk_rays = chunk_rays_dict['origins'].shape[0] 98 | remainder = num_chunk_rays % device_count 99 | if remainder != 0: 100 | padding = device_count - remainder 101 | # pylint: disable=cell-var-from-loop 102 | chunk_pad_fn = lambda x: jnp.pad(x, ((0, padding), (0, 0)), mode='edge') 103 | chunk_rays_dict = tree_util.tree_map(chunk_pad_fn, chunk_rays_dict) 104 | else: 105 | padding = 0 106 | # After padding the number of chunk_rays is always divisible by 107 | # proc_count. 108 | per_proc_rays = num_chunk_rays // jax.process_count() 109 | logging.debug( 110 | 'Rendering batch: num_chunk_rays = %d, padding = %d, remainder = %d, ' 111 | 'per_proc_rays = %d', 112 | num_chunk_rays, padding, remainder, per_proc_rays) 113 | chunk_rays_dict = tree_util.tree_map( 114 | lambda x: x[(proc_id * per_proc_rays):((proc_id + 1) * per_proc_rays)], 115 | chunk_rays_dict) 116 | chunk_rays_dict = utils.shard(chunk_rays_dict, device_count) 117 | model_out = model_fn(key_0, key_1, state.optimizer.target['model'], 118 | chunk_rays_dict, state.extra_params) 119 | if not default_ret_key: 120 | ret_key = 'fine' if 'fine' in model_out else 'coarse' 121 | else: 122 | ret_key = default_ret_key 123 | ret_map = jax_utils.unreplicate(model_out[ret_key]) 124 | ret_map = jax.tree_map(lambda x: utils.unshard(x, padding), ret_map) 125 | ret_maps.append(ret_map) 126 | ret_map = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *ret_maps) 127 | logging.info('Rendering took %.04s', time.time() - start_time) 128 | out = {} 129 | for key, value in ret_map.items(): 130 | out_shape = (*batch_shape, *value.shape[1:]) 131 | logging.debug('Reshaping %s of shape %s to %s', 132 | key, str(value.shape), str(out_shape)) 133 | out[key] = value.reshape(out_shape) 134 | 135 | return out 136 | -------------------------------------------------------------------------------- /hypernerf/gpath.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A thin wrapper around pathlib.""" 16 | import pathlib 17 | import tensorflow as tf 18 | 19 | 20 | class GPath(pathlib.PurePosixPath): 21 | """A thin wrapper around PurePath to support various filesystems.""" 22 | 23 | def open(self, *args, **kwargs): 24 | return tf.io.gfile.GFile(self, *args, **kwargs) 25 | 26 | def exists(self): 27 | return tf.io.gfile.exists(self) 28 | 29 | # pylint: disable=unused-argument 30 | def mkdir(self, mode=0o777, parents=False, exist_ok=False): 31 | if not exist_ok: 32 | if self.exists(): 33 | raise FileExistsError('Directory already exists.') 34 | 35 | if parents: 36 | return tf.io.gfile.makedirs(self) 37 | else: 38 | return tf.io.gfile.mkdir(self) 39 | 40 | def glob(self, pattern): 41 | return [GPath(x) for x in tf.io.gfile.glob(str(self / pattern))] 42 | 43 | def iterdir(self): 44 | return [GPath(self, x) for x in tf.io.gfile.listdir(self)] 45 | 46 | def is_dir(self): 47 | return tf.io.gfile.isdir(self) 48 | 49 | def rmtree(self): 50 | tf.io.gfile.rmtree(self) 51 | -------------------------------------------------------------------------------- /hypernerf/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Image-related utility functions.""" 16 | import math 17 | from typing import Tuple 18 | 19 | from absl import logging 20 | import cv2 21 | import imageio 22 | import numpy as np 23 | from PIL import Image 24 | 25 | from hypernerf import gpath 26 | from hypernerf import types 27 | 28 | 29 | UINT8_MAX = 255 30 | UINT16_MAX = 65535 31 | 32 | 33 | def make_divisible(image: np.ndarray, divisor: int) -> np.ndarray: 34 | """Trim the image if not divisible by the divisor.""" 35 | height, width = image.shape[:2] 36 | if height % divisor == 0 and width % divisor == 0: 37 | return image 38 | 39 | new_height = height - height % divisor 40 | new_width = width - width % divisor 41 | 42 | return image[:new_height, :new_width] 43 | 44 | 45 | def downsample_image(image: np.ndarray, scale: int) -> np.ndarray: 46 | """Downsamples the image by an integer factor to prevent artifacts.""" 47 | if scale == 1: 48 | return image 49 | 50 | height, width = image.shape[:2] 51 | if height % scale > 0 or width % scale > 0: 52 | raise ValueError(f'Image shape ({height},{width}) must be divisible by the' 53 | f' scale ({scale}).') 54 | out_height, out_width = height // scale, width // scale 55 | resized = cv2.resize(image, (out_width, out_height), cv2.INTER_AREA) 56 | return resized 57 | 58 | 59 | def upsample_image(image: np.ndarray, scale: int) -> np.ndarray: 60 | """Upsamples the image by an integer factor.""" 61 | if scale == 1: 62 | return image 63 | 64 | height, width = image.shape[:2] 65 | out_height, out_width = height * scale, width * scale 66 | resized = cv2.resize(image, (out_width, out_height), cv2.INTER_AREA) 67 | return resized 68 | 69 | 70 | def reshape_image(image: np.ndarray, shape: Tuple[int, int]) -> np.ndarray: 71 | """Reshapes the image to the given shape.""" 72 | out_height, out_width = shape 73 | return cv2.resize( 74 | image, (out_width, out_height), interpolation=cv2.INTER_AREA) 75 | 76 | 77 | def rescale_image(image: np.ndarray, scale_factor: float) -> np.ndarray: 78 | """Resize an image by a scale factor, using integer resizing if possible.""" 79 | scale_factor = float(scale_factor) 80 | if scale_factor <= 0.0: 81 | raise ValueError('scale_factor must be a non-negative number.') 82 | 83 | if scale_factor == 1.0: 84 | return image 85 | 86 | height, width = image.shape[:2] 87 | if scale_factor.is_integer(): 88 | return upsample_image(image, int(scale_factor)) 89 | 90 | inv_scale = 1.0 / scale_factor 91 | if (inv_scale.is_integer() and (scale_factor * height).is_integer() and 92 | (scale_factor * width).is_integer()): 93 | return downsample_image(image, int(inv_scale)) 94 | 95 | logging.warning( 96 | 'resizing image by non-integer factor %f, this may lead to artifacts.', 97 | scale_factor) 98 | 99 | height, width = image.shape[:2] 100 | out_height = math.ceil(height * scale_factor) 101 | out_height -= out_height % 2 102 | out_width = math.ceil(width * scale_factor) 103 | out_width -= out_width % 2 104 | 105 | return reshape_image(image, (out_height, out_width)) 106 | 107 | 108 | def crop_image(image, left=0, right=0, top=0, bottom=0): 109 | pad_width = [max(0, -x) for x in [top, bottom, left, right]] 110 | if any(pad_width): 111 | image = np.pad(image, pad_width=pad_width, mode='constant') 112 | h, w = image.shape[:2] 113 | crop_coords = [max(0, x) for x in (top, bottom, left, right)] 114 | return image[crop_coords[0]:h - crop_coords[1], 115 | crop_coords[2]:w - crop_coords[3], :] 116 | 117 | 118 | def variance_of_laplacian(image: np.ndarray) -> np.ndarray: 119 | """Compute the variance of the Laplacian which measure the focus.""" 120 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 121 | return cv2.Laplacian(gray, cv2.CVX_64F).var() 122 | 123 | 124 | def image_to_uint8(image: np.ndarray) -> np.ndarray: 125 | """Convert the image to a uint8 array.""" 126 | if image.dtype == np.uint8: 127 | return image 128 | if not issubclass(image.dtype.type, np.floating): 129 | raise ValueError( 130 | f'Input image should be a floating type but is of type {image.dtype!r}') 131 | return (image * UINT8_MAX).clip(0.0, UINT8_MAX).astype(np.uint8) 132 | 133 | 134 | def image_to_uint16(image: np.ndarray) -> np.ndarray: 135 | """Convert the image to a uint16 array.""" 136 | if image.dtype == np.uint16: 137 | return image 138 | if not issubclass(image.dtype.type, np.floating): 139 | raise ValueError( 140 | f'Input image should be a floating type but is of type {image.dtype!r}') 141 | return (image * UINT16_MAX).clip(0.0, UINT16_MAX).astype(np.uint16) 142 | 143 | 144 | def image_to_float32(image: np.ndarray) -> np.ndarray: 145 | """Convert the image to a float32 array and scale values appropriately.""" 146 | if image.dtype == np.float32: 147 | return image 148 | 149 | dtype = image.dtype 150 | image = image.astype(np.float32) 151 | if dtype == np.uint8: 152 | return image / UINT8_MAX 153 | elif dtype == np.uint16: 154 | return image / UINT16_MAX 155 | elif dtype == np.float64: 156 | return image 157 | elif dtype == np.float16: 158 | return image 159 | 160 | raise ValueError(f'Not sure how to handle dtype {dtype}') 161 | 162 | 163 | def load_image(path: types.PathType) -> np.ndarray: 164 | """Reads an image.""" 165 | if not isinstance(path, gpath.GPath): 166 | path = gpath.GPath(path) 167 | 168 | with path.open('rb') as f: 169 | return imageio.imread(f) 170 | 171 | 172 | def save_image(path: types.PathType, image: np.ndarray) -> None: 173 | """Saves the image to disk or gfile.""" 174 | if not isinstance(path, gpath.GPath): 175 | path = gpath.GPath(path) 176 | 177 | if not path.parent.exists(): 178 | path.parent.mkdir(exist_ok=True, parents=True) 179 | 180 | with path.open('wb') as f: 181 | image = Image.fromarray(np.asarray(image)) 182 | image.save(f, format=path.suffix.lstrip('.')) 183 | 184 | 185 | def save_depth(path: types.PathType, depth: np.ndarray) -> None: 186 | save_image(path, image_to_uint16(depth / 1000.0)) 187 | 188 | 189 | def load_depth(path: types.PathType) -> np.ndarray: 190 | depth = load_image(path) 191 | if depth.dtype != np.uint16: 192 | raise ValueError('Depth image must be of type uint16.') 193 | return image_to_float32(depth) * 1000.0 194 | 195 | 196 | def checkerboard(h, w, size=8, true_val=1.0, false_val=0.0): 197 | """Creates a checkerboard pattern with height h and width w.""" 198 | i = int(math.ceil(h / (size * 2))) 199 | j = int(math.ceil(w / (size * 2))) 200 | pattern = np.kron([[1, 0] * j, [0, 1] * j] * i, 201 | np.ones((size, size)))[:h, :w] 202 | 203 | true = np.full_like(pattern, fill_value=true_val) 204 | false = np.full_like(pattern, fill_value=false_val) 205 | return np.where(pattern > 0, true, false) 206 | 207 | 208 | def pad_image(image, pad=0, pad_mode='constant', pad_value=0.0): 209 | """Pads a batched image array.""" 210 | batch_shape = image.shape[:-3] 211 | padding = [ 212 | *[(0, 0) for _ in batch_shape], 213 | (pad, pad), (pad, pad), (0, 0), 214 | ] 215 | if pad_mode == 'constant': 216 | return np.pad(image, padding, pad_mode, constant_values=pad_value) 217 | else: 218 | return np.pad(image, padding, pad_mode) 219 | 220 | 221 | def split_tiles(image, tile_size): 222 | """Splits the image into tiles of size `tile_size`.""" 223 | # The copy is necessary due to the use of the memory layout. 224 | if image.ndim == 2: 225 | image = image[..., None] 226 | image = np.array(image) 227 | image = make_divisible(image, tile_size).copy() 228 | height = width = tile_size 229 | nrows, ncols, depth = image.shape 230 | stride = image.strides 231 | 232 | nrows, m = divmod(nrows, height) 233 | ncols, n = divmod(ncols, width) 234 | if m != 0 or n != 0: 235 | raise ValueError('Image must be divisible by tile size.') 236 | 237 | return np.lib.stride_tricks.as_strided( 238 | np.ravel(image), 239 | shape=(nrows, ncols, height, width, depth), 240 | strides=(height * stride[0], width * stride[1], *stride), 241 | writeable=False) 242 | 243 | 244 | def join_tiles(tiles): 245 | """Reconstructs the image from tiles.""" 246 | return np.concatenate(np.concatenate(tiles, 1), 1) 247 | 248 | 249 | def make_grid(batch, grid_height=None, zoom=1, old_buffer=None, border_size=1): 250 | """Creates a grid out an image batch. 251 | 252 | Args: 253 | batch: numpy array of shape [batch_size, height, width, n_channels]. The 254 | data can either be float in [0, 1] or int in [0, 255]. If the data has 255 | only 1 channel it will be converted to a grey 3 channel image. 256 | grid_height: optional int, number of rows to have. If not given, it is 257 | set so that the output is a square. If -1, then tiling will only be 258 | vertical. 259 | zoom: optional int, how much to zoom the input. Default is no zoom. 260 | old_buffer: Buffer to write grid into if possible. If not set, or if shape 261 | doesn't match, we create a new buffer. 262 | border_size: int specifying the white spacing between the images. 263 | 264 | Returns: 265 | A numpy array corresponding to the full grid, with 3 channels and values 266 | in the [0, 255] range. 267 | 268 | Raises: 269 | ValueError: if the n_channels is not one of [1, 3]. 270 | """ 271 | 272 | batch_size, height, width, n_channels = batch.shape 273 | 274 | if grid_height is None: 275 | n = int(math.ceil(math.sqrt(batch_size))) 276 | grid_height = n 277 | grid_width = n 278 | elif grid_height == -1: 279 | grid_height = batch_size 280 | grid_width = 1 281 | else: 282 | grid_width = int(math.ceil(batch_size/grid_height)) 283 | 284 | if n_channels == 1: 285 | batch = np.tile(batch, (1, 1, 1, 3)) 286 | n_channels = 3 287 | 288 | if n_channels != 3: 289 | raise ValueError('Image batch must have either 1 or 3 channels, but ' 290 | 'was {}'.format(n_channels)) 291 | 292 | # We create the numpy buffer if we don't have an old buffer or if the size has 293 | # changed. 294 | shape = (height * grid_height + border_size * (grid_height - 1), 295 | width * grid_width + border_size * (grid_width - 1), 296 | n_channels) 297 | if old_buffer is not None and old_buffer.shape == shape: 298 | buf = old_buffer 299 | else: 300 | buf = np.full(shape, 255, dtype=np.uint8) 301 | 302 | multiplier = 1 if np.issubdtype(batch.dtype, np.integer) else 255 303 | 304 | for k in range(batch_size): 305 | i = k // grid_width 306 | j = k % grid_width 307 | arr = batch[k] 308 | x, y = i * (height + border_size), j * (width + border_size) 309 | buf[x:x + height, y:y + width, :] = np.clip(multiplier * arr, 310 | 0, 255).astype(np.uint8) 311 | 312 | if zoom > 1: 313 | buf = buf.repeat(zoom, axis=0).repeat(zoom, axis=1) 314 | return buf 315 | -------------------------------------------------------------------------------- /hypernerf/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions/classes for model definition.""" 16 | from typing import Optional 17 | 18 | from flax import linen as nn 19 | from flax import optim 20 | from flax import struct 21 | from jax import lax 22 | from jax import random 23 | import jax.numpy as jnp 24 | 25 | 26 | @struct.dataclass 27 | class TrainState: 28 | """Stores training state, including the optimizer and model params.""" 29 | optimizer: optim.Optimizer 30 | nerf_alpha: Optional[jnp.ndarray] = None 31 | warp_alpha: Optional[jnp.ndarray] = None 32 | hyper_alpha: Optional[jnp.ndarray] = None 33 | hyper_sheet_alpha: Optional[jnp.ndarray] = None 34 | 35 | @property 36 | def extra_params(self): 37 | return { 38 | 'nerf_alpha': self.nerf_alpha, 39 | 'warp_alpha': self.warp_alpha, 40 | 'hyper_alpha': self.hyper_alpha, 41 | 'hyper_sheet_alpha': self.hyper_sheet_alpha, 42 | } 43 | 44 | 45 | def sample_along_rays(key, origins, directions, num_coarse_samples, near, far, 46 | use_stratified_sampling, use_linear_disparity): 47 | """Stratified sampling along the rays. 48 | 49 | Args: 50 | key: jnp.ndarray, random generator key. 51 | origins: ray origins. 52 | directions: ray directions. 53 | num_coarse_samples: int. 54 | near: float, near clip. 55 | far: float, far clip. 56 | use_stratified_sampling: use stratified sampling. 57 | use_linear_disparity: sampling linearly in disparity rather than depth. 58 | 59 | Returns: 60 | z_vals: jnp.ndarray, [batch_size, num_coarse_samples], sampled z values. 61 | points: jnp.ndarray, [batch_size, num_coarse_samples, 3], sampled points. 62 | """ 63 | batch_size = origins.shape[0] 64 | 65 | t_vals = jnp.linspace(0., 1., num_coarse_samples) 66 | if not use_linear_disparity: 67 | z_vals = near * (1. - t_vals) + far * t_vals 68 | else: 69 | z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals) 70 | if use_stratified_sampling: 71 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 72 | upper = jnp.concatenate([mids, z_vals[..., -1:]], -1) 73 | lower = jnp.concatenate([z_vals[..., :1], mids], -1) 74 | t_rand = random.uniform(key, [batch_size, num_coarse_samples]) 75 | z_vals = lower + (upper - lower) * t_rand 76 | else: 77 | # Broadcast z_vals to make the returned shape consistent. 78 | z_vals = jnp.broadcast_to(z_vals[None, ...], 79 | [batch_size, num_coarse_samples]) 80 | 81 | return (z_vals, (origins[..., None, :] + 82 | z_vals[..., :, None] * directions[..., None, :])) 83 | 84 | 85 | def volumetric_rendering(rgb, 86 | sigma, 87 | z_vals, 88 | dirs, 89 | use_white_background, 90 | sample_at_infinity=True, 91 | eps=1e-10): 92 | """Volumetric Rendering Function. 93 | 94 | Args: 95 | rgb: an array of size (B,S,3) containing the RGB color values. 96 | sigma: an array of size (B,S) containing the densities. 97 | z_vals: an array of size (B,S) containing the z-coordinate of the samples. 98 | dirs: an array of size (B,3) containing the directions of rays. 99 | use_white_background: whether to assume a white background or not. 100 | sample_at_infinity: if True adds a sample at infinity. 101 | eps: a small number to prevent numerical issues. 102 | 103 | Returns: 104 | A dictionary containing: 105 | rgb: an array of size (B,3) containing the rendered colors. 106 | depth: an array of size (B,) containing the rendered depth. 107 | acc: an array of size (B,) containing the accumulated density. 108 | weights: an array of size (B,S) containing the weight of each sample. 109 | """ 110 | # TODO(keunhong): remove this hack. 111 | last_sample_z = 1e10 if sample_at_infinity else 1e-19 112 | dists = jnp.concatenate([ 113 | z_vals[..., 1:] - z_vals[..., :-1], 114 | jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape) 115 | ], -1) 116 | dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1) 117 | alpha = 1.0 - jnp.exp(-sigma * dists) 118 | # Prepend a 1.0 to make this an 'exclusive' cumprod as in `tf.math.cumprod`. 119 | accum_prod = jnp.concatenate([ 120 | jnp.ones_like(alpha[..., :1], alpha.dtype), 121 | jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1), 122 | ], axis=-1) 123 | weights = alpha * accum_prod 124 | 125 | rgb = (weights[..., None] * rgb).sum(axis=-2) 126 | exp_depth = (weights * z_vals).sum(axis=-1) 127 | med_depth = compute_depth_map(weights, z_vals) 128 | acc = weights.sum(axis=-1) 129 | if use_white_background: 130 | rgb = rgb + (1. - acc[..., None]) 131 | 132 | if sample_at_infinity: 133 | acc = weights[..., :-1].sum(axis=-1) 134 | 135 | out = { 136 | 'rgb': rgb, 137 | 'depth': exp_depth, 138 | 'med_depth': med_depth, 139 | 'acc': acc, 140 | 'weights': weights, 141 | } 142 | return out 143 | 144 | 145 | def piecewise_constant_pdf(key, bins, weights, num_coarse_samples, 146 | use_stratified_sampling): 147 | """Piecewise-Constant PDF sampling. 148 | 149 | Args: 150 | key: jnp.ndarray(float32), [2,], random number generator. 151 | bins: jnp.ndarray(float32), [batch_size, n_bins + 1]. 152 | weights: jnp.ndarray(float32), [batch_size, n_bins]. 153 | num_coarse_samples: int, the number of samples. 154 | use_stratified_sampling: bool, use use_stratified_sampling samples. 155 | 156 | Returns: 157 | z_samples: jnp.ndarray(float32), [batch_size, num_coarse_samples]. 158 | """ 159 | eps = 1e-5 160 | 161 | # Get pdf 162 | weights += eps # prevent nans 163 | pdf = weights / weights.sum(axis=-1, keepdims=True) 164 | cdf = jnp.cumsum(pdf, axis=-1) 165 | cdf = jnp.concatenate([jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf], axis=-1) 166 | 167 | # Take uniform samples 168 | if use_stratified_sampling: 169 | u = random.uniform(key, list(cdf.shape[:-1]) + [num_coarse_samples]) 170 | else: 171 | u = jnp.linspace(0., 1., num_coarse_samples) 172 | u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_coarse_samples]) 173 | 174 | # Invert CDF. This takes advantage of the fact that `bins` is sorted. 175 | mask = (u[..., None, :] >= cdf[..., :, None]) 176 | 177 | def minmax(x): 178 | x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2) 179 | x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2) 180 | x0 = jnp.minimum(x0, x[..., -2:-1]) 181 | x1 = jnp.maximum(x1, x[..., 1:2]) 182 | return x0, x1 183 | 184 | bins_g0, bins_g1 = minmax(bins) 185 | cdf_g0, cdf_g1 = minmax(cdf) 186 | 187 | denom = (cdf_g1 - cdf_g0) 188 | denom = jnp.where(denom < eps, 1., denom) 189 | t = (u - cdf_g0) / denom 190 | z_samples = bins_g0 + t * (bins_g1 - bins_g0) 191 | 192 | # Prevent gradient from backprop-ing through samples 193 | return lax.stop_gradient(z_samples) 194 | 195 | 196 | def sample_pdf(key, bins, weights, origins, directions, z_vals, 197 | num_coarse_samples, use_stratified_sampling): 198 | """Hierarchical sampling. 199 | 200 | Args: 201 | key: jnp.ndarray(float32), [2,], random number generator. 202 | bins: jnp.ndarray(float32), [batch_size, n_bins + 1]. 203 | weights: jnp.ndarray(float32), [batch_size, n_bins]. 204 | origins: ray origins. 205 | directions: ray directions. 206 | z_vals: jnp.ndarray(float32), [batch_size, n_coarse_samples]. 207 | num_coarse_samples: int, the number of samples. 208 | use_stratified_sampling: bool, use use_stratified_sampling samples. 209 | 210 | Returns: 211 | z_vals: jnp.ndarray(float32), 212 | [batch_size, n_coarse_samples + num_fine_samples]. 213 | points: jnp.ndarray(float32), 214 | [batch_size, n_coarse_samples + num_fine_samples, 3]. 215 | """ 216 | z_samples = piecewise_constant_pdf(key, bins, weights, num_coarse_samples, 217 | use_stratified_sampling) 218 | # Compute united z_vals and sample points 219 | z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1) 220 | return z_vals, ( 221 | origins[..., None, :] + z_vals[..., None] * directions[..., None, :]) 222 | 223 | 224 | def compute_opaqueness_mask(weights, depth_threshold=0.5): 225 | """Computes a mask which will be 1.0 at the depth point. 226 | 227 | Args: 228 | weights: the density weights from NeRF. 229 | depth_threshold: the accumulation threshold which will be used as the depth 230 | termination point. 231 | 232 | Returns: 233 | A tensor containing a mask with the same size as weights that has one 234 | element long the sample dimension that is 1.0. This element is the point 235 | where the 'surface' is. 236 | """ 237 | cumulative_contribution = jnp.cumsum(weights, axis=-1) 238 | depth_threshold = jnp.array(depth_threshold, dtype=weights.dtype) 239 | opaqueness = cumulative_contribution >= depth_threshold 240 | false_padding = jnp.zeros_like(opaqueness[..., :1]) 241 | padded_opaqueness = jnp.concatenate( 242 | [false_padding, opaqueness[..., :-1]], axis=-1) 243 | opaqueness_mask = jnp.logical_xor(opaqueness, padded_opaqueness) 244 | opaqueness_mask = opaqueness_mask.astype(weights.dtype) 245 | return opaqueness_mask 246 | 247 | 248 | def compute_depth_index(weights, depth_threshold=0.5): 249 | """Compute the sample index of the median depth accumulation.""" 250 | opaqueness_mask = compute_opaqueness_mask(weights, depth_threshold) 251 | return jnp.argmax(opaqueness_mask, axis=-1) 252 | 253 | 254 | def compute_depth_map(weights, z_vals, depth_threshold=0.5): 255 | """Compute the depth using the median accumulation. 256 | 257 | Note that this differs from the depth computation in NeRF-W's codebase! 258 | 259 | Args: 260 | weights: the density weights from NeRF. 261 | z_vals: the z coordinates of the samples. 262 | depth_threshold: the accumulation threshold which will be used as the depth 263 | termination point. 264 | 265 | Returns: 266 | A tensor containing the depth of each input pixel. 267 | """ 268 | opaqueness_mask = compute_opaqueness_mask(weights, depth_threshold) 269 | return jnp.sum(opaqueness_mask * z_vals, axis=-1) 270 | 271 | 272 | def noise_regularize(key, raw, noise_std, use_stratified_sampling): 273 | """Regularize the density prediction by adding gaussian noise. 274 | 275 | Args: 276 | key: jnp.ndarray(float32), [2,], random number generator. 277 | raw: jnp.ndarray(float32), [batch_size, num_coarse_samples, 4]. 278 | noise_std: float, std dev of noise added to regularize sigma output. 279 | use_stratified_sampling: add noise only if use_stratified_sampling is True. 280 | 281 | Returns: 282 | raw: jnp.ndarray(float32), [batch_size, num_coarse_samples, 4], updated raw. 283 | """ 284 | if (noise_std is not None) and noise_std > 0.0 and use_stratified_sampling: 285 | unused_key, key = random.split(key) 286 | noise = random.normal(key, raw[..., 3:4].shape, dtype=raw.dtype) * noise_std 287 | raw = jnp.concatenate([raw[..., :3], raw[..., 3:4] + noise], axis=-1) 288 | return raw 289 | 290 | 291 | def broadcast_feature_to(array: jnp.ndarray, shape: jnp.shape): 292 | """Matches the shape dimensions (everything except the channel dims). 293 | 294 | This is useful when you watch to match the shape of two features that have 295 | a different number of channels. 296 | 297 | Args: 298 | array: the array to broadcast. 299 | shape: the shape to broadcast the tensor to. 300 | 301 | Returns: 302 | The broadcasted tensor. 303 | """ 304 | out_shape = (*shape[:-1], array.shape[-1]) 305 | return jnp.broadcast_to(array, out_shape) 306 | 307 | 308 | def metadata_like(rays, metadata_id): 309 | """Create a metadata array like a ray batch.""" 310 | return jnp.full_like(rays[..., :1], fill_value=metadata_id, dtype=jnp.uint32) 311 | 312 | 313 | def vmap_module(module, in_axes=0, out_axes=0, num_batch_dims=1): 314 | """Vectorize a module. 315 | 316 | Args: 317 | module: the module to vectorize. 318 | in_axes: the `in_axes` argument passed to vmap. See `jax.vmap`. 319 | out_axes: the `out_axes` argument passed to vmap. See `jax.vmap`. 320 | num_batch_dims: the number of batch dimensions (how many times to apply vmap 321 | to the module). 322 | 323 | Returns: 324 | A vectorized module. 325 | """ 326 | for _ in range(num_batch_dims): 327 | module = nn.vmap( 328 | module, 329 | variable_axes={'params': None}, 330 | split_rngs={'params': False}, 331 | in_axes=in_axes, 332 | out_axes=out_axes) 333 | 334 | return module 335 | 336 | 337 | def identity_initializer(_, shape): 338 | max_shape = max(shape) 339 | return jnp.eye(max_shape)[:shape[0], :shape[1]] 340 | 341 | 342 | def posenc(x, min_deg, max_deg, use_identity=False, alpha=None): 343 | """Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1].""" 344 | batch_shape = x.shape[:-1] 345 | scales = 2.0 ** jnp.arange(min_deg, max_deg) 346 | # (*, F, C). 347 | xb = x[..., None, :] * scales[:, None] 348 | # (*, F, 2, C). 349 | four_feat = jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], axis=-2)) 350 | 351 | if alpha is not None: 352 | window = posenc_window(min_deg, max_deg, alpha) 353 | four_feat = window[..., None, None] * four_feat 354 | 355 | # (*, 2*F*C). 356 | four_feat = four_feat.reshape((*batch_shape, -1)) 357 | 358 | if use_identity: 359 | return jnp.concatenate([x, four_feat], axis=-1) 360 | else: 361 | return four_feat 362 | 363 | 364 | def posenc_window(min_deg, max_deg, alpha): 365 | """Windows a posenc using a cosiney window. 366 | 367 | This is equivalent to taking a truncated Hann window and sliding it to the 368 | right along the frequency spectrum. 369 | 370 | Args: 371 | min_deg: the lower frequency band. 372 | max_deg: the upper frequency band. 373 | alpha: will ease in each frequency as alpha goes from 0.0 to num_freqs. 374 | 375 | Returns: 376 | A 1-d numpy array with num_sample elements containing the window. 377 | """ 378 | bands = jnp.arange(min_deg, max_deg) 379 | x = jnp.clip(alpha - bands, 0.0, 1.0) 380 | return 0.5 * (1 + jnp.cos(jnp.pi * x + jnp.pi)) 381 | -------------------------------------------------------------------------------- /hypernerf/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Modules for NeRF models.""" 16 | import functools 17 | from typing import Any, Optional, Tuple 18 | 19 | from flax import linen as nn 20 | import gin 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | from hypernerf import model_utils 25 | from hypernerf import types 26 | 27 | 28 | def get_norm_layer(norm_type): 29 | """Translates a norm type to a norm constructor.""" 30 | if norm_type is None or norm_type == 'none': 31 | return None 32 | elif norm_type == 'layer': 33 | return functools.partial(nn.LayerNorm, use_scale=False, use_bias=False) 34 | elif norm_type == 'group': 35 | return functools.partial(nn.GroupNorm, use_scale=False, use_bias=False) 36 | elif norm_type == 'batch': 37 | return functools.partial(nn.BatchNorm, use_scale=False, use_bias=False) 38 | else: 39 | raise ValueError(f'Unknown norm type {norm_type}') 40 | 41 | 42 | class MLP(nn.Module): 43 | """Basic MLP class with hidden layers and an output layers.""" 44 | depth: int 45 | width: int 46 | hidden_init: types.Initializer = jax.nn.initializers.glorot_uniform() 47 | hidden_activation: types.Activation = nn.relu 48 | hidden_norm: Optional[types.Normalizer] = None 49 | output_init: Optional[types.Initializer] = None 50 | output_channels: int = 0 51 | output_activation: Optional[types.Activation] = lambda x: x 52 | use_bias: bool = True 53 | skips: Tuple[int] = tuple() 54 | 55 | @nn.compact 56 | def __call__(self, x): 57 | inputs = x 58 | for i in range(self.depth): 59 | layer = nn.Dense( 60 | self.width, 61 | use_bias=self.use_bias, 62 | kernel_init=self.hidden_init, 63 | name=f'hidden_{i}') 64 | if i in self.skips: 65 | x = jnp.concatenate([x, inputs], axis=-1) 66 | x = layer(x) 67 | if self.hidden_norm is not None: 68 | x = self.hidden_norm()(x) # pylint: disable=not-callable 69 | x = self.hidden_activation(x) 70 | 71 | if self.output_channels > 0: 72 | logit_layer = nn.Dense( 73 | self.output_channels, 74 | use_bias=self.use_bias, 75 | kernel_init=self.output_init, 76 | name='logit') 77 | x = logit_layer(x) 78 | if self.output_activation is not None: 79 | x = self.output_activation(x) 80 | 81 | return x 82 | 83 | 84 | class NerfMLP(nn.Module): 85 | """A simple MLP. 86 | 87 | Attributes: 88 | nerf_trunk_depth: int, the depth of the first part of MLP. 89 | nerf_trunk_width: int, the width of the first part of MLP. 90 | nerf_rgb_branch_depth: int, the depth of the second part of MLP. 91 | nerf_rgb_branch_width: int, the width of the second part of MLP. 92 | activation: function, the activation function used in the MLP. 93 | skips: which layers to add skip layers to. 94 | alpha_channels: int, the number of alpha_channelss. 95 | rgb_channels: int, the number of rgb_channelss. 96 | condition_density: if True put the condition at the begining which 97 | conditions the density of the field. 98 | """ 99 | trunk_depth: int = 8 100 | trunk_width: int = 256 101 | 102 | rgb_branch_depth: int = 1 103 | rgb_branch_width: int = 128 104 | rgb_channels: int = 3 105 | 106 | alpha_branch_depth: int = 0 107 | alpha_branch_width: int = 128 108 | alpha_channels: int = 1 109 | 110 | activation: types.Activation = nn.relu 111 | norm: Optional[Any] = None 112 | skips: Tuple[int] = (4,) 113 | 114 | @nn.compact 115 | def __call__(self, x, alpha_condition, rgb_condition): 116 | """Multi-layer perception for nerf. 117 | 118 | Args: 119 | x: sample points with shape [batch, num_coarse_samples, feature]. 120 | alpha_condition: a condition array provided to the alpha branch. 121 | rgb_condition: a condition array provided in the RGB branch. 122 | 123 | Returns: 124 | raw: [batch, num_coarse_samples, rgb_channels+alpha_channels]. 125 | """ 126 | dense = functools.partial( 127 | nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform()) 128 | 129 | feature_dim = x.shape[-1] 130 | num_samples = x.shape[1] 131 | x = x.reshape([-1, feature_dim]) 132 | 133 | def broadcast_condition(c): 134 | # Broadcast condition from [batch, feature] to 135 | # [batch, num_coarse_samples, feature] since all the samples along the 136 | # same ray has the same viewdir. 137 | c = jnp.tile(c[:, None, :], (1, num_samples, 1)) 138 | # Collapse the [batch, num_coarse_samples, feature] tensor to 139 | # [batch * num_coarse_samples, feature] to be fed into nn.Dense. 140 | c = c.reshape([-1, c.shape[-1]]) 141 | return c 142 | 143 | trunk_mlp = MLP(depth=self.trunk_depth, 144 | width=self.trunk_width, 145 | hidden_activation=self.activation, 146 | hidden_norm=self.norm, 147 | hidden_init=jax.nn.initializers.glorot_uniform(), 148 | skips=self.skips) 149 | rgb_mlp = MLP(depth=self.rgb_branch_depth, 150 | width=self.rgb_branch_width, 151 | hidden_activation=self.activation, 152 | hidden_norm=self.norm, 153 | hidden_init=jax.nn.initializers.glorot_uniform(), 154 | output_init=jax.nn.initializers.glorot_uniform(), 155 | output_channels=self.rgb_channels) 156 | alpha_mlp = MLP(depth=self.alpha_branch_depth, 157 | width=self.alpha_branch_width, 158 | hidden_activation=self.activation, 159 | hidden_norm=self.norm, 160 | hidden_init=jax.nn.initializers.glorot_uniform(), 161 | output_init=jax.nn.initializers.glorot_uniform(), 162 | output_channels=self.alpha_channels) 163 | 164 | if self.trunk_depth > 0: 165 | x = trunk_mlp(x) 166 | 167 | if (alpha_condition is not None) or (rgb_condition is not None): 168 | bottleneck = dense(self.trunk_width, name='bottleneck')(x) 169 | 170 | if alpha_condition is not None: 171 | alpha_condition = broadcast_condition(alpha_condition) 172 | alpha_input = jnp.concatenate([bottleneck, alpha_condition], axis=-1) 173 | else: 174 | alpha_input = x 175 | alpha = alpha_mlp(alpha_input) 176 | 177 | if rgb_condition is not None: 178 | rgb_condition = broadcast_condition(rgb_condition) 179 | rgb_input = jnp.concatenate([bottleneck, rgb_condition], axis=-1) 180 | else: 181 | rgb_input = x 182 | rgb = rgb_mlp(rgb_input) 183 | 184 | return { 185 | 'rgb': rgb.reshape((-1, num_samples, self.rgb_channels)), 186 | 'alpha': alpha.reshape((-1, num_samples, self.alpha_channels)), 187 | } 188 | 189 | 190 | @gin.configurable(denylist=['name']) 191 | class GLOEmbed(nn.Module): 192 | """A GLO encoder module, which is just a thin wrapper around nn.Embed. 193 | 194 | Attributes: 195 | num_embeddings: The number of embeddings. 196 | features: The dimensions of each embedding. 197 | embedding_init: The initializer to use for each. 198 | """ 199 | 200 | num_embeddings: int = gin.REQUIRED 201 | num_dims: int = gin.REQUIRED 202 | embedding_init: types.Activation = nn.initializers.uniform(scale=0.05) 203 | 204 | def setup(self): 205 | self.embed = nn.Embed( 206 | num_embeddings=self.num_embeddings, 207 | features=self.num_dims, 208 | embedding_init=self.embedding_init) 209 | 210 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 211 | """Method to get embeddings for specified indices. 212 | 213 | Args: 214 | inputs: The indices to fetch embeddings for. 215 | 216 | Returns: 217 | The embeddings corresponding to the indices provided. 218 | """ 219 | if inputs.shape[-1] == 1: 220 | inputs = jnp.squeeze(inputs, axis=-1) 221 | 222 | return self.embed(inputs) 223 | 224 | 225 | @gin.configurable(denylist=['name']) 226 | class HyperSheetMLP(nn.Module): 227 | """An MLP that defines a bendy slicing surface through hyper space.""" 228 | output_channels: int = gin.REQUIRED 229 | min_deg: int = 0 230 | max_deg: int = 1 231 | 232 | depth: int = 6 233 | width: int = 64 234 | skips: Tuple[int] = (4,) 235 | hidden_init: types.Initializer = jax.nn.initializers.glorot_uniform() 236 | output_init: types.Initializer = jax.nn.initializers.normal(1e-5) 237 | # output_init: types.Initializer = jax.nn.initializers.glorot_uniform() 238 | 239 | use_residual: bool = False 240 | 241 | @nn.compact 242 | def __call__(self, points, embed, alpha=None): 243 | points_feat = model_utils.posenc( 244 | points, self.min_deg, self.max_deg, alpha=alpha) 245 | inputs = jnp.concatenate([points_feat, embed], axis=-1) 246 | mlp = MLP(depth=self.depth, 247 | width=self.width, 248 | skips=self.skips, 249 | hidden_init=self.hidden_init, 250 | output_channels=self.output_channels, 251 | output_init=self.output_init) 252 | if self.use_residual: 253 | return mlp(inputs) + embed 254 | else: 255 | return mlp(inputs) 256 | -------------------------------------------------------------------------------- /hypernerf/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Quaternion math. 16 | 17 | This module assumes the xyzw quaternion format where xyz is the imaginary part 18 | and w is the real part. 19 | 20 | Functions in this module support both batched and unbatched quaternions. 21 | """ 22 | from jax import numpy as jnp 23 | from jax.numpy import linalg 24 | 25 | 26 | def safe_acos(t, eps=1e-7): 27 | """A safe version of arccos which avoids evaluating at -1 or 1.""" 28 | return jnp.arccos(jnp.clip(t, -1.0 + eps, 1.0 - eps)) 29 | 30 | 31 | def im(q): 32 | """Fetch the imaginary part of the quaternion.""" 33 | return q[..., :3] 34 | 35 | 36 | def re(q): 37 | """Fetch the real part of the quaternion.""" 38 | return q[..., 3:] 39 | 40 | 41 | def identity(): 42 | return jnp.array([0.0, 0.0, 0.0, 1.0]) 43 | 44 | 45 | def conjugate(q): 46 | """Compute the conjugate of a quaternion.""" 47 | return jnp.concatenate([-im(q), re(q)], axis=-1) 48 | 49 | 50 | def inverse(q): 51 | """Compute the inverse of a quaternion.""" 52 | return normalize(conjugate(q)) 53 | 54 | 55 | def normalize(q): 56 | """Normalize a quaternion.""" 57 | return q / norm(q) 58 | 59 | 60 | def norm(q): 61 | return linalg.norm(q, axis=-1, keepdims=True) 62 | 63 | 64 | def multiply(q1, q2): 65 | """Multiply two quaternions.""" 66 | c = (re(q1) * im(q2) 67 | + re(q2) * im(q1) 68 | + jnp.cross(im(q1), im(q2))) 69 | w = re(q1) * re(q2) - jnp.dot(im(q1), im(q2)) 70 | return jnp.concatenate([c, w], axis=-1) 71 | 72 | 73 | def rotate(q, v): 74 | """Rotate a vector using a quaternion.""" 75 | # Create the quaternion representation of the vector. 76 | q_v = jnp.concatenate([v, jnp.zeros_like(v[..., :1])], axis=-1) 77 | return im(multiply(multiply(q, q_v), conjugate(q))) 78 | 79 | 80 | def log(q, eps=1e-8): 81 | """Computes the quaternion logarithm. 82 | 83 | References: 84 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions 85 | 86 | Args: 87 | q: the quaternion in (x,y,z,w) format. 88 | eps: an epsilon value for numerical stability. 89 | 90 | Returns: 91 | The logarithm of q. 92 | """ 93 | mag = linalg.norm(q, axis=-1, keepdims=True) 94 | v = im(q) 95 | s = re(q) 96 | w = jnp.log(mag) 97 | denom = jnp.maximum( 98 | linalg.norm(v, axis=-1, keepdims=True), eps * jnp.ones_like(v)) 99 | xyz = v / denom * safe_acos(s / eps) 100 | return jnp.concatenate((xyz, w), axis=-1) 101 | 102 | 103 | def exp(q, eps=1e-8): 104 | """Computes the quaternion exponential. 105 | 106 | References: 107 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions 108 | 109 | Args: 110 | q: the quaternion in (x,y,z,w) format or (x,y,z) if is_pure is True. 111 | eps: an epsilon value for numerical stability. 112 | 113 | Returns: 114 | The exponential of q. 115 | """ 116 | is_pure = q.shape[-1] == 3 117 | if is_pure: 118 | s = jnp.zeros_like(q[..., -1:]) 119 | v = q 120 | else: 121 | v = im(q) 122 | s = re(q) 123 | 124 | norm_v = linalg.norm(v, axis=-1, keepdims=True) 125 | exp_s = jnp.exp(s) 126 | w = jnp.cos(norm_v) 127 | xyz = jnp.sin(norm_v) * v / jnp.maximum(norm_v, eps * jnp.ones_like(norm_v)) 128 | return exp_s * jnp.concatenate((xyz, w), axis=-1) 129 | 130 | 131 | def to_rotation_matrix(q): 132 | """Constructs a rotation matrix from a quaternion. 133 | 134 | Args: 135 | q: a (*,4) array containing quaternions. 136 | 137 | Returns: 138 | A (*,3,3) array containing rotation matrices. 139 | """ 140 | x, y, z, w = jnp.split(q, 4, axis=-1) 141 | s = 1.0 / jnp.sum(q ** 2, axis=-1) 142 | return jnp.stack([ 143 | jnp.stack([1 - 2 * s * (y ** 2 + z ** 2), 144 | 2 * s * (x * y - z * w), 145 | 2 * s * (x * z + y * w)], axis=0), 146 | jnp.stack([2 * s * (x * y + z * w), 147 | 1 - s * 2 * (x ** 2 + z ** 2), 148 | 2 * s * (y * z - x * w)], axis=0), 149 | jnp.stack([2 * s * (x * z - y * w), 150 | 2 * s * (y * z + x * w), 151 | 1 - 2 * s * (x ** 2 + y ** 2)], axis=0), 152 | ], axis=0) 153 | 154 | 155 | def from_rotation_matrix(m, eps=1e-9): 156 | """Construct quaternion from a rotation matrix. 157 | 158 | Args: 159 | m: a (*,3,3) array containing rotation matrices. 160 | eps: a small number for numerical stability. 161 | 162 | Returns: 163 | A (*,4) array containing quaternions. 164 | """ 165 | trace = jnp.trace(m) 166 | m00 = m[..., 0, 0] 167 | m01 = m[..., 0, 1] 168 | m02 = m[..., 0, 2] 169 | m10 = m[..., 1, 0] 170 | m11 = m[..., 1, 1] 171 | m12 = m[..., 1, 2] 172 | m20 = m[..., 2, 0] 173 | m21 = m[..., 2, 1] 174 | m22 = m[..., 2, 2] 175 | 176 | def tr_positive(): 177 | sq = jnp.sqrt(trace + 1.0) * 2. # sq = 4 * w. 178 | w = 0.25 * sq 179 | x = jnp.divide(m21 - m12, sq) 180 | y = jnp.divide(m02 - m20, sq) 181 | z = jnp.divide(m10 - m01, sq) 182 | return jnp.stack((x, y, z, w), axis=-1) 183 | 184 | def cond_1(): 185 | sq = jnp.sqrt(1.0 + m00 - m11 - m22 + eps) * 2. # sq = 4 * x. 186 | w = jnp.divide(m21 - m12, sq) 187 | x = 0.25 * sq 188 | y = jnp.divide(m01 + m10, sq) 189 | z = jnp.divide(m02 + m20, sq) 190 | return jnp.stack((x, y, z, w), axis=-1) 191 | 192 | def cond_2(): 193 | sq = jnp.sqrt(1.0 + m11 - m00 - m22 + eps) * 2. # sq = 4 * y. 194 | w = jnp.divide(m02 - m20, sq) 195 | x = jnp.divide(m01 + m10, sq) 196 | y = 0.25 * sq 197 | z = jnp.divide(m12 + m21, sq) 198 | return jnp.stack((x, y, z, w), axis=-1) 199 | 200 | def cond_3(): 201 | sq = jnp.sqrt(1.0 + m22 - m00 - m11 + eps) * 2. # sq = 4 * z. 202 | w = jnp.divide(m10 - m01, sq) 203 | x = jnp.divide(m02 + m20, sq) 204 | y = jnp.divide(m12 + m21, sq) 205 | z = 0.25 * sq 206 | return jnp.stack((x, y, z, w), axis=-1) 207 | 208 | def cond_idx(cond): 209 | cond = jnp.expand_dims(cond, -1) 210 | cond = jnp.tile(cond, [1] * (len(m.shape) - 2) + [4]) 211 | return cond 212 | 213 | where_2 = jnp.where(cond_idx(m11 > m22), cond_2(), cond_3()) 214 | where_1 = jnp.where(cond_idx((m00 > m11) & (m00 > m22)), cond_1(), where_2) 215 | return jnp.where(cond_idx(trace > 0), tr_positive(), where_1) 216 | 217 | 218 | def from_axis_angle(axis, theta): 219 | """Constructs a quaternion for the given axis/angle rotation.""" 220 | qx = axis[0] * jnp.sin(theta / 2) 221 | qy = axis[1] * jnp.sin(theta / 2) 222 | qz = axis[2] * jnp.sin(theta / 2) 223 | qw = jnp.cos(theta / 2) 224 | 225 | return jnp.squeeze(jnp.array([qx, qy, qz, qw])) 226 | -------------------------------------------------------------------------------- /hypernerf/rigid_body.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=invalid-name 16 | # pytype: disable=attribute-error 17 | import jax 18 | from jax import numpy as jnp 19 | 20 | 21 | def matmul(a, b): 22 | """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" 23 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 24 | 25 | 26 | @jax.jit 27 | def skew(w: jnp.ndarray) -> jnp.ndarray: 28 | """Build a skew matrix ("cross product matrix") for vector w. 29 | 30 | Modern Robotics Eqn 3.30. 31 | 32 | Args: 33 | w: (3,) A 3-vector 34 | 35 | Returns: 36 | W: (3, 3) A skew matrix such that W @ v == w x v 37 | """ 38 | w = jnp.reshape(w, (3)) 39 | return jnp.array([[0.0, -w[2], w[1]], \ 40 | [w[2], 0.0, -w[0]], \ 41 | [-w[1], w[0], 0.0]]) 42 | 43 | 44 | def rp_to_se3(R: jnp.ndarray, p: jnp.ndarray) -> jnp.ndarray: 45 | """Rotation and translation to homogeneous transform. 46 | 47 | Args: 48 | R: (3, 3) An orthonormal rotation matrix. 49 | p: (3,) A 3-vector representing an offset. 50 | 51 | Returns: 52 | X: (4, 4) The homogeneous transformation matrix described by rotating by R 53 | and translating by p. 54 | """ 55 | p = jnp.reshape(p, (3, 1)) 56 | return jnp.block([[R, p], [jnp.array([[0.0, 0.0, 0.0, 1.0]])]]) 57 | 58 | 59 | def exp_so3(w: jnp.ndarray, theta: float) -> jnp.ndarray: 60 | """Exponential map from Lie algebra so3 to Lie group SO3. 61 | 62 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula. 63 | 64 | Args: 65 | w: (3,) An axis of rotation. This is assumed to be a unit-vector. 66 | theta: An angle of rotation. 67 | 68 | Returns: 69 | R: (3, 3) An orthonormal rotation matrix representing a rotation of 70 | magnitude theta about axis w. 71 | """ 72 | W = skew(w) 73 | return (jnp.eye(3) 74 | + jnp.sin(theta) * W + (1.0 - jnp.cos(theta)) * matmul(W, W)) 75 | 76 | 77 | def exp_se3(S: jnp.ndarray, theta: float) -> jnp.ndarray: 78 | """Exponential map from Lie algebra so3 to Lie group SO3. 79 | 80 | Modern Robotics Eqn 3.88. 81 | 82 | Args: 83 | S: (6,) A screw axis of motion. 84 | theta: Magnitude of motion. 85 | 86 | Returns: 87 | a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating 88 | motion of magnitude theta about S for one second. 89 | """ 90 | w, v = jnp.split(S, 2) 91 | W = skew(w) 92 | R = exp_so3(w, theta) 93 | p = matmul((theta * jnp.eye(3) + (1.0 - jnp.cos(theta)) * W + 94 | (theta - jnp.sin(theta)) * matmul(W, W)), v) 95 | return rp_to_se3(R, p) 96 | 97 | 98 | def to_homogenous(v): 99 | return jnp.concatenate([v, jnp.ones_like(v[..., :1])], axis=-1) 100 | 101 | 102 | def from_homogenous(v): 103 | return v[..., :3] / v[..., -1:] 104 | -------------------------------------------------------------------------------- /hypernerf/schedules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Annealing Schedules.""" 16 | import abc 17 | import collections 18 | import copy 19 | import math 20 | from typing import Any, Iterable, List, Tuple, Union 21 | 22 | from jax import numpy as jnp 23 | 24 | 25 | def from_tuple(x): 26 | schedule_type, *args = x 27 | return SCHEDULE_MAP[schedule_type](*args) 28 | 29 | 30 | def from_dict(d): 31 | d = copy.copy(dict(d)) 32 | schedule_type = d.pop('type') 33 | return SCHEDULE_MAP[schedule_type](**d) 34 | 35 | 36 | def from_config(schedule): 37 | """Creates a schedule from a configuration.""" 38 | if schedule is None: 39 | return NoneSchedule() 40 | if isinstance(schedule, Schedule): 41 | return schedule 42 | if isinstance(schedule, Tuple) or isinstance(schedule, List): 43 | return from_tuple(schedule) 44 | if isinstance(schedule, collections.Mapping): 45 | return from_dict(schedule) 46 | 47 | raise ValueError(f'Unknown type {type(schedule)}.') 48 | 49 | 50 | class Schedule(abc.ABC): 51 | """An interface for generic schedules..""" 52 | 53 | @abc.abstractmethod 54 | def get(self, step): 55 | """Get the value for the given step.""" 56 | raise NotImplementedError 57 | 58 | def __call__(self, step): 59 | return self.get(step) 60 | 61 | 62 | class NoneSchedule(Schedule): 63 | """Always returns None. Useful for disable schedules.""" 64 | 65 | def get(self, step): 66 | return None 67 | 68 | 69 | class ConstantSchedule(Schedule): 70 | """Linearly scaled scheduler.""" 71 | 72 | def __init__(self, value): 73 | super().__init__() 74 | self.value = value 75 | 76 | def get(self, step): 77 | """Get the value for the given step.""" 78 | if self.value is None: 79 | return None 80 | return jnp.full_like(step, self.value, dtype=jnp.float32) 81 | 82 | 83 | class LinearSchedule(Schedule): 84 | """Linearly scaled scheduler.""" 85 | 86 | def __init__(self, initial_value, final_value, num_steps): 87 | super().__init__() 88 | self.initial_value = initial_value 89 | self.final_value = final_value 90 | self.num_steps = num_steps 91 | 92 | def get(self, step): 93 | """Get the value for the given step.""" 94 | if self.num_steps == 0: 95 | return jnp.full_like(step, self.final_value, dtype=jnp.float32) 96 | alpha = jnp.minimum(step / self.num_steps, 1.0) 97 | return (1.0 - alpha) * self.initial_value + alpha * self.final_value 98 | 99 | 100 | class ExponentialSchedule(Schedule): 101 | """Exponentially decaying scheduler.""" 102 | 103 | def __init__(self, initial_value, final_value, num_steps, eps=1e-10): 104 | super().__init__() 105 | if initial_value <= final_value: 106 | raise ValueError('Final value must be less than initial value.') 107 | 108 | self.initial_value = initial_value 109 | self.final_value = final_value 110 | self.num_steps = num_steps 111 | self.eps = eps 112 | 113 | def get(self, step): 114 | """Get the value for the given step.""" 115 | if step >= self.num_steps: 116 | return jnp.full_like(step, self.final_value, dtype=jnp.float32) 117 | 118 | final_value = max(self.final_value, self.eps) 119 | base = final_value / self.initial_value 120 | exponent = step / (self.num_steps - 1) 121 | if step >= self.num_steps: 122 | return jnp.full_like(step, self.final_value, dtype=jnp.float32) 123 | return self.initial_value * base**exponent 124 | 125 | 126 | class CosineEasingSchedule(Schedule): 127 | """Schedule that eases slowly using a cosine.""" 128 | 129 | def __init__(self, initial_value, final_value, num_steps): 130 | super().__init__() 131 | self.initial_value = initial_value 132 | self.final_value = final_value 133 | self.num_steps = num_steps 134 | 135 | def get(self, step): 136 | """Get the value for the given step.""" 137 | alpha = jnp.minimum(step / self.num_steps, 1.0) 138 | scale = self.final_value - self.initial_value 139 | x = min(max(alpha, 0.0), 1.0) 140 | return (self.initial_value 141 | + scale * 0.5 * (1 + math.cos(jnp.pi * x + jnp.pi))) 142 | 143 | 144 | class StepSchedule(Schedule): 145 | """Schedule that eases slowsly using a cosine.""" 146 | 147 | def __init__(self, 148 | initial_value, 149 | decay_interval, 150 | decay_factor, 151 | max_decays, 152 | final_value=None): 153 | super().__init__() 154 | self.initial_value = initial_value 155 | self.decay_factor = decay_factor 156 | self.decay_interval = decay_interval 157 | self.max_decays = max_decays 158 | if final_value is None: 159 | final_value = self.initial_value * self.decay_factor**self.max_decays 160 | self.final_value = final_value 161 | 162 | def get(self, step): 163 | """Get the value for the given step.""" 164 | phase = step // self.decay_interval 165 | if phase >= self.max_decays: 166 | return self.final_value 167 | else: 168 | return self.initial_value * self.decay_factor**phase 169 | 170 | 171 | class PiecewiseSchedule(Schedule): 172 | """A piecewise combination of multiple schedules.""" 173 | 174 | def __init__( 175 | self, schedules: Iterable[Tuple[int, Union[Schedule, Iterable[Any]]]]): 176 | self.schedules = [from_config(s) for ms, s in schedules] 177 | milestones = jnp.array([ms for ms, s in schedules]) 178 | self.milestones = jnp.cumsum(milestones)[:-1] 179 | 180 | def get(self, step): 181 | idx = jnp.searchsorted(self.milestones, step, side='right') 182 | schedule = self.schedules[idx] 183 | base_idx = self.milestones[idx - 1] if idx >= 1 else 0 184 | return schedule.get(step - base_idx) 185 | 186 | 187 | class DelayedSchedule(Schedule): 188 | """Delays the start of the base schedule.""" 189 | 190 | def __init__(self, base_schedule: Schedule, delay_steps, delay_mult): 191 | self.base_schedule = from_config(base_schedule) 192 | self.delay_steps = delay_steps 193 | self.delay_mult = delay_mult 194 | 195 | def get(self, step): 196 | delay_rate = ( 197 | self.delay_mult 198 | + (1 - self.delay_mult) 199 | * jnp.sin(0.5 * jnp.pi * jnp.clip(step / self.delay_steps, 0, 1))) 200 | 201 | return delay_rate * self.base_schedule(step) 202 | 203 | 204 | SCHEDULE_MAP = { 205 | 'constant': ConstantSchedule, 206 | 'linear': LinearSchedule, 207 | 'exponential': ExponentialSchedule, 208 | 'cosine_easing': CosineEasingSchedule, 209 | 'step': StepSchedule, 210 | 'piecewise': PiecewiseSchedule, 211 | 'delayed': DelayedSchedule, 212 | } 213 | -------------------------------------------------------------------------------- /hypernerf/testdata/camera.json: -------------------------------------------------------------------------------- 1 | { 2 | "orientation": [ 3 | [ 4 | 0.9839451340309302, 5 | -0.09685694918727988, 6 | 0.14990231689666453 7 | ], 8 | [ 9 | -0.03503400050627939, 10 | -0.9284052253532321, 11 | -0.3699139850632039 12 | ], 13 | [ 14 | 0.1749988343543504, 15 | 0.35872338776687934, 16 | -0.9168930903020656 17 | ] 18 | ], 19 | "position": [ 20 | -0.3236620944132945, 21 | -3.2642885887491806, 22 | 5.416047612682676 23 | ], 24 | "focal_length": 2691.1703975975283, 25 | "principal_point": [ 26 | 1220.9973265580372, 27 | 1652.4811427847815 28 | ], 29 | "skew": 0.0, 30 | "pixel_aspect_ratio": 1.0010612377729717, 31 | "radial_distortion": [ 32 | 0.10042443556128597, 33 | -0.20908508396511707, 34 | 0.0 35 | ], 36 | "tangential": [ 37 | 0.001109850269091041, 38 | -2.5733278797779516e-05 39 | ], 40 | "image_size": [ 41 | 2448, 42 | 3264 43 | ] 44 | } 45 | -------------------------------------------------------------------------------- /hypernerf/tf_camera.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A basic camera implementation in Tensorflow.""" 16 | from typing import Tuple, Optional 17 | 18 | import tensorflow as tf 19 | from tensorflow.experimental import numpy as tnp 20 | 21 | 22 | def _norm(x): 23 | return tnp.sqrt(tnp.sum(x ** 2, axis=-1, keepdims=True)) 24 | 25 | 26 | def _compute_residual_and_jacobian( 27 | x: tnp.ndarray, 28 | y: tnp.ndarray, 29 | xd: tnp.ndarray, 30 | yd: tnp.ndarray, 31 | k1: float = 0.0, 32 | k2: float = 0.0, 33 | k3: float = 0.0, 34 | p1: float = 0.0, 35 | p2: float = 0.0, 36 | ) -> Tuple[tnp.ndarray, tnp.ndarray, tnp.ndarray, tnp.ndarray, tnp.ndarray, 37 | tnp.ndarray]: 38 | """Auxiliary function of radial_and_tangential_undistort().""" 39 | # let r(x, y) = x^2 + y^2; 40 | # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3; 41 | r = x * x + y * y 42 | d = 1.0 + r * (k1 + r * (k2 + k3 * r)) 43 | 44 | # The perfect projection is: 45 | # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2); 46 | # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2); 47 | # 48 | # Let's define 49 | # 50 | # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd; 51 | # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd; 52 | # 53 | # We are looking for a solution that satisfies 54 | # fx(x, y) = fy(x, y) = 0; 55 | fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd 56 | fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd 57 | 58 | # Compute derivative of d over [x, y] 59 | d_r = (k1 + r * (2.0 * k2 + 3.0 * k3 * r)) 60 | d_x = 2.0 * x * d_r 61 | d_y = 2.0 * y * d_r 62 | 63 | # Compute derivative of fx over x and y. 64 | fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x 65 | fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y 66 | 67 | # Compute derivative of fy over x and y. 68 | fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x 69 | fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y 70 | 71 | return fx, fy, fx_x, fx_y, fy_x, fy_y 72 | 73 | 74 | def _radial_and_tangential_undistort( 75 | xd: tnp.ndarray, 76 | yd: tnp.ndarray, 77 | k1: float = 0, 78 | k2: float = 0, 79 | k3: float = 0, 80 | p1: float = 0, 81 | p2: float = 0, 82 | eps: float = 1e-9, 83 | max_iterations=10) -> Tuple[tnp.ndarray, tnp.ndarray]: 84 | """Computes undistorted (x, y) from (xd, yd).""" 85 | # Initialize from the distorted point. 86 | x = xd 87 | y = yd 88 | 89 | for _ in range(max_iterations): 90 | fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian( 91 | x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, p1=p1, p2=p2) 92 | denominator = fy_x * fx_y - fx_x * fy_y 93 | x_numerator = fx * fy_y - fy * fx_y 94 | y_numerator = fy * fx_x - fx * fy_x 95 | step_x = tnp.where( 96 | tnp.abs(denominator) > eps, x_numerator / denominator, 97 | tnp.zeros_like(denominator)) 98 | step_y = tnp.where( 99 | tnp.abs(denominator) > eps, y_numerator / denominator, 100 | tnp.zeros_like(denominator)) 101 | 102 | x = x + step_x 103 | y = y + step_y 104 | 105 | return x, y 106 | 107 | 108 | class TFCamera: 109 | """A duplicate of our JAX-basded camera class. 110 | 111 | This is necessary to use tf.data.Dataset. 112 | """ 113 | 114 | def __init__(self, 115 | orientation: tnp.ndarray, 116 | position: tnp.ndarray, 117 | focal_length: float, 118 | principal_point: tnp.ndarray, 119 | image_size: tnp.ndarray, 120 | skew: float = 0.0, 121 | pixel_aspect_ratio: float = 1.0, 122 | radial_distortion: Optional[tnp.ndarray] = None, 123 | tangential_distortion: Optional[tnp.ndarray] = None, 124 | dtype=tnp.float32): 125 | """Constructor for camera class.""" 126 | if radial_distortion is None: 127 | radial_distortion = tnp.array([0.0, 0.0, 0.0], dtype) 128 | if tangential_distortion is None: 129 | tangential_distortion = tnp.array([0.0, 0.0], dtype) 130 | 131 | self.orientation = tnp.array(orientation, dtype) 132 | self.position = tnp.array(position, dtype) 133 | self.focal_length = tnp.array(focal_length, dtype) 134 | self.principal_point = tnp.array(principal_point, dtype) 135 | self.skew = tnp.array(skew, dtype) 136 | self.pixel_aspect_ratio = tnp.array(pixel_aspect_ratio, dtype) 137 | self.radial_distortion = tnp.array(radial_distortion, dtype) 138 | self.tangential_distortion = tnp.array(tangential_distortion, dtype) 139 | self.image_size = tnp.array(image_size, dtype) 140 | self.dtype = dtype 141 | 142 | @property 143 | def scale_factor_x(self): 144 | return self.focal_length 145 | 146 | @property 147 | def scale_factor_y(self): 148 | return self.focal_length * self.pixel_aspect_ratio 149 | 150 | @property 151 | def principal_point_x(self): 152 | return self.principal_point[0] 153 | 154 | @property 155 | def principal_point_y(self): 156 | return self.principal_point[1] 157 | 158 | @property 159 | def image_size_y(self): 160 | return self.image_size[1] 161 | 162 | @property 163 | def image_size_x(self): 164 | return self.image_size[0] 165 | 166 | @property 167 | def image_shape(self): 168 | return self.image_size_y, self.image_size_x 169 | 170 | @property 171 | def optical_axis(self): 172 | return self.orientation[2, :] 173 | 174 | def pixel_to_local_rays(self, pixels: tnp.ndarray): 175 | """Returns the local ray directions for the provided pixels.""" 176 | y = ((pixels[..., 1] - self.principal_point_y) / self.scale_factor_y) 177 | x = ((pixels[..., 0] - self.principal_point_x - y * self.skew) / 178 | self.scale_factor_x) 179 | 180 | x, y = _radial_and_tangential_undistort( 181 | x, 182 | y, 183 | k1=self.radial_distortion[0], 184 | k2=self.radial_distortion[1], 185 | k3=self.radial_distortion[2], 186 | p1=self.tangential_distortion[0], 187 | p2=self.tangential_distortion[1]) 188 | 189 | dirs = tnp.stack([x, y, tnp.ones_like(x)], axis=-1) 190 | return dirs / _norm(dirs) 191 | 192 | def pixels_to_rays(self, 193 | pixels: tnp.ndarray) -> Tuple[tnp.ndarray, tnp.ndarray]: 194 | """Returns the rays for the provided pixels. 195 | 196 | Args: 197 | pixels: [A1, ..., An, 2] tensor or np.array containing 2d pixel positions. 198 | 199 | Returns: 200 | An array containing the normalized ray directions in world coordinates. 201 | """ 202 | if pixels.shape[-1] != 2: 203 | raise ValueError('The last dimension of pixels must be 2.') 204 | if pixels.dtype != self.dtype: 205 | raise ValueError(f'pixels dtype ({pixels.dtype!r}) must match camera ' 206 | f'dtype ({self.dtype!r})') 207 | 208 | local_rays_dir = self.pixel_to_local_rays(pixels) 209 | rays_dir = tf.linalg.matvec( 210 | self.orientation, local_rays_dir, transpose_a=True) 211 | 212 | # Normalize rays. 213 | rays_dir = rays_dir / _norm(rays_dir) 214 | return rays_dir 215 | 216 | def pixels_to_points(self, pixels: tnp.ndarray, depth: tnp.ndarray): 217 | rays_through_pixels = self.pixels_to_rays(pixels) 218 | cosa = rays_through_pixels @ self.optical_axis 219 | points = ( 220 | rays_through_pixels * depth[..., tnp.newaxis] / cosa[..., tnp.newaxis] + 221 | self.position) 222 | return points 223 | 224 | def points_to_local_points(self, points: tnp.ndarray): 225 | translated_points = points - self.position 226 | local_points = (self.orientation @ translated_points.T).T 227 | return local_points 228 | 229 | def get_pixel_centers(self): 230 | """Returns the pixel centers.""" 231 | xx, yy = tf.meshgrid(tf.range(self.image_size_x), 232 | tf.range(self.image_size_y)) 233 | return tf.cast(tf.stack([xx, yy], axis=-1), self.dtype) + 0.5 234 | -------------------------------------------------------------------------------- /hypernerf/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to training NeRFs.""" 16 | import functools 17 | from typing import Any, Callable, Dict 18 | 19 | from absl import logging 20 | import flax 21 | from flax import struct 22 | from flax import traverse_util 23 | from flax.training import checkpoints 24 | import jax 25 | from jax import lax 26 | from jax import numpy as jnp 27 | from jax import random 28 | from jax import vmap 29 | 30 | from hypernerf import model_utils 31 | from hypernerf import models 32 | from hypernerf import utils 33 | 34 | 35 | @struct.dataclass 36 | class ScalarParams: 37 | """Scalar parameters for training.""" 38 | learning_rate: float 39 | elastic_loss_weight: float = 0.0 40 | warp_reg_loss_weight: float = 0.0 41 | warp_reg_loss_alpha: float = -2.0 42 | warp_reg_loss_scale: float = 0.001 43 | background_loss_weight: float = 0.0 44 | background_noise_std: float = 0.001 45 | hyper_reg_loss_weight: float = 0.0 46 | 47 | 48 | def save_checkpoint(path, state, keep=2): 49 | """Save the state to a checkpoint.""" 50 | state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) 51 | step = state_to_save.optimizer.state.step 52 | checkpoint_path = checkpoints.save_checkpoint( 53 | path, state_to_save, step, keep=keep) 54 | logging.info('Saved checkpoint: step=%d, path=%s', int(step), checkpoint_path) 55 | return checkpoint_path 56 | 57 | 58 | def zero_adam_param_states(state: flax.optim.OptimizerState, selector: str): 59 | """Applies a gradient for a set of parameters. 60 | 61 | Args: 62 | state: a named tuple containing the state of the optimizer 63 | selector: a path string defining which parameters to freeze. 64 | 65 | Returns: 66 | A tuple containing the new parameters and the new optimizer state. 67 | """ 68 | step = state.step 69 | params = flax.core.unfreeze(state.param_states) 70 | flat_params = {'/'.join(k): v 71 | for k, v in traverse_util.flatten_dict(params).items()} 72 | for k in flat_params: 73 | if k.startswith(selector): 74 | v = flat_params[k] 75 | # pylint: disable=protected-access 76 | flat_params[k] = flax.optim.adam._AdamParamState( 77 | jnp.zeros_like(v.grad_ema), jnp.zeros_like(v.grad_sq_ema)) 78 | 79 | new_param_states = traverse_util.unflatten_dict( 80 | {tuple(k.split('/')): v for k, v in flat_params.items()}) 81 | new_param_states = dict(flax.core.freeze(new_param_states)) 82 | new_state = flax.optim.OptimizerState(step, new_param_states) 83 | return new_state 84 | 85 | 86 | @jax.jit 87 | def nearest_rotation_svd(matrix, eps=1e-6): 88 | """Computes the nearest rotation using SVD.""" 89 | # TODO(keunhong): Currently this produces NaNs for some reason. 90 | u, _, vh = jnp.linalg.svd(matrix + eps, compute_uv=True, full_matrices=False) 91 | # Handle the case when there is a flip. 92 | # M will be the identity matrix except when det(UV^T) = -1 93 | # in which case the last diagonal of M will be -1. 94 | det = jnp.linalg.det(utils.matmul(u, vh)) 95 | m = jnp.stack([jnp.ones_like(det), jnp.ones_like(det), det], axis=-1) 96 | m = jnp.diag(m) 97 | r = utils.matmul(u, utils.matmul(m, vh)) 98 | return r 99 | 100 | 101 | def compute_elastic_loss(jacobian, eps=1e-6, loss_type='log_svals'): 102 | """Compute the elastic regularization loss. 103 | 104 | The loss is given by sum(log(S)^2). This penalizes the singular values 105 | when they deviate from the identity since log(1) = 0.0, 106 | where D is the diagonal matrix containing the singular values. 107 | 108 | Args: 109 | jacobian: the Jacobian of the point transformation. 110 | eps: a small value to prevent taking the log of zero. 111 | loss_type: which elastic loss type to use. 112 | 113 | Returns: 114 | The elastic regularization loss. 115 | """ 116 | if loss_type == 'log_svals': 117 | svals = jnp.linalg.svd(jacobian, compute_uv=False) 118 | log_svals = jnp.log(jnp.maximum(svals, eps)) 119 | sq_residual = jnp.sum(log_svals**2, axis=-1) 120 | elif loss_type == 'svals': 121 | svals = jnp.linalg.svd(jacobian, compute_uv=False) 122 | sq_residual = jnp.sum((svals - 1.0)**2, axis=-1) 123 | elif loss_type == 'jtj': 124 | jtj = utils.matmul(jacobian, jacobian.T) 125 | sq_residual = ((jtj - jnp.eye(3)) ** 2).sum() / 4.0 126 | elif loss_type == 'div': 127 | div = utils.jacobian_to_div(jacobian) 128 | sq_residual = div ** 2 129 | elif loss_type == 'det': 130 | det = jnp.linalg.det(jacobian) 131 | sq_residual = (det - 1.0) ** 2 132 | elif loss_type == 'log_det': 133 | det = jnp.linalg.det(jacobian) 134 | sq_residual = jnp.log(jnp.maximum(det, eps)) ** 2 135 | elif loss_type == 'nr': 136 | rot = nearest_rotation_svd(jacobian) 137 | sq_residual = jnp.sum((jacobian - rot) ** 2) 138 | else: 139 | raise NotImplementedError( 140 | f'Unknown elastic loss type {loss_type!r}') 141 | residual = jnp.sqrt(sq_residual) 142 | loss = utils.general_loss_with_squared_residual( 143 | sq_residual, alpha=-2.0, scale=0.03) 144 | return loss, residual 145 | 146 | 147 | @functools.partial(jax.jit, static_argnums=0) 148 | def compute_background_loss(model, state, params, key, points, noise_std, 149 | alpha=-2, scale=0.001): 150 | """Compute the background regularization loss.""" 151 | metadata = random.choice(key, model.warp_embeds, shape=(points.shape[0], 1)) 152 | point_noise = noise_std * random.normal(key, points.shape) 153 | points = points + point_noise 154 | warp_fn = functools.partial(model.apply, method=model.apply_warp) 155 | warp_fn = jax.vmap(warp_fn, in_axes=(None, 0, 0, None)) 156 | warp_out = warp_fn({'params': params}, points, metadata, state.extra_params) 157 | warped_points = warp_out['warped_points'][..., :3] 158 | sq_residual = jnp.sum((warped_points - points)**2, axis=-1) 159 | loss = utils.general_loss_with_squared_residual( 160 | sq_residual, alpha=alpha, scale=scale) 161 | return loss 162 | 163 | 164 | @functools.partial(jax.jit, 165 | static_argnums=0, 166 | static_argnames=('disable_hyper_grads', 167 | 'grad_max_val', 168 | 'grad_max_norm', 169 | 'use_elastic_loss', 170 | 'elastic_reduce_method', 171 | 'elastic_loss_type', 172 | 'use_background_loss', 173 | 'use_warp_reg_loss', 174 | 'use_hyper_reg_loss')) 175 | def train_step(model: models.NerfModel, 176 | rng_key: Callable[[int], jnp.ndarray], 177 | state: model_utils.TrainState, 178 | batch: Dict[str, Any], 179 | scalar_params: ScalarParams, 180 | disable_hyper_grads: bool = False, 181 | grad_max_val: float = 0.0, 182 | grad_max_norm: float = 0.0, 183 | use_elastic_loss: bool = False, 184 | elastic_reduce_method: str = 'median', 185 | elastic_loss_type: str = 'log_svals', 186 | use_background_loss: bool = False, 187 | use_warp_reg_loss: bool = False, 188 | use_hyper_reg_loss: bool = False): 189 | """One optimization step. 190 | 191 | Args: 192 | model: the model module to evaluate. 193 | rng_key: The random number generator. 194 | state: model_utils.TrainState, state of model and optimizer. 195 | batch: dict. A mini-batch of data for training. 196 | scalar_params: scalar-valued parameters. 197 | disable_hyper_grads: if True disable gradients to the hyper coordinate 198 | branches. 199 | grad_max_val: The gradient clipping value (disabled if == 0). 200 | grad_max_norm: The gradient clipping magnitude (disabled if == 0). 201 | use_elastic_loss: is True use the elastic regularization loss. 202 | elastic_reduce_method: which method to use to reduce the samples for the 203 | elastic loss. 'median' selects the median depth point sample while 204 | 'weight' computes a weighted sum using the density weights. 205 | elastic_loss_type: which method to use for the elastic loss. 206 | use_background_loss: if True use the background regularization loss. 207 | use_warp_reg_loss: if True use the warp regularization loss. 208 | use_hyper_reg_loss: if True regularize the hyper points. 209 | 210 | Returns: 211 | new_state: model_utils.TrainState, new training state. 212 | stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)]. 213 | """ 214 | rng_key, fine_key, coarse_key, reg_key = random.split(rng_key, 4) 215 | 216 | # pylint: disable=unused-argument 217 | def _compute_loss_and_stats( 218 | params, model_out, level, 219 | use_elastic_loss=False, 220 | use_hyper_reg_loss=False): 221 | 222 | if 'channel_set' in batch['metadata']: 223 | num_sets = int(model_out['rgb'].shape[-1] / 3) 224 | losses = [] 225 | for i in range(num_sets): 226 | loss = (model_out['rgb'][..., i * 3:(i + 1) * 3] - batch['rgb'])**2 227 | loss *= (batch['metadata']['channel_set'] == i) 228 | losses.append(loss) 229 | rgb_loss = jnp.sum(jnp.asarray(losses), axis=0).mean() 230 | else: 231 | rgb_loss = ((model_out['rgb'][..., :3] - batch['rgb'][..., :3])**2).mean() 232 | stats = { 233 | 'loss/rgb': rgb_loss, 234 | } 235 | loss = rgb_loss 236 | if use_elastic_loss: 237 | elastic_fn = functools.partial(compute_elastic_loss, 238 | loss_type=elastic_loss_type) 239 | v_elastic_fn = jax.jit(vmap(vmap(jax.jit(elastic_fn)))) 240 | weights = lax.stop_gradient(model_out['weights']) 241 | jacobian = model_out['warp_jacobian'] 242 | # Pick the median point Jacobian. 243 | if elastic_reduce_method == 'median': 244 | depth_indices = model_utils.compute_depth_index(weights) 245 | jacobian = jnp.take_along_axis( 246 | # Unsqueeze axes: sample axis, Jacobian row, Jacobian col. 247 | jacobian, depth_indices[..., None, None, None], axis=-3) 248 | # Compute loss using Jacobian. 249 | elastic_loss, elastic_residual = v_elastic_fn(jacobian) 250 | # Multiply weight if weighting by density. 251 | if elastic_reduce_method == 'weight': 252 | elastic_loss = weights * elastic_loss 253 | elastic_loss = elastic_loss.sum(axis=-1).mean() 254 | stats['loss/elastic'] = elastic_loss 255 | stats['residual/elastic'] = jnp.mean(elastic_residual) 256 | loss += scalar_params.elastic_loss_weight * elastic_loss 257 | 258 | if use_warp_reg_loss: 259 | weights = lax.stop_gradient(model_out['weights']) 260 | depth_indices = model_utils.compute_depth_index(weights) 261 | warp_mag = ((model_out['points'] 262 | - model_out['warped_points'][..., :3]) ** 2).sum(axis=-1) 263 | warp_reg_residual = jnp.take_along_axis( 264 | warp_mag, depth_indices[..., None], axis=-1) 265 | warp_reg_loss = utils.general_loss_with_squared_residual( 266 | warp_reg_residual, 267 | alpha=scalar_params.warp_reg_loss_alpha, 268 | scale=scalar_params.warp_reg_loss_scale).mean() 269 | stats['loss/warp_reg'] = warp_reg_loss 270 | stats['residual/warp_reg'] = jnp.mean(jnp.sqrt(warp_reg_residual)) 271 | loss += scalar_params.warp_reg_loss_weight * warp_reg_loss 272 | 273 | if use_hyper_reg_loss: 274 | weights = lax.stop_gradient(model_out['weights']) 275 | hyper_points = model_out['warped_points'][..., 3:] 276 | hyper_reg_residual = (hyper_points ** 2).sum(axis=-1) 277 | hyper_reg_loss = utils.general_loss_with_squared_residual( 278 | hyper_reg_residual, alpha=0.0, scale=0.05) 279 | assert weights.shape == hyper_reg_loss.shape 280 | hyper_reg_loss = (weights * hyper_reg_loss).sum(axis=1).mean() 281 | stats['loss/hyper_reg'] = hyper_reg_loss 282 | stats['residual/hyper_reg'] = jnp.mean(jnp.sqrt(hyper_reg_residual)) 283 | loss += scalar_params.hyper_reg_loss_weight * hyper_reg_loss 284 | 285 | if 'warp_jacobian' in model_out: 286 | jacobian = model_out['warp_jacobian'] 287 | jacobian_det = jnp.linalg.det(jacobian) 288 | jacobian_div = utils.jacobian_to_div(jacobian) 289 | jacobian_curl = utils.jacobian_to_curl(jacobian) 290 | stats['metric/jacobian_det'] = jnp.mean(jacobian_det) 291 | stats['metric/jacobian_div'] = jnp.mean(jacobian_div) 292 | stats['metric/jacobian_curl'] = jnp.mean( 293 | jnp.linalg.norm(jacobian_curl, axis=-1)) 294 | 295 | stats['loss/total'] = loss 296 | stats['metric/psnr'] = utils.compute_psnr(rgb_loss) 297 | return loss, stats 298 | 299 | def _loss_fn(params): 300 | ret = model.apply({'params': params['model']}, 301 | batch, 302 | extra_params=state.extra_params, 303 | return_points=(use_warp_reg_loss or use_hyper_reg_loss), 304 | return_weights=(use_warp_reg_loss or use_elastic_loss), 305 | return_warp_jacobian=use_elastic_loss, 306 | rngs={ 307 | 'fine': fine_key, 308 | 'coarse': coarse_key 309 | }) 310 | 311 | losses = {} 312 | stats = {} 313 | if 'fine' in ret: 314 | losses['fine'], stats['fine'] = _compute_loss_and_stats( 315 | params, ret['fine'], 'fine') 316 | if 'coarse' in ret: 317 | losses['coarse'], stats['coarse'] = _compute_loss_and_stats( 318 | params, ret['coarse'], 'coarse', 319 | use_elastic_loss=use_elastic_loss, 320 | use_hyper_reg_loss=use_hyper_reg_loss) 321 | 322 | if use_background_loss: 323 | background_loss = compute_background_loss( 324 | model, 325 | state=state, 326 | params=params['model'], 327 | key=reg_key, 328 | points=batch['background_points'], 329 | noise_std=scalar_params.background_noise_std) 330 | background_loss = background_loss.mean() 331 | losses['background'] = ( 332 | scalar_params.background_loss_weight * background_loss) 333 | stats['background_loss'] = background_loss 334 | 335 | return sum(losses.values()), (stats, ret) 336 | 337 | optimizer = state.optimizer 338 | if disable_hyper_grads: 339 | optimizer = optimizer.replace( 340 | state=zero_adam_param_states(optimizer.state, 'model/hyper_sheet_mlp')) 341 | 342 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 343 | (_, (stats, model_out)), grad = grad_fn(optimizer.target) 344 | grad = jax.lax.pmean(grad, axis_name='batch') 345 | if grad_max_val > 0.0 or grad_max_norm > 0.0: 346 | grad = utils.clip_gradients(grad, grad_max_val, grad_max_norm) 347 | stats = jax.lax.pmean(stats, axis_name='batch') 348 | model_out = jax.lax.pmean(model_out, axis_name='batch') 349 | new_optimizer = optimizer.apply_gradient( 350 | grad, learning_rate=scalar_params.learning_rate) 351 | new_state = state.replace(optimizer=new_optimizer) 352 | return new_state, stats, rng_key, model_out 353 | -------------------------------------------------------------------------------- /hypernerf/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom type annotations.""" 16 | import pathlib 17 | from typing import Any, Callable, Tuple, Text, Union 18 | 19 | PRNGKey = Any 20 | Shape = Tuple[int] 21 | Dtype = Any # this could be a real type? 22 | Array = Any 23 | 24 | Activation = Callable[[Array], Array] 25 | Initializer = Callable[[PRNGKey, Shape, Dtype], Array] 26 | Normalizer = Callable[[], Callable[[Array], Array]] 27 | 28 | PathType = Union[Text, pathlib.PurePosixPath] 29 | -------------------------------------------------------------------------------- /hypernerf/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Non-differentiable utility functions.""" 17 | import collections 18 | from concurrent import futures 19 | import contextlib 20 | import functools 21 | import time 22 | from typing import List, Union 23 | 24 | import jax 25 | from jax import tree_util 26 | import jax.numpy as jnp 27 | import numpy as np 28 | from scipy import interpolate 29 | from scipy.spatial import transform as scipy_transform 30 | 31 | 32 | def clip_gradients(grad, grad_max_val=0.0, grad_max_norm=0.0, eps=1e-7): 33 | """Gradient clipping.""" 34 | # Clip the gradient by value. 35 | if grad_max_val > 0: 36 | clip_fn = lambda z: jnp.clip(z, -grad_max_val, grad_max_val) 37 | grad = jax.tree_util.tree_map(clip_fn, grad) 38 | 39 | # Clip the (possibly value-clipped) gradient by norm. 40 | if grad_max_norm > 0: 41 | grad_norm = safe_sqrt( 42 | jax.tree_util.tree_reduce( 43 | lambda x, y: x + jnp.sum(y**2), grad, initializer=0)) 44 | mult = jnp.minimum(1, grad_max_norm / (eps + grad_norm)) 45 | grad = jax.tree_util.tree_map(lambda z: mult * z, grad) 46 | 47 | return grad 48 | 49 | 50 | def matmul(a, b): 51 | """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" 52 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 53 | 54 | 55 | # pylint: disable=unused-argument 56 | @functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 3)) 57 | def safe_norm(x, axis=-1, keepdims=False, tol=1e-9): 58 | """Calculates a np.linalg.norm(d) that's safe for gradients at d=0. 59 | 60 | These gymnastics are to avoid a poorly defined gradient for np.linal.norm(0) 61 | see https://github.com/google/jax/issues/3058 for details 62 | 63 | 64 | Args: 65 | x: A np.array 66 | axis: The axis along which to compute the norm 67 | keepdims: if True don't squeeze the axis. 68 | tol: the absolute threshold within which to zero out the gradient. 69 | 70 | Returns: 71 | Equivalent to np.linalg.norm(d) 72 | """ 73 | return jnp.linalg.norm(x, axis=axis, keepdims=keepdims) 74 | 75 | 76 | @safe_norm.defjvp 77 | def _safe_norm_jvp(axis, keepdims, tol, primals, tangents): 78 | """Custom JVP rule for safe_norm.""" 79 | x, = primals 80 | x_dot, = tangents 81 | safe_tol = max(tol, 1e-30) 82 | y = safe_norm(x, tol=safe_tol, axis=axis, keepdims=True) 83 | y_safe = jnp.maximum(y, tol) # Prevent divide by zero. 84 | y_dot = jnp.where(y > safe_tol, x_dot * x / y_safe, jnp.zeros_like(x)) 85 | y_dot = jnp.sum(y_dot, axis=axis, keepdims=True) 86 | # Squeeze the axis if `keepdims` is True. 87 | if not keepdims: 88 | y = jnp.squeeze(y, axis=axis) 89 | y_dot = jnp.squeeze(y_dot, axis=axis) 90 | return y, y_dot 91 | 92 | 93 | def jacobian_to_curl(jacobian): 94 | """Computes the curl from the Jacobian.""" 95 | dfx_dy = jacobian[..., 0, 1] 96 | dfx_dz = jacobian[..., 0, 2] 97 | dfy_dx = jacobian[..., 1, 0] 98 | dfy_dz = jacobian[..., 1, 2] 99 | dfz_dx = jacobian[..., 2, 0] 100 | dfz_dy = jacobian[..., 2, 1] 101 | 102 | return jnp.stack([ 103 | dfz_dy - dfy_dz, 104 | dfx_dz - dfz_dx, 105 | dfy_dx - dfx_dy, 106 | ], axis=-1) 107 | 108 | 109 | def jacobian_to_div(jacobian): 110 | """Computes the divergence from the Jacobian.""" 111 | # If F : x -> x + f(x) then dF/dx = 1 + df/dx, so subtract 1 for each 112 | # diagonal of the Jacobian. 113 | return jnp.trace(jacobian, axis1=-2, axis2=-1) - 3.0 114 | 115 | 116 | def compute_psnr(mse): 117 | """Compute psnr value given mse (we assume the maximum pixel value is 1). 118 | 119 | Args: 120 | mse: float, mean square error of pixels. 121 | 122 | Returns: 123 | psnr: float, the psnr value. 124 | """ 125 | return -10. * jnp.log(mse) / jnp.log(10.) 126 | 127 | 128 | @jax.jit 129 | def robust_whiten(x): 130 | median = jnp.nanmedian(x) 131 | mad = jnp.nanmean(jnp.abs(x - median)) 132 | return (x - median) / mad 133 | 134 | 135 | def interpolate_codes(codes: Union[np.ndarray, List[np.ndarray]], 136 | num_samples: int, 137 | method='spline', 138 | bc_type='natural'): 139 | """Interpolates latent codes. 140 | 141 | Args: 142 | codes: the codes to interpolate. 143 | num_samples: the number of samples to interpolate to. 144 | method: which method to use for interpolation. 145 | bc_type: interpolation type for spline interpolation. 146 | 147 | Returns: 148 | (np.ndarray): the interpolated codes. 149 | """ 150 | if isinstance(codes, list): 151 | codes = np.array(codes) 152 | t = np.arange(len(codes)) 153 | xs = np.linspace(0, len(codes) - 1, num_samples) 154 | if method == 'spline': 155 | cs = interpolate.CubicSpline(t, codes, bc_type=bc_type) 156 | return cs(xs).astype(np.float32) 157 | elif method in {'linear', 'cubic', 'quadratic', 'slinear'}: 158 | interp = interpolate.interp1d(t, codes, axis=0) 159 | return interp(xs).astype(np.float32) 160 | 161 | raise ValueError(f'Unknown method {method!r}') 162 | 163 | 164 | def interpolate_cameras(cameras, num_samples: int): 165 | """Interpolates the cameras to the number of output samples. 166 | 167 | Uses a spherical linear interpolation (Slerp) to interpolate the camera 168 | orientations and a cubic spline to interpolate the camera positions. 169 | 170 | Args: 171 | cameras: the input cameras to interpolate. 172 | num_samples: the number of output cameras. 173 | 174 | Returns: 175 | (List[vision_sfm.Camera]): a list of interpolated cameras. 176 | """ 177 | rotations = [] 178 | positions = [] 179 | for camera in cameras: 180 | rotations.append(camera.orientation) 181 | positions.append(camera.position) 182 | 183 | in_times = np.linspace(0, 1, len(rotations)) 184 | slerp = scipy_transform.Slerp( 185 | in_times, scipy_transform.Rotation.from_dcm(rotations)) 186 | spline = interpolate.CubicSpline(in_times, positions) 187 | 188 | out_times = np.linspace(0, 1, num_samples) 189 | out_rots = slerp(out_times).as_dcm() 190 | out_positions = spline(out_times) 191 | 192 | ref_camera = cameras[0] 193 | out_cameras = [] 194 | for out_rot, out_pos in zip(out_rots, out_positions): 195 | out_camera = ref_camera.copy() 196 | out_camera.orientation = out_rot 197 | out_camera.position = out_pos 198 | out_cameras.append(out_camera) 199 | return out_cameras 200 | 201 | 202 | def safe_sqrt(x, eps=1e-7): 203 | safe_x = jnp.where(x == 0, jnp.ones_like(x) * eps, x) 204 | return jnp.sqrt(safe_x) 205 | 206 | 207 | @jax.jit 208 | def general_loss_with_squared_residual(x_sq, alpha, scale): 209 | r"""Implements the general form of the loss. 210 | 211 | This implements the rho(x, \alpha, c) function described in "A General and 212 | Adaptive Robust Loss Function", Jonathan T. Barron, 213 | https://arxiv.org/abs/1701.03077. 214 | Args: 215 | x_sq: The residual for which the loss is being computed. x can have any 216 | shape, and alpha and scale will be broadcasted to match x's shape if 217 | necessary. 218 | alpha: The shape parameter of the loss (\alpha in the paper), where more 219 | negative values produce a loss with more robust behavior (outliers "cost" 220 | less), and more positive values produce a loss with less robust behavior 221 | (outliers are penalized more heavily). Alpha can be any value in 222 | [-infinity, infinity], but the gradient of the loss with respect to alpha 223 | is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth 224 | interpolation between several discrete robust losses: 225 | alpha=-Infinity: Welsch/Leclerc Loss. 226 | alpha=-2: Geman-McClure loss. 227 | alpha=0: Cauchy/Lortentzian loss. 228 | alpha=1: Charbonnier/pseudo-Huber loss. 229 | alpha=2: L2 loss. 230 | scale: The scale parameter of the loss. When |x| < scale, the loss is an 231 | L2-like quadratic bowl, and when |x| > scale the loss function takes on a 232 | different shape according to alpha. 233 | 234 | Returns: 235 | The losses for each element of x, in the same shape as x. 236 | """ 237 | eps = jnp.finfo(jnp.float32).eps 238 | 239 | # `scale` must be > 0. 240 | scale = jnp.maximum(eps, scale) 241 | 242 | # The loss when alpha == 2. This will get reused repeatedly. 243 | loss_two = 0.5 * x_sq / (scale**2) 244 | 245 | # "Safe" versions of log1p and expm1 that will not NaN-out. 246 | log1p_safe = lambda x: jnp.log1p(jnp.minimum(x, 3e37)) 247 | expm1_safe = lambda x: jnp.expm1(jnp.minimum(x, 87.5)) 248 | 249 | # The loss when not in one of the special casess. 250 | # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. 251 | a = jnp.where(alpha >= 0, jnp.ones_like(alpha), 252 | -jnp.ones_like(alpha)) * jnp.maximum(eps, jnp.abs(alpha)) 253 | # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. 254 | b = jnp.maximum(eps, jnp.abs(alpha - 2)) 255 | loss_ow = (b / a) * ((loss_two / (0.5 * b) + 1)**(0.5 * alpha) - 1) 256 | 257 | # Select which of the cases of the loss to return as a function of alpha. 258 | return scale * jnp.where( 259 | alpha == -jnp.inf, -expm1_safe(-loss_two), 260 | jnp.where( 261 | alpha == 0, log1p_safe(loss_two), 262 | jnp.where(alpha == 2, loss_two, 263 | jnp.where(alpha == jnp.inf, expm1_safe(loss_two), 264 | loss_ow)))) 265 | 266 | 267 | def points_bound(points): 268 | """Computes the min and max dims of the points.""" 269 | min_dim = np.min(points, axis=0) 270 | max_dim = np.max(points, axis=0) 271 | return np.stack((min_dim, max_dim), axis=1) 272 | 273 | 274 | def points_centroid(points): 275 | """Computes the centroid of the points from the bounding box.""" 276 | return points_bound(points).mean(axis=1) 277 | 278 | 279 | def points_bounding_size(points): 280 | """Computes the bounding size of the points from the bounding box.""" 281 | bounds = points_bound(points) 282 | return np.linalg.norm(bounds[:, 1] - bounds[:, 0]) 283 | 284 | 285 | def shard(xs, device_count=None): 286 | """Split data into shards for multiple devices along the first dimension.""" 287 | if device_count is None: 288 | device_count = jax.local_device_count() 289 | return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs) 290 | 291 | 292 | def to_device(xs): 293 | """Transfer data to devices (GPU/TPU).""" 294 | return jax.tree_map(jnp.array, xs) 295 | 296 | 297 | def unshard(x, padding=0): 298 | """Collect the sharded tensor to the shape before sharding.""" 299 | if padding > 0: 300 | return x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))[:-padding] 301 | else: 302 | return x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:])) 303 | 304 | 305 | def normalize(x): 306 | """Normalization helper function.""" 307 | return x / np.linalg.norm(x) 308 | 309 | 310 | def parallel_map(f, iterable, max_threads=None, show_pbar=False, **kwargs): 311 | """Parallel version of map().""" 312 | with futures.ThreadPoolExecutor(max_threads) as executor: 313 | if show_pbar: 314 | # pylint: disable=g-import-not-at-top 315 | import tqdm 316 | results = tqdm.tqdm( 317 | executor.map(f, iterable, **kwargs), total=len(iterable)) 318 | else: 319 | results = executor.map(f, iterable, **kwargs) 320 | return list(results) 321 | 322 | 323 | def parallel_tree_map(f, tree, **kwargs): 324 | """Parallel version of jax.tree_map.""" 325 | leaves, treedef = jax.tree_flatten(tree) 326 | results = parallel_map(f, leaves, **kwargs) 327 | return jax.tree_unflatten(treedef, results) 328 | 329 | 330 | def strided_subset(sequence, count): 331 | """Returns a strided subset of a list.""" 332 | if count: 333 | stride = max(1, len(sequence) // count) 334 | return sequence[::stride] 335 | return sequence 336 | 337 | 338 | def tree_collate(list_of_pytrees): 339 | """Collates a list of pytrees with the same structure.""" 340 | return tree_util.tree_multimap(lambda *x: np.stack(x), *list_of_pytrees) 341 | 342 | 343 | @contextlib.contextmanager 344 | def print_time(name): 345 | """Records the time elapsed.""" 346 | start = time.time() 347 | yield 348 | elapsed = time.time() - start 349 | print(f'[{name}] time elapsed: {elapsed:.04f}') 350 | 351 | 352 | class ValueMeter: 353 | """Tracks the average of a value.""" 354 | 355 | def __init__(self): 356 | self._values = [] 357 | 358 | def reset(self): 359 | """Resets the meter.""" 360 | self._values.clear() 361 | 362 | def update(self, value): 363 | """Adds a value to the meter.""" 364 | self._values.append(value) 365 | 366 | def reduce(self, reduction='mean'): 367 | """Reduces the tracked values.""" 368 | if reduction == 'mean': 369 | return np.mean(self._values) 370 | elif reduction == 'std': 371 | return np.std(self._values) 372 | elif reduction == 'last': 373 | return self._values[-1] 374 | else: 375 | raise ValueError(f'Unknown reduction {reduction}') 376 | 377 | 378 | class TimeTracker: 379 | """Tracks the average time elapsed over multiple steps.""" 380 | 381 | def __init__(self): 382 | self._meters = collections.defaultdict(ValueMeter) 383 | self._marked_time = collections.defaultdict(float) 384 | 385 | @contextlib.contextmanager 386 | def record_time(self, key: str): 387 | """Records the time elapsed.""" 388 | start = time.time() 389 | yield 390 | elapsed = time.time() - start 391 | self.update(key, elapsed) 392 | 393 | def update(self, key, value): 394 | """Updates the time value for a given key.""" 395 | self._meters[key].update(value) 396 | 397 | def tic(self, *args): 398 | """Marks the starting time of an event.""" 399 | for key in args: 400 | self._marked_time[key] = time.time() 401 | 402 | def toc(self, *args): 403 | """Records the time elapsed based on the previous call to `tic`.""" 404 | for key in args: 405 | self.update(key, time.time() - self._marked_time[key]) 406 | del self._marked_time[key] 407 | 408 | def reset(self): 409 | """Resets all time meters.""" 410 | for meter in self._meters.values(): 411 | meter.reset() 412 | 413 | def summary(self, reduction='mean'): 414 | """Returns a dictionary of reduced times.""" 415 | time_dict = {k: v.reduce(reduction) for k, v in self._meters.items()} 416 | if 'total' not in time_dict: 417 | time_dict['total'] = sum(time_dict.values()) 418 | 419 | time_dict['steps_per_sec'] = 1.0 / time_dict['total'] 420 | return time_dict 421 | 422 | def summary_str(self, reduction='mean'): 423 | """Returns a string of reduced times.""" 424 | strings = [f'{k}={v:.04f}' for k, v in self.summary(reduction).items()] 425 | return ', '.join(strings) 426 | -------------------------------------------------------------------------------- /hypernerf/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Visualization utilities.""" 16 | import contextlib 17 | import functools 18 | 19 | from matplotlib import cm 20 | from matplotlib import pyplot as plt 21 | from matplotlib.colors import LinearSegmentedColormap 22 | import numpy as np 23 | 24 | 25 | _TURBO_COLORS = np.array( 26 | [[0.18995, 0.07176, 0.23217], [0.19483, 0.08339, 0.26149], 27 | [0.19956, 0.09498, 0.29024], [0.20415, 0.10652, 0.31844], 28 | [0.20860, 0.11802, 0.34607], [0.21291, 0.12947, 0.37314], 29 | [0.21708, 0.14087, 0.39964], [0.22111, 0.15223, 0.42558], 30 | [0.22500, 0.16354, 0.45096], [0.22875, 0.17481, 0.47578], 31 | [0.23236, 0.18603, 0.50004], [0.23582, 0.19720, 0.52373], 32 | [0.23915, 0.20833, 0.54686], [0.24234, 0.21941, 0.56942], 33 | [0.24539, 0.23044, 0.59142], [0.24830, 0.24143, 0.61286], 34 | [0.25107, 0.25237, 0.63374], [0.25369, 0.26327, 0.65406], 35 | [0.25618, 0.27412, 0.67381], [0.25853, 0.28492, 0.69300], 36 | [0.26074, 0.29568, 0.71162], [0.26280, 0.30639, 0.72968], 37 | [0.26473, 0.31706, 0.74718], [0.26652, 0.32768, 0.76412], 38 | [0.26816, 0.33825, 0.78050], [0.26967, 0.34878, 0.79631], 39 | [0.27103, 0.35926, 0.81156], [0.27226, 0.36970, 0.82624], 40 | [0.27334, 0.38008, 0.84037], [0.27429, 0.39043, 0.85393], 41 | [0.27509, 0.40072, 0.86692], [0.27576, 0.41097, 0.87936], 42 | [0.27628, 0.42118, 0.89123], [0.27667, 0.43134, 0.90254], 43 | [0.27691, 0.44145, 0.91328], [0.27701, 0.45152, 0.92347], 44 | [0.27698, 0.46153, 0.93309], [0.27680, 0.47151, 0.94214], 45 | [0.27648, 0.48144, 0.95064], [0.27603, 0.49132, 0.95857], 46 | [0.27543, 0.50115, 0.96594], [0.27469, 0.51094, 0.97275], 47 | [0.27381, 0.52069, 0.97899], [0.27273, 0.53040, 0.98461], 48 | [0.27106, 0.54015, 0.98930], [0.26878, 0.54995, 0.99303], 49 | [0.26592, 0.55979, 0.99583], [0.26252, 0.56967, 0.99773], 50 | [0.25862, 0.57958, 0.99876], [0.25425, 0.58950, 0.99896], 51 | [0.24946, 0.59943, 0.99835], [0.24427, 0.60937, 0.99697], 52 | [0.23874, 0.61931, 0.99485], [0.23288, 0.62923, 0.99202], 53 | [0.22676, 0.63913, 0.98851], [0.22039, 0.64901, 0.98436], 54 | [0.21382, 0.65886, 0.97959], [0.20708, 0.66866, 0.97423], 55 | [0.20021, 0.67842, 0.96833], [0.19326, 0.68812, 0.96190], 56 | [0.18625, 0.69775, 0.95498], [0.17923, 0.70732, 0.94761], 57 | [0.17223, 0.71680, 0.93981], [0.16529, 0.72620, 0.93161], 58 | [0.15844, 0.73551, 0.92305], [0.15173, 0.74472, 0.91416], 59 | [0.14519, 0.75381, 0.90496], [0.13886, 0.76279, 0.89550], 60 | [0.13278, 0.77165, 0.88580], [0.12698, 0.78037, 0.87590], 61 | [0.12151, 0.78896, 0.86581], [0.11639, 0.79740, 0.85559], 62 | [0.11167, 0.80569, 0.84525], [0.10738, 0.81381, 0.83484], 63 | [0.10357, 0.82177, 0.82437], [0.10026, 0.82955, 0.81389], 64 | [0.09750, 0.83714, 0.80342], [0.09532, 0.84455, 0.79299], 65 | [0.09377, 0.85175, 0.78264], [0.09287, 0.85875, 0.77240], 66 | [0.09267, 0.86554, 0.76230], [0.09320, 0.87211, 0.75237], 67 | [0.09451, 0.87844, 0.74265], [0.09662, 0.88454, 0.73316], 68 | [0.09958, 0.89040, 0.72393], [0.10342, 0.89600, 0.71500], 69 | [0.10815, 0.90142, 0.70599], [0.11374, 0.90673, 0.69651], 70 | [0.12014, 0.91193, 0.68660], [0.12733, 0.91701, 0.67627], 71 | [0.13526, 0.92197, 0.66556], [0.14391, 0.92680, 0.65448], 72 | [0.15323, 0.93151, 0.64308], [0.16319, 0.93609, 0.63137], 73 | [0.17377, 0.94053, 0.61938], [0.18491, 0.94484, 0.60713], 74 | [0.19659, 0.94901, 0.59466], [0.20877, 0.95304, 0.58199], 75 | [0.22142, 0.95692, 0.56914], [0.23449, 0.96065, 0.55614], 76 | [0.24797, 0.96423, 0.54303], [0.26180, 0.96765, 0.52981], 77 | [0.27597, 0.97092, 0.51653], [0.29042, 0.97403, 0.50321], 78 | [0.30513, 0.97697, 0.48987], [0.32006, 0.97974, 0.47654], 79 | [0.33517, 0.98234, 0.46325], [0.35043, 0.98477, 0.45002], 80 | [0.36581, 0.98702, 0.43688], [0.38127, 0.98909, 0.42386], 81 | [0.39678, 0.99098, 0.41098], [0.41229, 0.99268, 0.39826], 82 | [0.42778, 0.99419, 0.38575], [0.44321, 0.99551, 0.37345], 83 | [0.45854, 0.99663, 0.36140], [0.47375, 0.99755, 0.34963], 84 | [0.48879, 0.99828, 0.33816], [0.50362, 0.99879, 0.32701], 85 | [0.51822, 0.99910, 0.31622], [0.53255, 0.99919, 0.30581], 86 | [0.54658, 0.99907, 0.29581], [0.56026, 0.99873, 0.28623], 87 | [0.57357, 0.99817, 0.27712], [0.58646, 0.99739, 0.26849], 88 | [0.59891, 0.99638, 0.26038], [0.61088, 0.99514, 0.25280], 89 | [0.62233, 0.99366, 0.24579], [0.63323, 0.99195, 0.23937], 90 | [0.64362, 0.98999, 0.23356], [0.65394, 0.98775, 0.22835], 91 | [0.66428, 0.98524, 0.22370], [0.67462, 0.98246, 0.21960], 92 | [0.68494, 0.97941, 0.21602], [0.69525, 0.97610, 0.21294], 93 | [0.70553, 0.97255, 0.21032], [0.71577, 0.96875, 0.20815], 94 | [0.72596, 0.96470, 0.20640], [0.73610, 0.96043, 0.20504], 95 | [0.74617, 0.95593, 0.20406], [0.75617, 0.95121, 0.20343], 96 | [0.76608, 0.94627, 0.20311], [0.77591, 0.94113, 0.20310], 97 | [0.78563, 0.93579, 0.20336], [0.79524, 0.93025, 0.20386], 98 | [0.80473, 0.92452, 0.20459], [0.81410, 0.91861, 0.20552], 99 | [0.82333, 0.91253, 0.20663], [0.83241, 0.90627, 0.20788], 100 | [0.84133, 0.89986, 0.20926], [0.85010, 0.89328, 0.21074], 101 | [0.85868, 0.88655, 0.21230], [0.86709, 0.87968, 0.21391], 102 | [0.87530, 0.87267, 0.21555], [0.88331, 0.86553, 0.21719], 103 | [0.89112, 0.85826, 0.21880], [0.89870, 0.85087, 0.22038], 104 | [0.90605, 0.84337, 0.22188], [0.91317, 0.83576, 0.22328], 105 | [0.92004, 0.82806, 0.22456], [0.92666, 0.82025, 0.22570], 106 | [0.93301, 0.81236, 0.22667], [0.93909, 0.80439, 0.22744], 107 | [0.94489, 0.79634, 0.22800], [0.95039, 0.78823, 0.22831], 108 | [0.95560, 0.78005, 0.22836], [0.96049, 0.77181, 0.22811], 109 | [0.96507, 0.76352, 0.22754], [0.96931, 0.75519, 0.22663], 110 | [0.97323, 0.74682, 0.22536], [0.97679, 0.73842, 0.22369], 111 | [0.98000, 0.73000, 0.22161], [0.98289, 0.72140, 0.21918], 112 | [0.98549, 0.71250, 0.21650], [0.98781, 0.70330, 0.21358], 113 | [0.98986, 0.69382, 0.21043], [0.99163, 0.68408, 0.20706], 114 | [0.99314, 0.67408, 0.20348], [0.99438, 0.66386, 0.19971], 115 | [0.99535, 0.65341, 0.19577], [0.99607, 0.64277, 0.19165], 116 | [0.99654, 0.63193, 0.18738], [0.99675, 0.62093, 0.18297], 117 | [0.99672, 0.60977, 0.17842], [0.99644, 0.59846, 0.17376], 118 | [0.99593, 0.58703, 0.16899], [0.99517, 0.57549, 0.16412], 119 | [0.99419, 0.56386, 0.15918], [0.99297, 0.55214, 0.15417], 120 | [0.99153, 0.54036, 0.14910], [0.98987, 0.52854, 0.14398], 121 | [0.98799, 0.51667, 0.13883], [0.98590, 0.50479, 0.13367], 122 | [0.98360, 0.49291, 0.12849], [0.98108, 0.48104, 0.12332], 123 | [0.97837, 0.46920, 0.11817], [0.97545, 0.45740, 0.11305], 124 | [0.97234, 0.44565, 0.10797], [0.96904, 0.43399, 0.10294], 125 | [0.96555, 0.42241, 0.09798], [0.96187, 0.41093, 0.09310], 126 | [0.95801, 0.39958, 0.08831], [0.95398, 0.38836, 0.08362], 127 | [0.94977, 0.37729, 0.07905], [0.94538, 0.36638, 0.07461], 128 | [0.94084, 0.35566, 0.07031], [0.93612, 0.34513, 0.06616], 129 | [0.93125, 0.33482, 0.06218], [0.92623, 0.32473, 0.05837], 130 | [0.92105, 0.31489, 0.05475], [0.91572, 0.30530, 0.05134], 131 | [0.91024, 0.29599, 0.04814], [0.90463, 0.28696, 0.04516], 132 | [0.89888, 0.27824, 0.04243], [0.89298, 0.26981, 0.03993], 133 | [0.88691, 0.26152, 0.03753], [0.88066, 0.25334, 0.03521], 134 | [0.87422, 0.24526, 0.03297], [0.86760, 0.23730, 0.03082], 135 | [0.86079, 0.22945, 0.02875], [0.85380, 0.22170, 0.02677], 136 | [0.84662, 0.21407, 0.02487], [0.83926, 0.20654, 0.02305], 137 | [0.83172, 0.19912, 0.02131], [0.82399, 0.19182, 0.01966], 138 | [0.81608, 0.18462, 0.01809], [0.80799, 0.17753, 0.01660], 139 | [0.79971, 0.17055, 0.01520], [0.79125, 0.16368, 0.01387], 140 | [0.78260, 0.15693, 0.01264], [0.77377, 0.15028, 0.01148], 141 | [0.76476, 0.14374, 0.01041], [0.75556, 0.13731, 0.00942], 142 | [0.74617, 0.13098, 0.00851], [0.73661, 0.12477, 0.00769], 143 | [0.72686, 0.11867, 0.00695], [0.71692, 0.11268, 0.00629], 144 | [0.70680, 0.10680, 0.00571], [0.69650, 0.10102, 0.00522], 145 | [0.68602, 0.09536, 0.00481], [0.67535, 0.08980, 0.00449], 146 | [0.66449, 0.08436, 0.00424], [0.65345, 0.07902, 0.00408], 147 | [0.64223, 0.07380, 0.00401], [0.63082, 0.06868, 0.00401], 148 | [0.61923, 0.06367, 0.00410], [0.60746, 0.05878, 0.00427], 149 | [0.59550, 0.05399, 0.00453], [0.58336, 0.04931, 0.00486], 150 | [0.57103, 0.04474, 0.00529], [0.55852, 0.04028, 0.00579], 151 | [0.54583, 0.03593, 0.00638], [0.53295, 0.03169, 0.00705], 152 | [0.51989, 0.02756, 0.00780], [0.50664, 0.02354, 0.00863], 153 | [0.49321, 0.01963, 0.00955], [0.47960, 0.01583, 0.01055]]) 154 | 155 | _colormap_cache = {} 156 | 157 | 158 | def _build_colormap(name, num_bins=256): 159 | base = cm.get_cmap(name) 160 | color_list = base(np.linspace(0, 1, num_bins)) 161 | cmap_name = base.name + str(num_bins) 162 | colormap = LinearSegmentedColormap.from_list(cmap_name, color_list, num_bins) 163 | colormap = colormap(np.linspace(0, 1, num_bins))[:, :3] 164 | return colormap 165 | 166 | 167 | def sinebow(h): 168 | f = lambda x: np.sin(np.pi * x)**2 169 | return np.stack([f(3/6-h), f(5/6-h), f(7/6-h)], -1) 170 | 171 | 172 | @functools.lru_cache(maxsize=32) 173 | def get_colormap(name, num_bins=256): 174 | """Lazily initializes and returns a colormap.""" 175 | if name == 'turbo': 176 | return _TURBO_COLORS 177 | elif name == 'sinebow': 178 | c = np.array([sinebow(i) for i in np.linspace(0, 1, num_bins)]) 179 | return c 180 | 181 | return _build_colormap(name, num_bins) 182 | 183 | 184 | def interpolate_colormap(values, colormap): 185 | """Interpolates the colormap given values between 0.0 and 1.0.""" 186 | a = np.floor(values * 255) 187 | b = (a + 1).clip(max=255) 188 | f = values * 255.0 - a 189 | a = a.astype(np.uint16).clip(0, 255) 190 | b = b.astype(np.uint16).clip(0, 255) 191 | return colormap[a] + (colormap[b] - colormap[a]) * f[..., np.newaxis] 192 | 193 | 194 | def scale_values(values, vmin, vmax, eps=1e-6): 195 | return (values - vmin) / max(vmax - vmin, eps) 196 | 197 | 198 | def colorize(array, 199 | cmin=None, 200 | cmax=None, 201 | cmap='magma', 202 | eps=1e-6, 203 | invert=False, 204 | clip=False): 205 | """Applies a colormap to an array. 206 | 207 | Args: 208 | array: the array to apply a colormap to. 209 | cmin: the minimum value of the colormap. If None will take the min. 210 | cmax: the maximum value of the colormap. If None will take the max. 211 | cmap: the color mapping to use. 212 | eps: a small value to prevent divide by zero. 213 | invert: if True will invert the colormap. 214 | clip: if True, clip values instead of setting to white/black. 215 | 216 | Returns: 217 | a color mapped version of array. 218 | """ 219 | array = np.asarray(array) 220 | 221 | if cmin is None: 222 | cmin = array.min() 223 | if cmax is None: 224 | cmax = array.max() 225 | 226 | x = scale_values(array, cmin, cmax, eps) 227 | if clip: 228 | x = np.clip(x, a_min=0.0, a_max=1.0) 229 | colormap = get_colormap(cmap) 230 | colorized = interpolate_colormap(1.0 - x if invert else x, colormap) 231 | colorized[x > 1.0] = 0.0 if invert else 1.0 232 | colorized[x < 0.0] = 1.0 if invert else 0.0 233 | 234 | return colorized 235 | 236 | 237 | def colorize_binary_logits(array, cmap=None): 238 | """Colorizes binary logits as a segmentation map.""" 239 | num_classes = array.shape[-1] 240 | if cmap is None: 241 | if num_classes <= 8: 242 | cmap = 'Set3' 243 | elif num_classes <= 10: 244 | cmap = 'tab10' 245 | elif num_classes <= 20: 246 | cmap = 'tab20' 247 | else: 248 | cmap = 'gist_rainbow' 249 | 250 | colormap = get_colormap(cmap, num_classes) 251 | indices = np.argmax(array, axis=-1) 252 | return np.take(colormap, indices, axis=0) 253 | 254 | 255 | @contextlib.contextmanager 256 | def plot_to_array(height, width, rows=1, cols=1, dpi=100, no_axis=False, 257 | use_alpha=False): 258 | """A context manager that plots to a numpy array. 259 | 260 | When the context manager exits the output array will be populated with an 261 | image of the plot. 262 | 263 | Usage: 264 | ``` 265 | with plot_to_array(480, 640, 2, 2) as (fig, axes, out_image): 266 | axes[0][0].plot(...) 267 | ``` 268 | Args: 269 | height: the height of the canvas 270 | width: the width of the canvas 271 | rows: the number of axis rows 272 | cols: the number of axis columns 273 | dpi: the DPI to render at 274 | no_axis: if True will hide the axes of the plot 275 | use_alpha: if True return RGBA images. 276 | 277 | Yields: 278 | A 3-tuple of: a pyplot Figure, array of Axes, and the output np.ndarray. 279 | """ 280 | num_channels = 4 if use_alpha else 3 281 | out_array = np.empty((height, width, num_channels), dtype=np.uint8) 282 | fig, axes = plt.subplots( 283 | rows, cols, figsize=(width / dpi, height / dpi), dpi=dpi) 284 | if no_axis: 285 | for ax in fig.axes: 286 | ax.margins(0, 0) 287 | ax.axis('off') 288 | ax.get_xaxis().set_visible(False) 289 | ax.get_yaxis().set_visible(False) 290 | 291 | yield fig, axes, out_array 292 | 293 | # If we haven't already shown or saved the plot, then we need to 294 | # draw the figure first... 295 | fig.tight_layout(pad=0) 296 | fig.canvas.draw() 297 | 298 | # Now we can save it to a numpy array. 299 | if use_alpha: 300 | data = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) 301 | else: 302 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 303 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (num_channels,)) 304 | data = np.roll(data, -1, axis=-1) 305 | plt.close() 306 | 307 | np.copyto(out_array, data) 308 | -------------------------------------------------------------------------------- /hypernerf/warping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Warp fields.""" 16 | from typing import Any, Iterable, Optional, Dict 17 | 18 | from flax import linen as nn 19 | import gin 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from hypernerf import model_utils 24 | from hypernerf import utils 25 | from hypernerf import modules 26 | from hypernerf import rigid_body as rigid 27 | from hypernerf import types 28 | 29 | 30 | @gin.configurable(denylist=['name']) 31 | class TranslationField(nn.Module): 32 | """Network that predicts warps as a translation field. 33 | 34 | References: 35 | https://en.wikipedia.org/wiki/Vector_potential 36 | https://en.wikipedia.org/wiki/Helmholtz_decomposition 37 | 38 | Attributes: 39 | metadata_encoder: an encoder for metadata. 40 | alpha: the alpha for the positional encoding. 41 | skips: the index of the layers with skip connections. 42 | depth: the depth of the network excluding the output layer. 43 | hidden_channels: the width of the network hidden layers. 44 | activation: the activation for each layer. 45 | metadata_encoded: whether the metadata parameter is pre-encoded or not. 46 | hidden_initializer: the initializer for the hidden layers. 47 | output_initializer: the initializer for the last output layer. 48 | """ 49 | min_deg: int = 0 50 | max_deg: int = 8 51 | use_posenc_identity: bool = True 52 | 53 | skips: Iterable[int] = (4,) 54 | depth: int = 6 55 | hidden_channels: int = 128 56 | activation: types.Activation = nn.relu 57 | norm: Optional[Any] = None 58 | hidden_init: types.Initializer = jax.nn.initializers.glorot_uniform() 59 | output_init: types.Initializer = jax.nn.initializers.uniform(scale=1e-4) 60 | 61 | def setup(self): 62 | # Note that this must be done this way instead of using mutable list 63 | # operations. 64 | # See https://github.com/google/flax/issues/524. 65 | # pylint: disable=g-complex-comprehension 66 | output_dims = 3 67 | self.mlp = modules.MLP( 68 | width=self.hidden_channels, 69 | depth=self.depth, 70 | skips=self.skips, 71 | hidden_activation=self.activation, 72 | hidden_norm=self.norm, 73 | hidden_init=self.hidden_init, 74 | output_init=self.output_init, 75 | output_channels=output_dims) 76 | 77 | def warp(self, 78 | points: jnp.ndarray, 79 | metadata: jnp.ndarray, 80 | extra_params: Dict[str, Any]): 81 | points_embed = model_utils.posenc(points, 82 | min_deg=self.min_deg, 83 | max_deg=self.max_deg, 84 | use_identity=self.use_posenc_identity, 85 | alpha=extra_params['warp_alpha']) 86 | inputs = jnp.concatenate([points_embed, metadata], axis=-1) 87 | translation = self.mlp(inputs) 88 | warped_points = points + translation 89 | 90 | return warped_points 91 | 92 | def __call__(self, 93 | points: jnp.ndarray, 94 | metadata: jnp.ndarray, 95 | extra_params: Dict[str, Any], 96 | return_jacobian: bool = False): 97 | """Warp the given points using a warp field. 98 | 99 | Args: 100 | points: the points to warp. 101 | metadata: encoded metadata features. 102 | extra_params: extra parameters used in the warp field e.g., the warp 103 | alpha. 104 | return_jacobian: if True compute and return the Jacobian of the warp. 105 | 106 | Returns: 107 | The warped points and the Jacobian of the warp if `return_jacobian` is 108 | True. 109 | """ 110 | out = { 111 | 'warped_points': self.warp(points, metadata, extra_params) 112 | } 113 | 114 | if return_jacobian: 115 | jac_fn = jax.jacfwd(lambda *x: self.warp(*x)[..., :3], argnums=0) 116 | out['jacobian'] = jac_fn(points, metadata, extra_params) 117 | 118 | return out 119 | 120 | 121 | @gin.configurable(denylist=['name']) 122 | class SE3Field(nn.Module): 123 | """Network that predicts warps as an SE(3) field. 124 | 125 | Attributes: 126 | points_encoder: the positional encoder for the points. 127 | metadata_encoder: an encoder for metadata. 128 | alpha: the alpha for the positional encoding. 129 | skips: the index of the layers with skip connections. 130 | depth: the depth of the network excluding the logit layer. 131 | hidden_channels: the width of the network hidden layers. 132 | activation: the activation for each layer. 133 | metadata_encoded: whether the metadata parameter is pre-encoded or not. 134 | hidden_initializer: the initializer for the hidden layers. 135 | output_initializer: the initializer for the last logit layer. 136 | """ 137 | min_deg: int = 0 138 | max_deg: int = 8 139 | use_posenc_identity: bool = False 140 | 141 | activation: types.Activation = nn.relu 142 | norm: Optional[Any] = None 143 | skips: Iterable[int] = (4,) 144 | trunk_depth: int = 6 145 | trunk_width: int = 128 146 | rotation_depth: int = 0 147 | rotation_width: int = 128 148 | pivot_depth: int = 0 149 | pivot_width: int = 128 150 | translation_depth: int = 0 151 | translation_width: int = 128 152 | 153 | default_init: types.Initializer = jax.nn.initializers.xavier_uniform() 154 | rotation_init: types.Initializer = jax.nn.initializers.uniform(scale=1e-4) 155 | translation_init: types.Initializer = jax.nn.initializers.uniform(scale=1e-4) 156 | 157 | # Unused, here for backwards compatibility. 158 | num_hyper_dims: int = 0 159 | hyper_depth: int = 0 160 | hyper_width: int = 0 161 | hyper_init: Optional[types.Initializer] = None 162 | 163 | def setup(self): 164 | self.trunk = modules.MLP( 165 | depth=self.trunk_depth, 166 | width=self.trunk_width, 167 | hidden_activation=self.activation, 168 | hidden_norm=self.norm, 169 | hidden_init=self.default_init, 170 | skips=self.skips) 171 | 172 | branches = { 173 | 'w': 174 | modules.MLP( 175 | depth=self.rotation_depth, 176 | width=self.rotation_width, 177 | hidden_activation=self.activation, 178 | hidden_norm=self.norm, 179 | hidden_init=self.default_init, 180 | output_init=self.rotation_init, 181 | output_channels=3), 182 | 'v': 183 | modules.MLP( 184 | depth=self.translation_depth, 185 | width=self.translation_width, 186 | hidden_activation=self.activation, 187 | hidden_norm=self.norm, 188 | hidden_init=self.default_init, 189 | output_init=self.translation_init, 190 | output_channels=3), 191 | } 192 | 193 | # Note that this must be done this way instead of using mutable operations. 194 | # See https://github.com/google/flax/issues/524. 195 | self.branches = branches 196 | 197 | def warp(self, 198 | points: jnp.ndarray, 199 | metadata_embed: jnp.ndarray, 200 | extra_params: Dict[str, Any]): 201 | points_embed = model_utils.posenc(points, 202 | min_deg=self.min_deg, 203 | max_deg=self.max_deg, 204 | use_identity=self.use_posenc_identity, 205 | alpha=extra_params['warp_alpha']) 206 | inputs = jnp.concatenate([points_embed, metadata_embed], axis=-1) 207 | trunk_output = self.trunk(inputs) 208 | 209 | w = self.branches['w'](trunk_output) 210 | v = self.branches['v'](trunk_output) 211 | theta = jnp.linalg.norm(w, axis=-1) 212 | w = w / theta[..., jnp.newaxis] 213 | v = v / theta[..., jnp.newaxis] 214 | screw_axis = jnp.concatenate([w, v], axis=-1) 215 | transform = rigid.exp_se3(screw_axis, theta) 216 | 217 | warped_points = points 218 | warped_points = rigid.from_homogenous( 219 | utils.matmul(transform, rigid.to_homogenous(warped_points))) 220 | 221 | return warped_points 222 | 223 | def __call__(self, 224 | points: jnp.ndarray, 225 | metadata: jnp.ndarray, 226 | extra_params: Dict[str, Any], 227 | return_jacobian: bool = False): 228 | """Warp the given points using a warp field. 229 | 230 | Args: 231 | points: the points to warp. 232 | metadata: metadata indices if metadata_encoded is False else pre-encoded 233 | metadata. 234 | extra_params: A dictionary containing 235 | 'alpha': the alpha value for the positional encoding. 236 | return_jacobian: if True compute and return the Jacobian of the warp. 237 | 238 | Returns: 239 | The warped points and the Jacobian of the warp if `return_jacobian` is 240 | True. 241 | """ 242 | 243 | out = { 244 | 'warped_points': self.warp(points, metadata, extra_params) 245 | } 246 | 247 | if return_jacobian: 248 | jac_fn = jax.jacfwd(self.warp, argnums=0) 249 | out['jacobian'] = jac_fn(points, metadata, extra_params) 250 | 251 | return out 252 | -------------------------------------------------------------------------------- /notebooks/figures/hypernerf_optim_latent.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "hypernerf_optim_latent", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "code", 18 | "metadata": { 19 | "id": "-QWb-snnOf1I" 20 | }, 21 | "source": [ 22 | "def apply_model(rng, state, batch):\n", 23 | " fine_key, coarse_key = random.split(rng, 2)\n", 24 | " model_out = model.apply(\n", 25 | " {'params': state.optimizer.target['model']}, \n", 26 | " batch,\n", 27 | " extra_params=state.extra_params,\n", 28 | " metadata_encoded=True,\n", 29 | " rngs={'fine': fine_key, 'coarse': coarse_key})\n", 30 | " return model_out\n", 31 | "\n", 32 | "\n", 33 | "def loss_fn(rng, state, target, batch):\n", 34 | " batch['metadata'] = jax.tree_map(lambda x: x.reshape((1, -1)), \n", 35 | " target['metadata'])\n", 36 | " model_out = apply_model(rng, state, batch)['fine']\n", 37 | " # loss = ((model_out['rgb'] - batch['rgb']) ** 2).mean(axis=-1)\n", 38 | " loss = jnp.abs(model_out['rgb'] - batch['rgb']).mean(axis=-1)\n", 39 | " return loss.mean()\n", 40 | "\n", 41 | "\n", 42 | "def optim_step(rng, state, optimizer, batch):\n", 43 | " rng, key = random.split(rng, 2)\n", 44 | " grad_fn = jax.value_and_grad(loss_fn, argnums=2)\n", 45 | " loss, grad = grad_fn(key, state, optimizer.target, batch)\n", 46 | " grad = jax.lax.pmean(grad, axis_name='batch')\n", 47 | " loss = jax.lax.pmean(loss, axis_name='batch')\n", 48 | "\n", 49 | " optimizer = optimizer.apply_gradient(grad)\n", 50 | "\n", 51 | " return rng, loss, optimizer\n", 52 | "\n", 53 | "\n", 54 | "p_optim_step = jax.pmap(optim_step, axis_name='batch')\n", 55 | "\n", 56 | "key = random.PRNGKey(0)\n", 57 | "key = key + jax.process_index()\n", 58 | "keys = random.split(key, jax.local_device_count())\n", 59 | "\n", 60 | "optimizer_def = optim.Adam(0.1)\n", 61 | "init_metadata = evaluation.encode_metadata(\n", 62 | " model, \n", 63 | " jax_utils.unreplicate(state.optimizer.target['model']), \n", 64 | " jax.tree_map(lambda x: x[0, 0], data['metadata']))\n", 65 | "# init_metadata = jax.tree_map(lambda x: x[0], init_metadata)\n", 66 | "# Initialize to zero.\n", 67 | "init_metadata = jax.tree_map(lambda x: jnp.zeros_like(x), init_metadata)\n", 68 | "optimizer = optimizer_def.create({'metadata': init_metadata})\n", 69 | "optimizer = jax_utils.replicate(optimizer, jax.local_devices())\n", 70 | "devices = jax.local_devices()\n", 71 | "batch_size = 1024\n" 72 | ], 73 | "execution_count": null, 74 | "outputs": [] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "metadata": { 79 | "id": "gzqX4gTWOf3b" 80 | }, 81 | "source": [ 82 | "metadata_progression = []\n", 83 | "\n", 84 | "for i in range(25):\n", 85 | " batch_inds = random.choice(keys[0], np.arange(train_data['rgb'].shape[0]), replace=False, shape=(batch_size,))\n", 86 | " batch = jax.tree_map(lambda x: x[batch_inds, ...], train_data)\n", 87 | " batch = datasets.prepare_data(batch)\n", 88 | " keys, loss, optimizer = p_optim_step(keys, state, optimizer, batch)\n", 89 | " loss = jax_utils.unreplicate(loss)\n", 90 | " metadata_progression.append(jax.tree_map(lambda x: np.array(x), jax_utils.unreplicate(optimizer.target['metadata'])))\n", 91 | " print(f'train_loss = {loss.item():.04f}')\n", 92 | " del batch" 93 | ], 94 | "execution_count": null, 95 | "outputs": [] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "metadata": { 100 | "id": "-s9pCeiqZhKi" 101 | }, 102 | "source": [ 103 | "frames = []\n", 104 | "for metadata in metadata_progression:\n", 105 | "# metadata = jax_utils.unreplicate(optimizer.target['metadata'])\n", 106 | " camera = datasource.load_camera(target_id).scale(1.0)\n", 107 | " batch = make_batch(camera, None, metadata['encoded_warp'], metadata['encoded_hyper'])\n", 108 | " render = render_fn(state, batch, rng=rng)\n", 109 | " pred_rgb = np.array(render['rgb'])\n", 110 | " pred_depth_med = np.array(render['med_depth'])\n", 111 | " pred_depth_viz = viz.colorize(1.0 / pred_depth_med.squeeze())\n", 112 | " media.show_images([pred_rgb, pred_depth_viz])\n", 113 | " frames.append({ \n", 114 | " 'rgb': pred_rgb,\n", 115 | " 'depth': pred_depth_med,\n", 116 | " })\n" 117 | ], 118 | "execution_count": null, 119 | "outputs": [] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "metadata": { 124 | "id": "tu8cC9gKe2BS" 125 | }, 126 | "source": [ 127 | "media.show_image(data['rgb'])\n", 128 | "media.show_videos([\n", 129 | " [d['rgb'] for d in frames],\n", 130 | " [viz.colorize(1/d['depth'].squeeze(), cmin=1.5, cmax=2.9) for d in frames],\n", 131 | "], fps=10)" 132 | ], 133 | "execution_count": null, 134 | "outputs": [] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "metadata": { 139 | "id": "EUpAEa4boPbw" 140 | }, 141 | "source": [ 142 | "" 143 | ], 144 | "execution_count": null, 145 | "outputs": [] 146 | } 147 | ] 148 | } -------------------------------------------------------------------------------- /notebooks/figures/level_set_visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "level_set_visualization.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "metadata": { 22 | "id": "_QO__KXIxS78" 23 | }, 24 | "source": [ 25 | "import numpy as np\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from matplotlib import cm\n", 28 | "import mediapy as media" 29 | ], 30 | "execution_count": null, 31 | "outputs": [] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "metadata": { 36 | "id": "fiVMLQ0bxmXQ" 37 | }, 38 | "source": [ 39 | "def f(x, y):\n", 40 | " r = np.sqrt(x**2 + y**2)\n", 41 | " n = r*5+1\n", 42 | " z = (np.abs(x)**n + np.abs(y)**n)**(1/n)\n", 43 | " return z\n", 44 | "\n", 45 | "def g(x, y):\n", 46 | " z = 1 - np.minimum(f(x - 0.5, y), f(x + 0.5, y))\n", 47 | " return np.maximum(z, 0)\n", 48 | "\n", 49 | "n = 100\n", 50 | "x, y = np.meshgrid(np.linspace(-1.5, 1.5, 2*n), np.linspace(-1, 1, n), indexing='xy')\n", 51 | "plt.contourf(g(x, y))" 52 | ], 53 | "execution_count": null, 54 | "outputs": [] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "Yv_ZrUADzXk7" 60 | }, 61 | "source": [ 62 | "\n", 63 | "n = 1000\n", 64 | "x, y = np.meshgrid(np.linspace(-1.5, 1.5, 2*n), np.linspace(-1, 1, n), indexing='xy')\n", 65 | "\n", 66 | "fig, ax = plt.subplots(figsize=(6, 6), subplot_kw={\"projection\": \"3d\"})\n", 67 | "\n", 68 | "z = g(x, y)\n", 69 | "\n", 70 | "z0s = [0.2, 0.5, 0.8]\n", 71 | "colors = [cm.tab10(0), cm.tab10(1), cm.tab10(2)]\n", 72 | "\n", 73 | "ax.plot_surface(x, y, z, linewidth=0, antialiased=False, color='gray')\n", 74 | "\n", 75 | "x, y = np.meshgrid([-1.5, 1.5], [-1, 1], indexing='xy')\n", 76 | "xv = np.array([-1.5, -1.5, 1.5, 1.5, -1.5])\n", 77 | "yv = np.array([-1, 1, 1, -1, -1])\n", 78 | "\n", 79 | "for z0, color in zip(z0s, colors):\n", 80 | " ax.plot3D(xv, yv, z0, zorder=10, color=color)\n", 81 | "\n", 82 | "ax.axis(False)" 83 | ], 84 | "execution_count": null, 85 | "outputs": [] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "metadata": { 90 | "id": "d1pX7ypI3PE0" 91 | }, 92 | "source": [ 93 | "vis = []\n", 94 | "for z0, color in zip(z0s, colors):\n", 95 | " mask = z > z0\n", 96 | " vis.append(1-mask[:,:,None] * (1-np.array(color[:3])[None,None,:]))\n", 97 | "\n", 98 | "plt.figure(figsize=(8,8))\n", 99 | "plt.imshow(np.concatenate(vis[::-1], 0))\n", 100 | "plt.axis(False)" 101 | ], 102 | "execution_count": null, 103 | "outputs": [] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "metadata": { 108 | "id": "gS1-LG-u6xZs" 109 | }, 110 | "source": [ 111 | "" 112 | ], 113 | "execution_count": null, 114 | "outputs": [] 115 | } 116 | ] 117 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | flax==0.3.4 3 | gin-config @ git+https://github.com/google/gin-config@243ba87b3fcfeb2efb4a920b8f19679b61a6f0dc 4 | imageio==2.9.0 5 | immutabledict==2.2.0 6 | jax==0.2.20 7 | jaxlib==0.1.71+cuda111 8 | Markdown==3.3.4 9 | matplotlib==3.4.3 10 | numpy==1.21.0 11 | oauthlib==3.1.1 12 | opencv-python==4.5.3.56 13 | opt-einsum==3.3.0 14 | optax==0.0.9 15 | Pillow==9.0.0 16 | scipy==1.7.1 17 | tensorflow==2.6.3 18 | tqdm==4.62.2 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import setuptools 16 | 17 | with open("README.md", "r", encoding="utf-8") as fh: 18 | long_description = fh.read() 19 | 20 | setuptools.setup( 21 | name="hypernerf", # Replace with your own username 22 | version="0.0.1", 23 | author="Keunhong Park", 24 | author_email="kpar@cs.washington.edu", 25 | description="Code for 'HyperNeRF'.", 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | url="https://github.com/google/hypernerf", 29 | packages=setuptools.find_packages(), 30 | classifiers=[ 31 | "Programming Language :: Python :: 3", 32 | "License :: OSI Approved :: Apache License 2.0", 33 | "Operating System :: OS Independent", 34 | ], 35 | python_requires='>=3.6', 36 | ) 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Training script for Nerf.""" 17 | 18 | import functools 19 | from typing import Dict, Union 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | from flax import jax_utils 25 | from flax import optim 26 | from flax.metrics import tensorboard 27 | from flax.training import checkpoints 28 | import gin 29 | import jax 30 | from jax import numpy as jnp 31 | from jax import random 32 | import numpy as np 33 | import tensorflow as tf 34 | 35 | from hypernerf import configs 36 | from hypernerf import datasets 37 | from hypernerf import gpath 38 | from hypernerf import model_utils 39 | from hypernerf import models 40 | from hypernerf import schedules 41 | from hypernerf import training 42 | from hypernerf import utils 43 | 44 | flags.DEFINE_enum('mode', None, ['jax_cpu', 'jax_gpu', 'jax_tpu'], 45 | 'Distributed strategy approach.') 46 | 47 | flags.DEFINE_string('base_folder', None, 'where to store ckpts and logs') 48 | flags.mark_flag_as_required('base_folder') 49 | flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') 50 | flags.DEFINE_multi_string('gin_configs', (), 'Gin config files.') 51 | FLAGS = flags.FLAGS 52 | 53 | 54 | def _log_to_tensorboard(writer: tensorboard.SummaryWriter, 55 | state: model_utils.TrainState, 56 | scalar_params: training.ScalarParams, 57 | stats: Dict[str, Union[Dict[str, jnp.ndarray], 58 | jnp.ndarray]], 59 | time_dict: Dict[str, jnp.ndarray]): 60 | """Log statistics to Tensorboard.""" 61 | step = int(state.optimizer.state.step) 62 | 63 | def _log_scalar(tag, value): 64 | if value is not None: 65 | writer.scalar(tag, value, step) 66 | 67 | _log_scalar('params/learning_rate', scalar_params.learning_rate) 68 | _log_scalar('params/nerf_alpha', state.nerf_alpha) 69 | _log_scalar('params/warp_alpha', state.warp_alpha) 70 | _log_scalar('params/hyper_sheet_alpha', state.hyper_sheet_alpha) 71 | _log_scalar('params/elastic_loss/weight', scalar_params.elastic_loss_weight) 72 | 73 | # pmean is applied in train_step so just take the item. 74 | for branch in {'coarse', 'fine'}: 75 | if branch not in stats: 76 | continue 77 | for stat_key, stat_value in stats[branch].items(): 78 | writer.scalar(f'{stat_key}/{branch}', stat_value, step) 79 | 80 | _log_scalar('loss/background', stats.get('background_loss')) 81 | 82 | for k, v in time_dict.items(): 83 | writer.scalar(f'time/{k}', v, step) 84 | 85 | 86 | def _log_histograms(writer: tensorboard.SummaryWriter, 87 | state: model_utils.TrainState, 88 | model_out): 89 | """Log histograms to Tensorboard.""" 90 | step = int(state.optimizer.state.step) 91 | params = state.optimizer.target['model'] 92 | if 'nerf_embed' in params: 93 | embeddings = params['nerf_embed']['embed']['embedding'] 94 | writer.histogram('nerf_embedding', embeddings, step) 95 | if 'hyper_embed' in params: 96 | embeddings = params['hyper_embed']['embed']['embedding'] 97 | writer.histogram('hyper_embedding', embeddings, step) 98 | if 'warp_embed' in params: 99 | embeddings = params['warp_embed']['embed']['embedding'] 100 | writer.histogram('warp_embedding', embeddings, step) 101 | 102 | for branch in {'coarse', 'fine'}: 103 | if 'warped_points' in model_out[branch]: 104 | points = model_out[branch]['points'] 105 | warped_points = model_out[branch]['warped_points'] 106 | writer.histogram(f'{branch}/spatial_points', 107 | warped_points[..., :3], step) 108 | writer.histogram(f'{branch}/spatial_points_delta', 109 | warped_points[..., :3] - points, step) 110 | if warped_points.shape[-1] > 3: 111 | writer.histogram(f'{branch}/hyper_points', 112 | warped_points[..., 3:], step) 113 | 114 | 115 | def _log_grads(writer: tensorboard.SummaryWriter, model: models.NerfModel, 116 | state: model_utils.TrainState): 117 | """Log histograms to Tensorboard.""" 118 | step = int(state.optimizer.state.step) 119 | params = state.optimizer.target['model'] 120 | if 'nerf_metadata_encoder' in params: 121 | embeddings = params['nerf_metadata_encoder']['embed']['embedding'] 122 | writer.histogram('nerf_embedding', embeddings, step) 123 | if 'hyper_metadata_encoder' in params: 124 | embeddings = params['hyper_metadata_encoder']['embed']['embedding'] 125 | writer.histogram('hyper_embedding', embeddings, step) 126 | if 'warp_field' in params and model.warp_metadata_config['type'] == 'glo': 127 | embeddings = params['warp_metadata_encoder']['embed']['embedding'] 128 | writer.histogram('warp_embedding', embeddings, step) 129 | 130 | 131 | def main(argv): 132 | jax.config.parse_flags_with_absl() 133 | tf.config.experimental.set_visible_devices([], 'GPU') 134 | del argv 135 | logging.info('*** Starting experiment') 136 | # Assume G3 path for config files when running locally. 137 | gin_configs = FLAGS.gin_configs 138 | 139 | logging.info('*** Loading Gin configs from: %s', str(gin_configs)) 140 | gin.parse_config_files_and_bindings( 141 | config_files=gin_configs, 142 | bindings=FLAGS.gin_bindings, 143 | skip_unknown=True) 144 | 145 | # Load configurations. 146 | exp_config = configs.ExperimentConfig() 147 | train_config = configs.TrainConfig() 148 | dummy_model = models.NerfModel({}, 0, 0) 149 | 150 | # Get directory information. 151 | exp_dir = gpath.GPath(FLAGS.base_folder) 152 | if exp_config.subname: 153 | exp_dir = exp_dir / exp_config.subname 154 | summary_dir = exp_dir / 'summaries' / 'train' 155 | checkpoint_dir = exp_dir / 'checkpoints' 156 | 157 | # Log and create directories if this is the main process. 158 | if jax.process_index() == 0: 159 | logging.info('exp_dir = %s', exp_dir) 160 | if not exp_dir.exists(): 161 | exp_dir.mkdir(parents=True, exist_ok=True) 162 | 163 | logging.info('summary_dir = %s', summary_dir) 164 | if not summary_dir.exists(): 165 | summary_dir.mkdir(parents=True, exist_ok=True) 166 | 167 | logging.info('checkpoint_dir = %s', checkpoint_dir) 168 | if not checkpoint_dir.exists(): 169 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 170 | 171 | logging.info('Starting process %d. There are %d processes.', 172 | jax.process_index(), jax.process_count()) 173 | logging.info('Found %d accelerator devices: %s.', jax.local_device_count(), 174 | str(jax.local_devices())) 175 | logging.info('Found %d total devices: %s.', jax.device_count(), 176 | str(jax.devices())) 177 | 178 | rng = random.PRNGKey(exp_config.random_seed) 179 | # Shift the numpy random seed by process_index() to shuffle data loaded by 180 | # different processes. 181 | np.random.seed(exp_config.random_seed + jax.process_index()) 182 | 183 | if train_config.batch_size % jax.device_count() != 0: 184 | raise ValueError('Batch size must be divisible by the number of devices.') 185 | 186 | devices = jax.local_devices() 187 | logging.info('Creating datasource') 188 | datasource = exp_config.datasource_cls( 189 | image_scale=exp_config.image_scale, 190 | random_seed=exp_config.random_seed, 191 | # Enable metadata based on model needs. 192 | use_warp_id=dummy_model.use_warp, 193 | use_appearance_id=( 194 | dummy_model.nerf_embed_key == 'appearance' 195 | or dummy_model.hyper_embed_key == 'appearance'), 196 | use_camera_id=dummy_model.nerf_embed_key == 'camera', 197 | use_time=dummy_model.warp_embed_key == 'time') 198 | 199 | # Create Model. 200 | logging.info('Initializing models.') 201 | rng, key = random.split(rng) 202 | params = {} 203 | model, params['model'] = models.construct_nerf( 204 | key, 205 | batch_size=train_config.batch_size, 206 | embeddings_dict=datasource.embeddings_dict, 207 | near=datasource.near, 208 | far=datasource.far) 209 | 210 | # Create Jax iterator. 211 | logging.info('Creating dataset iterator.') 212 | train_iter = datasource.create_iterator( 213 | datasource.train_ids, 214 | flatten=True, 215 | shuffle=True, 216 | batch_size=train_config.batch_size, 217 | prefetch_size=3, 218 | shuffle_buffer_size=train_config.shuffle_buffer_size, 219 | devices=devices, 220 | ) 221 | 222 | points_iter = None 223 | if train_config.use_background_loss: 224 | points = datasource.load_points(shuffle=True) 225 | points_batch_size = min( 226 | len(points), 227 | len(devices) * train_config.background_points_batch_size) 228 | points_batch_size -= points_batch_size % len(devices) 229 | points_dataset = tf.data.Dataset.from_tensor_slices(points) 230 | points_iter = datasets.iterator_from_dataset( 231 | points_dataset, 232 | batch_size=points_batch_size, 233 | prefetch_size=3, 234 | devices=devices) 235 | 236 | learning_rate_sched = schedules.from_config(train_config.lr_schedule) 237 | nerf_alpha_sched = schedules.from_config(train_config.nerf_alpha_schedule) 238 | warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule) 239 | hyper_alpha_sched = schedules.from_config(train_config.hyper_alpha_schedule) 240 | hyper_sheet_alpha_sched = schedules.from_config( 241 | train_config.hyper_sheet_alpha_schedule) 242 | elastic_loss_weight_sched = schedules.from_config( 243 | train_config.elastic_loss_weight_schedule) 244 | 245 | optimizer_def = optim.Adam(learning_rate_sched(0)) 246 | if train_config.use_weight_norm: 247 | optimizer_def = optim.WeightNorm(optimizer_def) 248 | optimizer = optimizer_def.create(params) 249 | state = model_utils.TrainState( 250 | optimizer=optimizer, 251 | nerf_alpha=nerf_alpha_sched(0), 252 | warp_alpha=warp_alpha_sched(0), 253 | hyper_alpha=hyper_alpha_sched(0), 254 | hyper_sheet_alpha=hyper_sheet_alpha_sched(0)) 255 | scalar_params = training.ScalarParams( 256 | learning_rate=learning_rate_sched(0), 257 | elastic_loss_weight=elastic_loss_weight_sched(0), 258 | warp_reg_loss_weight=train_config.warp_reg_loss_weight, 259 | warp_reg_loss_alpha=train_config.warp_reg_loss_alpha, 260 | warp_reg_loss_scale=train_config.warp_reg_loss_scale, 261 | background_loss_weight=train_config.background_loss_weight, 262 | hyper_reg_loss_weight=train_config.hyper_reg_loss_weight) 263 | state = checkpoints.restore_checkpoint(checkpoint_dir, state) 264 | init_step = state.optimizer.state.step + 1 265 | state = jax_utils.replicate(state, devices=devices) 266 | del params 267 | 268 | summary_writer = None 269 | if jax.process_index() == 0: 270 | config_str = gin.operative_config_str() 271 | logging.info('Configuration: \n%s', config_str) 272 | with (exp_dir / 'config.gin').open('w') as f: 273 | f.write(config_str) 274 | summary_writer = tensorboard.SummaryWriter(str(summary_dir)) 275 | summary_writer.text('gin/train', textdata=gin.markdown(config_str), step=0) 276 | 277 | train_step = functools.partial( 278 | training.train_step, 279 | model, 280 | elastic_reduce_method=train_config.elastic_reduce_method, 281 | elastic_loss_type=train_config.elastic_loss_type, 282 | use_elastic_loss=train_config.use_elastic_loss, 283 | use_background_loss=train_config.use_background_loss, 284 | use_warp_reg_loss=train_config.use_warp_reg_loss, 285 | use_hyper_reg_loss=train_config.use_hyper_reg_loss, 286 | ) 287 | ptrain_step = jax.pmap( 288 | train_step, 289 | axis_name='batch', 290 | devices=devices, 291 | # rng_key, state, batch, scalar_params. 292 | in_axes=(0, 0, 0, None), 293 | # Treat use_elastic_loss as compile-time static. 294 | donate_argnums=(2,), # Donate the 'batch' argument. 295 | ) 296 | 297 | if devices: 298 | n_local_devices = len(devices) 299 | else: 300 | n_local_devices = jax.local_device_count() 301 | 302 | logging.info('Starting training') 303 | # Make random seed separate across processes. 304 | rng = rng + jax.process_index() 305 | keys = random.split(rng, n_local_devices) 306 | time_tracker = utils.TimeTracker() 307 | time_tracker.tic('data', 'total') 308 | for step, batch in zip(range(init_step, train_config.max_steps + 1), 309 | train_iter): 310 | if points_iter is not None: 311 | batch['background_points'] = next(points_iter) 312 | time_tracker.toc('data') 313 | # See: b/162398046. 314 | # pytype: disable=attribute-error 315 | scalar_params = scalar_params.replace( 316 | learning_rate=learning_rate_sched(step), 317 | elastic_loss_weight=elastic_loss_weight_sched(step)) 318 | # pytype: enable=attribute-error 319 | nerf_alpha = jax_utils.replicate(nerf_alpha_sched(step), devices) 320 | warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices) 321 | hyper_alpha = jax_utils.replicate(hyper_alpha_sched(step), devices) 322 | hyper_sheet_alpha = jax_utils.replicate( 323 | hyper_sheet_alpha_sched(step), devices) 324 | state = state.replace(nerf_alpha=nerf_alpha, 325 | warp_alpha=warp_alpha, 326 | hyper_alpha=hyper_alpha, 327 | hyper_sheet_alpha=hyper_sheet_alpha) 328 | 329 | with time_tracker.record_time('train_step'): 330 | state, stats, keys, model_out = ptrain_step( 331 | keys, state, batch, scalar_params) 332 | time_tracker.toc('total') 333 | 334 | if step % train_config.print_every == 0 and jax.process_index() == 0: 335 | logging.info('step=%d, nerf_alpha=%.04f, warp_alpha=%.04f, %s', step, 336 | nerf_alpha_sched(step), 337 | warp_alpha_sched(step), 338 | time_tracker.summary_str('last')) 339 | coarse_metrics_str = ', '.join( 340 | [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()]) 341 | fine_metrics_str = ', '.join( 342 | [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()]) 343 | logging.info('\tcoarse metrics: %s', coarse_metrics_str) 344 | if 'fine' in stats: 345 | logging.info('\tfine metrics: %s', fine_metrics_str) 346 | 347 | if step % train_config.save_every == 0 and jax.process_index() == 0: 348 | training.save_checkpoint(checkpoint_dir, state, keep=2) 349 | 350 | if step % train_config.log_every == 0 and jax.process_index() == 0: 351 | # Only log via process 0. 352 | _log_to_tensorboard( 353 | summary_writer, 354 | jax_utils.unreplicate(state), 355 | scalar_params, 356 | jax_utils.unreplicate(stats), 357 | time_dict=time_tracker.summary('mean')) 358 | time_tracker.reset() 359 | 360 | if step % train_config.histogram_every == 0 and jax.process_index() == 0: 361 | _log_histograms(summary_writer, jax_utils.unreplicate(state), model_out) 362 | 363 | time_tracker.tic('data', 'total') 364 | 365 | if train_config.max_steps % train_config.save_every != 0: 366 | training.save_checkpoint(checkpoint_dir, state, keep=2) 367 | 368 | 369 | if __name__ == '__main__': 370 | app.run(main) 371 | --------------------------------------------------------------------------------