├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── defaults.gin ├── gpu_fullhd.gin ├── gpu_quarterhd.gin ├── gpu_quarterhd_4gpu.gin ├── gpu_vrig_paper.gin ├── test_local.gin ├── test_vrig.gin └── warp_defaults.gin ├── eval.py ├── nerfies ├── __init__.py ├── camera.py ├── configs.py ├── datasets │ ├── __init__.py │ ├── core.py │ └── nerfies.py ├── evaluation.py ├── glo.py ├── gpath.py ├── image_utils.py ├── model_utils.py ├── models.py ├── modules.py ├── quaternion.py ├── rigid_body.py ├── schedules.py ├── tf_camera.py ├── training.py ├── types.py ├── utils.py ├── visualization.py └── warping.py ├── notebooks ├── Nerfies_Capture_Processing.ipynb ├── Nerfies_Render_Video.ipynb └── Nerfies_Training.ipynb ├── requirements.txt ├── setup.py ├── third_party └── pycolmap │ ├── LICENSE │ ├── README.md │ ├── pycolmap │ ├── __init__.py │ ├── camera.py │ ├── database.py │ ├── image.py │ ├── rotation.py │ └── scene_manager.py │ └── setup.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .idea 3 | bower_components 4 | node_modules 5 | venv 6 | 7 | *.ts.map 8 | 9 | *~ 10 | *.so 11 | .DS_Store 12 | ._.DS_Store 13 | *.swp 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *,cover 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nerfies: Deformable Neural Radiance Fields 2 | 3 | This is the code for Nerfies: Deformable Neural Radiance Fields. 4 | 5 | * [Project Page](https://nerfies.github.io) 6 | * [Paper](https://arxiv.org/abs/2011.12948) 7 | * [Video](https://www.youtube.com/watch?v=MrKrnHhk8IA) 8 | 9 | This codebase is implemented using [JAX](https://github.com/google/jax), 10 | building on [JaxNeRF](https://github.com/google-research/google-research/tree/master/jaxnerf). 11 | 12 | This repository has been updated to reflect the version used for our ICCV 2021 submission. 13 | 14 | ## Demo 15 | 16 | We provide an easy-to-get-started demo using Google Colab! 17 | 18 | These Colabs will allow you to train a basic version of our method using 19 | Cloud TPUs (or GPUs) on Google Colab. 20 | 21 | Note that due to limited compute resources available, these are not the fully 22 | featured models. If you would like to train a fully featured Nerfie, please 23 | refer to the instructions below on how to train on your own machine. 24 | 25 | | Description | Link | 26 | | ----------- | ----------- | 27 | | Process a video into a Nerfie dataset| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb)| 28 | | Train a Nerfie| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Training.ipynb)| 29 | | Render a Nerfie video| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Render_Video.ipynb)| 30 | 31 | ## Setup 32 | The code can be run under any environment with Python 3.8 and above. 33 | (It may run with lower versions, but we have not tested it). 34 | 35 | We recommend using [Miniconda](https://docs.conda.io/en/latest/miniconda.html) and setting up an environment: 36 | 37 | conda create --name nerfies python=3.8 38 | 39 | Next, install the required packages: 40 | 41 | pip install -r requirements.txt 42 | 43 | Install the appropriate JAX distribution for your environment by [following the instructions here](https://github.com/google/jax#installation). For example: 44 | 45 | # For CUDA version 11.0 46 | pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html 47 | 48 | 49 | ## Training 50 | After preparing a dataset, you can train a Nerfie by running: 51 | 52 | export DATASET_PATH=/path/to/dataset 53 | export EXPERIMENT_PATH=/path/to/save/experiment/to 54 | python train.py \ 55 | --data_dir $DATASET_PATH \ 56 | --base_folder $EXPERIMENT_PATH \ 57 | --gin_configs configs/test_vrig.gin 58 | 59 | To plot telemetry to Tensorboard and render checkpoints on the fly, also 60 | launch an evaluation job by running: 61 | 62 | python eval.py \ 63 | --data_dir $DATASET_PATH \ 64 | --base_folder $EXPERIMENT_PATH \ 65 | --gin_configs configs/test_vrig.gin 66 | 67 | The two jobs should use a mutually exclusive set of GPUs. This division allows the 68 | training job to run without having to stop for evaluation. 69 | 70 | ## Configuration 71 | * We use [Gin](https://github.com/google/gin-config) for configuration. 72 | * We provide a couple preset configurations. 73 | * Please refer to `config.py` for documentation on what each configuration does. 74 | * Preset configs: 75 | - `gpu_vrig_paper.gin`: This is the configuration we used to generate the table in the paper. It requires 8 GPUs for training. 76 | - `gpu_fullhd.gin`: This is a high-resolution model and will take around 3 days to train on 8 GPUs. 77 | - `gpu_quarterhd.gin`: This is a low-resolution model and will take around 14 hours to train on 8 GPUs. 78 | - `test_local.gin`: This is a test configuration to see if the code runs. It probably will not result in a good looking result. 79 | - `test_vrig.gin`: This is a test configuration to see if the code runs for validation rig captures. It probably will not result in a good looking result. 80 | * Training on fewer GPUs will require tuning of the batch size and learning rates. We've provided an example configuration for 4 GPUs in `gpu_quarterhd_4gpu.gin` but we have not tested it, so please only use it as a reference. 81 | 82 | ## Datasets 83 | A dataset is a directory with the following structure: 84 | 85 | dataset 86 | ├── camera 87 | │   └── ${item_id}.json 88 | ├── camera-paths 89 | ├── rgb 90 | │   ├── ${scale}x 91 | │   └── └── ${item_id}.png 92 | ├── metadata.json 93 | ├── points.npy 94 | ├── dataset.json 95 | └── scene.json 96 | 97 | At a high level, a dataset is simply the following: 98 | * A collection of images (e.g., from a video). 99 | * Camera parameters for each image. 100 | 101 | We have a unique identifier for each image which we call `item_id`, and this is 102 | used to match the camera and images. An `item_id` can be any string, but typically 103 | it is some alphanumeric string such as `000054`. 104 | 105 | ### `camera` 106 | 107 | * This directory contains cameras corresponding to each image. 108 | * We use a camera model identical to the [OpenCV camera model](https://docs.opencv.org/master/dc/dbb/tutorial_py_calibration.html), which is also supported by COLMAP. 109 | * Each camera is a serialized version of the `Camera` class defined in `camera.py` and looks like this: 110 | 111 | ```javascript 112 | { 113 | // A 3x3 world-to-camera rotation matrix representing the camera orientation. 114 | "orientation": [ 115 | [0.9839, -0.0968, 0.1499], 116 | [-0.0350, -0.9284, -0.3699], 117 | [0.1749, 0.358, -0.9168] 118 | ], 119 | // The 3D position of the camera in world-space. 120 | "position": [-0.3236, -3.26428, 5.4160], 121 | // The focal length of the camera. 122 | "focal_length": 2691, 123 | // The principle point [u_0, v_0] of the camera. 124 | "principal_point": [1220, 1652], 125 | // The skew of the camera. 126 | "skew": 0.0, 127 | // The aspect ratio for the camera pixels. 128 | "pixel_aspect_ratio": 1.0, 129 | // Parameters for the radial distortion of the camera. 130 | "radial_distortion": [0.1004, -0.2090, 0.0], 131 | // Parameters for the tangential distortion of the camera. 132 | "tangential": [0.001109, -2.5733e-05], 133 | // The image width and height in pixels. 134 | "image_size": [2448, 3264] 135 | } 136 | ``` 137 | 138 | ### `camera-paths` 139 | * This directory contains test-time camera paths which can be used to render videos. 140 | * Each sub-directory in this path should contain a sequence of JSON files. 141 | * The naming scheme does not matter, but the cameras will be sorted by their filenames. 142 | 143 | ### `rgb` 144 | * This directory contains images at various scales. 145 | * Each subdirectory should be named `${scale}x` where `${scale}` is an integer scaling factor. For example, `1x` would contain the original images while `4x` would contain images a quarter of the size. 146 | * We assume the images are in PNG format. 147 | * It is important the scaled images are integer factors of the original to allow the use of area relation when scaling the images to prevent Moiré. A simple way to do this is to simply trim the borders of the image to be divisible by the maximum scale factor you want. 148 | 149 | ### `metadata.json` 150 | * This defines the 'metadata' IDs used for embedding lookups. 151 | * Contains a dictionary of the following format: 152 | 153 | ```javascript 154 | { 155 | "${item_id}": { 156 | // The embedding ID used to fetch the deformation latent code 157 | // passed to the deformation field. 158 | "warp_id": 0, 159 | // The embedding ID used to fetch the appearance latent code 160 | // which is passed to the second branch of the template NeRF. 161 | "appearance_id": 0, 162 | // For validation rig datasets, we use the camera ID instead 163 | // of the appearance ID. For example, this would be '0' for the 164 | // left camera and '1' for the right camera. This can potentially 165 | // also be used for multi-view setups as well. 166 | "camera_id": 0 167 | }, 168 | ... 169 | }, 170 | ``` 171 | ### `scene.json` 172 | * Contains information about how we will parse the scene. 173 | * See comments inline. 174 | 175 | ```javascript 176 | { 177 | // The scale factor we will apply to the pointcloud and cameras. This is 178 | // important since it controls what scale is used when computing the positional 179 | // encoding. 180 | "scale": 0.0387243672920458, 181 | // Defines the origin of the scene. The scene will be translated such that 182 | // this point becomes the origin. Defined in unscaled coordinates. 183 | "center": [ 184 | 1.1770838526103944e-08, 185 | -2.58235339289195, 186 | -1.29117656263135 187 | ], 188 | // The distance of the near plane from the camera center in scaled coordinates. 189 | "near": 0.02057418950149491, 190 | // The distance of the far plane from the camera center in scaled coordinates. 191 | "far": 0.8261601717667288 192 | } 193 | ``` 194 | 195 | ### `dataset.json` 196 | * Defines the training/validation split of the dataset. 197 | * See inline comments: 198 | 199 | ```javascript 200 | { 201 | // The total number of images in the dataset. 202 | "count": 114, 203 | // The total number of training images (exemplars) in the dataset. 204 | "num_exemplars": 57, 205 | // A list containins all item IDs in the dataset. 206 | "ids": [...], 207 | // A list containing all training item IDs in the dataset. 208 | "train_ids": [...], 209 | // A list containing all validation item IDs in the dataset. 210 | // This should be mutually exclusive with `train_ids`. 211 | "val_ids": [...], 212 | } 213 | ``` 214 | 215 | ### `points.npy` 216 | 217 | * A numpy file containing a single array of size `(N,3)` containing the background points. 218 | * This is required if you want to use the background regularization loss. 219 | 220 | ## Citing 221 | If you find our work useful, please consider citing: 222 | ```BibTeX 223 | @article{park2021nerfies 224 | author = {Park, Keunhong 225 | and Sinha, Utkarsh 226 | and Barron, Jonathan T. 227 | and Bouaziz, Sofien 228 | and Goldman, Dan B 229 | and Seitz, Steven M. 230 | and Martin-Brualla, Ricardo}, 231 | title = {Nerfies: Deformable Neural Radiance Fields}, 232 | journal = {ICCV}, 233 | year = {2021}, 234 | } 235 | ``` 236 | -------------------------------------------------------------------------------- /configs/defaults.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # This is the base configuration that is imported by other configurations. 17 | # Do not run this configuration directly. 18 | 19 | num_warp_freqs = 8 20 | elastic_init_weight = 0.01 21 | lr_delay_steps = 2500 22 | lr_delay_mult = 0.01 23 | 24 | # Predefined warp alpha schedules. 25 | ANNEALED_WARP_ALPHA_SCHED = { 26 | 'type': 'linear', 27 | 'initial_value': 0.0, 28 | 'final_value': %num_warp_freqs, 29 | 'num_steps': 80000, 30 | } 31 | CONSTANT_WARP_ALPHA_SCHED = { 32 | 'type': 'constant', 33 | 'value': %num_warp_freqs, 34 | } 35 | 36 | # Predefined elastic loss schedules. 37 | CONSTANT_ELASTIC_LOSS_SCHED = { 38 | 'type': 'constant', 39 | 'value': %elastic_init_weight, 40 | } 41 | DECAYING_ELASTIC_LOSS_SCHED = { 42 | 'type': 'piecewise', 43 | 'schedules': [ 44 | (50000, ('constant', %elastic_init_weight)), 45 | (100000, ('cosine_easing', %elastic_init_weight, 1e-8, 100000)), 46 | ] 47 | } 48 | 49 | DEFAULT_LR_SCHEDULE = { 50 | 'type': 'exponential', 51 | 'initial_value': %init_lr, 52 | 'final_value': %final_lr, 53 | 'num_steps': %max_steps, 54 | } 55 | 56 | DELAYED_LR_SCHEDULE = { 57 | 'type': 'delayed', 58 | 'delay_steps': %lr_delay_steps, 59 | 'delay_mult': %lr_delay_mult, 60 | 'base_schedule': %DEFAULT_LR_SCHEDULE, 61 | } 62 | 63 | # Common configs. 64 | ModelConfig.use_viewdirs = True 65 | ModelConfig.use_stratified_sampling = True 66 | ModelConfig.sigma_activation = @nn.softplus 67 | ModelConfig.use_appearance_metadata = False 68 | 69 | # Experiment configs. 70 | ExperimentConfig.image_scale = %image_scale 71 | ExperimentConfig.random_seed = 12345 72 | 73 | # Warp field configs. 74 | ModelConfig.use_warp = False 75 | ModelConfig.warp_field_type = 'se3' 76 | ModelConfig.num_warp_freqs = %num_warp_freqs 77 | ModelConfig.num_warp_features = 8 78 | 79 | # Use macros to make sure these are set somewhere. 80 | TrainConfig.batch_size = %batch_size 81 | TrainConfig.max_steps = %max_steps 82 | TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE 83 | TrainConfig.warp_alpha_schedule = %CONSTANT_WARP_ALPHA_SCHED 84 | 85 | # Elastic loss. 86 | TrainConfig.use_elastic_loss = False 87 | TrainConfig.elastic_loss_weight_schedule = %CONSTANT_ELASTIC_LOSS_SCHED 88 | 89 | # Background regularization loss. 90 | TrainConfig.use_background_loss = False 91 | TrainConfig.background_loss_weight = 1.0 92 | 93 | # Script interval configs. 94 | TrainConfig.print_every = 100 95 | TrainConfig.log_every = 500 96 | TrainConfig.save_every = 5000 97 | 98 | EvalConfig.eval_once = False 99 | EvalConfig.save_output = True 100 | EvalConfig.chunk = %eval_batch_size 101 | -------------------------------------------------------------------------------- /configs/gpu_fullhd.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # This is a Full HD (1080p-ish) configuration. 17 | # The image scale is based on our input video size of 1920x1080. 18 | # This configuration requires 8 GPUs to train. 19 | # 20 | # To make this runnable on fewer GPUs, decrease the `batch_size` and scale the 21 | # learning rate by the sqrt of the factor by which you decreased it. Expect 22 | # the results to look slightly worse without tuning. 23 | 24 | include 'warp_defaults.gin' 25 | 26 | max_steps = 1000000 27 | lr_decay_steps = 2000000 28 | 29 | image_scale = 1 30 | batch_size = 4096 31 | eval_batch_size = 4096 32 | init_lr = 0.00075 33 | final_lr = 0.000075 34 | 35 | ModelConfig.num_nerf_point_freqs = 10 36 | ModelConfig.nerf_trunk_width = 256 37 | ModelConfig.nerf_trunk_depth = 8 38 | ModelConfig.num_coarse_samples = 256 39 | ModelConfig.num_fine_samples = 256 40 | 41 | TrainConfig.print_every = 200 42 | TrainConfig.log_every = 500 43 | TrainConfig.save_every = 10000 44 | -------------------------------------------------------------------------------- /configs/gpu_quarterhd.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # This is a quarter HD configuration. 17 | # The image scale is based on our input video size of 1920x1080. 18 | # This configuration requires 8 GPUs to train. 19 | # 20 | # To make this runnable on fewer GPUs, decrease the `batch_size` and scale the 21 | # learning rate by the sqrt of the factor by which you decreased it. Expect 22 | # the results to look slightly worse without tuning. 23 | 24 | include 'warp_defaults.gin' 25 | 26 | max_steps = 250000 27 | lr_decay_steps = 500000 28 | 29 | image_scale = 4 30 | batch_size = 6144 31 | eval_batch_size = 8096 32 | init_lr = 0.001 33 | final_lr = 0.0001 34 | 35 | ModelConfig.num_nerf_point_freqs = 8 36 | ModelConfig.nerf_trunk_width = 256 37 | ModelConfig.nerf_trunk_depth = 8 38 | ModelConfig.num_coarse_samples = 128 39 | ModelConfig.num_fine_samples = 128 40 | 41 | TrainConfig.print_every = 200 42 | TrainConfig.log_every = 500 43 | TrainConfig.save_every = 5000 44 | -------------------------------------------------------------------------------- /configs/gpu_quarterhd_4gpu.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # This is a quarter HD configuration for 4 GPUs. 17 | # The image scale is based on our input video size of 1920x1080. 18 | # This configuration requires 4 GPUs to train. 19 | # 20 | # Note that this configuration has not been tested and may require further 21 | # tuning. 22 | 23 | include 'warp_defaults.gin' 24 | 25 | max_steps = 500000 26 | lr_decay_steps = 1000000 27 | 28 | image_scale = 4 29 | batch_size = 3072 30 | eval_batch_size = 4096 31 | init_lr = 0.0007 32 | final_lr = 0.00007 33 | 34 | ModelConfig.num_nerf_point_freqs = 8 35 | ModelConfig.nerf_trunk_width = 256 36 | ModelConfig.nerf_trunk_depth = 8 37 | ModelConfig.num_coarse_samples = 128 38 | ModelConfig.num_fine_samples = 128 39 | 40 | TrainConfig.print_every = 200 41 | TrainConfig.log_every = 500 42 | TrainConfig.save_every = 5000 43 | -------------------------------------------------------------------------------- /configs/gpu_vrig_paper.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # This is the validation rig configuration we used in the quantitative 17 | # evaluation of the paper. The `image_scale` is based on our raw dataset 18 | # resolution of 4032x3024 from the validation rig. 19 | # This configuration requires 8 GPUs to train. 20 | 21 | include 'configs/warp_defaults.gin' 22 | 23 | max_steps = 250000 24 | lr_decay_steps = %max_steps 25 | 26 | image_scale = 4 27 | batch_size = 6144 28 | eval_batch_size = 8096 29 | init_lr = 0.001 30 | final_lr = 0.0001 31 | elastic_init_weight = 0.001 32 | num_warp_freqs = 6 33 | 34 | ModelConfig.use_warp = True 35 | ModelConfig.num_nerf_point_freqs = 8 36 | ModelConfig.nerf_trunk_width = 256 37 | ModelConfig.nerf_trunk_depth = 8 38 | ModelConfig.num_coarse_samples = 128 39 | ModelConfig.num_fine_samples = 128 40 | ModelConfig.use_appearance_metadata = False 41 | ModelConfig.use_camera_metadata = True 42 | ModelConfig.use_stratified_sampling = True 43 | ModelConfig.camera_metadata_dims = 2 44 | ModelConfig.use_sample_at_infinity = True 45 | ModelConfig.warp_field_type = 'se3' 46 | 47 | TrainConfig.print_every = 500 48 | TrainConfig.log_every = 500 49 | TrainConfig.histogram_every = 1000 50 | TrainConfig.save_every = 5000 51 | 52 | TrainConfig.use_elastic_loss = True 53 | TrainConfig.use_background_loss = True 54 | TrainConfig.background_loss_weight = 1.0 55 | TrainConfig.warp_alpha_schedule = %ANNEALED_WARP_ALPHA_SCHED 56 | 57 | TrainConfig.use_warp_reg_loss = False 58 | TrainConfig.warp_reg_loss_weight = 1e-2 59 | 60 | TrainConfig.elastic_reduce_method = 'weight' 61 | TrainConfig.elastic_loss_weight_schedule = { 62 | 'type': 'constant', 63 | 'value': %elastic_init_weight, 64 | } 65 | TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE 66 | 67 | EvalConfig.num_val_eval = None 68 | EvalConfig.num_train_eval = None 69 | -------------------------------------------------------------------------------- /configs/test_local.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # This is a test configuration for sanity checking. 17 | # It will likely not result in a good quality reconstruction. 18 | # This config will run on a single GPU. 19 | 20 | elastic_init_weight = 0.01 21 | max_steps = 250000 22 | 23 | ExperimentConfig.image_scale = 4 24 | 25 | ModelConfig.num_coarse_samples = 64 26 | ModelConfig.num_fine_samples = 64 27 | ModelConfig.use_viewdirs = True 28 | ModelConfig.use_stratified_sampling = True 29 | ModelConfig.use_appearance_metadata = True 30 | ModelConfig.use_warp = True 31 | ModelConfig.warp_field_type = 'se3' 32 | ModelConfig.num_warp_features = 3 33 | ModelConfig.num_warp_freqs = 8 34 | ModelConfig.sigma_activation = @nn.softplus 35 | 36 | TrainConfig.max_steps = 200000 37 | TrainConfig.lr_schedule = { 38 | 'type': 'exponential', 39 | 'initial_value': 0.001, 40 | 'final_value': 0.0001, 41 | 'num_steps': %max_steps, 42 | } 43 | TrainConfig.batch_size = 1024 44 | TrainConfig.warp_alpha_schedule = { 45 | 'type': 'linear', 46 | 'initial_value': 0.0, 47 | 'final_value': 8.0, 48 | 'num_steps': 80000, 49 | } 50 | TrainConfig.use_elastic_loss = True 51 | TrainConfig.elastic_loss_weight_schedule = { 52 | 'type': 'piecewise', 53 | 'schedules': [ 54 | (50000, ('constant', %elastic_init_weight)), 55 | (100000, ('cosine_easing', %elastic_init_weight, 1e-8, 100000)), 56 | ] 57 | } 58 | TrainConfig.use_background_loss = False 59 | TrainConfig.background_loss_weight = 1.0 60 | 61 | TrainConfig.print_every = 10 62 | TrainConfig.log_every = 100 63 | TrainConfig.save_every = 1000 64 | 65 | EvalConfig.eval_once = False 66 | EvalConfig.save_output = True 67 | EvalConfig.chunk = 8192 68 | -------------------------------------------------------------------------------- /configs/test_vrig.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # This is a test configuration for sanity checking. 17 | # It will likely not result in a good quality reconstruction. 18 | # This config will run on a single GPU. 19 | 20 | include 'defaults.gin' 21 | 22 | max_steps = 250000 23 | 24 | image_scale = 8 25 | batch_size = 1024 26 | eval_batch_size = 1024 27 | init_lr = 0.001 28 | final_lr = 0.0001 29 | lr_decay_steps = 500000 30 | elastic_init_weight = 0.001 31 | 32 | ModelConfig.num_nerf_point_freqs = 8 33 | ModelConfig.nerf_trunk_width = 128 34 | ModelConfig.nerf_trunk_depth = 8 35 | ModelConfig.num_coarse_samples = 64 36 | ModelConfig.num_fine_samples = 64 37 | ModelConfig.use_appearance_metadata = False 38 | ModelConfig.use_camera_metadata = True 39 | ModelConfig.use_stratified_sampling = False 40 | ModelConfig.camera_metadata_dims = 2 41 | ModelConfig.use_warp = True 42 | 43 | TrainConfig.use_elastic_loss = True 44 | TrainConfig.use_background_loss = True 45 | TrainConfig.print_every = 1 46 | TrainConfig.log_every = 100 47 | TrainConfig.save_every = 1000 48 | 49 | EvalConfig.chunk = 8192 50 | EvalConfig.num_val_eval = None 51 | EvalConfig.num_train_eval = None 52 | -------------------------------------------------------------------------------- /configs/warp_defaults.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # This file is the base configuration for Nerfies. 17 | # Do not run this directly, it is for importing from other configurations. 18 | 19 | include 'defaults.gin' 20 | 21 | ModelConfig.use_warp = True 22 | ModelConfig.use_appearance_metadata = True 23 | 24 | TrainConfig.warp_alpha_schedule = %ANNEALED_WARP_ALPHA_SCHED 25 | TrainConfig.elastic_loss_weight_schedule = %DECAYING_ELASTIC_LOSS_SCHED 26 | TrainConfig.use_elastic_loss = True 27 | TrainConfig.use_background_loss = True 28 | -------------------------------------------------------------------------------- /nerfies/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /nerfies/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration classes.""" 16 | from typing import Any, Mapping, Optional, Tuple 17 | 18 | import dataclasses 19 | import flax.linen as nn 20 | import gin 21 | import immutabledict 22 | 23 | from nerfies import types 24 | 25 | ScheduleDef = Any 26 | 27 | gin.config.external_configurable(nn.elu, module='flax.nn') 28 | gin.config.external_configurable(nn.relu, module='flax.nn') 29 | gin.config.external_configurable(nn.leaky_relu, module='flax.nn') 30 | gin.config.external_configurable(nn.tanh, module='flax.nn') 31 | gin.config.external_configurable(nn.sigmoid, module='flax.nn') 32 | gin.config.external_configurable(nn.softplus, module='flax.nn') 33 | 34 | 35 | @gin.configurable() 36 | @dataclasses.dataclass 37 | class ModelConfig: 38 | """Parameters for the model.""" 39 | # Sample linearly in disparity rather than depth. 40 | use_linear_disparity: bool = False 41 | # Use white as the default background. 42 | use_white_background: bool = False 43 | # Use stratified sampling. 44 | use_stratified_sampling: bool = True 45 | # Use the sample at infinity. 46 | use_sample_at_infinity: bool = True 47 | # The standard deviation of the alpha noise. 48 | noise_std: Optional[float] = None 49 | 50 | # The depth of the NeRF. 51 | nerf_trunk_depth: int = 8 52 | # The width of the NeRF. 53 | nerf_trunk_width: int = 256 54 | # The depth of the conditional part of the MLP. 55 | nerf_rgb_branch_depth: int = 1 56 | # The width of the conditional part of the MLP. 57 | nerf_rgb_branch_width: int = 128 58 | # The intermediate activation for the NeRF. 59 | activation: types.Activation = nn.relu 60 | # The sigma activation for the NeRF. 61 | sigma_activation: types.Activation = nn.relu 62 | # Adds a skip connection every N layers. 63 | nerf_skips: Tuple[int] = (4,) 64 | # The number of alpha channels. 65 | alpha_channels: int = 1 66 | # The number of RGB channels. 67 | rgb_channels: int = 3 68 | # The number of positional encodings for points. 69 | num_nerf_point_freqs: int = 10 70 | # The number of positional encodings for viewdirs. 71 | num_nerf_viewdir_freqs: int = 4 72 | # The number of coarse samples along each ray. 73 | num_coarse_samples: int = 64 74 | # The number of fine samples along each ray. 75 | num_fine_samples: int = 128 76 | # Whether to use view directions. 77 | use_viewdirs: bool = True 78 | # Whether to condition the entire NeRF MLP. 79 | use_trunk_condition: bool = False 80 | # Whether to condition the density of the template NeRF. 81 | use_alpha_condition: bool = False 82 | # Whether to condition the RGB of the template NeRF. 83 | use_rgb_condition: bool = False 84 | 85 | # Whether to use the appearance metadata for the conditional branch. 86 | use_appearance_metadata: bool = False 87 | # The number of dimensions for the appearance metadata. 88 | appearance_metadata_dims: int = 8 89 | # Whether to use the camera metadata for the conditional branch. 90 | use_camera_metadata: bool = False 91 | # The number of dimensions for the camera metadata. 92 | camera_metadata_dims: int = 2 93 | 94 | # Whether to use the warp field. 95 | use_warp: bool = False 96 | # The number of frequencies for the warp field. 97 | num_warp_freqs: int = 8 98 | # The number of dimensions for the warp metadata. 99 | num_warp_features: int = 8 100 | # The type of warp field to use. One of: 'translation', or 'se3'. 101 | warp_field_type: str = 'translation' 102 | # The type of metadata encoder the warp field should use. 103 | warp_metadata_encoder_type: str = 'glo' 104 | # Additional keyword arguments to pass to the warp field. 105 | warp_kwargs: Mapping[str, Any] = immutabledict.immutabledict() 106 | 107 | 108 | @gin.configurable() 109 | @dataclasses.dataclass 110 | class ExperimentConfig: 111 | """Experiment configuration.""" 112 | # A subname for the experiment e.g., for parameter sweeps. If this is set 113 | # experiment artifacts will be saves to a subdirectory with this name. 114 | subname: Optional[str] = None 115 | # The image scale to use for the dataset. Should be a power of 2. 116 | image_scale: int = 4 117 | # The random seed used to initialize the RNGs for the experiment. 118 | random_seed: int = 12345 119 | # The type of datasource. Either 'nerfies' or 'dynamic_scene'. 120 | datasource_type: str = 'nerfies' 121 | # Data source specification. 122 | datasource_spec: Optional[Mapping[str, Any]] = None 123 | # Extra keyword arguments to pass to the datasource. 124 | datasource_kwargs: Mapping[str, Any] = immutabledict.immutabledict() 125 | 126 | 127 | @gin.configurable() 128 | @dataclasses.dataclass 129 | class TrainConfig: 130 | """Parameters for training.""" 131 | batch_size: int = gin.REQUIRED 132 | 133 | # The definition for the learning rate schedule. 134 | lr_schedule: ScheduleDef = immutabledict.immutabledict({ 135 | 'type': 'exponential', 136 | 'initial_value': 0.001, 137 | 'final_value': 0.0001, 138 | 'num_steps': 1000000, 139 | }) 140 | # The maximum number of training steps. 141 | max_steps: int = 1000000 142 | 143 | # The start value of the warp alpha. 144 | warp_alpha_schedule: ScheduleDef = immutabledict.immutabledict({ 145 | 'type': 'linear', 146 | 'initial_value': 0.0, 147 | 'final_value': 8.0, 148 | 'num_steps': 80000, 149 | }) 150 | 151 | # The time encoder alpha schedule. 152 | time_alpha_schedule: ScheduleDef = ('constant', 0.0) 153 | 154 | # Whether to use the elastic regularization loss. 155 | use_elastic_loss: bool = False 156 | # The weight of the elastic regularization loss. 157 | elastic_loss_weight_schedule: ScheduleDef = ('constant', 0.0) 158 | # Which method to use to reduce the samples for the elastic loss. 159 | # 'weight' computes a weighted sum using the density weights, and 'median' 160 | # selects the sample at the median depth point. 161 | elastic_reduce_method: str = 'weight' 162 | # Which loss method to use for the elastic loss. 163 | elastic_loss_type: str = 'log_svals' 164 | # Whether to use background regularization. 165 | use_background_loss: bool = False 166 | # The weight for the background loss. 167 | background_loss_weight: float = 0.0 168 | # The batch size for background regularization loss. 169 | background_points_batch_size: int = 16384 170 | # Whether to use the warp reg loss. 171 | use_warp_reg_loss: bool = False 172 | # The weight for the warp reg loss. 173 | warp_reg_loss_weight: float = 0.0 174 | # The alpha for the warp reg loss. 175 | warp_reg_loss_alpha: float = -2.0 176 | # The scale for the warp reg loss. 177 | warp_reg_loss_scale: float = 0.001 178 | 179 | # The size of the shuffle buffer size when shuffling the training dataset. 180 | # This needs to be sufficiently large to contain a diverse set of images in 181 | # each batch, especially when optimizing GLO embeddings. 182 | shuffle_buffer_size: int = 5000000 183 | # How often to save a checkpoint. 184 | save_every: int = 10000 185 | # How often to log to Tensorboard. 186 | log_every: int = 500 187 | # How often to log histograms to Tensorboard. 188 | histogram_every: int = 5000 189 | # How often to print to the console. 190 | print_every: int = 25 191 | 192 | 193 | @gin.configurable() 194 | @dataclasses.dataclass 195 | class EvalConfig: 196 | """Parameters for evaluation.""" 197 | # If True only evaluate the model once, otherwise evaluate any new 198 | # checkpoints. 199 | eval_once: bool = False 200 | # If True save the predicted images to persistent storage. 201 | save_output: bool = True 202 | # The evaluation batch size. 203 | chunk: int = 8192 204 | # Max render checkpoints. The renders will rotate after this many. 205 | max_render_checkpoints = 3 206 | 207 | # The number of validation examples to evaluate. (Default: all). 208 | num_val_eval: Optional[int] = 10 209 | # The number of training examples to evaluate. 210 | num_train_eval: Optional[int] = 10 211 | # The number of test examples to evaluate. 212 | num_test_eval: Optional[int] = 10 213 | -------------------------------------------------------------------------------- /nerfies/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset definition and utility package.""" 16 | from nerfies.datasets.core import * 17 | from nerfies.datasets.nerfies import NerfiesDataSource 18 | 19 | 20 | def from_config(spec, **kwargs): 21 | """Create a datasource from a config specification.""" 22 | spec = dict(spec) 23 | ds_type = spec.pop('type') 24 | if ds_type == 'nerfies': 25 | return NerfiesDataSource(**spec, **kwargs) 26 | 27 | raise ValueError(f'Unknown datasource type {ds_type!r}') 28 | -------------------------------------------------------------------------------- /nerfies/datasets/nerfies.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Casual Volumetric Capture datasets.""" 16 | import json 17 | from typing import List, Tuple 18 | 19 | from absl import logging 20 | import cv2 21 | import numpy as np 22 | 23 | from nerfies import gpath 24 | from nerfies import types 25 | from nerfies import utils 26 | from nerfies.datasets import core 27 | 28 | 29 | def load_scene_info( 30 | data_dir: types.PathType) -> Tuple[np.ndarray, float, float, float]: 31 | """Loads the scene scale from scene_scale.npy. 32 | 33 | Args: 34 | data_dir: the path to the dataset. 35 | 36 | Returns: 37 | scene_center: the center of the scene (unscaled coordinates). 38 | scene_scale: the scale of the scene. 39 | near: the near plane of the scene (scaled coordinates). 40 | far: the far plane of the scene (scaled coordinates). 41 | 42 | Raises: 43 | ValueError if scene_scale.npy does not exist. 44 | """ 45 | scene_json_path = gpath.GPath(data_dir, 'scene.json') 46 | with scene_json_path.open('r') as f: 47 | scene_json = json.load(f) 48 | 49 | scene_center = np.array(scene_json['center']) 50 | scene_scale = scene_json['scale'] 51 | near = scene_json['near'] 52 | far = scene_json['far'] 53 | 54 | return scene_center, scene_scale, near, far 55 | 56 | 57 | def _load_image(path: types.PathType) -> np.ndarray: 58 | path = gpath.GPath(path) 59 | with path.open('rb') as f: 60 | raw_im = np.asarray(bytearray(f.read()), dtype=np.uint8) 61 | image = cv2.imdecode(raw_im, cv2.IMREAD_COLOR)[:, :, ::-1] # BGR -> RGB 62 | image = np.asarray(image).astype(np.float32) / 255.0 63 | return image 64 | 65 | 66 | def _load_dataset_ids(data_dir: types.PathType) -> Tuple[List[str], List[str]]: 67 | """Loads dataset IDs.""" 68 | dataset_json_path = gpath.GPath(data_dir, 'dataset.json') 69 | logging.info('*** Loading dataset IDs from %s', dataset_json_path) 70 | with dataset_json_path.open('r') as f: 71 | dataset_json = json.load(f) 72 | train_ids = dataset_json['train_ids'] 73 | val_ids = dataset_json['val_ids'] 74 | 75 | train_ids = [str(i) for i in train_ids] 76 | val_ids = [str(i) for i in val_ids] 77 | 78 | return train_ids, val_ids 79 | 80 | 81 | class NerfiesDataSource(core.DataSource): 82 | """Data loader for videos.""" 83 | 84 | def __init__( 85 | self, 86 | data_dir, 87 | image_scale: int, 88 | shuffle_pixels=False, 89 | camera_type='json', 90 | test_camera_trajectory='orbit-extreme', 91 | **kwargs): 92 | self.data_dir = gpath.GPath(data_dir) 93 | # Load IDs from JSON if it exists. This is useful since COLMAP fails on 94 | # some images so this gives us the ability to skip invalid images. 95 | train_ids, val_ids = _load_dataset_ids(self.data_dir) 96 | super().__init__(train_ids=train_ids, val_ids=val_ids, 97 | **kwargs) 98 | self.scene_center, self.scene_scale, self._near, self._far = \ 99 | load_scene_info(self.data_dir) 100 | self.test_camera_trajectory = test_camera_trajectory 101 | 102 | self.image_scale = image_scale 103 | self.shuffle_pixels = shuffle_pixels 104 | 105 | self.rgb_dir = gpath.GPath(data_dir, 'rgb', f'{image_scale}x') 106 | self.depth_dir = gpath.GPath(data_dir, 'depth', f'{image_scale}x') 107 | self.camera_type = camera_type 108 | self.camera_dir = gpath.GPath(data_dir, 'camera') 109 | 110 | metadata_path = self.data_dir / 'metadata.json' 111 | self.metadata_dict = None 112 | if metadata_path.exists(): 113 | with metadata_path.open('r') as f: 114 | self.metadata_dict = json.load(f) 115 | 116 | @property 117 | def near(self): 118 | return self._near 119 | 120 | @property 121 | def far(self): 122 | return self._far 123 | 124 | @property 125 | def camera_ext(self): 126 | if self.camera_type == 'json': 127 | return '.json' 128 | 129 | raise ValueError(f'Unknown camera_type {self.camera_type}') 130 | 131 | def get_rgb_path(self, item_id): 132 | return self.rgb_dir / f'{item_id}.png' 133 | 134 | def load_rgb(self, item_id): 135 | return _load_image(self.rgb_dir / f'{item_id}.png') 136 | 137 | def load_camera(self, item_id, scale_factor=1.0): 138 | if isinstance(item_id, gpath.GPath): 139 | camera_path = item_id 140 | else: 141 | if self.camera_type == 'json': 142 | camera_path = self.camera_dir / f'{item_id}{self.camera_ext}' 143 | else: 144 | raise ValueError(f'Unknown camera type {self.camera_type!r}.') 145 | 146 | return core.load_camera(camera_path, 147 | scale_factor=scale_factor / self.image_scale, 148 | scene_center=self.scene_center, 149 | scene_scale=self.scene_scale) 150 | 151 | def glob_cameras(self, path): 152 | path = gpath.GPath(path) 153 | return sorted(path.glob(f'*{self.camera_ext}')) 154 | 155 | def load_test_cameras(self, count=None): 156 | camera_dir = (self.data_dir / 'camera-paths' / self.test_camera_trajectory) 157 | if not camera_dir.exists(): 158 | logging.warning('test camera path does not exist: %s', str(camera_dir)) 159 | return [] 160 | camera_paths = sorted(camera_dir.glob(f'*{self.camera_ext}')) 161 | if count is not None: 162 | stride = max(1, len(camera_paths) // count) 163 | camera_paths = camera_paths[::stride] 164 | cameras = utils.parallel_map(self.load_camera, camera_paths) 165 | return cameras 166 | 167 | def load_points(self, shuffle=False): 168 | with (self.data_dir / 'points.npy').open('rb') as f: 169 | points = np.load(f) 170 | points = (points - self.scene_center) * self.scene_scale 171 | points = points.astype(np.float32) 172 | if shuffle: 173 | logging.info('Shuffling points.') 174 | shuffled_inds = self.rng.permutation(len(points)) 175 | points = points[shuffled_inds] 176 | logging.info('Loaded %d points.', len(points)) 177 | return points 178 | 179 | def get_appearance_id(self, item_id): 180 | return self.metadata_dict[item_id]['appearance_id'] 181 | 182 | def get_camera_id(self, item_id): 183 | return self.metadata_dict[item_id]['camera_id'] 184 | 185 | def get_warp_id(self, item_id): 186 | return self.metadata_dict[item_id]['warp_id'] 187 | 188 | def get_time_id(self, item_id): 189 | if 'time_id' in self.metadata_dict[item_id]: 190 | return self.metadata_dict[item_id]['time_id'] 191 | else: 192 | # Fallback for older datasets. 193 | return self.metadata_dict[item_id]['warp_id'] 194 | -------------------------------------------------------------------------------- /nerfies/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for evaluating a trained NeRF.""" 16 | import math 17 | import time 18 | 19 | from absl import logging 20 | from flax import jax_utils 21 | import jax 22 | from jax import tree_util 23 | import jax.numpy as jnp 24 | 25 | from nerfies import utils 26 | 27 | 28 | def render_image( 29 | state, 30 | rays_dict, 31 | model_fn, 32 | device_count, 33 | rng, 34 | chunk=8192, 35 | default_ret_key=None): 36 | """Render all the pixels of an image (in test mode). 37 | 38 | Args: 39 | state: model_utils.TrainState. 40 | rays_dict: dict, test example. 41 | model_fn: function, jit-ed render function. 42 | device_count: The number of devices to shard batches over. 43 | rng: The random number generator. 44 | chunk: int, the size of chunks to render sequentially. 45 | default_ret_key: either 'fine' or 'coarse'. If None will default to highest. 46 | 47 | Returns: 48 | rgb: jnp.ndarray, rendered color image. 49 | depth: jnp.ndarray, rendered depth. 50 | acc: jnp.ndarray, rendered accumulated weights per pixel. 51 | """ 52 | h, w = rays_dict['origins'].shape[:2] 53 | rays_dict = tree_util.tree_map(lambda x: x.reshape((h * w, -1)), rays_dict) 54 | num_rays = h * w 55 | _, key_0, key_1 = jax.random.split(rng, 3) 56 | key_0 = jax.random.split(key_0, device_count) 57 | key_1 = jax.random.split(key_1, device_count) 58 | host_id = jax.process_index() 59 | ret_maps = [] 60 | start_time = time.time() 61 | num_batches = int(math.ceil(num_rays / chunk)) 62 | for batch_idx in range(num_batches): 63 | ray_idx = batch_idx * chunk 64 | logging.log_every_n_seconds( 65 | logging.INFO, 'Rendering batch %d/%d (%d/%d)', 2.0, 66 | batch_idx, num_batches, ray_idx, num_rays) 67 | # pylint: disable=cell-var-from-loop 68 | chunk_slice_fn = lambda x: x[ray_idx:ray_idx + chunk] 69 | chunk_rays_dict = tree_util.tree_map(chunk_slice_fn, rays_dict) 70 | num_chunk_rays = chunk_rays_dict['origins'].shape[0] 71 | remainder = num_chunk_rays % device_count 72 | if remainder != 0: 73 | padding = device_count - remainder 74 | # pylint: disable=cell-var-from-loop 75 | chunk_pad_fn = lambda x: jnp.pad(x, ((0, padding), (0, 0)), mode='edge') 76 | chunk_rays_dict = tree_util.tree_map(chunk_pad_fn, chunk_rays_dict) 77 | else: 78 | padding = 0 79 | # After padding the number of chunk_rays is always divisible by 80 | # host_count. 81 | per_host_rays = num_chunk_rays // jax.process_count() 82 | chunk_rays_dict = tree_util.tree_map( 83 | lambda x: x[(host_id * per_host_rays):((host_id + 1) * per_host_rays)], 84 | chunk_rays_dict) 85 | chunk_rays_dict = utils.shard(chunk_rays_dict, device_count) 86 | model_out = model_fn(key_0, key_1, state.optimizer.target['model'], 87 | chunk_rays_dict, state.warp_extra) 88 | if not default_ret_key: 89 | ret_key = 'fine' if 'fine' in model_out else 'coarse' 90 | else: 91 | ret_key = default_ret_key 92 | ret_map = jax_utils.unreplicate(model_out[ret_key]) 93 | ret_map = jax.tree_map(lambda x: utils.unshard(x, padding), ret_map) 94 | ret_maps.append(ret_map) 95 | ret_map = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *ret_maps) 96 | logging.info('Rendering took %.04s', time.time() - start_time) 97 | out = {} 98 | for key, value in ret_map.items(): 99 | out[key] = value.reshape((h, w, *value.shape[1:])) 100 | 101 | return out 102 | -------------------------------------------------------------------------------- /nerfies/glo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A module to help create embeddings in Jax.""" 16 | from flax import linen as nn 17 | import jax.numpy as jnp 18 | 19 | from nerfies import types 20 | 21 | 22 | class GloEncoder(nn.Module): 23 | """A GLO encoder module, which is just a thin wrapper around nn.Embed. 24 | 25 | Attributes: 26 | num_embeddings: The number of embeddings. 27 | features: The dimensions of each embedding. 28 | embedding_init: The initializer to use for each. 29 | """ 30 | 31 | num_embeddings: int 32 | features: int 33 | embedding_init: types.Activation = nn.initializers.uniform(scale=0.05) 34 | 35 | def setup(self): 36 | self.embed = nn.Embed( 37 | num_embeddings=self.num_embeddings, 38 | features=self.features, 39 | embedding_init=self.embedding_init) 40 | 41 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 42 | """Method to get embeddings for specified indices. 43 | 44 | Args: 45 | inputs: The indices to fetch embeddings for. 46 | 47 | Returns: 48 | The embeddings corresponding to the indices provided. 49 | """ 50 | if inputs.shape[-1] == 1: 51 | inputs = jnp.squeeze(inputs, axis=-1) 52 | 53 | return self.embed(inputs) 54 | -------------------------------------------------------------------------------- /nerfies/gpath.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A thin wrapper around pathlib.""" 16 | import pathlib 17 | import tensorflow as tf 18 | 19 | 20 | class GPath(pathlib.PurePosixPath): 21 | """A thin wrapper around PurePath to support various filesystems.""" 22 | 23 | def open(self, *args, **kwargs): 24 | return tf.io.gfile.GFile(self, *args, **kwargs) 25 | 26 | def exists(self): 27 | return tf.io.gfile.exists(self) 28 | 29 | # pylint: disable=unused-argument 30 | def mkdir(self, mode=0o777, parents=False, exist_ok=False): 31 | if not exist_ok: 32 | if self.exists(): 33 | raise FileExistsError('Directory already exists.') 34 | 35 | if parents: 36 | return tf.io.gfile.makedirs(self) 37 | else: 38 | return tf.io.gfile.mkdir(self) 39 | 40 | def glob(self, pattern): 41 | return [GPath(x) for x in tf.io.gfile.glob(str(self / pattern))] 42 | 43 | def iterdir(self): 44 | return [GPath(self, x) for x in tf.io.gfile.listdir(self)] 45 | 46 | def is_dir(self): 47 | return tf.io.gfile.isdir(self) 48 | 49 | def rmtree(self): 50 | tf.io.gfile.rmtree(self) 51 | -------------------------------------------------------------------------------- /nerfies/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Image-related utility functions.""" 16 | import math 17 | from typing import Tuple 18 | 19 | from absl import logging 20 | import cv2 21 | import imageio 22 | import numpy as np 23 | from PIL import Image 24 | 25 | from nerfies import gpath 26 | from nerfies import types 27 | 28 | 29 | UINT8_MAX = 255 30 | UINT16_MAX = 65535 31 | 32 | 33 | def make_divisible(image: np.ndarray, divisor: int) -> np.ndarray: 34 | """Trim the image if not divisible by the divisor.""" 35 | height, width = image.shape[:2] 36 | if height % divisor == 0 and width % divisor == 0: 37 | return image 38 | 39 | new_height = height - height % divisor 40 | new_width = width - width % divisor 41 | 42 | return image[:new_height, :new_width] 43 | 44 | 45 | def downsample_image(image: np.ndarray, scale: int) -> np.ndarray: 46 | """Downsamples the image by an integer factor to prevent artifacts.""" 47 | if scale == 1: 48 | return image 49 | 50 | height, width = image.shape[:2] 51 | if height % scale > 0 or width % scale > 0: 52 | raise ValueError(f'Image shape ({height},{width}) must be divisible by the' 53 | f' scale ({scale}).') 54 | out_height, out_width = height // scale, width // scale 55 | resized = cv2.resize(image, (out_width, out_height), cv2.INTER_AREA) 56 | return resized 57 | 58 | 59 | def upsample_image(image: np.ndarray, scale: int) -> np.ndarray: 60 | """Upsamples the image by an integer factor.""" 61 | if scale == 1: 62 | return image 63 | 64 | height, width = image.shape[:2] 65 | out_height, out_width = height * scale, width * scale 66 | resized = cv2.resize(image, (out_width, out_height), cv2.INTER_AREA) 67 | return resized 68 | 69 | 70 | def reshape_image(image: np.ndarray, shape: Tuple[int, int]) -> np.ndarray: 71 | """Reshapes the image to the given shape.""" 72 | out_height, out_width = shape 73 | return cv2.resize( 74 | image, (out_width, out_height), interpolation=cv2.INTER_AREA) 75 | 76 | 77 | def rescale_image(image: np.ndarray, scale_factor: float) -> np.ndarray: 78 | """Resize an image by a scale factor, using integer resizing if possible.""" 79 | scale_factor = float(scale_factor) 80 | if scale_factor <= 0.0: 81 | raise ValueError('scale_factor must be a non-negative number.') 82 | 83 | if scale_factor == 1.0: 84 | return image 85 | 86 | height, width = image.shape[:2] 87 | if scale_factor.is_integer(): 88 | return upsample_image(image, int(scale_factor)) 89 | 90 | inv_scale = 1.0 / scale_factor 91 | if (inv_scale.is_integer() and (scale_factor * height).is_integer() and 92 | (scale_factor * width).is_integer()): 93 | return downsample_image(image, int(inv_scale)) 94 | 95 | logging.warning( 96 | 'resizing image by non-integer factor %f, this may lead to artifacts.', 97 | scale_factor) 98 | 99 | height, width = image.shape[:2] 100 | out_height = math.ceil(height * scale_factor) 101 | out_height -= out_height % 2 102 | out_width = math.ceil(width * scale_factor) 103 | out_width -= out_width % 2 104 | 105 | return reshape_image(image, (out_height, out_width)) 106 | 107 | 108 | def variance_of_laplacian(image: np.ndarray) -> np.ndarray: 109 | """Compute the variance of the Laplacian which measure the focus.""" 110 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 111 | return cv2.Laplacian(gray, cv2.CVX_64F).var() 112 | 113 | 114 | def image_to_uint8(image: np.ndarray) -> np.ndarray: 115 | """Convert the image to a uint8 array.""" 116 | if image.dtype == np.uint8: 117 | return image 118 | if not issubclass(image.dtype.type, np.floating): 119 | raise ValueError( 120 | f'Input image should be a floating type but is of type {image.dtype!r}') 121 | return (image * UINT8_MAX).clip(0.0, UINT8_MAX).astype(np.uint8) 122 | 123 | 124 | def image_to_uint16(image: np.ndarray) -> np.ndarray: 125 | """Convert the image to a uint16 array.""" 126 | if image.dtype == np.uint16: 127 | return image 128 | if not issubclass(image.dtype.type, np.floating): 129 | raise ValueError( 130 | f'Input image should be a floating type but is of type {image.dtype!r}') 131 | return (image * UINT16_MAX).clip(0.0, UINT16_MAX).astype(np.uint16) 132 | 133 | 134 | def image_to_float32(image: np.ndarray) -> np.ndarray: 135 | """Convert the image to a float32 array and scale values appropriately.""" 136 | if image.dtype == np.float32: 137 | return image 138 | 139 | dtype = image.dtype 140 | image = image.astype(np.float32) 141 | if dtype == np.uint8: 142 | return image / UINT8_MAX 143 | elif dtype == np.uint16: 144 | return image / UINT16_MAX 145 | elif dtype == np.float64: 146 | return image 147 | elif dtype == np.float16: 148 | return image 149 | 150 | raise ValueError(f'Not sure how to handle dtype {dtype}') 151 | 152 | 153 | def load_image(path: types.PathType) -> np.ndarray: 154 | """Reads an image.""" 155 | if not isinstance(path, gpath.GPath): 156 | path = gpath.GPath(path) 157 | 158 | with path.open('rb') as f: 159 | return imageio.imread(f) 160 | 161 | 162 | def save_image(path: types.PathType, image: np.ndarray) -> None: 163 | """Saves the image to disk or gfile.""" 164 | if not isinstance(path, gpath.GPath): 165 | path = gpath.GPath(path) 166 | 167 | with path.open('wb') as f: 168 | image = Image.fromarray(np.asarray(image)) 169 | image.save(f, format=path.suffix.lstrip('.')) 170 | 171 | 172 | def save_depth(path: types.PathType, depth: np.ndarray) -> None: 173 | save_image(path, image_to_uint16(depth / 1000.0)) 174 | 175 | 176 | def load_depth(path: types.PathType) -> np.ndarray: 177 | depth = load_image(path) 178 | if depth.dtype != np.uint16: 179 | raise ValueError('Depth image must be of type uint16.') 180 | return image_to_float32(depth) * 1000.0 181 | 182 | 183 | def checkerboard(h, w, size=8): 184 | """Creates a checkerboard pattern with height h and width w.""" 185 | i = int(math.ceil(h / (size * 2))) 186 | j = int(math.ceil(w / (size * 2))) 187 | pattern = np.kron([[1, 0] * j, [0, 1] * j] * i, 188 | np.ones((size, size)))[:h, :w] 189 | return np.clip(pattern + 0.8, 0.0, 1.0) 190 | -------------------------------------------------------------------------------- /nerfies/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions/classes for model definition.""" 16 | 17 | from flax import linen as nn 18 | from flax import optim 19 | from flax import struct 20 | from jax import lax 21 | from jax import random 22 | import jax.numpy as jnp 23 | 24 | 25 | @struct.dataclass 26 | class TrainState: 27 | optimizer: optim.Optimizer 28 | warp_alpha: jnp.ndarray = 0.0 29 | time_alpha: jnp.ndarray = 0.0 30 | 31 | @property 32 | def warp_extra(self): 33 | return {'alpha': self.warp_alpha, 'time_alpha': self.time_alpha} 34 | 35 | 36 | def sample_along_rays(key, origins, directions, num_coarse_samples, near, far, 37 | use_stratified_sampling, use_linear_disparity): 38 | """Stratified sampling along the rays. 39 | 40 | Args: 41 | key: jnp.ndarray, random generator key. 42 | origins: ray origins. 43 | directions: ray directions. 44 | num_coarse_samples: int. 45 | near: float, near clip. 46 | far: float, far clip. 47 | use_stratified_sampling: use stratified sampling. 48 | use_linear_disparity: sampling linearly in disparity rather than depth. 49 | 50 | Returns: 51 | z_vals: jnp.ndarray, [batch_size, num_coarse_samples], sampled z values. 52 | points: jnp.ndarray, [batch_size, num_coarse_samples, 3], sampled points. 53 | """ 54 | batch_size = origins.shape[0] 55 | 56 | t_vals = jnp.linspace(0., 1., num_coarse_samples) 57 | if not use_linear_disparity: 58 | z_vals = near * (1. - t_vals) + far * t_vals 59 | else: 60 | z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals) 61 | if use_stratified_sampling: 62 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 63 | upper = jnp.concatenate([mids, z_vals[..., -1:]], -1) 64 | lower = jnp.concatenate([z_vals[..., :1], mids], -1) 65 | t_rand = random.uniform(key, [batch_size, num_coarse_samples]) 66 | z_vals = lower + (upper - lower) * t_rand 67 | else: 68 | # Broadcast z_vals to make the returned shape consistent. 69 | z_vals = jnp.broadcast_to(z_vals[None, ...], 70 | [batch_size, num_coarse_samples]) 71 | 72 | return (z_vals, (origins[..., None, :] + 73 | z_vals[..., :, None] * directions[..., None, :])) 74 | 75 | 76 | def volumetric_rendering(rgb, 77 | sigma, 78 | z_vals, 79 | dirs, 80 | use_white_background, 81 | sample_at_infinity=True, 82 | return_weights=False, 83 | eps=1e-10): 84 | """Volumetric Rendering Function. 85 | 86 | Args: 87 | rgb: an array of size (B,S,3) containing the RGB color values. 88 | sigma: an array of size (B,S,1) containing the densities. 89 | z_vals: an array of size (B,S) containing the z-coordinate of the samples. 90 | dirs: an array of size (B,3) containing the directions of rays. 91 | use_white_background: whether to assume a white background or not. 92 | sample_at_infinity: if True adds a sample at infinity. 93 | return_weights: if True returns the weights in the dictionary. 94 | eps: a small number to prevent numerical issues. 95 | 96 | Returns: 97 | A dictionary containing: 98 | rgb: an array of size (B,3) containing the rendered colors. 99 | depth: an array of size (B,) containing the rendered depth. 100 | acc: an array of size (B,) containing the accumulated density. 101 | weights: an array of size (B,S) containing the weight of each sample. 102 | """ 103 | # TODO(keunhong): remove this hack. 104 | last_sample_z = 1e10 if sample_at_infinity else 1e-19 105 | dists = jnp.concatenate([ 106 | z_vals[..., 1:] - z_vals[..., :-1], 107 | jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape) 108 | ], -1) 109 | dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1) 110 | alpha = 1.0 - jnp.exp(-sigma * dists) 111 | # Prepend a 1.0 to make this an 'exclusive' cumprod as in `tf.math.cumprod`. 112 | accum_prod = jnp.concatenate([ 113 | jnp.ones_like(alpha[..., :1], alpha.dtype), 114 | jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1), 115 | ], axis=-1) 116 | weights = alpha * accum_prod 117 | 118 | rgb = (weights[..., None] * rgb).sum(axis=-2) 119 | exp_depth = (weights * z_vals).sum(axis=-1) 120 | med_depth = compute_depth_map(weights, z_vals) 121 | acc = weights.sum(axis=-1) 122 | if use_white_background: 123 | rgb = rgb + (1. - acc[..., None]) 124 | 125 | if sample_at_infinity: 126 | acc = weights[..., :-1].sum(axis=-1) 127 | 128 | out = { 129 | 'rgb': rgb, 130 | 'depth': exp_depth, 131 | 'med_depth': med_depth, 132 | 'acc': acc, 133 | } 134 | if return_weights: 135 | out['weights'] = weights 136 | return out 137 | 138 | 139 | def piecewise_constant_pdf(key, bins, weights, num_coarse_samples, 140 | use_stratified_sampling): 141 | """Piecewise-Constant PDF sampling. 142 | 143 | Args: 144 | key: jnp.ndarray(float32), [2,], random number generator. 145 | bins: jnp.ndarray(float32), [batch_size, n_bins + 1]. 146 | weights: jnp.ndarray(float32), [batch_size, n_bins]. 147 | num_coarse_samples: int, the number of samples. 148 | use_stratified_sampling: bool, use use_stratified_sampling samples. 149 | 150 | Returns: 151 | z_samples: jnp.ndarray(float32), [batch_size, num_coarse_samples]. 152 | """ 153 | eps = 1e-5 154 | 155 | # Get pdf 156 | weights += eps # prevent nans 157 | pdf = weights / weights.sum(axis=-1, keepdims=True) 158 | cdf = jnp.cumsum(pdf, axis=-1) 159 | cdf = jnp.concatenate([jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf], axis=-1) 160 | 161 | # Take uniform samples 162 | if use_stratified_sampling: 163 | u = random.uniform(key, list(cdf.shape[:-1]) + [num_coarse_samples]) 164 | else: 165 | u = jnp.linspace(0., 1., num_coarse_samples) 166 | u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_coarse_samples]) 167 | 168 | # Invert CDF. This takes advantage of the fact that `bins` is sorted. 169 | mask = (u[..., None, :] >= cdf[..., :, None]) 170 | 171 | def minmax(x): 172 | x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2) 173 | x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2) 174 | x0 = jnp.minimum(x0, x[..., -2:-1]) 175 | x1 = jnp.maximum(x1, x[..., 1:2]) 176 | return x0, x1 177 | 178 | bins_g0, bins_g1 = minmax(bins) 179 | cdf_g0, cdf_g1 = minmax(cdf) 180 | 181 | denom = (cdf_g1 - cdf_g0) 182 | denom = jnp.where(denom < eps, 1., denom) 183 | t = (u - cdf_g0) / denom 184 | z_samples = bins_g0 + t * (bins_g1 - bins_g0) 185 | 186 | # Prevent gradient from backprop-ing through samples 187 | return lax.stop_gradient(z_samples) 188 | 189 | 190 | def sample_pdf(key, bins, weights, origins, directions, z_vals, 191 | num_coarse_samples, use_stratified_sampling): 192 | """Hierarchical sampling. 193 | 194 | Args: 195 | key: jnp.ndarray(float32), [2,], random number generator. 196 | bins: jnp.ndarray(float32), [batch_size, n_bins + 1]. 197 | weights: jnp.ndarray(float32), [batch_size, n_bins]. 198 | origins: ray origins. 199 | directions: ray directions. 200 | z_vals: jnp.ndarray(float32), [batch_size, n_coarse_samples]. 201 | num_coarse_samples: int, the number of samples. 202 | use_stratified_sampling: bool, use use_stratified_sampling samples. 203 | 204 | Returns: 205 | z_vals: jnp.ndarray(float32), 206 | [batch_size, n_coarse_samples + num_fine_samples]. 207 | points: jnp.ndarray(float32), 208 | [batch_size, n_coarse_samples + num_fine_samples, 3]. 209 | """ 210 | z_samples = piecewise_constant_pdf(key, bins, weights, num_coarse_samples, 211 | use_stratified_sampling) 212 | # Compute united z_vals and sample points 213 | z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1) 214 | return z_vals, ( 215 | origins[..., None, :] + z_vals[..., None] * directions[..., None, :]) 216 | 217 | 218 | def compute_opaqueness_mask(weights, depth_threshold=0.5): 219 | """Computes a mask which will be 1.0 at the depth point. 220 | 221 | Args: 222 | weights: the density weights from NeRF. 223 | depth_threshold: the accumulation threshold which will be used as the depth 224 | termination point. 225 | 226 | Returns: 227 | A tensor containing a mask with the same size as weights that has one 228 | element long the sample dimension that is 1.0. This element is the point 229 | where the 'surface' is. 230 | """ 231 | cumulative_contribution = jnp.cumsum(weights, axis=-1) 232 | depth_threshold = jnp.array(depth_threshold, dtype=weights.dtype) 233 | opaqueness = cumulative_contribution >= depth_threshold 234 | false_padding = jnp.zeros_like(opaqueness[..., :1]) 235 | padded_opaqueness = jnp.concatenate( 236 | [false_padding, opaqueness[..., :-1]], axis=-1) 237 | opaqueness_mask = jnp.logical_xor(opaqueness, padded_opaqueness) 238 | opaqueness_mask = opaqueness_mask.astype(weights.dtype) 239 | return opaqueness_mask 240 | 241 | 242 | def compute_depth_index(weights, depth_threshold=0.5): 243 | """Compute the sample index of the median depth accumulation.""" 244 | opaqueness_mask = compute_opaqueness_mask(weights, depth_threshold) 245 | return jnp.argmax(opaqueness_mask, axis=-1) 246 | 247 | 248 | def compute_depth_map(weights, z_vals, depth_threshold=0.5): 249 | """Compute the depth using the median accumulation. 250 | 251 | Note that this differs from the depth computation in NeRF-W's codebase! 252 | 253 | Args: 254 | weights: the density weights from NeRF. 255 | z_vals: the z coordinates of the samples. 256 | depth_threshold: the accumulation threshold which will be used as the depth 257 | termination point. 258 | 259 | Returns: 260 | A tensor containing the depth of each input pixel. 261 | """ 262 | opaqueness_mask = compute_opaqueness_mask(weights, depth_threshold) 263 | return jnp.sum(opaqueness_mask * z_vals, axis=-1) 264 | 265 | 266 | def noise_regularize(key, raw, noise_std, use_stratified_sampling): 267 | """Regularize the density prediction by adding gaussian noise. 268 | 269 | Args: 270 | key: jnp.ndarray(float32), [2,], random number generator. 271 | raw: jnp.ndarray(float32), [batch_size, num_coarse_samples, 4]. 272 | noise_std: float, std dev of noise added to regularize sigma output. 273 | use_stratified_sampling: add noise only if use_stratified_sampling is True. 274 | 275 | Returns: 276 | raw: jnp.ndarray(float32), [batch_size, num_coarse_samples, 4], updated raw. 277 | """ 278 | if (noise_std is not None) and noise_std > 0.0 and use_stratified_sampling: 279 | unused_key, key = random.split(key) 280 | noise = random.normal(key, raw[..., 3:4].shape, dtype=raw.dtype) * noise_std 281 | raw = jnp.concatenate([raw[..., :3], raw[..., 3:4] + noise], axis=-1) 282 | return raw 283 | 284 | 285 | def broadcast_feature_to(array: jnp.ndarray, shape: jnp.shape): 286 | """Matches the shape dimensions (everything except the channel dims). 287 | 288 | This is useful when you watch to match the shape of two features that have 289 | a different number of channels. 290 | 291 | Args: 292 | array: the array to broadcast. 293 | shape: the shape to broadcast the tensor to. 294 | 295 | Returns: 296 | The broadcasted tensor. 297 | """ 298 | out_shape = (*shape[:-1], array.shape[-1]) 299 | return jnp.broadcast_to(array, out_shape) 300 | 301 | 302 | def metadata_like(rays, metadata_id): 303 | """Create a metadata array like a ray batch.""" 304 | return jnp.full_like(rays[..., :1], fill_value=metadata_id, dtype=jnp.uint32) 305 | 306 | 307 | def vmap_module(module, in_axes=0, out_axes=0, num_batch_dims=1): 308 | """Vectorize a module. 309 | 310 | Args: 311 | module: the module to vectorize. 312 | in_axes: the `in_axes` argument passed to vmap. See `jax.vmap`. 313 | out_axes: the `out_axes` argument passed to vmap. See `jax.vmap`. 314 | num_batch_dims: the number of batch dimensions (how many times to apply vmap 315 | to the module). 316 | 317 | Returns: 318 | A vectorized module. 319 | """ 320 | for _ in range(num_batch_dims): 321 | module = nn.vmap( 322 | module, 323 | variable_axes={'params': None}, 324 | split_rngs={'params': False}, 325 | in_axes=in_axes, 326 | out_axes=out_axes) 327 | 328 | return module 329 | 330 | 331 | def identity_initializer(_, shape): 332 | max_shape = max(shape) 333 | return jnp.eye(max_shape)[:shape[0], :shape[1]] 334 | -------------------------------------------------------------------------------- /nerfies/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Modules for NeRF models.""" 16 | import functools 17 | from typing import Optional, Tuple 18 | 19 | from flax import linen as nn 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from nerfies import types 24 | 25 | 26 | class MLP(nn.Module): 27 | """Basic MLP class with hidden layers and an output layers.""" 28 | depth: int 29 | width: int 30 | hidden_init: types.Initializer = nn.initializers.xavier_uniform() 31 | hidden_activation: types.Activation = nn.relu 32 | output_init: Optional[types.Initializer] = None 33 | output_channels: int = 0 34 | output_activation: Optional[types.Activation] = lambda x: x 35 | use_bias: bool = True 36 | skips: Tuple[int] = tuple() 37 | 38 | @nn.compact 39 | def __call__(self, x): 40 | inputs = x 41 | for i in range(self.depth): 42 | layer = nn.Dense( 43 | self.width, 44 | use_bias=self.use_bias, 45 | kernel_init=self.hidden_init, 46 | name=f'hidden_{i}') 47 | if i in self.skips: 48 | x = jnp.concatenate([x, inputs], axis=-1) 49 | x = layer(x) 50 | x = self.hidden_activation(x) 51 | 52 | if self.output_channels > 0: 53 | logit_layer = nn.Dense( 54 | self.output_channels, 55 | use_bias=self.use_bias, 56 | kernel_init=self.output_init, 57 | name='logit') 58 | x = logit_layer(x) 59 | if self.output_activation is not None: 60 | x = self.output_activation(x) 61 | 62 | return x 63 | 64 | 65 | class NerfMLP(nn.Module): 66 | """A simple MLP. 67 | 68 | Attributes: 69 | nerf_trunk_depth: int, the depth of the first part of MLP. 70 | nerf_trunk_width: int, the width of the first part of MLP. 71 | nerf_rgb_branch_depth: int, the depth of the second part of MLP. 72 | nerf_rgb_branch_width: int, the width of the second part of MLP. 73 | activation: function, the activation function used in the MLP. 74 | skips: which layers to add skip layers to. 75 | alpha_channels: int, the number of alpha_channelss. 76 | rgb_channels: int, the number of rgb_channelss. 77 | condition_density: if True put the condition at the begining which 78 | conditions the density of the field. 79 | """ 80 | trunk_depth: int = 8 81 | trunk_width: int = 256 82 | 83 | rgb_branch_depth: int = 1 84 | rgb_branch_width: int = 128 85 | rgb_channels: int = 3 86 | 87 | alpha_branch_depth: int = 0 88 | alpha_branch_width: int = 128 89 | alpha_channels: int = 1 90 | 91 | activation: types.Activation = nn.relu 92 | skips: Tuple[int] = (4,) 93 | 94 | @nn.compact 95 | def __call__(self, x, trunk_condition, alpha_condition, rgb_condition): 96 | """Multi-layer perception for nerf. 97 | 98 | Args: 99 | x: sample points with shape [batch, num_coarse_samples, feature]. 100 | trunk_condition: a condition array provided to the trunk. 101 | alpha_condition: a condition array provided to the alpha branch. 102 | rgb_condition: a condition array provided in the RGB branch. 103 | 104 | Returns: 105 | raw: [batch, num_coarse_samples, rgb_channels+alpha_channels]. 106 | """ 107 | dense = functools.partial( 108 | nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform()) 109 | 110 | feature_dim = x.shape[-1] 111 | num_samples = x.shape[1] 112 | x = x.reshape([-1, feature_dim]) 113 | 114 | def broadcast_condition(c): 115 | # Broadcast condition from [batch, feature] to 116 | # [batch, num_coarse_samples, feature] since all the samples along the 117 | # same ray has the same viewdir. 118 | c = jnp.tile(c[:, None, :], (1, num_samples, 1)) 119 | # Collapse the [batch, num_coarse_samples, feature] tensor to 120 | # [batch * num_coarse_samples, feature] to be fed into nn.Dense. 121 | c = c.reshape([-1, c.shape[-1]]) 122 | return c 123 | 124 | trunk_mlp = MLP(depth=self.trunk_depth, 125 | width=self.trunk_width, 126 | hidden_activation=self.activation, 127 | hidden_init=jax.nn.initializers.glorot_uniform(), 128 | skips=self.skips) 129 | rgb_mlp = MLP(depth=self.rgb_branch_depth, 130 | width=self.rgb_branch_width, 131 | hidden_activation=self.activation, 132 | hidden_init=jax.nn.initializers.glorot_uniform(), 133 | output_init=jax.nn.initializers.glorot_uniform(), 134 | output_channels=self.rgb_channels) 135 | alpha_mlp = MLP(depth=self.alpha_branch_depth, 136 | width=self.alpha_branch_width, 137 | hidden_activation=self.activation, 138 | hidden_init=jax.nn.initializers.glorot_uniform(), 139 | output_init=jax.nn.initializers.glorot_uniform(), 140 | output_channels=self.alpha_channels) 141 | 142 | if trunk_condition is not None: 143 | trunk_condition = broadcast_condition(trunk_condition) 144 | trunk_input = jnp.concatenate([x, trunk_condition], axis=-1) 145 | else: 146 | trunk_input = x 147 | x = trunk_mlp(trunk_input) 148 | 149 | if (alpha_condition is not None) or (rgb_condition is not None): 150 | bottleneck = dense(self.trunk_width, name='bottleneck')(x) 151 | 152 | if alpha_condition is not None: 153 | alpha_condition = broadcast_condition(alpha_condition) 154 | alpha_input = jnp.concatenate([bottleneck, alpha_condition], axis=-1) 155 | else: 156 | alpha_input = x 157 | alpha = alpha_mlp(alpha_input) 158 | 159 | if rgb_condition is not None: 160 | rgb_condition = broadcast_condition(rgb_condition) 161 | rgb_input = jnp.concatenate([bottleneck, rgb_condition], axis=-1) 162 | else: 163 | rgb_input = x 164 | rgb = rgb_mlp(rgb_input) 165 | 166 | return { 167 | 'rgb': rgb.reshape((-1, num_samples, self.rgb_channels)), 168 | 'alpha': alpha.reshape((-1, num_samples, self.alpha_channels)), 169 | } 170 | 171 | 172 | class SinusoidalEncoder(nn.Module): 173 | """A vectorized sinusoidal encoding. 174 | 175 | Attributes: 176 | num_freqs: the number of frequency bands in the encoding. 177 | min_freq_log2: the log (base 2) of the lower frequency. 178 | max_freq_log2: the log (base 2) of the upper frequency. 179 | scale: a scaling factor for the positional encoding. 180 | use_identity: if True use the identity encoding as well. 181 | """ 182 | num_freqs: int 183 | min_freq_log2: int = 0 184 | max_freq_log2: Optional[int] = None 185 | scale: float = 1.0 186 | use_identity: bool = True 187 | 188 | def setup(self): 189 | if self.max_freq_log2 is None: 190 | max_freq_log2 = self.num_freqs - 1.0 191 | else: 192 | max_freq_log2 = self.max_freq_log2 193 | self.freq_bands = 2.0**jnp.linspace(self.min_freq_log2, 194 | max_freq_log2, 195 | int(self.num_freqs)) 196 | 197 | # (F, 1). 198 | self.freqs = jnp.reshape(self.freq_bands, (self.num_freqs, 1)) 199 | 200 | def __call__(self, x, alpha: Optional[float] = None): 201 | """A vectorized sinusoidal encoding. 202 | 203 | Args: 204 | x: the input features to encode. 205 | alpha: a dummy argument for API compatibility. 206 | 207 | Returns: 208 | A tensor containing the encoded features. 209 | """ 210 | if self.num_freqs == 0: 211 | return x 212 | 213 | x_expanded = jnp.expand_dims(x, axis=-2) # (1, C). 214 | # Will be broadcasted to shape (F, C). 215 | angles = self.scale * x_expanded * self.freqs 216 | 217 | # The shape of the features is (F, 2, C) so that when we reshape it 218 | # it matches the ordering of the original NeRF code. 219 | # Vectorize the computation of the high-frequency (sin, cos) terms. 220 | # We use the trigonometric identity: cos(x) = sin(x + pi/2) 221 | features = jnp.stack((angles, angles + jnp.pi / 2), axis=-2) 222 | features = features.flatten() 223 | features = jnp.sin(features) 224 | 225 | # Prepend the original signal for the identity. 226 | if self.use_identity: 227 | features = jnp.concatenate([x, features], axis=-1) 228 | return features 229 | 230 | 231 | class AnnealedSinusoidalEncoder(nn.Module): 232 | """An annealed sinusoidal encoding.""" 233 | num_freqs: int 234 | min_freq_log2: int = 0 235 | max_freq_log2: Optional[int] = None 236 | scale: float = 1.0 237 | use_identity: bool = True 238 | 239 | @nn.compact 240 | def __call__(self, x, alpha): 241 | if alpha is None: 242 | raise ValueError('alpha must be specified.') 243 | if self.num_freqs == 0: 244 | return x 245 | 246 | num_channels = x.shape[-1] 247 | 248 | base_encoder = SinusoidalEncoder( 249 | num_freqs=self.num_freqs, 250 | min_freq_log2=self.min_freq_log2, 251 | max_freq_log2=self.max_freq_log2, 252 | scale=self.scale, 253 | use_identity=self.use_identity) 254 | features = base_encoder(x) 255 | 256 | if self.use_identity: 257 | identity, features = jnp.split(features, (x.shape[-1],), axis=-1) 258 | 259 | # Apply the window by broadcasting to save on memory. 260 | features = jnp.reshape(features, (-1, 2, num_channels)) 261 | window = self.cosine_easing_window( 262 | self.min_freq_log2, self.max_freq_log2, self.num_freqs, alpha) 263 | window = jnp.reshape(window, (-1, 1, 1)) 264 | features = window * features 265 | 266 | if self.use_identity: 267 | return jnp.concatenate([ 268 | identity, 269 | features.flatten(), 270 | ], axis=-1) 271 | else: 272 | return features.flatten() 273 | 274 | @classmethod 275 | def cosine_easing_window(cls, min_freq_log2, max_freq_log2, num_bands, alpha): 276 | """Eases in each frequency one by one with a cosine. 277 | 278 | This is equivalent to taking a Tukey window and sliding it to the right 279 | along the frequency spectrum. 280 | 281 | Args: 282 | min_freq_log2: the lower frequency band. 283 | max_freq_log2: the upper frequency band. 284 | num_bands: the number of frequencies. 285 | alpha: will ease in each frequency as alpha goes from 0.0 to num_freqs. 286 | 287 | Returns: 288 | A 1-d numpy array with num_sample elements containing the window. 289 | """ 290 | if max_freq_log2 is None: 291 | max_freq_log2 = num_bands - 1.0 292 | bands = jnp.linspace(min_freq_log2, max_freq_log2, num_bands) 293 | x = jnp.clip(alpha - bands, 0.0, 1.0) 294 | return 0.5 * (1 + jnp.cos(jnp.pi * x + jnp.pi)) 295 | 296 | 297 | class TimeEncoder(nn.Module): 298 | """Encodes a timestamp to an embedding.""" 299 | num_freqs: int 300 | 301 | features: int = 10 302 | depth: int = 6 303 | width: int = 64 304 | skips: int = (4,) 305 | hidden_init: types.Initializer = nn.initializers.xavier_uniform() 306 | output_init: types.Activation = nn.initializers.uniform(scale=0.05) 307 | 308 | def setup(self): 309 | self.posenc = AnnealedSinusoidalEncoder(num_freqs=self.num_freqs) 310 | self.mlp = MLP( 311 | depth=self.depth, 312 | width=self.width, 313 | skips=self.skips, 314 | hidden_init=self.hidden_init, 315 | output_channels=self.features, 316 | output_init=self.output_init) 317 | 318 | def __call__(self, time, alpha=None): 319 | if alpha is None: 320 | alpha = self.num_freqs 321 | encoded_time = self.posenc(time, alpha) 322 | return self.mlp(encoded_time) 323 | -------------------------------------------------------------------------------- /nerfies/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Quaternion math. 16 | 17 | This module assumes the xyzw quaternion format where xyz is the imaginary part 18 | and w is the real part. 19 | 20 | Functions in this module support both batched and unbatched quaternions. 21 | """ 22 | from jax import numpy as jnp 23 | from jax.numpy import linalg 24 | 25 | 26 | def safe_acos(t, eps=1e-8): 27 | """A safe version of arccos which avoids evaluating at -1 or 1.""" 28 | return jnp.arccos(jnp.clip(t, -1.0 + eps, 1.0 - eps)) 29 | 30 | 31 | def im(q): 32 | """Fetch the imaginary part of the quaternion.""" 33 | return q[..., :3] 34 | 35 | 36 | def re(q): 37 | """Fetch the real part of the quaternion.""" 38 | return q[..., 3:] 39 | 40 | 41 | def identity(): 42 | return jnp.array([0.0, 0.0, 0.0, 1.0]) 43 | 44 | 45 | def conjugate(q): 46 | """Compute the conjugate of a quaternion.""" 47 | return jnp.concatenate([-im(q), re(q)], axis=-1) 48 | 49 | 50 | def inverse(q): 51 | """Compute the inverse of a quaternion.""" 52 | return normalize(conjugate(q)) 53 | 54 | 55 | def normalize(q): 56 | """Normalize a quaternion.""" 57 | return q / norm(q) 58 | 59 | 60 | def norm(q): 61 | return linalg.norm(q, axis=-1, keepdims=True) 62 | 63 | 64 | def multiply(q1, q2): 65 | """Multiply two quaternions.""" 66 | c = (re(q1) * im(q2) 67 | + re(q2) * im(q1) 68 | + jnp.cross(im(q1), im(q2))) 69 | w = re(q1) * re(q2) - jnp.dot(im(q1), im(q2)) 70 | return jnp.concatenate([c, w], axis=-1) 71 | 72 | 73 | def rotate(q, v): 74 | """Rotate a vector using a quaternion.""" 75 | # Create the quaternion representation of the vector. 76 | q_v = jnp.concatenate([v, jnp.zeros_like(v[..., :1])], axis=-1) 77 | return im(multiply(multiply(q, q_v), conjugate(q))) 78 | 79 | 80 | def log(q, eps=1e-8): 81 | """Computes the quaternion logarithm. 82 | 83 | References: 84 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions 85 | 86 | Args: 87 | q: the quaternion in (x,y,z,w) format. 88 | eps: an epsilon value for numerical stability. 89 | 90 | Returns: 91 | The logarithm of q. 92 | """ 93 | mag = linalg.norm(q, axis=-1, keepdims=True) 94 | v = im(q) 95 | s = re(q) 96 | w = jnp.log(mag) 97 | denom = jnp.maximum( 98 | linalg.norm(v, axis=-1, keepdims=True), eps * jnp.ones_like(v)) 99 | xyz = v / denom * safe_acos(s / eps) 100 | return jnp.concatenate((xyz, w), axis=-1) 101 | 102 | 103 | def exp(q, eps=1e-8): 104 | """Computes the quaternion exponential. 105 | 106 | References: 107 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions 108 | 109 | Args: 110 | q: the quaternion in (x,y,z,w) format or (x,y,z) if is_pure is True. 111 | eps: an epsilon value for numerical stability. 112 | 113 | Returns: 114 | The exponential of q. 115 | """ 116 | is_pure = q.shape[-1] == 3 117 | if is_pure: 118 | s = jnp.zeros_like(q[..., -1:]) 119 | v = q 120 | else: 121 | v = im(q) 122 | s = re(q) 123 | 124 | norm_v = linalg.norm(v, axis=-1, keepdims=True) 125 | exp_s = jnp.exp(s) 126 | w = jnp.cos(norm_v) 127 | xyz = jnp.sin(norm_v) * v / jnp.maximum(norm_v, eps * jnp.ones_like(norm_v)) 128 | return exp_s * jnp.concatenate((xyz, w), axis=-1) 129 | 130 | 131 | def to_rotation_matrix(q): 132 | """Constructs a rotation matrix from a quaternion. 133 | 134 | Args: 135 | q: a (*,4) array containing quaternions. 136 | 137 | Returns: 138 | A (*,3,3) array containing rotation matrices. 139 | """ 140 | x, y, z, w = jnp.split(q, 4, axis=-1) 141 | s = 1.0 / jnp.sum(q ** 2, axis=-1) 142 | return jnp.stack([ 143 | jnp.stack([1 - 2 * s * (y ** 2 + z ** 2), 144 | 2 * s * (x * y - z * w), 145 | 2 * s * (x * z + y * w)], axis=0), 146 | jnp.stack([2 * s * (x * y + z * w), 147 | 1 - s * 2 * (x ** 2 + z ** 2), 148 | 2 * s * (y * z - x * w)], axis=0), 149 | jnp.stack([2 * s * (x * z - y * w), 150 | 2 * s * (y * z + x * w), 151 | 1 - 2 * s * (x ** 2 + y ** 2)], axis=0), 152 | ], axis=0) 153 | 154 | 155 | def from_rotation_matrix(m, eps=1e-9): 156 | """Construct quaternion from a rotation matrix. 157 | 158 | Args: 159 | m: a (*,3,3) array containing rotation matrices. 160 | eps: a small number for numerical stability. 161 | 162 | Returns: 163 | A (*,4) array containing quaternions. 164 | """ 165 | trace = jnp.trace(m) 166 | m00 = m[..., 0, 0] 167 | m01 = m[..., 0, 1] 168 | m02 = m[..., 0, 2] 169 | m10 = m[..., 1, 0] 170 | m11 = m[..., 1, 1] 171 | m12 = m[..., 1, 2] 172 | m20 = m[..., 2, 0] 173 | m21 = m[..., 2, 1] 174 | m22 = m[..., 2, 2] 175 | 176 | def tr_positive(): 177 | sq = jnp.sqrt(trace + 1.0) * 2. # sq = 4 * w. 178 | w = 0.25 * sq 179 | x = jnp.divide(m21 - m12, sq) 180 | y = jnp.divide(m02 - m20, sq) 181 | z = jnp.divide(m10 - m01, sq) 182 | return jnp.stack((x, y, z, w), axis=-1) 183 | 184 | def cond_1(): 185 | sq = jnp.sqrt(1.0 + m00 - m11 - m22 + eps) * 2. # sq = 4 * x. 186 | w = jnp.divide(m21 - m12, sq) 187 | x = 0.25 * sq 188 | y = jnp.divide(m01 + m10, sq) 189 | z = jnp.divide(m02 + m20, sq) 190 | return jnp.stack((x, y, z, w), axis=-1) 191 | 192 | def cond_2(): 193 | sq = jnp.sqrt(1.0 + m11 - m00 - m22 + eps) * 2. # sq = 4 * y. 194 | w = jnp.divide(m02 - m20, sq) 195 | x = jnp.divide(m01 + m10, sq) 196 | y = 0.25 * sq 197 | z = jnp.divide(m12 + m21, sq) 198 | return jnp.stack((x, y, z, w), axis=-1) 199 | 200 | def cond_3(): 201 | sq = jnp.sqrt(1.0 + m22 - m00 - m11 + eps) * 2. # sq = 4 * z. 202 | w = jnp.divide(m10 - m01, sq) 203 | x = jnp.divide(m02 + m20, sq) 204 | y = jnp.divide(m12 + m21, sq) 205 | z = 0.25 * sq 206 | return jnp.stack((x, y, z, w), axis=-1) 207 | 208 | def cond_idx(cond): 209 | cond = jnp.expand_dims(cond, -1) 210 | cond = jnp.tile(cond, [1] * (len(m.shape) - 2) + [4]) 211 | return cond 212 | 213 | where_2 = jnp.where(cond_idx(m11 > m22), cond_2(), cond_3()) 214 | where_1 = jnp.where(cond_idx((m00 > m11) & (m00 > m22)), cond_1(), where_2) 215 | return jnp.where(cond_idx(trace > 0), tr_positive(), where_1) 216 | -------------------------------------------------------------------------------- /nerfies/rigid_body.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=invalid-name 16 | # pytype: disable=attribute-error 17 | import jax 18 | from jax import numpy as jnp 19 | 20 | 21 | @jax.jit 22 | def skew(w: jnp.ndarray) -> jnp.ndarray: 23 | """Build a skew matrix ("cross product matrix") for vector w. 24 | 25 | Modern Robotics Eqn 3.30. 26 | 27 | Args: 28 | w: (3,) A 3-vector 29 | 30 | Returns: 31 | W: (3, 3) A skew matrix such that W @ v == w x v 32 | """ 33 | w = jnp.reshape(w, (3)) 34 | return jnp.array([[0.0, -w[2], w[1]], \ 35 | [w[2], 0.0, -w[0]], \ 36 | [-w[1], w[0], 0.0]]) 37 | 38 | 39 | def rp_to_se3(R: jnp.ndarray, p: jnp.ndarray) -> jnp.ndarray: 40 | """Rotation and translation to homogeneous transform. 41 | 42 | Args: 43 | R: (3, 3) An orthonormal rotation matrix. 44 | p: (3,) A 3-vector representing an offset. 45 | 46 | Returns: 47 | X: (4, 4) The homogeneous transformation matrix described by rotating by R 48 | and translating by p. 49 | """ 50 | p = jnp.reshape(p, (3, 1)) 51 | return jnp.block([[R, p], [jnp.array([[0.0, 0.0, 0.0, 1.0]])]]) 52 | 53 | 54 | def exp_so3(w: jnp.ndarray, theta: float) -> jnp.ndarray: 55 | """Exponential map from Lie algebra so3 to Lie group SO3. 56 | 57 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula. 58 | 59 | Args: 60 | w: (3,) An axis of rotation. 61 | theta: An angle of rotation. 62 | 63 | Returns: 64 | R: (3, 3) An orthonormal rotation matrix representing a rotation of 65 | magnitude theta about axis w. 66 | """ 67 | W = skew(w) 68 | return jnp.eye(3) + jnp.sin(theta) * W + (1.0 - jnp.cos(theta)) * W @ W 69 | 70 | 71 | def exp_se3(S: jnp.ndarray, theta: float) -> jnp.ndarray: 72 | """Exponential map from Lie algebra so3 to Lie group SO3. 73 | 74 | Modern Robotics Eqn 3.88. 75 | 76 | Args: 77 | S: (6,) A screw axis of motion. 78 | theta: Magnitude of motion. 79 | 80 | Returns: 81 | a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating 82 | motion of magnitude theta about S for one second. 83 | """ 84 | w, v = jnp.split(S, 2) 85 | W = skew(w) 86 | R = exp_so3(w, theta) 87 | p = (theta * jnp.eye(3) + (1.0 - jnp.cos(theta)) * W + 88 | (theta - jnp.sin(theta)) * W @ W) @ v 89 | return rp_to_se3(R, p) 90 | 91 | 92 | def to_homogenous(v): 93 | return jnp.concatenate([v, jnp.ones_like(v[..., :1])], axis=-1) 94 | 95 | 96 | def from_homogenous(v): 97 | return v[..., :3] / v[..., -1:] 98 | -------------------------------------------------------------------------------- /nerfies/schedules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Annealing Schedules.""" 16 | import abc 17 | import collections 18 | import copy 19 | import math 20 | from typing import Any, Iterable, List, Tuple, Union 21 | 22 | from jax import numpy as jnp 23 | 24 | 25 | def from_tuple(x): 26 | schedule_type, *args = x 27 | return SCHEDULE_MAP[schedule_type](*args) 28 | 29 | 30 | def from_dict(d): 31 | d = copy.copy(dict(d)) 32 | schedule_type = d.pop('type') 33 | return SCHEDULE_MAP[schedule_type](**d) 34 | 35 | 36 | def from_config(schedule): 37 | if isinstance(schedule, Schedule): 38 | return schedule 39 | if isinstance(schedule, Tuple) or isinstance(schedule, List): 40 | return from_tuple(schedule) 41 | if isinstance(schedule, collections.Mapping): 42 | return from_dict(schedule) 43 | 44 | raise ValueError(f'Unknown type {type(schedule)}.') 45 | 46 | 47 | class Schedule(abc.ABC): 48 | """An interface for generic schedules..""" 49 | 50 | @abc.abstractmethod 51 | def get(self, step): 52 | """Get the value for the given step.""" 53 | raise NotImplementedError 54 | 55 | def __call__(self, step): 56 | return self.get(step) 57 | 58 | 59 | class ConstantSchedule(Schedule): 60 | """Linearly scaled scheduler.""" 61 | 62 | def __init__(self, value): 63 | super().__init__() 64 | self.value = value 65 | 66 | def get(self, step): 67 | """Get the value for the given step.""" 68 | return jnp.full_like(step, self.value, dtype=jnp.float32) 69 | 70 | 71 | class LinearSchedule(Schedule): 72 | """Linearly scaled scheduler.""" 73 | 74 | def __init__(self, initial_value, final_value, num_steps): 75 | super().__init__() 76 | self.initial_value = initial_value 77 | self.final_value = final_value 78 | self.num_steps = num_steps 79 | 80 | def get(self, step): 81 | """Get the value for the given step.""" 82 | if self.num_steps == 0: 83 | return jnp.full_like(step, self.final_value, dtype=jnp.float32) 84 | alpha = jnp.minimum(step / self.num_steps, 1.0) 85 | return (1.0 - alpha) * self.initial_value + alpha * self.final_value 86 | 87 | 88 | class ExponentialSchedule(Schedule): 89 | """Exponentially decaying scheduler.""" 90 | 91 | def __init__(self, initial_value, final_value, num_steps, eps=1e-10): 92 | super().__init__() 93 | if initial_value <= final_value: 94 | raise ValueError('Final value must be less than initial value.') 95 | 96 | self.initial_value = initial_value 97 | self.final_value = final_value 98 | self.num_steps = num_steps 99 | self.eps = eps 100 | 101 | def get(self, step): 102 | """Get the value for the given step.""" 103 | if step >= self.num_steps: 104 | return jnp.full_like(step, self.final_value, dtype=jnp.float32) 105 | 106 | final_value = max(self.final_value, self.eps) 107 | base = final_value / self.initial_value 108 | exponent = step / (self.num_steps - 1) 109 | if step >= self.num_steps: 110 | return jnp.full_like(step, self.final_value, dtype=jnp.float32) 111 | return self.initial_value * base**exponent 112 | 113 | 114 | class CosineEasingSchedule(Schedule): 115 | """Schedule that eases slowsly using a cosine.""" 116 | 117 | def __init__(self, initial_value, final_value, num_steps): 118 | super().__init__() 119 | self.initial_value = initial_value 120 | self.final_value = final_value 121 | self.num_steps = num_steps 122 | 123 | def get(self, step): 124 | """Get the value for the given step.""" 125 | alpha = jnp.minimum(step / self.num_steps, 1.0) 126 | scale = self.final_value - self.initial_value 127 | x = min(max(alpha, 0.0), 1.0) 128 | return (self.initial_value 129 | + scale * 0.5 * (1 + math.cos(jnp.pi * x + jnp.pi))) 130 | 131 | 132 | class StepSchedule(Schedule): 133 | """Schedule that eases slowsly using a cosine.""" 134 | 135 | def __init__(self, 136 | initial_value, 137 | decay_interval, 138 | decay_factor, 139 | max_decays, 140 | final_value=None): 141 | super().__init__() 142 | self.initial_value = initial_value 143 | self.decay_factor = decay_factor 144 | self.decay_interval = decay_interval 145 | self.max_decays = max_decays 146 | if final_value is None: 147 | final_value = self.initial_value * self.decay_factor**self.max_decays 148 | self.final_value = final_value 149 | 150 | def get(self, step): 151 | """Get the value for the given step.""" 152 | phase = step // self.decay_interval 153 | if phase >= self.max_decays: 154 | return self.final_value 155 | else: 156 | return self.initial_value * self.decay_factor**phase 157 | 158 | 159 | class PiecewiseSchedule(Schedule): 160 | """A piecewise combination of multiple schedules.""" 161 | 162 | def __init__( 163 | self, schedules: Iterable[Tuple[int, Union[Schedule, Iterable[Any]]]]): 164 | self.schedules = [from_config(s) for ms, s in schedules] 165 | milestones = jnp.array([ms for ms, s in schedules]) 166 | self.milestones = jnp.cumsum(milestones)[:-1] 167 | 168 | def get(self, step): 169 | idx = jnp.searchsorted(self.milestones, step, side='right') 170 | schedule = self.schedules[idx] 171 | base_idx = self.milestones[idx - 1] if idx >= 1 else 0 172 | return schedule.get(step - base_idx) 173 | 174 | 175 | class DelayedSchedule(Schedule): 176 | """Delays the start of the base schedule.""" 177 | 178 | def __init__(self, base_schedule: Schedule, delay_steps, delay_mult): 179 | self.base_schedule = from_config(base_schedule) 180 | self.delay_steps = delay_steps 181 | self.delay_mult = delay_mult 182 | 183 | def get(self, step): 184 | delay_rate = ( 185 | self.delay_mult 186 | + (1 - self.delay_mult) 187 | * jnp.sin(0.5 * jnp.pi * jnp.clip(step / self.delay_steps, 0, 1))) 188 | 189 | return delay_rate * self.base_schedule(step) 190 | 191 | 192 | SCHEDULE_MAP = { 193 | 'constant': ConstantSchedule, 194 | 'linear': LinearSchedule, 195 | 'exponential': ExponentialSchedule, 196 | 'cosine_easing': CosineEasingSchedule, 197 | 'step': StepSchedule, 198 | 'piecewise': PiecewiseSchedule, 199 | 'delayed': DelayedSchedule, 200 | } 201 | -------------------------------------------------------------------------------- /nerfies/tf_camera.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A basic camera implementation in Tensorflow.""" 16 | from typing import Tuple, Optional 17 | 18 | import tensorflow as tf 19 | from tensorflow.experimental import numpy as tnp 20 | 21 | 22 | def _norm(x): 23 | return tnp.sqrt(tnp.sum(x ** 2, axis=-1, keepdims=True)) 24 | 25 | 26 | def _compute_residual_and_jacobian( 27 | x: tnp.ndarray, 28 | y: tnp.ndarray, 29 | xd: tnp.ndarray, 30 | yd: tnp.ndarray, 31 | k1: float = 0.0, 32 | k2: float = 0.0, 33 | k3: float = 0.0, 34 | p1: float = 0.0, 35 | p2: float = 0.0, 36 | ) -> Tuple[tnp.ndarray, tnp.ndarray, tnp.ndarray, tnp.ndarray, tnp.ndarray, 37 | tnp.ndarray]: 38 | """Auxiliary function of radial_and_tangential_undistort().""" 39 | # let r(x, y) = x^2 + y^2; 40 | # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3; 41 | r = x * x + y * y 42 | d = 1.0 + r * (k1 + r * (k2 + k3 * r)) 43 | 44 | # The perfect projection is: 45 | # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2); 46 | # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2); 47 | # 48 | # Let's define 49 | # 50 | # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd; 51 | # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd; 52 | # 53 | # We are looking for a solution that satisfies 54 | # fx(x, y) = fy(x, y) = 0; 55 | fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd 56 | fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd 57 | 58 | # Compute derivative of d over [x, y] 59 | d_r = (k1 + r * (2.0 * k2 + 3.0 * k3 * r)) 60 | d_x = 2.0 * x * d_r 61 | d_y = 2.0 * y * d_r 62 | 63 | # Compute derivative of fx over x and y. 64 | fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x 65 | fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y 66 | 67 | # Compute derivative of fy over x and y. 68 | fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x 69 | fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y 70 | 71 | return fx, fy, fx_x, fx_y, fy_x, fy_y 72 | 73 | 74 | def _radial_and_tangential_undistort( 75 | xd: tnp.ndarray, 76 | yd: tnp.ndarray, 77 | k1: float = 0, 78 | k2: float = 0, 79 | k3: float = 0, 80 | p1: float = 0, 81 | p2: float = 0, 82 | eps: float = 1e-9, 83 | max_iterations=10) -> Tuple[tnp.ndarray, tnp.ndarray]: 84 | """Computes undistorted (x, y) from (xd, yd).""" 85 | # Initialize from the distorted point. 86 | x = xd 87 | y = yd 88 | 89 | for _ in range(max_iterations): 90 | fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian( 91 | x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, p1=p1, p2=p2) 92 | denominator = fy_x * fx_y - fx_x * fy_y 93 | x_numerator = fx * fy_y - fy * fx_y 94 | y_numerator = fy * fx_x - fx * fy_x 95 | step_x = tnp.where( 96 | tnp.abs(denominator) > eps, x_numerator / denominator, 97 | tnp.zeros_like(denominator)) 98 | step_y = tnp.where( 99 | tnp.abs(denominator) > eps, y_numerator / denominator, 100 | tnp.zeros_like(denominator)) 101 | 102 | x = x + step_x 103 | y = y + step_y 104 | 105 | return x, y 106 | 107 | 108 | class TFCamera: 109 | """A duplicate of our JAX-basded camera class. 110 | 111 | This is necessary to use tf.data.Dataset. 112 | """ 113 | 114 | def __init__(self, 115 | orientation: tnp.ndarray, 116 | position: tnp.ndarray, 117 | focal_length: float, 118 | principal_point: tnp.ndarray, 119 | image_size: tnp.ndarray, 120 | skew: float = 0.0, 121 | pixel_aspect_ratio: float = 1.0, 122 | radial_distortion: Optional[tnp.ndarray] = None, 123 | tangential_distortion: Optional[tnp.ndarray] = None, 124 | dtype=tnp.float32): 125 | """Constructor for camera class.""" 126 | if radial_distortion is None: 127 | radial_distortion = tnp.array([0.0, 0.0, 0.0], dtype) 128 | if tangential_distortion is None: 129 | tangential_distortion = tnp.array([0.0, 0.0], dtype) 130 | 131 | self.orientation = tnp.array(orientation, dtype) 132 | self.position = tnp.array(position, dtype) 133 | self.focal_length = tnp.array(focal_length, dtype) 134 | self.principal_point = tnp.array(principal_point, dtype) 135 | self.skew = tnp.array(skew, dtype) 136 | self.pixel_aspect_ratio = tnp.array(pixel_aspect_ratio, dtype) 137 | self.radial_distortion = tnp.array(radial_distortion, dtype) 138 | self.tangential_distortion = tnp.array(tangential_distortion, dtype) 139 | self.image_size = tnp.array(image_size, dtype) 140 | self.dtype = dtype 141 | 142 | @property 143 | def scale_factor_x(self): 144 | return self.focal_length 145 | 146 | @property 147 | def scale_factor_y(self): 148 | return self.focal_length * self.pixel_aspect_ratio 149 | 150 | @property 151 | def principal_point_x(self): 152 | return self.principal_point[0] 153 | 154 | @property 155 | def principal_point_y(self): 156 | return self.principal_point[1] 157 | 158 | @property 159 | def image_size_y(self): 160 | return self.image_size[1] 161 | 162 | @property 163 | def image_size_x(self): 164 | return self.image_size[0] 165 | 166 | @property 167 | def image_shape(self): 168 | return self.image_size_y, self.image_size_x 169 | 170 | @property 171 | def optical_axis(self): 172 | return self.orientation[2, :] 173 | 174 | def pixel_to_local_rays(self, pixels: tnp.ndarray): 175 | """Returns the local ray directions for the provided pixels.""" 176 | y = ((pixels[..., 1] - self.principal_point_y) / self.scale_factor_y) 177 | x = ((pixels[..., 0] - self.principal_point_x - y * self.skew) / 178 | self.scale_factor_x) 179 | 180 | x, y = _radial_and_tangential_undistort( 181 | x, 182 | y, 183 | k1=self.radial_distortion[0], 184 | k2=self.radial_distortion[1], 185 | k3=self.radial_distortion[2], 186 | p1=self.tangential_distortion[0], 187 | p2=self.tangential_distortion[1]) 188 | 189 | dirs = tnp.stack([x, y, tnp.ones_like(x)], axis=-1) 190 | return dirs / _norm(dirs) 191 | 192 | def pixels_to_rays(self, 193 | pixels: tnp.ndarray) -> Tuple[tnp.ndarray, tnp.ndarray]: 194 | """Returns the rays for the provided pixels. 195 | 196 | Args: 197 | pixels: [A1, ..., An, 2] tensor or np.array containing 2d pixel positions. 198 | 199 | Returns: 200 | An array containing the normalized ray directions in world coordinates. 201 | """ 202 | if pixels.shape[-1] != 2: 203 | raise ValueError('The last dimension of pixels must be 2.') 204 | if pixels.dtype != self.dtype: 205 | raise ValueError(f'pixels dtype ({pixels.dtype!r}) must match camera ' 206 | f'dtype ({self.dtype!r})') 207 | 208 | local_rays_dir = self.pixel_to_local_rays(pixels) 209 | rays_dir = tf.linalg.matvec( 210 | self.orientation, local_rays_dir, transpose_a=True) 211 | 212 | # Normalize rays. 213 | rays_dir = rays_dir / _norm(rays_dir) 214 | return rays_dir 215 | 216 | def pixels_to_points(self, pixels: tnp.ndarray, depth: tnp.ndarray): 217 | rays_through_pixels = self.pixels_to_rays(pixels) 218 | cosa = rays_through_pixels @ self.optical_axis 219 | points = ( 220 | rays_through_pixels * depth[..., tnp.newaxis] / cosa[..., tnp.newaxis] + 221 | self.position) 222 | return points 223 | 224 | def points_to_local_points(self, points: tnp.ndarray): 225 | translated_points = points - self.position 226 | local_points = (self.orientation @ translated_points.T).T 227 | return local_points 228 | 229 | def get_pixel_centers(self): 230 | """Returns the pixel centers.""" 231 | xx, yy = tf.meshgrid(tf.range(self.image_size_x), 232 | tf.range(self.image_size_y)) 233 | return tf.cast(tf.stack([xx, yy], axis=-1), self.dtype) + 0.5 234 | -------------------------------------------------------------------------------- /nerfies/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to training NeRFs.""" 16 | import functools 17 | from typing import Any 18 | from typing import Callable 19 | from typing import Dict 20 | 21 | from absl import logging 22 | from flax import struct 23 | from flax.training import checkpoints 24 | import jax 25 | from jax import lax 26 | from jax import numpy as jnp 27 | from jax import random 28 | from jax import vmap 29 | 30 | from nerfies import model_utils 31 | from nerfies import models 32 | from nerfies import utils 33 | 34 | 35 | @struct.dataclass 36 | class ScalarParams: 37 | learning_rate: float 38 | elastic_loss_weight: float = 0.0 39 | warp_reg_loss_weight: float = 0.0 40 | warp_reg_loss_alpha: float = -2.0 41 | warp_reg_loss_scale: float = 0.001 42 | background_loss_weight: float = 0.0 43 | background_noise_std: float = 0.001 44 | 45 | 46 | def save_checkpoint(path, state, keep=2): 47 | """Save the state to a checkpoint.""" 48 | state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) 49 | step = state_to_save.optimizer.state.step 50 | checkpoint_path = checkpoints.save_checkpoint( 51 | path, state_to_save, step, keep=keep) 52 | logging.info('Saved checkpoint: step=%d, path=%s', int(step), checkpoint_path) 53 | return checkpoint_path 54 | 55 | 56 | @jax.jit 57 | def nearest_rotation_svd(matrix, eps=1e-6): 58 | """Computes the nearest rotation using SVD.""" 59 | # TODO(keunhong): Currently this produces NaNs for some reason. 60 | u, _, vh = jnp.linalg.svd(matrix + eps, compute_uv=True, full_matrices=False) 61 | # Handle the case when there is a flip. 62 | # M will be the identity matrix except when det(UV^T) = -1 63 | # in which case the last diagonal of M will be -1. 64 | det = jnp.linalg.det(u @ vh) 65 | m = jnp.stack([jnp.ones_like(det), jnp.ones_like(det), det], axis=-1) 66 | m = jnp.diag(m) 67 | r = u @ m @ vh 68 | return r 69 | 70 | 71 | def compute_elastic_loss(jacobian, eps=1e-6, loss_type='log_svals'): 72 | """Compute the elastic regularization loss. 73 | 74 | The loss is given by sum(log(S)^2). This penalizes the singular values 75 | when they deviate from the identity since log(1) = 0.0, 76 | where D is the diagonal matrix containing the singular values. 77 | 78 | Args: 79 | jacobian: the Jacobian of the point transformation. 80 | eps: a small value to prevent taking the log of zero. 81 | loss_type: which elastic loss type to use. 82 | 83 | Returns: 84 | The elastic regularization loss. 85 | """ 86 | if loss_type == 'log_svals': 87 | svals = jnp.linalg.svd(jacobian, compute_uv=False) 88 | log_svals = jnp.log(jnp.maximum(svals, eps)) 89 | sq_residual = jnp.sum(log_svals**2, axis=-1) 90 | elif loss_type == 'svals': 91 | svals = jnp.linalg.svd(jacobian, compute_uv=False) 92 | sq_residual = jnp.sum((svals - 1.0)**2, axis=-1) 93 | elif loss_type == 'jtj': 94 | jtj = jacobian @ jacobian.T 95 | sq_residual = ((jtj - jnp.eye(3)) ** 2).sum() / 4.0 96 | elif loss_type == 'div': 97 | div = utils.jacobian_to_div(jacobian) 98 | sq_residual = div ** 2 99 | elif loss_type == 'det': 100 | det = jnp.linalg.det(jacobian) 101 | sq_residual = (det - 1.0) ** 2 102 | elif loss_type == 'log_det': 103 | det = jnp.linalg.det(jacobian) 104 | sq_residual = jnp.log(jnp.maximum(det, eps)) ** 2 105 | elif loss_type == 'nr': 106 | rot = nearest_rotation_svd(jacobian) 107 | sq_residual = jnp.sum((jacobian - rot) ** 2) 108 | else: 109 | raise NotImplementedError( 110 | f'Unknown elastic loss type {loss_type!r}') 111 | residual = jnp.sqrt(sq_residual) 112 | loss = utils.general_loss_with_squared_residual( 113 | sq_residual, alpha=-2.0, scale=0.03) 114 | return loss, residual 115 | 116 | 117 | @functools.partial(jax.jit, static_argnums=0) 118 | def compute_background_loss( 119 | model, state, params, key, points, noise_std, alpha=-2, scale=0.001): 120 | """Compute the background regularization loss.""" 121 | metadata = random.choice(key, 122 | jnp.array(model.warp_ids, jnp.uint32), 123 | shape=(points.shape[0], 1)) 124 | point_noise = noise_std * random.normal(key, points.shape) 125 | points = points + point_noise 126 | 127 | warp_field = model.create_warp_field(model, num_batch_dims=1) 128 | warp_out = warp_field.apply( 129 | {'params': params['warp_field']}, 130 | points, metadata, state.warp_extra, False, False) 131 | warped_points = warp_out['warped_points'][..., :3] 132 | sq_residual = jnp.sum((warped_points - points)**2, axis=-1) 133 | loss = utils.general_loss_with_squared_residual( 134 | sq_residual, alpha=alpha, scale=scale) 135 | return loss 136 | 137 | 138 | def train_step(model: models.NerfModel, 139 | rng_key: Callable[[int], jnp.ndarray], 140 | state, 141 | batch: Dict[str, Any], 142 | scalar_params: ScalarParams, 143 | use_elastic_loss: bool = False, 144 | elastic_reduce_method: str = 'median', 145 | elastic_loss_type: str = 'log_svals', 146 | use_background_loss: bool = False, 147 | use_warp_reg_loss: bool = False): 148 | """One optimization step. 149 | 150 | Args: 151 | model: the model module to evaluate. 152 | rng_key: The random number generator. 153 | state: model_utils.TrainState, state of model and optimizer. 154 | batch: dict. A mini-batch of data for training. 155 | scalar_params: scalar-valued parameters. 156 | use_elastic_loss: is True use the elastic regularization loss. 157 | elastic_reduce_method: which method to use to reduce the samples for the 158 | elastic loss. 'median' selects the median depth point sample while 159 | 'weight' computes a weighted sum using the density weights. 160 | elastic_loss_type: which method to use for the elastic loss. 161 | use_background_loss: if True use the background regularization loss. 162 | use_warp_reg_loss: if True use the warp regularization loss. 163 | 164 | Returns: 165 | new_state: model_utils.TrainState, new training state. 166 | stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)]. 167 | """ 168 | rng_key, fine_key, coarse_key, reg_key = random.split(rng_key, 4) 169 | 170 | # pylint: disable=unused-argument 171 | def _compute_loss_and_stats(params, model_out, use_elastic_loss=False): 172 | rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean() 173 | stats = { 174 | 'loss/rgb': rgb_loss, 175 | } 176 | loss = rgb_loss 177 | if use_elastic_loss: 178 | elastic_fn = functools.partial(compute_elastic_loss, 179 | loss_type=elastic_loss_type) 180 | v_elastic_fn = jax.jit(vmap(vmap(jax.jit(elastic_fn)))) 181 | weights = lax.stop_gradient(model_out['weights']) 182 | jacobian = model_out['warp_jacobian'] 183 | # Pick the median point Jacobian. 184 | if elastic_reduce_method == 'median': 185 | depth_indices = model_utils.compute_depth_index(weights) 186 | jacobian = jnp.take_along_axis( 187 | # Unsqueeze axes: sample axis, Jacobian row, Jacobian col. 188 | jacobian, depth_indices[..., None, None, None], axis=-3) 189 | # Compute loss using Jacobian. 190 | elastic_loss, elastic_residual = v_elastic_fn(jacobian) 191 | # Multiply weight if weighting by density. 192 | if elastic_reduce_method == 'weight': 193 | elastic_loss = weights * elastic_loss 194 | elastic_loss = elastic_loss.sum(axis=-1).mean() 195 | stats['loss/elastic'] = elastic_loss 196 | stats['residual/elastic'] = jnp.mean(elastic_residual) 197 | loss += scalar_params.elastic_loss_weight * elastic_loss 198 | 199 | if use_warp_reg_loss: 200 | weights = lax.stop_gradient(model_out['weights']) 201 | depth_indices = model_utils.compute_depth_index(weights) 202 | warp_mag = ( 203 | (model_out['points'] - model_out['warped_points']) ** 2).sum(axis=-1) 204 | warp_reg_residual = jnp.take_along_axis( 205 | warp_mag, depth_indices[..., None], axis=-1) 206 | warp_reg_loss = utils.general_loss_with_squared_residual( 207 | warp_reg_residual, 208 | alpha=scalar_params.warp_reg_loss_alpha, 209 | scale=scalar_params.warp_reg_loss_scale).mean() 210 | stats['loss/warp_reg'] = warp_reg_loss 211 | stats['residual/warp_reg'] = jnp.mean(jnp.sqrt(warp_reg_residual)) 212 | loss += scalar_params.warp_reg_loss_weight * warp_reg_loss 213 | 214 | if 'warp_jacobian' in model_out: 215 | jacobian = model_out['warp_jacobian'] 216 | jacobian_det = jnp.linalg.det(jacobian) 217 | jacobian_div = utils.jacobian_to_div(jacobian) 218 | jacobian_curl = utils.jacobian_to_curl(jacobian) 219 | stats['metric/jacobian_det'] = jnp.mean(jacobian_det) 220 | stats['metric/jacobian_div'] = jnp.mean(jacobian_div) 221 | stats['metric/jacobian_curl'] = jnp.mean( 222 | jnp.linalg.norm(jacobian_curl, axis=-1)) 223 | 224 | stats['loss/total'] = loss 225 | stats['metric/psnr'] = utils.compute_psnr(rgb_loss) 226 | return loss, stats 227 | 228 | def _loss_fn(params): 229 | ret = model.apply({'params': params['model']}, 230 | batch, 231 | warp_extra=state.warp_extra, 232 | return_points=use_warp_reg_loss, 233 | return_weights=(use_warp_reg_loss or use_elastic_loss), 234 | rngs={ 235 | 'fine': fine_key, 236 | 'coarse': coarse_key 237 | }) 238 | 239 | losses = {} 240 | stats = {} 241 | if 'fine' in ret: 242 | losses['fine'], stats['fine'] = _compute_loss_and_stats( 243 | params, ret['fine']) 244 | if 'coarse' in ret: 245 | losses['coarse'], stats['coarse'] = _compute_loss_and_stats( 246 | params, ret['coarse'], use_elastic_loss=use_elastic_loss) 247 | 248 | if use_background_loss: 249 | background_loss = compute_background_loss( 250 | model, 251 | state=state, 252 | params=params['model'], 253 | key=reg_key, 254 | points=batch['background_points'], 255 | noise_std=scalar_params.background_noise_std) 256 | background_loss = background_loss.mean() 257 | losses['background'] = ( 258 | scalar_params.background_loss_weight * background_loss) 259 | stats['background_loss'] = background_loss 260 | 261 | return sum(losses.values()), stats 262 | 263 | optimizer = state.optimizer 264 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 265 | (_, stats), grad = grad_fn(optimizer.target) 266 | grad = jax.lax.pmean(grad, axis_name='batch') 267 | stats = jax.lax.pmean(stats, axis_name='batch') 268 | new_optimizer = optimizer.apply_gradient( 269 | grad, learning_rate=scalar_params.learning_rate) 270 | new_state = state.replace(optimizer=new_optimizer) 271 | return new_state, stats, rng_key 272 | -------------------------------------------------------------------------------- /nerfies/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom type annotations.""" 16 | import pathlib 17 | from typing import Any, Callable, Tuple, Text, Union 18 | 19 | PRNGKey = Any 20 | Shape = Tuple[int] 21 | Dtype = Any # this could be a real type? 22 | Array = Any 23 | 24 | Activation = Callable[[Array], Array] 25 | Initializer = Callable[[PRNGKey, Shape, Dtype], Array] 26 | 27 | PathType = Union[Text, pathlib.PurePosixPath] 28 | -------------------------------------------------------------------------------- /nerfies/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Visualization utilities.""" 16 | import functools 17 | 18 | from matplotlib import cm 19 | from matplotlib.colors import LinearSegmentedColormap 20 | import numpy as np 21 | 22 | 23 | # Turbo colors by Anton Mikhailov. 24 | # Copyright 2019 Google LLC. 25 | # https://gist.github.com/mikhailov-work/ee72ba4191942acecc03fe6da94fc73f 26 | _TURBO_COLORS = np.array( 27 | [[0.18995, 0.07176, 0.23217], [0.19483, 0.08339, 0.26149], 28 | [0.19956, 0.09498, 0.29024], [0.20415, 0.10652, 0.31844], 29 | [0.20860, 0.11802, 0.34607], [0.21291, 0.12947, 0.37314], 30 | [0.21708, 0.14087, 0.39964], [0.22111, 0.15223, 0.42558], 31 | [0.22500, 0.16354, 0.45096], [0.22875, 0.17481, 0.47578], 32 | [0.23236, 0.18603, 0.50004], [0.23582, 0.19720, 0.52373], 33 | [0.23915, 0.20833, 0.54686], [0.24234, 0.21941, 0.56942], 34 | [0.24539, 0.23044, 0.59142], [0.24830, 0.24143, 0.61286], 35 | [0.25107, 0.25237, 0.63374], [0.25369, 0.26327, 0.65406], 36 | [0.25618, 0.27412, 0.67381], [0.25853, 0.28492, 0.69300], 37 | [0.26074, 0.29568, 0.71162], [0.26280, 0.30639, 0.72968], 38 | [0.26473, 0.31706, 0.74718], [0.26652, 0.32768, 0.76412], 39 | [0.26816, 0.33825, 0.78050], [0.26967, 0.34878, 0.79631], 40 | [0.27103, 0.35926, 0.81156], [0.27226, 0.36970, 0.82624], 41 | [0.27334, 0.38008, 0.84037], [0.27429, 0.39043, 0.85393], 42 | [0.27509, 0.40072, 0.86692], [0.27576, 0.41097, 0.87936], 43 | [0.27628, 0.42118, 0.89123], [0.27667, 0.43134, 0.90254], 44 | [0.27691, 0.44145, 0.91328], [0.27701, 0.45152, 0.92347], 45 | [0.27698, 0.46153, 0.93309], [0.27680, 0.47151, 0.94214], 46 | [0.27648, 0.48144, 0.95064], [0.27603, 0.49132, 0.95857], 47 | [0.27543, 0.50115, 0.96594], [0.27469, 0.51094, 0.97275], 48 | [0.27381, 0.52069, 0.97899], [0.27273, 0.53040, 0.98461], 49 | [0.27106, 0.54015, 0.98930], [0.26878, 0.54995, 0.99303], 50 | [0.26592, 0.55979, 0.99583], [0.26252, 0.56967, 0.99773], 51 | [0.25862, 0.57958, 0.99876], [0.25425, 0.58950, 0.99896], 52 | [0.24946, 0.59943, 0.99835], [0.24427, 0.60937, 0.99697], 53 | [0.23874, 0.61931, 0.99485], [0.23288, 0.62923, 0.99202], 54 | [0.22676, 0.63913, 0.98851], [0.22039, 0.64901, 0.98436], 55 | [0.21382, 0.65886, 0.97959], [0.20708, 0.66866, 0.97423], 56 | [0.20021, 0.67842, 0.96833], [0.19326, 0.68812, 0.96190], 57 | [0.18625, 0.69775, 0.95498], [0.17923, 0.70732, 0.94761], 58 | [0.17223, 0.71680, 0.93981], [0.16529, 0.72620, 0.93161], 59 | [0.15844, 0.73551, 0.92305], [0.15173, 0.74472, 0.91416], 60 | [0.14519, 0.75381, 0.90496], [0.13886, 0.76279, 0.89550], 61 | [0.13278, 0.77165, 0.88580], [0.12698, 0.78037, 0.87590], 62 | [0.12151, 0.78896, 0.86581], [0.11639, 0.79740, 0.85559], 63 | [0.11167, 0.80569, 0.84525], [0.10738, 0.81381, 0.83484], 64 | [0.10357, 0.82177, 0.82437], [0.10026, 0.82955, 0.81389], 65 | [0.09750, 0.83714, 0.80342], [0.09532, 0.84455, 0.79299], 66 | [0.09377, 0.85175, 0.78264], [0.09287, 0.85875, 0.77240], 67 | [0.09267, 0.86554, 0.76230], [0.09320, 0.87211, 0.75237], 68 | [0.09451, 0.87844, 0.74265], [0.09662, 0.88454, 0.73316], 69 | [0.09958, 0.89040, 0.72393], [0.10342, 0.89600, 0.71500], 70 | [0.10815, 0.90142, 0.70599], [0.11374, 0.90673, 0.69651], 71 | [0.12014, 0.91193, 0.68660], [0.12733, 0.91701, 0.67627], 72 | [0.13526, 0.92197, 0.66556], [0.14391, 0.92680, 0.65448], 73 | [0.15323, 0.93151, 0.64308], [0.16319, 0.93609, 0.63137], 74 | [0.17377, 0.94053, 0.61938], [0.18491, 0.94484, 0.60713], 75 | [0.19659, 0.94901, 0.59466], [0.20877, 0.95304, 0.58199], 76 | [0.22142, 0.95692, 0.56914], [0.23449, 0.96065, 0.55614], 77 | [0.24797, 0.96423, 0.54303], [0.26180, 0.96765, 0.52981], 78 | [0.27597, 0.97092, 0.51653], [0.29042, 0.97403, 0.50321], 79 | [0.30513, 0.97697, 0.48987], [0.32006, 0.97974, 0.47654], 80 | [0.33517, 0.98234, 0.46325], [0.35043, 0.98477, 0.45002], 81 | [0.36581, 0.98702, 0.43688], [0.38127, 0.98909, 0.42386], 82 | [0.39678, 0.99098, 0.41098], [0.41229, 0.99268, 0.39826], 83 | [0.42778, 0.99419, 0.38575], [0.44321, 0.99551, 0.37345], 84 | [0.45854, 0.99663, 0.36140], [0.47375, 0.99755, 0.34963], 85 | [0.48879, 0.99828, 0.33816], [0.50362, 0.99879, 0.32701], 86 | [0.51822, 0.99910, 0.31622], [0.53255, 0.99919, 0.30581], 87 | [0.54658, 0.99907, 0.29581], [0.56026, 0.99873, 0.28623], 88 | [0.57357, 0.99817, 0.27712], [0.58646, 0.99739, 0.26849], 89 | [0.59891, 0.99638, 0.26038], [0.61088, 0.99514, 0.25280], 90 | [0.62233, 0.99366, 0.24579], [0.63323, 0.99195, 0.23937], 91 | [0.64362, 0.98999, 0.23356], [0.65394, 0.98775, 0.22835], 92 | [0.66428, 0.98524, 0.22370], [0.67462, 0.98246, 0.21960], 93 | [0.68494, 0.97941, 0.21602], [0.69525, 0.97610, 0.21294], 94 | [0.70553, 0.97255, 0.21032], [0.71577, 0.96875, 0.20815], 95 | [0.72596, 0.96470, 0.20640], [0.73610, 0.96043, 0.20504], 96 | [0.74617, 0.95593, 0.20406], [0.75617, 0.95121, 0.20343], 97 | [0.76608, 0.94627, 0.20311], [0.77591, 0.94113, 0.20310], 98 | [0.78563, 0.93579, 0.20336], [0.79524, 0.93025, 0.20386], 99 | [0.80473, 0.92452, 0.20459], [0.81410, 0.91861, 0.20552], 100 | [0.82333, 0.91253, 0.20663], [0.83241, 0.90627, 0.20788], 101 | [0.84133, 0.89986, 0.20926], [0.85010, 0.89328, 0.21074], 102 | [0.85868, 0.88655, 0.21230], [0.86709, 0.87968, 0.21391], 103 | [0.87530, 0.87267, 0.21555], [0.88331, 0.86553, 0.21719], 104 | [0.89112, 0.85826, 0.21880], [0.89870, 0.85087, 0.22038], 105 | [0.90605, 0.84337, 0.22188], [0.91317, 0.83576, 0.22328], 106 | [0.92004, 0.82806, 0.22456], [0.92666, 0.82025, 0.22570], 107 | [0.93301, 0.81236, 0.22667], [0.93909, 0.80439, 0.22744], 108 | [0.94489, 0.79634, 0.22800], [0.95039, 0.78823, 0.22831], 109 | [0.95560, 0.78005, 0.22836], [0.96049, 0.77181, 0.22811], 110 | [0.96507, 0.76352, 0.22754], [0.96931, 0.75519, 0.22663], 111 | [0.97323, 0.74682, 0.22536], [0.97679, 0.73842, 0.22369], 112 | [0.98000, 0.73000, 0.22161], [0.98289, 0.72140, 0.21918], 113 | [0.98549, 0.71250, 0.21650], [0.98781, 0.70330, 0.21358], 114 | [0.98986, 0.69382, 0.21043], [0.99163, 0.68408, 0.20706], 115 | [0.99314, 0.67408, 0.20348], [0.99438, 0.66386, 0.19971], 116 | [0.99535, 0.65341, 0.19577], [0.99607, 0.64277, 0.19165], 117 | [0.99654, 0.63193, 0.18738], [0.99675, 0.62093, 0.18297], 118 | [0.99672, 0.60977, 0.17842], [0.99644, 0.59846, 0.17376], 119 | [0.99593, 0.58703, 0.16899], [0.99517, 0.57549, 0.16412], 120 | [0.99419, 0.56386, 0.15918], [0.99297, 0.55214, 0.15417], 121 | [0.99153, 0.54036, 0.14910], [0.98987, 0.52854, 0.14398], 122 | [0.98799, 0.51667, 0.13883], [0.98590, 0.50479, 0.13367], 123 | [0.98360, 0.49291, 0.12849], [0.98108, 0.48104, 0.12332], 124 | [0.97837, 0.46920, 0.11817], [0.97545, 0.45740, 0.11305], 125 | [0.97234, 0.44565, 0.10797], [0.96904, 0.43399, 0.10294], 126 | [0.96555, 0.42241, 0.09798], [0.96187, 0.41093, 0.09310], 127 | [0.95801, 0.39958, 0.08831], [0.95398, 0.38836, 0.08362], 128 | [0.94977, 0.37729, 0.07905], [0.94538, 0.36638, 0.07461], 129 | [0.94084, 0.35566, 0.07031], [0.93612, 0.34513, 0.06616], 130 | [0.93125, 0.33482, 0.06218], [0.92623, 0.32473, 0.05837], 131 | [0.92105, 0.31489, 0.05475], [0.91572, 0.30530, 0.05134], 132 | [0.91024, 0.29599, 0.04814], [0.90463, 0.28696, 0.04516], 133 | [0.89888, 0.27824, 0.04243], [0.89298, 0.26981, 0.03993], 134 | [0.88691, 0.26152, 0.03753], [0.88066, 0.25334, 0.03521], 135 | [0.87422, 0.24526, 0.03297], [0.86760, 0.23730, 0.03082], 136 | [0.86079, 0.22945, 0.02875], [0.85380, 0.22170, 0.02677], 137 | [0.84662, 0.21407, 0.02487], [0.83926, 0.20654, 0.02305], 138 | [0.83172, 0.19912, 0.02131], [0.82399, 0.19182, 0.01966], 139 | [0.81608, 0.18462, 0.01809], [0.80799, 0.17753, 0.01660], 140 | [0.79971, 0.17055, 0.01520], [0.79125, 0.16368, 0.01387], 141 | [0.78260, 0.15693, 0.01264], [0.77377, 0.15028, 0.01148], 142 | [0.76476, 0.14374, 0.01041], [0.75556, 0.13731, 0.00942], 143 | [0.74617, 0.13098, 0.00851], [0.73661, 0.12477, 0.00769], 144 | [0.72686, 0.11867, 0.00695], [0.71692, 0.11268, 0.00629], 145 | [0.70680, 0.10680, 0.00571], [0.69650, 0.10102, 0.00522], 146 | [0.68602, 0.09536, 0.00481], [0.67535, 0.08980, 0.00449], 147 | [0.66449, 0.08436, 0.00424], [0.65345, 0.07902, 0.00408], 148 | [0.64223, 0.07380, 0.00401], [0.63082, 0.06868, 0.00401], 149 | [0.61923, 0.06367, 0.00410], [0.60746, 0.05878, 0.00427], 150 | [0.59550, 0.05399, 0.00453], [0.58336, 0.04931, 0.00486], 151 | [0.57103, 0.04474, 0.00529], [0.55852, 0.04028, 0.00579], 152 | [0.54583, 0.03593, 0.00638], [0.53295, 0.03169, 0.00705], 153 | [0.51989, 0.02756, 0.00780], [0.50664, 0.02354, 0.00863], 154 | [0.49321, 0.01963, 0.00955], [0.47960, 0.01583, 0.01055]]) 155 | 156 | _colormap_cache = {} 157 | 158 | 159 | def _build_colormap(name, num_bins=256): 160 | base = cm.get_cmap(name) 161 | color_list = base(np.linspace(0, 1, num_bins)) 162 | cmap_name = base.name + str(num_bins) 163 | colormap = LinearSegmentedColormap.from_list(cmap_name, color_list, num_bins) 164 | colormap = colormap(np.linspace(0, 1, num_bins))[:, :3] 165 | return colormap 166 | 167 | 168 | @functools.lru_cache(maxsize=32) 169 | def get_colormap(name, num_bins=256): 170 | """Lazily initializes and returns a colormap.""" 171 | if name == 'turbo': 172 | return _TURBO_COLORS 173 | 174 | return _build_colormap(name, num_bins) 175 | 176 | 177 | def interpolate_colormap(values, colormap): 178 | """Interpolates the colormap given values between 0.0 and 1.0.""" 179 | a = np.floor(values * 255) 180 | b = (a + 1).clip(max=255) 181 | f = values * 255.0 - a 182 | a = a.astype(np.uint16).clip(0, 255) 183 | b = b.astype(np.uint16).clip(0, 255) 184 | return colormap[a] + (colormap[b] - colormap[a]) * f[..., np.newaxis] 185 | 186 | 187 | def scale_values(values, vmin, vmax, eps=1e-6): 188 | return (values - vmin) / max(vmax - vmin, eps) 189 | 190 | 191 | def colorize( 192 | array, cmin=None, cmax=None, cmap='magma', eps=1e-6, invert=False): 193 | """Applies a colormap to an array. 194 | 195 | Args: 196 | array: the array to apply a colormap to. 197 | cmin: the minimum value of the colormap. If None will take the min. 198 | cmax: the maximum value of the colormap. If None will take the max. 199 | cmap: the color mapping to use. 200 | eps: a small value to prevent divide by zero. 201 | invert: if True will invert the colormap. 202 | 203 | Returns: 204 | a color mapped version of array. 205 | """ 206 | array = np.asarray(array) 207 | 208 | if cmin is None: 209 | cmin = array.min() 210 | if cmax is None: 211 | cmax = array.max() 212 | 213 | x = scale_values(array, cmin, cmax, eps) 214 | colormap = get_colormap(cmap) 215 | colorized = interpolate_colormap(1.0 - x if invert else x, colormap) 216 | colorized[x > 1.0] = 0.0 if invert else 1.0 217 | colorized[x < 0.0] = 1.0 if invert else 0.0 218 | 219 | return colorized 220 | 221 | 222 | def colorize_binary_logits(array, cmap=None): 223 | """Colorizes binary logits as a segmentation map.""" 224 | num_classes = array.shape[-1] 225 | if cmap is None: 226 | if num_classes <= 8: 227 | cmap = 'Set3' 228 | elif num_classes <= 10: 229 | cmap = 'tab10' 230 | elif num_classes <= 20: 231 | cmap = 'tab20' 232 | else: 233 | cmap = 'gist_rainbow' 234 | 235 | colormap = get_colormap(cmap, num_classes) 236 | indices = np.argmax(array, axis=-1) 237 | return np.take(colormap, indices, axis=0) 238 | -------------------------------------------------------------------------------- /nerfies/warping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Warp fields.""" 16 | from typing import Any, Iterable, Optional, Dict 17 | 18 | from flax import linen as nn 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from nerfies import glo 23 | from nerfies import model_utils 24 | from nerfies import modules 25 | from nerfies import rigid_body as rigid 26 | from nerfies import types 27 | 28 | 29 | def create_warp_field( 30 | field_type: str, 31 | num_freqs: int, 32 | num_embeddings: int, 33 | num_features: int, 34 | num_batch_dims: int, 35 | **kwargs): 36 | """Factory function for warp fields.""" 37 | kwargs = {**kwargs} 38 | if field_type == 'translation': 39 | warp_field_cls = TranslationField 40 | elif field_type == 'se3': 41 | warp_field_cls = SE3Field 42 | else: 43 | raise ValueError(f'Unknown warp field type: {field_type!r}') 44 | 45 | if num_batch_dims > 0: 46 | v_warp_field_cls = model_utils.vmap_module( 47 | warp_field_cls, 48 | num_batch_dims=num_batch_dims, 49 | # (points, metadata, extras, return_jacobian, 50 | # metadata_encoded). 51 | in_axes=(0, 0, None, None, None)) 52 | else: 53 | v_warp_field_cls = warp_field_cls 54 | 55 | return v_warp_field_cls( 56 | num_freqs=num_freqs, 57 | num_embeddings=num_embeddings, 58 | num_embedding_features=num_features, 59 | **kwargs) 60 | 61 | 62 | class TranslationField(nn.Module): 63 | """Network that predicts warps as a translation field. 64 | 65 | References: 66 | https://en.wikipedia.org/wiki/Vector_potential 67 | https://en.wikipedia.org/wiki/Helmholtz_decomposition 68 | 69 | Attributes: 70 | points_encoder: the positional encoder for the points. 71 | metadata_encoder: an encoder for metadata. 72 | alpha: the alpha for the positional encoding. 73 | skips: the index of the layers with skip connections. 74 | depth: the depth of the network excluding the output layer. 75 | hidden_channels: the width of the network hidden layers. 76 | activation: the activation for each layer. 77 | metadata_encoded: whether the metadata parameter is pre-encoded or not. 78 | hidden_initializer: the initializer for the hidden layers. 79 | output_initializer: the initializer for the last output layer. 80 | """ 81 | num_freqs: int 82 | num_embeddings: int 83 | num_embedding_features: int 84 | min_freq_log2: int = 0 85 | max_freq_log2: Optional[int] = None 86 | use_identity_map: bool = True 87 | 88 | metadata_encoder_type: str = 'glo' 89 | metadata_encoder_num_freqs: int = 1 90 | 91 | skips: Iterable[int] = (4,) 92 | depth: int = 6 93 | hidden_channels: int = 128 94 | activation: types.Activation = nn.relu 95 | hidden_init: types.Initializer = nn.initializers.xavier_uniform() 96 | output_init: types.Initializer = nn.initializers.uniform(scale=1e-4) 97 | 98 | def setup(self): 99 | self.points_encoder = modules.AnnealedSinusoidalEncoder( 100 | num_freqs=self.num_freqs, 101 | min_freq_log2=self.min_freq_log2, 102 | max_freq_log2=self.max_freq_log2, 103 | use_identity=self.use_identity_map) 104 | 105 | if self.metadata_encoder_type == 'glo': 106 | self.metadata_encoder = glo.GloEncoder( 107 | num_embeddings=self.num_embeddings, 108 | features=self.num_embedding_features) 109 | elif self.metadata_encoder_type == 'time': 110 | self.metadata_encoder = modules.TimeEncoder( 111 | num_freqs=self.metadata_encoder_num_freqs, 112 | features=self.num_embedding_features) 113 | elif self.metadata_encoder_type == 'blend': 114 | self.glo_encoder = glo.GloEncoder( 115 | num_embeddings=self.num_embeddings, 116 | features=self.num_embedding_features) 117 | self.time_encoder = modules.TimeEncoder( 118 | num_freqs=self.metadata_encoder_num_freqs, 119 | features=self.num_embedding_features) 120 | else: 121 | raise ValueError( 122 | f'Unknown metadata encoder type {self.metadata_encoder_type}') 123 | 124 | # Note that this must be done this way instead of using mutable list 125 | # operations. 126 | # See https://github.com/google/flax/issues/524. 127 | # pylint: disable=g-complex-comprehension 128 | output_dims = 3 129 | self.mlp = modules.MLP( 130 | width=self.hidden_channels, 131 | depth=self.depth, 132 | skips=self.skips, 133 | hidden_init=self.hidden_init, 134 | output_init=self.output_init, 135 | output_channels=output_dims) 136 | 137 | def encode_metadata(self, 138 | metadata: jnp.ndarray, 139 | time_alpha: Optional[float] = None): 140 | if self.metadata_encoder_type == 'time': 141 | metadata_embed = self.metadata_encoder(metadata, time_alpha) 142 | elif self.metadata_encoder_type == 'blend': 143 | glo_embed = self.glo_encoder(metadata) 144 | time_embed = self.time_encoder(metadata) 145 | metadata_embed = ((1.0 - time_alpha) * glo_embed + 146 | time_alpha * time_embed) 147 | elif self.metadata_encoder_type == 'glo': 148 | metadata_embed = self.metadata_encoder(metadata) 149 | else: 150 | raise RuntimeError( 151 | f'Unknown metadata encoder type {self.metadata_encoder_type}') 152 | 153 | return metadata_embed 154 | 155 | def warp(self, 156 | points: jnp.ndarray, 157 | metadata_embed: jnp.ndarray, 158 | extra: Dict[str, Any]): 159 | points_embed = self.points_encoder(points, alpha=extra.get('alpha')) 160 | inputs = jnp.concatenate([points_embed, metadata_embed], axis=-1) 161 | translation = self.mlp(inputs) 162 | warped_points = points + translation 163 | 164 | return warped_points 165 | 166 | def __call__(self, 167 | points: jnp.ndarray, 168 | metadata: jnp.ndarray, 169 | extra: Dict[str, Any], 170 | return_jacobian: bool = False, 171 | metadata_encoded: bool = False): 172 | """Warp the given points using a warp field. 173 | 174 | Args: 175 | points: the points to warp. 176 | metadata: metadata indices if metadata_encoded is False else pre-encoded 177 | metadata. 178 | extra: extra parameters used in the warp field e.g., the warp alpha. 179 | return_jacobian: if True compute and return the Jacobian of the warp. 180 | metadata_encoded: if True assumes the metadata is already encoded. 181 | 182 | Returns: 183 | The warped points and the Jacobian of the warp if `return_jacobian` is 184 | True. 185 | """ 186 | if metadata_encoded: 187 | metadata_embed = metadata 188 | else: 189 | metadata_embed = self.encode_metadata(metadata, extra.get('time_alpha')) 190 | 191 | out = { 192 | 'warped_points': self.warp(points, metadata_embed, extra) 193 | } 194 | 195 | if return_jacobian: 196 | jac_fn = jax.jacfwd(lambda *x: self.warp(*x)[..., :3], argnums=0) 197 | out['jacobian'] = jac_fn(points, metadata_embed, extra) 198 | 199 | return out 200 | 201 | 202 | class SE3Field(nn.Module): 203 | """Network that predicts warps as an SE(3) field. 204 | 205 | Attributes: 206 | points_encoder: the positional encoder for the points. 207 | metadata_encoder: an encoder for metadata. 208 | alpha: the alpha for the positional encoding. 209 | skips: the index of the layers with skip connections. 210 | depth: the depth of the network excluding the logit layer. 211 | hidden_channels: the width of the network hidden layers. 212 | activation: the activation for each layer. 213 | metadata_encoded: whether the metadata parameter is pre-encoded or not. 214 | hidden_initializer: the initializer for the hidden layers. 215 | output_initializer: the initializer for the last logit layer. 216 | """ 217 | num_freqs: int 218 | num_embeddings: int 219 | num_embedding_features: int 220 | min_freq_log2: int = 0 221 | max_freq_log2: Optional[int] = None 222 | use_identity_map: bool = True 223 | 224 | activation: types.Activation = nn.relu 225 | skips: Iterable[int] = (4,) 226 | trunk_depth: int = 6 227 | trunk_width: int = 128 228 | rotation_depth: int = 0 229 | rotation_width: int = 128 230 | pivot_depth: int = 0 231 | pivot_width: int = 128 232 | translation_depth: int = 0 233 | translation_width: int = 128 234 | metadata_encoder_type: str = 'glo' 235 | metadata_encoder_num_freqs: int = 1 236 | 237 | default_init: types.Initializer = nn.initializers.xavier_uniform() 238 | rotation_init: types.Initializer = nn.initializers.uniform(scale=1e-4) 239 | pivot_init: types.Initializer = nn.initializers.uniform(scale=1e-4) 240 | translation_init: types.Initializer = nn.initializers.uniform(scale=1e-4) 241 | 242 | use_pivot: bool = False 243 | use_translation: bool = False 244 | 245 | def setup(self): 246 | self.points_encoder = modules.AnnealedSinusoidalEncoder( 247 | num_freqs=self.num_freqs, 248 | min_freq_log2=self.min_freq_log2, 249 | max_freq_log2=self.max_freq_log2, 250 | use_identity=self.use_identity_map) 251 | 252 | if self.metadata_encoder_type == 'glo': 253 | self.metadata_encoder = glo.GloEncoder( 254 | num_embeddings=self.num_embeddings, 255 | features=self.num_embedding_features) 256 | elif self.metadata_encoder_type == 'time': 257 | self.metadata_encoder = modules.TimeEncoder( 258 | num_freqs=self.metadata_encoder_num_freqs, 259 | features=self.num_embedding_features) 260 | else: 261 | raise ValueError( 262 | f'Unknown metadata encoder type {self.metadata_encoder_type}') 263 | 264 | self.trunk = modules.MLP( 265 | depth=self.trunk_depth, 266 | width=self.trunk_width, 267 | hidden_activation=self.activation, 268 | hidden_init=self.default_init, 269 | skips=self.skips) 270 | 271 | branches = { 272 | 'w': 273 | modules.MLP( 274 | depth=self.rotation_depth, 275 | width=self.rotation_width, 276 | hidden_activation=self.activation, 277 | hidden_init=self.default_init, 278 | output_init=self.rotation_init, 279 | output_channels=3), 280 | 'v': 281 | modules.MLP( 282 | depth=self.pivot_depth, 283 | width=self.pivot_width, 284 | hidden_activation=self.activation, 285 | hidden_init=self.default_init, 286 | output_init=self.pivot_init, 287 | output_channels=3), 288 | } 289 | if self.use_pivot: 290 | branches['p'] = modules.MLP( 291 | depth=self.pivot_depth, 292 | width=self.pivot_width, 293 | hidden_activation=self.activation, 294 | hidden_init=self.default_init, 295 | output_init=self.pivot_init, 296 | output_channels=3) 297 | if self.use_translation: 298 | branches['t'] = modules.MLP( 299 | depth=self.translation_depth, 300 | width=self.translation_width, 301 | hidden_activation=self.activation, 302 | hidden_init=self.default_init, 303 | output_init=self.translation_init, 304 | output_channels=3) 305 | # Note that this must be done this way instead of using mutable operations. 306 | # See https://github.com/google/flax/issues/524. 307 | self.branches = branches 308 | 309 | def encode_metadata(self, 310 | metadata: jnp.ndarray, 311 | time_alpha: Optional[float] = None): 312 | if self.metadata_encoder_type == 'time': 313 | metadata_embed = self.metadata_encoder(metadata, time_alpha) 314 | elif self.metadata_encoder_type == 'glo': 315 | metadata_embed = self.metadata_encoder(metadata) 316 | else: 317 | raise RuntimeError( 318 | f'Unknown metadata encoder type {self.metadata_encoder_type}') 319 | 320 | return metadata_embed 321 | 322 | def warp(self, 323 | points: jnp.ndarray, 324 | metadata_embed: jnp.ndarray, 325 | extra: Dict[str, Any]): 326 | points_embed = self.points_encoder(points, alpha=extra.get('alpha')) 327 | inputs = jnp.concatenate([points_embed, metadata_embed], axis=-1) 328 | trunk_output = self.trunk(inputs) 329 | 330 | w = self.branches['w'](trunk_output) 331 | v = self.branches['v'](trunk_output) 332 | theta = jnp.linalg.norm(w, axis=-1) 333 | w = w / theta[..., jnp.newaxis] 334 | v = v / theta[..., jnp.newaxis] 335 | screw_axis = jnp.concatenate([w, v], axis=-1) 336 | transform = rigid.exp_se3(screw_axis, theta) 337 | 338 | warped_points = points 339 | if self.use_pivot: 340 | pivot = self.branches['p'](trunk_output) 341 | warped_points = warped_points + pivot 342 | 343 | warped_points = rigid.from_homogenous( 344 | transform @ rigid.to_homogenous(warped_points)) 345 | 346 | if self.use_pivot: 347 | warped_points = warped_points - pivot 348 | 349 | if self.use_translation: 350 | t = self.branches['t'](trunk_output) 351 | warped_points = warped_points + t 352 | 353 | return warped_points 354 | 355 | def __call__(self, 356 | points: jnp.ndarray, 357 | metadata: jnp.ndarray, 358 | extra: Dict[str, Any], 359 | return_jacobian: bool = False, 360 | metadata_encoded: bool = False): 361 | """Warp the given points using a warp field. 362 | 363 | Args: 364 | points: the points to warp. 365 | metadata: metadata indices if metadata_encoded is False else pre-encoded 366 | metadata. 367 | extra: A dictionary containing 368 | 'alpha': the alpha value for the positional encoding. 369 | 'time_alpha': the alpha value for the time positional encoding 370 | (if applicable). 371 | return_jacobian: if True compute and return the Jacobian of the warp. 372 | metadata_encoded: if True assumes the metadata is already encoded. 373 | 374 | Returns: 375 | The warped points and the Jacobian of the warp if `return_jacobian` is 376 | True. 377 | """ 378 | if metadata_encoded: 379 | metadata_embed = metadata 380 | else: 381 | metadata_embed = self.encode_metadata(metadata, extra.get('time_alpha')) 382 | 383 | out = {'warped_points': self.warp(points, metadata_embed, extra)} 384 | 385 | if return_jacobian: 386 | jac_fn = jax.jacfwd(self.warp, argnums=0) 387 | out['jacobian'] = jac_fn(points, metadata_embed, extra) 388 | 389 | return out 390 | -------------------------------------------------------------------------------- /notebooks/Nerfies_Render_Video.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Nerfies Render Video v2.ipynb", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "TPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "QMMWf9AQcdlp" 22 | }, 23 | "source": [ 24 | "# Render a Nerfie video!\n", 25 | "\n", 26 | "**Author**: [Keunhong Park](https://keunhong.com)\n", 27 | "\n", 28 | "[[Project Page](https://nerfies.github.io)]\n", 29 | "[[Paper](https://storage.googleapis.com/nerfies-public/videos/nerfies_paper.pdf)]\n", 30 | "[[Video](https://www.youtube.com/watch?v=MrKrnHhk8IA)]\n", 31 | "[[GitHub](https://github.com/google/nerfies)]\n", 32 | "\n", 33 | "This notebook renders a figure-8 orbit video using the test cameras generated in the capture processing notebook.\n", 34 | "\n", 35 | "You can also load your own custom cameras by modifying the code slightly.\n", 36 | "\n", 37 | "### Instructions\n", 38 | "\n", 39 | "1. Convert a video into our dataset format using the [capture processing notebook](https://colab.sandbox.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb).\n", 40 | "2. Train a Nerfie model using the [training notebook](https://colab.sandbox.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Training.ipynb)\n", 41 | "3. Run this notebook!\n", 42 | "\n", 43 | "\n", 44 | "### Notes\n", 45 | " * Please report issues on the [GitHub issue tracker](https://github.com/google/nerfies/issues)." 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": { 51 | "id": "gHqkIo4hcGou" 52 | }, 53 | "source": [ 54 | "## Environment Setup" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "metadata": { 60 | "id": "-GwSf5FfcH4b" 61 | }, 62 | "source": [ 63 | "!pip install flax immutabledict mediapy\n", 64 | "!pip install git+https://github.com/google/nerfies@v2" 65 | ], 66 | "execution_count": null, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "metadata": { 72 | "id": "-3T2lBKBcIGP", 73 | "cellView": "form" 74 | }, 75 | "source": [ 76 | "# @title Configure notebook runtime\n", 77 | "# @markdown If you would like to use a GPU runtime instead, change the runtime type by going to `Runtime > Change runtime type`. \n", 78 | "# @markdown You will have to use a smaller batch size on GPU.\n", 79 | "\n", 80 | "runtime_type = 'tpu' # @param ['gpu', 'tpu']\n", 81 | "if runtime_type == 'tpu':\n", 82 | " import jax.tools.colab_tpu\n", 83 | " jax.tools.colab_tpu.setup_tpu()\n", 84 | "\n", 85 | "print('Detected Devices:', jax.devices())" 86 | ], 87 | "execution_count": null, 88 | "outputs": [] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "metadata": { 93 | "id": "82kU-W1NcNTW", 94 | "cellView": "form" 95 | }, 96 | "source": [ 97 | "# @title Mount Google Drive\n", 98 | "# @markdown Mount Google Drive onto `/content/gdrive`. You can skip this if running locally.\n", 99 | "\n", 100 | "from google.colab import drive\n", 101 | "drive.mount('/content/gdrive')" 102 | ], 103 | "execution_count": null, 104 | "outputs": [] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "metadata": { 109 | "id": "YIDbV769cPn1", 110 | "cellView": "form" 111 | }, 112 | "source": [ 113 | "# @title Define imports and utility functions.\n", 114 | "\n", 115 | "import jax\n", 116 | "from jax.config import config as jax_config\n", 117 | "import jax.numpy as jnp\n", 118 | "from jax import grad, jit, vmap\n", 119 | "from jax import random\n", 120 | "\n", 121 | "import flax\n", 122 | "import flax.linen as nn\n", 123 | "from flax import jax_utils\n", 124 | "from flax import optim\n", 125 | "from flax.metrics import tensorboard\n", 126 | "from flax.training import checkpoints\n", 127 | "\n", 128 | "from absl import logging\n", 129 | "from io import BytesIO\n", 130 | "import random as pyrandom\n", 131 | "import numpy as np\n", 132 | "import PIL\n", 133 | "import IPython\n", 134 | "import tempfile\n", 135 | "import imageio\n", 136 | "import mediapy\n", 137 | "from IPython.display import display, HTML\n", 138 | "from base64 import b64encode\n", 139 | "\n", 140 | "\n", 141 | "# Monkey patch logging.\n", 142 | "def myprint(msg, *args, **kwargs):\n", 143 | " print(msg % args)\n", 144 | "\n", 145 | "logging.info = myprint \n", 146 | "logging.warn = myprint\n", 147 | "logging.error = myprint" 148 | ], 149 | "execution_count": null, 150 | "outputs": [] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "metadata": { 155 | "id": "2QYJ7dyMcw2f" 156 | }, 157 | "source": [ 158 | "# @title Model and dataset configuration\n", 159 | "# @markdown Change the directories to where you saved your capture and experiment.\n", 160 | "\n", 161 | "\n", 162 | "from pathlib import Path\n", 163 | "from pprint import pprint\n", 164 | "import gin\n", 165 | "from IPython.display import display, Markdown\n", 166 | "\n", 167 | "from nerfies import configs\n", 168 | "\n", 169 | "\n", 170 | "# @markdown The working directory where the trained model is.\n", 171 | "train_dir = '/content/gdrive/My Drive/nerfies/experiments/capture1/exp1' # @param {type: \"string\"}\n", 172 | "# @markdown The directory to the dataset capture.\n", 173 | "data_dir = '/content/gdrive/My Drive/nerfies/captures/capture1' # @param {type: \"string\"}\n", 174 | "\n", 175 | "checkpoint_dir = Path(train_dir, 'checkpoints')\n", 176 | "checkpoint_dir.mkdir(exist_ok=True, parents=True)\n", 177 | "\n", 178 | "config_path = Path(train_dir, 'config.gin')\n", 179 | "with open(config_path, 'r') as f:\n", 180 | " logging.info('Loading config from %s', config_path)\n", 181 | " config_str = f.read()\n", 182 | "gin.parse_config(config_str)\n", 183 | "\n", 184 | "config_path = Path(train_dir, 'config.gin')\n", 185 | "with open(config_path, 'w') as f:\n", 186 | " logging.info('Saving config to %s', config_path)\n", 187 | " f.write(config_str)\n", 188 | "\n", 189 | "exp_config = configs.ExperimentConfig()\n", 190 | "model_config = configs.ModelConfig()\n", 191 | "train_config = configs.TrainConfig()\n", 192 | "eval_config = configs.EvalConfig()\n", 193 | "\n", 194 | "display(Markdown(\n", 195 | " gin.config.markdown(gin.operative_config_str())))" 196 | ], 197 | "execution_count": null, 198 | "outputs": [] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "metadata": { 203 | "id": "6T7LQ5QSmu4o", 204 | "cellView": "form" 205 | }, 206 | "source": [ 207 | "# @title Create datasource and show an example.\n", 208 | "\n", 209 | "from nerfies import datasets\n", 210 | "from nerfies import image_utils\n", 211 | "\n", 212 | "datasource = datasets.from_config(\n", 213 | " exp_config.datasource_spec,\n", 214 | " image_scale=exp_config.image_scale,\n", 215 | " use_appearance_id=model_config.use_appearance_metadata,\n", 216 | " use_camera_id=model_config.use_camera_metadata,\n", 217 | " use_warp_id=model_config.use_warp,\n", 218 | " random_seed=exp_config.random_seed)\n", 219 | "\n", 220 | "mediapy.show_image(datasource.load_rgb(datasource.train_ids[0]))" 221 | ], 222 | "execution_count": null, 223 | "outputs": [] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "metadata": { 228 | "id": "jEO3xcxpnCqx", 229 | "cellView": "form" 230 | }, 231 | "source": [ 232 | "# @title Initialize model\n", 233 | "# @markdown Defines the model and initializes its parameters.\n", 234 | "\n", 235 | "from flax.training import checkpoints\n", 236 | "from nerfies import models\n", 237 | "from nerfies import model_utils\n", 238 | "from nerfies import schedules\n", 239 | "from nerfies import training\n", 240 | "\n", 241 | "\n", 242 | "rng = random.PRNGKey(exp_config.random_seed)\n", 243 | "np.random.seed(exp_config.random_seed + jax.process_index())\n", 244 | "devices = jax.devices()\n", 245 | "\n", 246 | "learning_rate_sched = schedules.from_config(train_config.lr_schedule)\n", 247 | "warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)\n", 248 | "elastic_loss_weight_sched = schedules.from_config(\n", 249 | " train_config.elastic_loss_weight_schedule)\n", 250 | "\n", 251 | "rng, key = random.split(rng)\n", 252 | "params = {}\n", 253 | "model, params['model'] = models.construct_nerf(\n", 254 | " key,\n", 255 | " model_config,\n", 256 | " batch_size=train_config.batch_size,\n", 257 | " appearance_ids=datasource.appearance_ids,\n", 258 | " camera_ids=datasource.camera_ids,\n", 259 | " warp_ids=datasource.warp_ids,\n", 260 | " near=datasource.near,\n", 261 | " far=datasource.far,\n", 262 | " use_warp_jacobian=train_config.use_elastic_loss,\n", 263 | " use_weights=train_config.use_elastic_loss)\n", 264 | "\n", 265 | "optimizer_def = optim.Adam(learning_rate_sched(0))\n", 266 | "optimizer = optimizer_def.create(params)\n", 267 | "state = model_utils.TrainState(\n", 268 | " optimizer=optimizer,\n", 269 | " warp_alpha=warp_alpha_sched(0))\n", 270 | "scalar_params = training.ScalarParams(\n", 271 | " learning_rate=learning_rate_sched(0),\n", 272 | " elastic_loss_weight=elastic_loss_weight_sched(0),\n", 273 | " background_loss_weight=train_config.background_loss_weight)\n", 274 | "logging.info('Restoring checkpoint from %s', checkpoint_dir)\n", 275 | "state = checkpoints.restore_checkpoint(checkpoint_dir, state)\n", 276 | "step = state.optimizer.state.step + 1\n", 277 | "state = jax_utils.replicate(state, devices=devices)\n", 278 | "del params" 279 | ], 280 | "execution_count": null, 281 | "outputs": [] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "metadata": { 286 | "id": "2KYhbpsklwAy", 287 | "cellView": "form" 288 | }, 289 | "source": [ 290 | "# @title Define pmapped render function.\n", 291 | "\n", 292 | "import functools\n", 293 | "from nerfies import evaluation\n", 294 | "\n", 295 | "devices = jax.devices()\n", 296 | "\n", 297 | "\n", 298 | "def _model_fn(key_0, key_1, params, rays_dict, warp_extra):\n", 299 | " out = model.apply({'params': params},\n", 300 | " rays_dict,\n", 301 | " warp_extra=warp_extra,\n", 302 | " rngs={\n", 303 | " 'coarse': key_0,\n", 304 | " 'fine': key_1\n", 305 | " },\n", 306 | " mutable=False)\n", 307 | " return jax.lax.all_gather(out, axis_name='batch')\n", 308 | "\n", 309 | "pmodel_fn = jax.pmap(\n", 310 | " # Note rng_keys are useless in eval mode since there's no randomness.\n", 311 | " _model_fn,\n", 312 | " in_axes=(0, 0, 0, 0, 0), # Only distribute the data input.\n", 313 | " devices=devices,\n", 314 | " donate_argnums=(3,), # Donate the 'rays' argument.\n", 315 | " axis_name='batch',\n", 316 | ")\n", 317 | "\n", 318 | "render_fn = functools.partial(evaluation.render_image,\n", 319 | " model_fn=pmodel_fn,\n", 320 | " device_count=len(devices),\n", 321 | " chunk=eval_config.chunk)" 322 | ], 323 | "execution_count": null, 324 | "outputs": [] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "metadata": { 329 | "id": "73Fq0kNcmAra" 330 | }, 331 | "source": [ 332 | "# @title Load cameras.\n", 333 | "\n", 334 | "from nerfies import utils\n", 335 | "\n", 336 | "camera_path = 'camera-paths/orbit-mild' # @param {type: 'string'}\n", 337 | "\n", 338 | "camera_dir = Path(data_dir, camera_path)\n", 339 | "print(f'Loading cameras from {camera_dir}')\n", 340 | "test_camera_paths = datasource.glob_cameras(camera_dir)\n", 341 | "test_cameras = utils.parallel_map(datasource.load_camera, test_camera_paths, show_pbar=True)" 342 | ], 343 | "execution_count": null, 344 | "outputs": [] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "metadata": { 349 | "id": "aP9LjiAZmoRc", 350 | "cellView": "form" 351 | }, 352 | "source": [ 353 | "# @title Render video frames.\n", 354 | "from nerfies import visualization as viz\n", 355 | "\n", 356 | "\n", 357 | "rng = rng + jax.host_id() # Make random seed separate across hosts.\n", 358 | "keys = random.split(rng, len(devices))\n", 359 | "\n", 360 | "results = []\n", 361 | "for i in range(len(test_cameras)):\n", 362 | " print(f'Rendering frame {i+1}/{len(test_cameras)}')\n", 363 | " camera = test_cameras[i]\n", 364 | " batch = datasets.camera_to_rays(camera)\n", 365 | " batch['metadata'] = {\n", 366 | " 'appearance': jnp.zeros_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32),\n", 367 | " 'warp': jnp.zeros_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32),\n", 368 | " }\n", 369 | "\n", 370 | " render = render_fn(state, batch, rng=rng)\n", 371 | " rgb = np.array(render['rgb'])\n", 372 | " depth_med = np.array(render['med_depth'])\n", 373 | " results.append((rgb, depth_med))\n", 374 | " depth_viz = viz.colorize(depth_med.squeeze(), cmin=datasource.near, cmax=datasource.far, invert=True)\n", 375 | " mediapy.show_images([rgb, depth_viz])" 376 | ], 377 | "execution_count": null, 378 | "outputs": [] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "id": "_5hHR9XVm8Ix" 384 | }, 385 | "source": [ 386 | "# @title Show rendered video.\n", 387 | "\n", 388 | "fps = 30 # @param {type:'number'}\n", 389 | "\n", 390 | "frames = []\n", 391 | "for rgb, depth in results:\n", 392 | " depth_viz = viz.colorize(depth.squeeze(), cmin=datasource.near, cmax=datasource.far, invert=True)\n", 393 | " frame = np.concatenate([rgb, depth_viz], axis=1)\n", 394 | " frames.append(image_utils.image_to_uint8(frame))\n", 395 | "\n", 396 | "mediapy.show_video(frames, fps=fps)" 397 | ], 398 | "execution_count": null, 399 | "outputs": [] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "metadata": { 404 | "id": "WW32AVGR0Vwh" 405 | }, 406 | "source": [ 407 | "" 408 | ], 409 | "execution_count": null, 410 | "outputs": [] 411 | } 412 | ] 413 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | flax==0.3.4 3 | gin-config @ git+https://github.com/google/gin-config@243ba87b3fcfeb2efb4a920b8f19679b61a6f0dc 4 | imageio==2.9.0 5 | immutabledict==2.2.0 6 | jax==0.2.20 7 | numpy==1.19.5 8 | opencv-python==4.5.3.56 9 | Pillow==8.3.2 10 | scikit-image==0.18.3 11 | scipy==1.7.1 12 | tensorboard==2.6.0 13 | tensorflow>=2.6.1 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import setuptools 16 | 17 | with open("README.md", "r", encoding="utf-8") as fh: 18 | long_description = fh.read() 19 | 20 | setuptools.setup( 21 | name="nerfies", # Replace with your own username 22 | version="0.0.2", 23 | author="Keunhong Park", 24 | author_email="kpar@cs.washington.edu", 25 | description="Code for 'Nerfies: Deformable Neural Radiance Fields'.", 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | url="https://github.com/google/nerfies", 29 | packages=setuptools.find_packages(), 30 | classifiers=[ 31 | "Programming Language :: Python :: 3", 32 | "License :: OSI Approved :: Apache License 2.0", 33 | "Operating System :: OS Independent", 34 | ], 35 | python_requires='>=3.7', 36 | ) 37 | -------------------------------------------------------------------------------- /third_party/pycolmap/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 True Price, UNC Chapel Hill 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/pycolmap/README.md: -------------------------------------------------------------------------------- 1 | # pycolmap 2 | Python interface for COLMAP reconstructions, plus some convenient scripts for loading/modifying/converting reconstructions. 3 | 4 | This code does not, however, run reconstruction -- it only provides a convenient interface for handling COLMAP's output. 5 | -------------------------------------------------------------------------------- /third_party/pycolmap/pycolmap/__init__.py: -------------------------------------------------------------------------------- 1 | from .camera import Camera 2 | from .database import COLMAPDatabase 3 | from .image import Image 4 | from .scene_manager import SceneManager 5 | from .rotation import Quaternion, DualQuaternion 6 | -------------------------------------------------------------------------------- /third_party/pycolmap/pycolmap/camera.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | from scipy.optimize import root 6 | 7 | 8 | #------------------------------------------------------------------------------- 9 | # 10 | # camera distortion functions for arrays of size (..., 2) 11 | # 12 | #------------------------------------------------------------------------------- 13 | 14 | def simple_radial_distortion(camera, x): 15 | return x * (1. + camera.k1 * np.square(x).sum(axis=-1, keepdims=True)) 16 | 17 | 18 | def radial_distortion(camera, x): 19 | r_sq = np.square(x).sum(axis=-1, keepdims=True) 20 | return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) 21 | 22 | 23 | def opencv_distortion(camera, x): 24 | x_sq = np.square(x) 25 | xy = np.prod(x, axis=-1, keepdims=True) 26 | r_sq = x_sq.sum(axis=-1, keepdims=True) 27 | 28 | return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + np.concatenate( 29 | (2. * camera.p1 * xy + camera.p2 * (r_sq + 2. * x_sq), camera.p1 * 30 | (r_sq + 2. * y_sq) + 2. * camera.p2 * xy), 31 | axis=-1) 32 | 33 | 34 | #------------------------------------------------------------------------------- 35 | # 36 | # Camera 37 | # 38 | #------------------------------------------------------------------------------- 39 | 40 | class Camera: 41 | 42 | @staticmethod 43 | def GetNumParams(type_): 44 | if type_ == 0 or type_ == 'SIMPLE_PINHOLE': 45 | return 3 46 | if type_ == 1 or type_ == 'PINHOLE': 47 | return 4 48 | if type_ == 2 or type_ == 'SIMPLE_RADIAL': 49 | return 4 50 | if type_ == 3 or type_ == 'RADIAL': 51 | return 5 52 | if type_ == 4 or type_ == 'OPENCV': 53 | return 8 54 | #if type_ == 5 or type_ == 'OPENCV_FISHEYE': 55 | # return 8 56 | #if type_ == 6 or type_ == 'FULL_OPENCV': 57 | # return 12 58 | #if type_ == 7 or type_ == 'FOV': 59 | # return 5 60 | #if type_ == 8 or type_ == 'SIMPLE_RADIAL_FISHEYE': 61 | # return 4 62 | #if type_ == 9 or type_ == 'RADIAL_FISHEYE': 63 | # return 5 64 | #if type_ == 10 or type_ == 'THIN_PRISM_FISHEYE': 65 | # return 12 66 | 67 | # TODO: not supporting other camera types, currently 68 | raise Exception('Camera type not supported') 69 | 70 | #--------------------------------------------------------------------------- 71 | 72 | @staticmethod 73 | def GetNameFromType(type_): 74 | if type_ == 0: 75 | return 'SIMPLE_PINHOLE' 76 | if type_ == 1: 77 | return 'PINHOLE' 78 | if type_ == 2: 79 | return 'SIMPLE_RADIAL' 80 | if type_ == 3: 81 | return 'RADIAL' 82 | if type_ == 4: 83 | return 'OPENCV' 84 | #if type_ == 5: return 'OPENCV_FISHEYE' 85 | #if type_ == 6: return 'FULL_OPENCV' 86 | #if type_ == 7: return 'FOV' 87 | #if type_ == 8: return 'SIMPLE_RADIAL_FISHEYE' 88 | #if type_ == 9: return 'RADIAL_FISHEYE' 89 | #if type_ == 10: return 'THIN_PRISM_FISHEYE' 90 | 91 | raise Exception('Camera type not supported') 92 | 93 | #--------------------------------------------------------------------------- 94 | 95 | def __init__(self, type_, width_, height_, params): 96 | self.width = width_ 97 | self.height = height_ 98 | 99 | if type_ == 0 or type_ == 'SIMPLE_PINHOLE': 100 | self.fx, self.cx, self.cy = params 101 | self.fy = self.fx 102 | self.distortion_func = None 103 | self.camera_type = 0 104 | 105 | elif type_ == 1 or type_ == 'PINHOLE': 106 | self.fx, self.fy, self.cx, self.cy = params 107 | self.distortion_func = None 108 | self.camera_type = 1 109 | 110 | elif type_ == 2 or type_ == 'SIMPLE_RADIAL': 111 | self.fx, self.cx, self.cy, self.k1 = params 112 | self.fy = self.fx 113 | self.distortion_func = simple_radial_distortion 114 | self.camera_type = 2 115 | 116 | elif type_ == 3 or type_ == 'RADIAL': 117 | self.fx, self.cx, self.cy, self.k1, self.k2 = params 118 | self.fy = self.fx 119 | self.distortion_func = radial_distortion 120 | self.camera_type = 3 121 | 122 | elif type_ == 4 or type_ == 'OPENCV': 123 | self.fx, self.fy, self.cx, self.cy = params[:4] 124 | self.k1, self.k2, self.p1, self.p2 = params[4:] 125 | self.distortion_func = opencv_distortion 126 | self.camera_type = 4 127 | 128 | else: 129 | raise Exception('Camera type not supported') 130 | 131 | #--------------------------------------------------------------------------- 132 | 133 | def __str__(self): 134 | s = ( 135 | self.GetNameFromType(self.camera_type) + 136 | ' {} {} {}'.format(self.width, self.height, self.fx)) 137 | 138 | if self.camera_type in (1, 4): # PINHOLE, OPENCV 139 | s += ' {}'.format(self.fy) 140 | 141 | s += ' {} {}'.format(self.cx, self.cy) 142 | 143 | if self.camera_type == 2: # SIMPLE_RADIAL 144 | s += ' {}'.format(self.k1) 145 | 146 | elif self.camera_type == 3: # RADIAL 147 | s += ' {} {}'.format(self.k1, self.k2) 148 | 149 | elif self.camera_type == 4: # OPENCV 150 | s += ' {} {} {} {}'.format(self.k1, self.k2, self.p1, self.p2) 151 | 152 | return s 153 | 154 | #--------------------------------------------------------------------------- 155 | 156 | # return the camera parameters in the same order as the colmap output format 157 | def get_params(self): 158 | if self.camera_type == 0: 159 | return np.array((self.fx, self.cx, self.cy)) 160 | if self.camera_type == 1: 161 | return np.array((self.fx, self.fy, self.cx, self.cy)) 162 | if self.camera_type == 2: 163 | return np.array((self.fx, self.cx, self.cy, self.k1)) 164 | if self.camera_type == 3: 165 | return np.array((self.fx, self.cx, self.cy, self.k1, self.k2)) 166 | if self.camera_type == 4: 167 | return np.array((self.fx, self.fy, self.cx, self.cy, self.k1, self.k2, 168 | self.p1, self.p2)) 169 | 170 | #--------------------------------------------------------------------------- 171 | 172 | def get_camera_matrix(self): 173 | return np.array(((self.fx, 0, self.cx), (0, self.fy, self.cy), (0, 0, 1))) 174 | 175 | def get_inverse_camera_matrix(self): 176 | return np.array(((1. / self.fx, 0, -self.cx / self.fx), 177 | (0, 1. / self.fy, -self.cy / self.fy), (0, 0, 1))) 178 | 179 | @property 180 | def K(self): 181 | return self.get_camera_matrix() 182 | 183 | @property 184 | def K_inv(self): 185 | return self.get_inverse_camera_matrix() 186 | 187 | #--------------------------------------------------------------------------- 188 | 189 | # return the inverse camera matrix 190 | def get_inv_camera_matrix(self): 191 | inv_fx, inv_fy = 1. / self.fx, 1. / self.fy 192 | return np.array(((inv_fx, 0, -inv_fx * self.cx), 193 | (0, inv_fy, -inv_fy * self.cy), (0, 0, 1))) 194 | 195 | #--------------------------------------------------------------------------- 196 | 197 | # return an (x, y) pixel coordinate grid for this camera 198 | def get_image_grid(self): 199 | xmin = (0.5 - self.cx) / self.fx 200 | xmax = (self.width - 0.5 - self.cx) / self.fx 201 | ymin = (0.5 - self.cy) / self.fy 202 | ymax = (self.height - 0.5 - self.cy) / self.fy 203 | return np.meshgrid( 204 | np.linspace(xmin, xmax, self.width), 205 | np.linspace(ymin, ymax, self.height)) 206 | 207 | #--------------------------------------------------------------------------- 208 | 209 | # x: array of shape (N,2) or (2,) 210 | # normalized: False if the input points are in pixel coordinates 211 | # denormalize: True if the points should be put back into pixel coordinates 212 | def distort_points(self, x, normalized=True, denormalize=True): 213 | x = np.atleast_2d(x) 214 | 215 | # put the points into normalized camera coordinates 216 | if not normalized: 217 | x -= np.array([[self.cx, self.cy]]) 218 | x /= np.array([[self.fx, self.fy]]) 219 | 220 | # distort, if necessary 221 | if self.distortion_func is not None: 222 | x = self.distortion_func(self, x) 223 | 224 | if denormalize: 225 | x *= np.array([[self.fx, self.fy]]) 226 | x += np.array([[self.cx, self.cy]]) 227 | 228 | return x 229 | 230 | #--------------------------------------------------------------------------- 231 | 232 | # x: array of shape (N1,N2,...,2), (N,2), or (2,) 233 | # normalized: False if the input points are in pixel coordinates 234 | # denormalize: True if the points should be put back into pixel coordinates 235 | def undistort_points(self, x, normalized=False, denormalize=True): 236 | x = np.atleast_2d(x) 237 | 238 | # put the points into normalized camera coordinates 239 | if not normalized: 240 | x = x - np.array([self.cx, self.cy]) # creates a copy 241 | x /= np.array([self.fx, self.fy]) 242 | 243 | # undistort, if necessary 244 | if self.distortion_func is not None: 245 | 246 | def objective(xu): 247 | return (x - self.distortion_func(self, xu.reshape(*x.shape))).ravel() 248 | 249 | xu = root(objective, x).x.reshape(*x.shape) 250 | else: 251 | xu = x 252 | 253 | if denormalize: 254 | xu *= np.array([[self.fx, self.fy]]) 255 | xu += np.array([[self.cx, self.cy]]) 256 | 257 | return xu 258 | -------------------------------------------------------------------------------- /third_party/pycolmap/pycolmap/database.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sqlite3 4 | 5 | #------------------------------------------------------------------------------- 6 | # convert SQLite BLOBs to/from numpy arrays 7 | 8 | 9 | def array_to_blob(arr): 10 | return np.getbuffer(arr) 11 | 12 | 13 | def blob_to_array(blob, dtype, shape=(-1,)): 14 | return np.frombuffer(blob, dtype).reshape(*shape) 15 | 16 | 17 | #------------------------------------------------------------------------------- 18 | # convert to/from image pair ids 19 | 20 | MAX_IMAGE_ID = 2**31 - 1 21 | 22 | 23 | def get_pair_id(image_id1, image_id2): 24 | if image_id1 > image_id2: 25 | image_id1, image_id2 = image_id2, image_id1 26 | return image_id1 * MAX_IMAGE_ID + image_id2 27 | 28 | 29 | def get_image_ids_from_pair_id(pair_id): 30 | image_id2 = pair_id % MAX_IMAGE_ID 31 | return (pair_id - image_id2) / MAX_IMAGE_ID, image_id2 32 | 33 | 34 | #------------------------------------------------------------------------------- 35 | # create table commands 36 | 37 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( 38 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 39 | model INTEGER NOT NULL, 40 | width INTEGER NOT NULL, 41 | height INTEGER NOT NULL, 42 | params BLOB, 43 | prior_focal_length INTEGER NOT NULL)""" 44 | 45 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( 46 | image_id INTEGER PRIMARY KEY NOT NULL, 47 | rows INTEGER NOT NULL, 48 | cols INTEGER NOT NULL, 49 | data BLOB, 50 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 51 | 52 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( 53 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 54 | name TEXT NOT NULL UNIQUE, 55 | camera_id INTEGER NOT NULL, 56 | prior_qw REAL, 57 | prior_qx REAL, 58 | prior_qy REAL, 59 | prior_qz REAL, 60 | prior_tx REAL, 61 | prior_ty REAL, 62 | prior_tz REAL, 63 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647), 64 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))""" 65 | 66 | CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries ( 67 | pair_id INTEGER PRIMARY KEY NOT NULL, 68 | rows INTEGER NOT NULL, 69 | cols INTEGER NOT NULL, 70 | data BLOB, 71 | config INTEGER NOT NULL, 72 | F BLOB, 73 | E BLOB, 74 | H BLOB)""" 75 | 76 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( 77 | image_id INTEGER PRIMARY KEY NOT NULL, 78 | rows INTEGER NOT NULL, 79 | cols INTEGER NOT NULL, 80 | data BLOB, 81 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 82 | 83 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( 84 | pair_id INTEGER PRIMARY KEY NOT NULL, 85 | rows INTEGER NOT NULL, 86 | cols INTEGER NOT NULL, 87 | data BLOB)""" 88 | 89 | CREATE_NAME_INDEX = \ 90 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" 91 | 92 | CREATE_ALL = "; ".join([ 93 | CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE, CREATE_IMAGES_TABLE, 94 | CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE, CREATE_MATCHES_TABLE, 95 | CREATE_NAME_INDEX 96 | ]) 97 | 98 | #------------------------------------------------------------------------------- 99 | # functional interface for adding objects 100 | 101 | 102 | def add_camera(db, 103 | model, 104 | width, 105 | height, 106 | params, 107 | prior_focal_length=False, 108 | camera_id=None): 109 | # TODO: Parameter count checks 110 | params = np.asarray(params, np.float64) 111 | db.execute("INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", 112 | (camera_id, model, width, height, array_to_blob(params), 113 | prior_focal_length)) 114 | 115 | 116 | def add_descriptors(db, image_id, descriptors): 117 | descriptors = np.ascontiguousarray(descriptors, np.uint8) 118 | db.execute("INSERT INTO descriptors VALUES (?, ?, ?, ?)", 119 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) 120 | 121 | 122 | def add_image(db, 123 | name, 124 | camera_id, 125 | prior_q=np.zeros(4), 126 | prior_t=np.zeros(3), 127 | image_id=None): 128 | db.execute("INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 129 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], 130 | prior_q[3], prior_t[0], prior_t[1], prior_t[2])) 131 | 132 | 133 | # config: defaults to fundamental matrix 134 | def add_inlier_matches(db, 135 | image_id1, 136 | image_id2, 137 | matches, 138 | config=2, 139 | F=None, 140 | E=None, 141 | H=None): 142 | assert (len(matches.shape) == 2) 143 | assert (matches.shape[1] == 2) 144 | 145 | if image_id1 > image_id2: 146 | matches = matches[:, ::-1] 147 | 148 | if F is not None: 149 | F = np.asarray(F, np.float64) 150 | if E is not None: 151 | E = np.asarray(E, np.float64) 152 | if H is not None: 153 | H = np.asarray(H, np.float64) 154 | 155 | pair_id = get_pair_id(image_id1, image_id2) 156 | matches = np.asarray(matches, np.uint32) 157 | db.execute("INSERT INTO inlier_matches VALUES (?, ?, ?, ?, ?, ?, ?, ?)", 158 | (pair_id,) + matches.shape + 159 | (array_to_blob(matches), config, F, E, H)) 160 | 161 | 162 | def add_keypoints(db, image_id, keypoints): 163 | assert (len(keypoints.shape) == 2) 164 | assert (keypoints.shape[1] in [2, 4, 6]) 165 | 166 | keypoints = np.asarray(keypoints, np.float32) 167 | db.execute("INSERT INTO keypoints VALUES (?, ?, ?, ?)", 168 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) 169 | 170 | 171 | # config: defaults to fundamental matrix 172 | def add_matches(db, image_id1, image_id2, matches): 173 | assert (len(matches.shape) == 2) 174 | assert (matches.shape[1] == 2) 175 | 176 | if image_id1 > image_id2: 177 | matches = matches[:, ::-1] 178 | 179 | pair_id = get_pair_id(image_id1, image_id2) 180 | matches = np.asarray(matches, np.uint32) 181 | db.execute("INSERT INTO matches VALUES (?, ?, ?, ?)", 182 | (pair_id,) + matches.shape + (array_to_blob(matches),)) 183 | 184 | 185 | #------------------------------------------------------------------------------- 186 | # simple functional interface 187 | 188 | 189 | class COLMAPDatabase(sqlite3.Connection): 190 | 191 | @staticmethod 192 | def connect(database_path): 193 | return sqlite3.connect(database_path, factory=COLMAPDatabase) 194 | 195 | def __init__(self, *args, **kwargs): 196 | super(COLMAPDatabase, self).__init__(*args, **kwargs) 197 | 198 | self.initialize_tables = lambda: self.executescript(CREATE_ALL) 199 | 200 | self.initialize_cameras = \ 201 | lambda: self.executescript(CREATE_CAMERAS_TABLE) 202 | self.initialize_descriptors = \ 203 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) 204 | self.initialize_images = \ 205 | lambda: self.executescript(CREATE_IMAGES_TABLE) 206 | self.initialize_inlier_matches = \ 207 | lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE) 208 | self.initialize_keypoints = \ 209 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE) 210 | self.initialize_matches = \ 211 | lambda: self.executescript(CREATE_MATCHES_TABLE) 212 | 213 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) 214 | 215 | add_camera = add_camera 216 | add_descriptors = add_descriptors 217 | add_image = add_image 218 | add_inlier_matches = add_inlier_matches 219 | add_keypoints = add_keypoints 220 | add_matches = add_matches 221 | 222 | 223 | #------------------------------------------------------------------------------- 224 | 225 | 226 | def main(args): 227 | import os 228 | 229 | if os.path.exists(args.database_path): 230 | print("Error: database path already exists -- will not modify it.") 231 | exit() 232 | 233 | db = COLMAPDatabase.connect(args.database_path) 234 | 235 | # 236 | # for convenience, try creating all the tables upfront 237 | # 238 | 239 | db.initialize_tables() 240 | 241 | # 242 | # create dummy cameras 243 | # 244 | 245 | model1, w1, h1, params1 = 0, 1024, 768, np.array((1024., 512., 384.)) 246 | model2, w2, h2, params2 = 2, 1024, 768, np.array((1024., 512., 384., 0.1)) 247 | 248 | db.add_camera(model1, w1, h1, params1) 249 | db.add_camera(model2, w2, h2, params2) 250 | 251 | # 252 | # create dummy images 253 | # 254 | 255 | db.add_image("image1.png", 0) 256 | db.add_image("image2.png", 0) 257 | db.add_image("image3.png", 2) 258 | db.add_image("image4.png", 2) 259 | 260 | # 261 | # create dummy keypoints; note that COLMAP supports 2D keypoints (x, y), 262 | # 4D keypoints (x, y, theta, scale), and 6D affine keypoints 263 | # (x, y, a_11, a_12, a_21, a_22) 264 | # 265 | 266 | N = 1000 267 | kp1 = np.random.rand(N, 2) * (1024., 768.) 268 | kp2 = np.random.rand(N, 2) * (1024., 768.) 269 | kp3 = np.random.rand(N, 2) * (1024., 768.) 270 | kp4 = np.random.rand(N, 2) * (1024., 768.) 271 | 272 | db.add_keypoints(1, kp1) 273 | db.add_keypoints(2, kp2) 274 | db.add_keypoints(3, kp3) 275 | db.add_keypoints(4, kp4) 276 | 277 | # 278 | # create dummy matches 279 | # 280 | 281 | M = 50 282 | m12 = np.random.randint(N, size=(M, 2)) 283 | m23 = np.random.randint(N, size=(M, 2)) 284 | m34 = np.random.randint(N, size=(M, 2)) 285 | 286 | db.add_matches(1, 2, m12) 287 | db.add_matches(2, 3, m23) 288 | db.add_matches(3, 4, m34) 289 | 290 | # 291 | # check cameras 292 | # 293 | 294 | rows = db.execute("SELECT * FROM cameras") 295 | 296 | camera_id, model, width, height, params, prior = next(rows) 297 | params = blob_to_array(params, np.float32) 298 | assert model == model1 and width == w1 and height == h1 299 | assert np.allclose(params, params1) 300 | 301 | camera_id, model, width, height, params, prior = next(rows) 302 | params = blob_to_array(params, np.float32) 303 | assert model == model2 and width == w2 and height == h2 304 | assert np.allclose(params, params2) 305 | 306 | # 307 | # check keypoints 308 | # 309 | 310 | kps = dict( 311 | (image_id, blob_to_array(data, np.float32, (-1, 2))) 312 | for image_id, data in db.execute("SELECT image_id, data FROM keypoints")) 313 | 314 | assert np.allclose(kps[1], kp1) 315 | assert np.allclose(kps[2], kp2) 316 | assert np.allclose(kps[3], kp3) 317 | assert np.allclose(kps[4], kp4) 318 | 319 | # 320 | # check matches 321 | # 322 | 323 | pair_ids = [get_pair_id(*pair) for pair in [(1, 2), (2, 3), (3, 4)]] 324 | 325 | matches = dict( 326 | (get_image_ids_from_pair_id(pair_id), 327 | blob_to_array(data, np.uint32, (-1, 2))) 328 | for pair_id, data in db.execute("SELECT pair_id, data FROM matches")) 329 | 330 | assert np.all(matches[(1, 2)] == m12) 331 | assert np.all(matches[(2, 3)] == m23) 332 | assert np.all(matches[(3, 4)] == m34) 333 | 334 | # 335 | # clean up 336 | # 337 | 338 | db.close() 339 | os.remove(args.database_path) 340 | -------------------------------------------------------------------------------- /third_party/pycolmap/pycolmap/image.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | #------------------------------------------------------------------------------- 6 | # 7 | # Image 8 | # 9 | #------------------------------------------------------------------------------- 10 | 11 | 12 | class Image: 13 | 14 | def __init__(self, name_, camera_id_, q_, tvec_): 15 | self.name = name_ 16 | self.camera_id = camera_id_ 17 | self.q = q_ 18 | self.tvec = tvec_ 19 | 20 | self.points2D = np.empty((0, 2), dtype=np.float64) 21 | self.point3D_ids = np.empty((0,), dtype=np.uint64) 22 | 23 | #--------------------------------------------------------------------------- 24 | 25 | def R(self): 26 | return self.q.ToR() 27 | 28 | #--------------------------------------------------------------------------- 29 | 30 | def C(self): 31 | return -self.R().T.dot(self.tvec) 32 | 33 | #--------------------------------------------------------------------------- 34 | 35 | @property 36 | def t(self): 37 | return self.tvec 38 | -------------------------------------------------------------------------------- /third_party/pycolmap/pycolmap/rotation.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | #------------------------------------------------------------------------------- 6 | # 7 | # Axis-Angle Functions 8 | # 9 | #------------------------------------------------------------------------------- 10 | 11 | 12 | # returns the cross product matrix representation of a 3-vector v 13 | def cross_prod_matrix(v): 14 | return np.array(((0., -v[2], v[1]), (v[2], 0., -v[0]), (-v[1], v[0], 0.))) 15 | 16 | 17 | #------------------------------------------------------------------------------- 18 | 19 | 20 | # www.euclideanspace.com/maths/geometry/rotations/conversions/angleToMatrix/ 21 | # if angle is None, assume ||axis|| == angle, in radians 22 | # if angle is not None, assume that axis is a unit vector 23 | def axis_angle_to_rotation_matrix(axis, angle=None): 24 | if angle is None: 25 | angle = np.linalg.norm(axis) 26 | if np.abs(angle) > np.finfo('float').eps: 27 | axis = axis / angle 28 | 29 | cp_axis = cross_prod_matrix(axis) 30 | return np.eye(3) + ( 31 | np.sin(angle) * cp_axis + (1. - np.cos(angle)) * cp_axis.dot(cp_axis)) 32 | 33 | 34 | #------------------------------------------------------------------------------- 35 | 36 | 37 | # after some deliberation, I've decided the easiest way to do this is to use 38 | # quaternions as an intermediary 39 | def rotation_matrix_to_axis_angle(R): 40 | return Quaternion.FromR(R).ToAxisAngle() 41 | 42 | 43 | #------------------------------------------------------------------------------- 44 | # 45 | # Quaternion 46 | # 47 | #------------------------------------------------------------------------------- 48 | 49 | 50 | class Quaternion: 51 | # create a quaternion from an existing rotation matrix 52 | # euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/ 53 | @staticmethod 54 | def FromR(R): 55 | trace = np.trace(R) 56 | 57 | if trace > 0: 58 | qw = 0.5 * np.sqrt(1. + trace) 59 | qx = (R[2, 1] - R[1, 2]) * 0.25 / qw 60 | qy = (R[0, 2] - R[2, 0]) * 0.25 / qw 61 | qz = (R[1, 0] - R[0, 1]) * 0.25 / qw 62 | elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]: 63 | s = 2. * np.sqrt(1. + R[0, 0] - R[1, 1] - R[2, 2]) 64 | qw = (R[2, 1] - R[1, 2]) / s 65 | qx = 0.25 * s 66 | qy = (R[0, 1] + R[1, 0]) / s 67 | qz = (R[0, 2] + R[2, 0]) / s 68 | elif R[1, 1] > R[2, 2]: 69 | s = 2. * np.sqrt(1. + R[1, 1] - R[0, 0] - R[2, 2]) 70 | qw = (R[0, 2] - R[2, 0]) / s 71 | qx = (R[0, 1] + R[1, 0]) / s 72 | qy = 0.25 * s 73 | qz = (R[1, 2] + R[2, 1]) / s 74 | else: 75 | s = 2. * np.sqrt(1. + R[2, 2] - R[0, 0] - R[1, 1]) 76 | qw = (R[1, 0] - R[0, 1]) / s 77 | qx = (R[0, 2] + R[2, 0]) / s 78 | qy = (R[1, 2] + R[2, 1]) / s 79 | qz = 0.25 * s 80 | 81 | return Quaternion(np.array((qw, qx, qy, qz))) 82 | 83 | # if angle is None, assume ||axis|| == angle, in radians 84 | # if angle is not None, assume that axis is a unit vector 85 | @staticmethod 86 | def FromAxisAngle(axis, angle=None): 87 | if angle is None: 88 | angle = np.linalg.norm(axis) 89 | if np.abs(angle) > np.finfo('float').eps: 90 | axis = axis / angle 91 | 92 | qw = np.cos(0.5 * angle) 93 | axis = axis * np.sin(0.5 * angle) 94 | 95 | return Quaternion(np.array((qw, axis[0], axis[1], axis[2]))) 96 | 97 | #--------------------------------------------------------------------------- 98 | 99 | def __init__(self, q=np.array((1., 0., 0., 0.))): 100 | if isinstance(q, Quaternion): 101 | self.q = q.q.copy() 102 | else: 103 | q = np.asarray(q) 104 | if q.size == 4: 105 | self.q = q.copy() 106 | elif q.size == 3: # convert from a 3-vector to a quaternion 107 | self.q = np.empty(4) 108 | self.q[0], self.q[1:] = 0., q.ravel() 109 | else: 110 | raise Exception('Input quaternion should be a 3- or 4-vector') 111 | 112 | def __add__(self, other): 113 | return Quaternion(self.q + other.q) 114 | 115 | def __iadd__(self, other): 116 | self.q += other.q 117 | return self 118 | 119 | # conjugation via the ~ operator 120 | def __invert__(self): 121 | return Quaternion(np.array((self.q[0], -self.q[1], -self.q[2], -self.q[3]))) 122 | 123 | # returns: self.q * other.q if other is a Quaternion; otherwise performs 124 | # scalar multiplication 125 | def __mul__(self, other): 126 | if isinstance(other, Quaternion): # quaternion multiplication 127 | return Quaternion( 128 | np.array((self.q[0] * other.q[0] - self.q[1] * other.q[1] - 129 | self.q[2] * other.q[2] - self.q[3] * other.q[3], 130 | self.q[0] * other.q[1] + self.q[1] * other.q[0] + 131 | self.q[2] * other.q[3] - self.q[3] * other.q[2], 132 | self.q[0] * other.q[2] - self.q[1] * other.q[3] + 133 | self.q[2] * other.q[0] + self.q[3] * other.q[1], 134 | self.q[0] * other.q[3] + self.q[1] * other.q[2] - 135 | self.q[2] * other.q[1] + self.q[3] * other.q[0]))) 136 | else: # scalar multiplication (assumed) 137 | return Quaternion(other * self.q) 138 | 139 | def __rmul__(self, other): 140 | return self * other 141 | 142 | def __imul__(self, other): 143 | self.q[:] = (self * other).q 144 | return self 145 | 146 | def __irmul__(self, other): 147 | self.q[:] = (self * other).q 148 | return self 149 | 150 | def __neg__(self): 151 | return Quaternion(-self.q) 152 | 153 | def __sub__(self, other): 154 | return Quaternion(self.q - other.q) 155 | 156 | def __isub__(self, other): 157 | self.q -= other.q 158 | return self 159 | 160 | def __str__(self): 161 | return str(self.q) 162 | 163 | def copy(self): 164 | return Quaternion(self) 165 | 166 | def dot(self, other): 167 | return self.q.dot(other.q) 168 | 169 | # assume the quaternion is nonzero! 170 | def inverse(self): 171 | return Quaternion((~self).q / self.q.dot(self.q)) 172 | 173 | def norm(self): 174 | return np.linalg.norm(self.q) 175 | 176 | def normalize(self): 177 | self.q /= np.linalg.norm(self.q) 178 | return self 179 | 180 | # assume x is a Nx3 numpy array or a numpy 3-vector 181 | def rotate_points(self, x): 182 | x = np.atleast_2d(x) 183 | return x.dot(self.ToR().T) 184 | 185 | # convert to a rotation matrix 186 | def ToR(self): 187 | return np.eye(3) + 2 * np.array(( 188 | (-self.q[2] * self.q[2] - self.q[3] * self.q[3], self.q[1] * self.q[2] - 189 | self.q[3] * self.q[0], self.q[1] * self.q[3] + self.q[2] * self.q[0]), 190 | (self.q[1] * self.q[2] + self.q[3] * self.q[0], -self.q[1] * self.q[1] - 191 | self.q[3] * self.q[3], self.q[2] * self.q[3] - self.q[1] * self.q[0]), 192 | (self.q[1] * self.q[3] - self.q[2] * self.q[0], 193 | self.q[2] * self.q[3] + self.q[1] * self.q[0], 194 | -self.q[1] * self.q[1] - self.q[2] * self.q[2]))) 195 | 196 | # convert to axis-angle representation, with angle encoded by the length 197 | def ToAxisAngle(self): 198 | # recall that for axis-angle representation (a, angle), with "a" unit: 199 | # q = (cos(angle/2), a * sin(angle/2)) 200 | # below, for readability, "theta" actually means half of the angle 201 | 202 | sin_sq_theta = self.q[1:].dot(self.q[1:]) 203 | 204 | # if theta is non-zero, then we can compute a unique rotation 205 | if np.abs(sin_sq_theta) > np.finfo('float').eps: 206 | sin_theta = np.sqrt(sin_sq_theta) 207 | cos_theta = self.q[0] 208 | 209 | # atan2 is more stable, so we use it to compute theta 210 | # note that we multiply by 2 to get the actual angle 211 | angle = 2. * ( 212 | np.arctan2(-sin_theta, -cos_theta) if cos_theta < 0. else np.arctan2( 213 | sin_theta, cos_theta)) 214 | 215 | return self.q[1:] * (angle / sin_theta) 216 | 217 | # otherwise, the result is singular, and we avoid dividing by 218 | # sin(angle/2) = 0 219 | return np.zeros(3) 220 | 221 | # euclideanspace.com/maths/geometry/rotations/conversions/quaternionToEuler 222 | # this assumes the quaternion is non-zero 223 | # returns yaw, pitch, roll, with application in that order 224 | def ToEulerAngles(self): 225 | qsq = self.q**2 226 | k = 2. * (self.q[0] * self.q[3] + self.q[1] * self.q[2]) / qsq.sum() 227 | 228 | if (1. - k) < np.finfo('float').eps: # north pole singularity 229 | return 2. * np.arctan2(self.q[1], self.q[0]), 0.5 * np.pi, 0. 230 | if (1. + k) < np.finfo('float').eps: # south pole singularity 231 | return -2. * np.arctan2(self.q[1], self.q[0]), -0.5 * np.pi, 0. 232 | 233 | yaw = np.arctan2(2. * (self.q[0] * self.q[2] - self.q[1] * self.q[3]), 234 | qsq[0] + qsq[1] - qsq[2] - qsq[3]) 235 | pitch = np.arcsin(k) 236 | roll = np.arctan2(2. * (self.q[0] * self.q[1] - self.q[2] * self.q[3]), 237 | qsq[0] - qsq[1] + qsq[2] - qsq[3]) 238 | 239 | return yaw, pitch, roll 240 | 241 | 242 | #------------------------------------------------------------------------------- 243 | # 244 | # DualQuaternion 245 | # 246 | #------------------------------------------------------------------------------- 247 | 248 | 249 | class DualQuaternion: 250 | # DualQuaternion from an existing rotation + translation 251 | @staticmethod 252 | def FromQT(q, t): 253 | return DualQuaternion(qe=(0.5 * np.asarray(t))) * DualQuaternion(q) 254 | 255 | def __init__(self, q0=np.array((1., 0., 0., 0.)), qe=np.zeros(4)): 256 | self.q0, self.qe = Quaternion(q0), Quaternion(qe) 257 | 258 | def __add__(self, other): 259 | return DualQuaternion(self.q0 + other.q0, self.qe + other.qe) 260 | 261 | def __iadd__(self, other): 262 | self.q0 += other.q0 263 | self.qe += other.qe 264 | return self 265 | 266 | # conguation via the ~ operator 267 | def __invert__(self): 268 | return DualQuaternion(~self.q0, ~self.qe) 269 | 270 | def __mul__(self, other): 271 | if isinstance(other, DualQuaternion): 272 | return DualQuaternion(self.q0 * other.q0, 273 | self.q0 * other.qe + self.qe * other.q0) 274 | elif isinstance(other, complex): # multiplication by a dual number 275 | return DualQuaternion(self.q0 * other.real, 276 | self.q0 * other.imag + self.qe * other.real) 277 | else: # scalar multiplication (assumed) 278 | return DualQuaternion(other * self.q0, other * self.qe) 279 | 280 | def __rmul__(self, other): 281 | return self.__mul__(other) 282 | 283 | def __imul__(self, other): 284 | tmp = self * other 285 | self.q0, self.qe = tmp.q0, tmp.qe 286 | return self 287 | 288 | def __neg__(self): 289 | return DualQuaternion(-self.q0, -self.qe) 290 | 291 | def __sub__(self, other): 292 | return DualQuaternion(self.q0 - other.q0, self.qe - other.qe) 293 | 294 | def __isub__(self, other): 295 | self.q0 -= other.q0 296 | self.qe -= other.qe 297 | return self 298 | 299 | # q^-1 = q* / ||q||^2 300 | # assume that q0 is nonzero! 301 | def inverse(self): 302 | normsq = complex(q0.dot(q0), 2. * self.q0.q.dot(self.qe.q)) 303 | inv_len_real = 1. / normsq.real 304 | return ~self * complex(inv_len_real, 305 | -normsq.imag * inv_len_real * inv_len_real) 306 | 307 | # returns a complex representation of the real and imaginary parts of the norm 308 | # assume that q0 is nonzero! 309 | def norm(self): 310 | q0_norm = self.q0.norm() 311 | return complex(q0_norm, self.q0.dot(self.qe) / q0_norm) 312 | 313 | # assume that q0 is nonzero! 314 | def normalize(self): 315 | # current length is ||q0|| + eps * ( / ||q0||) 316 | # writing this as a + eps * b, the inverse is 317 | # 1/||q|| = 1/a - eps * b / a^2 318 | norm = self.norm() 319 | inv_len_real = 1. / norm.real 320 | self *= complex(inv_len_real, -norm.imag * inv_len_real * inv_len_real) 321 | return self 322 | 323 | # return the translation vector for this dual quaternion 324 | def getT(self): 325 | return 2 * (self.qe * ~self.q0).q[1:] 326 | 327 | def ToQT(self): 328 | return self.q0, self.getT() 329 | -------------------------------------------------------------------------------- /third_party/pycolmap/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="pycolmap", 8 | version="0.0.1", 9 | author="True Price", 10 | description="PyColmap", 11 | long_description=long_description, 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/google/nerfies/third_party/pycolmap", 14 | packages=setuptools.find_packages(), 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ], 20 | python_requires='>=3.6', 21 | ) 22 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training script for Nerf.""" 16 | 17 | import functools 18 | from typing import Dict, Union 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | from flax import jax_utils 24 | from flax import optim 25 | from flax.metrics import tensorboard 26 | from flax.training import checkpoints 27 | import gin 28 | import jax 29 | from jax import numpy as jnp 30 | from jax import random 31 | import numpy as np 32 | import tensorflow as tf 33 | 34 | from nerfies import configs 35 | from nerfies import datasets 36 | from nerfies import gpath 37 | from nerfies import model_utils 38 | from nerfies import models 39 | from nerfies import schedules 40 | from nerfies import training 41 | from nerfies import utils 42 | 43 | flags.DEFINE_enum('mode', None, ['jax_cpu', 'jax_gpu', 'jax_tpu'], 44 | 'Distributed strategy approach.') 45 | 46 | flags.DEFINE_string('base_folder', None, 'where to store ckpts and logs') 47 | flags.mark_flag_as_required('base_folder') 48 | flags.DEFINE_string('data_dir', None, 'input data directory.') 49 | flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') 50 | flags.DEFINE_multi_string('gin_configs', (), 'Gin config files.') 51 | FLAGS = flags.FLAGS 52 | 53 | jax.config.parse_flags_with_absl() 54 | 55 | 56 | def _log_to_tensorboard(writer: tensorboard.SummaryWriter, 57 | state: model_utils.TrainState, 58 | scalar_params: training.ScalarParams, 59 | stats: Dict[str, Union[Dict[str, jnp.ndarray], 60 | jnp.ndarray]], 61 | time_dict: Dict[str, jnp.ndarray]): 62 | """Log statistics to Tensorboard.""" 63 | step = int(state.optimizer.state.step) 64 | writer.scalar('params/learning_rate', scalar_params.learning_rate, step) 65 | writer.scalar('params/warp_alpha', state.warp_alpha, step) 66 | writer.scalar('params/time_alpha', state.time_alpha, step) 67 | writer.scalar('params/elastic_loss/weight', 68 | scalar_params.elastic_loss_weight, step) 69 | 70 | # pmean is applied in train_step so just take the item. 71 | for branch in {'coarse', 'fine'}: 72 | if branch not in stats: 73 | continue 74 | for stat_key, stat_value in stats[branch].items(): 75 | writer.scalar(f'{stat_key}/{branch}', stat_value, step) 76 | 77 | if 'background_loss' in stats: 78 | writer.scalar('loss/background', stats['background_loss'], step) 79 | 80 | for k, v in time_dict.items(): 81 | writer.scalar(f'time/{k}', v, step) 82 | 83 | 84 | def _log_histograms(writer: tensorboard.SummaryWriter, model: models.NerfModel, 85 | state: model_utils.TrainState): 86 | """Log histograms to Tensorboard.""" 87 | step = int(state.optimizer.state.step) 88 | params = state.optimizer.target['model'] 89 | if 'appearance_encoder' in params: 90 | embeddings = params['appearance_encoder']['embed']['embedding'] 91 | writer.histogram('appearance_embedding', embeddings, step) 92 | if 'camera_encoder' in params: 93 | embeddings = params['camera_encoder']['embed']['embedding'] 94 | writer.histogram('camera_embedding', embeddings, step) 95 | if 'warp_field' in params and model.warp_metadata_encoder_type == 'glo': 96 | embeddings = params['warp_field']['metadata_encoder']['embed']['embedding'] 97 | writer.histogram('warp_embedding', embeddings, step) 98 | 99 | 100 | def main(argv): 101 | tf.config.experimental.set_visible_devices([], 'GPU') 102 | del argv 103 | logging.info('*** Starting experiment') 104 | gin_configs = FLAGS.gin_configs 105 | 106 | logging.info('*** Loading Gin configs from: %s', str(gin_configs)) 107 | gin.parse_config_files_and_bindings( 108 | config_files=gin_configs, 109 | bindings=FLAGS.gin_bindings, 110 | skip_unknown=True) 111 | 112 | # Load configurations. 113 | exp_config = configs.ExperimentConfig() 114 | model_config = configs.ModelConfig() 115 | train_config = configs.TrainConfig() 116 | 117 | # Get directory information. 118 | exp_dir = gpath.GPath(FLAGS.base_folder) 119 | if exp_config.subname: 120 | exp_dir = exp_dir / exp_config.subname 121 | summary_dir = exp_dir / 'summaries' / 'train' 122 | checkpoint_dir = exp_dir / 'checkpoints' 123 | 124 | # Log and create directories if this is the main host. 125 | if jax.process_index() == 0: 126 | logging.info('exp_dir = %s', exp_dir) 127 | if not exp_dir.exists(): 128 | exp_dir.mkdir(parents=True, exist_ok=True) 129 | 130 | logging.info('summary_dir = %s', summary_dir) 131 | if not summary_dir.exists(): 132 | summary_dir.mkdir(parents=True, exist_ok=True) 133 | 134 | logging.info('checkpoint_dir = %s', checkpoint_dir) 135 | if not checkpoint_dir.exists(): 136 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 137 | 138 | config_str = gin.operative_config_str() 139 | logging.info('Configuration: \n%s', config_str) 140 | with (exp_dir / 'config.gin').open('w') as f: 141 | f.write(config_str) 142 | 143 | logging.info('Starting host %d. There are %d hosts : %s', jax.process_index(), 144 | jax.process_count(), str(jax.process_indexs())) 145 | logging.info('Found %d accelerator devices: %s.', jax.local_device_count(), 146 | str(jax.local_devices())) 147 | logging.info('Found %d total devices: %s.', jax.device_count(), 148 | str(jax.devices())) 149 | 150 | rng = random.PRNGKey(exp_config.random_seed) 151 | # Shift the numpy random seed by host_id() to shuffle data loaded by different 152 | # hosts. 153 | np.random.seed(exp_config.random_seed + jax.process_index()) 154 | 155 | if train_config.batch_size % jax.device_count() != 0: 156 | raise ValueError('Batch size must be divisible by the number of devices.') 157 | 158 | devices = jax.local_devices() 159 | datasource_spec = exp_config.datasource_spec 160 | if datasource_spec is None: 161 | datasource_spec = { 162 | 'type': exp_config.datasource_type, 163 | 'data_dir': FLAGS.data_dir, 164 | } 165 | logging.info('Creating datasource: %s', datasource_spec) 166 | datasource = datasets.from_config( 167 | datasource_spec, 168 | image_scale=exp_config.image_scale, 169 | use_appearance_id=model_config.use_appearance_metadata, 170 | use_camera_id=model_config.use_camera_metadata, 171 | use_warp_id=model_config.use_warp, 172 | use_time=model_config.warp_metadata_encoder_type == 'time', 173 | random_seed=exp_config.random_seed, 174 | **exp_config.datasource_kwargs) 175 | train_iter = datasource.create_iterator( 176 | datasource.train_ids, 177 | flatten=True, 178 | shuffle=True, 179 | batch_size=train_config.batch_size, 180 | prefetch_size=3, 181 | shuffle_buffer_size=train_config.shuffle_buffer_size, 182 | devices=devices, 183 | ) 184 | 185 | points_iter = None 186 | if train_config.use_background_loss: 187 | points = datasource.load_points(shuffle=True) 188 | points_batch_size = min( 189 | len(points), 190 | len(devices) * train_config.background_points_batch_size) 191 | points_batch_size -= points_batch_size % len(devices) 192 | points_dataset = tf.data.Dataset.from_tensor_slices(points) 193 | points_iter = datasets.iterator_from_dataset( 194 | points_dataset, 195 | batch_size=points_batch_size, 196 | prefetch_size=3, 197 | devices=devices) 198 | 199 | learning_rate_sched = schedules.from_config(train_config.lr_schedule) 200 | warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule) 201 | time_alpha_sched = schedules.from_config(train_config.time_alpha_schedule) 202 | elastic_loss_weight_sched = schedules.from_config( 203 | train_config.elastic_loss_weight_schedule) 204 | 205 | rng, key = random.split(rng) 206 | params = {} 207 | model, params['model'] = models.construct_nerf( 208 | key, 209 | model_config, 210 | batch_size=train_config.batch_size, 211 | appearance_ids=datasource.appearance_ids, 212 | camera_ids=datasource.camera_ids, 213 | warp_ids=datasource.warp_ids, 214 | near=datasource.near, 215 | far=datasource.far, 216 | use_warp_jacobian=train_config.use_elastic_loss, 217 | use_weights=train_config.use_elastic_loss) 218 | 219 | optimizer_def = optim.Adam(learning_rate_sched(0)) 220 | optimizer = optimizer_def.create(params) 221 | state = model_utils.TrainState( 222 | optimizer=optimizer, 223 | warp_alpha=warp_alpha_sched(0), 224 | time_alpha=time_alpha_sched(0)) 225 | scalar_params = training.ScalarParams( 226 | learning_rate=learning_rate_sched(0), 227 | elastic_loss_weight=elastic_loss_weight_sched(0), 228 | warp_reg_loss_weight=train_config.warp_reg_loss_weight, 229 | warp_reg_loss_alpha=train_config.warp_reg_loss_alpha, 230 | warp_reg_loss_scale=train_config.warp_reg_loss_scale, 231 | background_loss_weight=train_config.background_loss_weight) 232 | state = checkpoints.restore_checkpoint(checkpoint_dir, state) 233 | init_step = state.optimizer.state.step + 1 234 | state = jax_utils.replicate(state, devices=devices) 235 | del params 236 | 237 | logging.info('Initializing models') 238 | 239 | summary_writer = None 240 | if jax.process_index() == 0: 241 | summary_writer = tensorboard.SummaryWriter(str(summary_dir)) 242 | summary_writer.text( 243 | 'gin/train', textdata=gin.config.markdown(config_str), step=0) 244 | 245 | train_step = functools.partial( 246 | training.train_step, 247 | model, 248 | elastic_reduce_method=train_config.elastic_reduce_method, 249 | elastic_loss_type=train_config.elastic_loss_type, 250 | use_elastic_loss=train_config.use_elastic_loss, 251 | use_background_loss=train_config.use_background_loss, 252 | use_warp_reg_loss=train_config.use_warp_reg_loss, 253 | ) 254 | ptrain_step = jax.pmap( 255 | train_step, 256 | axis_name='batch', 257 | devices=devices, 258 | # rng_key, state, batch, scalar_params. 259 | in_axes=(0, 0, 0, None), 260 | # Treat use_elastic_loss as compile-time static. 261 | donate_argnums=(2,), # Donate the 'batch' argument. 262 | ) 263 | 264 | if devices: 265 | n_local_devices = len(devices) 266 | else: 267 | n_local_devices = jax.local_device_count() 268 | 269 | logging.info('Starting training') 270 | rng = rng + jax.process_index() # Make random seed separate across hosts. 271 | keys = random.split(rng, n_local_devices) 272 | time_tracker = utils.TimeTracker() 273 | time_tracker.tic('data', 'total') 274 | for step, batch in zip(range(init_step, train_config.max_steps + 1), 275 | train_iter): 276 | if points_iter is not None: 277 | batch['background_points'] = next(points_iter) 278 | time_tracker.toc('data') 279 | # pytype: disable=attribute-error 280 | scalar_params = scalar_params.replace( 281 | learning_rate=learning_rate_sched(step), 282 | elastic_loss_weight=elastic_loss_weight_sched(step)) 283 | warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices) 284 | time_alpha = jax_utils.replicate(time_alpha_sched(step), devices) 285 | state = state.replace(warp_alpha=warp_alpha, time_alpha=time_alpha) 286 | 287 | with time_tracker.record_time('train_step'): 288 | state, stats, keys = ptrain_step(keys, state, batch, scalar_params) 289 | time_tracker.toc('total') 290 | 291 | if step % train_config.print_every == 0 and jax.process_index() == 0: 292 | logging.info('step=%d, warp_alpha=%.04f, time_alpha=%.04f, %s', step, 293 | warp_alpha_sched(step), time_alpha_sched(step), 294 | time_tracker.summary_str('last')) 295 | coarse_metrics_str = ', '.join( 296 | [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()]) 297 | fine_metrics_str = ', '.join( 298 | [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()]) 299 | logging.info('\tcoarse metrics: %s', coarse_metrics_str) 300 | if 'fine' in stats: 301 | logging.info('\tfine metrics: %s', fine_metrics_str) 302 | 303 | if step % train_config.save_every == 0 and jax.process_index() == 0: 304 | training.save_checkpoint(checkpoint_dir, state) 305 | 306 | if step % train_config.log_every == 0 and jax.process_index() == 0: 307 | # Only log via host 0. 308 | _log_to_tensorboard( 309 | summary_writer, 310 | jax_utils.unreplicate(state), 311 | scalar_params, 312 | jax_utils.unreplicate(stats), 313 | time_dict=time_tracker.summary('mean')) 314 | time_tracker.reset() 315 | 316 | if step % train_config.histogram_every == 0 and jax.process_index() == 0: 317 | _log_histograms(summary_writer, model, jax_utils.unreplicate(state)) 318 | 319 | time_tracker.tic('data', 'total') 320 | 321 | if train_config.max_steps % train_config.save_every != 0: 322 | training.save_checkpoint(checkpoint_dir, state) 323 | 324 | 325 | if __name__ == '__main__': 326 | app.run(main) 327 | --------------------------------------------------------------------------------