├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── configs
├── defaults.gin
├── gpu_fullhd.gin
├── gpu_quarterhd.gin
├── gpu_quarterhd_4gpu.gin
├── gpu_vrig_paper.gin
├── test_local.gin
├── test_vrig.gin
└── warp_defaults.gin
├── eval.py
├── nerfies
├── __init__.py
├── camera.py
├── configs.py
├── datasets
│ ├── __init__.py
│ ├── core.py
│ └── nerfies.py
├── evaluation.py
├── glo.py
├── gpath.py
├── image_utils.py
├── model_utils.py
├── models.py
├── modules.py
├── quaternion.py
├── rigid_body.py
├── schedules.py
├── tf_camera.py
├── training.py
├── types.py
├── utils.py
├── visualization.py
└── warping.py
├── notebooks
├── Nerfies_Capture_Processing.ipynb
├── Nerfies_Render_Video.ipynb
└── Nerfies_Training.ipynb
├── requirements.txt
├── setup.py
├── third_party
└── pycolmap
│ ├── LICENSE
│ ├── README.md
│ ├── pycolmap
│ ├── __init__.py
│ ├── camera.py
│ ├── database.py
│ ├── image.py
│ ├── rotation.py
│ └── scene_manager.py
│ └── setup.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | .idea
3 | bower_components
4 | node_modules
5 | venv
6 |
7 | *.ts.map
8 |
9 | *~
10 | *.so
11 | .DS_Store
12 | ._.DS_Store
13 | *.swp
14 |
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | env/
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *,cover
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Nerfies: Deformable Neural Radiance Fields
2 |
3 | This is the code for Nerfies: Deformable Neural Radiance Fields.
4 |
5 | * [Project Page](https://nerfies.github.io)
6 | * [Paper](https://arxiv.org/abs/2011.12948)
7 | * [Video](https://www.youtube.com/watch?v=MrKrnHhk8IA)
8 |
9 | This codebase is implemented using [JAX](https://github.com/google/jax),
10 | building on [JaxNeRF](https://github.com/google-research/google-research/tree/master/jaxnerf).
11 |
12 | This repository has been updated to reflect the version used for our ICCV 2021 submission.
13 |
14 | ## Demo
15 |
16 | We provide an easy-to-get-started demo using Google Colab!
17 |
18 | These Colabs will allow you to train a basic version of our method using
19 | Cloud TPUs (or GPUs) on Google Colab.
20 |
21 | Note that due to limited compute resources available, these are not the fully
22 | featured models. If you would like to train a fully featured Nerfie, please
23 | refer to the instructions below on how to train on your own machine.
24 |
25 | | Description | Link |
26 | | ----------- | ----------- |
27 | | Process a video into a Nerfie dataset| [](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb)|
28 | | Train a Nerfie| [](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Training.ipynb)|
29 | | Render a Nerfie video| [](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Render_Video.ipynb)|
30 |
31 | ## Setup
32 | The code can be run under any environment with Python 3.8 and above.
33 | (It may run with lower versions, but we have not tested it).
34 |
35 | We recommend using [Miniconda](https://docs.conda.io/en/latest/miniconda.html) and setting up an environment:
36 |
37 | conda create --name nerfies python=3.8
38 |
39 | Next, install the required packages:
40 |
41 | pip install -r requirements.txt
42 |
43 | Install the appropriate JAX distribution for your environment by [following the instructions here](https://github.com/google/jax#installation). For example:
44 |
45 | # For CUDA version 11.0
46 | pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
47 |
48 |
49 | ## Training
50 | After preparing a dataset, you can train a Nerfie by running:
51 |
52 | export DATASET_PATH=/path/to/dataset
53 | export EXPERIMENT_PATH=/path/to/save/experiment/to
54 | python train.py \
55 | --data_dir $DATASET_PATH \
56 | --base_folder $EXPERIMENT_PATH \
57 | --gin_configs configs/test_vrig.gin
58 |
59 | To plot telemetry to Tensorboard and render checkpoints on the fly, also
60 | launch an evaluation job by running:
61 |
62 | python eval.py \
63 | --data_dir $DATASET_PATH \
64 | --base_folder $EXPERIMENT_PATH \
65 | --gin_configs configs/test_vrig.gin
66 |
67 | The two jobs should use a mutually exclusive set of GPUs. This division allows the
68 | training job to run without having to stop for evaluation.
69 |
70 | ## Configuration
71 | * We use [Gin](https://github.com/google/gin-config) for configuration.
72 | * We provide a couple preset configurations.
73 | * Please refer to `config.py` for documentation on what each configuration does.
74 | * Preset configs:
75 | - `gpu_vrig_paper.gin`: This is the configuration we used to generate the table in the paper. It requires 8 GPUs for training.
76 | - `gpu_fullhd.gin`: This is a high-resolution model and will take around 3 days to train on 8 GPUs.
77 | - `gpu_quarterhd.gin`: This is a low-resolution model and will take around 14 hours to train on 8 GPUs.
78 | - `test_local.gin`: This is a test configuration to see if the code runs. It probably will not result in a good looking result.
79 | - `test_vrig.gin`: This is a test configuration to see if the code runs for validation rig captures. It probably will not result in a good looking result.
80 | * Training on fewer GPUs will require tuning of the batch size and learning rates. We've provided an example configuration for 4 GPUs in `gpu_quarterhd_4gpu.gin` but we have not tested it, so please only use it as a reference.
81 |
82 | ## Datasets
83 | A dataset is a directory with the following structure:
84 |
85 | dataset
86 | ├── camera
87 | │ └── ${item_id}.json
88 | ├── camera-paths
89 | ├── rgb
90 | │ ├── ${scale}x
91 | │ └── └── ${item_id}.png
92 | ├── metadata.json
93 | ├── points.npy
94 | ├── dataset.json
95 | └── scene.json
96 |
97 | At a high level, a dataset is simply the following:
98 | * A collection of images (e.g., from a video).
99 | * Camera parameters for each image.
100 |
101 | We have a unique identifier for each image which we call `item_id`, and this is
102 | used to match the camera and images. An `item_id` can be any string, but typically
103 | it is some alphanumeric string such as `000054`.
104 |
105 | ### `camera`
106 |
107 | * This directory contains cameras corresponding to each image.
108 | * We use a camera model identical to the [OpenCV camera model](https://docs.opencv.org/master/dc/dbb/tutorial_py_calibration.html), which is also supported by COLMAP.
109 | * Each camera is a serialized version of the `Camera` class defined in `camera.py` and looks like this:
110 |
111 | ```javascript
112 | {
113 | // A 3x3 world-to-camera rotation matrix representing the camera orientation.
114 | "orientation": [
115 | [0.9839, -0.0968, 0.1499],
116 | [-0.0350, -0.9284, -0.3699],
117 | [0.1749, 0.358, -0.9168]
118 | ],
119 | // The 3D position of the camera in world-space.
120 | "position": [-0.3236, -3.26428, 5.4160],
121 | // The focal length of the camera.
122 | "focal_length": 2691,
123 | // The principle point [u_0, v_0] of the camera.
124 | "principal_point": [1220, 1652],
125 | // The skew of the camera.
126 | "skew": 0.0,
127 | // The aspect ratio for the camera pixels.
128 | "pixel_aspect_ratio": 1.0,
129 | // Parameters for the radial distortion of the camera.
130 | "radial_distortion": [0.1004, -0.2090, 0.0],
131 | // Parameters for the tangential distortion of the camera.
132 | "tangential": [0.001109, -2.5733e-05],
133 | // The image width and height in pixels.
134 | "image_size": [2448, 3264]
135 | }
136 | ```
137 |
138 | ### `camera-paths`
139 | * This directory contains test-time camera paths which can be used to render videos.
140 | * Each sub-directory in this path should contain a sequence of JSON files.
141 | * The naming scheme does not matter, but the cameras will be sorted by their filenames.
142 |
143 | ### `rgb`
144 | * This directory contains images at various scales.
145 | * Each subdirectory should be named `${scale}x` where `${scale}` is an integer scaling factor. For example, `1x` would contain the original images while `4x` would contain images a quarter of the size.
146 | * We assume the images are in PNG format.
147 | * It is important the scaled images are integer factors of the original to allow the use of area relation when scaling the images to prevent Moiré. A simple way to do this is to simply trim the borders of the image to be divisible by the maximum scale factor you want.
148 |
149 | ### `metadata.json`
150 | * This defines the 'metadata' IDs used for embedding lookups.
151 | * Contains a dictionary of the following format:
152 |
153 | ```javascript
154 | {
155 | "${item_id}": {
156 | // The embedding ID used to fetch the deformation latent code
157 | // passed to the deformation field.
158 | "warp_id": 0,
159 | // The embedding ID used to fetch the appearance latent code
160 | // which is passed to the second branch of the template NeRF.
161 | "appearance_id": 0,
162 | // For validation rig datasets, we use the camera ID instead
163 | // of the appearance ID. For example, this would be '0' for the
164 | // left camera and '1' for the right camera. This can potentially
165 | // also be used for multi-view setups as well.
166 | "camera_id": 0
167 | },
168 | ...
169 | },
170 | ```
171 | ### `scene.json`
172 | * Contains information about how we will parse the scene.
173 | * See comments inline.
174 |
175 | ```javascript
176 | {
177 | // The scale factor we will apply to the pointcloud and cameras. This is
178 | // important since it controls what scale is used when computing the positional
179 | // encoding.
180 | "scale": 0.0387243672920458,
181 | // Defines the origin of the scene. The scene will be translated such that
182 | // this point becomes the origin. Defined in unscaled coordinates.
183 | "center": [
184 | 1.1770838526103944e-08,
185 | -2.58235339289195,
186 | -1.29117656263135
187 | ],
188 | // The distance of the near plane from the camera center in scaled coordinates.
189 | "near": 0.02057418950149491,
190 | // The distance of the far plane from the camera center in scaled coordinates.
191 | "far": 0.8261601717667288
192 | }
193 | ```
194 |
195 | ### `dataset.json`
196 | * Defines the training/validation split of the dataset.
197 | * See inline comments:
198 |
199 | ```javascript
200 | {
201 | // The total number of images in the dataset.
202 | "count": 114,
203 | // The total number of training images (exemplars) in the dataset.
204 | "num_exemplars": 57,
205 | // A list containins all item IDs in the dataset.
206 | "ids": [...],
207 | // A list containing all training item IDs in the dataset.
208 | "train_ids": [...],
209 | // A list containing all validation item IDs in the dataset.
210 | // This should be mutually exclusive with `train_ids`.
211 | "val_ids": [...],
212 | }
213 | ```
214 |
215 | ### `points.npy`
216 |
217 | * A numpy file containing a single array of size `(N,3)` containing the background points.
218 | * This is required if you want to use the background regularization loss.
219 |
220 | ## Citing
221 | If you find our work useful, please consider citing:
222 | ```BibTeX
223 | @article{park2021nerfies
224 | author = {Park, Keunhong
225 | and Sinha, Utkarsh
226 | and Barron, Jonathan T.
227 | and Bouaziz, Sofien
228 | and Goldman, Dan B
229 | and Seitz, Steven M.
230 | and Martin-Brualla, Ricardo},
231 | title = {Nerfies: Deformable Neural Radiance Fields},
232 | journal = {ICCV},
233 | year = {2021},
234 | }
235 | ```
236 |
--------------------------------------------------------------------------------
/configs/defaults.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # This is the base configuration that is imported by other configurations.
17 | # Do not run this configuration directly.
18 |
19 | num_warp_freqs = 8
20 | elastic_init_weight = 0.01
21 | lr_delay_steps = 2500
22 | lr_delay_mult = 0.01
23 |
24 | # Predefined warp alpha schedules.
25 | ANNEALED_WARP_ALPHA_SCHED = {
26 | 'type': 'linear',
27 | 'initial_value': 0.0,
28 | 'final_value': %num_warp_freqs,
29 | 'num_steps': 80000,
30 | }
31 | CONSTANT_WARP_ALPHA_SCHED = {
32 | 'type': 'constant',
33 | 'value': %num_warp_freqs,
34 | }
35 |
36 | # Predefined elastic loss schedules.
37 | CONSTANT_ELASTIC_LOSS_SCHED = {
38 | 'type': 'constant',
39 | 'value': %elastic_init_weight,
40 | }
41 | DECAYING_ELASTIC_LOSS_SCHED = {
42 | 'type': 'piecewise',
43 | 'schedules': [
44 | (50000, ('constant', %elastic_init_weight)),
45 | (100000, ('cosine_easing', %elastic_init_weight, 1e-8, 100000)),
46 | ]
47 | }
48 |
49 | DEFAULT_LR_SCHEDULE = {
50 | 'type': 'exponential',
51 | 'initial_value': %init_lr,
52 | 'final_value': %final_lr,
53 | 'num_steps': %max_steps,
54 | }
55 |
56 | DELAYED_LR_SCHEDULE = {
57 | 'type': 'delayed',
58 | 'delay_steps': %lr_delay_steps,
59 | 'delay_mult': %lr_delay_mult,
60 | 'base_schedule': %DEFAULT_LR_SCHEDULE,
61 | }
62 |
63 | # Common configs.
64 | ModelConfig.use_viewdirs = True
65 | ModelConfig.use_stratified_sampling = True
66 | ModelConfig.sigma_activation = @nn.softplus
67 | ModelConfig.use_appearance_metadata = False
68 |
69 | # Experiment configs.
70 | ExperimentConfig.image_scale = %image_scale
71 | ExperimentConfig.random_seed = 12345
72 |
73 | # Warp field configs.
74 | ModelConfig.use_warp = False
75 | ModelConfig.warp_field_type = 'se3'
76 | ModelConfig.num_warp_freqs = %num_warp_freqs
77 | ModelConfig.num_warp_features = 8
78 |
79 | # Use macros to make sure these are set somewhere.
80 | TrainConfig.batch_size = %batch_size
81 | TrainConfig.max_steps = %max_steps
82 | TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE
83 | TrainConfig.warp_alpha_schedule = %CONSTANT_WARP_ALPHA_SCHED
84 |
85 | # Elastic loss.
86 | TrainConfig.use_elastic_loss = False
87 | TrainConfig.elastic_loss_weight_schedule = %CONSTANT_ELASTIC_LOSS_SCHED
88 |
89 | # Background regularization loss.
90 | TrainConfig.use_background_loss = False
91 | TrainConfig.background_loss_weight = 1.0
92 |
93 | # Script interval configs.
94 | TrainConfig.print_every = 100
95 | TrainConfig.log_every = 500
96 | TrainConfig.save_every = 5000
97 |
98 | EvalConfig.eval_once = False
99 | EvalConfig.save_output = True
100 | EvalConfig.chunk = %eval_batch_size
101 |
--------------------------------------------------------------------------------
/configs/gpu_fullhd.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # This is a Full HD (1080p-ish) configuration.
17 | # The image scale is based on our input video size of 1920x1080.
18 | # This configuration requires 8 GPUs to train.
19 | #
20 | # To make this runnable on fewer GPUs, decrease the `batch_size` and scale the
21 | # learning rate by the sqrt of the factor by which you decreased it. Expect
22 | # the results to look slightly worse without tuning.
23 |
24 | include 'warp_defaults.gin'
25 |
26 | max_steps = 1000000
27 | lr_decay_steps = 2000000
28 |
29 | image_scale = 1
30 | batch_size = 4096
31 | eval_batch_size = 4096
32 | init_lr = 0.00075
33 | final_lr = 0.000075
34 |
35 | ModelConfig.num_nerf_point_freqs = 10
36 | ModelConfig.nerf_trunk_width = 256
37 | ModelConfig.nerf_trunk_depth = 8
38 | ModelConfig.num_coarse_samples = 256
39 | ModelConfig.num_fine_samples = 256
40 |
41 | TrainConfig.print_every = 200
42 | TrainConfig.log_every = 500
43 | TrainConfig.save_every = 10000
44 |
--------------------------------------------------------------------------------
/configs/gpu_quarterhd.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # This is a quarter HD configuration.
17 | # The image scale is based on our input video size of 1920x1080.
18 | # This configuration requires 8 GPUs to train.
19 | #
20 | # To make this runnable on fewer GPUs, decrease the `batch_size` and scale the
21 | # learning rate by the sqrt of the factor by which you decreased it. Expect
22 | # the results to look slightly worse without tuning.
23 |
24 | include 'warp_defaults.gin'
25 |
26 | max_steps = 250000
27 | lr_decay_steps = 500000
28 |
29 | image_scale = 4
30 | batch_size = 6144
31 | eval_batch_size = 8096
32 | init_lr = 0.001
33 | final_lr = 0.0001
34 |
35 | ModelConfig.num_nerf_point_freqs = 8
36 | ModelConfig.nerf_trunk_width = 256
37 | ModelConfig.nerf_trunk_depth = 8
38 | ModelConfig.num_coarse_samples = 128
39 | ModelConfig.num_fine_samples = 128
40 |
41 | TrainConfig.print_every = 200
42 | TrainConfig.log_every = 500
43 | TrainConfig.save_every = 5000
44 |
--------------------------------------------------------------------------------
/configs/gpu_quarterhd_4gpu.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # This is a quarter HD configuration for 4 GPUs.
17 | # The image scale is based on our input video size of 1920x1080.
18 | # This configuration requires 4 GPUs to train.
19 | #
20 | # Note that this configuration has not been tested and may require further
21 | # tuning.
22 |
23 | include 'warp_defaults.gin'
24 |
25 | max_steps = 500000
26 | lr_decay_steps = 1000000
27 |
28 | image_scale = 4
29 | batch_size = 3072
30 | eval_batch_size = 4096
31 | init_lr = 0.0007
32 | final_lr = 0.00007
33 |
34 | ModelConfig.num_nerf_point_freqs = 8
35 | ModelConfig.nerf_trunk_width = 256
36 | ModelConfig.nerf_trunk_depth = 8
37 | ModelConfig.num_coarse_samples = 128
38 | ModelConfig.num_fine_samples = 128
39 |
40 | TrainConfig.print_every = 200
41 | TrainConfig.log_every = 500
42 | TrainConfig.save_every = 5000
43 |
--------------------------------------------------------------------------------
/configs/gpu_vrig_paper.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # This is the validation rig configuration we used in the quantitative
17 | # evaluation of the paper. The `image_scale` is based on our raw dataset
18 | # resolution of 4032x3024 from the validation rig.
19 | # This configuration requires 8 GPUs to train.
20 |
21 | include 'configs/warp_defaults.gin'
22 |
23 | max_steps = 250000
24 | lr_decay_steps = %max_steps
25 |
26 | image_scale = 4
27 | batch_size = 6144
28 | eval_batch_size = 8096
29 | init_lr = 0.001
30 | final_lr = 0.0001
31 | elastic_init_weight = 0.001
32 | num_warp_freqs = 6
33 |
34 | ModelConfig.use_warp = True
35 | ModelConfig.num_nerf_point_freqs = 8
36 | ModelConfig.nerf_trunk_width = 256
37 | ModelConfig.nerf_trunk_depth = 8
38 | ModelConfig.num_coarse_samples = 128
39 | ModelConfig.num_fine_samples = 128
40 | ModelConfig.use_appearance_metadata = False
41 | ModelConfig.use_camera_metadata = True
42 | ModelConfig.use_stratified_sampling = True
43 | ModelConfig.camera_metadata_dims = 2
44 | ModelConfig.use_sample_at_infinity = True
45 | ModelConfig.warp_field_type = 'se3'
46 |
47 | TrainConfig.print_every = 500
48 | TrainConfig.log_every = 500
49 | TrainConfig.histogram_every = 1000
50 | TrainConfig.save_every = 5000
51 |
52 | TrainConfig.use_elastic_loss = True
53 | TrainConfig.use_background_loss = True
54 | TrainConfig.background_loss_weight = 1.0
55 | TrainConfig.warp_alpha_schedule = %ANNEALED_WARP_ALPHA_SCHED
56 |
57 | TrainConfig.use_warp_reg_loss = False
58 | TrainConfig.warp_reg_loss_weight = 1e-2
59 |
60 | TrainConfig.elastic_reduce_method = 'weight'
61 | TrainConfig.elastic_loss_weight_schedule = {
62 | 'type': 'constant',
63 | 'value': %elastic_init_weight,
64 | }
65 | TrainConfig.lr_schedule = %DEFAULT_LR_SCHEDULE
66 |
67 | EvalConfig.num_val_eval = None
68 | EvalConfig.num_train_eval = None
69 |
--------------------------------------------------------------------------------
/configs/test_local.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # This is a test configuration for sanity checking.
17 | # It will likely not result in a good quality reconstruction.
18 | # This config will run on a single GPU.
19 |
20 | elastic_init_weight = 0.01
21 | max_steps = 250000
22 |
23 | ExperimentConfig.image_scale = 4
24 |
25 | ModelConfig.num_coarse_samples = 64
26 | ModelConfig.num_fine_samples = 64
27 | ModelConfig.use_viewdirs = True
28 | ModelConfig.use_stratified_sampling = True
29 | ModelConfig.use_appearance_metadata = True
30 | ModelConfig.use_warp = True
31 | ModelConfig.warp_field_type = 'se3'
32 | ModelConfig.num_warp_features = 3
33 | ModelConfig.num_warp_freqs = 8
34 | ModelConfig.sigma_activation = @nn.softplus
35 |
36 | TrainConfig.max_steps = 200000
37 | TrainConfig.lr_schedule = {
38 | 'type': 'exponential',
39 | 'initial_value': 0.001,
40 | 'final_value': 0.0001,
41 | 'num_steps': %max_steps,
42 | }
43 | TrainConfig.batch_size = 1024
44 | TrainConfig.warp_alpha_schedule = {
45 | 'type': 'linear',
46 | 'initial_value': 0.0,
47 | 'final_value': 8.0,
48 | 'num_steps': 80000,
49 | }
50 | TrainConfig.use_elastic_loss = True
51 | TrainConfig.elastic_loss_weight_schedule = {
52 | 'type': 'piecewise',
53 | 'schedules': [
54 | (50000, ('constant', %elastic_init_weight)),
55 | (100000, ('cosine_easing', %elastic_init_weight, 1e-8, 100000)),
56 | ]
57 | }
58 | TrainConfig.use_background_loss = False
59 | TrainConfig.background_loss_weight = 1.0
60 |
61 | TrainConfig.print_every = 10
62 | TrainConfig.log_every = 100
63 | TrainConfig.save_every = 1000
64 |
65 | EvalConfig.eval_once = False
66 | EvalConfig.save_output = True
67 | EvalConfig.chunk = 8192
68 |
--------------------------------------------------------------------------------
/configs/test_vrig.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # This is a test configuration for sanity checking.
17 | # It will likely not result in a good quality reconstruction.
18 | # This config will run on a single GPU.
19 |
20 | include 'defaults.gin'
21 |
22 | max_steps = 250000
23 |
24 | image_scale = 8
25 | batch_size = 1024
26 | eval_batch_size = 1024
27 | init_lr = 0.001
28 | final_lr = 0.0001
29 | lr_decay_steps = 500000
30 | elastic_init_weight = 0.001
31 |
32 | ModelConfig.num_nerf_point_freqs = 8
33 | ModelConfig.nerf_trunk_width = 128
34 | ModelConfig.nerf_trunk_depth = 8
35 | ModelConfig.num_coarse_samples = 64
36 | ModelConfig.num_fine_samples = 64
37 | ModelConfig.use_appearance_metadata = False
38 | ModelConfig.use_camera_metadata = True
39 | ModelConfig.use_stratified_sampling = False
40 | ModelConfig.camera_metadata_dims = 2
41 | ModelConfig.use_warp = True
42 |
43 | TrainConfig.use_elastic_loss = True
44 | TrainConfig.use_background_loss = True
45 | TrainConfig.print_every = 1
46 | TrainConfig.log_every = 100
47 | TrainConfig.save_every = 1000
48 |
49 | EvalConfig.chunk = 8192
50 | EvalConfig.num_val_eval = None
51 | EvalConfig.num_train_eval = None
52 |
--------------------------------------------------------------------------------
/configs/warp_defaults.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # This file is the base configuration for Nerfies.
17 | # Do not run this directly, it is for importing from other configurations.
18 |
19 | include 'defaults.gin'
20 |
21 | ModelConfig.use_warp = True
22 | ModelConfig.use_appearance_metadata = True
23 |
24 | TrainConfig.warp_alpha_schedule = %ANNEALED_WARP_ALPHA_SCHED
25 | TrainConfig.elastic_loss_weight_schedule = %DECAYING_ELASTIC_LOSS_SCHED
26 | TrainConfig.use_elastic_loss = True
27 | TrainConfig.use_background_loss = True
28 |
--------------------------------------------------------------------------------
/nerfies/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/nerfies/configs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Configuration classes."""
16 | from typing import Any, Mapping, Optional, Tuple
17 |
18 | import dataclasses
19 | import flax.linen as nn
20 | import gin
21 | import immutabledict
22 |
23 | from nerfies import types
24 |
25 | ScheduleDef = Any
26 |
27 | gin.config.external_configurable(nn.elu, module='flax.nn')
28 | gin.config.external_configurable(nn.relu, module='flax.nn')
29 | gin.config.external_configurable(nn.leaky_relu, module='flax.nn')
30 | gin.config.external_configurable(nn.tanh, module='flax.nn')
31 | gin.config.external_configurable(nn.sigmoid, module='flax.nn')
32 | gin.config.external_configurable(nn.softplus, module='flax.nn')
33 |
34 |
35 | @gin.configurable()
36 | @dataclasses.dataclass
37 | class ModelConfig:
38 | """Parameters for the model."""
39 | # Sample linearly in disparity rather than depth.
40 | use_linear_disparity: bool = False
41 | # Use white as the default background.
42 | use_white_background: bool = False
43 | # Use stratified sampling.
44 | use_stratified_sampling: bool = True
45 | # Use the sample at infinity.
46 | use_sample_at_infinity: bool = True
47 | # The standard deviation of the alpha noise.
48 | noise_std: Optional[float] = None
49 |
50 | # The depth of the NeRF.
51 | nerf_trunk_depth: int = 8
52 | # The width of the NeRF.
53 | nerf_trunk_width: int = 256
54 | # The depth of the conditional part of the MLP.
55 | nerf_rgb_branch_depth: int = 1
56 | # The width of the conditional part of the MLP.
57 | nerf_rgb_branch_width: int = 128
58 | # The intermediate activation for the NeRF.
59 | activation: types.Activation = nn.relu
60 | # The sigma activation for the NeRF.
61 | sigma_activation: types.Activation = nn.relu
62 | # Adds a skip connection every N layers.
63 | nerf_skips: Tuple[int] = (4,)
64 | # The number of alpha channels.
65 | alpha_channels: int = 1
66 | # The number of RGB channels.
67 | rgb_channels: int = 3
68 | # The number of positional encodings for points.
69 | num_nerf_point_freqs: int = 10
70 | # The number of positional encodings for viewdirs.
71 | num_nerf_viewdir_freqs: int = 4
72 | # The number of coarse samples along each ray.
73 | num_coarse_samples: int = 64
74 | # The number of fine samples along each ray.
75 | num_fine_samples: int = 128
76 | # Whether to use view directions.
77 | use_viewdirs: bool = True
78 | # Whether to condition the entire NeRF MLP.
79 | use_trunk_condition: bool = False
80 | # Whether to condition the density of the template NeRF.
81 | use_alpha_condition: bool = False
82 | # Whether to condition the RGB of the template NeRF.
83 | use_rgb_condition: bool = False
84 |
85 | # Whether to use the appearance metadata for the conditional branch.
86 | use_appearance_metadata: bool = False
87 | # The number of dimensions for the appearance metadata.
88 | appearance_metadata_dims: int = 8
89 | # Whether to use the camera metadata for the conditional branch.
90 | use_camera_metadata: bool = False
91 | # The number of dimensions for the camera metadata.
92 | camera_metadata_dims: int = 2
93 |
94 | # Whether to use the warp field.
95 | use_warp: bool = False
96 | # The number of frequencies for the warp field.
97 | num_warp_freqs: int = 8
98 | # The number of dimensions for the warp metadata.
99 | num_warp_features: int = 8
100 | # The type of warp field to use. One of: 'translation', or 'se3'.
101 | warp_field_type: str = 'translation'
102 | # The type of metadata encoder the warp field should use.
103 | warp_metadata_encoder_type: str = 'glo'
104 | # Additional keyword arguments to pass to the warp field.
105 | warp_kwargs: Mapping[str, Any] = immutabledict.immutabledict()
106 |
107 |
108 | @gin.configurable()
109 | @dataclasses.dataclass
110 | class ExperimentConfig:
111 | """Experiment configuration."""
112 | # A subname for the experiment e.g., for parameter sweeps. If this is set
113 | # experiment artifacts will be saves to a subdirectory with this name.
114 | subname: Optional[str] = None
115 | # The image scale to use for the dataset. Should be a power of 2.
116 | image_scale: int = 4
117 | # The random seed used to initialize the RNGs for the experiment.
118 | random_seed: int = 12345
119 | # The type of datasource. Either 'nerfies' or 'dynamic_scene'.
120 | datasource_type: str = 'nerfies'
121 | # Data source specification.
122 | datasource_spec: Optional[Mapping[str, Any]] = None
123 | # Extra keyword arguments to pass to the datasource.
124 | datasource_kwargs: Mapping[str, Any] = immutabledict.immutabledict()
125 |
126 |
127 | @gin.configurable()
128 | @dataclasses.dataclass
129 | class TrainConfig:
130 | """Parameters for training."""
131 | batch_size: int = gin.REQUIRED
132 |
133 | # The definition for the learning rate schedule.
134 | lr_schedule: ScheduleDef = immutabledict.immutabledict({
135 | 'type': 'exponential',
136 | 'initial_value': 0.001,
137 | 'final_value': 0.0001,
138 | 'num_steps': 1000000,
139 | })
140 | # The maximum number of training steps.
141 | max_steps: int = 1000000
142 |
143 | # The start value of the warp alpha.
144 | warp_alpha_schedule: ScheduleDef = immutabledict.immutabledict({
145 | 'type': 'linear',
146 | 'initial_value': 0.0,
147 | 'final_value': 8.0,
148 | 'num_steps': 80000,
149 | })
150 |
151 | # The time encoder alpha schedule.
152 | time_alpha_schedule: ScheduleDef = ('constant', 0.0)
153 |
154 | # Whether to use the elastic regularization loss.
155 | use_elastic_loss: bool = False
156 | # The weight of the elastic regularization loss.
157 | elastic_loss_weight_schedule: ScheduleDef = ('constant', 0.0)
158 | # Which method to use to reduce the samples for the elastic loss.
159 | # 'weight' computes a weighted sum using the density weights, and 'median'
160 | # selects the sample at the median depth point.
161 | elastic_reduce_method: str = 'weight'
162 | # Which loss method to use for the elastic loss.
163 | elastic_loss_type: str = 'log_svals'
164 | # Whether to use background regularization.
165 | use_background_loss: bool = False
166 | # The weight for the background loss.
167 | background_loss_weight: float = 0.0
168 | # The batch size for background regularization loss.
169 | background_points_batch_size: int = 16384
170 | # Whether to use the warp reg loss.
171 | use_warp_reg_loss: bool = False
172 | # The weight for the warp reg loss.
173 | warp_reg_loss_weight: float = 0.0
174 | # The alpha for the warp reg loss.
175 | warp_reg_loss_alpha: float = -2.0
176 | # The scale for the warp reg loss.
177 | warp_reg_loss_scale: float = 0.001
178 |
179 | # The size of the shuffle buffer size when shuffling the training dataset.
180 | # This needs to be sufficiently large to contain a diverse set of images in
181 | # each batch, especially when optimizing GLO embeddings.
182 | shuffle_buffer_size: int = 5000000
183 | # How often to save a checkpoint.
184 | save_every: int = 10000
185 | # How often to log to Tensorboard.
186 | log_every: int = 500
187 | # How often to log histograms to Tensorboard.
188 | histogram_every: int = 5000
189 | # How often to print to the console.
190 | print_every: int = 25
191 |
192 |
193 | @gin.configurable()
194 | @dataclasses.dataclass
195 | class EvalConfig:
196 | """Parameters for evaluation."""
197 | # If True only evaluate the model once, otherwise evaluate any new
198 | # checkpoints.
199 | eval_once: bool = False
200 | # If True save the predicted images to persistent storage.
201 | save_output: bool = True
202 | # The evaluation batch size.
203 | chunk: int = 8192
204 | # Max render checkpoints. The renders will rotate after this many.
205 | max_render_checkpoints = 3
206 |
207 | # The number of validation examples to evaluate. (Default: all).
208 | num_val_eval: Optional[int] = 10
209 | # The number of training examples to evaluate.
210 | num_train_eval: Optional[int] = 10
211 | # The number of test examples to evaluate.
212 | num_test_eval: Optional[int] = 10
213 |
--------------------------------------------------------------------------------
/nerfies/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Dataset definition and utility package."""
16 | from nerfies.datasets.core import *
17 | from nerfies.datasets.nerfies import NerfiesDataSource
18 |
19 |
20 | def from_config(spec, **kwargs):
21 | """Create a datasource from a config specification."""
22 | spec = dict(spec)
23 | ds_type = spec.pop('type')
24 | if ds_type == 'nerfies':
25 | return NerfiesDataSource(**spec, **kwargs)
26 |
27 | raise ValueError(f'Unknown datasource type {ds_type!r}')
28 |
--------------------------------------------------------------------------------
/nerfies/datasets/nerfies.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Casual Volumetric Capture datasets."""
16 | import json
17 | from typing import List, Tuple
18 |
19 | from absl import logging
20 | import cv2
21 | import numpy as np
22 |
23 | from nerfies import gpath
24 | from nerfies import types
25 | from nerfies import utils
26 | from nerfies.datasets import core
27 |
28 |
29 | def load_scene_info(
30 | data_dir: types.PathType) -> Tuple[np.ndarray, float, float, float]:
31 | """Loads the scene scale from scene_scale.npy.
32 |
33 | Args:
34 | data_dir: the path to the dataset.
35 |
36 | Returns:
37 | scene_center: the center of the scene (unscaled coordinates).
38 | scene_scale: the scale of the scene.
39 | near: the near plane of the scene (scaled coordinates).
40 | far: the far plane of the scene (scaled coordinates).
41 |
42 | Raises:
43 | ValueError if scene_scale.npy does not exist.
44 | """
45 | scene_json_path = gpath.GPath(data_dir, 'scene.json')
46 | with scene_json_path.open('r') as f:
47 | scene_json = json.load(f)
48 |
49 | scene_center = np.array(scene_json['center'])
50 | scene_scale = scene_json['scale']
51 | near = scene_json['near']
52 | far = scene_json['far']
53 |
54 | return scene_center, scene_scale, near, far
55 |
56 |
57 | def _load_image(path: types.PathType) -> np.ndarray:
58 | path = gpath.GPath(path)
59 | with path.open('rb') as f:
60 | raw_im = np.asarray(bytearray(f.read()), dtype=np.uint8)
61 | image = cv2.imdecode(raw_im, cv2.IMREAD_COLOR)[:, :, ::-1] # BGR -> RGB
62 | image = np.asarray(image).astype(np.float32) / 255.0
63 | return image
64 |
65 |
66 | def _load_dataset_ids(data_dir: types.PathType) -> Tuple[List[str], List[str]]:
67 | """Loads dataset IDs."""
68 | dataset_json_path = gpath.GPath(data_dir, 'dataset.json')
69 | logging.info('*** Loading dataset IDs from %s', dataset_json_path)
70 | with dataset_json_path.open('r') as f:
71 | dataset_json = json.load(f)
72 | train_ids = dataset_json['train_ids']
73 | val_ids = dataset_json['val_ids']
74 |
75 | train_ids = [str(i) for i in train_ids]
76 | val_ids = [str(i) for i in val_ids]
77 |
78 | return train_ids, val_ids
79 |
80 |
81 | class NerfiesDataSource(core.DataSource):
82 | """Data loader for videos."""
83 |
84 | def __init__(
85 | self,
86 | data_dir,
87 | image_scale: int,
88 | shuffle_pixels=False,
89 | camera_type='json',
90 | test_camera_trajectory='orbit-extreme',
91 | **kwargs):
92 | self.data_dir = gpath.GPath(data_dir)
93 | # Load IDs from JSON if it exists. This is useful since COLMAP fails on
94 | # some images so this gives us the ability to skip invalid images.
95 | train_ids, val_ids = _load_dataset_ids(self.data_dir)
96 | super().__init__(train_ids=train_ids, val_ids=val_ids,
97 | **kwargs)
98 | self.scene_center, self.scene_scale, self._near, self._far = \
99 | load_scene_info(self.data_dir)
100 | self.test_camera_trajectory = test_camera_trajectory
101 |
102 | self.image_scale = image_scale
103 | self.shuffle_pixels = shuffle_pixels
104 |
105 | self.rgb_dir = gpath.GPath(data_dir, 'rgb', f'{image_scale}x')
106 | self.depth_dir = gpath.GPath(data_dir, 'depth', f'{image_scale}x')
107 | self.camera_type = camera_type
108 | self.camera_dir = gpath.GPath(data_dir, 'camera')
109 |
110 | metadata_path = self.data_dir / 'metadata.json'
111 | self.metadata_dict = None
112 | if metadata_path.exists():
113 | with metadata_path.open('r') as f:
114 | self.metadata_dict = json.load(f)
115 |
116 | @property
117 | def near(self):
118 | return self._near
119 |
120 | @property
121 | def far(self):
122 | return self._far
123 |
124 | @property
125 | def camera_ext(self):
126 | if self.camera_type == 'json':
127 | return '.json'
128 |
129 | raise ValueError(f'Unknown camera_type {self.camera_type}')
130 |
131 | def get_rgb_path(self, item_id):
132 | return self.rgb_dir / f'{item_id}.png'
133 |
134 | def load_rgb(self, item_id):
135 | return _load_image(self.rgb_dir / f'{item_id}.png')
136 |
137 | def load_camera(self, item_id, scale_factor=1.0):
138 | if isinstance(item_id, gpath.GPath):
139 | camera_path = item_id
140 | else:
141 | if self.camera_type == 'json':
142 | camera_path = self.camera_dir / f'{item_id}{self.camera_ext}'
143 | else:
144 | raise ValueError(f'Unknown camera type {self.camera_type!r}.')
145 |
146 | return core.load_camera(camera_path,
147 | scale_factor=scale_factor / self.image_scale,
148 | scene_center=self.scene_center,
149 | scene_scale=self.scene_scale)
150 |
151 | def glob_cameras(self, path):
152 | path = gpath.GPath(path)
153 | return sorted(path.glob(f'*{self.camera_ext}'))
154 |
155 | def load_test_cameras(self, count=None):
156 | camera_dir = (self.data_dir / 'camera-paths' / self.test_camera_trajectory)
157 | if not camera_dir.exists():
158 | logging.warning('test camera path does not exist: %s', str(camera_dir))
159 | return []
160 | camera_paths = sorted(camera_dir.glob(f'*{self.camera_ext}'))
161 | if count is not None:
162 | stride = max(1, len(camera_paths) // count)
163 | camera_paths = camera_paths[::stride]
164 | cameras = utils.parallel_map(self.load_camera, camera_paths)
165 | return cameras
166 |
167 | def load_points(self, shuffle=False):
168 | with (self.data_dir / 'points.npy').open('rb') as f:
169 | points = np.load(f)
170 | points = (points - self.scene_center) * self.scene_scale
171 | points = points.astype(np.float32)
172 | if shuffle:
173 | logging.info('Shuffling points.')
174 | shuffled_inds = self.rng.permutation(len(points))
175 | points = points[shuffled_inds]
176 | logging.info('Loaded %d points.', len(points))
177 | return points
178 |
179 | def get_appearance_id(self, item_id):
180 | return self.metadata_dict[item_id]['appearance_id']
181 |
182 | def get_camera_id(self, item_id):
183 | return self.metadata_dict[item_id]['camera_id']
184 |
185 | def get_warp_id(self, item_id):
186 | return self.metadata_dict[item_id]['warp_id']
187 |
188 | def get_time_id(self, item_id):
189 | if 'time_id' in self.metadata_dict[item_id]:
190 | return self.metadata_dict[item_id]['time_id']
191 | else:
192 | # Fallback for older datasets.
193 | return self.metadata_dict[item_id]['warp_id']
194 |
--------------------------------------------------------------------------------
/nerfies/evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module for evaluating a trained NeRF."""
16 | import math
17 | import time
18 |
19 | from absl import logging
20 | from flax import jax_utils
21 | import jax
22 | from jax import tree_util
23 | import jax.numpy as jnp
24 |
25 | from nerfies import utils
26 |
27 |
28 | def render_image(
29 | state,
30 | rays_dict,
31 | model_fn,
32 | device_count,
33 | rng,
34 | chunk=8192,
35 | default_ret_key=None):
36 | """Render all the pixels of an image (in test mode).
37 |
38 | Args:
39 | state: model_utils.TrainState.
40 | rays_dict: dict, test example.
41 | model_fn: function, jit-ed render function.
42 | device_count: The number of devices to shard batches over.
43 | rng: The random number generator.
44 | chunk: int, the size of chunks to render sequentially.
45 | default_ret_key: either 'fine' or 'coarse'. If None will default to highest.
46 |
47 | Returns:
48 | rgb: jnp.ndarray, rendered color image.
49 | depth: jnp.ndarray, rendered depth.
50 | acc: jnp.ndarray, rendered accumulated weights per pixel.
51 | """
52 | h, w = rays_dict['origins'].shape[:2]
53 | rays_dict = tree_util.tree_map(lambda x: x.reshape((h * w, -1)), rays_dict)
54 | num_rays = h * w
55 | _, key_0, key_1 = jax.random.split(rng, 3)
56 | key_0 = jax.random.split(key_0, device_count)
57 | key_1 = jax.random.split(key_1, device_count)
58 | host_id = jax.process_index()
59 | ret_maps = []
60 | start_time = time.time()
61 | num_batches = int(math.ceil(num_rays / chunk))
62 | for batch_idx in range(num_batches):
63 | ray_idx = batch_idx * chunk
64 | logging.log_every_n_seconds(
65 | logging.INFO, 'Rendering batch %d/%d (%d/%d)', 2.0,
66 | batch_idx, num_batches, ray_idx, num_rays)
67 | # pylint: disable=cell-var-from-loop
68 | chunk_slice_fn = lambda x: x[ray_idx:ray_idx + chunk]
69 | chunk_rays_dict = tree_util.tree_map(chunk_slice_fn, rays_dict)
70 | num_chunk_rays = chunk_rays_dict['origins'].shape[0]
71 | remainder = num_chunk_rays % device_count
72 | if remainder != 0:
73 | padding = device_count - remainder
74 | # pylint: disable=cell-var-from-loop
75 | chunk_pad_fn = lambda x: jnp.pad(x, ((0, padding), (0, 0)), mode='edge')
76 | chunk_rays_dict = tree_util.tree_map(chunk_pad_fn, chunk_rays_dict)
77 | else:
78 | padding = 0
79 | # After padding the number of chunk_rays is always divisible by
80 | # host_count.
81 | per_host_rays = num_chunk_rays // jax.process_count()
82 | chunk_rays_dict = tree_util.tree_map(
83 | lambda x: x[(host_id * per_host_rays):((host_id + 1) * per_host_rays)],
84 | chunk_rays_dict)
85 | chunk_rays_dict = utils.shard(chunk_rays_dict, device_count)
86 | model_out = model_fn(key_0, key_1, state.optimizer.target['model'],
87 | chunk_rays_dict, state.warp_extra)
88 | if not default_ret_key:
89 | ret_key = 'fine' if 'fine' in model_out else 'coarse'
90 | else:
91 | ret_key = default_ret_key
92 | ret_map = jax_utils.unreplicate(model_out[ret_key])
93 | ret_map = jax.tree_map(lambda x: utils.unshard(x, padding), ret_map)
94 | ret_maps.append(ret_map)
95 | ret_map = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *ret_maps)
96 | logging.info('Rendering took %.04s', time.time() - start_time)
97 | out = {}
98 | for key, value in ret_map.items():
99 | out[key] = value.reshape((h, w, *value.shape[1:]))
100 |
101 | return out
102 |
--------------------------------------------------------------------------------
/nerfies/glo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A module to help create embeddings in Jax."""
16 | from flax import linen as nn
17 | import jax.numpy as jnp
18 |
19 | from nerfies import types
20 |
21 |
22 | class GloEncoder(nn.Module):
23 | """A GLO encoder module, which is just a thin wrapper around nn.Embed.
24 |
25 | Attributes:
26 | num_embeddings: The number of embeddings.
27 | features: The dimensions of each embedding.
28 | embedding_init: The initializer to use for each.
29 | """
30 |
31 | num_embeddings: int
32 | features: int
33 | embedding_init: types.Activation = nn.initializers.uniform(scale=0.05)
34 |
35 | def setup(self):
36 | self.embed = nn.Embed(
37 | num_embeddings=self.num_embeddings,
38 | features=self.features,
39 | embedding_init=self.embedding_init)
40 |
41 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
42 | """Method to get embeddings for specified indices.
43 |
44 | Args:
45 | inputs: The indices to fetch embeddings for.
46 |
47 | Returns:
48 | The embeddings corresponding to the indices provided.
49 | """
50 | if inputs.shape[-1] == 1:
51 | inputs = jnp.squeeze(inputs, axis=-1)
52 |
53 | return self.embed(inputs)
54 |
--------------------------------------------------------------------------------
/nerfies/gpath.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A thin wrapper around pathlib."""
16 | import pathlib
17 | import tensorflow as tf
18 |
19 |
20 | class GPath(pathlib.PurePosixPath):
21 | """A thin wrapper around PurePath to support various filesystems."""
22 |
23 | def open(self, *args, **kwargs):
24 | return tf.io.gfile.GFile(self, *args, **kwargs)
25 |
26 | def exists(self):
27 | return tf.io.gfile.exists(self)
28 |
29 | # pylint: disable=unused-argument
30 | def mkdir(self, mode=0o777, parents=False, exist_ok=False):
31 | if not exist_ok:
32 | if self.exists():
33 | raise FileExistsError('Directory already exists.')
34 |
35 | if parents:
36 | return tf.io.gfile.makedirs(self)
37 | else:
38 | return tf.io.gfile.mkdir(self)
39 |
40 | def glob(self, pattern):
41 | return [GPath(x) for x in tf.io.gfile.glob(str(self / pattern))]
42 |
43 | def iterdir(self):
44 | return [GPath(self, x) for x in tf.io.gfile.listdir(self)]
45 |
46 | def is_dir(self):
47 | return tf.io.gfile.isdir(self)
48 |
49 | def rmtree(self):
50 | tf.io.gfile.rmtree(self)
51 |
--------------------------------------------------------------------------------
/nerfies/image_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Image-related utility functions."""
16 | import math
17 | from typing import Tuple
18 |
19 | from absl import logging
20 | import cv2
21 | import imageio
22 | import numpy as np
23 | from PIL import Image
24 |
25 | from nerfies import gpath
26 | from nerfies import types
27 |
28 |
29 | UINT8_MAX = 255
30 | UINT16_MAX = 65535
31 |
32 |
33 | def make_divisible(image: np.ndarray, divisor: int) -> np.ndarray:
34 | """Trim the image if not divisible by the divisor."""
35 | height, width = image.shape[:2]
36 | if height % divisor == 0 and width % divisor == 0:
37 | return image
38 |
39 | new_height = height - height % divisor
40 | new_width = width - width % divisor
41 |
42 | return image[:new_height, :new_width]
43 |
44 |
45 | def downsample_image(image: np.ndarray, scale: int) -> np.ndarray:
46 | """Downsamples the image by an integer factor to prevent artifacts."""
47 | if scale == 1:
48 | return image
49 |
50 | height, width = image.shape[:2]
51 | if height % scale > 0 or width % scale > 0:
52 | raise ValueError(f'Image shape ({height},{width}) must be divisible by the'
53 | f' scale ({scale}).')
54 | out_height, out_width = height // scale, width // scale
55 | resized = cv2.resize(image, (out_width, out_height), cv2.INTER_AREA)
56 | return resized
57 |
58 |
59 | def upsample_image(image: np.ndarray, scale: int) -> np.ndarray:
60 | """Upsamples the image by an integer factor."""
61 | if scale == 1:
62 | return image
63 |
64 | height, width = image.shape[:2]
65 | out_height, out_width = height * scale, width * scale
66 | resized = cv2.resize(image, (out_width, out_height), cv2.INTER_AREA)
67 | return resized
68 |
69 |
70 | def reshape_image(image: np.ndarray, shape: Tuple[int, int]) -> np.ndarray:
71 | """Reshapes the image to the given shape."""
72 | out_height, out_width = shape
73 | return cv2.resize(
74 | image, (out_width, out_height), interpolation=cv2.INTER_AREA)
75 |
76 |
77 | def rescale_image(image: np.ndarray, scale_factor: float) -> np.ndarray:
78 | """Resize an image by a scale factor, using integer resizing if possible."""
79 | scale_factor = float(scale_factor)
80 | if scale_factor <= 0.0:
81 | raise ValueError('scale_factor must be a non-negative number.')
82 |
83 | if scale_factor == 1.0:
84 | return image
85 |
86 | height, width = image.shape[:2]
87 | if scale_factor.is_integer():
88 | return upsample_image(image, int(scale_factor))
89 |
90 | inv_scale = 1.0 / scale_factor
91 | if (inv_scale.is_integer() and (scale_factor * height).is_integer() and
92 | (scale_factor * width).is_integer()):
93 | return downsample_image(image, int(inv_scale))
94 |
95 | logging.warning(
96 | 'resizing image by non-integer factor %f, this may lead to artifacts.',
97 | scale_factor)
98 |
99 | height, width = image.shape[:2]
100 | out_height = math.ceil(height * scale_factor)
101 | out_height -= out_height % 2
102 | out_width = math.ceil(width * scale_factor)
103 | out_width -= out_width % 2
104 |
105 | return reshape_image(image, (out_height, out_width))
106 |
107 |
108 | def variance_of_laplacian(image: np.ndarray) -> np.ndarray:
109 | """Compute the variance of the Laplacian which measure the focus."""
110 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
111 | return cv2.Laplacian(gray, cv2.CVX_64F).var()
112 |
113 |
114 | def image_to_uint8(image: np.ndarray) -> np.ndarray:
115 | """Convert the image to a uint8 array."""
116 | if image.dtype == np.uint8:
117 | return image
118 | if not issubclass(image.dtype.type, np.floating):
119 | raise ValueError(
120 | f'Input image should be a floating type but is of type {image.dtype!r}')
121 | return (image * UINT8_MAX).clip(0.0, UINT8_MAX).astype(np.uint8)
122 |
123 |
124 | def image_to_uint16(image: np.ndarray) -> np.ndarray:
125 | """Convert the image to a uint16 array."""
126 | if image.dtype == np.uint16:
127 | return image
128 | if not issubclass(image.dtype.type, np.floating):
129 | raise ValueError(
130 | f'Input image should be a floating type but is of type {image.dtype!r}')
131 | return (image * UINT16_MAX).clip(0.0, UINT16_MAX).astype(np.uint16)
132 |
133 |
134 | def image_to_float32(image: np.ndarray) -> np.ndarray:
135 | """Convert the image to a float32 array and scale values appropriately."""
136 | if image.dtype == np.float32:
137 | return image
138 |
139 | dtype = image.dtype
140 | image = image.astype(np.float32)
141 | if dtype == np.uint8:
142 | return image / UINT8_MAX
143 | elif dtype == np.uint16:
144 | return image / UINT16_MAX
145 | elif dtype == np.float64:
146 | return image
147 | elif dtype == np.float16:
148 | return image
149 |
150 | raise ValueError(f'Not sure how to handle dtype {dtype}')
151 |
152 |
153 | def load_image(path: types.PathType) -> np.ndarray:
154 | """Reads an image."""
155 | if not isinstance(path, gpath.GPath):
156 | path = gpath.GPath(path)
157 |
158 | with path.open('rb') as f:
159 | return imageio.imread(f)
160 |
161 |
162 | def save_image(path: types.PathType, image: np.ndarray) -> None:
163 | """Saves the image to disk or gfile."""
164 | if not isinstance(path, gpath.GPath):
165 | path = gpath.GPath(path)
166 |
167 | with path.open('wb') as f:
168 | image = Image.fromarray(np.asarray(image))
169 | image.save(f, format=path.suffix.lstrip('.'))
170 |
171 |
172 | def save_depth(path: types.PathType, depth: np.ndarray) -> None:
173 | save_image(path, image_to_uint16(depth / 1000.0))
174 |
175 |
176 | def load_depth(path: types.PathType) -> np.ndarray:
177 | depth = load_image(path)
178 | if depth.dtype != np.uint16:
179 | raise ValueError('Depth image must be of type uint16.')
180 | return image_to_float32(depth) * 1000.0
181 |
182 |
183 | def checkerboard(h, w, size=8):
184 | """Creates a checkerboard pattern with height h and width w."""
185 | i = int(math.ceil(h / (size * 2)))
186 | j = int(math.ceil(w / (size * 2)))
187 | pattern = np.kron([[1, 0] * j, [0, 1] * j] * i,
188 | np.ones((size, size)))[:h, :w]
189 | return np.clip(pattern + 0.8, 0.0, 1.0)
190 |
--------------------------------------------------------------------------------
/nerfies/model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper functions/classes for model definition."""
16 |
17 | from flax import linen as nn
18 | from flax import optim
19 | from flax import struct
20 | from jax import lax
21 | from jax import random
22 | import jax.numpy as jnp
23 |
24 |
25 | @struct.dataclass
26 | class TrainState:
27 | optimizer: optim.Optimizer
28 | warp_alpha: jnp.ndarray = 0.0
29 | time_alpha: jnp.ndarray = 0.0
30 |
31 | @property
32 | def warp_extra(self):
33 | return {'alpha': self.warp_alpha, 'time_alpha': self.time_alpha}
34 |
35 |
36 | def sample_along_rays(key, origins, directions, num_coarse_samples, near, far,
37 | use_stratified_sampling, use_linear_disparity):
38 | """Stratified sampling along the rays.
39 |
40 | Args:
41 | key: jnp.ndarray, random generator key.
42 | origins: ray origins.
43 | directions: ray directions.
44 | num_coarse_samples: int.
45 | near: float, near clip.
46 | far: float, far clip.
47 | use_stratified_sampling: use stratified sampling.
48 | use_linear_disparity: sampling linearly in disparity rather than depth.
49 |
50 | Returns:
51 | z_vals: jnp.ndarray, [batch_size, num_coarse_samples], sampled z values.
52 | points: jnp.ndarray, [batch_size, num_coarse_samples, 3], sampled points.
53 | """
54 | batch_size = origins.shape[0]
55 |
56 | t_vals = jnp.linspace(0., 1., num_coarse_samples)
57 | if not use_linear_disparity:
58 | z_vals = near * (1. - t_vals) + far * t_vals
59 | else:
60 | z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)
61 | if use_stratified_sampling:
62 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
63 | upper = jnp.concatenate([mids, z_vals[..., -1:]], -1)
64 | lower = jnp.concatenate([z_vals[..., :1], mids], -1)
65 | t_rand = random.uniform(key, [batch_size, num_coarse_samples])
66 | z_vals = lower + (upper - lower) * t_rand
67 | else:
68 | # Broadcast z_vals to make the returned shape consistent.
69 | z_vals = jnp.broadcast_to(z_vals[None, ...],
70 | [batch_size, num_coarse_samples])
71 |
72 | return (z_vals, (origins[..., None, :] +
73 | z_vals[..., :, None] * directions[..., None, :]))
74 |
75 |
76 | def volumetric_rendering(rgb,
77 | sigma,
78 | z_vals,
79 | dirs,
80 | use_white_background,
81 | sample_at_infinity=True,
82 | return_weights=False,
83 | eps=1e-10):
84 | """Volumetric Rendering Function.
85 |
86 | Args:
87 | rgb: an array of size (B,S,3) containing the RGB color values.
88 | sigma: an array of size (B,S,1) containing the densities.
89 | z_vals: an array of size (B,S) containing the z-coordinate of the samples.
90 | dirs: an array of size (B,3) containing the directions of rays.
91 | use_white_background: whether to assume a white background or not.
92 | sample_at_infinity: if True adds a sample at infinity.
93 | return_weights: if True returns the weights in the dictionary.
94 | eps: a small number to prevent numerical issues.
95 |
96 | Returns:
97 | A dictionary containing:
98 | rgb: an array of size (B,3) containing the rendered colors.
99 | depth: an array of size (B,) containing the rendered depth.
100 | acc: an array of size (B,) containing the accumulated density.
101 | weights: an array of size (B,S) containing the weight of each sample.
102 | """
103 | # TODO(keunhong): remove this hack.
104 | last_sample_z = 1e10 if sample_at_infinity else 1e-19
105 | dists = jnp.concatenate([
106 | z_vals[..., 1:] - z_vals[..., :-1],
107 | jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
108 | ], -1)
109 | dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
110 | alpha = 1.0 - jnp.exp(-sigma * dists)
111 | # Prepend a 1.0 to make this an 'exclusive' cumprod as in `tf.math.cumprod`.
112 | accum_prod = jnp.concatenate([
113 | jnp.ones_like(alpha[..., :1], alpha.dtype),
114 | jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1),
115 | ], axis=-1)
116 | weights = alpha * accum_prod
117 |
118 | rgb = (weights[..., None] * rgb).sum(axis=-2)
119 | exp_depth = (weights * z_vals).sum(axis=-1)
120 | med_depth = compute_depth_map(weights, z_vals)
121 | acc = weights.sum(axis=-1)
122 | if use_white_background:
123 | rgb = rgb + (1. - acc[..., None])
124 |
125 | if sample_at_infinity:
126 | acc = weights[..., :-1].sum(axis=-1)
127 |
128 | out = {
129 | 'rgb': rgb,
130 | 'depth': exp_depth,
131 | 'med_depth': med_depth,
132 | 'acc': acc,
133 | }
134 | if return_weights:
135 | out['weights'] = weights
136 | return out
137 |
138 |
139 | def piecewise_constant_pdf(key, bins, weights, num_coarse_samples,
140 | use_stratified_sampling):
141 | """Piecewise-Constant PDF sampling.
142 |
143 | Args:
144 | key: jnp.ndarray(float32), [2,], random number generator.
145 | bins: jnp.ndarray(float32), [batch_size, n_bins + 1].
146 | weights: jnp.ndarray(float32), [batch_size, n_bins].
147 | num_coarse_samples: int, the number of samples.
148 | use_stratified_sampling: bool, use use_stratified_sampling samples.
149 |
150 | Returns:
151 | z_samples: jnp.ndarray(float32), [batch_size, num_coarse_samples].
152 | """
153 | eps = 1e-5
154 |
155 | # Get pdf
156 | weights += eps # prevent nans
157 | pdf = weights / weights.sum(axis=-1, keepdims=True)
158 | cdf = jnp.cumsum(pdf, axis=-1)
159 | cdf = jnp.concatenate([jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf], axis=-1)
160 |
161 | # Take uniform samples
162 | if use_stratified_sampling:
163 | u = random.uniform(key, list(cdf.shape[:-1]) + [num_coarse_samples])
164 | else:
165 | u = jnp.linspace(0., 1., num_coarse_samples)
166 | u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_coarse_samples])
167 |
168 | # Invert CDF. This takes advantage of the fact that `bins` is sorted.
169 | mask = (u[..., None, :] >= cdf[..., :, None])
170 |
171 | def minmax(x):
172 | x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2)
173 | x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)
174 | x0 = jnp.minimum(x0, x[..., -2:-1])
175 | x1 = jnp.maximum(x1, x[..., 1:2])
176 | return x0, x1
177 |
178 | bins_g0, bins_g1 = minmax(bins)
179 | cdf_g0, cdf_g1 = minmax(cdf)
180 |
181 | denom = (cdf_g1 - cdf_g0)
182 | denom = jnp.where(denom < eps, 1., denom)
183 | t = (u - cdf_g0) / denom
184 | z_samples = bins_g0 + t * (bins_g1 - bins_g0)
185 |
186 | # Prevent gradient from backprop-ing through samples
187 | return lax.stop_gradient(z_samples)
188 |
189 |
190 | def sample_pdf(key, bins, weights, origins, directions, z_vals,
191 | num_coarse_samples, use_stratified_sampling):
192 | """Hierarchical sampling.
193 |
194 | Args:
195 | key: jnp.ndarray(float32), [2,], random number generator.
196 | bins: jnp.ndarray(float32), [batch_size, n_bins + 1].
197 | weights: jnp.ndarray(float32), [batch_size, n_bins].
198 | origins: ray origins.
199 | directions: ray directions.
200 | z_vals: jnp.ndarray(float32), [batch_size, n_coarse_samples].
201 | num_coarse_samples: int, the number of samples.
202 | use_stratified_sampling: bool, use use_stratified_sampling samples.
203 |
204 | Returns:
205 | z_vals: jnp.ndarray(float32),
206 | [batch_size, n_coarse_samples + num_fine_samples].
207 | points: jnp.ndarray(float32),
208 | [batch_size, n_coarse_samples + num_fine_samples, 3].
209 | """
210 | z_samples = piecewise_constant_pdf(key, bins, weights, num_coarse_samples,
211 | use_stratified_sampling)
212 | # Compute united z_vals and sample points
213 | z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
214 | return z_vals, (
215 | origins[..., None, :] + z_vals[..., None] * directions[..., None, :])
216 |
217 |
218 | def compute_opaqueness_mask(weights, depth_threshold=0.5):
219 | """Computes a mask which will be 1.0 at the depth point.
220 |
221 | Args:
222 | weights: the density weights from NeRF.
223 | depth_threshold: the accumulation threshold which will be used as the depth
224 | termination point.
225 |
226 | Returns:
227 | A tensor containing a mask with the same size as weights that has one
228 | element long the sample dimension that is 1.0. This element is the point
229 | where the 'surface' is.
230 | """
231 | cumulative_contribution = jnp.cumsum(weights, axis=-1)
232 | depth_threshold = jnp.array(depth_threshold, dtype=weights.dtype)
233 | opaqueness = cumulative_contribution >= depth_threshold
234 | false_padding = jnp.zeros_like(opaqueness[..., :1])
235 | padded_opaqueness = jnp.concatenate(
236 | [false_padding, opaqueness[..., :-1]], axis=-1)
237 | opaqueness_mask = jnp.logical_xor(opaqueness, padded_opaqueness)
238 | opaqueness_mask = opaqueness_mask.astype(weights.dtype)
239 | return opaqueness_mask
240 |
241 |
242 | def compute_depth_index(weights, depth_threshold=0.5):
243 | """Compute the sample index of the median depth accumulation."""
244 | opaqueness_mask = compute_opaqueness_mask(weights, depth_threshold)
245 | return jnp.argmax(opaqueness_mask, axis=-1)
246 |
247 |
248 | def compute_depth_map(weights, z_vals, depth_threshold=0.5):
249 | """Compute the depth using the median accumulation.
250 |
251 | Note that this differs from the depth computation in NeRF-W's codebase!
252 |
253 | Args:
254 | weights: the density weights from NeRF.
255 | z_vals: the z coordinates of the samples.
256 | depth_threshold: the accumulation threshold which will be used as the depth
257 | termination point.
258 |
259 | Returns:
260 | A tensor containing the depth of each input pixel.
261 | """
262 | opaqueness_mask = compute_opaqueness_mask(weights, depth_threshold)
263 | return jnp.sum(opaqueness_mask * z_vals, axis=-1)
264 |
265 |
266 | def noise_regularize(key, raw, noise_std, use_stratified_sampling):
267 | """Regularize the density prediction by adding gaussian noise.
268 |
269 | Args:
270 | key: jnp.ndarray(float32), [2,], random number generator.
271 | raw: jnp.ndarray(float32), [batch_size, num_coarse_samples, 4].
272 | noise_std: float, std dev of noise added to regularize sigma output.
273 | use_stratified_sampling: add noise only if use_stratified_sampling is True.
274 |
275 | Returns:
276 | raw: jnp.ndarray(float32), [batch_size, num_coarse_samples, 4], updated raw.
277 | """
278 | if (noise_std is not None) and noise_std > 0.0 and use_stratified_sampling:
279 | unused_key, key = random.split(key)
280 | noise = random.normal(key, raw[..., 3:4].shape, dtype=raw.dtype) * noise_std
281 | raw = jnp.concatenate([raw[..., :3], raw[..., 3:4] + noise], axis=-1)
282 | return raw
283 |
284 |
285 | def broadcast_feature_to(array: jnp.ndarray, shape: jnp.shape):
286 | """Matches the shape dimensions (everything except the channel dims).
287 |
288 | This is useful when you watch to match the shape of two features that have
289 | a different number of channels.
290 |
291 | Args:
292 | array: the array to broadcast.
293 | shape: the shape to broadcast the tensor to.
294 |
295 | Returns:
296 | The broadcasted tensor.
297 | """
298 | out_shape = (*shape[:-1], array.shape[-1])
299 | return jnp.broadcast_to(array, out_shape)
300 |
301 |
302 | def metadata_like(rays, metadata_id):
303 | """Create a metadata array like a ray batch."""
304 | return jnp.full_like(rays[..., :1], fill_value=metadata_id, dtype=jnp.uint32)
305 |
306 |
307 | def vmap_module(module, in_axes=0, out_axes=0, num_batch_dims=1):
308 | """Vectorize a module.
309 |
310 | Args:
311 | module: the module to vectorize.
312 | in_axes: the `in_axes` argument passed to vmap. See `jax.vmap`.
313 | out_axes: the `out_axes` argument passed to vmap. See `jax.vmap`.
314 | num_batch_dims: the number of batch dimensions (how many times to apply vmap
315 | to the module).
316 |
317 | Returns:
318 | A vectorized module.
319 | """
320 | for _ in range(num_batch_dims):
321 | module = nn.vmap(
322 | module,
323 | variable_axes={'params': None},
324 | split_rngs={'params': False},
325 | in_axes=in_axes,
326 | out_axes=out_axes)
327 |
328 | return module
329 |
330 |
331 | def identity_initializer(_, shape):
332 | max_shape = max(shape)
333 | return jnp.eye(max_shape)[:shape[0], :shape[1]]
334 |
--------------------------------------------------------------------------------
/nerfies/modules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Modules for NeRF models."""
16 | import functools
17 | from typing import Optional, Tuple
18 |
19 | from flax import linen as nn
20 | import jax
21 | import jax.numpy as jnp
22 |
23 | from nerfies import types
24 |
25 |
26 | class MLP(nn.Module):
27 | """Basic MLP class with hidden layers and an output layers."""
28 | depth: int
29 | width: int
30 | hidden_init: types.Initializer = nn.initializers.xavier_uniform()
31 | hidden_activation: types.Activation = nn.relu
32 | output_init: Optional[types.Initializer] = None
33 | output_channels: int = 0
34 | output_activation: Optional[types.Activation] = lambda x: x
35 | use_bias: bool = True
36 | skips: Tuple[int] = tuple()
37 |
38 | @nn.compact
39 | def __call__(self, x):
40 | inputs = x
41 | for i in range(self.depth):
42 | layer = nn.Dense(
43 | self.width,
44 | use_bias=self.use_bias,
45 | kernel_init=self.hidden_init,
46 | name=f'hidden_{i}')
47 | if i in self.skips:
48 | x = jnp.concatenate([x, inputs], axis=-1)
49 | x = layer(x)
50 | x = self.hidden_activation(x)
51 |
52 | if self.output_channels > 0:
53 | logit_layer = nn.Dense(
54 | self.output_channels,
55 | use_bias=self.use_bias,
56 | kernel_init=self.output_init,
57 | name='logit')
58 | x = logit_layer(x)
59 | if self.output_activation is not None:
60 | x = self.output_activation(x)
61 |
62 | return x
63 |
64 |
65 | class NerfMLP(nn.Module):
66 | """A simple MLP.
67 |
68 | Attributes:
69 | nerf_trunk_depth: int, the depth of the first part of MLP.
70 | nerf_trunk_width: int, the width of the first part of MLP.
71 | nerf_rgb_branch_depth: int, the depth of the second part of MLP.
72 | nerf_rgb_branch_width: int, the width of the second part of MLP.
73 | activation: function, the activation function used in the MLP.
74 | skips: which layers to add skip layers to.
75 | alpha_channels: int, the number of alpha_channelss.
76 | rgb_channels: int, the number of rgb_channelss.
77 | condition_density: if True put the condition at the begining which
78 | conditions the density of the field.
79 | """
80 | trunk_depth: int = 8
81 | trunk_width: int = 256
82 |
83 | rgb_branch_depth: int = 1
84 | rgb_branch_width: int = 128
85 | rgb_channels: int = 3
86 |
87 | alpha_branch_depth: int = 0
88 | alpha_branch_width: int = 128
89 | alpha_channels: int = 1
90 |
91 | activation: types.Activation = nn.relu
92 | skips: Tuple[int] = (4,)
93 |
94 | @nn.compact
95 | def __call__(self, x, trunk_condition, alpha_condition, rgb_condition):
96 | """Multi-layer perception for nerf.
97 |
98 | Args:
99 | x: sample points with shape [batch, num_coarse_samples, feature].
100 | trunk_condition: a condition array provided to the trunk.
101 | alpha_condition: a condition array provided to the alpha branch.
102 | rgb_condition: a condition array provided in the RGB branch.
103 |
104 | Returns:
105 | raw: [batch, num_coarse_samples, rgb_channels+alpha_channels].
106 | """
107 | dense = functools.partial(
108 | nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform())
109 |
110 | feature_dim = x.shape[-1]
111 | num_samples = x.shape[1]
112 | x = x.reshape([-1, feature_dim])
113 |
114 | def broadcast_condition(c):
115 | # Broadcast condition from [batch, feature] to
116 | # [batch, num_coarse_samples, feature] since all the samples along the
117 | # same ray has the same viewdir.
118 | c = jnp.tile(c[:, None, :], (1, num_samples, 1))
119 | # Collapse the [batch, num_coarse_samples, feature] tensor to
120 | # [batch * num_coarse_samples, feature] to be fed into nn.Dense.
121 | c = c.reshape([-1, c.shape[-1]])
122 | return c
123 |
124 | trunk_mlp = MLP(depth=self.trunk_depth,
125 | width=self.trunk_width,
126 | hidden_activation=self.activation,
127 | hidden_init=jax.nn.initializers.glorot_uniform(),
128 | skips=self.skips)
129 | rgb_mlp = MLP(depth=self.rgb_branch_depth,
130 | width=self.rgb_branch_width,
131 | hidden_activation=self.activation,
132 | hidden_init=jax.nn.initializers.glorot_uniform(),
133 | output_init=jax.nn.initializers.glorot_uniform(),
134 | output_channels=self.rgb_channels)
135 | alpha_mlp = MLP(depth=self.alpha_branch_depth,
136 | width=self.alpha_branch_width,
137 | hidden_activation=self.activation,
138 | hidden_init=jax.nn.initializers.glorot_uniform(),
139 | output_init=jax.nn.initializers.glorot_uniform(),
140 | output_channels=self.alpha_channels)
141 |
142 | if trunk_condition is not None:
143 | trunk_condition = broadcast_condition(trunk_condition)
144 | trunk_input = jnp.concatenate([x, trunk_condition], axis=-1)
145 | else:
146 | trunk_input = x
147 | x = trunk_mlp(trunk_input)
148 |
149 | if (alpha_condition is not None) or (rgb_condition is not None):
150 | bottleneck = dense(self.trunk_width, name='bottleneck')(x)
151 |
152 | if alpha_condition is not None:
153 | alpha_condition = broadcast_condition(alpha_condition)
154 | alpha_input = jnp.concatenate([bottleneck, alpha_condition], axis=-1)
155 | else:
156 | alpha_input = x
157 | alpha = alpha_mlp(alpha_input)
158 |
159 | if rgb_condition is not None:
160 | rgb_condition = broadcast_condition(rgb_condition)
161 | rgb_input = jnp.concatenate([bottleneck, rgb_condition], axis=-1)
162 | else:
163 | rgb_input = x
164 | rgb = rgb_mlp(rgb_input)
165 |
166 | return {
167 | 'rgb': rgb.reshape((-1, num_samples, self.rgb_channels)),
168 | 'alpha': alpha.reshape((-1, num_samples, self.alpha_channels)),
169 | }
170 |
171 |
172 | class SinusoidalEncoder(nn.Module):
173 | """A vectorized sinusoidal encoding.
174 |
175 | Attributes:
176 | num_freqs: the number of frequency bands in the encoding.
177 | min_freq_log2: the log (base 2) of the lower frequency.
178 | max_freq_log2: the log (base 2) of the upper frequency.
179 | scale: a scaling factor for the positional encoding.
180 | use_identity: if True use the identity encoding as well.
181 | """
182 | num_freqs: int
183 | min_freq_log2: int = 0
184 | max_freq_log2: Optional[int] = None
185 | scale: float = 1.0
186 | use_identity: bool = True
187 |
188 | def setup(self):
189 | if self.max_freq_log2 is None:
190 | max_freq_log2 = self.num_freqs - 1.0
191 | else:
192 | max_freq_log2 = self.max_freq_log2
193 | self.freq_bands = 2.0**jnp.linspace(self.min_freq_log2,
194 | max_freq_log2,
195 | int(self.num_freqs))
196 |
197 | # (F, 1).
198 | self.freqs = jnp.reshape(self.freq_bands, (self.num_freqs, 1))
199 |
200 | def __call__(self, x, alpha: Optional[float] = None):
201 | """A vectorized sinusoidal encoding.
202 |
203 | Args:
204 | x: the input features to encode.
205 | alpha: a dummy argument for API compatibility.
206 |
207 | Returns:
208 | A tensor containing the encoded features.
209 | """
210 | if self.num_freqs == 0:
211 | return x
212 |
213 | x_expanded = jnp.expand_dims(x, axis=-2) # (1, C).
214 | # Will be broadcasted to shape (F, C).
215 | angles = self.scale * x_expanded * self.freqs
216 |
217 | # The shape of the features is (F, 2, C) so that when we reshape it
218 | # it matches the ordering of the original NeRF code.
219 | # Vectorize the computation of the high-frequency (sin, cos) terms.
220 | # We use the trigonometric identity: cos(x) = sin(x + pi/2)
221 | features = jnp.stack((angles, angles + jnp.pi / 2), axis=-2)
222 | features = features.flatten()
223 | features = jnp.sin(features)
224 |
225 | # Prepend the original signal for the identity.
226 | if self.use_identity:
227 | features = jnp.concatenate([x, features], axis=-1)
228 | return features
229 |
230 |
231 | class AnnealedSinusoidalEncoder(nn.Module):
232 | """An annealed sinusoidal encoding."""
233 | num_freqs: int
234 | min_freq_log2: int = 0
235 | max_freq_log2: Optional[int] = None
236 | scale: float = 1.0
237 | use_identity: bool = True
238 |
239 | @nn.compact
240 | def __call__(self, x, alpha):
241 | if alpha is None:
242 | raise ValueError('alpha must be specified.')
243 | if self.num_freqs == 0:
244 | return x
245 |
246 | num_channels = x.shape[-1]
247 |
248 | base_encoder = SinusoidalEncoder(
249 | num_freqs=self.num_freqs,
250 | min_freq_log2=self.min_freq_log2,
251 | max_freq_log2=self.max_freq_log2,
252 | scale=self.scale,
253 | use_identity=self.use_identity)
254 | features = base_encoder(x)
255 |
256 | if self.use_identity:
257 | identity, features = jnp.split(features, (x.shape[-1],), axis=-1)
258 |
259 | # Apply the window by broadcasting to save on memory.
260 | features = jnp.reshape(features, (-1, 2, num_channels))
261 | window = self.cosine_easing_window(
262 | self.min_freq_log2, self.max_freq_log2, self.num_freqs, alpha)
263 | window = jnp.reshape(window, (-1, 1, 1))
264 | features = window * features
265 |
266 | if self.use_identity:
267 | return jnp.concatenate([
268 | identity,
269 | features.flatten(),
270 | ], axis=-1)
271 | else:
272 | return features.flatten()
273 |
274 | @classmethod
275 | def cosine_easing_window(cls, min_freq_log2, max_freq_log2, num_bands, alpha):
276 | """Eases in each frequency one by one with a cosine.
277 |
278 | This is equivalent to taking a Tukey window and sliding it to the right
279 | along the frequency spectrum.
280 |
281 | Args:
282 | min_freq_log2: the lower frequency band.
283 | max_freq_log2: the upper frequency band.
284 | num_bands: the number of frequencies.
285 | alpha: will ease in each frequency as alpha goes from 0.0 to num_freqs.
286 |
287 | Returns:
288 | A 1-d numpy array with num_sample elements containing the window.
289 | """
290 | if max_freq_log2 is None:
291 | max_freq_log2 = num_bands - 1.0
292 | bands = jnp.linspace(min_freq_log2, max_freq_log2, num_bands)
293 | x = jnp.clip(alpha - bands, 0.0, 1.0)
294 | return 0.5 * (1 + jnp.cos(jnp.pi * x + jnp.pi))
295 |
296 |
297 | class TimeEncoder(nn.Module):
298 | """Encodes a timestamp to an embedding."""
299 | num_freqs: int
300 |
301 | features: int = 10
302 | depth: int = 6
303 | width: int = 64
304 | skips: int = (4,)
305 | hidden_init: types.Initializer = nn.initializers.xavier_uniform()
306 | output_init: types.Activation = nn.initializers.uniform(scale=0.05)
307 |
308 | def setup(self):
309 | self.posenc = AnnealedSinusoidalEncoder(num_freqs=self.num_freqs)
310 | self.mlp = MLP(
311 | depth=self.depth,
312 | width=self.width,
313 | skips=self.skips,
314 | hidden_init=self.hidden_init,
315 | output_channels=self.features,
316 | output_init=self.output_init)
317 |
318 | def __call__(self, time, alpha=None):
319 | if alpha is None:
320 | alpha = self.num_freqs
321 | encoded_time = self.posenc(time, alpha)
322 | return self.mlp(encoded_time)
323 |
--------------------------------------------------------------------------------
/nerfies/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Quaternion math.
16 |
17 | This module assumes the xyzw quaternion format where xyz is the imaginary part
18 | and w is the real part.
19 |
20 | Functions in this module support both batched and unbatched quaternions.
21 | """
22 | from jax import numpy as jnp
23 | from jax.numpy import linalg
24 |
25 |
26 | def safe_acos(t, eps=1e-8):
27 | """A safe version of arccos which avoids evaluating at -1 or 1."""
28 | return jnp.arccos(jnp.clip(t, -1.0 + eps, 1.0 - eps))
29 |
30 |
31 | def im(q):
32 | """Fetch the imaginary part of the quaternion."""
33 | return q[..., :3]
34 |
35 |
36 | def re(q):
37 | """Fetch the real part of the quaternion."""
38 | return q[..., 3:]
39 |
40 |
41 | def identity():
42 | return jnp.array([0.0, 0.0, 0.0, 1.0])
43 |
44 |
45 | def conjugate(q):
46 | """Compute the conjugate of a quaternion."""
47 | return jnp.concatenate([-im(q), re(q)], axis=-1)
48 |
49 |
50 | def inverse(q):
51 | """Compute the inverse of a quaternion."""
52 | return normalize(conjugate(q))
53 |
54 |
55 | def normalize(q):
56 | """Normalize a quaternion."""
57 | return q / norm(q)
58 |
59 |
60 | def norm(q):
61 | return linalg.norm(q, axis=-1, keepdims=True)
62 |
63 |
64 | def multiply(q1, q2):
65 | """Multiply two quaternions."""
66 | c = (re(q1) * im(q2)
67 | + re(q2) * im(q1)
68 | + jnp.cross(im(q1), im(q2)))
69 | w = re(q1) * re(q2) - jnp.dot(im(q1), im(q2))
70 | return jnp.concatenate([c, w], axis=-1)
71 |
72 |
73 | def rotate(q, v):
74 | """Rotate a vector using a quaternion."""
75 | # Create the quaternion representation of the vector.
76 | q_v = jnp.concatenate([v, jnp.zeros_like(v[..., :1])], axis=-1)
77 | return im(multiply(multiply(q, q_v), conjugate(q)))
78 |
79 |
80 | def log(q, eps=1e-8):
81 | """Computes the quaternion logarithm.
82 |
83 | References:
84 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions
85 |
86 | Args:
87 | q: the quaternion in (x,y,z,w) format.
88 | eps: an epsilon value for numerical stability.
89 |
90 | Returns:
91 | The logarithm of q.
92 | """
93 | mag = linalg.norm(q, axis=-1, keepdims=True)
94 | v = im(q)
95 | s = re(q)
96 | w = jnp.log(mag)
97 | denom = jnp.maximum(
98 | linalg.norm(v, axis=-1, keepdims=True), eps * jnp.ones_like(v))
99 | xyz = v / denom * safe_acos(s / eps)
100 | return jnp.concatenate((xyz, w), axis=-1)
101 |
102 |
103 | def exp(q, eps=1e-8):
104 | """Computes the quaternion exponential.
105 |
106 | References:
107 | https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions
108 |
109 | Args:
110 | q: the quaternion in (x,y,z,w) format or (x,y,z) if is_pure is True.
111 | eps: an epsilon value for numerical stability.
112 |
113 | Returns:
114 | The exponential of q.
115 | """
116 | is_pure = q.shape[-1] == 3
117 | if is_pure:
118 | s = jnp.zeros_like(q[..., -1:])
119 | v = q
120 | else:
121 | v = im(q)
122 | s = re(q)
123 |
124 | norm_v = linalg.norm(v, axis=-1, keepdims=True)
125 | exp_s = jnp.exp(s)
126 | w = jnp.cos(norm_v)
127 | xyz = jnp.sin(norm_v) * v / jnp.maximum(norm_v, eps * jnp.ones_like(norm_v))
128 | return exp_s * jnp.concatenate((xyz, w), axis=-1)
129 |
130 |
131 | def to_rotation_matrix(q):
132 | """Constructs a rotation matrix from a quaternion.
133 |
134 | Args:
135 | q: a (*,4) array containing quaternions.
136 |
137 | Returns:
138 | A (*,3,3) array containing rotation matrices.
139 | """
140 | x, y, z, w = jnp.split(q, 4, axis=-1)
141 | s = 1.0 / jnp.sum(q ** 2, axis=-1)
142 | return jnp.stack([
143 | jnp.stack([1 - 2 * s * (y ** 2 + z ** 2),
144 | 2 * s * (x * y - z * w),
145 | 2 * s * (x * z + y * w)], axis=0),
146 | jnp.stack([2 * s * (x * y + z * w),
147 | 1 - s * 2 * (x ** 2 + z ** 2),
148 | 2 * s * (y * z - x * w)], axis=0),
149 | jnp.stack([2 * s * (x * z - y * w),
150 | 2 * s * (y * z + x * w),
151 | 1 - 2 * s * (x ** 2 + y ** 2)], axis=0),
152 | ], axis=0)
153 |
154 |
155 | def from_rotation_matrix(m, eps=1e-9):
156 | """Construct quaternion from a rotation matrix.
157 |
158 | Args:
159 | m: a (*,3,3) array containing rotation matrices.
160 | eps: a small number for numerical stability.
161 |
162 | Returns:
163 | A (*,4) array containing quaternions.
164 | """
165 | trace = jnp.trace(m)
166 | m00 = m[..., 0, 0]
167 | m01 = m[..., 0, 1]
168 | m02 = m[..., 0, 2]
169 | m10 = m[..., 1, 0]
170 | m11 = m[..., 1, 1]
171 | m12 = m[..., 1, 2]
172 | m20 = m[..., 2, 0]
173 | m21 = m[..., 2, 1]
174 | m22 = m[..., 2, 2]
175 |
176 | def tr_positive():
177 | sq = jnp.sqrt(trace + 1.0) * 2. # sq = 4 * w.
178 | w = 0.25 * sq
179 | x = jnp.divide(m21 - m12, sq)
180 | y = jnp.divide(m02 - m20, sq)
181 | z = jnp.divide(m10 - m01, sq)
182 | return jnp.stack((x, y, z, w), axis=-1)
183 |
184 | def cond_1():
185 | sq = jnp.sqrt(1.0 + m00 - m11 - m22 + eps) * 2. # sq = 4 * x.
186 | w = jnp.divide(m21 - m12, sq)
187 | x = 0.25 * sq
188 | y = jnp.divide(m01 + m10, sq)
189 | z = jnp.divide(m02 + m20, sq)
190 | return jnp.stack((x, y, z, w), axis=-1)
191 |
192 | def cond_2():
193 | sq = jnp.sqrt(1.0 + m11 - m00 - m22 + eps) * 2. # sq = 4 * y.
194 | w = jnp.divide(m02 - m20, sq)
195 | x = jnp.divide(m01 + m10, sq)
196 | y = 0.25 * sq
197 | z = jnp.divide(m12 + m21, sq)
198 | return jnp.stack((x, y, z, w), axis=-1)
199 |
200 | def cond_3():
201 | sq = jnp.sqrt(1.0 + m22 - m00 - m11 + eps) * 2. # sq = 4 * z.
202 | w = jnp.divide(m10 - m01, sq)
203 | x = jnp.divide(m02 + m20, sq)
204 | y = jnp.divide(m12 + m21, sq)
205 | z = 0.25 * sq
206 | return jnp.stack((x, y, z, w), axis=-1)
207 |
208 | def cond_idx(cond):
209 | cond = jnp.expand_dims(cond, -1)
210 | cond = jnp.tile(cond, [1] * (len(m.shape) - 2) + [4])
211 | return cond
212 |
213 | where_2 = jnp.where(cond_idx(m11 > m22), cond_2(), cond_3())
214 | where_1 = jnp.where(cond_idx((m00 > m11) & (m00 > m22)), cond_1(), where_2)
215 | return jnp.where(cond_idx(trace > 0), tr_positive(), where_1)
216 |
--------------------------------------------------------------------------------
/nerfies/rigid_body.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint: disable=invalid-name
16 | # pytype: disable=attribute-error
17 | import jax
18 | from jax import numpy as jnp
19 |
20 |
21 | @jax.jit
22 | def skew(w: jnp.ndarray) -> jnp.ndarray:
23 | """Build a skew matrix ("cross product matrix") for vector w.
24 |
25 | Modern Robotics Eqn 3.30.
26 |
27 | Args:
28 | w: (3,) A 3-vector
29 |
30 | Returns:
31 | W: (3, 3) A skew matrix such that W @ v == w x v
32 | """
33 | w = jnp.reshape(w, (3))
34 | return jnp.array([[0.0, -w[2], w[1]], \
35 | [w[2], 0.0, -w[0]], \
36 | [-w[1], w[0], 0.0]])
37 |
38 |
39 | def rp_to_se3(R: jnp.ndarray, p: jnp.ndarray) -> jnp.ndarray:
40 | """Rotation and translation to homogeneous transform.
41 |
42 | Args:
43 | R: (3, 3) An orthonormal rotation matrix.
44 | p: (3,) A 3-vector representing an offset.
45 |
46 | Returns:
47 | X: (4, 4) The homogeneous transformation matrix described by rotating by R
48 | and translating by p.
49 | """
50 | p = jnp.reshape(p, (3, 1))
51 | return jnp.block([[R, p], [jnp.array([[0.0, 0.0, 0.0, 1.0]])]])
52 |
53 |
54 | def exp_so3(w: jnp.ndarray, theta: float) -> jnp.ndarray:
55 | """Exponential map from Lie algebra so3 to Lie group SO3.
56 |
57 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula.
58 |
59 | Args:
60 | w: (3,) An axis of rotation.
61 | theta: An angle of rotation.
62 |
63 | Returns:
64 | R: (3, 3) An orthonormal rotation matrix representing a rotation of
65 | magnitude theta about axis w.
66 | """
67 | W = skew(w)
68 | return jnp.eye(3) + jnp.sin(theta) * W + (1.0 - jnp.cos(theta)) * W @ W
69 |
70 |
71 | def exp_se3(S: jnp.ndarray, theta: float) -> jnp.ndarray:
72 | """Exponential map from Lie algebra so3 to Lie group SO3.
73 |
74 | Modern Robotics Eqn 3.88.
75 |
76 | Args:
77 | S: (6,) A screw axis of motion.
78 | theta: Magnitude of motion.
79 |
80 | Returns:
81 | a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating
82 | motion of magnitude theta about S for one second.
83 | """
84 | w, v = jnp.split(S, 2)
85 | W = skew(w)
86 | R = exp_so3(w, theta)
87 | p = (theta * jnp.eye(3) + (1.0 - jnp.cos(theta)) * W +
88 | (theta - jnp.sin(theta)) * W @ W) @ v
89 | return rp_to_se3(R, p)
90 |
91 |
92 | def to_homogenous(v):
93 | return jnp.concatenate([v, jnp.ones_like(v[..., :1])], axis=-1)
94 |
95 |
96 | def from_homogenous(v):
97 | return v[..., :3] / v[..., -1:]
98 |
--------------------------------------------------------------------------------
/nerfies/schedules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Annealing Schedules."""
16 | import abc
17 | import collections
18 | import copy
19 | import math
20 | from typing import Any, Iterable, List, Tuple, Union
21 |
22 | from jax import numpy as jnp
23 |
24 |
25 | def from_tuple(x):
26 | schedule_type, *args = x
27 | return SCHEDULE_MAP[schedule_type](*args)
28 |
29 |
30 | def from_dict(d):
31 | d = copy.copy(dict(d))
32 | schedule_type = d.pop('type')
33 | return SCHEDULE_MAP[schedule_type](**d)
34 |
35 |
36 | def from_config(schedule):
37 | if isinstance(schedule, Schedule):
38 | return schedule
39 | if isinstance(schedule, Tuple) or isinstance(schedule, List):
40 | return from_tuple(schedule)
41 | if isinstance(schedule, collections.Mapping):
42 | return from_dict(schedule)
43 |
44 | raise ValueError(f'Unknown type {type(schedule)}.')
45 |
46 |
47 | class Schedule(abc.ABC):
48 | """An interface for generic schedules.."""
49 |
50 | @abc.abstractmethod
51 | def get(self, step):
52 | """Get the value for the given step."""
53 | raise NotImplementedError
54 |
55 | def __call__(self, step):
56 | return self.get(step)
57 |
58 |
59 | class ConstantSchedule(Schedule):
60 | """Linearly scaled scheduler."""
61 |
62 | def __init__(self, value):
63 | super().__init__()
64 | self.value = value
65 |
66 | def get(self, step):
67 | """Get the value for the given step."""
68 | return jnp.full_like(step, self.value, dtype=jnp.float32)
69 |
70 |
71 | class LinearSchedule(Schedule):
72 | """Linearly scaled scheduler."""
73 |
74 | def __init__(self, initial_value, final_value, num_steps):
75 | super().__init__()
76 | self.initial_value = initial_value
77 | self.final_value = final_value
78 | self.num_steps = num_steps
79 |
80 | def get(self, step):
81 | """Get the value for the given step."""
82 | if self.num_steps == 0:
83 | return jnp.full_like(step, self.final_value, dtype=jnp.float32)
84 | alpha = jnp.minimum(step / self.num_steps, 1.0)
85 | return (1.0 - alpha) * self.initial_value + alpha * self.final_value
86 |
87 |
88 | class ExponentialSchedule(Schedule):
89 | """Exponentially decaying scheduler."""
90 |
91 | def __init__(self, initial_value, final_value, num_steps, eps=1e-10):
92 | super().__init__()
93 | if initial_value <= final_value:
94 | raise ValueError('Final value must be less than initial value.')
95 |
96 | self.initial_value = initial_value
97 | self.final_value = final_value
98 | self.num_steps = num_steps
99 | self.eps = eps
100 |
101 | def get(self, step):
102 | """Get the value for the given step."""
103 | if step >= self.num_steps:
104 | return jnp.full_like(step, self.final_value, dtype=jnp.float32)
105 |
106 | final_value = max(self.final_value, self.eps)
107 | base = final_value / self.initial_value
108 | exponent = step / (self.num_steps - 1)
109 | if step >= self.num_steps:
110 | return jnp.full_like(step, self.final_value, dtype=jnp.float32)
111 | return self.initial_value * base**exponent
112 |
113 |
114 | class CosineEasingSchedule(Schedule):
115 | """Schedule that eases slowsly using a cosine."""
116 |
117 | def __init__(self, initial_value, final_value, num_steps):
118 | super().__init__()
119 | self.initial_value = initial_value
120 | self.final_value = final_value
121 | self.num_steps = num_steps
122 |
123 | def get(self, step):
124 | """Get the value for the given step."""
125 | alpha = jnp.minimum(step / self.num_steps, 1.0)
126 | scale = self.final_value - self.initial_value
127 | x = min(max(alpha, 0.0), 1.0)
128 | return (self.initial_value
129 | + scale * 0.5 * (1 + math.cos(jnp.pi * x + jnp.pi)))
130 |
131 |
132 | class StepSchedule(Schedule):
133 | """Schedule that eases slowsly using a cosine."""
134 |
135 | def __init__(self,
136 | initial_value,
137 | decay_interval,
138 | decay_factor,
139 | max_decays,
140 | final_value=None):
141 | super().__init__()
142 | self.initial_value = initial_value
143 | self.decay_factor = decay_factor
144 | self.decay_interval = decay_interval
145 | self.max_decays = max_decays
146 | if final_value is None:
147 | final_value = self.initial_value * self.decay_factor**self.max_decays
148 | self.final_value = final_value
149 |
150 | def get(self, step):
151 | """Get the value for the given step."""
152 | phase = step // self.decay_interval
153 | if phase >= self.max_decays:
154 | return self.final_value
155 | else:
156 | return self.initial_value * self.decay_factor**phase
157 |
158 |
159 | class PiecewiseSchedule(Schedule):
160 | """A piecewise combination of multiple schedules."""
161 |
162 | def __init__(
163 | self, schedules: Iterable[Tuple[int, Union[Schedule, Iterable[Any]]]]):
164 | self.schedules = [from_config(s) for ms, s in schedules]
165 | milestones = jnp.array([ms for ms, s in schedules])
166 | self.milestones = jnp.cumsum(milestones)[:-1]
167 |
168 | def get(self, step):
169 | idx = jnp.searchsorted(self.milestones, step, side='right')
170 | schedule = self.schedules[idx]
171 | base_idx = self.milestones[idx - 1] if idx >= 1 else 0
172 | return schedule.get(step - base_idx)
173 |
174 |
175 | class DelayedSchedule(Schedule):
176 | """Delays the start of the base schedule."""
177 |
178 | def __init__(self, base_schedule: Schedule, delay_steps, delay_mult):
179 | self.base_schedule = from_config(base_schedule)
180 | self.delay_steps = delay_steps
181 | self.delay_mult = delay_mult
182 |
183 | def get(self, step):
184 | delay_rate = (
185 | self.delay_mult
186 | + (1 - self.delay_mult)
187 | * jnp.sin(0.5 * jnp.pi * jnp.clip(step / self.delay_steps, 0, 1)))
188 |
189 | return delay_rate * self.base_schedule(step)
190 |
191 |
192 | SCHEDULE_MAP = {
193 | 'constant': ConstantSchedule,
194 | 'linear': LinearSchedule,
195 | 'exponential': ExponentialSchedule,
196 | 'cosine_easing': CosineEasingSchedule,
197 | 'step': StepSchedule,
198 | 'piecewise': PiecewiseSchedule,
199 | 'delayed': DelayedSchedule,
200 | }
201 |
--------------------------------------------------------------------------------
/nerfies/tf_camera.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A basic camera implementation in Tensorflow."""
16 | from typing import Tuple, Optional
17 |
18 | import tensorflow as tf
19 | from tensorflow.experimental import numpy as tnp
20 |
21 |
22 | def _norm(x):
23 | return tnp.sqrt(tnp.sum(x ** 2, axis=-1, keepdims=True))
24 |
25 |
26 | def _compute_residual_and_jacobian(
27 | x: tnp.ndarray,
28 | y: tnp.ndarray,
29 | xd: tnp.ndarray,
30 | yd: tnp.ndarray,
31 | k1: float = 0.0,
32 | k2: float = 0.0,
33 | k3: float = 0.0,
34 | p1: float = 0.0,
35 | p2: float = 0.0,
36 | ) -> Tuple[tnp.ndarray, tnp.ndarray, tnp.ndarray, tnp.ndarray, tnp.ndarray,
37 | tnp.ndarray]:
38 | """Auxiliary function of radial_and_tangential_undistort()."""
39 | # let r(x, y) = x^2 + y^2;
40 | # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3;
41 | r = x * x + y * y
42 | d = 1.0 + r * (k1 + r * (k2 + k3 * r))
43 |
44 | # The perfect projection is:
45 | # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
46 | # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);
47 | #
48 | # Let's define
49 | #
50 | # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
51 | # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;
52 | #
53 | # We are looking for a solution that satisfies
54 | # fx(x, y) = fy(x, y) = 0;
55 | fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd
56 | fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd
57 |
58 | # Compute derivative of d over [x, y]
59 | d_r = (k1 + r * (2.0 * k2 + 3.0 * k3 * r))
60 | d_x = 2.0 * x * d_r
61 | d_y = 2.0 * y * d_r
62 |
63 | # Compute derivative of fx over x and y.
64 | fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x
65 | fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y
66 |
67 | # Compute derivative of fy over x and y.
68 | fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x
69 | fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y
70 |
71 | return fx, fy, fx_x, fx_y, fy_x, fy_y
72 |
73 |
74 | def _radial_and_tangential_undistort(
75 | xd: tnp.ndarray,
76 | yd: tnp.ndarray,
77 | k1: float = 0,
78 | k2: float = 0,
79 | k3: float = 0,
80 | p1: float = 0,
81 | p2: float = 0,
82 | eps: float = 1e-9,
83 | max_iterations=10) -> Tuple[tnp.ndarray, tnp.ndarray]:
84 | """Computes undistorted (x, y) from (xd, yd)."""
85 | # Initialize from the distorted point.
86 | x = xd
87 | y = yd
88 |
89 | for _ in range(max_iterations):
90 | fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian(
91 | x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, p1=p1, p2=p2)
92 | denominator = fy_x * fx_y - fx_x * fy_y
93 | x_numerator = fx * fy_y - fy * fx_y
94 | y_numerator = fy * fx_x - fx * fy_x
95 | step_x = tnp.where(
96 | tnp.abs(denominator) > eps, x_numerator / denominator,
97 | tnp.zeros_like(denominator))
98 | step_y = tnp.where(
99 | tnp.abs(denominator) > eps, y_numerator / denominator,
100 | tnp.zeros_like(denominator))
101 |
102 | x = x + step_x
103 | y = y + step_y
104 |
105 | return x, y
106 |
107 |
108 | class TFCamera:
109 | """A duplicate of our JAX-basded camera class.
110 |
111 | This is necessary to use tf.data.Dataset.
112 | """
113 |
114 | def __init__(self,
115 | orientation: tnp.ndarray,
116 | position: tnp.ndarray,
117 | focal_length: float,
118 | principal_point: tnp.ndarray,
119 | image_size: tnp.ndarray,
120 | skew: float = 0.0,
121 | pixel_aspect_ratio: float = 1.0,
122 | radial_distortion: Optional[tnp.ndarray] = None,
123 | tangential_distortion: Optional[tnp.ndarray] = None,
124 | dtype=tnp.float32):
125 | """Constructor for camera class."""
126 | if radial_distortion is None:
127 | radial_distortion = tnp.array([0.0, 0.0, 0.0], dtype)
128 | if tangential_distortion is None:
129 | tangential_distortion = tnp.array([0.0, 0.0], dtype)
130 |
131 | self.orientation = tnp.array(orientation, dtype)
132 | self.position = tnp.array(position, dtype)
133 | self.focal_length = tnp.array(focal_length, dtype)
134 | self.principal_point = tnp.array(principal_point, dtype)
135 | self.skew = tnp.array(skew, dtype)
136 | self.pixel_aspect_ratio = tnp.array(pixel_aspect_ratio, dtype)
137 | self.radial_distortion = tnp.array(radial_distortion, dtype)
138 | self.tangential_distortion = tnp.array(tangential_distortion, dtype)
139 | self.image_size = tnp.array(image_size, dtype)
140 | self.dtype = dtype
141 |
142 | @property
143 | def scale_factor_x(self):
144 | return self.focal_length
145 |
146 | @property
147 | def scale_factor_y(self):
148 | return self.focal_length * self.pixel_aspect_ratio
149 |
150 | @property
151 | def principal_point_x(self):
152 | return self.principal_point[0]
153 |
154 | @property
155 | def principal_point_y(self):
156 | return self.principal_point[1]
157 |
158 | @property
159 | def image_size_y(self):
160 | return self.image_size[1]
161 |
162 | @property
163 | def image_size_x(self):
164 | return self.image_size[0]
165 |
166 | @property
167 | def image_shape(self):
168 | return self.image_size_y, self.image_size_x
169 |
170 | @property
171 | def optical_axis(self):
172 | return self.orientation[2, :]
173 |
174 | def pixel_to_local_rays(self, pixels: tnp.ndarray):
175 | """Returns the local ray directions for the provided pixels."""
176 | y = ((pixels[..., 1] - self.principal_point_y) / self.scale_factor_y)
177 | x = ((pixels[..., 0] - self.principal_point_x - y * self.skew) /
178 | self.scale_factor_x)
179 |
180 | x, y = _radial_and_tangential_undistort(
181 | x,
182 | y,
183 | k1=self.radial_distortion[0],
184 | k2=self.radial_distortion[1],
185 | k3=self.radial_distortion[2],
186 | p1=self.tangential_distortion[0],
187 | p2=self.tangential_distortion[1])
188 |
189 | dirs = tnp.stack([x, y, tnp.ones_like(x)], axis=-1)
190 | return dirs / _norm(dirs)
191 |
192 | def pixels_to_rays(self,
193 | pixels: tnp.ndarray) -> Tuple[tnp.ndarray, tnp.ndarray]:
194 | """Returns the rays for the provided pixels.
195 |
196 | Args:
197 | pixels: [A1, ..., An, 2] tensor or np.array containing 2d pixel positions.
198 |
199 | Returns:
200 | An array containing the normalized ray directions in world coordinates.
201 | """
202 | if pixels.shape[-1] != 2:
203 | raise ValueError('The last dimension of pixels must be 2.')
204 | if pixels.dtype != self.dtype:
205 | raise ValueError(f'pixels dtype ({pixels.dtype!r}) must match camera '
206 | f'dtype ({self.dtype!r})')
207 |
208 | local_rays_dir = self.pixel_to_local_rays(pixels)
209 | rays_dir = tf.linalg.matvec(
210 | self.orientation, local_rays_dir, transpose_a=True)
211 |
212 | # Normalize rays.
213 | rays_dir = rays_dir / _norm(rays_dir)
214 | return rays_dir
215 |
216 | def pixels_to_points(self, pixels: tnp.ndarray, depth: tnp.ndarray):
217 | rays_through_pixels = self.pixels_to_rays(pixels)
218 | cosa = rays_through_pixels @ self.optical_axis
219 | points = (
220 | rays_through_pixels * depth[..., tnp.newaxis] / cosa[..., tnp.newaxis] +
221 | self.position)
222 | return points
223 |
224 | def points_to_local_points(self, points: tnp.ndarray):
225 | translated_points = points - self.position
226 | local_points = (self.orientation @ translated_points.T).T
227 | return local_points
228 |
229 | def get_pixel_centers(self):
230 | """Returns the pixel centers."""
231 | xx, yy = tf.meshgrid(tf.range(self.image_size_x),
232 | tf.range(self.image_size_y))
233 | return tf.cast(tf.stack([xx, yy], axis=-1), self.dtype) + 0.5
234 |
--------------------------------------------------------------------------------
/nerfies/training.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to training NeRFs."""
16 | import functools
17 | from typing import Any
18 | from typing import Callable
19 | from typing import Dict
20 |
21 | from absl import logging
22 | from flax import struct
23 | from flax.training import checkpoints
24 | import jax
25 | from jax import lax
26 | from jax import numpy as jnp
27 | from jax import random
28 | from jax import vmap
29 |
30 | from nerfies import model_utils
31 | from nerfies import models
32 | from nerfies import utils
33 |
34 |
35 | @struct.dataclass
36 | class ScalarParams:
37 | learning_rate: float
38 | elastic_loss_weight: float = 0.0
39 | warp_reg_loss_weight: float = 0.0
40 | warp_reg_loss_alpha: float = -2.0
41 | warp_reg_loss_scale: float = 0.001
42 | background_loss_weight: float = 0.0
43 | background_noise_std: float = 0.001
44 |
45 |
46 | def save_checkpoint(path, state, keep=2):
47 | """Save the state to a checkpoint."""
48 | state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
49 | step = state_to_save.optimizer.state.step
50 | checkpoint_path = checkpoints.save_checkpoint(
51 | path, state_to_save, step, keep=keep)
52 | logging.info('Saved checkpoint: step=%d, path=%s', int(step), checkpoint_path)
53 | return checkpoint_path
54 |
55 |
56 | @jax.jit
57 | def nearest_rotation_svd(matrix, eps=1e-6):
58 | """Computes the nearest rotation using SVD."""
59 | # TODO(keunhong): Currently this produces NaNs for some reason.
60 | u, _, vh = jnp.linalg.svd(matrix + eps, compute_uv=True, full_matrices=False)
61 | # Handle the case when there is a flip.
62 | # M will be the identity matrix except when det(UV^T) = -1
63 | # in which case the last diagonal of M will be -1.
64 | det = jnp.linalg.det(u @ vh)
65 | m = jnp.stack([jnp.ones_like(det), jnp.ones_like(det), det], axis=-1)
66 | m = jnp.diag(m)
67 | r = u @ m @ vh
68 | return r
69 |
70 |
71 | def compute_elastic_loss(jacobian, eps=1e-6, loss_type='log_svals'):
72 | """Compute the elastic regularization loss.
73 |
74 | The loss is given by sum(log(S)^2). This penalizes the singular values
75 | when they deviate from the identity since log(1) = 0.0,
76 | where D is the diagonal matrix containing the singular values.
77 |
78 | Args:
79 | jacobian: the Jacobian of the point transformation.
80 | eps: a small value to prevent taking the log of zero.
81 | loss_type: which elastic loss type to use.
82 |
83 | Returns:
84 | The elastic regularization loss.
85 | """
86 | if loss_type == 'log_svals':
87 | svals = jnp.linalg.svd(jacobian, compute_uv=False)
88 | log_svals = jnp.log(jnp.maximum(svals, eps))
89 | sq_residual = jnp.sum(log_svals**2, axis=-1)
90 | elif loss_type == 'svals':
91 | svals = jnp.linalg.svd(jacobian, compute_uv=False)
92 | sq_residual = jnp.sum((svals - 1.0)**2, axis=-1)
93 | elif loss_type == 'jtj':
94 | jtj = jacobian @ jacobian.T
95 | sq_residual = ((jtj - jnp.eye(3)) ** 2).sum() / 4.0
96 | elif loss_type == 'div':
97 | div = utils.jacobian_to_div(jacobian)
98 | sq_residual = div ** 2
99 | elif loss_type == 'det':
100 | det = jnp.linalg.det(jacobian)
101 | sq_residual = (det - 1.0) ** 2
102 | elif loss_type == 'log_det':
103 | det = jnp.linalg.det(jacobian)
104 | sq_residual = jnp.log(jnp.maximum(det, eps)) ** 2
105 | elif loss_type == 'nr':
106 | rot = nearest_rotation_svd(jacobian)
107 | sq_residual = jnp.sum((jacobian - rot) ** 2)
108 | else:
109 | raise NotImplementedError(
110 | f'Unknown elastic loss type {loss_type!r}')
111 | residual = jnp.sqrt(sq_residual)
112 | loss = utils.general_loss_with_squared_residual(
113 | sq_residual, alpha=-2.0, scale=0.03)
114 | return loss, residual
115 |
116 |
117 | @functools.partial(jax.jit, static_argnums=0)
118 | def compute_background_loss(
119 | model, state, params, key, points, noise_std, alpha=-2, scale=0.001):
120 | """Compute the background regularization loss."""
121 | metadata = random.choice(key,
122 | jnp.array(model.warp_ids, jnp.uint32),
123 | shape=(points.shape[0], 1))
124 | point_noise = noise_std * random.normal(key, points.shape)
125 | points = points + point_noise
126 |
127 | warp_field = model.create_warp_field(model, num_batch_dims=1)
128 | warp_out = warp_field.apply(
129 | {'params': params['warp_field']},
130 | points, metadata, state.warp_extra, False, False)
131 | warped_points = warp_out['warped_points'][..., :3]
132 | sq_residual = jnp.sum((warped_points - points)**2, axis=-1)
133 | loss = utils.general_loss_with_squared_residual(
134 | sq_residual, alpha=alpha, scale=scale)
135 | return loss
136 |
137 |
138 | def train_step(model: models.NerfModel,
139 | rng_key: Callable[[int], jnp.ndarray],
140 | state,
141 | batch: Dict[str, Any],
142 | scalar_params: ScalarParams,
143 | use_elastic_loss: bool = False,
144 | elastic_reduce_method: str = 'median',
145 | elastic_loss_type: str = 'log_svals',
146 | use_background_loss: bool = False,
147 | use_warp_reg_loss: bool = False):
148 | """One optimization step.
149 |
150 | Args:
151 | model: the model module to evaluate.
152 | rng_key: The random number generator.
153 | state: model_utils.TrainState, state of model and optimizer.
154 | batch: dict. A mini-batch of data for training.
155 | scalar_params: scalar-valued parameters.
156 | use_elastic_loss: is True use the elastic regularization loss.
157 | elastic_reduce_method: which method to use to reduce the samples for the
158 | elastic loss. 'median' selects the median depth point sample while
159 | 'weight' computes a weighted sum using the density weights.
160 | elastic_loss_type: which method to use for the elastic loss.
161 | use_background_loss: if True use the background regularization loss.
162 | use_warp_reg_loss: if True use the warp regularization loss.
163 |
164 | Returns:
165 | new_state: model_utils.TrainState, new training state.
166 | stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
167 | """
168 | rng_key, fine_key, coarse_key, reg_key = random.split(rng_key, 4)
169 |
170 | # pylint: disable=unused-argument
171 | def _compute_loss_and_stats(params, model_out, use_elastic_loss=False):
172 | rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean()
173 | stats = {
174 | 'loss/rgb': rgb_loss,
175 | }
176 | loss = rgb_loss
177 | if use_elastic_loss:
178 | elastic_fn = functools.partial(compute_elastic_loss,
179 | loss_type=elastic_loss_type)
180 | v_elastic_fn = jax.jit(vmap(vmap(jax.jit(elastic_fn))))
181 | weights = lax.stop_gradient(model_out['weights'])
182 | jacobian = model_out['warp_jacobian']
183 | # Pick the median point Jacobian.
184 | if elastic_reduce_method == 'median':
185 | depth_indices = model_utils.compute_depth_index(weights)
186 | jacobian = jnp.take_along_axis(
187 | # Unsqueeze axes: sample axis, Jacobian row, Jacobian col.
188 | jacobian, depth_indices[..., None, None, None], axis=-3)
189 | # Compute loss using Jacobian.
190 | elastic_loss, elastic_residual = v_elastic_fn(jacobian)
191 | # Multiply weight if weighting by density.
192 | if elastic_reduce_method == 'weight':
193 | elastic_loss = weights * elastic_loss
194 | elastic_loss = elastic_loss.sum(axis=-1).mean()
195 | stats['loss/elastic'] = elastic_loss
196 | stats['residual/elastic'] = jnp.mean(elastic_residual)
197 | loss += scalar_params.elastic_loss_weight * elastic_loss
198 |
199 | if use_warp_reg_loss:
200 | weights = lax.stop_gradient(model_out['weights'])
201 | depth_indices = model_utils.compute_depth_index(weights)
202 | warp_mag = (
203 | (model_out['points'] - model_out['warped_points']) ** 2).sum(axis=-1)
204 | warp_reg_residual = jnp.take_along_axis(
205 | warp_mag, depth_indices[..., None], axis=-1)
206 | warp_reg_loss = utils.general_loss_with_squared_residual(
207 | warp_reg_residual,
208 | alpha=scalar_params.warp_reg_loss_alpha,
209 | scale=scalar_params.warp_reg_loss_scale).mean()
210 | stats['loss/warp_reg'] = warp_reg_loss
211 | stats['residual/warp_reg'] = jnp.mean(jnp.sqrt(warp_reg_residual))
212 | loss += scalar_params.warp_reg_loss_weight * warp_reg_loss
213 |
214 | if 'warp_jacobian' in model_out:
215 | jacobian = model_out['warp_jacobian']
216 | jacobian_det = jnp.linalg.det(jacobian)
217 | jacobian_div = utils.jacobian_to_div(jacobian)
218 | jacobian_curl = utils.jacobian_to_curl(jacobian)
219 | stats['metric/jacobian_det'] = jnp.mean(jacobian_det)
220 | stats['metric/jacobian_div'] = jnp.mean(jacobian_div)
221 | stats['metric/jacobian_curl'] = jnp.mean(
222 | jnp.linalg.norm(jacobian_curl, axis=-1))
223 |
224 | stats['loss/total'] = loss
225 | stats['metric/psnr'] = utils.compute_psnr(rgb_loss)
226 | return loss, stats
227 |
228 | def _loss_fn(params):
229 | ret = model.apply({'params': params['model']},
230 | batch,
231 | warp_extra=state.warp_extra,
232 | return_points=use_warp_reg_loss,
233 | return_weights=(use_warp_reg_loss or use_elastic_loss),
234 | rngs={
235 | 'fine': fine_key,
236 | 'coarse': coarse_key
237 | })
238 |
239 | losses = {}
240 | stats = {}
241 | if 'fine' in ret:
242 | losses['fine'], stats['fine'] = _compute_loss_and_stats(
243 | params, ret['fine'])
244 | if 'coarse' in ret:
245 | losses['coarse'], stats['coarse'] = _compute_loss_and_stats(
246 | params, ret['coarse'], use_elastic_loss=use_elastic_loss)
247 |
248 | if use_background_loss:
249 | background_loss = compute_background_loss(
250 | model,
251 | state=state,
252 | params=params['model'],
253 | key=reg_key,
254 | points=batch['background_points'],
255 | noise_std=scalar_params.background_noise_std)
256 | background_loss = background_loss.mean()
257 | losses['background'] = (
258 | scalar_params.background_loss_weight * background_loss)
259 | stats['background_loss'] = background_loss
260 |
261 | return sum(losses.values()), stats
262 |
263 | optimizer = state.optimizer
264 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
265 | (_, stats), grad = grad_fn(optimizer.target)
266 | grad = jax.lax.pmean(grad, axis_name='batch')
267 | stats = jax.lax.pmean(stats, axis_name='batch')
268 | new_optimizer = optimizer.apply_gradient(
269 | grad, learning_rate=scalar_params.learning_rate)
270 | new_state = state.replace(optimizer=new_optimizer)
271 | return new_state, stats, rng_key
272 |
--------------------------------------------------------------------------------
/nerfies/types.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Custom type annotations."""
16 | import pathlib
17 | from typing import Any, Callable, Tuple, Text, Union
18 |
19 | PRNGKey = Any
20 | Shape = Tuple[int]
21 | Dtype = Any # this could be a real type?
22 | Array = Any
23 |
24 | Activation = Callable[[Array], Array]
25 | Initializer = Callable[[PRNGKey, Shape, Dtype], Array]
26 |
27 | PathType = Union[Text, pathlib.PurePosixPath]
28 |
--------------------------------------------------------------------------------
/nerfies/visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Visualization utilities."""
16 | import functools
17 |
18 | from matplotlib import cm
19 | from matplotlib.colors import LinearSegmentedColormap
20 | import numpy as np
21 |
22 |
23 | # Turbo colors by Anton Mikhailov.
24 | # Copyright 2019 Google LLC.
25 | # https://gist.github.com/mikhailov-work/ee72ba4191942acecc03fe6da94fc73f
26 | _TURBO_COLORS = np.array(
27 | [[0.18995, 0.07176, 0.23217], [0.19483, 0.08339, 0.26149],
28 | [0.19956, 0.09498, 0.29024], [0.20415, 0.10652, 0.31844],
29 | [0.20860, 0.11802, 0.34607], [0.21291, 0.12947, 0.37314],
30 | [0.21708, 0.14087, 0.39964], [0.22111, 0.15223, 0.42558],
31 | [0.22500, 0.16354, 0.45096], [0.22875, 0.17481, 0.47578],
32 | [0.23236, 0.18603, 0.50004], [0.23582, 0.19720, 0.52373],
33 | [0.23915, 0.20833, 0.54686], [0.24234, 0.21941, 0.56942],
34 | [0.24539, 0.23044, 0.59142], [0.24830, 0.24143, 0.61286],
35 | [0.25107, 0.25237, 0.63374], [0.25369, 0.26327, 0.65406],
36 | [0.25618, 0.27412, 0.67381], [0.25853, 0.28492, 0.69300],
37 | [0.26074, 0.29568, 0.71162], [0.26280, 0.30639, 0.72968],
38 | [0.26473, 0.31706, 0.74718], [0.26652, 0.32768, 0.76412],
39 | [0.26816, 0.33825, 0.78050], [0.26967, 0.34878, 0.79631],
40 | [0.27103, 0.35926, 0.81156], [0.27226, 0.36970, 0.82624],
41 | [0.27334, 0.38008, 0.84037], [0.27429, 0.39043, 0.85393],
42 | [0.27509, 0.40072, 0.86692], [0.27576, 0.41097, 0.87936],
43 | [0.27628, 0.42118, 0.89123], [0.27667, 0.43134, 0.90254],
44 | [0.27691, 0.44145, 0.91328], [0.27701, 0.45152, 0.92347],
45 | [0.27698, 0.46153, 0.93309], [0.27680, 0.47151, 0.94214],
46 | [0.27648, 0.48144, 0.95064], [0.27603, 0.49132, 0.95857],
47 | [0.27543, 0.50115, 0.96594], [0.27469, 0.51094, 0.97275],
48 | [0.27381, 0.52069, 0.97899], [0.27273, 0.53040, 0.98461],
49 | [0.27106, 0.54015, 0.98930], [0.26878, 0.54995, 0.99303],
50 | [0.26592, 0.55979, 0.99583], [0.26252, 0.56967, 0.99773],
51 | [0.25862, 0.57958, 0.99876], [0.25425, 0.58950, 0.99896],
52 | [0.24946, 0.59943, 0.99835], [0.24427, 0.60937, 0.99697],
53 | [0.23874, 0.61931, 0.99485], [0.23288, 0.62923, 0.99202],
54 | [0.22676, 0.63913, 0.98851], [0.22039, 0.64901, 0.98436],
55 | [0.21382, 0.65886, 0.97959], [0.20708, 0.66866, 0.97423],
56 | [0.20021, 0.67842, 0.96833], [0.19326, 0.68812, 0.96190],
57 | [0.18625, 0.69775, 0.95498], [0.17923, 0.70732, 0.94761],
58 | [0.17223, 0.71680, 0.93981], [0.16529, 0.72620, 0.93161],
59 | [0.15844, 0.73551, 0.92305], [0.15173, 0.74472, 0.91416],
60 | [0.14519, 0.75381, 0.90496], [0.13886, 0.76279, 0.89550],
61 | [0.13278, 0.77165, 0.88580], [0.12698, 0.78037, 0.87590],
62 | [0.12151, 0.78896, 0.86581], [0.11639, 0.79740, 0.85559],
63 | [0.11167, 0.80569, 0.84525], [0.10738, 0.81381, 0.83484],
64 | [0.10357, 0.82177, 0.82437], [0.10026, 0.82955, 0.81389],
65 | [0.09750, 0.83714, 0.80342], [0.09532, 0.84455, 0.79299],
66 | [0.09377, 0.85175, 0.78264], [0.09287, 0.85875, 0.77240],
67 | [0.09267, 0.86554, 0.76230], [0.09320, 0.87211, 0.75237],
68 | [0.09451, 0.87844, 0.74265], [0.09662, 0.88454, 0.73316],
69 | [0.09958, 0.89040, 0.72393], [0.10342, 0.89600, 0.71500],
70 | [0.10815, 0.90142, 0.70599], [0.11374, 0.90673, 0.69651],
71 | [0.12014, 0.91193, 0.68660], [0.12733, 0.91701, 0.67627],
72 | [0.13526, 0.92197, 0.66556], [0.14391, 0.92680, 0.65448],
73 | [0.15323, 0.93151, 0.64308], [0.16319, 0.93609, 0.63137],
74 | [0.17377, 0.94053, 0.61938], [0.18491, 0.94484, 0.60713],
75 | [0.19659, 0.94901, 0.59466], [0.20877, 0.95304, 0.58199],
76 | [0.22142, 0.95692, 0.56914], [0.23449, 0.96065, 0.55614],
77 | [0.24797, 0.96423, 0.54303], [0.26180, 0.96765, 0.52981],
78 | [0.27597, 0.97092, 0.51653], [0.29042, 0.97403, 0.50321],
79 | [0.30513, 0.97697, 0.48987], [0.32006, 0.97974, 0.47654],
80 | [0.33517, 0.98234, 0.46325], [0.35043, 0.98477, 0.45002],
81 | [0.36581, 0.98702, 0.43688], [0.38127, 0.98909, 0.42386],
82 | [0.39678, 0.99098, 0.41098], [0.41229, 0.99268, 0.39826],
83 | [0.42778, 0.99419, 0.38575], [0.44321, 0.99551, 0.37345],
84 | [0.45854, 0.99663, 0.36140], [0.47375, 0.99755, 0.34963],
85 | [0.48879, 0.99828, 0.33816], [0.50362, 0.99879, 0.32701],
86 | [0.51822, 0.99910, 0.31622], [0.53255, 0.99919, 0.30581],
87 | [0.54658, 0.99907, 0.29581], [0.56026, 0.99873, 0.28623],
88 | [0.57357, 0.99817, 0.27712], [0.58646, 0.99739, 0.26849],
89 | [0.59891, 0.99638, 0.26038], [0.61088, 0.99514, 0.25280],
90 | [0.62233, 0.99366, 0.24579], [0.63323, 0.99195, 0.23937],
91 | [0.64362, 0.98999, 0.23356], [0.65394, 0.98775, 0.22835],
92 | [0.66428, 0.98524, 0.22370], [0.67462, 0.98246, 0.21960],
93 | [0.68494, 0.97941, 0.21602], [0.69525, 0.97610, 0.21294],
94 | [0.70553, 0.97255, 0.21032], [0.71577, 0.96875, 0.20815],
95 | [0.72596, 0.96470, 0.20640], [0.73610, 0.96043, 0.20504],
96 | [0.74617, 0.95593, 0.20406], [0.75617, 0.95121, 0.20343],
97 | [0.76608, 0.94627, 0.20311], [0.77591, 0.94113, 0.20310],
98 | [0.78563, 0.93579, 0.20336], [0.79524, 0.93025, 0.20386],
99 | [0.80473, 0.92452, 0.20459], [0.81410, 0.91861, 0.20552],
100 | [0.82333, 0.91253, 0.20663], [0.83241, 0.90627, 0.20788],
101 | [0.84133, 0.89986, 0.20926], [0.85010, 0.89328, 0.21074],
102 | [0.85868, 0.88655, 0.21230], [0.86709, 0.87968, 0.21391],
103 | [0.87530, 0.87267, 0.21555], [0.88331, 0.86553, 0.21719],
104 | [0.89112, 0.85826, 0.21880], [0.89870, 0.85087, 0.22038],
105 | [0.90605, 0.84337, 0.22188], [0.91317, 0.83576, 0.22328],
106 | [0.92004, 0.82806, 0.22456], [0.92666, 0.82025, 0.22570],
107 | [0.93301, 0.81236, 0.22667], [0.93909, 0.80439, 0.22744],
108 | [0.94489, 0.79634, 0.22800], [0.95039, 0.78823, 0.22831],
109 | [0.95560, 0.78005, 0.22836], [0.96049, 0.77181, 0.22811],
110 | [0.96507, 0.76352, 0.22754], [0.96931, 0.75519, 0.22663],
111 | [0.97323, 0.74682, 0.22536], [0.97679, 0.73842, 0.22369],
112 | [0.98000, 0.73000, 0.22161], [0.98289, 0.72140, 0.21918],
113 | [0.98549, 0.71250, 0.21650], [0.98781, 0.70330, 0.21358],
114 | [0.98986, 0.69382, 0.21043], [0.99163, 0.68408, 0.20706],
115 | [0.99314, 0.67408, 0.20348], [0.99438, 0.66386, 0.19971],
116 | [0.99535, 0.65341, 0.19577], [0.99607, 0.64277, 0.19165],
117 | [0.99654, 0.63193, 0.18738], [0.99675, 0.62093, 0.18297],
118 | [0.99672, 0.60977, 0.17842], [0.99644, 0.59846, 0.17376],
119 | [0.99593, 0.58703, 0.16899], [0.99517, 0.57549, 0.16412],
120 | [0.99419, 0.56386, 0.15918], [0.99297, 0.55214, 0.15417],
121 | [0.99153, 0.54036, 0.14910], [0.98987, 0.52854, 0.14398],
122 | [0.98799, 0.51667, 0.13883], [0.98590, 0.50479, 0.13367],
123 | [0.98360, 0.49291, 0.12849], [0.98108, 0.48104, 0.12332],
124 | [0.97837, 0.46920, 0.11817], [0.97545, 0.45740, 0.11305],
125 | [0.97234, 0.44565, 0.10797], [0.96904, 0.43399, 0.10294],
126 | [0.96555, 0.42241, 0.09798], [0.96187, 0.41093, 0.09310],
127 | [0.95801, 0.39958, 0.08831], [0.95398, 0.38836, 0.08362],
128 | [0.94977, 0.37729, 0.07905], [0.94538, 0.36638, 0.07461],
129 | [0.94084, 0.35566, 0.07031], [0.93612, 0.34513, 0.06616],
130 | [0.93125, 0.33482, 0.06218], [0.92623, 0.32473, 0.05837],
131 | [0.92105, 0.31489, 0.05475], [0.91572, 0.30530, 0.05134],
132 | [0.91024, 0.29599, 0.04814], [0.90463, 0.28696, 0.04516],
133 | [0.89888, 0.27824, 0.04243], [0.89298, 0.26981, 0.03993],
134 | [0.88691, 0.26152, 0.03753], [0.88066, 0.25334, 0.03521],
135 | [0.87422, 0.24526, 0.03297], [0.86760, 0.23730, 0.03082],
136 | [0.86079, 0.22945, 0.02875], [0.85380, 0.22170, 0.02677],
137 | [0.84662, 0.21407, 0.02487], [0.83926, 0.20654, 0.02305],
138 | [0.83172, 0.19912, 0.02131], [0.82399, 0.19182, 0.01966],
139 | [0.81608, 0.18462, 0.01809], [0.80799, 0.17753, 0.01660],
140 | [0.79971, 0.17055, 0.01520], [0.79125, 0.16368, 0.01387],
141 | [0.78260, 0.15693, 0.01264], [0.77377, 0.15028, 0.01148],
142 | [0.76476, 0.14374, 0.01041], [0.75556, 0.13731, 0.00942],
143 | [0.74617, 0.13098, 0.00851], [0.73661, 0.12477, 0.00769],
144 | [0.72686, 0.11867, 0.00695], [0.71692, 0.11268, 0.00629],
145 | [0.70680, 0.10680, 0.00571], [0.69650, 0.10102, 0.00522],
146 | [0.68602, 0.09536, 0.00481], [0.67535, 0.08980, 0.00449],
147 | [0.66449, 0.08436, 0.00424], [0.65345, 0.07902, 0.00408],
148 | [0.64223, 0.07380, 0.00401], [0.63082, 0.06868, 0.00401],
149 | [0.61923, 0.06367, 0.00410], [0.60746, 0.05878, 0.00427],
150 | [0.59550, 0.05399, 0.00453], [0.58336, 0.04931, 0.00486],
151 | [0.57103, 0.04474, 0.00529], [0.55852, 0.04028, 0.00579],
152 | [0.54583, 0.03593, 0.00638], [0.53295, 0.03169, 0.00705],
153 | [0.51989, 0.02756, 0.00780], [0.50664, 0.02354, 0.00863],
154 | [0.49321, 0.01963, 0.00955], [0.47960, 0.01583, 0.01055]])
155 |
156 | _colormap_cache = {}
157 |
158 |
159 | def _build_colormap(name, num_bins=256):
160 | base = cm.get_cmap(name)
161 | color_list = base(np.linspace(0, 1, num_bins))
162 | cmap_name = base.name + str(num_bins)
163 | colormap = LinearSegmentedColormap.from_list(cmap_name, color_list, num_bins)
164 | colormap = colormap(np.linspace(0, 1, num_bins))[:, :3]
165 | return colormap
166 |
167 |
168 | @functools.lru_cache(maxsize=32)
169 | def get_colormap(name, num_bins=256):
170 | """Lazily initializes and returns a colormap."""
171 | if name == 'turbo':
172 | return _TURBO_COLORS
173 |
174 | return _build_colormap(name, num_bins)
175 |
176 |
177 | def interpolate_colormap(values, colormap):
178 | """Interpolates the colormap given values between 0.0 and 1.0."""
179 | a = np.floor(values * 255)
180 | b = (a + 1).clip(max=255)
181 | f = values * 255.0 - a
182 | a = a.astype(np.uint16).clip(0, 255)
183 | b = b.astype(np.uint16).clip(0, 255)
184 | return colormap[a] + (colormap[b] - colormap[a]) * f[..., np.newaxis]
185 |
186 |
187 | def scale_values(values, vmin, vmax, eps=1e-6):
188 | return (values - vmin) / max(vmax - vmin, eps)
189 |
190 |
191 | def colorize(
192 | array, cmin=None, cmax=None, cmap='magma', eps=1e-6, invert=False):
193 | """Applies a colormap to an array.
194 |
195 | Args:
196 | array: the array to apply a colormap to.
197 | cmin: the minimum value of the colormap. If None will take the min.
198 | cmax: the maximum value of the colormap. If None will take the max.
199 | cmap: the color mapping to use.
200 | eps: a small value to prevent divide by zero.
201 | invert: if True will invert the colormap.
202 |
203 | Returns:
204 | a color mapped version of array.
205 | """
206 | array = np.asarray(array)
207 |
208 | if cmin is None:
209 | cmin = array.min()
210 | if cmax is None:
211 | cmax = array.max()
212 |
213 | x = scale_values(array, cmin, cmax, eps)
214 | colormap = get_colormap(cmap)
215 | colorized = interpolate_colormap(1.0 - x if invert else x, colormap)
216 | colorized[x > 1.0] = 0.0 if invert else 1.0
217 | colorized[x < 0.0] = 1.0 if invert else 0.0
218 |
219 | return colorized
220 |
221 |
222 | def colorize_binary_logits(array, cmap=None):
223 | """Colorizes binary logits as a segmentation map."""
224 | num_classes = array.shape[-1]
225 | if cmap is None:
226 | if num_classes <= 8:
227 | cmap = 'Set3'
228 | elif num_classes <= 10:
229 | cmap = 'tab10'
230 | elif num_classes <= 20:
231 | cmap = 'tab20'
232 | else:
233 | cmap = 'gist_rainbow'
234 |
235 | colormap = get_colormap(cmap, num_classes)
236 | indices = np.argmax(array, axis=-1)
237 | return np.take(colormap, indices, axis=0)
238 |
--------------------------------------------------------------------------------
/nerfies/warping.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Warp fields."""
16 | from typing import Any, Iterable, Optional, Dict
17 |
18 | from flax import linen as nn
19 | import jax
20 | import jax.numpy as jnp
21 |
22 | from nerfies import glo
23 | from nerfies import model_utils
24 | from nerfies import modules
25 | from nerfies import rigid_body as rigid
26 | from nerfies import types
27 |
28 |
29 | def create_warp_field(
30 | field_type: str,
31 | num_freqs: int,
32 | num_embeddings: int,
33 | num_features: int,
34 | num_batch_dims: int,
35 | **kwargs):
36 | """Factory function for warp fields."""
37 | kwargs = {**kwargs}
38 | if field_type == 'translation':
39 | warp_field_cls = TranslationField
40 | elif field_type == 'se3':
41 | warp_field_cls = SE3Field
42 | else:
43 | raise ValueError(f'Unknown warp field type: {field_type!r}')
44 |
45 | if num_batch_dims > 0:
46 | v_warp_field_cls = model_utils.vmap_module(
47 | warp_field_cls,
48 | num_batch_dims=num_batch_dims,
49 | # (points, metadata, extras, return_jacobian,
50 | # metadata_encoded).
51 | in_axes=(0, 0, None, None, None))
52 | else:
53 | v_warp_field_cls = warp_field_cls
54 |
55 | return v_warp_field_cls(
56 | num_freqs=num_freqs,
57 | num_embeddings=num_embeddings,
58 | num_embedding_features=num_features,
59 | **kwargs)
60 |
61 |
62 | class TranslationField(nn.Module):
63 | """Network that predicts warps as a translation field.
64 |
65 | References:
66 | https://en.wikipedia.org/wiki/Vector_potential
67 | https://en.wikipedia.org/wiki/Helmholtz_decomposition
68 |
69 | Attributes:
70 | points_encoder: the positional encoder for the points.
71 | metadata_encoder: an encoder for metadata.
72 | alpha: the alpha for the positional encoding.
73 | skips: the index of the layers with skip connections.
74 | depth: the depth of the network excluding the output layer.
75 | hidden_channels: the width of the network hidden layers.
76 | activation: the activation for each layer.
77 | metadata_encoded: whether the metadata parameter is pre-encoded or not.
78 | hidden_initializer: the initializer for the hidden layers.
79 | output_initializer: the initializer for the last output layer.
80 | """
81 | num_freqs: int
82 | num_embeddings: int
83 | num_embedding_features: int
84 | min_freq_log2: int = 0
85 | max_freq_log2: Optional[int] = None
86 | use_identity_map: bool = True
87 |
88 | metadata_encoder_type: str = 'glo'
89 | metadata_encoder_num_freqs: int = 1
90 |
91 | skips: Iterable[int] = (4,)
92 | depth: int = 6
93 | hidden_channels: int = 128
94 | activation: types.Activation = nn.relu
95 | hidden_init: types.Initializer = nn.initializers.xavier_uniform()
96 | output_init: types.Initializer = nn.initializers.uniform(scale=1e-4)
97 |
98 | def setup(self):
99 | self.points_encoder = modules.AnnealedSinusoidalEncoder(
100 | num_freqs=self.num_freqs,
101 | min_freq_log2=self.min_freq_log2,
102 | max_freq_log2=self.max_freq_log2,
103 | use_identity=self.use_identity_map)
104 |
105 | if self.metadata_encoder_type == 'glo':
106 | self.metadata_encoder = glo.GloEncoder(
107 | num_embeddings=self.num_embeddings,
108 | features=self.num_embedding_features)
109 | elif self.metadata_encoder_type == 'time':
110 | self.metadata_encoder = modules.TimeEncoder(
111 | num_freqs=self.metadata_encoder_num_freqs,
112 | features=self.num_embedding_features)
113 | elif self.metadata_encoder_type == 'blend':
114 | self.glo_encoder = glo.GloEncoder(
115 | num_embeddings=self.num_embeddings,
116 | features=self.num_embedding_features)
117 | self.time_encoder = modules.TimeEncoder(
118 | num_freqs=self.metadata_encoder_num_freqs,
119 | features=self.num_embedding_features)
120 | else:
121 | raise ValueError(
122 | f'Unknown metadata encoder type {self.metadata_encoder_type}')
123 |
124 | # Note that this must be done this way instead of using mutable list
125 | # operations.
126 | # See https://github.com/google/flax/issues/524.
127 | # pylint: disable=g-complex-comprehension
128 | output_dims = 3
129 | self.mlp = modules.MLP(
130 | width=self.hidden_channels,
131 | depth=self.depth,
132 | skips=self.skips,
133 | hidden_init=self.hidden_init,
134 | output_init=self.output_init,
135 | output_channels=output_dims)
136 |
137 | def encode_metadata(self,
138 | metadata: jnp.ndarray,
139 | time_alpha: Optional[float] = None):
140 | if self.metadata_encoder_type == 'time':
141 | metadata_embed = self.metadata_encoder(metadata, time_alpha)
142 | elif self.metadata_encoder_type == 'blend':
143 | glo_embed = self.glo_encoder(metadata)
144 | time_embed = self.time_encoder(metadata)
145 | metadata_embed = ((1.0 - time_alpha) * glo_embed +
146 | time_alpha * time_embed)
147 | elif self.metadata_encoder_type == 'glo':
148 | metadata_embed = self.metadata_encoder(metadata)
149 | else:
150 | raise RuntimeError(
151 | f'Unknown metadata encoder type {self.metadata_encoder_type}')
152 |
153 | return metadata_embed
154 |
155 | def warp(self,
156 | points: jnp.ndarray,
157 | metadata_embed: jnp.ndarray,
158 | extra: Dict[str, Any]):
159 | points_embed = self.points_encoder(points, alpha=extra.get('alpha'))
160 | inputs = jnp.concatenate([points_embed, metadata_embed], axis=-1)
161 | translation = self.mlp(inputs)
162 | warped_points = points + translation
163 |
164 | return warped_points
165 |
166 | def __call__(self,
167 | points: jnp.ndarray,
168 | metadata: jnp.ndarray,
169 | extra: Dict[str, Any],
170 | return_jacobian: bool = False,
171 | metadata_encoded: bool = False):
172 | """Warp the given points using a warp field.
173 |
174 | Args:
175 | points: the points to warp.
176 | metadata: metadata indices if metadata_encoded is False else pre-encoded
177 | metadata.
178 | extra: extra parameters used in the warp field e.g., the warp alpha.
179 | return_jacobian: if True compute and return the Jacobian of the warp.
180 | metadata_encoded: if True assumes the metadata is already encoded.
181 |
182 | Returns:
183 | The warped points and the Jacobian of the warp if `return_jacobian` is
184 | True.
185 | """
186 | if metadata_encoded:
187 | metadata_embed = metadata
188 | else:
189 | metadata_embed = self.encode_metadata(metadata, extra.get('time_alpha'))
190 |
191 | out = {
192 | 'warped_points': self.warp(points, metadata_embed, extra)
193 | }
194 |
195 | if return_jacobian:
196 | jac_fn = jax.jacfwd(lambda *x: self.warp(*x)[..., :3], argnums=0)
197 | out['jacobian'] = jac_fn(points, metadata_embed, extra)
198 |
199 | return out
200 |
201 |
202 | class SE3Field(nn.Module):
203 | """Network that predicts warps as an SE(3) field.
204 |
205 | Attributes:
206 | points_encoder: the positional encoder for the points.
207 | metadata_encoder: an encoder for metadata.
208 | alpha: the alpha for the positional encoding.
209 | skips: the index of the layers with skip connections.
210 | depth: the depth of the network excluding the logit layer.
211 | hidden_channels: the width of the network hidden layers.
212 | activation: the activation for each layer.
213 | metadata_encoded: whether the metadata parameter is pre-encoded or not.
214 | hidden_initializer: the initializer for the hidden layers.
215 | output_initializer: the initializer for the last logit layer.
216 | """
217 | num_freqs: int
218 | num_embeddings: int
219 | num_embedding_features: int
220 | min_freq_log2: int = 0
221 | max_freq_log2: Optional[int] = None
222 | use_identity_map: bool = True
223 |
224 | activation: types.Activation = nn.relu
225 | skips: Iterable[int] = (4,)
226 | trunk_depth: int = 6
227 | trunk_width: int = 128
228 | rotation_depth: int = 0
229 | rotation_width: int = 128
230 | pivot_depth: int = 0
231 | pivot_width: int = 128
232 | translation_depth: int = 0
233 | translation_width: int = 128
234 | metadata_encoder_type: str = 'glo'
235 | metadata_encoder_num_freqs: int = 1
236 |
237 | default_init: types.Initializer = nn.initializers.xavier_uniform()
238 | rotation_init: types.Initializer = nn.initializers.uniform(scale=1e-4)
239 | pivot_init: types.Initializer = nn.initializers.uniform(scale=1e-4)
240 | translation_init: types.Initializer = nn.initializers.uniform(scale=1e-4)
241 |
242 | use_pivot: bool = False
243 | use_translation: bool = False
244 |
245 | def setup(self):
246 | self.points_encoder = modules.AnnealedSinusoidalEncoder(
247 | num_freqs=self.num_freqs,
248 | min_freq_log2=self.min_freq_log2,
249 | max_freq_log2=self.max_freq_log2,
250 | use_identity=self.use_identity_map)
251 |
252 | if self.metadata_encoder_type == 'glo':
253 | self.metadata_encoder = glo.GloEncoder(
254 | num_embeddings=self.num_embeddings,
255 | features=self.num_embedding_features)
256 | elif self.metadata_encoder_type == 'time':
257 | self.metadata_encoder = modules.TimeEncoder(
258 | num_freqs=self.metadata_encoder_num_freqs,
259 | features=self.num_embedding_features)
260 | else:
261 | raise ValueError(
262 | f'Unknown metadata encoder type {self.metadata_encoder_type}')
263 |
264 | self.trunk = modules.MLP(
265 | depth=self.trunk_depth,
266 | width=self.trunk_width,
267 | hidden_activation=self.activation,
268 | hidden_init=self.default_init,
269 | skips=self.skips)
270 |
271 | branches = {
272 | 'w':
273 | modules.MLP(
274 | depth=self.rotation_depth,
275 | width=self.rotation_width,
276 | hidden_activation=self.activation,
277 | hidden_init=self.default_init,
278 | output_init=self.rotation_init,
279 | output_channels=3),
280 | 'v':
281 | modules.MLP(
282 | depth=self.pivot_depth,
283 | width=self.pivot_width,
284 | hidden_activation=self.activation,
285 | hidden_init=self.default_init,
286 | output_init=self.pivot_init,
287 | output_channels=3),
288 | }
289 | if self.use_pivot:
290 | branches['p'] = modules.MLP(
291 | depth=self.pivot_depth,
292 | width=self.pivot_width,
293 | hidden_activation=self.activation,
294 | hidden_init=self.default_init,
295 | output_init=self.pivot_init,
296 | output_channels=3)
297 | if self.use_translation:
298 | branches['t'] = modules.MLP(
299 | depth=self.translation_depth,
300 | width=self.translation_width,
301 | hidden_activation=self.activation,
302 | hidden_init=self.default_init,
303 | output_init=self.translation_init,
304 | output_channels=3)
305 | # Note that this must be done this way instead of using mutable operations.
306 | # See https://github.com/google/flax/issues/524.
307 | self.branches = branches
308 |
309 | def encode_metadata(self,
310 | metadata: jnp.ndarray,
311 | time_alpha: Optional[float] = None):
312 | if self.metadata_encoder_type == 'time':
313 | metadata_embed = self.metadata_encoder(metadata, time_alpha)
314 | elif self.metadata_encoder_type == 'glo':
315 | metadata_embed = self.metadata_encoder(metadata)
316 | else:
317 | raise RuntimeError(
318 | f'Unknown metadata encoder type {self.metadata_encoder_type}')
319 |
320 | return metadata_embed
321 |
322 | def warp(self,
323 | points: jnp.ndarray,
324 | metadata_embed: jnp.ndarray,
325 | extra: Dict[str, Any]):
326 | points_embed = self.points_encoder(points, alpha=extra.get('alpha'))
327 | inputs = jnp.concatenate([points_embed, metadata_embed], axis=-1)
328 | trunk_output = self.trunk(inputs)
329 |
330 | w = self.branches['w'](trunk_output)
331 | v = self.branches['v'](trunk_output)
332 | theta = jnp.linalg.norm(w, axis=-1)
333 | w = w / theta[..., jnp.newaxis]
334 | v = v / theta[..., jnp.newaxis]
335 | screw_axis = jnp.concatenate([w, v], axis=-1)
336 | transform = rigid.exp_se3(screw_axis, theta)
337 |
338 | warped_points = points
339 | if self.use_pivot:
340 | pivot = self.branches['p'](trunk_output)
341 | warped_points = warped_points + pivot
342 |
343 | warped_points = rigid.from_homogenous(
344 | transform @ rigid.to_homogenous(warped_points))
345 |
346 | if self.use_pivot:
347 | warped_points = warped_points - pivot
348 |
349 | if self.use_translation:
350 | t = self.branches['t'](trunk_output)
351 | warped_points = warped_points + t
352 |
353 | return warped_points
354 |
355 | def __call__(self,
356 | points: jnp.ndarray,
357 | metadata: jnp.ndarray,
358 | extra: Dict[str, Any],
359 | return_jacobian: bool = False,
360 | metadata_encoded: bool = False):
361 | """Warp the given points using a warp field.
362 |
363 | Args:
364 | points: the points to warp.
365 | metadata: metadata indices if metadata_encoded is False else pre-encoded
366 | metadata.
367 | extra: A dictionary containing
368 | 'alpha': the alpha value for the positional encoding.
369 | 'time_alpha': the alpha value for the time positional encoding
370 | (if applicable).
371 | return_jacobian: if True compute and return the Jacobian of the warp.
372 | metadata_encoded: if True assumes the metadata is already encoded.
373 |
374 | Returns:
375 | The warped points and the Jacobian of the warp if `return_jacobian` is
376 | True.
377 | """
378 | if metadata_encoded:
379 | metadata_embed = metadata
380 | else:
381 | metadata_embed = self.encode_metadata(metadata, extra.get('time_alpha'))
382 |
383 | out = {'warped_points': self.warp(points, metadata_embed, extra)}
384 |
385 | if return_jacobian:
386 | jac_fn = jax.jacfwd(self.warp, argnums=0)
387 | out['jacobian'] = jac_fn(points, metadata_embed, extra)
388 |
389 | return out
390 |
--------------------------------------------------------------------------------
/notebooks/Nerfies_Render_Video.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Nerfies Render Video v2.ipynb",
7 | "private_outputs": true,
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "TPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "QMMWf9AQcdlp"
22 | },
23 | "source": [
24 | "# Render a Nerfie video!\n",
25 | "\n",
26 | "**Author**: [Keunhong Park](https://keunhong.com)\n",
27 | "\n",
28 | "[[Project Page](https://nerfies.github.io)]\n",
29 | "[[Paper](https://storage.googleapis.com/nerfies-public/videos/nerfies_paper.pdf)]\n",
30 | "[[Video](https://www.youtube.com/watch?v=MrKrnHhk8IA)]\n",
31 | "[[GitHub](https://github.com/google/nerfies)]\n",
32 | "\n",
33 | "This notebook renders a figure-8 orbit video using the test cameras generated in the capture processing notebook.\n",
34 | "\n",
35 | "You can also load your own custom cameras by modifying the code slightly.\n",
36 | "\n",
37 | "### Instructions\n",
38 | "\n",
39 | "1. Convert a video into our dataset format using the [capture processing notebook](https://colab.sandbox.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb).\n",
40 | "2. Train a Nerfie model using the [training notebook](https://colab.sandbox.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Training.ipynb)\n",
41 | "3. Run this notebook!\n",
42 | "\n",
43 | "\n",
44 | "### Notes\n",
45 | " * Please report issues on the [GitHub issue tracker](https://github.com/google/nerfies/issues)."
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {
51 | "id": "gHqkIo4hcGou"
52 | },
53 | "source": [
54 | "## Environment Setup"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "metadata": {
60 | "id": "-GwSf5FfcH4b"
61 | },
62 | "source": [
63 | "!pip install flax immutabledict mediapy\n",
64 | "!pip install git+https://github.com/google/nerfies@v2"
65 | ],
66 | "execution_count": null,
67 | "outputs": []
68 | },
69 | {
70 | "cell_type": "code",
71 | "metadata": {
72 | "id": "-3T2lBKBcIGP",
73 | "cellView": "form"
74 | },
75 | "source": [
76 | "# @title Configure notebook runtime\n",
77 | "# @markdown If you would like to use a GPU runtime instead, change the runtime type by going to `Runtime > Change runtime type`. \n",
78 | "# @markdown You will have to use a smaller batch size on GPU.\n",
79 | "\n",
80 | "runtime_type = 'tpu' # @param ['gpu', 'tpu']\n",
81 | "if runtime_type == 'tpu':\n",
82 | " import jax.tools.colab_tpu\n",
83 | " jax.tools.colab_tpu.setup_tpu()\n",
84 | "\n",
85 | "print('Detected Devices:', jax.devices())"
86 | ],
87 | "execution_count": null,
88 | "outputs": []
89 | },
90 | {
91 | "cell_type": "code",
92 | "metadata": {
93 | "id": "82kU-W1NcNTW",
94 | "cellView": "form"
95 | },
96 | "source": [
97 | "# @title Mount Google Drive\n",
98 | "# @markdown Mount Google Drive onto `/content/gdrive`. You can skip this if running locally.\n",
99 | "\n",
100 | "from google.colab import drive\n",
101 | "drive.mount('/content/gdrive')"
102 | ],
103 | "execution_count": null,
104 | "outputs": []
105 | },
106 | {
107 | "cell_type": "code",
108 | "metadata": {
109 | "id": "YIDbV769cPn1",
110 | "cellView": "form"
111 | },
112 | "source": [
113 | "# @title Define imports and utility functions.\n",
114 | "\n",
115 | "import jax\n",
116 | "from jax.config import config as jax_config\n",
117 | "import jax.numpy as jnp\n",
118 | "from jax import grad, jit, vmap\n",
119 | "from jax import random\n",
120 | "\n",
121 | "import flax\n",
122 | "import flax.linen as nn\n",
123 | "from flax import jax_utils\n",
124 | "from flax import optim\n",
125 | "from flax.metrics import tensorboard\n",
126 | "from flax.training import checkpoints\n",
127 | "\n",
128 | "from absl import logging\n",
129 | "from io import BytesIO\n",
130 | "import random as pyrandom\n",
131 | "import numpy as np\n",
132 | "import PIL\n",
133 | "import IPython\n",
134 | "import tempfile\n",
135 | "import imageio\n",
136 | "import mediapy\n",
137 | "from IPython.display import display, HTML\n",
138 | "from base64 import b64encode\n",
139 | "\n",
140 | "\n",
141 | "# Monkey patch logging.\n",
142 | "def myprint(msg, *args, **kwargs):\n",
143 | " print(msg % args)\n",
144 | "\n",
145 | "logging.info = myprint \n",
146 | "logging.warn = myprint\n",
147 | "logging.error = myprint"
148 | ],
149 | "execution_count": null,
150 | "outputs": []
151 | },
152 | {
153 | "cell_type": "code",
154 | "metadata": {
155 | "id": "2QYJ7dyMcw2f"
156 | },
157 | "source": [
158 | "# @title Model and dataset configuration\n",
159 | "# @markdown Change the directories to where you saved your capture and experiment.\n",
160 | "\n",
161 | "\n",
162 | "from pathlib import Path\n",
163 | "from pprint import pprint\n",
164 | "import gin\n",
165 | "from IPython.display import display, Markdown\n",
166 | "\n",
167 | "from nerfies import configs\n",
168 | "\n",
169 | "\n",
170 | "# @markdown The working directory where the trained model is.\n",
171 | "train_dir = '/content/gdrive/My Drive/nerfies/experiments/capture1/exp1' # @param {type: \"string\"}\n",
172 | "# @markdown The directory to the dataset capture.\n",
173 | "data_dir = '/content/gdrive/My Drive/nerfies/captures/capture1' # @param {type: \"string\"}\n",
174 | "\n",
175 | "checkpoint_dir = Path(train_dir, 'checkpoints')\n",
176 | "checkpoint_dir.mkdir(exist_ok=True, parents=True)\n",
177 | "\n",
178 | "config_path = Path(train_dir, 'config.gin')\n",
179 | "with open(config_path, 'r') as f:\n",
180 | " logging.info('Loading config from %s', config_path)\n",
181 | " config_str = f.read()\n",
182 | "gin.parse_config(config_str)\n",
183 | "\n",
184 | "config_path = Path(train_dir, 'config.gin')\n",
185 | "with open(config_path, 'w') as f:\n",
186 | " logging.info('Saving config to %s', config_path)\n",
187 | " f.write(config_str)\n",
188 | "\n",
189 | "exp_config = configs.ExperimentConfig()\n",
190 | "model_config = configs.ModelConfig()\n",
191 | "train_config = configs.TrainConfig()\n",
192 | "eval_config = configs.EvalConfig()\n",
193 | "\n",
194 | "display(Markdown(\n",
195 | " gin.config.markdown(gin.operative_config_str())))"
196 | ],
197 | "execution_count": null,
198 | "outputs": []
199 | },
200 | {
201 | "cell_type": "code",
202 | "metadata": {
203 | "id": "6T7LQ5QSmu4o",
204 | "cellView": "form"
205 | },
206 | "source": [
207 | "# @title Create datasource and show an example.\n",
208 | "\n",
209 | "from nerfies import datasets\n",
210 | "from nerfies import image_utils\n",
211 | "\n",
212 | "datasource = datasets.from_config(\n",
213 | " exp_config.datasource_spec,\n",
214 | " image_scale=exp_config.image_scale,\n",
215 | " use_appearance_id=model_config.use_appearance_metadata,\n",
216 | " use_camera_id=model_config.use_camera_metadata,\n",
217 | " use_warp_id=model_config.use_warp,\n",
218 | " random_seed=exp_config.random_seed)\n",
219 | "\n",
220 | "mediapy.show_image(datasource.load_rgb(datasource.train_ids[0]))"
221 | ],
222 | "execution_count": null,
223 | "outputs": []
224 | },
225 | {
226 | "cell_type": "code",
227 | "metadata": {
228 | "id": "jEO3xcxpnCqx",
229 | "cellView": "form"
230 | },
231 | "source": [
232 | "# @title Initialize model\n",
233 | "# @markdown Defines the model and initializes its parameters.\n",
234 | "\n",
235 | "from flax.training import checkpoints\n",
236 | "from nerfies import models\n",
237 | "from nerfies import model_utils\n",
238 | "from nerfies import schedules\n",
239 | "from nerfies import training\n",
240 | "\n",
241 | "\n",
242 | "rng = random.PRNGKey(exp_config.random_seed)\n",
243 | "np.random.seed(exp_config.random_seed + jax.process_index())\n",
244 | "devices = jax.devices()\n",
245 | "\n",
246 | "learning_rate_sched = schedules.from_config(train_config.lr_schedule)\n",
247 | "warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)\n",
248 | "elastic_loss_weight_sched = schedules.from_config(\n",
249 | " train_config.elastic_loss_weight_schedule)\n",
250 | "\n",
251 | "rng, key = random.split(rng)\n",
252 | "params = {}\n",
253 | "model, params['model'] = models.construct_nerf(\n",
254 | " key,\n",
255 | " model_config,\n",
256 | " batch_size=train_config.batch_size,\n",
257 | " appearance_ids=datasource.appearance_ids,\n",
258 | " camera_ids=datasource.camera_ids,\n",
259 | " warp_ids=datasource.warp_ids,\n",
260 | " near=datasource.near,\n",
261 | " far=datasource.far,\n",
262 | " use_warp_jacobian=train_config.use_elastic_loss,\n",
263 | " use_weights=train_config.use_elastic_loss)\n",
264 | "\n",
265 | "optimizer_def = optim.Adam(learning_rate_sched(0))\n",
266 | "optimizer = optimizer_def.create(params)\n",
267 | "state = model_utils.TrainState(\n",
268 | " optimizer=optimizer,\n",
269 | " warp_alpha=warp_alpha_sched(0))\n",
270 | "scalar_params = training.ScalarParams(\n",
271 | " learning_rate=learning_rate_sched(0),\n",
272 | " elastic_loss_weight=elastic_loss_weight_sched(0),\n",
273 | " background_loss_weight=train_config.background_loss_weight)\n",
274 | "logging.info('Restoring checkpoint from %s', checkpoint_dir)\n",
275 | "state = checkpoints.restore_checkpoint(checkpoint_dir, state)\n",
276 | "step = state.optimizer.state.step + 1\n",
277 | "state = jax_utils.replicate(state, devices=devices)\n",
278 | "del params"
279 | ],
280 | "execution_count": null,
281 | "outputs": []
282 | },
283 | {
284 | "cell_type": "code",
285 | "metadata": {
286 | "id": "2KYhbpsklwAy",
287 | "cellView": "form"
288 | },
289 | "source": [
290 | "# @title Define pmapped render function.\n",
291 | "\n",
292 | "import functools\n",
293 | "from nerfies import evaluation\n",
294 | "\n",
295 | "devices = jax.devices()\n",
296 | "\n",
297 | "\n",
298 | "def _model_fn(key_0, key_1, params, rays_dict, warp_extra):\n",
299 | " out = model.apply({'params': params},\n",
300 | " rays_dict,\n",
301 | " warp_extra=warp_extra,\n",
302 | " rngs={\n",
303 | " 'coarse': key_0,\n",
304 | " 'fine': key_1\n",
305 | " },\n",
306 | " mutable=False)\n",
307 | " return jax.lax.all_gather(out, axis_name='batch')\n",
308 | "\n",
309 | "pmodel_fn = jax.pmap(\n",
310 | " # Note rng_keys are useless in eval mode since there's no randomness.\n",
311 | " _model_fn,\n",
312 | " in_axes=(0, 0, 0, 0, 0), # Only distribute the data input.\n",
313 | " devices=devices,\n",
314 | " donate_argnums=(3,), # Donate the 'rays' argument.\n",
315 | " axis_name='batch',\n",
316 | ")\n",
317 | "\n",
318 | "render_fn = functools.partial(evaluation.render_image,\n",
319 | " model_fn=pmodel_fn,\n",
320 | " device_count=len(devices),\n",
321 | " chunk=eval_config.chunk)"
322 | ],
323 | "execution_count": null,
324 | "outputs": []
325 | },
326 | {
327 | "cell_type": "code",
328 | "metadata": {
329 | "id": "73Fq0kNcmAra"
330 | },
331 | "source": [
332 | "# @title Load cameras.\n",
333 | "\n",
334 | "from nerfies import utils\n",
335 | "\n",
336 | "camera_path = 'camera-paths/orbit-mild' # @param {type: 'string'}\n",
337 | "\n",
338 | "camera_dir = Path(data_dir, camera_path)\n",
339 | "print(f'Loading cameras from {camera_dir}')\n",
340 | "test_camera_paths = datasource.glob_cameras(camera_dir)\n",
341 | "test_cameras = utils.parallel_map(datasource.load_camera, test_camera_paths, show_pbar=True)"
342 | ],
343 | "execution_count": null,
344 | "outputs": []
345 | },
346 | {
347 | "cell_type": "code",
348 | "metadata": {
349 | "id": "aP9LjiAZmoRc",
350 | "cellView": "form"
351 | },
352 | "source": [
353 | "# @title Render video frames.\n",
354 | "from nerfies import visualization as viz\n",
355 | "\n",
356 | "\n",
357 | "rng = rng + jax.host_id() # Make random seed separate across hosts.\n",
358 | "keys = random.split(rng, len(devices))\n",
359 | "\n",
360 | "results = []\n",
361 | "for i in range(len(test_cameras)):\n",
362 | " print(f'Rendering frame {i+1}/{len(test_cameras)}')\n",
363 | " camera = test_cameras[i]\n",
364 | " batch = datasets.camera_to_rays(camera)\n",
365 | " batch['metadata'] = {\n",
366 | " 'appearance': jnp.zeros_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32),\n",
367 | " 'warp': jnp.zeros_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32),\n",
368 | " }\n",
369 | "\n",
370 | " render = render_fn(state, batch, rng=rng)\n",
371 | " rgb = np.array(render['rgb'])\n",
372 | " depth_med = np.array(render['med_depth'])\n",
373 | " results.append((rgb, depth_med))\n",
374 | " depth_viz = viz.colorize(depth_med.squeeze(), cmin=datasource.near, cmax=datasource.far, invert=True)\n",
375 | " mediapy.show_images([rgb, depth_viz])"
376 | ],
377 | "execution_count": null,
378 | "outputs": []
379 | },
380 | {
381 | "cell_type": "code",
382 | "metadata": {
383 | "id": "_5hHR9XVm8Ix"
384 | },
385 | "source": [
386 | "# @title Show rendered video.\n",
387 | "\n",
388 | "fps = 30 # @param {type:'number'}\n",
389 | "\n",
390 | "frames = []\n",
391 | "for rgb, depth in results:\n",
392 | " depth_viz = viz.colorize(depth.squeeze(), cmin=datasource.near, cmax=datasource.far, invert=True)\n",
393 | " frame = np.concatenate([rgb, depth_viz], axis=1)\n",
394 | " frames.append(image_utils.image_to_uint8(frame))\n",
395 | "\n",
396 | "mediapy.show_video(frames, fps=fps)"
397 | ],
398 | "execution_count": null,
399 | "outputs": []
400 | },
401 | {
402 | "cell_type": "code",
403 | "metadata": {
404 | "id": "WW32AVGR0Vwh"
405 | },
406 | "source": [
407 | ""
408 | ],
409 | "execution_count": null,
410 | "outputs": []
411 | }
412 | ]
413 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.13.0
2 | flax==0.3.4
3 | gin-config @ git+https://github.com/google/gin-config@243ba87b3fcfeb2efb4a920b8f19679b61a6f0dc
4 | imageio==2.9.0
5 | immutabledict==2.2.0
6 | jax==0.2.20
7 | numpy==1.19.5
8 | opencv-python==4.5.3.56
9 | Pillow==8.3.2
10 | scikit-image==0.18.3
11 | scipy==1.7.1
12 | tensorboard==2.6.0
13 | tensorflow>=2.6.1
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import setuptools
16 |
17 | with open("README.md", "r", encoding="utf-8") as fh:
18 | long_description = fh.read()
19 |
20 | setuptools.setup(
21 | name="nerfies", # Replace with your own username
22 | version="0.0.2",
23 | author="Keunhong Park",
24 | author_email="kpar@cs.washington.edu",
25 | description="Code for 'Nerfies: Deformable Neural Radiance Fields'.",
26 | long_description=long_description,
27 | long_description_content_type="text/markdown",
28 | url="https://github.com/google/nerfies",
29 | packages=setuptools.find_packages(),
30 | classifiers=[
31 | "Programming Language :: Python :: 3",
32 | "License :: OSI Approved :: Apache License 2.0",
33 | "Operating System :: OS Independent",
34 | ],
35 | python_requires='>=3.7',
36 | )
37 |
--------------------------------------------------------------------------------
/third_party/pycolmap/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 True Price, UNC Chapel Hill
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/third_party/pycolmap/README.md:
--------------------------------------------------------------------------------
1 | # pycolmap
2 | Python interface for COLMAP reconstructions, plus some convenient scripts for loading/modifying/converting reconstructions.
3 |
4 | This code does not, however, run reconstruction -- it only provides a convenient interface for handling COLMAP's output.
5 |
--------------------------------------------------------------------------------
/third_party/pycolmap/pycolmap/__init__.py:
--------------------------------------------------------------------------------
1 | from .camera import Camera
2 | from .database import COLMAPDatabase
3 | from .image import Image
4 | from .scene_manager import SceneManager
5 | from .rotation import Quaternion, DualQuaternion
6 |
--------------------------------------------------------------------------------
/third_party/pycolmap/pycolmap/camera.py:
--------------------------------------------------------------------------------
1 | # Author: True Price
2 |
3 | import numpy as np
4 |
5 | from scipy.optimize import root
6 |
7 |
8 | #-------------------------------------------------------------------------------
9 | #
10 | # camera distortion functions for arrays of size (..., 2)
11 | #
12 | #-------------------------------------------------------------------------------
13 |
14 | def simple_radial_distortion(camera, x):
15 | return x * (1. + camera.k1 * np.square(x).sum(axis=-1, keepdims=True))
16 |
17 |
18 | def radial_distortion(camera, x):
19 | r_sq = np.square(x).sum(axis=-1, keepdims=True)
20 | return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq))
21 |
22 |
23 | def opencv_distortion(camera, x):
24 | x_sq = np.square(x)
25 | xy = np.prod(x, axis=-1, keepdims=True)
26 | r_sq = x_sq.sum(axis=-1, keepdims=True)
27 |
28 | return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + np.concatenate(
29 | (2. * camera.p1 * xy + camera.p2 * (r_sq + 2. * x_sq), camera.p1 *
30 | (r_sq + 2. * y_sq) + 2. * camera.p2 * xy),
31 | axis=-1)
32 |
33 |
34 | #-------------------------------------------------------------------------------
35 | #
36 | # Camera
37 | #
38 | #-------------------------------------------------------------------------------
39 |
40 | class Camera:
41 |
42 | @staticmethod
43 | def GetNumParams(type_):
44 | if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
45 | return 3
46 | if type_ == 1 or type_ == 'PINHOLE':
47 | return 4
48 | if type_ == 2 or type_ == 'SIMPLE_RADIAL':
49 | return 4
50 | if type_ == 3 or type_ == 'RADIAL':
51 | return 5
52 | if type_ == 4 or type_ == 'OPENCV':
53 | return 8
54 | #if type_ == 5 or type_ == 'OPENCV_FISHEYE':
55 | # return 8
56 | #if type_ == 6 or type_ == 'FULL_OPENCV':
57 | # return 12
58 | #if type_ == 7 or type_ == 'FOV':
59 | # return 5
60 | #if type_ == 8 or type_ == 'SIMPLE_RADIAL_FISHEYE':
61 | # return 4
62 | #if type_ == 9 or type_ == 'RADIAL_FISHEYE':
63 | # return 5
64 | #if type_ == 10 or type_ == 'THIN_PRISM_FISHEYE':
65 | # return 12
66 |
67 | # TODO: not supporting other camera types, currently
68 | raise Exception('Camera type not supported')
69 |
70 | #---------------------------------------------------------------------------
71 |
72 | @staticmethod
73 | def GetNameFromType(type_):
74 | if type_ == 0:
75 | return 'SIMPLE_PINHOLE'
76 | if type_ == 1:
77 | return 'PINHOLE'
78 | if type_ == 2:
79 | return 'SIMPLE_RADIAL'
80 | if type_ == 3:
81 | return 'RADIAL'
82 | if type_ == 4:
83 | return 'OPENCV'
84 | #if type_ == 5: return 'OPENCV_FISHEYE'
85 | #if type_ == 6: return 'FULL_OPENCV'
86 | #if type_ == 7: return 'FOV'
87 | #if type_ == 8: return 'SIMPLE_RADIAL_FISHEYE'
88 | #if type_ == 9: return 'RADIAL_FISHEYE'
89 | #if type_ == 10: return 'THIN_PRISM_FISHEYE'
90 |
91 | raise Exception('Camera type not supported')
92 |
93 | #---------------------------------------------------------------------------
94 |
95 | def __init__(self, type_, width_, height_, params):
96 | self.width = width_
97 | self.height = height_
98 |
99 | if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
100 | self.fx, self.cx, self.cy = params
101 | self.fy = self.fx
102 | self.distortion_func = None
103 | self.camera_type = 0
104 |
105 | elif type_ == 1 or type_ == 'PINHOLE':
106 | self.fx, self.fy, self.cx, self.cy = params
107 | self.distortion_func = None
108 | self.camera_type = 1
109 |
110 | elif type_ == 2 or type_ == 'SIMPLE_RADIAL':
111 | self.fx, self.cx, self.cy, self.k1 = params
112 | self.fy = self.fx
113 | self.distortion_func = simple_radial_distortion
114 | self.camera_type = 2
115 |
116 | elif type_ == 3 or type_ == 'RADIAL':
117 | self.fx, self.cx, self.cy, self.k1, self.k2 = params
118 | self.fy = self.fx
119 | self.distortion_func = radial_distortion
120 | self.camera_type = 3
121 |
122 | elif type_ == 4 or type_ == 'OPENCV':
123 | self.fx, self.fy, self.cx, self.cy = params[:4]
124 | self.k1, self.k2, self.p1, self.p2 = params[4:]
125 | self.distortion_func = opencv_distortion
126 | self.camera_type = 4
127 |
128 | else:
129 | raise Exception('Camera type not supported')
130 |
131 | #---------------------------------------------------------------------------
132 |
133 | def __str__(self):
134 | s = (
135 | self.GetNameFromType(self.camera_type) +
136 | ' {} {} {}'.format(self.width, self.height, self.fx))
137 |
138 | if self.camera_type in (1, 4): # PINHOLE, OPENCV
139 | s += ' {}'.format(self.fy)
140 |
141 | s += ' {} {}'.format(self.cx, self.cy)
142 |
143 | if self.camera_type == 2: # SIMPLE_RADIAL
144 | s += ' {}'.format(self.k1)
145 |
146 | elif self.camera_type == 3: # RADIAL
147 | s += ' {} {}'.format(self.k1, self.k2)
148 |
149 | elif self.camera_type == 4: # OPENCV
150 | s += ' {} {} {} {}'.format(self.k1, self.k2, self.p1, self.p2)
151 |
152 | return s
153 |
154 | #---------------------------------------------------------------------------
155 |
156 | # return the camera parameters in the same order as the colmap output format
157 | def get_params(self):
158 | if self.camera_type == 0:
159 | return np.array((self.fx, self.cx, self.cy))
160 | if self.camera_type == 1:
161 | return np.array((self.fx, self.fy, self.cx, self.cy))
162 | if self.camera_type == 2:
163 | return np.array((self.fx, self.cx, self.cy, self.k1))
164 | if self.camera_type == 3:
165 | return np.array((self.fx, self.cx, self.cy, self.k1, self.k2))
166 | if self.camera_type == 4:
167 | return np.array((self.fx, self.fy, self.cx, self.cy, self.k1, self.k2,
168 | self.p1, self.p2))
169 |
170 | #---------------------------------------------------------------------------
171 |
172 | def get_camera_matrix(self):
173 | return np.array(((self.fx, 0, self.cx), (0, self.fy, self.cy), (0, 0, 1)))
174 |
175 | def get_inverse_camera_matrix(self):
176 | return np.array(((1. / self.fx, 0, -self.cx / self.fx),
177 | (0, 1. / self.fy, -self.cy / self.fy), (0, 0, 1)))
178 |
179 | @property
180 | def K(self):
181 | return self.get_camera_matrix()
182 |
183 | @property
184 | def K_inv(self):
185 | return self.get_inverse_camera_matrix()
186 |
187 | #---------------------------------------------------------------------------
188 |
189 | # return the inverse camera matrix
190 | def get_inv_camera_matrix(self):
191 | inv_fx, inv_fy = 1. / self.fx, 1. / self.fy
192 | return np.array(((inv_fx, 0, -inv_fx * self.cx),
193 | (0, inv_fy, -inv_fy * self.cy), (0, 0, 1)))
194 |
195 | #---------------------------------------------------------------------------
196 |
197 | # return an (x, y) pixel coordinate grid for this camera
198 | def get_image_grid(self):
199 | xmin = (0.5 - self.cx) / self.fx
200 | xmax = (self.width - 0.5 - self.cx) / self.fx
201 | ymin = (0.5 - self.cy) / self.fy
202 | ymax = (self.height - 0.5 - self.cy) / self.fy
203 | return np.meshgrid(
204 | np.linspace(xmin, xmax, self.width),
205 | np.linspace(ymin, ymax, self.height))
206 |
207 | #---------------------------------------------------------------------------
208 |
209 | # x: array of shape (N,2) or (2,)
210 | # normalized: False if the input points are in pixel coordinates
211 | # denormalize: True if the points should be put back into pixel coordinates
212 | def distort_points(self, x, normalized=True, denormalize=True):
213 | x = np.atleast_2d(x)
214 |
215 | # put the points into normalized camera coordinates
216 | if not normalized:
217 | x -= np.array([[self.cx, self.cy]])
218 | x /= np.array([[self.fx, self.fy]])
219 |
220 | # distort, if necessary
221 | if self.distortion_func is not None:
222 | x = self.distortion_func(self, x)
223 |
224 | if denormalize:
225 | x *= np.array([[self.fx, self.fy]])
226 | x += np.array([[self.cx, self.cy]])
227 |
228 | return x
229 |
230 | #---------------------------------------------------------------------------
231 |
232 | # x: array of shape (N1,N2,...,2), (N,2), or (2,)
233 | # normalized: False if the input points are in pixel coordinates
234 | # denormalize: True if the points should be put back into pixel coordinates
235 | def undistort_points(self, x, normalized=False, denormalize=True):
236 | x = np.atleast_2d(x)
237 |
238 | # put the points into normalized camera coordinates
239 | if not normalized:
240 | x = x - np.array([self.cx, self.cy]) # creates a copy
241 | x /= np.array([self.fx, self.fy])
242 |
243 | # undistort, if necessary
244 | if self.distortion_func is not None:
245 |
246 | def objective(xu):
247 | return (x - self.distortion_func(self, xu.reshape(*x.shape))).ravel()
248 |
249 | xu = root(objective, x).x.reshape(*x.shape)
250 | else:
251 | xu = x
252 |
253 | if denormalize:
254 | xu *= np.array([[self.fx, self.fy]])
255 | xu += np.array([[self.cx, self.cy]])
256 |
257 | return xu
258 |
--------------------------------------------------------------------------------
/third_party/pycolmap/pycolmap/database.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sqlite3
4 |
5 | #-------------------------------------------------------------------------------
6 | # convert SQLite BLOBs to/from numpy arrays
7 |
8 |
9 | def array_to_blob(arr):
10 | return np.getbuffer(arr)
11 |
12 |
13 | def blob_to_array(blob, dtype, shape=(-1,)):
14 | return np.frombuffer(blob, dtype).reshape(*shape)
15 |
16 |
17 | #-------------------------------------------------------------------------------
18 | # convert to/from image pair ids
19 |
20 | MAX_IMAGE_ID = 2**31 - 1
21 |
22 |
23 | def get_pair_id(image_id1, image_id2):
24 | if image_id1 > image_id2:
25 | image_id1, image_id2 = image_id2, image_id1
26 | return image_id1 * MAX_IMAGE_ID + image_id2
27 |
28 |
29 | def get_image_ids_from_pair_id(pair_id):
30 | image_id2 = pair_id % MAX_IMAGE_ID
31 | return (pair_id - image_id2) / MAX_IMAGE_ID, image_id2
32 |
33 |
34 | #-------------------------------------------------------------------------------
35 | # create table commands
36 |
37 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
38 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
39 | model INTEGER NOT NULL,
40 | width INTEGER NOT NULL,
41 | height INTEGER NOT NULL,
42 | params BLOB,
43 | prior_focal_length INTEGER NOT NULL)"""
44 |
45 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
46 | image_id INTEGER PRIMARY KEY NOT NULL,
47 | rows INTEGER NOT NULL,
48 | cols INTEGER NOT NULL,
49 | data BLOB,
50 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
51 |
52 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
53 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
54 | name TEXT NOT NULL UNIQUE,
55 | camera_id INTEGER NOT NULL,
56 | prior_qw REAL,
57 | prior_qx REAL,
58 | prior_qy REAL,
59 | prior_qz REAL,
60 | prior_tx REAL,
61 | prior_ty REAL,
62 | prior_tz REAL,
63 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647),
64 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))"""
65 |
66 | CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries (
67 | pair_id INTEGER PRIMARY KEY NOT NULL,
68 | rows INTEGER NOT NULL,
69 | cols INTEGER NOT NULL,
70 | data BLOB,
71 | config INTEGER NOT NULL,
72 | F BLOB,
73 | E BLOB,
74 | H BLOB)"""
75 |
76 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
77 | image_id INTEGER PRIMARY KEY NOT NULL,
78 | rows INTEGER NOT NULL,
79 | cols INTEGER NOT NULL,
80 | data BLOB,
81 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
82 |
83 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
84 | pair_id INTEGER PRIMARY KEY NOT NULL,
85 | rows INTEGER NOT NULL,
86 | cols INTEGER NOT NULL,
87 | data BLOB)"""
88 |
89 | CREATE_NAME_INDEX = \
90 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
91 |
92 | CREATE_ALL = "; ".join([
93 | CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE, CREATE_IMAGES_TABLE,
94 | CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE, CREATE_MATCHES_TABLE,
95 | CREATE_NAME_INDEX
96 | ])
97 |
98 | #-------------------------------------------------------------------------------
99 | # functional interface for adding objects
100 |
101 |
102 | def add_camera(db,
103 | model,
104 | width,
105 | height,
106 | params,
107 | prior_focal_length=False,
108 | camera_id=None):
109 | # TODO: Parameter count checks
110 | params = np.asarray(params, np.float64)
111 | db.execute("INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
112 | (camera_id, model, width, height, array_to_blob(params),
113 | prior_focal_length))
114 |
115 |
116 | def add_descriptors(db, image_id, descriptors):
117 | descriptors = np.ascontiguousarray(descriptors, np.uint8)
118 | db.execute("INSERT INTO descriptors VALUES (?, ?, ?, ?)",
119 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),))
120 |
121 |
122 | def add_image(db,
123 | name,
124 | camera_id,
125 | prior_q=np.zeros(4),
126 | prior_t=np.zeros(3),
127 | image_id=None):
128 | db.execute("INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
129 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
130 | prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
131 |
132 |
133 | # config: defaults to fundamental matrix
134 | def add_inlier_matches(db,
135 | image_id1,
136 | image_id2,
137 | matches,
138 | config=2,
139 | F=None,
140 | E=None,
141 | H=None):
142 | assert (len(matches.shape) == 2)
143 | assert (matches.shape[1] == 2)
144 |
145 | if image_id1 > image_id2:
146 | matches = matches[:, ::-1]
147 |
148 | if F is not None:
149 | F = np.asarray(F, np.float64)
150 | if E is not None:
151 | E = np.asarray(E, np.float64)
152 | if H is not None:
153 | H = np.asarray(H, np.float64)
154 |
155 | pair_id = get_pair_id(image_id1, image_id2)
156 | matches = np.asarray(matches, np.uint32)
157 | db.execute("INSERT INTO inlier_matches VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
158 | (pair_id,) + matches.shape +
159 | (array_to_blob(matches), config, F, E, H))
160 |
161 |
162 | def add_keypoints(db, image_id, keypoints):
163 | assert (len(keypoints.shape) == 2)
164 | assert (keypoints.shape[1] in [2, 4, 6])
165 |
166 | keypoints = np.asarray(keypoints, np.float32)
167 | db.execute("INSERT INTO keypoints VALUES (?, ?, ?, ?)",
168 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
169 |
170 |
171 | # config: defaults to fundamental matrix
172 | def add_matches(db, image_id1, image_id2, matches):
173 | assert (len(matches.shape) == 2)
174 | assert (matches.shape[1] == 2)
175 |
176 | if image_id1 > image_id2:
177 | matches = matches[:, ::-1]
178 |
179 | pair_id = get_pair_id(image_id1, image_id2)
180 | matches = np.asarray(matches, np.uint32)
181 | db.execute("INSERT INTO matches VALUES (?, ?, ?, ?)",
182 | (pair_id,) + matches.shape + (array_to_blob(matches),))
183 |
184 |
185 | #-------------------------------------------------------------------------------
186 | # simple functional interface
187 |
188 |
189 | class COLMAPDatabase(sqlite3.Connection):
190 |
191 | @staticmethod
192 | def connect(database_path):
193 | return sqlite3.connect(database_path, factory=COLMAPDatabase)
194 |
195 | def __init__(self, *args, **kwargs):
196 | super(COLMAPDatabase, self).__init__(*args, **kwargs)
197 |
198 | self.initialize_tables = lambda: self.executescript(CREATE_ALL)
199 |
200 | self.initialize_cameras = \
201 | lambda: self.executescript(CREATE_CAMERAS_TABLE)
202 | self.initialize_descriptors = \
203 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
204 | self.initialize_images = \
205 | lambda: self.executescript(CREATE_IMAGES_TABLE)
206 | self.initialize_inlier_matches = \
207 | lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE)
208 | self.initialize_keypoints = \
209 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
210 | self.initialize_matches = \
211 | lambda: self.executescript(CREATE_MATCHES_TABLE)
212 |
213 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
214 |
215 | add_camera = add_camera
216 | add_descriptors = add_descriptors
217 | add_image = add_image
218 | add_inlier_matches = add_inlier_matches
219 | add_keypoints = add_keypoints
220 | add_matches = add_matches
221 |
222 |
223 | #-------------------------------------------------------------------------------
224 |
225 |
226 | def main(args):
227 | import os
228 |
229 | if os.path.exists(args.database_path):
230 | print("Error: database path already exists -- will not modify it.")
231 | exit()
232 |
233 | db = COLMAPDatabase.connect(args.database_path)
234 |
235 | #
236 | # for convenience, try creating all the tables upfront
237 | #
238 |
239 | db.initialize_tables()
240 |
241 | #
242 | # create dummy cameras
243 | #
244 |
245 | model1, w1, h1, params1 = 0, 1024, 768, np.array((1024., 512., 384.))
246 | model2, w2, h2, params2 = 2, 1024, 768, np.array((1024., 512., 384., 0.1))
247 |
248 | db.add_camera(model1, w1, h1, params1)
249 | db.add_camera(model2, w2, h2, params2)
250 |
251 | #
252 | # create dummy images
253 | #
254 |
255 | db.add_image("image1.png", 0)
256 | db.add_image("image2.png", 0)
257 | db.add_image("image3.png", 2)
258 | db.add_image("image4.png", 2)
259 |
260 | #
261 | # create dummy keypoints; note that COLMAP supports 2D keypoints (x, y),
262 | # 4D keypoints (x, y, theta, scale), and 6D affine keypoints
263 | # (x, y, a_11, a_12, a_21, a_22)
264 | #
265 |
266 | N = 1000
267 | kp1 = np.random.rand(N, 2) * (1024., 768.)
268 | kp2 = np.random.rand(N, 2) * (1024., 768.)
269 | kp3 = np.random.rand(N, 2) * (1024., 768.)
270 | kp4 = np.random.rand(N, 2) * (1024., 768.)
271 |
272 | db.add_keypoints(1, kp1)
273 | db.add_keypoints(2, kp2)
274 | db.add_keypoints(3, kp3)
275 | db.add_keypoints(4, kp4)
276 |
277 | #
278 | # create dummy matches
279 | #
280 |
281 | M = 50
282 | m12 = np.random.randint(N, size=(M, 2))
283 | m23 = np.random.randint(N, size=(M, 2))
284 | m34 = np.random.randint(N, size=(M, 2))
285 |
286 | db.add_matches(1, 2, m12)
287 | db.add_matches(2, 3, m23)
288 | db.add_matches(3, 4, m34)
289 |
290 | #
291 | # check cameras
292 | #
293 |
294 | rows = db.execute("SELECT * FROM cameras")
295 |
296 | camera_id, model, width, height, params, prior = next(rows)
297 | params = blob_to_array(params, np.float32)
298 | assert model == model1 and width == w1 and height == h1
299 | assert np.allclose(params, params1)
300 |
301 | camera_id, model, width, height, params, prior = next(rows)
302 | params = blob_to_array(params, np.float32)
303 | assert model == model2 and width == w2 and height == h2
304 | assert np.allclose(params, params2)
305 |
306 | #
307 | # check keypoints
308 | #
309 |
310 | kps = dict(
311 | (image_id, blob_to_array(data, np.float32, (-1, 2)))
312 | for image_id, data in db.execute("SELECT image_id, data FROM keypoints"))
313 |
314 | assert np.allclose(kps[1], kp1)
315 | assert np.allclose(kps[2], kp2)
316 | assert np.allclose(kps[3], kp3)
317 | assert np.allclose(kps[4], kp4)
318 |
319 | #
320 | # check matches
321 | #
322 |
323 | pair_ids = [get_pair_id(*pair) for pair in [(1, 2), (2, 3), (3, 4)]]
324 |
325 | matches = dict(
326 | (get_image_ids_from_pair_id(pair_id),
327 | blob_to_array(data, np.uint32, (-1, 2)))
328 | for pair_id, data in db.execute("SELECT pair_id, data FROM matches"))
329 |
330 | assert np.all(matches[(1, 2)] == m12)
331 | assert np.all(matches[(2, 3)] == m23)
332 | assert np.all(matches[(3, 4)] == m34)
333 |
334 | #
335 | # clean up
336 | #
337 |
338 | db.close()
339 | os.remove(args.database_path)
340 |
--------------------------------------------------------------------------------
/third_party/pycolmap/pycolmap/image.py:
--------------------------------------------------------------------------------
1 | # Author: True Price
2 |
3 | import numpy as np
4 |
5 | #-------------------------------------------------------------------------------
6 | #
7 | # Image
8 | #
9 | #-------------------------------------------------------------------------------
10 |
11 |
12 | class Image:
13 |
14 | def __init__(self, name_, camera_id_, q_, tvec_):
15 | self.name = name_
16 | self.camera_id = camera_id_
17 | self.q = q_
18 | self.tvec = tvec_
19 |
20 | self.points2D = np.empty((0, 2), dtype=np.float64)
21 | self.point3D_ids = np.empty((0,), dtype=np.uint64)
22 |
23 | #---------------------------------------------------------------------------
24 |
25 | def R(self):
26 | return self.q.ToR()
27 |
28 | #---------------------------------------------------------------------------
29 |
30 | def C(self):
31 | return -self.R().T.dot(self.tvec)
32 |
33 | #---------------------------------------------------------------------------
34 |
35 | @property
36 | def t(self):
37 | return self.tvec
38 |
--------------------------------------------------------------------------------
/third_party/pycolmap/pycolmap/rotation.py:
--------------------------------------------------------------------------------
1 | # Author: True Price
2 |
3 | import numpy as np
4 |
5 | #-------------------------------------------------------------------------------
6 | #
7 | # Axis-Angle Functions
8 | #
9 | #-------------------------------------------------------------------------------
10 |
11 |
12 | # returns the cross product matrix representation of a 3-vector v
13 | def cross_prod_matrix(v):
14 | return np.array(((0., -v[2], v[1]), (v[2], 0., -v[0]), (-v[1], v[0], 0.)))
15 |
16 |
17 | #-------------------------------------------------------------------------------
18 |
19 |
20 | # www.euclideanspace.com/maths/geometry/rotations/conversions/angleToMatrix/
21 | # if angle is None, assume ||axis|| == angle, in radians
22 | # if angle is not None, assume that axis is a unit vector
23 | def axis_angle_to_rotation_matrix(axis, angle=None):
24 | if angle is None:
25 | angle = np.linalg.norm(axis)
26 | if np.abs(angle) > np.finfo('float').eps:
27 | axis = axis / angle
28 |
29 | cp_axis = cross_prod_matrix(axis)
30 | return np.eye(3) + (
31 | np.sin(angle) * cp_axis + (1. - np.cos(angle)) * cp_axis.dot(cp_axis))
32 |
33 |
34 | #-------------------------------------------------------------------------------
35 |
36 |
37 | # after some deliberation, I've decided the easiest way to do this is to use
38 | # quaternions as an intermediary
39 | def rotation_matrix_to_axis_angle(R):
40 | return Quaternion.FromR(R).ToAxisAngle()
41 |
42 |
43 | #-------------------------------------------------------------------------------
44 | #
45 | # Quaternion
46 | #
47 | #-------------------------------------------------------------------------------
48 |
49 |
50 | class Quaternion:
51 | # create a quaternion from an existing rotation matrix
52 | # euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/
53 | @staticmethod
54 | def FromR(R):
55 | trace = np.trace(R)
56 |
57 | if trace > 0:
58 | qw = 0.5 * np.sqrt(1. + trace)
59 | qx = (R[2, 1] - R[1, 2]) * 0.25 / qw
60 | qy = (R[0, 2] - R[2, 0]) * 0.25 / qw
61 | qz = (R[1, 0] - R[0, 1]) * 0.25 / qw
62 | elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]:
63 | s = 2. * np.sqrt(1. + R[0, 0] - R[1, 1] - R[2, 2])
64 | qw = (R[2, 1] - R[1, 2]) / s
65 | qx = 0.25 * s
66 | qy = (R[0, 1] + R[1, 0]) / s
67 | qz = (R[0, 2] + R[2, 0]) / s
68 | elif R[1, 1] > R[2, 2]:
69 | s = 2. * np.sqrt(1. + R[1, 1] - R[0, 0] - R[2, 2])
70 | qw = (R[0, 2] - R[2, 0]) / s
71 | qx = (R[0, 1] + R[1, 0]) / s
72 | qy = 0.25 * s
73 | qz = (R[1, 2] + R[2, 1]) / s
74 | else:
75 | s = 2. * np.sqrt(1. + R[2, 2] - R[0, 0] - R[1, 1])
76 | qw = (R[1, 0] - R[0, 1]) / s
77 | qx = (R[0, 2] + R[2, 0]) / s
78 | qy = (R[1, 2] + R[2, 1]) / s
79 | qz = 0.25 * s
80 |
81 | return Quaternion(np.array((qw, qx, qy, qz)))
82 |
83 | # if angle is None, assume ||axis|| == angle, in radians
84 | # if angle is not None, assume that axis is a unit vector
85 | @staticmethod
86 | def FromAxisAngle(axis, angle=None):
87 | if angle is None:
88 | angle = np.linalg.norm(axis)
89 | if np.abs(angle) > np.finfo('float').eps:
90 | axis = axis / angle
91 |
92 | qw = np.cos(0.5 * angle)
93 | axis = axis * np.sin(0.5 * angle)
94 |
95 | return Quaternion(np.array((qw, axis[0], axis[1], axis[2])))
96 |
97 | #---------------------------------------------------------------------------
98 |
99 | def __init__(self, q=np.array((1., 0., 0., 0.))):
100 | if isinstance(q, Quaternion):
101 | self.q = q.q.copy()
102 | else:
103 | q = np.asarray(q)
104 | if q.size == 4:
105 | self.q = q.copy()
106 | elif q.size == 3: # convert from a 3-vector to a quaternion
107 | self.q = np.empty(4)
108 | self.q[0], self.q[1:] = 0., q.ravel()
109 | else:
110 | raise Exception('Input quaternion should be a 3- or 4-vector')
111 |
112 | def __add__(self, other):
113 | return Quaternion(self.q + other.q)
114 |
115 | def __iadd__(self, other):
116 | self.q += other.q
117 | return self
118 |
119 | # conjugation via the ~ operator
120 | def __invert__(self):
121 | return Quaternion(np.array((self.q[0], -self.q[1], -self.q[2], -self.q[3])))
122 |
123 | # returns: self.q * other.q if other is a Quaternion; otherwise performs
124 | # scalar multiplication
125 | def __mul__(self, other):
126 | if isinstance(other, Quaternion): # quaternion multiplication
127 | return Quaternion(
128 | np.array((self.q[0] * other.q[0] - self.q[1] * other.q[1] -
129 | self.q[2] * other.q[2] - self.q[3] * other.q[3],
130 | self.q[0] * other.q[1] + self.q[1] * other.q[0] +
131 | self.q[2] * other.q[3] - self.q[3] * other.q[2],
132 | self.q[0] * other.q[2] - self.q[1] * other.q[3] +
133 | self.q[2] * other.q[0] + self.q[3] * other.q[1],
134 | self.q[0] * other.q[3] + self.q[1] * other.q[2] -
135 | self.q[2] * other.q[1] + self.q[3] * other.q[0])))
136 | else: # scalar multiplication (assumed)
137 | return Quaternion(other * self.q)
138 |
139 | def __rmul__(self, other):
140 | return self * other
141 |
142 | def __imul__(self, other):
143 | self.q[:] = (self * other).q
144 | return self
145 |
146 | def __irmul__(self, other):
147 | self.q[:] = (self * other).q
148 | return self
149 |
150 | def __neg__(self):
151 | return Quaternion(-self.q)
152 |
153 | def __sub__(self, other):
154 | return Quaternion(self.q - other.q)
155 |
156 | def __isub__(self, other):
157 | self.q -= other.q
158 | return self
159 |
160 | def __str__(self):
161 | return str(self.q)
162 |
163 | def copy(self):
164 | return Quaternion(self)
165 |
166 | def dot(self, other):
167 | return self.q.dot(other.q)
168 |
169 | # assume the quaternion is nonzero!
170 | def inverse(self):
171 | return Quaternion((~self).q / self.q.dot(self.q))
172 |
173 | def norm(self):
174 | return np.linalg.norm(self.q)
175 |
176 | def normalize(self):
177 | self.q /= np.linalg.norm(self.q)
178 | return self
179 |
180 | # assume x is a Nx3 numpy array or a numpy 3-vector
181 | def rotate_points(self, x):
182 | x = np.atleast_2d(x)
183 | return x.dot(self.ToR().T)
184 |
185 | # convert to a rotation matrix
186 | def ToR(self):
187 | return np.eye(3) + 2 * np.array((
188 | (-self.q[2] * self.q[2] - self.q[3] * self.q[3], self.q[1] * self.q[2] -
189 | self.q[3] * self.q[0], self.q[1] * self.q[3] + self.q[2] * self.q[0]),
190 | (self.q[1] * self.q[2] + self.q[3] * self.q[0], -self.q[1] * self.q[1] -
191 | self.q[3] * self.q[3], self.q[2] * self.q[3] - self.q[1] * self.q[0]),
192 | (self.q[1] * self.q[3] - self.q[2] * self.q[0],
193 | self.q[2] * self.q[3] + self.q[1] * self.q[0],
194 | -self.q[1] * self.q[1] - self.q[2] * self.q[2])))
195 |
196 | # convert to axis-angle representation, with angle encoded by the length
197 | def ToAxisAngle(self):
198 | # recall that for axis-angle representation (a, angle), with "a" unit:
199 | # q = (cos(angle/2), a * sin(angle/2))
200 | # below, for readability, "theta" actually means half of the angle
201 |
202 | sin_sq_theta = self.q[1:].dot(self.q[1:])
203 |
204 | # if theta is non-zero, then we can compute a unique rotation
205 | if np.abs(sin_sq_theta) > np.finfo('float').eps:
206 | sin_theta = np.sqrt(sin_sq_theta)
207 | cos_theta = self.q[0]
208 |
209 | # atan2 is more stable, so we use it to compute theta
210 | # note that we multiply by 2 to get the actual angle
211 | angle = 2. * (
212 | np.arctan2(-sin_theta, -cos_theta) if cos_theta < 0. else np.arctan2(
213 | sin_theta, cos_theta))
214 |
215 | return self.q[1:] * (angle / sin_theta)
216 |
217 | # otherwise, the result is singular, and we avoid dividing by
218 | # sin(angle/2) = 0
219 | return np.zeros(3)
220 |
221 | # euclideanspace.com/maths/geometry/rotations/conversions/quaternionToEuler
222 | # this assumes the quaternion is non-zero
223 | # returns yaw, pitch, roll, with application in that order
224 | def ToEulerAngles(self):
225 | qsq = self.q**2
226 | k = 2. * (self.q[0] * self.q[3] + self.q[1] * self.q[2]) / qsq.sum()
227 |
228 | if (1. - k) < np.finfo('float').eps: # north pole singularity
229 | return 2. * np.arctan2(self.q[1], self.q[0]), 0.5 * np.pi, 0.
230 | if (1. + k) < np.finfo('float').eps: # south pole singularity
231 | return -2. * np.arctan2(self.q[1], self.q[0]), -0.5 * np.pi, 0.
232 |
233 | yaw = np.arctan2(2. * (self.q[0] * self.q[2] - self.q[1] * self.q[3]),
234 | qsq[0] + qsq[1] - qsq[2] - qsq[3])
235 | pitch = np.arcsin(k)
236 | roll = np.arctan2(2. * (self.q[0] * self.q[1] - self.q[2] * self.q[3]),
237 | qsq[0] - qsq[1] + qsq[2] - qsq[3])
238 |
239 | return yaw, pitch, roll
240 |
241 |
242 | #-------------------------------------------------------------------------------
243 | #
244 | # DualQuaternion
245 | #
246 | #-------------------------------------------------------------------------------
247 |
248 |
249 | class DualQuaternion:
250 | # DualQuaternion from an existing rotation + translation
251 | @staticmethod
252 | def FromQT(q, t):
253 | return DualQuaternion(qe=(0.5 * np.asarray(t))) * DualQuaternion(q)
254 |
255 | def __init__(self, q0=np.array((1., 0., 0., 0.)), qe=np.zeros(4)):
256 | self.q0, self.qe = Quaternion(q0), Quaternion(qe)
257 |
258 | def __add__(self, other):
259 | return DualQuaternion(self.q0 + other.q0, self.qe + other.qe)
260 |
261 | def __iadd__(self, other):
262 | self.q0 += other.q0
263 | self.qe += other.qe
264 | return self
265 |
266 | # conguation via the ~ operator
267 | def __invert__(self):
268 | return DualQuaternion(~self.q0, ~self.qe)
269 |
270 | def __mul__(self, other):
271 | if isinstance(other, DualQuaternion):
272 | return DualQuaternion(self.q0 * other.q0,
273 | self.q0 * other.qe + self.qe * other.q0)
274 | elif isinstance(other, complex): # multiplication by a dual number
275 | return DualQuaternion(self.q0 * other.real,
276 | self.q0 * other.imag + self.qe * other.real)
277 | else: # scalar multiplication (assumed)
278 | return DualQuaternion(other * self.q0, other * self.qe)
279 |
280 | def __rmul__(self, other):
281 | return self.__mul__(other)
282 |
283 | def __imul__(self, other):
284 | tmp = self * other
285 | self.q0, self.qe = tmp.q0, tmp.qe
286 | return self
287 |
288 | def __neg__(self):
289 | return DualQuaternion(-self.q0, -self.qe)
290 |
291 | def __sub__(self, other):
292 | return DualQuaternion(self.q0 - other.q0, self.qe - other.qe)
293 |
294 | def __isub__(self, other):
295 | self.q0 -= other.q0
296 | self.qe -= other.qe
297 | return self
298 |
299 | # q^-1 = q* / ||q||^2
300 | # assume that q0 is nonzero!
301 | def inverse(self):
302 | normsq = complex(q0.dot(q0), 2. * self.q0.q.dot(self.qe.q))
303 | inv_len_real = 1. / normsq.real
304 | return ~self * complex(inv_len_real,
305 | -normsq.imag * inv_len_real * inv_len_real)
306 |
307 | # returns a complex representation of the real and imaginary parts of the norm
308 | # assume that q0 is nonzero!
309 | def norm(self):
310 | q0_norm = self.q0.norm()
311 | return complex(q0_norm, self.q0.dot(self.qe) / q0_norm)
312 |
313 | # assume that q0 is nonzero!
314 | def normalize(self):
315 | # current length is ||q0|| + eps * ( / ||q0||)
316 | # writing this as a + eps * b, the inverse is
317 | # 1/||q|| = 1/a - eps * b / a^2
318 | norm = self.norm()
319 | inv_len_real = 1. / norm.real
320 | self *= complex(inv_len_real, -norm.imag * inv_len_real * inv_len_real)
321 | return self
322 |
323 | # return the translation vector for this dual quaternion
324 | def getT(self):
325 | return 2 * (self.qe * ~self.q0).q[1:]
326 |
327 | def ToQT(self):
328 | return self.q0, self.getT()
329 |
--------------------------------------------------------------------------------
/third_party/pycolmap/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r", encoding="utf-8") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="pycolmap",
8 | version="0.0.1",
9 | author="True Price",
10 | description="PyColmap",
11 | long_description=long_description,
12 | long_description_content_type="text/markdown",
13 | url="https://github.com/google/nerfies/third_party/pycolmap",
14 | packages=setuptools.find_packages(),
15 | classifiers=[
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: MIT License",
18 | "Operating System :: OS Independent",
19 | ],
20 | python_requires='>=3.6',
21 | )
22 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Training script for Nerf."""
16 |
17 | import functools
18 | from typing import Dict, Union
19 |
20 | from absl import app
21 | from absl import flags
22 | from absl import logging
23 | from flax import jax_utils
24 | from flax import optim
25 | from flax.metrics import tensorboard
26 | from flax.training import checkpoints
27 | import gin
28 | import jax
29 | from jax import numpy as jnp
30 | from jax import random
31 | import numpy as np
32 | import tensorflow as tf
33 |
34 | from nerfies import configs
35 | from nerfies import datasets
36 | from nerfies import gpath
37 | from nerfies import model_utils
38 | from nerfies import models
39 | from nerfies import schedules
40 | from nerfies import training
41 | from nerfies import utils
42 |
43 | flags.DEFINE_enum('mode', None, ['jax_cpu', 'jax_gpu', 'jax_tpu'],
44 | 'Distributed strategy approach.')
45 |
46 | flags.DEFINE_string('base_folder', None, 'where to store ckpts and logs')
47 | flags.mark_flag_as_required('base_folder')
48 | flags.DEFINE_string('data_dir', None, 'input data directory.')
49 | flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.')
50 | flags.DEFINE_multi_string('gin_configs', (), 'Gin config files.')
51 | FLAGS = flags.FLAGS
52 |
53 | jax.config.parse_flags_with_absl()
54 |
55 |
56 | def _log_to_tensorboard(writer: tensorboard.SummaryWriter,
57 | state: model_utils.TrainState,
58 | scalar_params: training.ScalarParams,
59 | stats: Dict[str, Union[Dict[str, jnp.ndarray],
60 | jnp.ndarray]],
61 | time_dict: Dict[str, jnp.ndarray]):
62 | """Log statistics to Tensorboard."""
63 | step = int(state.optimizer.state.step)
64 | writer.scalar('params/learning_rate', scalar_params.learning_rate, step)
65 | writer.scalar('params/warp_alpha', state.warp_alpha, step)
66 | writer.scalar('params/time_alpha', state.time_alpha, step)
67 | writer.scalar('params/elastic_loss/weight',
68 | scalar_params.elastic_loss_weight, step)
69 |
70 | # pmean is applied in train_step so just take the item.
71 | for branch in {'coarse', 'fine'}:
72 | if branch not in stats:
73 | continue
74 | for stat_key, stat_value in stats[branch].items():
75 | writer.scalar(f'{stat_key}/{branch}', stat_value, step)
76 |
77 | if 'background_loss' in stats:
78 | writer.scalar('loss/background', stats['background_loss'], step)
79 |
80 | for k, v in time_dict.items():
81 | writer.scalar(f'time/{k}', v, step)
82 |
83 |
84 | def _log_histograms(writer: tensorboard.SummaryWriter, model: models.NerfModel,
85 | state: model_utils.TrainState):
86 | """Log histograms to Tensorboard."""
87 | step = int(state.optimizer.state.step)
88 | params = state.optimizer.target['model']
89 | if 'appearance_encoder' in params:
90 | embeddings = params['appearance_encoder']['embed']['embedding']
91 | writer.histogram('appearance_embedding', embeddings, step)
92 | if 'camera_encoder' in params:
93 | embeddings = params['camera_encoder']['embed']['embedding']
94 | writer.histogram('camera_embedding', embeddings, step)
95 | if 'warp_field' in params and model.warp_metadata_encoder_type == 'glo':
96 | embeddings = params['warp_field']['metadata_encoder']['embed']['embedding']
97 | writer.histogram('warp_embedding', embeddings, step)
98 |
99 |
100 | def main(argv):
101 | tf.config.experimental.set_visible_devices([], 'GPU')
102 | del argv
103 | logging.info('*** Starting experiment')
104 | gin_configs = FLAGS.gin_configs
105 |
106 | logging.info('*** Loading Gin configs from: %s', str(gin_configs))
107 | gin.parse_config_files_and_bindings(
108 | config_files=gin_configs,
109 | bindings=FLAGS.gin_bindings,
110 | skip_unknown=True)
111 |
112 | # Load configurations.
113 | exp_config = configs.ExperimentConfig()
114 | model_config = configs.ModelConfig()
115 | train_config = configs.TrainConfig()
116 |
117 | # Get directory information.
118 | exp_dir = gpath.GPath(FLAGS.base_folder)
119 | if exp_config.subname:
120 | exp_dir = exp_dir / exp_config.subname
121 | summary_dir = exp_dir / 'summaries' / 'train'
122 | checkpoint_dir = exp_dir / 'checkpoints'
123 |
124 | # Log and create directories if this is the main host.
125 | if jax.process_index() == 0:
126 | logging.info('exp_dir = %s', exp_dir)
127 | if not exp_dir.exists():
128 | exp_dir.mkdir(parents=True, exist_ok=True)
129 |
130 | logging.info('summary_dir = %s', summary_dir)
131 | if not summary_dir.exists():
132 | summary_dir.mkdir(parents=True, exist_ok=True)
133 |
134 | logging.info('checkpoint_dir = %s', checkpoint_dir)
135 | if not checkpoint_dir.exists():
136 | checkpoint_dir.mkdir(parents=True, exist_ok=True)
137 |
138 | config_str = gin.operative_config_str()
139 | logging.info('Configuration: \n%s', config_str)
140 | with (exp_dir / 'config.gin').open('w') as f:
141 | f.write(config_str)
142 |
143 | logging.info('Starting host %d. There are %d hosts : %s', jax.process_index(),
144 | jax.process_count(), str(jax.process_indexs()))
145 | logging.info('Found %d accelerator devices: %s.', jax.local_device_count(),
146 | str(jax.local_devices()))
147 | logging.info('Found %d total devices: %s.', jax.device_count(),
148 | str(jax.devices()))
149 |
150 | rng = random.PRNGKey(exp_config.random_seed)
151 | # Shift the numpy random seed by host_id() to shuffle data loaded by different
152 | # hosts.
153 | np.random.seed(exp_config.random_seed + jax.process_index())
154 |
155 | if train_config.batch_size % jax.device_count() != 0:
156 | raise ValueError('Batch size must be divisible by the number of devices.')
157 |
158 | devices = jax.local_devices()
159 | datasource_spec = exp_config.datasource_spec
160 | if datasource_spec is None:
161 | datasource_spec = {
162 | 'type': exp_config.datasource_type,
163 | 'data_dir': FLAGS.data_dir,
164 | }
165 | logging.info('Creating datasource: %s', datasource_spec)
166 | datasource = datasets.from_config(
167 | datasource_spec,
168 | image_scale=exp_config.image_scale,
169 | use_appearance_id=model_config.use_appearance_metadata,
170 | use_camera_id=model_config.use_camera_metadata,
171 | use_warp_id=model_config.use_warp,
172 | use_time=model_config.warp_metadata_encoder_type == 'time',
173 | random_seed=exp_config.random_seed,
174 | **exp_config.datasource_kwargs)
175 | train_iter = datasource.create_iterator(
176 | datasource.train_ids,
177 | flatten=True,
178 | shuffle=True,
179 | batch_size=train_config.batch_size,
180 | prefetch_size=3,
181 | shuffle_buffer_size=train_config.shuffle_buffer_size,
182 | devices=devices,
183 | )
184 |
185 | points_iter = None
186 | if train_config.use_background_loss:
187 | points = datasource.load_points(shuffle=True)
188 | points_batch_size = min(
189 | len(points),
190 | len(devices) * train_config.background_points_batch_size)
191 | points_batch_size -= points_batch_size % len(devices)
192 | points_dataset = tf.data.Dataset.from_tensor_slices(points)
193 | points_iter = datasets.iterator_from_dataset(
194 | points_dataset,
195 | batch_size=points_batch_size,
196 | prefetch_size=3,
197 | devices=devices)
198 |
199 | learning_rate_sched = schedules.from_config(train_config.lr_schedule)
200 | warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)
201 | time_alpha_sched = schedules.from_config(train_config.time_alpha_schedule)
202 | elastic_loss_weight_sched = schedules.from_config(
203 | train_config.elastic_loss_weight_schedule)
204 |
205 | rng, key = random.split(rng)
206 | params = {}
207 | model, params['model'] = models.construct_nerf(
208 | key,
209 | model_config,
210 | batch_size=train_config.batch_size,
211 | appearance_ids=datasource.appearance_ids,
212 | camera_ids=datasource.camera_ids,
213 | warp_ids=datasource.warp_ids,
214 | near=datasource.near,
215 | far=datasource.far,
216 | use_warp_jacobian=train_config.use_elastic_loss,
217 | use_weights=train_config.use_elastic_loss)
218 |
219 | optimizer_def = optim.Adam(learning_rate_sched(0))
220 | optimizer = optimizer_def.create(params)
221 | state = model_utils.TrainState(
222 | optimizer=optimizer,
223 | warp_alpha=warp_alpha_sched(0),
224 | time_alpha=time_alpha_sched(0))
225 | scalar_params = training.ScalarParams(
226 | learning_rate=learning_rate_sched(0),
227 | elastic_loss_weight=elastic_loss_weight_sched(0),
228 | warp_reg_loss_weight=train_config.warp_reg_loss_weight,
229 | warp_reg_loss_alpha=train_config.warp_reg_loss_alpha,
230 | warp_reg_loss_scale=train_config.warp_reg_loss_scale,
231 | background_loss_weight=train_config.background_loss_weight)
232 | state = checkpoints.restore_checkpoint(checkpoint_dir, state)
233 | init_step = state.optimizer.state.step + 1
234 | state = jax_utils.replicate(state, devices=devices)
235 | del params
236 |
237 | logging.info('Initializing models')
238 |
239 | summary_writer = None
240 | if jax.process_index() == 0:
241 | summary_writer = tensorboard.SummaryWriter(str(summary_dir))
242 | summary_writer.text(
243 | 'gin/train', textdata=gin.config.markdown(config_str), step=0)
244 |
245 | train_step = functools.partial(
246 | training.train_step,
247 | model,
248 | elastic_reduce_method=train_config.elastic_reduce_method,
249 | elastic_loss_type=train_config.elastic_loss_type,
250 | use_elastic_loss=train_config.use_elastic_loss,
251 | use_background_loss=train_config.use_background_loss,
252 | use_warp_reg_loss=train_config.use_warp_reg_loss,
253 | )
254 | ptrain_step = jax.pmap(
255 | train_step,
256 | axis_name='batch',
257 | devices=devices,
258 | # rng_key, state, batch, scalar_params.
259 | in_axes=(0, 0, 0, None),
260 | # Treat use_elastic_loss as compile-time static.
261 | donate_argnums=(2,), # Donate the 'batch' argument.
262 | )
263 |
264 | if devices:
265 | n_local_devices = len(devices)
266 | else:
267 | n_local_devices = jax.local_device_count()
268 |
269 | logging.info('Starting training')
270 | rng = rng + jax.process_index() # Make random seed separate across hosts.
271 | keys = random.split(rng, n_local_devices)
272 | time_tracker = utils.TimeTracker()
273 | time_tracker.tic('data', 'total')
274 | for step, batch in zip(range(init_step, train_config.max_steps + 1),
275 | train_iter):
276 | if points_iter is not None:
277 | batch['background_points'] = next(points_iter)
278 | time_tracker.toc('data')
279 | # pytype: disable=attribute-error
280 | scalar_params = scalar_params.replace(
281 | learning_rate=learning_rate_sched(step),
282 | elastic_loss_weight=elastic_loss_weight_sched(step))
283 | warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices)
284 | time_alpha = jax_utils.replicate(time_alpha_sched(step), devices)
285 | state = state.replace(warp_alpha=warp_alpha, time_alpha=time_alpha)
286 |
287 | with time_tracker.record_time('train_step'):
288 | state, stats, keys = ptrain_step(keys, state, batch, scalar_params)
289 | time_tracker.toc('total')
290 |
291 | if step % train_config.print_every == 0 and jax.process_index() == 0:
292 | logging.info('step=%d, warp_alpha=%.04f, time_alpha=%.04f, %s', step,
293 | warp_alpha_sched(step), time_alpha_sched(step),
294 | time_tracker.summary_str('last'))
295 | coarse_metrics_str = ', '.join(
296 | [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()])
297 | fine_metrics_str = ', '.join(
298 | [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()])
299 | logging.info('\tcoarse metrics: %s', coarse_metrics_str)
300 | if 'fine' in stats:
301 | logging.info('\tfine metrics: %s', fine_metrics_str)
302 |
303 | if step % train_config.save_every == 0 and jax.process_index() == 0:
304 | training.save_checkpoint(checkpoint_dir, state)
305 |
306 | if step % train_config.log_every == 0 and jax.process_index() == 0:
307 | # Only log via host 0.
308 | _log_to_tensorboard(
309 | summary_writer,
310 | jax_utils.unreplicate(state),
311 | scalar_params,
312 | jax_utils.unreplicate(stats),
313 | time_dict=time_tracker.summary('mean'))
314 | time_tracker.reset()
315 |
316 | if step % train_config.histogram_every == 0 and jax.process_index() == 0:
317 | _log_histograms(summary_writer, model, jax_utils.unreplicate(state))
318 |
319 | time_tracker.tic('data', 'total')
320 |
321 | if train_config.max_steps % train_config.save_every != 0:
322 | training.save_checkpoint(checkpoint_dir, state)
323 |
324 |
325 | if __name__ == '__main__':
326 | app.run(main)
327 |
--------------------------------------------------------------------------------