├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── configs
├── 360_dino.gin
├── 360_glo4.gin
└── 360_robustnerf.gin
├── eval.py
├── internal
├── camera_utils.py
├── configs.py
├── coord.py
├── datasets.py
├── geopoly.py
├── image.py
├── math.py
├── models.py
├── raw_utils.py
├── ref_utils.py
├── render.py
├── robustnerf.py
├── stepfun.py
├── train_utils.py
├── utils.py
└── vis.py
├── media
└── teaser.gif
├── render.py
├── requirements.txt
├── scripts
├── download_on-the-go.sh
├── eval_on-the-go.sh
├── eval_on-the-go_HD.sh
├── feature_extract.py
├── feature_extract.sh
├── local_colmap_and_resize.sh
├── render_on-the-go.sh
├── render_on-the-go_HD.sh
├── run_all_unit_tests.sh
├── train_on-the-go.sh
└── train_on-the-go_HD.sh
├── tests
├── camera_utils_test.py
├── coord_test.py
├── math_test.py
├── stepfun_test.py
└── utils_test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | internal/pycolmap
2 | __pycache__/
3 | interal/__pycache__/
4 | tests/__pycache__/
5 | .DS_Store
6 | .vscode/
7 | .idea/
8 | __MACOSX/
9 | *.err
10 | *.out
11 | output/
12 | RobustNerf/
13 | data/
14 | jupyter/
15 | slurm/
16 | output_8/
17 | output_2/
18 | output_ablation/
19 | zzh_output_8/
20 | 360/
21 | output_360/
22 | output_new_ablation/
23 | output_highres
24 | scripts/SAM/*
25 | scripts/static/*
26 | scripts/blockview/*
27 | video_maker/
28 | output_ablation_new/
29 | tmp_script/
30 | scripts/SAM/
31 | Datasets/
32 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
29 |
30 |
31 |
32 |
33 |
34 | Table of Contents
35 |
36 | -
37 | Description
38 |
39 | -
40 | Setup
41 |
42 | -
43 | Dataset Preparation
44 |
45 | -
46 | Running
47 |
48 | -
49 | Checkpoint
50 |
51 | -
52 | Citation
53 |
54 | -
55 | Contact
56 |
57 |
58 |
59 |
60 |
61 | ## Description
62 |
63 | This repository hosts the official Jax implementation of the paper "NeRF on-the-go: Exploiting Uncertainty for Distractor-free NeRFs in the Wild" (CVPR 2024). For more details, please visit our [project webpage](https://rwn17.github.io/nerf-on-the-go/).
64 |
65 | This Repo is built upon [Multinerf](https://github.com/google-research/multinerf) codebase.
66 |
67 | ## Setup
68 |
69 | ```
70 | # Clone the repo.
71 | git clone https://github.com/cvg/nerf-on-the-go
72 | cd nerf-on-the-go
73 |
74 | # Make a conda environment.
75 | conda create --name on-the-go python=3.9
76 | conda activate on-the-go
77 |
78 | # Prepare pip.
79 | conda install pip
80 | pip install --upgrade pip
81 |
82 |
83 | # Install requirements.
84 | pip install -r requirements.txt
85 |
86 | # Manually install rmbrualla's `pycolmap` (don't use pip's! It's different).
87 | git clone https://github.com/rmbrualla/pycolmap.git ./internal/pycolmap
88 |
89 | # Confirm that all the unit tests pass.
90 | ./scripts/run_all_unit_tests.sh
91 | ```
92 | You'll also need to update your [JAX](https://jax.readthedocs.io/en/latest/installation.html) installation to support GPUs or TPUs.
93 |
94 | ```
95 | pip install -U "jax[cuda12]"
96 | ```
97 |
98 | ### Instructions for ETH Euler
99 |
100 | Click to expand
101 |
102 | on ETH Euler, to support for GPU jax, you need to apply for a debug mode gpu and then upgrade the gcc and cuda
103 | ```
104 | srun -n 4 --mem-per-cpu=12000 --gpus=rtx_3090:1 --gres=gpumem:20g --time=4:00:00 --pty bash
105 | conda activate on-the-go
106 | module load eth_proxy gcc/8.2.0 cuda/12.1.1 cudnn/8.9.2.26
107 | ```
108 |
109 | After loading the modules, verify their activation by executing ```module list```. Occasionally, modules may not load correctly, requiring you to load each one individually. Following this, proceed with the Jax installation:
110 |
111 | ```
112 | # Installs the wheel compatible with CUDA 12 and cuDNN 8.9 or newer.
113 | pip install jax==0.4.26 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
114 | pip install jaxlib==0.4.26+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
115 | ```
116 |
117 | After successful installation, please rerun ```./scripts/run_all_unit_tests.sh```.
118 |
119 | The installation process outlined above has been verified on the Euler system using an RTX 3090. You may get a warning
120 | ```
121 | The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
122 | ```
123 |
124 | But it's fine. The Euler supports up to CUDA 12.1, while JAX now requires a minimum of CUDA 12.3. As discussed in the JAX [Issue #18032](https://github.com/google/jax/issues/18032), this discrepancy primarily impacts compilation speed rather than overall functionality.
125 |
126 |
127 |
128 | ## Dataset Preparation
129 |
130 | ### Downloading the Dataset
131 |
132 | Before process data, please make sure you install mogrify by
133 |
134 | ```
135 | conda install -y -c conda-forge imagemagick
136 | ```
137 |
138 | To download the "On-the-go" dataset, execute the following command:
139 | ```bash
140 | bash ./scripts/download_on-the-go.sh
141 | ```
142 | This script not only downloads the dataset but also downsamples the images as required. **NOTE: Please double check whether the data has been correctly DOWNSAMPLED!**
143 |
144 | ### Feature Extraction with DINOv2
145 | For extracting features using the DINOv2, use the command below:
146 | ```bash
147 | bash ./scripts/feature_extract.sh
148 | ```
149 |
150 |
151 | After feature extraction, the dataset should be organized as
152 | ```
153 | on-the-go
154 | ├── arcdetriomphe
155 | │ ├── images
156 | │ ├── images_{DOWNSAMPLE_RATE}
157 | │ ├── features_{DOWNSAMPLE_RATE}
158 | │ ├── split.json
159 | │ ├── transforms.json
160 | ├── ....
161 | │
162 | └── tree
163 | ├── images_{DOWNSAMPLE_RATE}
164 | ├── ....
165 | └── transforms.json
166 | ```
167 |
168 | ### Dataset Structure and Configuration Files
169 | - **split.json**: This file outlines the train and evaluation splits, following the naming conventions used in the RobustNeRF dataset, categorized as 'clutter' and 'clean'.
170 | - **transforms.json**: Contains pose and intrinsic information, formatted according to the Blender dataset format, derived from COLMAP files. Refer to the [Instant-NGP script](https://github.com/NVlabs/instant-ngp/blob/de507662d4b3398163e426fd426d48ff8f2895f6/scripts/colmap2nerf.py) for more details.
171 |
172 | ### Future Updates
173 | We plan to expand support to include custom datasets in future updates.
174 |
175 |
176 | ## Running
177 |
178 | Example scripts for training, evaluating, and rendering can be found in
179 | `scripts/`. You'll need to change the paths to point to wherever the datasets
180 | are located. [Gin](https://github.com/google/gin-config) configuration files
181 | for our model and some ablations can be found in `configs/`.
182 |
183 | 1. Training on-the-go:
184 | ```
185 | bash scripts/train_on-the-go.sh
186 | ```
187 |
188 | 2. Evaluating on-the-go:
189 | ```
190 | bash scripts/eval_on-the-go.sh
191 | ```
192 |
193 | 3. Rendering on-the-go:
194 | ```
195 | bash scirpts/render_on-the-go.sh
196 | ```
197 |
198 | Tensorboard is supported for logging.
199 |
200 | ### Note
201 | Since we use a different recording device for ***arc de triomphe*** and ***patio*** scene, the image downsample rate(4 instead of 8) and feature downsample rate(2 instead of 4) is different. Please use a separate script to train them by
202 |
203 | ```
204 | bash scripts/train_on-the-go_HD.sh
205 | ```
206 |
207 | ### OOM errors
208 |
209 | About **80G gpu memory** is needed to run current version.You may need to reduce the batch size (`Config.batch_size`) to avoid out of memory
210 | errors. If you do this, but want to preserve quality, be sure to increase the number
211 | of training iterations and decrease the learning rate by whatever scale factor you
212 | decrease batch size by.
213 |
214 | ## Checkpoint
215 | We release the ckpt for quantatitive scenes [here](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/weining_connect_hku_hk/ER2Esfrn0plAjCa2I6G2BJ4B56qX4B5whdMgk5T90A_F-A?e=3gklLp).
216 |
217 | Scene | Mountain | Fountain | Corner | Patio | Spot | Patio-High
218 | -- | -- | -- | -- | -- | -- | --
219 | onthego (paper) | 20.15 | 20.11 | 24.22 | 20.78 | 23.33 | 21.41
220 | onthego (released ckpt) | 20.89 | 19.88 | 24.69 | 22.30 | 24.67 | 22.30
221 |
222 |
223 | ## Citation
224 |
225 | If you use NeRF on-the-go, please cite
226 |
227 | ```
228 | @InProceedings{Ren2024NeRF,
229 | title={NeRF on-the-go: Exploiting Uncertainty for Distractor-free NeRFs in the Wild},
230 | author={Ren, Weining and Zhu, Zihan and Sun, Boyang and Chen, Jiaqi and Pollefeys, Marc and Peng, Songyou},
231 | booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
232 | year={2024}
233 | }
234 | ```
235 |
236 | Also, this code is built upon multinerf, feel free to cite this entire codebase as:
237 |
238 | ```
239 | @misc{multinerf2022,
240 | title={{MultiNeRF}: {A} {Code} {Release} for {Mip-NeRF} 360, {Ref-NeRF}, and {RawNeRF}},
241 | author={Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman and Ricardo Martin-Brualla and Jonathan T. Barron},
242 | year={2022},
243 | url={https://github.com/google-research/multinerf},
244 | }
245 | ```
246 |
247 | ## Contact
248 | If there is any problem, please contact Weining by weining@connect.hku.hk
249 |
--------------------------------------------------------------------------------
/configs/360_dino.gin:
--------------------------------------------------------------------------------
1 | Config.dataset_loader = 'on-the-go'
2 | Config.near = 0.2
3 | Config.far = 1e6
4 | Config.factor = 8
5 |
6 | Config.patch_size = 32
7 | Config.data_loss_type = 'dino_ssim'
8 | Config.enable_robustnerf_loss = False
9 | Config.compute_feature_metrics = True
10 |
11 | Model.raydist_fn = @jnp.reciprocal
12 | Model.opaque_background = True
13 | Model.num_glo_features = 4
14 |
15 |
16 | PropMLP.warp_fn = @coord.contract
17 | PropMLP.net_depth = 4
18 | PropMLP.net_width = 256
19 | PropMLP.disable_density_normals = True
20 | PropMLP.disable_rgb = True
21 |
22 | NerfMLP.warp_fn = @coord.contract
23 | NerfMLP.net_depth = 8
24 | NerfMLP.net_width = 1024
25 | NerfMLP.disable_density_normals = True
26 |
--------------------------------------------------------------------------------
/configs/360_glo4.gin:
--------------------------------------------------------------------------------
1 | Config.dataset_loader = 'on-the-go'
2 | Config.near = 0.2
3 | Config.far = 1e6
4 | Config.factor = 4
5 |
6 | Model.raydist_fn = @jnp.reciprocal
7 | Model.num_glo_features = 4
8 | Model.opaque_background = True
9 |
10 | PropMLP.warp_fn = @coord.contract
11 | PropMLP.net_depth = 4
12 | PropMLP.net_width = 256
13 | PropMLP.disable_density_normals = True
14 | PropMLP.disable_rgb = True
15 |
16 | NerfMLP.warp_fn = @coord.contract
17 | NerfMLP.net_depth = 8
18 | NerfMLP.net_width = 1024
19 | NerfMLP.disable_density_normals = True
20 |
--------------------------------------------------------------------------------
/configs/360_robustnerf.gin:
--------------------------------------------------------------------------------
1 | Config.dataset_loader = 'on-the-go'
2 | Config.near = 0.2
3 | Config.far = 1e6
4 | Config.factor = 8
5 |
6 | Config.patch_size = 16
7 | Config.data_loss_type = 'robustnerf'
8 | Config.robustnerf_inlier_quantile = 0.8
9 | Config.enable_robustnerf_loss = True
10 | Config.compute_feature_metrics = True
11 | Model.num_glo_features = 4
12 |
13 |
14 | Model.raydist_fn = @jnp.reciprocal
15 | Model.opaque_background = True
16 |
17 | PropMLP.warp_fn = @coord.contract
18 | PropMLP.net_depth = 4
19 | PropMLP.net_width = 256
20 | PropMLP.disable_density_normals = True
21 | PropMLP.disable_rgb = True
22 |
23 | NerfMLP.warp_fn = @coord.contract
24 | NerfMLP.net_depth = 8
25 | NerfMLP.net_width = 1024
26 | NerfMLP.disable_density_normals = True
27 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Evaluation script."""
16 |
17 | import functools
18 | from os import path
19 | import sys
20 | import time
21 |
22 | from absl import app
23 | from flax.metrics import tensorboard
24 | from flax.training import checkpoints
25 | import gin
26 | from internal import configs
27 | from internal import datasets
28 | from internal import image
29 | from internal import models
30 | from internal import raw_utils
31 | from internal import ref_utils
32 | from internal import train_utils
33 | from internal import utils
34 | from internal import vis
35 | import jax
36 | from jax import random
37 | import jax.numpy as jnp
38 | import numpy as np
39 | from matplotlib import cm
40 | from internal.vis import visualize_cmap
41 |
42 |
43 | configs.define_common_flags()
44 | jax.config.parse_flags_with_absl()
45 |
46 |
47 | def main(unused_argv):
48 | config = configs.load_config(save_config=False)
49 |
50 | dataset = datasets.load_dataset('test', config.data_dir, config)
51 |
52 | key = random.PRNGKey(20200823)
53 | _, state, render_eval_pfn, _, _ = train_utils.setup_model(config, key)
54 |
55 | if config.rawnerf_mode:
56 | postprocess_fn = dataset.metadata['postprocess_fn']
57 | else:
58 | postprocess_fn = lambda z: z
59 |
60 | if config.eval_raw_affine_cc:
61 | cc_fun = raw_utils.match_images_affine
62 | else:
63 | cc_fun = image.color_correct
64 |
65 | metric_harness = image.MetricHarnessLPIPS()
66 |
67 | last_step = 0
68 | dir_name = 'train_preds' if config.eval_train else 'test_preds'
69 | out_dir = path.join(config.checkpoint_dir,
70 | 'path_renders' if config.render_path else dir_name)
71 | path_fn = lambda x: path.join(out_dir, x)
72 |
73 | if not config.eval_only_once:
74 | summary_writer = tensorboard.SummaryWriter(
75 | path.join(config.checkpoint_dir, 'eval'))
76 | while True:
77 | state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)
78 | step = int(state.step)
79 | if step <= last_step:
80 | print(f'Checkpoint step {step} <= last step {last_step}, sleeping.')
81 | time.sleep(10)
82 | continue
83 | print(f'Evaluating checkpoint at step {step}.')
84 | if config.eval_save_output and (not utils.isdir(out_dir)):
85 | utils.makedirs(out_dir)
86 |
87 | num_eval = min(dataset.size, config.eval_dataset_limit)
88 | key = random.PRNGKey(0 if config.deterministic_showcase else step)
89 | perm = random.permutation(key, num_eval)
90 | showcase_indices = np.sort(perm[:config.num_showcase_images])
91 |
92 | metrics = []
93 | metrics_cc = []
94 | showcases = []
95 | render_times = []
96 | for idx in range(dataset.size):
97 | eval_start_time = time.time()
98 | batch = next(dataset)
99 | if idx >= num_eval:
100 | print(f'Skipping image {idx+1}/{dataset.size}')
101 | continue
102 | print(f'Evaluating image {idx+1}/{dataset.size}')
103 | rays = batch.rays
104 | train_frac = state.step / config.max_steps
105 | rendering = models.render_image(
106 | functools.partial(
107 | render_eval_pfn,
108 | state.params,
109 | train_frac,
110 | ),
111 | rays,
112 | None,
113 | config,
114 | )
115 |
116 | if jax.host_id() != 0: # Only record via host 0.
117 | continue
118 |
119 | render_times.append((time.time() - eval_start_time))
120 | print(f'Rendered in {render_times[-1]:0.3f}s')
121 |
122 | # Cast to 64-bit to ensure high precision for color correction function.
123 | gt_rgb = np.array(batch.rgb, dtype=np.float64)
124 | rendering['rgb'] = np.array(rendering['rgb'], dtype=np.float64)
125 |
126 | cc_start_time = time.time()
127 | rendering['rgb_cc'] = cc_fun(rendering['rgb'], gt_rgb)
128 | # rendering['rgb_cc'] = rendering['rgb']
129 | print(f'Color corrected in {(time.time() - cc_start_time):0.3f}s')
130 |
131 | if not config.eval_only_once and idx in showcase_indices:
132 | showcase_idx = idx if config.deterministic_showcase else len(showcases)
133 | showcases.append((showcase_idx, rendering, batch))
134 | if not config.render_path:
135 | rgb = postprocess_fn(rendering['rgb'])
136 | rgb_cc = postprocess_fn(rendering['rgb_cc'])
137 | rgb_gt = postprocess_fn(gt_rgb)
138 |
139 | if config.eval_quantize_metrics:
140 | # Ensures that the images written to disk reproduce the metrics.
141 | rgb = np.round(rgb * 255) / 255
142 | rgb_cc = np.round(rgb_cc * 255) / 255
143 |
144 | if config.eval_crop_borders > 0:
145 | crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c]
146 | rgb = crop_fn(rgb)
147 | rgb_cc = crop_fn(rgb_cc)
148 | rgb_gt = crop_fn(rgb_gt)
149 | metric = metric_harness(rgb.astype(np.float32), rgb_gt.astype(np.float32))
150 | metric_cc = metric_harness(rgb_cc.astype(np.float32), rgb_gt.astype(np.float32))
151 |
152 | for m, v in metric.items():
153 | print(f'{m:30s} = {v:.4f}')
154 |
155 | metrics.append(metric)
156 | metrics_cc.append(metric_cc)
157 |
158 | if config.eval_save_output and (config.eval_render_interval > 0):
159 | if (idx % config.eval_render_interval) == 0:
160 | utils.save_img_u8(postprocess_fn(rendering['rgb']),
161 | path_fn(f'color_{idx:03d}.png'))
162 | utils.save_img_u8(postprocess_fn(rendering['rgb_cc']),
163 | path_fn(f'color_{idx:03d}_cc.png'))
164 | utils.save_img_u8(rgb_gt,
165 | path_fn(f'gt_color_{idx:03d}.png'))
166 | utils.save_img_u8(postprocess_fn(rendering['rgb_cc']),
167 | path_fn(f'color_cc_{idx:03d}.png'))
168 |
169 | for key in ['distance_mean', 'distance_median']:
170 | if key in rendering:
171 | utils.save_img_f32(rendering[key],
172 | path_fn(f'{key}_{idx:03d}.tiff'))
173 |
174 | for key in ['normals']:
175 | if key in rendering:
176 | utils.save_img_u8(rendering[key] / 2. + 0.5,
177 | path_fn(f'{key}_{idx:03d}.png'))
178 |
179 | vis_uncertainty = visualize_cmap(
180 | rendering['uncer'][...,0],
181 | rendering['acc'],
182 | cm.get_cmap('turbo'),
183 | lo=0.2,
184 | hi=2,
185 | )
186 | utils.save_img_u8(postprocess_fn(vis_uncertainty), path_fn(f'uncer_{idx:03d}.png'))
187 | utils.save_img_f32(rendering['uncer'][...,0], path_fn(f'uncer_raw_{idx:03d}.tiff'))
188 |
189 | if (not config.eval_only_once) and (jax.host_id() == 0):
190 | summary_writer.scalar('eval_median_render_time', np.median(render_times),
191 | step)
192 | for name in metrics[0]:
193 | scores = [m[name] for m in metrics]
194 | summary_writer.scalar('eval_metrics/' + name, np.mean(scores), step)
195 | summary_writer.histogram('eval_metrics/' + 'perimage_' + name, scores,
196 | step)
197 | for name in metrics_cc[0]:
198 | scores = [m[name] for m in metrics_cc]
199 | summary_writer.scalar('eval_metrics_cc/' + name, np.mean(scores), step)
200 | summary_writer.histogram('eval_metrics_cc/' + 'perimage_' + name,
201 | scores, step)
202 |
203 | for i, r, b in showcases:
204 | if config.vis_decimate > 1:
205 | d = config.vis_decimate
206 | decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d]
207 | else:
208 | decimate_fn = lambda x: x
209 | r = jax.tree_util.tree_map(decimate_fn, r)
210 | b = jax.tree_util.tree_map(decimate_fn, b)
211 | visualizations = vis.visualize_suite(r, b.rays)
212 | for k, v in visualizations.items():
213 | if k == 'color':
214 | v = postprocess_fn(v)
215 | summary_writer.image(f'output_{k}_{i}', v, step)
216 | if not config.render_path:
217 | target = postprocess_fn(b.rgb)
218 | summary_writer.image(f'true_color_{i}', target, step)
219 | pred = postprocess_fn(visualizations['color'])
220 | residual = np.clip(pred - target + 0.5, 0, 1)
221 | summary_writer.image(f'true_residual_{i}', residual, step)
222 | summary_writer.image(f'uncertainty_{i}', visualizations['uncertainty'],
223 | step)
224 |
225 | if (config.eval_save_output and (not config.render_path) and
226 | (jax.host_id() == 0)):
227 | with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f:
228 | f.write(' '.join([str(r) for r in render_times]))
229 | for name in metrics[0]:
230 | with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f:
231 | f.write(' '.join([str(m[name]) for m in metrics]))
232 | for name in metrics_cc[0]:
233 | with utils.open_file(path_fn(f'metric_cc_{name}_{step}.txt'), 'w') as f:
234 | f.write(' '.join([str(m[name]) for m in metrics_cc]))
235 | if config.eval_save_ray_data:
236 | for i, r, b in showcases:
237 | rays = {k: v for k, v in r.items() if 'ray_' in k}
238 | np.set_printoptions(threshold=sys.maxsize)
239 | with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f:
240 | f.write(repr(rays))
241 | for name in metrics[0]:
242 | with utils.open_file(path_fn(f'metric_{name}_{step}_avg.txt'), 'w') as f:
243 | avg=np.mean(np.array([m[name] for m in metrics]))
244 | if 0 < avg < 1:
245 | f.write(f'{avg:.3f}'[1:])
246 | else:
247 | f.write(f'{avg:.2f}')
248 | for name in metrics_cc[0]:
249 | with utils.open_file(path_fn(f'metric_cc_{name}_{step}_avg.txt'), 'w') as f:
250 | avg=np.mean(np.array([m[name] for m in metrics_cc]))
251 | if 0 < avg < 1:
252 | f.write(f'{avg:.3f}'[1:])
253 | else:
254 | f.write(f'{avg:.2f}')
255 |
256 |
257 |
258 | # A hack that forces Jax to keep all TPUs alive until every TPU is finished.
259 | x = jnp.ones([jax.local_device_count()])
260 | x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
261 | print(x)
262 |
263 | if config.eval_only_once:
264 | break
265 | if config.early_exit_steps is not None:
266 | num_steps = config.early_exit_steps
267 | else:
268 | num_steps = config.max_steps
269 | if int(step) >= num_steps:
270 | break
271 | last_step = step
272 |
273 |
274 | if __name__ == '__main__':
275 | with gin.config_scope('eval'):
276 | app.run(main)
277 |
--------------------------------------------------------------------------------
/internal/configs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility functions for handling configurations."""
16 |
17 | import dataclasses
18 | from typing import Any, Callable, Optional, Tuple
19 |
20 | from absl import flags
21 | from flax.core import FrozenDict
22 | import gin
23 | from internal import utils
24 | import jax
25 | import jax.numpy as jnp
26 |
27 | gin.add_config_file_search_path('experimental/users/barron/mipnerf360/')
28 |
29 | configurables = {
30 | 'jnp': [jnp.reciprocal, jnp.log, jnp.log1p, jnp.exp, jnp.sqrt, jnp.square],
31 | 'jax.nn': [jax.nn.relu, jax.nn.softplus, jax.nn.silu],
32 | 'jax.nn.initializers.he_normal': [jax.nn.initializers.he_normal()],
33 | 'jax.nn.initializers.he_uniform': [jax.nn.initializers.he_uniform()],
34 | 'jax.nn.initializers.glorot_normal': [jax.nn.initializers.glorot_normal()],
35 | 'jax.nn.initializers.glorot_uniform': [
36 | jax.nn.initializers.glorot_uniform()
37 | ],
38 | }
39 |
40 | for module, configurables in configurables.items():
41 | for configurable in configurables:
42 | gin.config.external_configurable(configurable, module=module)
43 |
44 |
45 | @gin.configurable()
46 | @dataclasses.dataclass
47 | class Config:
48 | """Configuration flags for everything."""
49 | dataset_loader: str = 'on-the-go' # The type of dataset loader to use.
50 | batching: str = 'all_images' # Batch composition, [single_image, all_images].
51 | batch_size: int = 16384 # The number of rays/pixels in each batch.
52 | patch_size: int = 1 # Resolution of patches sampled for training batches.
53 | factor: int = 0 # The downsample factor of images, 0 for no downsampling.
54 | load_alphabetical: bool = True # Load images in COLMAP vs alphabetical
55 | # ordering (affects heldout test set).
56 | forward_facing: bool = False # Set to True for forward-facing captures.
57 | render_path: bool = False # If True, render a path.
58 |
59 | gc_every: int = 10000 # The number of steps between garbage collections.
60 | disable_multiscale_loss: bool = False # If True, disable multiscale loss.
61 | randomized: bool = True # Use randomized stratified sampling.
62 | near: float = 2. # Near plane distance.
63 | far: float = 6. # Far plane distance.
64 | checkpoint_dir: Optional[str] = None # Where to log checkpoints.
65 | render_dir: Optional[str] = None # Output rendering directory.
66 | data_dir: Optional[str] = None # Input data directory.
67 | render_chunk_size: int = 16384 # Chunk size for whole-image renderings.
68 | num_showcase_images: int = 5 # The number of test-set images to showcase.
69 | deterministic_showcase: bool = True # If True, showcase the same images.
70 | vis_num_rays: int = 16 # The number of rays to visualize.
71 | # Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage.
72 | vis_decimate: int = 0
73 |
74 |
75 | # Only used by train.py:
76 | max_steps: int = 250000 # The number of optimization steps.
77 | early_exit_steps: Optional[int] = None # Early stopping, for debugging.
78 | checkpoint_every: int = 25000 # The number of steps to save a checkpoint.
79 | print_every: int = 100 # The number of steps between reports to tensorboard.
80 | train_render_every: int = 5000 # Steps between test set renders when training
81 | cast_rays_in_train_step: bool = False # If True, compute rays in train step.
82 | data_loss_type: str = 'dino_ssim' # What kind of loss to use ('mse' or 'charb').
83 | charb_padding: float = 0.001 # The padding used for Charbonnier loss.
84 | data_loss_mult: float = 0.5 # Mult for the finest data term in the loss.
85 | data_coarse_loss_mult: float = 0. # Multiplier for the coarser data terms.
86 | interlevel_loss_mult: float = 1.0 # Mult. for the loss on the proposal MLP.
87 | orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss.
88 | orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights.
89 | # RobustNerf loss hyperparameters
90 | robustnerf_inlier_quantile: float = 0.5
91 | enable_robustnerf_loss: bool = False
92 | robustnerf_inner_patch_size: int = 8
93 | robustnerf_smoothed_filter_size: int = 3
94 | robustnerf_smoothed_inlier_quantile: float = 0.5
95 | robustnerf_inner_patch_inlier_quantile: float = 0.5
96 | # What that loss is imposed on, options are 'normals' or 'normals_pred'.
97 | orientation_loss_target: str = 'normals_pred'
98 | predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss.
99 | # Mult. on the coarser predicted normal loss.
100 | predicted_normal_coarse_loss_mult: float = 0.0
101 | weight_decay_mults: FrozenDict[str, Any] = FrozenDict({}) # Weight decays.
102 | # An example that regularizes the NeRF and the first layer of the prop MLP:
103 | # weight_decay_mults:FrozenDict[str, Any] = FrozenDict({
104 | # 'NerfMLP_0': 0.00001,
105 | # 'PropMLP_0/Dense_0': 0.001,
106 | # 'UncerMLP_0': 0.00001
107 | # })
108 | # weight_decay_mults:FrozenDict[str, Any] = FrozenDict({
109 | # 'UncerMLP_0': 0.00001
110 | # })
111 | # Any model parameter that isn't specified gets a mult of 0. See the
112 | # train_weight_l2_* parameters in TensorBoard to know what can be regularized.
113 |
114 | lr_init: float = 0.002 # The initial learning rate.
115 | lr_final: float = 0.00002 # The final learning rate.
116 | lr_delay_steps: int = 512 # The number of "warmup" learning steps.
117 | lr_delay_mult: float = 0.01 # How much sever the "warmup" should be.
118 | adam_beta1: float = 0.9 # Adam's beta2 hyperparameter.
119 | adam_beta2: float = 0.999 # Adam's beta2 hyperparameter.
120 | adam_eps: float = 1e-6 # Adam's epsilon hyperparameter.
121 | grad_max_norm: float = 0.001 # Gradient clipping magnitude, disabled if == 0.
122 | grad_max_val: float = 0. # Gradient clipping value, disabled if == 0.
123 | distortion_loss_mult: float = 0.01 # Multiplier on the distortion loss.
124 |
125 | # Only used by eval.py:
126 | eval_only_once: bool = True # If True evaluate the model only once, ow loop.
127 | eval_save_output: bool = True # If True save predicted images to disk.
128 | eval_save_ray_data: bool = False # If True save individual ray traces.
129 | eval_render_interval: int = 1 # The interval between images saved to disk.
130 | eval_dataset_limit: int = jnp.iinfo(jnp.int32).max # Num test images to eval.
131 | eval_quantize_metrics: bool = True # If True, run metrics on 8-bit images.
132 | eval_crop_borders: int = 0 # Ignore c border pixels in eval (x[c:-c, c:-c]).
133 |
134 | # Only used by render.py
135 | is_render: bool = False # whether is rendering, for feature assignment bug
136 | render_video_fps: int = 60 # Framerate in frames-per-second.
137 | render_video_crf: int = 18 # Constant rate factor for ffmpeg video quality.
138 | render_path_frames: int = 120 # Number of frames in render path.
139 | z_variation: float = 0. # How much height variation in render path.
140 | z_phase: float = 0. # Phase offset for height variation in render path.
141 | render_dist_percentile: float = 0.5 # How much to trim from near/far planes.
142 | render_dist_curve_fn: Callable[..., Any] = jnp.log # How depth is curved.
143 | render_path_file: Optional[str] = None # Numpy render pose file to load.
144 | render_job_id: int = 0 # Render job id.
145 | render_num_jobs: int = 1 # Total number of render jobs.
146 | render_resolution: Optional[Tuple[int, int]] = None # Render resolution, as
147 | # (width, height).
148 | render_focal: Optional[float] = None # Render focal length.
149 | render_camtype: Optional[str] = None # 'perspective', 'fisheye', or 'pano'.
150 | render_spherical: bool = False # Render spherical 360 panoramas.
151 | render_save_async: bool = True # Save to CNS using a separate thread.
152 |
153 | render_spline_keyframes: Optional[str] = None # Text file containing names of
154 | # images to be used as spline
155 | # keyframes, OR directory
156 | # containing those images.
157 | render_spline_n_interp: int = 30 # Num. frames to interpolate per keyframe.
158 | render_spline_degree: int = 5 # Polynomial degree of B-spline interpolation.
159 | render_spline_smoothness: float = .03 # B-spline smoothing factor, 0 for
160 | # exact interpolation of keyframes.
161 | # Interpolate per-frame exposure value from spline keyframes.
162 | render_spline_interpolate_exposure: bool = False
163 |
164 | # Flags for raw datasets.
165 | rawnerf_mode: bool = False # Load raw images and train in raw color space.
166 | exposure_percentile: float = 97. # Image percentile to expose as white.
167 | num_border_pixels_to_mask: int = 0 # During training, discard N-pixel border
168 | # around each input image.
169 | apply_bayer_mask: bool = False # During training, apply Bayer mosaic mask.
170 | autoexpose_renders: bool = False # During rendering, autoexpose each image.
171 | # For raw test scenes, use affine raw-space color correction.
172 | eval_raw_affine_cc: bool = False
173 |
174 | # dino configs
175 | dino_var_mult: float = 0.1 # multiplier for the variance of the dino features as regularization
176 | compute_feature_metrics: bool = True # If True, compute feature.
177 | feat_rate: int = 4 # Feature sampling rate w.r.t original image size.
178 | dilate: int = 8 # The dilate rate for the patch size
179 | eval_train: bool = True # evaluate test set or train set, for debug
180 | train_clean: bool = False # train on clean set or clutter set
181 | reg_mult: float = 0.5 # reg weight for uncertainty reg
182 | uncer_lr_rate: float = 1 # lr rate w.r.t nerf
183 | uncer_clip_min: float = 0.1 # minimum value to clip the uncertainty
184 | ssim_clip_max: float = 5 # maximum value to clip the ssim
185 | ssim_mult: float = 0.5 #multiplicative factor for ssim loss
186 | H: int = 3024 #height of the image
187 | W: int = 4032 #weight of the image
188 | ssim_anneal: float= 0.8 # anneal rate for ssim
189 | stop_ssim_gradient: bool = True # whether to stop the gradient flow from ssim to reconstruction
190 | ssim_window_size: int = 5 # window size of ssim
191 | mask_type: str = 'masks' # use the mask of which folder
192 | feat_dim: int = 384 # feature dimension, 384 for dino_s/14
193 | feat_ds: int = 14 # feature downsample rate for dino, 14 for dino_s/14, combine together with feat_rate
194 |
195 |
196 | def define_common_flags():
197 | # Define the flags used by both train.py and eval.py
198 | flags.DEFINE_string('mode', None, 'Required by GINXM, not used.')
199 | flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.')
200 | flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.')
201 | flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.')
202 |
203 |
204 | def load_config(save_config=True):
205 | """Load the config, and optionally checkpoint it."""
206 | gin.parse_config_files_and_bindings(
207 | flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True)
208 | config = Config()
209 | if save_config and jax.host_id() == 0:
210 | utils.makedirs(config.checkpoint_dir)
211 | with utils.open_file(config.checkpoint_dir + '/config.gin', 'w') as f:
212 | f.write(gin.config_str())
213 | return config
214 |
--------------------------------------------------------------------------------
/internal/coord.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tools for manipulating coordinate spaces and distances along rays."""
15 |
16 | from internal import math
17 | import jax
18 | import jax.numpy as jnp
19 |
20 |
21 | def contract(x):
22 | """Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077)."""
23 | eps = jnp.finfo(jnp.float32).eps
24 | # Clamping to eps prevents non-finite gradients when x == 0.
25 | x_mag_sq = jnp.maximum(eps, jnp.sum(x**2, axis=-1, keepdims=True))
26 | z = jnp.where(x_mag_sq <= 1, x, ((2 * jnp.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
27 | return z
28 |
29 |
30 | def inv_contract(z):
31 | """The inverse of contract()."""
32 | eps = jnp.finfo(jnp.float32).eps
33 | # Clamping to eps prevents non-finite gradients when z == 0.
34 | z_mag_sq = jnp.maximum(eps, jnp.sum(z**2, axis=-1, keepdims=True))
35 | x = jnp.where(z_mag_sq <= 1, z, z / (2 * jnp.sqrt(z_mag_sq) - z_mag_sq))
36 | return x
37 |
38 |
39 | def track_linearize(fn, mean, cov):
40 | """Apply function `fn` to a set of means and covariances, ala a Kalman filter.
41 |
42 | We can analytically transform a Gaussian parameterized by `mean` and `cov`
43 | with a function `fn` by linearizing `fn` around `mean`, and taking advantage
44 | of the fact that Covar[Ax + y] = A(Covar[x])A^T (see
45 | https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details).
46 |
47 | Args:
48 | fn: the function applied to the Gaussians parameterized by (mean, cov).
49 | mean: a tensor of means, where the last axis is the dimension.
50 | cov: a tensor of covariances, where the last two axes are the dimensions.
51 |
52 | Returns:
53 | fn_mean: the transformed means.
54 | fn_cov: the transformed covariances.
55 | """
56 | if (len(mean.shape) + 1) != len(cov.shape):
57 | raise ValueError('cov must be non-diagonal')
58 | fn_mean, lin_fn = jax.linearize(fn, mean)
59 | fn_cov = jax.vmap(lin_fn, -1, -2)(jax.vmap(lin_fn, -1, -2)(cov))
60 | return fn_mean, fn_cov
61 |
62 |
63 | def construct_ray_warps(fn, t_near, t_far):
64 | """Construct a bijection between metric distances and normalized distances.
65 |
66 | See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a
67 | detailed explanation.
68 |
69 | Args:
70 | fn: the function to ray distances.
71 | t_near: a tensor of near-plane distances.
72 | t_far: a tensor of far-plane distances.
73 |
74 | Returns:
75 | t_to_s: a function that maps distances to normalized distances in [0, 1].
76 | s_to_t: the inverse of t_to_s.
77 | """
78 | if fn is None:
79 | fn_fwd = lambda x: x
80 | fn_inv = lambda x: x
81 | elif fn == 'piecewise':
82 | # Piecewise spacing combining identity and 1/x functions to allow t_near=0.
83 | fn_fwd = lambda x: jnp.where(x < 1, .5 * x, 1 - .5 / x)
84 | fn_inv = lambda x: jnp.where(x < .5, 2 * x, .5 / (1 - x))
85 | else:
86 | inv_mapping = {
87 | 'reciprocal': jnp.reciprocal,
88 | 'log': jnp.exp,
89 | 'exp': jnp.log,
90 | 'sqrt': jnp.square,
91 | 'square': jnp.sqrt
92 | }
93 | fn_fwd = fn
94 | fn_inv = inv_mapping[fn.__name__]
95 |
96 | s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)]
97 | t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near)
98 | s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near)
99 | return t_to_s, s_to_t
100 |
101 |
102 | def expected_sin(mean, var):
103 | """Compute the mean of sin(x), x ~ N(mean, var)."""
104 | return jnp.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value.
105 |
106 |
107 | def integrated_pos_enc(mean, var, min_deg, max_deg):
108 | """Encode `x` with sinusoids scaled by 2^[min_deg, max_deg).
109 |
110 | Args:
111 | mean: tensor, the mean coordinates to be encoded
112 | var: tensor, the variance of the coordinates to be encoded.
113 | min_deg: int, the min degree of the encoding.
114 | max_deg: int, the max degree of the encoding.
115 |
116 | Returns:
117 | encoded: jnp.ndarray, encoded variables.
118 | """
119 | scales = 2**jnp.arange(min_deg, max_deg)
120 | shape = mean.shape[:-1] + (-1,)
121 | scaled_mean = jnp.reshape(mean[..., None, :] * scales[:, None], shape)
122 | scaled_var = jnp.reshape(var[..., None, :] * scales[:, None]**2, shape)
123 |
124 | return expected_sin(
125 | jnp.concatenate([scaled_mean, scaled_mean + 0.5 * jnp.pi], axis=-1),
126 | jnp.concatenate([scaled_var] * 2, axis=-1))
127 |
128 |
129 | def lift_and_diagonalize(mean, cov, basis):
130 | """Project `mean` and `cov` onto basis and diagonalize the projected cov."""
131 | fn_mean = math.matmul(mean, basis)
132 | fn_cov_diag = jnp.sum(basis * math.matmul(cov, basis), axis=-2)
133 | return fn_mean, fn_cov_diag
134 |
135 |
136 | def pos_enc(x, min_deg, max_deg, append_identity=True):
137 | """The positional encoding used by the original NeRF paper."""
138 | scales = 2**jnp.arange(min_deg, max_deg)
139 | shape = x.shape[:-1] + (-1,)
140 | scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape)
141 | # Note that we're not using safe_sin, unlike IPE.
142 | four_feat = jnp.sin(
143 | jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1))
144 | if append_identity:
145 | return jnp.concatenate([x] + [four_feat], axis=-1)
146 | else:
147 | return four_feat
148 |
--------------------------------------------------------------------------------
/internal/geopoly.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tools for constructing geodesic polyhedron, which are used as a basis."""
16 |
17 | import itertools
18 | import numpy as np
19 |
20 |
21 | def compute_sq_dist(mat0, mat1=None):
22 | """Compute the squared Euclidean distance between all pairs of columns."""
23 | if mat1 is None:
24 | mat1 = mat0
25 | # Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y.
26 | sq_norm0 = np.sum(mat0**2, 0)
27 | sq_norm1 = np.sum(mat1**2, 0)
28 | sq_dist = sq_norm0[:, None] + sq_norm1[None, :] - 2 * mat0.T @ mat1
29 | sq_dist = np.maximum(0, sq_dist) # Negative values must be numerical errors.
30 | return sq_dist
31 |
32 |
33 | def compute_tesselation_weights(v):
34 | """Tesselate the vertices of a triangle by a factor of `v`."""
35 | if v < 1:
36 | raise ValueError(f'v {v} must be >= 1')
37 | int_weights = []
38 | for i in range(v + 1):
39 | for j in range(v + 1 - i):
40 | int_weights.append((i, j, v - (i + j)))
41 | int_weights = np.array(int_weights)
42 | weights = int_weights / v # Barycentric weights.
43 | return weights
44 |
45 |
46 | def tesselate_geodesic(base_verts, base_faces, v, eps=1e-4):
47 | """Tesselate the vertices of a geodesic polyhedron.
48 |
49 | Args:
50 | base_verts: tensor of floats, the vertex coordinates of the geodesic.
51 | base_faces: tensor of ints, the indices of the vertices of base_verts that
52 | constitute eachface of the polyhedra.
53 | v: int, the factor of the tesselation (v==1 is a no-op).
54 | eps: float, a small value used to determine if two vertices are the same.
55 |
56 | Returns:
57 | verts: a tensor of floats, the coordinates of the tesselated vertices.
58 | """
59 | if not isinstance(v, int):
60 | raise ValueError(f'v {v} must an integer')
61 | tri_weights = compute_tesselation_weights(v)
62 |
63 | verts = []
64 | for base_face in base_faces:
65 | new_verts = np.matmul(tri_weights, base_verts[base_face, :])
66 | new_verts /= np.sqrt(np.sum(new_verts**2, 1, keepdims=True))
67 | verts.append(new_verts)
68 | verts = np.concatenate(verts, 0)
69 |
70 | sq_dist = compute_sq_dist(verts.T)
71 | assignment = np.array([np.min(np.argwhere(d <= eps)) for d in sq_dist])
72 | unique = np.unique(assignment)
73 | verts = verts[unique, :]
74 |
75 | return verts
76 |
77 |
78 | def generate_basis(base_shape,
79 | angular_tesselation,
80 | remove_symmetries=True,
81 | eps=1e-4):
82 | """Generates a 3D basis by tesselating a geometric polyhedron.
83 |
84 | Args:
85 | base_shape: string, the name of the starting polyhedron, must be either
86 | 'icosahedron' or 'octahedron'.
87 | angular_tesselation: int, the number of times to tesselate the polyhedron,
88 | must be >= 1 (a value of 1 is a no-op to the polyhedron).
89 | remove_symmetries: bool, if True then remove the symmetric basis columns,
90 | which is usually a good idea because otherwise projections onto the basis
91 | will have redundant negative copies of each other.
92 | eps: float, a small number used to determine symmetries.
93 |
94 | Returns:
95 | basis: a matrix with shape [3, n].
96 | """
97 | if base_shape == 'icosahedron':
98 | a = (np.sqrt(5) + 1) / 2
99 | verts = np.array([(-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1),
100 | (0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0),
101 | (-a, 1, 0), (a, -1, 0), (-a, -1, 0)]) / np.sqrt(a + 2)
102 | faces = np.array([(0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1),
103 | (8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3),
104 | (7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6),
105 | (6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5),
106 | (7, 2, 11)])
107 | verts = tesselate_geodesic(verts, faces, angular_tesselation)
108 | elif base_shape == 'octahedron':
109 | verts = np.array([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0),
110 | (1, 0, 0)])
111 | corners = np.array(list(itertools.product([-1, 1], repeat=3)))
112 | pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)
113 | faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)
114 | verts = tesselate_geodesic(verts, faces, angular_tesselation)
115 | else:
116 | raise ValueError(f'base_shape {base_shape} not supported')
117 |
118 | if remove_symmetries:
119 | # Remove elements of `verts` that are reflections of each other.
120 | match = compute_sq_dist(verts.T, -verts.T) < eps
121 | verts = verts[np.any(np.triu(match), 1), :]
122 |
123 | basis = verts[:, ::-1]
124 | return basis
125 |
--------------------------------------------------------------------------------
/internal/image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for processing images."""
16 |
17 | import types
18 | from typing import Optional, Union
19 |
20 | import dm_pix
21 | import jax
22 | import jax.numpy as jnp
23 | import numpy as np
24 | import lpips_jax
25 |
26 | _Array = Union[np.ndarray, jnp.ndarray]
27 |
28 |
29 | def mse_to_psnr(mse):
30 | """Compute PSNR given an MSE (we assume the maximum pixel value is 1)."""
31 | return -10. / jnp.log(10.) * jnp.log(mse)
32 |
33 |
34 | def psnr_to_mse(psnr):
35 | """Compute MSE given a PSNR (we assume the maximum pixel value is 1)."""
36 | return jnp.exp(-0.1 * jnp.log(10.) * psnr)
37 |
38 |
39 | def ssim_to_dssim(ssim):
40 | """Compute DSSIM given an SSIM."""
41 | return (1 - ssim) / 2
42 |
43 |
44 | def dssim_to_ssim(dssim):
45 | """Compute DSSIM given an SSIM."""
46 | return 1 - 2 * dssim
47 |
48 |
49 | def linear_to_srgb(linear: _Array,
50 | eps: Optional[float] = None,
51 | xnp: types.ModuleType = jnp) -> _Array:
52 | """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
53 | if eps is None:
54 | eps = xnp.finfo(xnp.float32).eps
55 | srgb0 = 323 / 25 * linear
56 | srgb1 = (211 * xnp.maximum(eps, linear)**(5 / 12) - 11) / 200
57 | return xnp.where(linear <= 0.0031308, srgb0, srgb1)
58 |
59 |
60 | def srgb_to_linear(srgb: _Array,
61 | eps: Optional[float] = None,
62 | xnp: types.ModuleType = jnp) -> _Array:
63 | """Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
64 | if eps is None:
65 | eps = xnp.finfo(xnp.float32).eps
66 | linear0 = 25 / 323 * srgb
67 | linear1 = xnp.maximum(eps, ((200 * srgb + 11) / (211)))**(12 / 5)
68 | return xnp.where(srgb <= 0.04045, linear0, linear1)
69 |
70 |
71 | def downsample(img, factor):
72 | """Area downsample img (factor must evenly divide img height and width)."""
73 | sh = img.shape
74 | if not (sh[0] % factor == 0 and sh[1] % factor == 0):
75 | raise ValueError(f'Downsampling factor {factor} does not '
76 | f'evenly divide image shape {sh[:2]}')
77 | img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:])
78 | img = img.mean((1, 3))
79 | return img
80 |
81 |
82 | def color_correct(img, ref, num_iters=5, eps=0.5 / 255):
83 | """Warp `img` to match the colors in `ref_img`."""
84 | if img.shape[-1] != ref.shape[-1]:
85 | raise ValueError(
86 | f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match'
87 | )
88 | num_channels = img.shape[-1]
89 | img_mat = img.reshape([-1, num_channels])
90 | ref_mat = ref.reshape([-1, num_channels])
91 | is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps)) # z \in [eps, 1-eps].
92 | mask0 = is_unclipped(img_mat)
93 | # Because the set of saturated pixels may change after solving for a
94 | # transformation, we repeatedly solve a system `num_iters` times and update
95 | # our estimate of which pixels are saturated.
96 | for _ in range(num_iters):
97 | # Construct the left hand side of a linear system that contains a quadratic
98 | # expansion of each pixel of `img`.
99 | a_mat = []
100 | for c in range(num_channels):
101 | a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) # Quadratic term.
102 | a_mat.append(img_mat) # Linear term.
103 | a_mat.append(jnp.ones_like(img_mat[:, :1])) # Bias term.
104 | a_mat = jnp.concatenate(a_mat, axis=-1)
105 | warp = []
106 | for c in range(num_channels):
107 | # Construct the right hand side of a linear system containing each color
108 | # of `ref`.
109 | b = ref_mat[:, c]
110 | # Ignore rows of the linear system that were saturated in the input or are
111 | # saturated in the current corrected color estimate.
112 | mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b)
113 | ma_mat = jnp.where(mask[:, None], a_mat, 0)
114 | mb = jnp.where(mask, b, 0)
115 | # Solve the linear system. We're using the np.lstsq instead of jnp because
116 | # it's significantly more stable in this case, for some reason.
117 | w = np.linalg.lstsq(ma_mat, mb, rcond=-1)[0]
118 | assert jnp.all(jnp.isfinite(w))
119 | warp.append(w)
120 | warp = jnp.stack(warp, axis=-1)
121 | # Apply the warp to update img_mat.
122 | img_mat = jnp.clip(
123 | jnp.matmul(a_mat, warp, precision=jax.lax.Precision.HIGHEST), 0, 1)
124 | corrected_img = jnp.reshape(img_mat, img.shape)
125 | return corrected_img
126 |
127 |
128 | class MetricHarness:
129 | """A helper class for evaluating several error metrics."""
130 |
131 | def __init__(self):
132 | self.ssim_fn = jax.jit(dm_pix.ssim)
133 |
134 | def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s):
135 | """Evaluate the error between a predicted rgb image and the true image."""
136 | psnr = float(mse_to_psnr(((rgb_pred - rgb_gt)**2).mean()))
137 | ssim = float(self.ssim_fn(rgb_pred, rgb_gt))
138 |
139 | return {
140 | name_fn('psnr'): psnr,
141 | name_fn('ssim'): ssim,
142 | }
143 |
144 |
145 | class MetricHarnessLPIPS:
146 | """A helper class for evaluating several error metrics with vgg16 lpips."""
147 |
148 | def __init__(self):
149 | self.ssim_fn = jax.jit(dm_pix.ssim)
150 | self.lpips_fn = lpips_jax.LPIPSEvaluator(replicate=False, net='vgg16') # ['alexnet', 'vgg16']
151 |
152 | def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s):
153 | """Evaluate the error between a predicted rgb image and the true image."""
154 | psnr = float(mse_to_psnr(((rgb_pred - rgb_gt)**2).mean()))
155 | ssim = float(self.ssim_fn(rgb_pred, rgb_gt))
156 | # To fix lpips calculation, now it's broken
157 | # lpips = float(self.lpips_fn(rgb_pred[None,...]*2-1, rgb_gt[None,...]*2-1))
158 |
159 | return {
160 | name_fn('psnr'): psnr,
161 | name_fn('ssim'): ssim,
162 | # name_fn('lpips'): lpips,
163 | }
164 |
--------------------------------------------------------------------------------
/internal/math.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Mathy utility functions."""
16 |
17 | import jax
18 | import jax.numpy as jnp
19 |
20 |
21 | def matmul(a, b):
22 | """jnp.matmul defaults to bfloat16, but this helper function doesn't."""
23 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST)
24 |
25 |
26 | def safe_trig_helper(x, fn, t=100 * jnp.pi):
27 | """Helper function used by safe_cos/safe_sin: mods x before sin()/cos()."""
28 | return fn(jnp.where(jnp.abs(x) < t, x, x % t))
29 |
30 |
31 | def safe_cos(x):
32 | """jnp.cos() on a TPU may NaN out for large values."""
33 | return safe_trig_helper(x, jnp.cos)
34 |
35 |
36 | def safe_sin(x):
37 | """jnp.sin() on a TPU may NaN out for large values."""
38 | return safe_trig_helper(x, jnp.sin)
39 |
40 |
41 | @jax.custom_jvp
42 | def safe_exp(x):
43 | """jnp.exp() but with finite output and gradients for large inputs."""
44 | return jnp.exp(jnp.minimum(x, 88.)) # jnp.exp(89) is infinity.
45 |
46 |
47 | @safe_exp.defjvp
48 | def safe_exp_jvp(primals, tangents):
49 | """Override safe_exp()'s gradient so that it's large when inputs are large."""
50 | x, = primals
51 | x_dot, = tangents
52 | exp_x = safe_exp(x)
53 | exp_x_dot = exp_x * x_dot
54 | return exp_x, exp_x_dot
55 |
56 |
57 | def log_lerp(t, v0, v1):
58 | """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1)."""
59 | if v0 <= 0 or v1 <= 0:
60 | raise ValueError(f'Interpolants {v0} and {v1} must be positive.')
61 | lv0 = jnp.log(v0)
62 | lv1 = jnp.log(v1)
63 | return jnp.exp(jnp.clip(t, 0, 1) * (lv1 - lv0) + lv0)
64 |
65 |
66 | def learning_rate_decay(step,
67 | lr_init,
68 | lr_final,
69 | max_steps,
70 | lr_delay_steps=0,
71 | lr_delay_mult=1):
72 | """Continuous learning rate decay function.
73 |
74 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
75 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
76 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
77 | function of lr_delay_mult, such that the initial learning rate is
78 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
79 | to the normal learning rate when steps>lr_delay_steps.
80 |
81 | Args:
82 | step: int, the current optimization step.
83 | lr_init: float, the initial learning rate.
84 | lr_final: float, the final learning rate.
85 | max_steps: int, the number of steps during optimization.
86 | lr_delay_steps: int, the number of steps to delay the full learning rate.
87 | lr_delay_mult: float, the multiplier on the rate when delaying it.
88 |
89 | Returns:
90 | lr: the learning for current step 'step'.
91 | """
92 | if lr_delay_steps > 0:
93 | # A kind of reverse cosine decay.
94 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * jnp.sin(
95 | 0.5 * jnp.pi * jnp.clip(step / lr_delay_steps, 0, 1))
96 | else:
97 | delay_rate = 1.
98 | return delay_rate * log_lerp(step / max_steps, lr_init, lr_final)
99 |
100 |
101 | def interp(*args):
102 | """A gather-based (GPU-friendly) vectorized replacement for jnp.interp()."""
103 | args_flat = [x.reshape([-1, x.shape[-1]]) for x in args]
104 | ret = jax.vmap(jnp.interp)(*args_flat).reshape(args[0].shape)
105 | return ret
106 |
107 |
108 | def sorted_interp(x, xp, fp):
109 | """A TPU-friendly version of interp(), where xp and fp must be sorted."""
110 |
111 | # Identify the location in `xp` that corresponds to each `x`.
112 | # The final `True` index in `mask` is the start of the matching interval.
113 | mask = x[..., None, :] >= xp[..., :, None]
114 |
115 | def find_interval(x):
116 | # Grab the value where `mask` switches from True to False, and vice versa.
117 | # This approach takes advantage of the fact that `x` is sorted.
118 | x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2)
119 | x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)
120 | return x0, x1
121 |
122 | fp0, fp1 = find_interval(fp)
123 | xp0, xp1 = find_interval(xp)
124 |
125 | offset = jnp.clip(jnp.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)
126 | ret = fp0 + offset * (fp1 - fp0)
127 | return ret
128 |
--------------------------------------------------------------------------------
/internal/raw_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for processing and loading raw image data."""
16 |
17 | import glob
18 | import json
19 | import os
20 | import types
21 | from typing import Any, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
22 |
23 | from internal import image as lib_image
24 | from internal import math
25 | from internal import utils
26 | import jax
27 | import jax.numpy as jnp
28 | import numpy as np
29 | import rawpy
30 |
31 | _Array = Union[np.ndarray, jnp.ndarray]
32 | _Axis = Optional[Union[int, Tuple[int, ...]]]
33 |
34 |
35 | def postprocess_raw(raw: _Array,
36 | camtorgb: _Array,
37 | exposure: Optional[float] = None,
38 | xnp: types.ModuleType = np) -> _Array:
39 | """Converts demosaicked raw to sRGB with a minimal postprocessing pipeline.
40 |
41 | Numpy array inputs will be automatically converted to Jax arrays.
42 |
43 | Args:
44 | raw: [H, W, 3], demosaicked raw camera image.
45 | camtorgb: [3, 3], color correction transformation to apply to raw image.
46 | exposure: color value to be scaled to pure white after color correction.
47 | If None, "autoexposes" at the 97th percentile.
48 | xnp: either numpy or jax.numpy.
49 |
50 | Returns:
51 | srgb: [H, W, 3], color corrected + exposed + gamma mapped image.
52 | """
53 | if raw.shape[-1] != 3:
54 | raise ValueError(f'raw.shape[-1] is {raw.shape[-1]}, expected 3')
55 | if camtorgb.shape != (3, 3):
56 | raise ValueError(f'camtorgb.shape is {camtorgb.shape}, expected (3, 3)')
57 | # Convert from camera color space to standard linear RGB color space.
58 | matmul = math.matmul if xnp == jnp else np.matmul
59 | rgb_linear = matmul(raw, camtorgb.T)
60 | if exposure is None:
61 | exposure = xnp.percentile(rgb_linear, 97)
62 | # "Expose" image by mapping the input exposure level to white and clipping.
63 | rgb_linear_scaled = xnp.clip(rgb_linear / exposure, 0, 1)
64 | # Apply sRGB gamma curve to serve as a simple tonemap.
65 | srgb = lib_image.linear_to_srgb(rgb_linear_scaled, xnp=xnp)
66 | return srgb
67 |
68 |
69 | def pixels_to_bayer_mask(pix_x: np.ndarray, pix_y: np.ndarray) -> np.ndarray:
70 | """Computes binary RGB Bayer mask values from integer pixel coordinates."""
71 | # Red is top left (0, 0).
72 | r = (pix_x % 2 == 0) * (pix_y % 2 == 0)
73 | # Green is top right (0, 1) and bottom left (1, 0).
74 | g = (pix_x % 2 == 1) * (pix_y % 2 == 0) + (pix_x % 2 == 0) * (pix_y % 2 == 1)
75 | # Blue is bottom right (1, 1).
76 | b = (pix_x % 2 == 1) * (pix_y % 2 == 1)
77 | return np.stack([r, g, b], -1).astype(np.float32)
78 |
79 |
80 | def bilinear_demosaic(bayer: _Array,
81 | xnp: types.ModuleType) -> _Array:
82 | """Converts Bayer data into a full RGB image using bilinear demosaicking.
83 |
84 | Input data should be ndarray of shape [height, width] with 2x2 mosaic pattern:
85 | -------------
86 | |red |green|
87 | -------------
88 | |green|blue |
89 | -------------
90 | Red and blue channels are bilinearly upsampled 2x, missing green channel
91 | elements are the average of the neighboring 4 values in a cross pattern.
92 |
93 | Args:
94 | bayer: [H, W] array, Bayer mosaic pattern input image.
95 | xnp: either numpy or jax.numpy.
96 |
97 | Returns:
98 | rgb: [H, W, 3] array, full RGB image.
99 | """
100 | def reshape_quads(*planes):
101 | """Reshape pixels from four input images to make tiled 2x2 quads."""
102 | planes = xnp.stack(planes, -1)
103 | shape = planes.shape[:-1]
104 | # Create [2, 2] arrays out of 4 channels.
105 | zup = planes.reshape(shape + (2, 2,))
106 | # Transpose so that x-axis dimensions come before y-axis dimensions.
107 | zup = xnp.transpose(zup, (0, 2, 1, 3))
108 | # Reshape to 2D.
109 | zup = zup.reshape((shape[0] * 2, shape[1] * 2))
110 | return zup
111 |
112 | def bilinear_upsample(z):
113 | """2x bilinear image upsample."""
114 | # Using np.roll makes the right and bottom edges wrap around. The raw image
115 | # data has a few garbage columns/rows at the edges that must be discarded
116 | # anyway, so this does not matter in practice.
117 | # Horizontally interpolated values.
118 | zx = .5 * (z + xnp.roll(z, -1, axis=-1))
119 | # Vertically interpolated values.
120 | zy = .5 * (z + xnp.roll(z, -1, axis=-2))
121 | # Diagonally interpolated values.
122 | zxy = .5 * (zx + xnp.roll(zx, -1, axis=-2))
123 | return reshape_quads(z, zx, zy, zxy)
124 |
125 | def upsample_green(g1, g2):
126 | """Special 2x upsample from the two green channels."""
127 | z = xnp.zeros_like(g1)
128 | z = reshape_quads(z, g1, g2, z)
129 | alt = 0
130 | # Grab the 4 directly adjacent neighbors in a "cross" pattern.
131 | for i in range(4):
132 | axis = -1 - (i // 2)
133 | roll = -1 + 2 * (i % 2)
134 | alt = alt + .25 * xnp.roll(z, roll, axis=axis)
135 | # For observed pixels, alt = 0, and for unobserved pixels, alt = avg(cross),
136 | # so alt + z will have every pixel filled in.
137 | return alt + z
138 |
139 | r, g1, g2, b = [bayer[(i//2)::2, (i%2)::2] for i in range(4)]
140 | r = bilinear_upsample(r)
141 | # Flip in x and y before and after calling upsample, as bilinear_upsample
142 | # assumes that the samples are at the top-left corner of the 2x2 sample.
143 | b = bilinear_upsample(b[::-1, ::-1])[::-1, ::-1]
144 | g = upsample_green(g1, g2)
145 | rgb = xnp.stack([r, g, b], -1)
146 | return rgb
147 |
148 |
149 | bilinear_demosaic_jax = jax.jit(lambda bayer: bilinear_demosaic(bayer, xnp=jnp))
150 |
151 |
152 | def load_raw_images(image_dir: str,
153 | image_names: Optional[Sequence[str]] = None
154 | ) -> Tuple[np.ndarray, Sequence[Mapping[str, Any]]]:
155 | """Loads raw images and their metadata from disk.
156 |
157 | Args:
158 | image_dir: directory containing raw image and EXIF data.
159 | image_names: files to load (ignores file extension), loads all DNGs if None.
160 |
161 | Returns:
162 | A tuple (images, exifs).
163 | images: [N, height, width, 3] array of raw sensor data.
164 | exifs: [N] list of dicts, one per image, containing the EXIF data.
165 | Raises:
166 | ValueError: The requested `image_dir` does not exist on disk.
167 | """
168 |
169 | if not utils.file_exists(image_dir):
170 | raise ValueError(f'Raw image folder {image_dir} does not exist.')
171 |
172 | # Load raw images (dng files) and exif metadata (json files).
173 | def load_raw_exif(image_name):
174 | base = os.path.join(image_dir, os.path.splitext(image_name)[0])
175 | with utils.open_file(base + '.dng', 'rb') as f:
176 | raw = rawpy.imread(f).raw_image
177 | with utils.open_file(base + '.json', 'rb') as f:
178 | exif = json.load(f)[0]
179 | return raw, exif
180 |
181 | if image_names is None:
182 | image_names = [
183 | os.path.basename(f)
184 | for f in sorted(glob.glob(os.path.join(image_dir, '*.dng')))
185 | ]
186 |
187 | data = [load_raw_exif(x) for x in image_names]
188 | raws, exifs = zip(*data)
189 | raws = np.stack(raws, axis=0).astype(np.float32)
190 |
191 | return raws, exifs
192 |
193 |
194 | # Brightness percentiles to use for re-exposing and tonemapping raw images.
195 | _PERCENTILE_LIST = (80, 90, 97, 99, 100)
196 |
197 | # Relevant fields to extract from raw image EXIF metadata.
198 | # For details regarding EXIF parameters, see:
199 | # https://www.adobe.com/content/dam/acom/en/products/photoshop/pdfs/dng_spec_1.4.0.0.pdf.
200 | _EXIF_KEYS = (
201 | 'BlackLevel', # Black level offset added to sensor measurements.
202 | 'WhiteLevel', # Maximum possible sensor measurement.
203 | 'AsShotNeutral', # RGB white balance coefficients.
204 | 'ColorMatrix2', # XYZ to camera color space conversion matrix.
205 | 'NoiseProfile', # Shot and read noise levels.
206 | )
207 |
208 | # Color conversion from reference illuminant XYZ to RGB color space.
209 | # See http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html.
210 | _RGB2XYZ = np.array([[0.4124564, 0.3575761, 0.1804375],
211 | [0.2126729, 0.7151522, 0.0721750],
212 | [0.0193339, 0.1191920, 0.9503041]])
213 |
214 |
215 | def process_exif(
216 | exifs: Sequence[Mapping[str, Any]]) -> MutableMapping[str, Any]:
217 | """Processes list of raw image EXIF data into useful metadata dict.
218 |
219 | Input should be a list of dictionaries loaded from JSON files.
220 | These JSON files are produced by running
221 | $ exiftool -json IMAGE.dng > IMAGE.json
222 | for each input raw file.
223 |
224 | We extract only the parameters relevant to
225 | 1. Rescaling the raw data to [0, 1],
226 | 2. White balance and color correction, and
227 | 3. Noise level estimation.
228 |
229 | Args:
230 | exifs: a list of dicts containing EXIF data as loaded from JSON files.
231 |
232 | Returns:
233 | meta: a dict of the relevant metadata for running RawNeRF.
234 | """
235 | meta = {}
236 | exif = exifs[0]
237 | # Convert from array of dicts (exifs) to dict of arrays (meta).
238 | for key in _EXIF_KEYS:
239 | exif_value = exif.get(key)
240 | if exif_value is None:
241 | continue
242 | # Values can be a single int or float...
243 | if isinstance(exif_value, int) or isinstance(exif_value, float):
244 | vals = [x[key] for x in exifs]
245 | # Or a string of numbers with ' ' between.
246 | elif isinstance(exif_value, str):
247 | vals = [[float(z) for z in x[key].split(' ')] for x in exifs]
248 | meta[key] = np.squeeze(np.array(vals))
249 | # Shutter speed is a special case, a string written like 1/N.
250 | meta['ShutterSpeed'] = np.fromiter(
251 | (1. / float(exif['ShutterSpeed'].split('/')[1]) for exif in exifs), float)
252 |
253 | # Create raw-to-sRGB color transform matrices. Pipeline is:
254 | # cam space -> white balanced cam space ("camwb") -> XYZ space -> RGB space.
255 | # 'AsShotNeutral' is an RGB triplet representing how pure white would measure
256 | # on the sensor, so dividing by these numbers corrects the white balance.
257 | whitebalance = meta['AsShotNeutral'].reshape(-1, 3)
258 | cam2camwb = np.array([np.diag(1. / x) for x in whitebalance])
259 | # ColorMatrix2 converts from XYZ color space to "reference illuminant" (white
260 | # balanced) camera space.
261 | xyz2camwb = meta['ColorMatrix2'].reshape(-1, 3, 3)
262 | rgb2camwb = xyz2camwb @ _RGB2XYZ
263 | # We normalize the rows of the full color correction matrix, as is done in
264 | # https://github.com/AbdoKamel/simple-camera-pipeline.
265 | rgb2camwb /= rgb2camwb.sum(axis=-1, keepdims=True)
266 | # Combining color correction with white balance gives the entire transform.
267 | cam2rgb = np.linalg.inv(rgb2camwb) @ cam2camwb
268 | meta['cam2rgb'] = cam2rgb
269 |
270 | return meta
271 |
272 |
273 | def load_raw_dataset(split: utils.DataSplit,
274 | data_dir: str,
275 | image_names: Sequence[str],
276 | exposure_percentile: float,
277 | n_downsample: int,
278 | ) -> Tuple[np.ndarray, MutableMapping[str, Any], bool]:
279 | """Loads and processes a set of RawNeRF input images.
280 |
281 | Includes logic necessary for special "test" scenes that include a noiseless
282 | ground truth frame, produced by HDR+ merge.
283 |
284 | Args:
285 | split: DataSplit.TRAIN or DataSplit.TEST, only used for test scene logic.
286 | data_dir: base directory for scene data.
287 | image_names: which images were successfully posed by COLMAP.
288 | exposure_percentile: what brightness percentile to expose to white.
289 | n_downsample: returned images are downsampled by a factor of n_downsample.
290 |
291 | Returns:
292 | A tuple (images, meta, testscene).
293 | images: [N, height // n_downsample, width // n_downsample, 3] array of
294 | demosaicked raw image data.
295 | meta: EXIF metadata and other useful processing parameters. Includes per
296 | image exposure information that can be passed into the NeRF model with
297 | each ray: the set of unique exposure times is determined and each image
298 | assigned a corresponding exposure index (mapping to an exposure value).
299 | These are keys 'unique_shutters', 'exposure_idx', and 'exposure_value' in
300 | the `meta` dictionary.
301 | We rescale so the maximum `exposure_value` is 1 for convenience.
302 | testscene: True when dataset includes ground truth test image, else False.
303 | """
304 |
305 | image_dir = os.path.join(data_dir, 'raw')
306 |
307 | testimg_file = os.path.join(data_dir, 'hdrplus_test/merged.dng')
308 | testscene = utils.file_exists(testimg_file)
309 | if testscene:
310 | # Test scenes have train/ and test/ split subdirectories inside raw/.
311 | image_dir = os.path.join(image_dir, split.value)
312 | if split == utils.DataSplit.TEST:
313 | # COLMAP image names not valid for test split of test scene.
314 | image_names = None
315 | else:
316 | # Discard the first COLMAP image name as it is a copy of the test image.
317 | image_names = image_names[1:]
318 |
319 | raws, exifs = load_raw_images(image_dir, image_names)
320 | meta = process_exif(exifs)
321 |
322 | if testscene and split == utils.DataSplit.TEST:
323 | # Test split for test scene must load the "ground truth" HDR+ merged image.
324 | with utils.open_file(testimg_file, 'rb') as imgin:
325 | testraw = rawpy.imread(imgin).raw_image
326 | # HDR+ output has 2 extra bits of fixed precision, need to divide by 4.
327 | testraw = testraw.astype(np.float32) / 4.
328 | # Need to rescale long exposure test image by fast:slow shutter speed ratio.
329 | fast_shutter = meta['ShutterSpeed'][0]
330 | slow_shutter = meta['ShutterSpeed'][-1]
331 | shutter_ratio = fast_shutter / slow_shutter
332 | # Replace loaded raws with the "ground truth" test image.
333 | raws = testraw[None]
334 | # Test image shares metadata with the first loaded image (fast exposure).
335 | meta = {k: meta[k][:1] for k in meta}
336 | else:
337 | shutter_ratio = 1.
338 |
339 | # Next we determine an index for each unique shutter speed in the data.
340 | shutter_speeds = meta['ShutterSpeed']
341 | # Sort the shutter speeds from slowest (largest) to fastest (smallest).
342 | # This way index 0 will always correspond to the brightest image.
343 | unique_shutters = np.sort(np.unique(shutter_speeds))[::-1]
344 | exposure_idx = np.zeros_like(shutter_speeds, dtype=np.int32)
345 | for i, shutter in enumerate(unique_shutters):
346 | # Assign index `i` to all images with shutter speed `shutter`.
347 | exposure_idx[shutter_speeds == shutter] = i
348 | meta['exposure_idx'] = exposure_idx
349 | meta['unique_shutters'] = unique_shutters
350 | # Rescale to use relative shutter speeds, where 1. is the brightest.
351 | # This way the NeRF output with exposure=1 will always be reasonable.
352 | meta['exposure_values'] = shutter_speeds / unique_shutters[0]
353 |
354 | # Rescale raw sensor measurements to [0, 1] (plus noise).
355 | blacklevel = meta['BlackLevel'].reshape(-1, 1, 1)
356 | whitelevel = meta['WhiteLevel'].reshape(-1, 1, 1)
357 | images = (raws - blacklevel) / (whitelevel - blacklevel) * shutter_ratio
358 |
359 | # Calculate value for exposure level when gamma mapping, defaults to 97%.
360 | # Always based on full resolution image 0 (for consistency).
361 | image0_raw_demosaic = np.array(bilinear_demosaic_jax(images[0]))
362 | image0_rgb = image0_raw_demosaic @ meta['cam2rgb'][0].T
363 | exposure = np.percentile(image0_rgb, exposure_percentile)
364 | meta['exposure'] = exposure
365 | # Sweep over various exposure percentiles to visualize in training logs.
366 | exposure_levels = {p: np.percentile(image0_rgb, p) for p in _PERCENTILE_LIST}
367 | meta['exposure_levels'] = exposure_levels
368 |
369 | # Create postprocessing function mapping raw images to tonemapped sRGB space.
370 | cam2rgb0 = meta['cam2rgb'][0]
371 | meta['postprocess_fn'] = lambda z, x=exposure: postprocess_raw(z, cam2rgb0, x)
372 |
373 | # Demosaic Bayer images (preserves the measured RGGB values) and downsample
374 | # if needed. Moving array to device + running processing function in Jax +
375 | # copying back to CPU is faster than running directly on CPU.
376 | def processing_fn(x):
377 | x_jax = jnp.array(x)
378 | x_demosaic_jax = bilinear_demosaic_jax(x_jax)
379 | if n_downsample > 1:
380 | x_demosaic_jax = lib_image.downsample(x_demosaic_jax, n_downsample)
381 | return np.array(x_demosaic_jax)
382 | images = np.stack([processing_fn(im) for im in images], axis=0)
383 |
384 | return images, meta, testscene
385 |
386 |
387 | def best_fit_affine(x: _Array, y: _Array, axis: _Axis) -> _Array:
388 | """Computes best fit a, b such that a * x + b = y, in a least square sense."""
389 | x_m = x.mean(axis=axis)
390 | y_m = y.mean(axis=axis)
391 | xy_m = (x * y).mean(axis=axis)
392 | xx_m = (x * x).mean(axis=axis)
393 | # slope a = Cov(x, y) / Cov(x, x).
394 | a = (xy_m - x_m * y_m) / (xx_m - x_m * x_m)
395 | b = y_m - a * x_m
396 | return a, b
397 |
398 |
399 | def match_images_affine(est: _Array, gt: _Array,
400 | axis: _Axis = (0, 1)) -> _Array:
401 | """Computes affine best fit of gt->est, then maps est back to match gt."""
402 | # Mapping is computed gt->est to be robust since `est` may be very noisy.
403 | a, b = best_fit_affine(gt, est, axis=axis)
404 | # Inverse mapping back to gt ensures we use a consistent space for metrics.
405 | est_matched = (est - b) / a
406 | return est_matched
407 |
--------------------------------------------------------------------------------
/internal/ref_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for reflection directions and directional encodings."""
16 |
17 | from internal import math
18 | import jax.numpy as jnp
19 | import numpy as np
20 |
21 |
22 | def reflect(viewdirs, normals):
23 | """Reflect view directions about normals.
24 |
25 | The reflection of a vector v about a unit vector n is a vector u such that
26 | dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two
27 | equations is u = 2 dot(n, v) n - v.
28 |
29 | Args:
30 | viewdirs: [..., 3] array of view directions.
31 | normals: [..., 3] array of normal directions (assumed to be unit vectors).
32 |
33 | Returns:
34 | [..., 3] array of reflection directions.
35 | """
36 | return 2.0 * jnp.sum(
37 | normals * viewdirs, axis=-1, keepdims=True) * normals - viewdirs
38 |
39 |
40 | def l2_normalize(x, eps=jnp.finfo(jnp.float32).eps):
41 | """Normalize x to unit length along last axis."""
42 | return x / jnp.sqrt(jnp.maximum(jnp.sum(x**2, axis=-1, keepdims=True), eps))
43 |
44 |
45 | def compute_weighted_mae(weights, normals, normals_gt):
46 | """Compute weighted mean angular error, assuming normals are unit length."""
47 | one_eps = 1 - jnp.finfo(jnp.float32).eps
48 | return (weights * jnp.arccos(
49 | jnp.clip((normals * normals_gt).sum(-1), -one_eps,
50 | one_eps))).sum() / weights.sum() * 180.0 / jnp.pi
51 |
52 |
53 | def generalized_binomial_coeff(a, k):
54 | """Compute generalized binomial coefficients."""
55 | return np.prod(a - np.arange(k)) / np.math.factorial(k)
56 |
57 |
58 | def assoc_legendre_coeff(l, m, k):
59 | """Compute associated Legendre polynomial coefficients.
60 |
61 | Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the
62 | (l, m)th associated Legendre polynomial, P_l^m(cos(theta)).
63 |
64 | Args:
65 | l: associated Legendre polynomial degree.
66 | m: associated Legendre polynomial order.
67 | k: power of cos(theta).
68 |
69 | Returns:
70 | A float, the coefficient of the term corresponding to the inputs.
71 | """
72 | return ((-1)**m * 2**l * np.math.factorial(l) / np.math.factorial(k) /
73 | np.math.factorial(l - k - m) *
74 | generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l))
75 |
76 |
77 | def sph_harm_coeff(l, m, k):
78 | """Compute spherical harmonic coefficients."""
79 | return (np.sqrt(
80 | (2.0 * l + 1.0) * np.math.factorial(l - m) /
81 | (4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k))
82 |
83 |
84 | def get_ml_array(deg_view):
85 | """Create a list with all pairs of (l, m) values to use in the encoding."""
86 | ml_list = []
87 | for i in range(deg_view):
88 | l = 2**i
89 | # Only use nonnegative m values, later splitting real and imaginary parts.
90 | for m in range(l + 1):
91 | ml_list.append((m, l))
92 |
93 | # Convert list into a numpy array.
94 | ml_array = np.array(ml_list).T
95 | return ml_array
96 |
97 |
98 | def generate_ide_fn(deg_view):
99 | """Generate integrated directional encoding (IDE) function.
100 |
101 | This function returns a function that computes the integrated directional
102 | encoding from Equations 6-8 of arxiv.org/abs/2112.03907.
103 |
104 | Args:
105 | deg_view: number of spherical harmonics degrees to use.
106 |
107 | Returns:
108 | A function for evaluating integrated directional encoding.
109 |
110 | Raises:
111 | ValueError: if deg_view is larger than 5.
112 | """
113 | if deg_view > 5:
114 | raise ValueError('Only deg_view of at most 5 is numerically stable.')
115 |
116 | ml_array = get_ml_array(deg_view)
117 | l_max = 2**(deg_view - 1)
118 |
119 | # Create a matrix corresponding to ml_array holding all coefficients, which,
120 | # when multiplied (from the right) by the z coordinate Vandermonde matrix,
121 | # results in the z component of the encoding.
122 | mat = np.zeros((l_max + 1, ml_array.shape[1]))
123 | for i, (m, l) in enumerate(ml_array.T):
124 | for k in range(l - m + 1):
125 | mat[k, i] = sph_harm_coeff(l, m, k)
126 |
127 | def integrated_dir_enc_fn(xyz, kappa_inv):
128 | """Function returning integrated directional encoding (IDE).
129 |
130 | Args:
131 | xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at.
132 | kappa_inv: [..., 1] reciprocal of the concentration parameter of the von
133 | Mises-Fisher distribution.
134 |
135 | Returns:
136 | An array with the resulting IDE.
137 | """
138 | x = xyz[..., 0:1]
139 | y = xyz[..., 1:2]
140 | z = xyz[..., 2:3]
141 |
142 | # Compute z Vandermonde matrix.
143 | vmz = jnp.concatenate([z**i for i in range(mat.shape[0])], axis=-1)
144 |
145 | # Compute x+iy Vandermonde matrix.
146 | vmxy = jnp.concatenate([(x + 1j * y)**m for m in ml_array[0, :]], axis=-1)
147 |
148 | # Get spherical harmonics.
149 | sph_harms = vmxy * math.matmul(vmz, mat)
150 |
151 | # Apply attenuation function using the von Mises-Fisher distribution
152 | # concentration parameter, kappa.
153 | sigma = 0.5 * ml_array[1, :] * (ml_array[1, :] + 1)
154 | ide = sph_harms * jnp.exp(-sigma * kappa_inv)
155 |
156 | # Split into real and imaginary parts and return
157 | return jnp.concatenate([jnp.real(ide), jnp.imag(ide)], axis=-1)
158 |
159 | return integrated_dir_enc_fn
160 |
161 |
162 | def generate_dir_enc_fn(deg_view):
163 | """Generate directional encoding (DE) function.
164 |
165 | Args:
166 | deg_view: number of spherical harmonics degrees to use.
167 |
168 | Returns:
169 | A function for evaluating directional encoding.
170 | """
171 | integrated_dir_enc_fn = generate_ide_fn(deg_view)
172 |
173 | def dir_enc_fn(xyz):
174 | """Function returning directional encoding (DE)."""
175 | return integrated_dir_enc_fn(xyz, jnp.zeros_like(xyz[..., :1]))
176 |
177 | return dir_enc_fn
178 |
--------------------------------------------------------------------------------
/internal/render.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper functions for shooting and rendering rays."""
16 |
17 | from internal import stepfun
18 | import jax.numpy as jnp
19 |
20 |
21 | def lift_gaussian(d, t_mean, t_var, r_var, diag):
22 | """Lift a Gaussian defined along a ray to 3D coordinates."""
23 | mean = d[..., None, :] * t_mean[..., None]
24 |
25 | d_mag_sq = jnp.maximum(1e-10, jnp.sum(d**2, axis=-1, keepdims=True))
26 |
27 | if diag:
28 | d_outer_diag = d**2
29 | null_outer_diag = 1 - d_outer_diag / d_mag_sq
30 | t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
31 | xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
32 | cov_diag = t_cov_diag + xy_cov_diag
33 | return mean, cov_diag
34 | else:
35 | d_outer = d[..., :, None] * d[..., None, :]
36 | eye = jnp.eye(d.shape[-1])
37 | null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
38 | t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
39 | xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
40 | cov = t_cov + xy_cov
41 | return mean, cov
42 |
43 |
44 | def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True):
45 | """Approximate a conical frustum as a Gaussian distribution (mean+cov).
46 |
47 | Assumes the ray is originating from the origin, and base_radius is the
48 | radius at dist=1. Doesn't assume `d` is normalized.
49 |
50 | Args:
51 | d: jnp.float32 3-vector, the axis of the cone
52 | t0: float, the starting distance of the frustum.
53 | t1: float, the ending distance of the frustum.
54 | base_radius: float, the scale of the radius as a function of distance.
55 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
56 | stable: boolean, whether or not to use the stable computation described in
57 | the paper (setting this to False will cause catastrophic failure).
58 |
59 | Returns:
60 | a Gaussian (mean and covariance).
61 | """
62 | if stable:
63 | # Equation 7 in the paper (https://arxiv.org/abs/2103.13415).
64 | mu = (t0 + t1) / 2 # The average of the two `t` values.
65 | hw = (t1 - t0) / 2 # The half-width of the two `t` values.
66 | eps = jnp.finfo(jnp.float32).eps
67 | t_mean = mu + (2 * mu * hw**2) / jnp.maximum(eps, 3 * mu**2 + hw**2)
68 | denom = jnp.maximum(eps, 3 * mu**2 + hw**2)
69 | t_var = (hw**2) / 3 - (4 / 15) * hw**4 * (12 * mu**2 - hw**2) / denom**2
70 | r_var = (mu**2) / 4 + (5 / 12) * hw**2 - (4 / 15) * (hw**4) / denom
71 | else:
72 | # Equations 37-39 in the paper.
73 | t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
74 | r_var = 3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3)
75 | t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
76 | t_var = t_mosq - t_mean**2
77 | r_var *= base_radius**2
78 | return lift_gaussian(d, t_mean, t_var, r_var, diag)
79 |
80 |
81 | def cylinder_to_gaussian(d, t0, t1, radius, diag):
82 | """Approximate a cylinder as a Gaussian distribution (mean+cov).
83 |
84 | Assumes the ray is originating from the origin, and radius is the
85 | radius. Does not renormalize `d`.
86 |
87 | Args:
88 | d: jnp.float32 3-vector, the axis of the cylinder
89 | t0: float, the starting distance of the cylinder.
90 | t1: float, the ending distance of the cylinder.
91 | radius: float, the radius of the cylinder
92 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
93 |
94 | Returns:
95 | a Gaussian (mean and covariance).
96 | """
97 | t_mean = (t0 + t1) / 2
98 | r_var = radius**2 / 4
99 | t_var = (t1 - t0)**2 / 12
100 | return lift_gaussian(d, t_mean, t_var, r_var, diag)
101 |
102 |
103 | def cast_rays(tdist, origins, directions, radii, ray_shape, diag=True):
104 | """Cast rays (cone- or cylinder-shaped) and featurize sections of it.
105 |
106 | Args:
107 | tdist: float array, the "fencepost" distances along the ray.
108 | origins: float array, the ray origin coordinates.
109 | directions: float array, the ray direction vectors.
110 | radii: float array, the radii (base radii for cones) of the rays.
111 | ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.
112 | diag: boolean, whether or not the covariance matrices should be diagonal.
113 |
114 | Returns:
115 | a tuple of arrays of means and covariances.
116 | """
117 | t0 = tdist[..., :-1]
118 | t1 = tdist[..., 1:]
119 | if ray_shape == 'cone':
120 | gaussian_fn = conical_frustum_to_gaussian
121 | elif ray_shape == 'cylinder':
122 | gaussian_fn = cylinder_to_gaussian
123 | else:
124 | raise ValueError('ray_shape must be \'cone\' or \'cylinder\'')
125 | means, covs = gaussian_fn(directions, t0, t1, radii, diag)
126 | means = means + origins[..., None, :]
127 | return means, covs
128 |
129 |
130 | def compute_alpha_weights(density, tdist, dirs, opaque_background=False):
131 | """Helper function for computing alpha compositing weights."""
132 | t_delta = tdist[..., 1:] - tdist[..., :-1]
133 | delta = t_delta * jnp.linalg.norm(dirs[..., None, :], axis=-1)
134 | density_delta = density * delta
135 |
136 | if opaque_background:
137 | # Equivalent to making the final t-interval infinitely wide.
138 | density_delta = jnp.concatenate([
139 | density_delta[..., :-1],
140 | jnp.full_like(density_delta[..., -1:], jnp.inf)
141 | ],
142 | axis=-1)
143 |
144 | alpha = 1 - jnp.exp(-density_delta)
145 | trans = jnp.exp(-jnp.concatenate([
146 | jnp.zeros_like(density_delta[..., :1]),
147 | jnp.cumsum(density_delta[..., :-1], axis=-1)
148 | ],
149 | axis=-1))
150 | weights = alpha * trans
151 | return weights, alpha, trans
152 |
153 |
154 | def volumetric_rendering(rgbs,
155 | weights,
156 | tdist,
157 | bg_rgbs,
158 | t_far,
159 | compute_extras,
160 | extras=None):
161 | """Volumetric Rendering Function.
162 |
163 | Args:
164 | rgbs: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
165 | weights: jnp.ndarray(float32), weights, [batch_size, num_samples].
166 | tdist: jnp.ndarray(float32), [batch_size, num_samples].
167 | bg_rgbs: jnp.ndarray(float32), the color(s) to use for the background.
168 | t_far: jnp.ndarray(float32), [batch_size, 1], the distance of the far plane.
169 | compute_extras: bool, if True, compute extra quantities besides color.
170 | extras: dict, a set of values along rays to render by alpha compositing.
171 |
172 | Returns:
173 | rendering: a dict containing an rgb image of size [batch_size, 3], and other
174 | visualizations if compute_extras=True.
175 | """
176 | eps = jnp.finfo(jnp.float32).eps
177 | rendering = {}
178 |
179 | acc = weights.sum(axis=-1)
180 | bg_w = jnp.maximum(0, 1 - acc[..., None]) # The weight of the background.
181 | rgb = (weights[..., None] * rgbs).sum(axis=-2) + bg_w * bg_rgbs
182 | rendering['rgb'] = rgb
183 |
184 | if compute_extras:
185 | rendering['acc'] = acc
186 |
187 | if extras is not None:
188 | for k, v in extras.items():
189 | if v is not None:
190 | rendering[k] = (weights[..., None] * v).sum(axis=-2)
191 |
192 | expectation = lambda x: (weights * x).sum(axis=-1) / jnp.maximum(eps, acc)
193 | t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:])
194 | # For numerical stability this expectation is computing using log-distance.
195 | rendering['distance_mean'] = (
196 | jnp.clip(
197 | jnp.nan_to_num(jnp.exp(expectation(jnp.log(t_mids))), jnp.inf),
198 | tdist[..., 0], tdist[..., -1]))
199 |
200 | # Add an extra fencepost with the far distance at the end of each ray, with
201 | # whatever weight is needed to make the new weight vector sum to exactly 1
202 | # (`weights` is only guaranteed to sum to <= 1, not == 1).
203 | t_aug = jnp.concatenate([tdist, t_far], axis=-1)
204 | weights_aug = jnp.concatenate([weights, bg_w], axis=-1)
205 |
206 | ps = [5, 50, 95]
207 | distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps)
208 |
209 | for i, p in enumerate(ps):
210 | s = 'median' if p == 50 else 'percentile_' + str(p)
211 | rendering['distance_' + s] = distance_percentiles[..., i]
212 |
213 | return rendering
214 |
215 | def volumetric_rendering_nerfw(
216 | rgbs_static,
217 | rgbs_transient,
218 | trans,
219 | alpha_static,
220 | alpha_transient,
221 | tdist,
222 | bg_rgbs,
223 | t_far,
224 | compute_extras,
225 | extras=None):
226 | eps = jnp.finfo(jnp.float32).eps
227 | rendering = {}
228 | weights_static = trans * alpha_static
229 | weights_transient = trans * alpha_transient
230 | weights = weights_static + weights_transient
231 |
232 | acc = (weights_static + weights_transient).sum(axis=-1)
233 | bg_w = jnp.maximum(0, 1 - acc[..., None]) # The weight of the background.
234 | rgb = (
235 | weights_static[..., None] * rgbs_static +
236 | weights_transient[..., None] * rgbs_transient
237 | ).sum(axis=-2) + bg_w * bg_rgbs
238 | rendering['rgb'] = rgb
239 |
240 | if compute_extras:
241 | rendering['acc'] = acc
242 |
243 | if extras is not None:
244 | for k, v in extras.items():
245 | if v is not None:
246 | rendering[k] = (weights[..., None] * v).sum(axis=-2)
247 |
248 | expectation = lambda x: (weights * x).sum(axis=-1) / jnp.maximum(eps, acc)
249 | t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:])
250 | # For numerical stability this expectation is computing using log-distance.
251 | rendering['distance_mean'] = (
252 | jnp.clip(
253 | jnp.nan_to_num(jnp.exp(expectation(jnp.log(t_mids))), jnp.inf),
254 | tdist[..., 0], tdist[..., -1]))
255 |
256 | # Add an extra fencepost with the far distance at the end of each ray, with
257 | # whatever weight is needed to make the new weight vector sum to exactly 1
258 | # (`weights` is only guaranteed to sum to <= 1, not == 1).
259 | t_aug = jnp.concatenate([tdist, t_far], axis=-1)
260 | weights_aug = jnp.concatenate([weights, bg_w], axis=-1)
261 |
262 | ps = [5, 50, 95]
263 | distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps)
264 |
265 | for i, p in enumerate(ps):
266 | s = 'median' if p == 50 else 'percentile_' + str(p)
267 | rendering['distance_' + s] = distance_percentiles[..., i]
268 |
269 | return rendering
270 |
--------------------------------------------------------------------------------
/internal/robustnerf.py:
--------------------------------------------------------------------------------
1 | """Computes RobustNeRF mask."""
2 | from typing import Mapping, Tuple
3 |
4 | from jax import lax
5 | import jax.numpy as jnp
6 |
7 |
8 | def robustnerf_mask(
9 | errors: jnp.ndarray, loss_threshold: float, config: {str: float}
10 | ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
11 | """Computes RobustNeRF mask.
12 |
13 | Args:
14 | errors: f32[n,h,w,c]. Per-subpixel errors in a batch of patches.
15 | loss_threshold: f32[]. Upper bound on per-pixel loss to use to determine
16 | if a pixel is an inlier or not.
17 | config: Config object. A dictionary of hyperparameters.
18 |
19 | Returns:
20 | mask: f32[n,h,w,c or 1]. Binary mask that broadcasts to shape [n,h,w,c].
21 | stats: { str: f32[] }. Statistics to pass on.
22 | """
23 | epsilon = 1e-3
24 | error_dtype = errors.dtype
25 | error_per_pixel = jnp.mean(errors, axis=-1, keepdims=True) # f32[n,h,w,1]
26 | next_loss_threshold = jnp.quantile(
27 | error_per_pixel, config.robustnerf_inlier_quantile
28 | )
29 | mask = jnp.ones_like(error_per_pixel, dtype=error_dtype)
30 | stats = {
31 | 'loss_threshold': next_loss_threshold,
32 | }
33 | if config.enable_robustnerf_loss:
34 | assert (
35 | config.robustnerf_inner_patch_size <= config.patch_size
36 | ), 'patch_size must be larger than robustnerf_inner_patch_size.'
37 |
38 | # Inlier pixels have a value of 1.0 in the mask.
39 | is_inlier_pixel = (error_per_pixel < loss_threshold).astype(error_dtype)
40 | stats['is_inlier_loss'] = jnp.mean(is_inlier_pixel)
41 |
42 | # Apply fxf (3x3) box filter 'window' for smoothing (diffusion).
43 | f = config.robustnerf_smoothed_filter_size
44 | window = jnp.ones((1, 1, f, f)) / (f * f)
45 | has_inlier_neighbors = lax.conv(
46 | jnp.transpose(is_inlier_pixel, [0, 3, 1, 2]), window, (1, 1), 'SAME'
47 | )
48 | has_inlier_neighbors = jnp.transpose(has_inlier_neighbors, [0, 2, 3, 1])
49 |
50 | # Binarize after smoothing.
51 | # config.robustnerf_smoothed_inlier_quantile default is 0.5 which means at
52 | # least 50% of neighbouring pixels are inliers.
53 | has_inlier_neighbors = (
54 | has_inlier_neighbors > 1 - config.robustnerf_smoothed_inlier_quantile
55 | ).astype(error_dtype)
56 | stats['has_inlier_neighbors'] = jnp.mean(has_inlier_neighbors)
57 | is_inlier_pixel = (
58 | has_inlier_neighbors + is_inlier_pixel > epsilon
59 | ).astype(error_dtype)
60 | # Construct binary mask for inner pixels. The entire inner patch is either
61 | # active or inactive.
62 | # patch_size is the input patch (h,w), inner patch size can be any value
63 | # smaller than patch_size. Default is for the inner patch size to be half
64 | # the input patch size (i.e. 16x16 -> 8x8).
65 | inner_patch_mask = _robustnerf_inner_patch_mask(
66 | config.robustnerf_inner_patch_size // config.stride, config.patch_size // config.stride
67 | )
68 | is_inlier_patch = jnp.mean(
69 | is_inlier_pixel, axis=[1, 2], keepdims=True
70 | ) # f32[n,1,1,1]
71 | # robustnerf_inner_patch_inlier_quantile what percentage of the patch
72 | # should be inliers so that the patch is counted as an inlier patch.
73 | is_inlier_patch = (
74 | is_inlier_patch > 1 - config.robustnerf_inner_patch_inlier_quantile
75 | ).astype(error_dtype)
76 | is_inlier_patch = is_inlier_patch * inner_patch_mask
77 | stats['is_inlier_patch'] = jnp.mean(is_inlier_patch)
78 |
79 | # A pixel is an inlier if it is an inlier according to any of the above
80 | # criteria.
81 | mask = (
82 | is_inlier_patch + is_inlier_pixel > epsilon
83 | ).astype(error_dtype)
84 |
85 | stats['mask'] = jnp.mean(mask)
86 | return mask, stats
87 |
88 |
89 | def _robustnerf_inner_patch_mask(
90 | inner_patch_size, outer_patch_size, *, dtype=jnp.float32
91 | ):
92 | """Constructs binary mask for inner patch.
93 |
94 | Args:
95 | inner_patch_size: Size of the (square) inside patch.
96 | outer_patch_size: Size of the (square) outer patch.
97 | dtype: dtype for result
98 |
99 | Returns:
100 | Binary mask of shape (1, outer_patch_size, outer_patch_size, 1). Mask is
101 | 1.0 for the center (inner_patch_size, inner_patch_size) square and 0.0
102 | elsewhere.
103 | """
104 | pad_size_lower = (outer_patch_size - inner_patch_size) // 2
105 | pad_size_upper = outer_patch_size - (inner_patch_size + pad_size_lower)
106 | mask = jnp.pad(
107 | jnp.ones((1, inner_patch_size, inner_patch_size, 1), dtype=dtype),
108 | (
109 | (0, 0), # batch
110 | (pad_size_lower, pad_size_upper), # height
111 | (pad_size_lower, pad_size_upper), # width
112 | (0, 0), # channels
113 | ),
114 | )
115 | return mask
116 |
117 |
118 |
--------------------------------------------------------------------------------
/internal/stepfun.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tools for manipulating step functions (piecewise-constant 1D functions).
16 |
17 | We have a shared naming and dimension convention for these functions.
18 | All input/output step functions are assumed to be aligned along the last axis.
19 | `t` always indicates the x coordinates of the *endpoints* of a step function.
20 | `y` indicates unconstrained values for the *bins* of a step function
21 | `w` indicates bin weights that sum to <= 1. `p` indicates non-negative bin
22 | values that *integrate* to <= 1.
23 | """
24 |
25 | from internal import math
26 | import jax
27 | import jax.numpy as jnp
28 |
29 |
30 | def searchsorted(a, v):
31 | """Find indices where v should be inserted into a to maintain order.
32 |
33 | This behaves like jnp.searchsorted (its second output is the same as
34 | jnp.searchsorted's output if all elements of v are in [a[0], a[-1]]) but is
35 | faster because it wastes memory to save some compute.
36 |
37 | Args:
38 | a: tensor, the sorted reference points that we are scanning to see where v
39 | should lie.
40 | v: tensor, the query points that we are pretending to insert into a. Does
41 | not need to be sorted. All but the last dimensions should match or expand
42 | to those of a, the last dimension can differ.
43 |
44 | Returns:
45 | (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the
46 | range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or
47 | last index of a.
48 | """
49 | i = jnp.arange(a.shape[-1])
50 | v_ge_a = v[..., None, :] >= a[..., :, None]
51 | idx_lo = jnp.max(jnp.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2)
52 | idx_hi = jnp.min(jnp.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2)
53 | return idx_lo, idx_hi
54 |
55 |
56 | def query(tq, t, y, outside_value=0):
57 | """Look up the values of the step function (t, y) at locations tq."""
58 | idx_lo, idx_hi = searchsorted(t, tq)
59 | yq = jnp.where(idx_lo == idx_hi, outside_value,
60 | jnp.take_along_axis(y, idx_lo, axis=-1))
61 | return yq
62 |
63 |
64 | def inner_outer(t0, t1, y1):
65 | """Construct inner and outer measures on (t1, y1) for t0."""
66 | cy1 = jnp.concatenate([jnp.zeros_like(y1[..., :1]),
67 | jnp.cumsum(y1, axis=-1)],
68 | axis=-1)
69 | idx_lo, idx_hi = searchsorted(t1, t0)
70 |
71 | cy1_lo = jnp.take_along_axis(cy1, idx_lo, axis=-1)
72 | cy1_hi = jnp.take_along_axis(cy1, idx_hi, axis=-1)
73 |
74 | y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1]
75 | y0_inner = jnp.where(idx_hi[..., :-1] <= idx_lo[..., 1:],
76 | cy1_lo[..., 1:] - cy1_hi[..., :-1], 0)
77 | return y0_inner, y0_outer
78 |
79 |
80 | def lossfun_outer(t, w, t_env, w_env, eps=jnp.finfo(jnp.float32).eps):
81 | """The proposal weight should be an upper envelope on the nerf weight."""
82 | _, w_outer = inner_outer(t, t_env, w_env)
83 | # We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's
84 | # more effective to pull w_outer up than it is to push w_inner down.
85 | # Scaled half-quadratic loss that gives a constant gradient at w_outer = 0.
86 | return jnp.maximum(0, w - w_outer)**2 / (w + eps)
87 |
88 |
89 | def weight_to_pdf(t, w, eps=jnp.finfo(jnp.float32).eps**2):
90 | """Turn a vector of weights that sums to 1 into a PDF that integrates to 1."""
91 | return w / jnp.maximum(eps, (t[..., 1:] - t[..., :-1]))
92 |
93 |
94 | def pdf_to_weight(t, p):
95 | """Turn a PDF that integrates to 1 into a vector of weights that sums to 1."""
96 | return p * (t[..., 1:] - t[..., :-1])
97 |
98 |
99 | def max_dilate(t, w, dilation, domain=(-jnp.inf, jnp.inf)):
100 | """Dilate (via max-pooling) a non-negative step function."""
101 | t0 = t[..., :-1] - dilation
102 | t1 = t[..., 1:] + dilation
103 | t_dilate = jnp.sort(jnp.concatenate([t, t0, t1], axis=-1), axis=-1)
104 | t_dilate = jnp.clip(t_dilate, *domain)
105 | w_dilate = jnp.max(
106 | jnp.where(
107 | (t0[..., None, :] <= t_dilate[..., None])
108 | & (t1[..., None, :] > t_dilate[..., None]),
109 | w[..., None, :],
110 | 0,
111 | ),
112 | axis=-1)[..., :-1]
113 | return t_dilate, w_dilate
114 |
115 |
116 | def max_dilate_weights(t,
117 | w,
118 | dilation,
119 | domain=(-jnp.inf, jnp.inf),
120 | renormalize=False,
121 | eps=jnp.finfo(jnp.float32).eps**2):
122 | """Dilate (via max-pooling) a set of weights."""
123 | p = weight_to_pdf(t, w)
124 | t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain)
125 | w_dilate = pdf_to_weight(t_dilate, p_dilate)
126 | if renormalize:
127 | w_dilate /= jnp.maximum(eps, jnp.sum(w_dilate, axis=-1, keepdims=True))
128 | return t_dilate, w_dilate
129 |
130 |
131 | def integrate_weights(w):
132 | """Compute the cumulative sum of w, assuming all weight vectors sum to 1.
133 |
134 | The output's size on the last dimension is one greater than that of the input,
135 | because we're computing the integral corresponding to the endpoints of a step
136 | function, not the integral of the interior/bin values.
137 |
138 | Args:
139 | w: Tensor, which will be integrated along the last axis. This is assumed to
140 | sum to 1 along the last axis, and this function will (silently) break if
141 | that is not the case.
142 |
143 | Returns:
144 | cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
145 | """
146 | cw = jnp.minimum(1, jnp.cumsum(w[..., :-1], axis=-1))
147 | shape = cw.shape[:-1] + (1,)
148 | # Ensure that the CDF starts with exactly 0 and ends with exactly 1.
149 | cw0 = jnp.concatenate([jnp.zeros(shape), cw, jnp.ones(shape)], axis=-1)
150 | return cw0
151 |
152 |
153 | def invert_cdf(u, t, w_logits, use_gpu_resampling=False):
154 | """Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
155 | # Compute the PDF and CDF for each weight vector.
156 | w = jax.nn.softmax(w_logits, axis=-1)
157 | cw = integrate_weights(w)
158 | # Interpolate into the inverse CDF.
159 | interp_fn = math.interp if use_gpu_resampling else math.sorted_interp
160 | t_new = interp_fn(u, cw, t)
161 | return t_new
162 |
163 |
164 | def sample(rng,
165 | t,
166 | w_logits,
167 | num_samples,
168 | single_jitter=False,
169 | deterministic_center=False,
170 | use_gpu_resampling=False):
171 | """Piecewise-Constant PDF sampling from a step function.
172 |
173 | Args:
174 | rng: random number generator (or None for `linspace` sampling).
175 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
176 | w_logits: [..., num_bins], logits corresponding to bin weights
177 | num_samples: int, the number of samples.
178 | single_jitter: bool, if True, jitter every sample along each ray by the same
179 | amount in the inverse CDF. Otherwise, jitter each sample independently.
180 | deterministic_center: bool, if False, when `rng` is None return samples that
181 | linspace the entire PDF. If True, skip the front and back of the linspace
182 | so that the centers of each PDF interval are returned.
183 | use_gpu_resampling: bool, If True this resamples the rays based on a
184 | "gather" instruction, which is fast on GPUs but slow on TPUs. If False,
185 | this resamples the rays based on brute-force searches, which is fast on
186 | TPUs, but slow on GPUs.
187 |
188 | Returns:
189 | t_samples: jnp.ndarray(float32), [batch_size, num_samples].
190 | """
191 | eps = jnp.finfo(jnp.float32).eps
192 |
193 | # Draw uniform samples.
194 | if rng is None:
195 | # Match the behavior of jax.random.uniform() by spanning [0, 1-eps].
196 | if deterministic_center:
197 | pad = 1 / (2 * num_samples)
198 | u = jnp.linspace(pad, 1. - pad - eps, num_samples)
199 | else:
200 | u = jnp.linspace(0, 1. - eps, num_samples)
201 | u = jnp.broadcast_to(u, t.shape[:-1] + (num_samples,))
202 | else:
203 | # `u` is in [0, 1) --- it can be zero, but it can never be 1.
204 | u_max = eps + (1 - eps) / num_samples
205 | max_jitter = (1 - u_max) / (num_samples - 1) - eps
206 | d = 1 if single_jitter else num_samples
207 | u = (
208 | jnp.linspace(0, 1 - u_max, num_samples) +
209 | jax.random.uniform(rng, t.shape[:-1] + (d,), maxval=max_jitter))
210 |
211 | return invert_cdf(u, t, w_logits, use_gpu_resampling=use_gpu_resampling)
212 |
213 |
214 | def sample_intervals(rng,
215 | t,
216 | w_logits,
217 | num_samples,
218 | single_jitter=False,
219 | domain=(-jnp.inf, jnp.inf),
220 | use_gpu_resampling=False):
221 | """Sample *intervals* (rather than points) from a step function.
222 |
223 | Args:
224 | rng: random number generator (or None for `linspace` sampling).
225 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
226 | w_logits: [..., num_bins], logits corresponding to bin weights
227 | num_samples: int, the number of intervals to sample.
228 | single_jitter: bool, if True, jitter every sample along each ray by the same
229 | amount in the inverse CDF. Otherwise, jitter each sample independently.
230 | domain: (minval, maxval), the range of valid values for `t`.
231 | use_gpu_resampling: bool, If True this resamples the rays based on a
232 | "gather" instruction, which is fast on GPUs but slow on TPUs. If False,
233 | this resamples the rays based on brute-force searches, which is fast on
234 | TPUs, but slow on GPUs.
235 |
236 | Returns:
237 | t_samples: jnp.ndarray(float32), [batch_size, num_samples].
238 | """
239 | if num_samples <= 1:
240 | raise ValueError(f'num_samples must be > 1, is {num_samples}.')
241 |
242 | # Sample a set of points from the step function.
243 | centers = sample(
244 | rng,
245 | t,
246 | w_logits,
247 | num_samples,
248 | single_jitter,
249 | deterministic_center=True,
250 | use_gpu_resampling=use_gpu_resampling)
251 |
252 | # The intervals we return will span the midpoints of each adjacent sample.
253 | mid = (centers[..., 1:] + centers[..., :-1]) / 2
254 |
255 | # Each first/last fencepost is the reflection of the first/last midpoint
256 | # around the first/last sampled center. We clamp to the limits of the input
257 | # domain, provided by the caller.
258 | minval, maxval = domain
259 | first = jnp.maximum(minval, 2 * centers[..., :1] - mid[..., :1])
260 | last = jnp.minimum(maxval, 2 * centers[..., -1:] - mid[..., -1:])
261 |
262 | t_samples = jnp.concatenate([first, mid, last], axis=-1)
263 | return t_samples
264 |
265 |
266 | def lossfun_distortion(t, w):
267 | """Compute iint w[i] w[j] |t[i] - t[j]| di dj."""
268 | # The loss incurred between all pairs of intervals.
269 | ut = (t[..., 1:] + t[..., :-1]) / 2
270 | dut = jnp.abs(ut[..., :, None] - ut[..., None, :])
271 | loss_inter = jnp.sum(w * jnp.sum(w[..., None, :] * dut, axis=-1), axis=-1)
272 |
273 | # The loss incurred within each individual interval with itself.
274 | loss_intra = jnp.sum(w**2 * (t[..., 1:] - t[..., :-1]), axis=-1) / 3
275 |
276 | return loss_inter + loss_intra
277 |
278 |
279 | def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi):
280 | """Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi])."""
281 | # Distortion when the intervals do not overlap.
282 | d_disjoint = jnp.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2)
283 |
284 | # Distortion when the intervals overlap.
285 | d_overlap = (2 *
286 | (jnp.minimum(t0_hi, t1_hi)**3 - jnp.maximum(t0_lo, t1_lo)**3) +
287 | 3 * (t1_hi * t0_hi * jnp.abs(t1_hi - t0_hi) +
288 | t1_lo * t0_lo * jnp.abs(t1_lo - t0_lo) + t1_hi * t0_lo *
289 | (t0_lo - t1_hi) + t1_lo * t0_hi *
290 | (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo))
291 |
292 | # Are the two intervals not overlapping?
293 | are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi)
294 |
295 | return jnp.where(are_disjoint, d_disjoint, d_overlap)
296 |
297 |
298 | def weighted_percentile(t, w, ps):
299 | """Compute the weighted percentiles of a step function. w's must sum to 1."""
300 | cw = integrate_weights(w)
301 | # We want to interpolate into the integrated weights according to `ps`.
302 | fn = lambda cw_i, t_i: jnp.interp(jnp.array(ps) / 100, cw_i, t_i)
303 | # Vmap fn to an arbitrary number of leading dimensions.
304 | cw_mat = cw.reshape([-1, cw.shape[-1]])
305 | t_mat = t.reshape([-1, t.shape[-1]])
306 | wprctile_mat = (jax.vmap(fn, 0)(cw_mat, t_mat))
307 | wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),))
308 | return wprctile
309 |
310 |
311 | def resample(t, tp, vp, use_avg=False, eps=jnp.finfo(jnp.float32).eps):
312 | """Resample a step function defined by (tp, vp) into intervals t.
313 |
314 | Notation roughly matches jnp.interp. Resamples by summation by default.
315 |
316 | Args:
317 | t: tensor with shape (..., n+1), the endpoints to resample into.
318 | tp: tensor with shape (..., m+1), the endpoints of the step function being
319 | resampled.
320 | vp: tensor with shape (..., m), the values of the step function being
321 | resampled.
322 | use_avg: bool, if False, return the sum of the step function for each
323 | interval in `t`. If True, return the average, weighted by the width of
324 | each interval in `t`.
325 | eps: float, a small value to prevent division by zero when use_avg=True.
326 |
327 | Returns:
328 | v: tensor with shape (..., n), the values of the resampled step function.
329 | """
330 | if use_avg:
331 | wp = jnp.diff(tp, axis=-1)
332 | v_numer = resample(t, tp, vp * wp, use_avg=False)
333 | v_denom = resample(t, tp, wp, use_avg=False)
334 | v = v_numer / jnp.maximum(eps, v_denom)
335 | return v
336 |
337 | acc = jnp.cumsum(vp, axis=-1)
338 | acc0 = jnp.concatenate([jnp.zeros(acc.shape[:-1] + (1,)), acc], axis=-1)
339 | acc0_resampled = jnp.vectorize(
340 | jnp.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0)
341 | v = jnp.diff(acc0_resampled, axis=-1)
342 | return v
343 |
--------------------------------------------------------------------------------
/internal/train_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Training step and model creation functions."""
16 |
17 | import collections
18 | import functools
19 | from typing import Any, Callable, Dict, MutableMapping, Optional, Text, Tuple
20 |
21 | from flax.core.scope import FrozenVariableDict
22 | from flax.training.train_state import TrainState
23 | from internal import camera_utils
24 | from internal import configs
25 | from internal import datasets
26 | from internal import image
27 | from internal import math
28 | from internal import models
29 | from internal import ref_utils
30 | from internal import robustnerf
31 | from internal import stepfun
32 | from internal import utils
33 | import jax
34 | from jax import random
35 | import jax.numpy as jnp
36 | import optax
37 |
38 | from flax import traverse_util
39 | import jax.scipy as jsp
40 | from functools import partial
41 |
42 | def flattened_traversal(fn):
43 | def mask(tree):
44 | flat = traverse_util.flatten_dict(tree)
45 | return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
46 |
47 | return mask
48 |
49 | def tree_sum(tree):
50 | return jax.tree_util.tree_reduce(lambda x, y: x + y, tree, initializer=0)
51 |
52 |
53 | def tree_norm_sq(tree):
54 | return tree_sum(jax.tree_util.tree_map(lambda x: jnp.sum(x**2), tree))
55 |
56 |
57 | def tree_norm(tree):
58 | return jnp.sqrt(tree_norm_sq(tree))
59 |
60 |
61 | def tree_abs_max(tree):
62 | return jax.tree_util.tree_reduce(
63 | lambda x, y: jnp.maximum(x, jnp.max(jnp.abs(y))), tree, initializer=0)
64 |
65 |
66 | def tree_len(tree):
67 | return tree_sum(
68 | jax.tree_util.tree_map(lambda z: jnp.prod(jnp.array(z.shape)), tree))
69 |
70 |
71 | def summarize_tree(tree, fn, ancestry=(), max_depth=3):
72 | """Flatten 'tree' while 'fn'-ing values and formatting keys like/this."""
73 | stats = {}
74 | for k, v in tree.items():
75 | name = ancestry + (k,)
76 | stats['/'.join(name)] = fn(v)
77 | if hasattr(v, 'items') and len(ancestry) < (max_depth - 1):
78 | stats.update(summarize_tree(v, fn, ancestry=name, max_depth=max_depth))
79 | return stats
80 |
81 | def dino_var_loss(renderings, config):
82 | losses = []
83 | for rendering in renderings:
84 | losses.append(jnp.mean(rendering['uncer_var']))
85 | return config.dino_var_mult * jnp.mean(jnp.array(losses))
86 |
87 |
88 | from jax.lax import conv_general_dilated
89 |
90 | def create_window(window_size, channels):
91 | """Create a window for SSIM computation."""
92 | window = jnp.ones((window_size, window_size, channels)) / (window_size**2 * channels)
93 | # Reshape for convolution: (spatial_dim_1, spatial_dim_2, in_channels, out_channels)
94 | return window.reshape(window_size, window_size, channels, 1)
95 |
96 | def convolve(img, window):
97 | """Perform a convolution operation in a functional style."""
98 | # Define the dimension specification for the convolution operation
99 | dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
100 | return conv_general_dilated(img, window, (1, 1), 'SAME', dimension_numbers=dimension_numbers)
101 |
102 | def compute_ssim(img1, img2, window_size=5):
103 | C1 = 0.01 ** 2
104 | C2 = 0.03 ** 2
105 | C3 = C2 / 2
106 |
107 | window = create_window(window_size, 3)# size for channel
108 |
109 | mu1 = convolve(img1, window)
110 | mu2 = convolve(img2, window)
111 |
112 | mu1_sq = mu1 ** 2
113 | mu2_sq = mu2 ** 2
114 | mu1_mu2 = mu1 * mu2
115 |
116 | sigma1_sq = convolve(img1 * img1, window) - mu1_sq
117 | sigma2_sq = convolve(img2 * img2, window) - mu2_sq
118 | sigma12 = convolve(img1 * img2, window) - mu1_mu2
119 |
120 | # Clip the variances and covariances to valid values.
121 | # Variance must be non-negative:
122 | epsilon = jnp.finfo(jnp.float32).eps**2
123 | sigma1_sq = jnp.maximum(epsilon, sigma1_sq)
124 | sigma2_sq = jnp.maximum(epsilon, sigma2_sq)
125 | sigma12 = jnp.sign(sigma12) * jnp.minimum(
126 | jnp.sqrt(sigma1_sq * sigma2_sq), jnp.abs(sigma12))
127 |
128 | l = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)
129 | c = (2 * jnp.sqrt(sigma1_sq) * jnp.sqrt(sigma2_sq) + C2) / (sigma1_sq + sigma2_sq + C2)
130 | s = (sigma12 + C3) / (jnp.sqrt(sigma1_sq) * jnp.sqrt(sigma2_sq) + C3)
131 |
132 | c = jnp.clip(c, a_max=0.98)
133 | s = jnp.clip(s, a_max=0.98)
134 |
135 | return l, c, s
136 |
137 | def compute_data_loss(batch, renderings, rays, loss_threshold, config, train_frac):
138 | """Computes data loss terms for RGB, normal, and depth outputs."""
139 | data_losses = []
140 | stats = collections.defaultdict(lambda: [])
141 |
142 | # lossmult can be used to apply a weight to each ray in the batch.
143 | # For example: masking out rays, applying the Bayer mosaic mask, upweighting
144 | # rays from lower resolution images and so on.
145 | lossmult = rays.lossmult
146 | lossmult = jnp.broadcast_to(lossmult, batch.rgb[..., :3].shape)
147 | if config.disable_multiscale_loss:
148 | lossmult = jnp.ones_like(lossmult)
149 |
150 | for rendering in renderings:
151 | resid_sq = (rendering['rgb'] - batch.rgb[..., :3])**2
152 | denom = lossmult.sum()
153 | stats['mses'].append((lossmult * resid_sq).sum() / denom)
154 |
155 | if config.data_loss_type == 'mse':
156 | # Mean-squared error (L2) loss.
157 | data_loss = resid_sq
158 |
159 | elif config.data_loss_type == 'on-the-go':
160 | uncer = rendering['uncer']
161 | uncer = jnp.clip(uncer, a_min=config.uncer_clip_min)+1e-3
162 | if config.stop_ssim_gradient:
163 | l,c,s = compute_ssim(jax.lax.stop_gradient(rendering['rgb']), batch.rgb[..., :3], config.ssim_window_size)
164 | else:
165 | l,c,s = compute_ssim(rendering['rgb'], batch.rgb[..., :3])
166 | train_frac = jnp.broadcast_to(train_frac, uncer.shape)
167 |
168 | # Calculate the SSIM loss rate, which starts at 100 and can scale up to 1000.
169 | # This is not mentioned in the paper since its effect is marginal
170 | bias = lambda x, s: x / (1 + (1 - x)*(1 / s - 2))
171 | rate = 100 + bias(train_frac, config.ssim_anneal) * 900
172 | my_ssim_loss = jnp.clip(rate * (1-l)*(1-s)*(1-c), a_max=config.ssim_clip_max)
173 | ssim_loss = my_ssim_loss / uncer**2 + config.reg_mult * jnp.log(uncer)
174 |
175 | # Adjust uncertainty to slowly increase based on the SSIM training fraction.
176 | uncer_rate = 1 + 1 * bias(train_frac, config.ssim_anneal)
177 | uncer = (jax.lax.stop_gradient(uncer) - config.uncer_clip_min) * uncer_rate + config.uncer_clip_min
178 | data_loss = 0.5 * resid_sq / (uncer) ** 2
179 | data_loss += config.ssim_mult * ssim_loss
180 |
181 | # robustnerf loss
182 | elif config.data_loss_type == 'robustnerf':
183 | mask, robust_stats = robustnerf.robustnerf_mask(resid_sq, loss_threshold,
184 | config)
185 | data_loss = resid_sq * mask
186 | stats.update(robust_stats)
187 | else:
188 | assert False
189 | data_losses.append((lossmult * data_loss).sum() / denom)
190 |
191 | data_losses = jnp.array(data_losses)
192 | loss = (
193 | config.data_coarse_loss_mult * jnp.sum(data_losses[:-1]) +
194 | config.data_loss_mult * data_losses[-1])
195 | stats = {k: jnp.array(stats[k]) for k in stats}
196 | return loss, stats
197 |
198 |
199 | def interlevel_loss(ray_history, config):
200 | """Computes the interlevel loss defined in mip-NeRF 360."""
201 | # Stop the gradient from the interlevel loss onto the NeRF MLP.
202 | last_ray_results = ray_history[-1]
203 | c = jax.lax.stop_gradient(last_ray_results['sdist'])
204 | w = jax.lax.stop_gradient(last_ray_results['weights'])
205 | loss_interlevel = 0.
206 | for ray_results in ray_history[:-1]:
207 | cp = ray_results['sdist']
208 | wp = ray_results['weights']
209 | loss_interlevel += jnp.mean(stepfun.lossfun_outer(c, w, cp, wp))
210 | return config.interlevel_loss_mult * loss_interlevel
211 |
212 |
213 | def distortion_loss(ray_history, config):
214 | """Computes the distortion loss regularizer defined in mip-NeRF 360."""
215 | last_ray_results = ray_history[-1]
216 | c = last_ray_results['sdist']
217 | w = last_ray_results['weights']
218 | loss = jnp.mean(stepfun.lossfun_distortion(c, w))
219 | return config.distortion_loss_mult * loss
220 |
221 |
222 | def orientation_loss(rays, model, ray_history, config):
223 | """Computes the orientation loss regularizer defined in ref-NeRF."""
224 | total_loss = 0.
225 | for i, ray_results in enumerate(ray_history):
226 | w = ray_results['weights']
227 | n = ray_results[config.orientation_loss_target]
228 | if n is None:
229 | raise ValueError('Normals cannot be None if orientation loss is on.')
230 | # Negate viewdirs to represent normalized vectors from point to camera.
231 | v = -1. * rays.viewdirs
232 | n_dot_v = (n * v[..., None, :]).sum(axis=-1)
233 | loss = jnp.mean((w * jnp.minimum(0.0, n_dot_v)**2).sum(axis=-1))
234 | if i < model.num_levels - 1:
235 | total_loss += config.orientation_coarse_loss_mult * loss
236 | else:
237 | total_loss += config.orientation_loss_mult * loss
238 | return total_loss
239 |
240 |
241 | def predicted_normal_loss(model, ray_history, config):
242 | """Computes the predicted normal supervision loss defined in ref-NeRF."""
243 | total_loss = 0.
244 | for i, ray_results in enumerate(ray_history):
245 | w = ray_results['weights']
246 | n = ray_results['normals']
247 | n_pred = ray_results['normals_pred']
248 | if n is None or n_pred is None:
249 | raise ValueError(
250 | 'Predicted normals and gradient normals cannot be None if '
251 | 'predicted normal loss is on.')
252 | loss = jnp.mean((w * (1.0 - jnp.sum(n * n_pred, axis=-1))).sum(axis=-1))
253 | if i < model.num_levels - 1:
254 | total_loss += config.predicted_normal_coarse_loss_mult * loss
255 | else:
256 | total_loss += config.predicted_normal_loss_mult * loss
257 | return total_loss
258 |
259 |
260 | def clip_gradients(grad, config):
261 | """Clips gradients of each MLP individually based on norm and max value."""
262 | # Clip the gradients of each MLP individually.
263 | grad_clipped = {'params': {}}
264 | for k, g in grad['params'].items():
265 | # Clip by value.
266 | if config.grad_max_val > 0:
267 | g = jax.tree_util.tree_map(
268 | lambda z: jnp.clip(z, -config.grad_max_val, config.grad_max_val), g)
269 |
270 | # Then clip by norm.
271 | if config.grad_max_norm > 0:
272 | mult = jnp.minimum(
273 | 1, config.grad_max_norm / (jnp.finfo(jnp.float32).eps + tree_norm(g)))
274 | g = jax.tree_util.tree_map(lambda z: mult * z, g) # pylint:disable=cell-var-from-loop
275 |
276 | grad_clipped['params'][k] = g
277 | grad = type(grad)(grad_clipped)
278 | return grad
279 |
280 |
281 | def create_train_step(model: models.Model,
282 | config: configs.Config,
283 | dataset: Optional[datasets.Dataset] = None):
284 | """Creates the pmap'ed Nerf training function.
285 |
286 | Args:
287 | model: The linen model.
288 | config: The configuration.
289 | dataset: Training dataset.
290 |
291 | Returns:
292 | pmap'ed training function.
293 | """
294 | if dataset is None:
295 | camtype = camera_utils.ProjectionType.PERSPECTIVE
296 | else:
297 | camtype = dataset.camtype
298 |
299 | def train_step(
300 | rng,
301 | state,
302 | batch,
303 | cameras,
304 | train_frac,
305 | loss_threshold,
306 | ):
307 | """One optimization step.
308 |
309 | Args:
310 | rng: jnp.ndarray, random number generator.
311 | state: TrainState, state of the model/optimizer.
312 | batch: dict, a mini-batch of data for training.
313 | cameras: module containing camera poses.
314 | train_frac: float, the fraction of training that is complete.
315 | loss_threshold: float, the loss threshold for inliers (for robustness).
316 |
317 | Returns:
318 | A tuple (new_state, stats, rng) with
319 | new_state: TrainState, new training state.
320 | stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
321 | rng: jnp.ndarray, updated random number generator.
322 | """
323 | rng, key, dropout_key = random.split(rng, num=3)
324 |
325 | def loss_fn(variables, dropout_key):
326 | rays = batch.rays
327 | if config.cast_rays_in_train_step:
328 | rays = camera_utils.cast_ray_batch(cameras, rays, camtype, xnp=jnp)
329 |
330 | # Indicates whether we need to compute output normal or depth maps in 2D.
331 |
332 | renderings, ray_history = model.apply(
333 | variables,
334 | key if config.randomized else None,
335 | rays,
336 | train_frac=train_frac,
337 | compute_extras=(),
338 | zero_glo=False,
339 | rngs={'dropout': dropout_key})
340 | losses = {}
341 | data_loss, stats = compute_data_loss(batch, renderings, rays,
342 | loss_threshold, config, train_frac)
343 | losses['data'] = data_loss
344 |
345 | if config.interlevel_loss_mult > 0:
346 | losses['interlevel'] = interlevel_loss(ray_history, config)
347 |
348 | if config.distortion_loss_mult > 0:
349 | losses['distortion'] = distortion_loss(ray_history, config)
350 |
351 | if (config.orientation_coarse_loss_mult > 0 or
352 | config.orientation_loss_mult > 0):
353 | losses['orientation'] = orientation_loss(rays, model, ray_history,
354 | config)
355 |
356 | if (config.predicted_normal_coarse_loss_mult > 0 or
357 | config.predicted_normal_loss_mult > 0):
358 | losses['predicted_normals'] = predicted_normal_loss(
359 | model, ray_history, config)
360 |
361 | if config.dino_var_mult > 0:
362 | losses['dino_var'] = dino_var_loss(renderings, config)
363 |
364 | stats['weight_l2s'] = summarize_tree(variables['params'], tree_norm_sq)
365 |
366 | if config.weight_decay_mults:
367 | it = config.weight_decay_mults.items
368 | losses['weight'] = jnp.sum(
369 | jnp.array([m * stats['weight_l2s'][k] for k, m in it()]))
370 | stats['loss'] = jnp.sum(jnp.array(list(losses.values())))
371 | stats['losses'] = losses
372 |
373 | return stats['loss'], stats
374 |
375 | loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
376 | (_, stats), grad = loss_grad_fn(state.params, dropout_key)
377 |
378 | pmean = lambda x: jax.lax.pmean(x, axis_name='batch')
379 | grad = pmean(grad)
380 | stats = pmean(stats)
381 |
382 | stats['grad_norms'] = summarize_tree(grad['params'], tree_norm)
383 | stats['grad_maxes'] = summarize_tree(grad['params'], tree_abs_max)
384 |
385 | grad = clip_gradients(grad, config)
386 |
387 | grad = jax.tree_util.tree_map(jnp.nan_to_num, grad)
388 |
389 | new_state = state.apply_gradients(grads=grad)
390 |
391 | opt_delta = jax.tree_util.tree_map(lambda x, y: x - y, new_state,
392 | state).params['params']
393 | stats['opt_update_norms'] = summarize_tree(opt_delta, tree_norm)
394 | stats['opt_update_maxes'] = summarize_tree(opt_delta, tree_abs_max)
395 |
396 | stats['psnrs'] = image.mse_to_psnr(stats['mses'])
397 | stats['psnr'] = stats['psnrs'][-1]
398 | return new_state, stats, rng
399 |
400 | train_pstep = jax.pmap(
401 | train_step,
402 | axis_name='batch',
403 | in_axes=(0, 0, 0, None, None, None),
404 | donate_argnums=(0, 1))
405 | return train_pstep
406 |
407 |
408 | def create_optimizer(
409 | config: configs.Config,
410 | variables: FrozenVariableDict) -> Tuple[TrainState, Callable[[int], float]]:
411 | """Creates optax optimizer for model training."""
412 | adam_kwargs = {
413 | 'b1': config.adam_beta1,
414 | 'b2': config.adam_beta2,
415 | 'eps': config.adam_eps,
416 | }
417 | lr_kwargs = {
418 | 'max_steps': config.max_steps,
419 | 'lr_delay_steps': config.lr_delay_steps,
420 | 'lr_delay_mult': config.lr_delay_mult,
421 | }
422 |
423 |
424 |
425 | def get_lr_fn(lr_init, lr_final):
426 | return functools.partial(
427 | math.learning_rate_decay,
428 | lr_init=lr_init,
429 | lr_final=lr_final,
430 | **lr_kwargs)
431 |
432 | lr_fn_main = get_lr_fn(config.lr_init, config.lr_final)
433 | tx = optax.adam(learning_rate=lr_fn_main, **adam_kwargs)
434 |
435 | return TrainState.create(apply_fn=None, params=variables, tx=tx), lr_fn_main
436 |
437 |
438 | def create_render_fn(model: models.Model):
439 | """Creates pmap'ed function for full image rendering."""
440 |
441 | def render_eval_fn(variables, train_frac, _, rays):
442 | return jax.lax.all_gather(
443 | model.apply(
444 | variables,
445 | None, # Deterministic.
446 | rays,
447 | train_frac=train_frac,
448 | compute_extras=True,
449 | is_training=False),
450 | axis_name='batch')
451 |
452 | # pmap over only the data input.
453 | render_eval_pfn = jax.pmap(
454 | render_eval_fn,
455 | in_axes=(None, None, None, 0),
456 | axis_name='batch',
457 | )
458 | return render_eval_pfn
459 |
460 |
461 | def setup_model(
462 | config: configs.Config,
463 | rng: jnp.array,
464 | dataset: Optional[datasets.Dataset] = None,
465 | ) -> Tuple[models.Model, TrainState, Callable[
466 | [FrozenVariableDict, jnp.array, utils.Rays],
467 | MutableMapping[Text, Any]], Callable[
468 | [jnp.array, TrainState, utils.Batch,
469 | Optional[Tuple[Any, ...]], float, float],
470 | Tuple[TrainState, Dict[Text, Any], jnp.array]], Callable[[int], float]]:
471 | """Creates NeRF model, optimizer, and pmap-ed train/render functions."""
472 | feat_dim = config.feat_dim
473 | dummy_rays = utils.dummy_rays(
474 | feat_dim, include_exposure_idx=config.rawnerf_mode, include_exposure_values=True)
475 | model, variables = models.construct_model(rng, dummy_rays, config)
476 |
477 | state, lr_fn = create_optimizer(config, variables)
478 | render_eval_pfn = create_render_fn(model)
479 | train_pstep = create_train_step(model, config, dataset=dataset)
480 |
481 | return model, state, render_eval_pfn, train_pstep, lr_fn
482 |
--------------------------------------------------------------------------------
/internal/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility functions."""
16 |
17 | import enum
18 | import os
19 | from typing import Any, Dict, Optional, Union
20 |
21 | import flax
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 | from PIL import ExifTags
26 | from PIL import Image
27 |
28 | _Array = Union[np.ndarray, jnp.ndarray]
29 |
30 |
31 | @flax.struct.dataclass
32 | class Pixels:
33 | """All tensors must have the same num_dims and first n-1 dims must match."""
34 | pix_x_int: _Array
35 | pix_y_int: _Array
36 | lossmult: _Array
37 | near: _Array
38 | far: _Array
39 | cam_idx: _Array
40 | exposure_idx: Optional[_Array] = None
41 | exposure_values: Optional[_Array] = None
42 | features: Optional[_Array] = None
43 |
44 |
45 | @flax.struct.dataclass
46 | class Rays:
47 | """All tensors must have the same num_dims and first n-1 dims must match."""
48 | origins: _Array
49 | directions: _Array
50 | viewdirs: _Array
51 | radii: _Array
52 | imageplane: _Array
53 | lossmult: _Array
54 | near: _Array
55 | far: _Array
56 | cam_idx: _Array
57 | exposure_idx: Optional[_Array] = None
58 | exposure_values: Optional[_Array] = None
59 | features: Optional[_Array] = None
60 |
61 |
62 |
63 |
64 | # Dummy Rays object that can be used to initialize NeRF model.
65 | def dummy_rays(feat_dim: int,
66 | include_exposure_idx: bool = False,
67 | include_exposure_values: bool = False) -> Rays:
68 | data_fn = lambda n: jnp.zeros((1, n))
69 | exposure_kwargs = {}
70 | if include_exposure_idx:
71 | exposure_kwargs['exposure_idx'] = data_fn(1).astype(jnp.int32)
72 | if include_exposure_values:
73 | exposure_kwargs['exposure_values'] = data_fn(1)
74 | return Rays(
75 | origins=data_fn(3),
76 | directions=data_fn(3),
77 | viewdirs=data_fn(3),
78 | radii=data_fn(1),
79 | imageplane=data_fn(2),
80 | lossmult=data_fn(1),
81 | near=data_fn(1),
82 | far=data_fn(1),
83 | cam_idx=data_fn(1).astype(jnp.int32),
84 | features=data_fn(feat_dim),
85 | **exposure_kwargs)
86 |
87 |
88 | @flax.struct.dataclass
89 | class Batch:
90 | """Data batch for NeRF training or testing."""
91 | rays: Union[Pixels, Rays]
92 | rgb: Optional[_Array] = None
93 | disps: Optional[_Array] = None
94 | normals: Optional[_Array] = None
95 | alphas: Optional[_Array] = None
96 | features: Optional[_Array] = None
97 |
98 |
99 | class DataSplit(enum.Enum):
100 | """Dataset split."""
101 | TRAIN = 'train'
102 | TEST = 'test'
103 |
104 |
105 | class BatchingMethod(enum.Enum):
106 | """Draw rays randomly from a single image or all images, in each batch."""
107 | ALL_IMAGES = 'all_images'
108 | SINGLE_IMAGE = 'single_image'
109 |
110 |
111 | def open_file(pth, mode='r'):
112 | return open(pth, mode=mode)
113 |
114 |
115 | def file_exists(pth):
116 | return os.path.exists(pth)
117 |
118 |
119 | def listdir(pth):
120 | return os.listdir(pth)
121 |
122 |
123 | def isdir(pth):
124 | return os.path.isdir(pth)
125 |
126 |
127 | def makedirs(pth):
128 | if not file_exists(pth):
129 | os.makedirs(pth)
130 |
131 |
132 | def shard(xs):
133 | """Split data into shards for multiple devices along the first dimension."""
134 | return jax.tree_util.tree_map(
135 | lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
136 |
137 |
138 | def unshard(x, padding=0):
139 | """Collect the sharded tensor to the shape before sharding."""
140 | y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))
141 | if padding > 0:
142 | y = y[:-padding]
143 | return y
144 |
145 | def load_npy(pth: str) -> np.ndarray:
146 | """Load a numpy array."""
147 | with open_file(pth, 'rb') as f:
148 | data = np.load(f).astype(np.float32)
149 | return data
150 |
151 | def load_img(pth: str) -> np.ndarray:
152 | """Load an image and cast to float32."""
153 | with open_file(pth, 'rb') as f:
154 | image = np.array(Image.open(f), dtype=np.float32)
155 | return image
156 |
157 |
158 | def load_exif(pth: str) -> Dict[str, Any]:
159 | """Load EXIF data for an image."""
160 | with open_file(pth, 'rb') as f:
161 | image_pil = Image.open(f)
162 | exif_pil = image_pil._getexif() # pylint: disable=protected-access
163 | if exif_pil is not None:
164 | exif = {
165 | ExifTags.TAGS[k]: v for k, v in exif_pil.items() if k in ExifTags.TAGS
166 | }
167 | else:
168 | exif = {}
169 | return exif
170 |
171 |
172 | def save_img_u8(img, pth):
173 | """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG."""
174 | with open_file(pth, 'wb') as f:
175 | Image.fromarray(
176 | (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save(
177 | f, 'PNG')
178 |
179 |
180 | def save_img_f32(depthmap, pth):
181 | """Save an image (probably a depthmap) to disk as a float32 TIFF."""
182 | with open_file(pth, 'wb') as f:
183 | Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF')
184 |
--------------------------------------------------------------------------------
/internal/vis.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper functions for visualizing things."""
16 |
17 | from internal import stepfun
18 | import jax.numpy as jnp
19 | from matplotlib import cm
20 |
21 |
22 | def weighted_percentile(x, w, ps, assume_sorted=False):
23 | """Compute the weighted percentile(s) of a single vector."""
24 | x = x.reshape([-1])
25 | w = w.reshape([-1])
26 | if not assume_sorted:
27 | sortidx = jnp.argsort(x)
28 | x, w = x[sortidx], w[sortidx]
29 | acc_w = jnp.cumsum(w)
30 | return jnp.interp(jnp.array(ps) * (acc_w[-1] / 100), acc_w, x)
31 |
32 |
33 | def sinebow(h):
34 | """A cyclic and uniform colormap, see http://basecase.org/env/on-rainbows."""
35 | f = lambda x: jnp.sin(jnp.pi * x)**2
36 | return jnp.stack([f(3 / 6 - h), f(5 / 6 - h), f(7 / 6 - h)], -1)
37 |
38 |
39 | def matte(vis, acc, dark=0.8, light=1.0, width=8):
40 | """Set non-accumulated pixels to a Photoshop-esque checker pattern."""
41 | bg_mask = jnp.logical_xor(
42 | (jnp.arange(acc.shape[0]) % (2 * width) // width)[:, None],
43 | (jnp.arange(acc.shape[1]) % (2 * width) // width)[None, :])
44 | bg = jnp.where(bg_mask, light, dark)
45 | return vis * acc[:, :, None] + (bg * (1 - acc))[:, :, None]
46 |
47 |
48 | def visualize_cmap(value,
49 | weight,
50 | colormap,
51 | lo=None,
52 | hi=None,
53 | percentile=99.,
54 | curve_fn=lambda x: x,
55 | modulus=None,
56 | matte_background=True):
57 | """Visualize a 1D image and a 1D weighting according to some colormap.
58 |
59 | Args:
60 | value: A 1D image.
61 | weight: A weight map, in [0, 1].
62 | colormap: A colormap function.
63 | lo: The lower bound to use when rendering, if None then use a percentile.
64 | hi: The upper bound to use when rendering, if None then use a percentile.
65 | percentile: What percentile of the value map to crop to when automatically
66 | generating `lo` and `hi`. Depends on `weight` as well as `value'.
67 | curve_fn: A curve function that gets applied to `value`, `lo`, and `hi`
68 | before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).
69 | modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If
70 | `modulus` is not None, `lo`, `hi` and `percentile` will have no effect.
71 | matte_background: If True, matte the image over a checkerboard.
72 |
73 | Returns:
74 | A colormap rendering.
75 | """
76 | # Identify the values that bound the middle of `value' according to `weight`.
77 | lo_auto, hi_auto = weighted_percentile(
78 | value, weight, [50 - percentile / 2, 50 + percentile / 2])
79 |
80 | # If `lo` or `hi` are None, use the automatically-computed bounds above.
81 | eps = jnp.finfo(jnp.float32).eps
82 | lo = lo or (lo_auto - eps)
83 | hi = hi or (hi_auto + eps)
84 |
85 | # Curve all values.
86 | value, lo, hi = [curve_fn(x) for x in [value, lo, hi]]
87 |
88 | # Wrap the values around if requested.
89 | if modulus:
90 | value = jnp.mod(value, modulus) / modulus
91 | else:
92 | # Otherwise, just scale to [0, 1].
93 | value = jnp.nan_to_num(
94 | jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1))
95 |
96 | if colormap:
97 | colorized = colormap(value)[:, :, :3]
98 | else:
99 | if len(value.shape) != 3:
100 | raise ValueError(f'value must have 3 dims but has {len(value.shape)}')
101 | if value.shape[-1] != 3:
102 | raise ValueError(
103 | f'value must have 3 channels but has {len(value.shape[-1])}')
104 | colorized = value
105 |
106 | return matte(colorized, weight) if matte_background else colorized
107 |
108 |
109 | def visualize_coord_mod(coords, acc):
110 | """Visualize the coordinate of each point within its "cell"."""
111 | return matte(((coords + 1) % 2) / 2, acc)
112 |
113 |
114 | def visualize_rays(dist,
115 | dist_range,
116 | weights,
117 | rgbs,
118 | accumulate=False,
119 | renormalize=False,
120 | resolution=2048,
121 | bg_color=0.8):
122 | """Visualize a bundle of rays."""
123 | dist_vis = jnp.linspace(*dist_range, resolution + 1)
124 | vis_rgb, vis_alpha = [], []
125 | for ds, ws, rs in zip(dist, weights, rgbs):
126 | vis_rs, vis_ws = [], []
127 | for d, w, r in zip(ds, ws, rs):
128 | if accumulate:
129 | # Produce the accumulated color and weight at each point along the ray.
130 | w_csum = jnp.cumsum(w, axis=0)
131 | rw_csum = jnp.cumsum((r * w[:, None]), axis=0)
132 | eps = jnp.finfo(jnp.float32).eps
133 | r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum
134 | vis_rs.append(stepfun.resample(dist_vis, d, r.T, use_avg=True).T)
135 | vis_ws.append(stepfun.resample(dist_vis, d, w.T, use_avg=True).T)
136 | vis_rgb.append(jnp.stack(vis_rs))
137 | vis_alpha.append(jnp.stack(vis_ws))
138 | vis_rgb = jnp.stack(vis_rgb, axis=1)
139 | vis_alpha = jnp.stack(vis_alpha, axis=1)
140 |
141 | if renormalize:
142 | # Scale the alphas so that the largest value is 1, for visualization.
143 | vis_alpha /= jnp.maximum(jnp.finfo(jnp.float32).eps, jnp.max(vis_alpha))
144 |
145 | if resolution > vis_rgb.shape[0]:
146 | rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1)
147 | stride = rep * vis_rgb.shape[1]
148 |
149 | vis_rgb = jnp.tile(vis_rgb, (1, 1, rep, 1)).reshape((-1,) + vis_rgb.shape[2:])
150 | vis_alpha = jnp.tile(vis_alpha, (1, 1, rep)).reshape((-1,) + vis_alpha.shape[2:])
151 |
152 | # Add a strip of background pixels after each set of levels of rays.
153 | vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:])
154 | vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:])
155 | vis_rgb = jnp.concatenate([vis_rgb, jnp.zeros_like(vis_rgb[:, :1])],
156 | axis=1).reshape((-1,) + vis_rgb.shape[2:])
157 | vis_alpha = jnp.concatenate(
158 | [vis_alpha, jnp.zeros_like(vis_alpha[:, :1])],
159 | axis=1).reshape((-1,) + vis_alpha.shape[2:])
160 |
161 | # Matte the RGB image over the background.
162 | vis = vis_rgb * vis_alpha[..., None] + (bg_color * (1 - vis_alpha))[..., None]
163 |
164 | # Remove the final row of background pixels.
165 | vis = vis[:-1]
166 | vis_alpha = vis_alpha[:-1]
167 | return vis, vis_alpha
168 |
169 |
170 | def visualize_suite(rendering, rays):
171 | """A wrapper around other visualizations for easy integration."""
172 |
173 | depth_curve_fn = lambda x: -jnp.log(x + jnp.finfo(jnp.float32).eps)
174 |
175 | rgb = rendering['rgb']
176 | acc = rendering['acc']
177 |
178 | distance_mean = rendering['distance_mean']
179 | distance_median = rendering['distance_median']
180 | distance_p5 = rendering['distance_percentile_5']
181 | distance_p95 = rendering['distance_percentile_95']
182 | acc = jnp.where(jnp.isnan(distance_mean), jnp.zeros_like(acc), acc)
183 |
184 | # The xyz coordinates where rays terminate.
185 | coords = rays.origins + rays.directions * distance_mean[:, :, None]
186 |
187 | vis_depth_mean, vis_depth_median = [
188 | visualize_cmap(x, acc, cm.get_cmap('turbo'), curve_fn=depth_curve_fn)
189 | for x in [distance_mean, distance_median]
190 | ]
191 |
192 | # Render three depth percentiles directly to RGB channels, where the spacing
193 | # determines the color. delta == big change, epsilon = small change.
194 | # Gray: A strong discontinuitiy, [x-epsilon, x, x+epsilon]
195 | # Purple: A thin but even density, [x-delta, x, x+delta]
196 | # Red: A thin density, then a thick density, [x-delta, x, x+epsilon]
197 | # Blue: A thick density, then a thin density, [x-epsilon, x, x+delta]
198 | vis_depth_triplet = visualize_cmap(
199 | jnp.stack(
200 | [2 * distance_median - distance_p5, distance_median, distance_p95],
201 | axis=-1),
202 | acc,
203 | None,
204 | curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps))
205 |
206 | dist = rendering['ray_sdist']
207 | dist_range = (0, 1)
208 | weights = rendering['ray_weights']
209 | rgbs = [jnp.clip(r, 0, 1) for r in rendering['ray_rgbs']]
210 |
211 | vis_ray_colors, _ = visualize_rays(dist, dist_range, weights, rgbs)
212 |
213 | sqrt_weights = [jnp.sqrt(w) for w in weights]
214 | sqrt_ray_weights, ray_alpha = visualize_rays(
215 | dist,
216 | dist_range,
217 | [jnp.ones_like(lw) for lw in sqrt_weights],
218 | [lw[..., None] for lw in sqrt_weights],
219 | bg_color=0,
220 | )
221 | sqrt_ray_weights = sqrt_ray_weights[..., 0]
222 | null_color = jnp.array([1., 0., 0.])
223 | vis_ray_weights = jnp.where(
224 | ray_alpha[:, :, None] == 0,
225 | null_color[None, None],
226 | visualize_cmap(
227 | sqrt_ray_weights,
228 | jnp.ones_like(sqrt_ray_weights),
229 | cm.get_cmap('gray'),
230 | lo=0,
231 | hi=1,
232 | matte_background=False,
233 | ),
234 | )
235 | vis_uncertainty = visualize_cmap(
236 | rendering['uncer'][...,0],
237 | acc,
238 | cm.get_cmap('turbo'),
239 | lo=0.2,
240 | hi=2,
241 | )
242 | vis = {
243 | 'color': rgb,
244 | 'acc': acc,
245 | 'color_matte': matte(rgb, acc),
246 | 'depth_mean': vis_depth_mean,
247 | 'depth_median': vis_depth_median,
248 | 'depth_triplet': vis_depth_triplet,
249 | 'coords_mod': visualize_coord_mod(coords, acc),
250 | 'ray_colors': vis_ray_colors,
251 | 'ray_weights': vis_ray_weights,
252 | 'uncertainty': vis_uncertainty,
253 | 'uncertainty_raw': rendering['uncer'][...,0],
254 | }
255 |
256 | if 'rgb_cc' in rendering:
257 | vis['color_corrected'] = rendering['rgb_cc']
258 |
259 | return vis
260 |
--------------------------------------------------------------------------------
/media/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvg/nerf-on-the-go/0c32fbb5fdec68d989d406618c253cc56524f64f/media/teaser.gif
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Render script."""
16 |
17 | import concurrent.futures
18 | import functools
19 | import glob
20 | import os
21 | import time
22 |
23 | from absl import app
24 | from flax.training import checkpoints
25 | import gin
26 | from internal import configs
27 | from internal import datasets
28 | from internal import models
29 | from internal import train_utils
30 | from internal import utils
31 | import jax
32 | from jax import random
33 | from matplotlib import cm
34 | import mediapy as media
35 | import numpy as np
36 |
37 | configs.define_common_flags()
38 | jax.config.parse_flags_with_absl()
39 |
40 |
41 | def create_videos(config, base_dir, out_dir, out_name, num_frames):
42 | """Creates videos out of the images saved to disk."""
43 | names = [n for n in config.checkpoint_dir.split('/') if n]
44 | # Last two parts of checkpoint path are experiment name and scene name.
45 | exp_name, scene_name = names[-2:]
46 | video_prefix = f'{scene_name}_{exp_name}_{out_name}'
47 |
48 | zpad = max(3, len(str(num_frames - 1)))
49 | idx_to_str = lambda idx: str(idx).zfill(zpad)
50 |
51 | utils.makedirs(base_dir)
52 |
53 | # Load one example frame to get image shape and depth range.
54 | depth_file = os.path.join(out_dir, f'distance_mean_{idx_to_str(0)}.tiff')
55 | depth_frame = utils.load_img(depth_file)
56 | shape = depth_frame.shape
57 | p = config.render_dist_percentile
58 | distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p])
59 | lo, hi = [config.render_dist_curve_fn(x) for x in distance_limits]
60 | print(f'Video shape is {shape[:2]}')
61 |
62 | video_kwargs = {
63 | 'shape': shape[:2],
64 | 'codec': 'h264',
65 | 'fps': config.render_video_fps,
66 | 'crf': config.render_video_crf,
67 | }
68 |
69 | for k in ['color', 'acc', 'distance_mean', 'distance_median']:
70 | video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4')
71 | input_format = 'gray' if k == 'acc' else 'rgb'
72 | file_ext = 'png' if k in ['color', 'normals'] else 'tiff'
73 | idx = 0
74 | file0 = os.path.join(out_dir, f'{k}_{idx_to_str(0)}.{file_ext}')
75 | if not utils.file_exists(file0):
76 | print(f'Images missing for tag {k}')
77 | continue
78 | print(f'Making video {video_file}...')
79 | with media.VideoWriter(
80 | video_file, **video_kwargs, input_format=input_format) as writer:
81 | for idx in range(num_frames):
82 | img_file = os.path.join(out_dir, f'{k}_{idx_to_str(idx)}.{file_ext}')
83 | if not utils.file_exists(img_file):
84 | ValueError(f'Image file {img_file} does not exist.')
85 | img = utils.load_img(img_file)
86 | if k in ['color', 'normals']:
87 | img = img / 255.
88 | elif k.startswith('distance'):
89 | img = config.render_dist_curve_fn(img)
90 | img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)
91 | img = cm.get_cmap('turbo')(img)[..., :3]
92 |
93 | frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)
94 | writer.add_image(frame)
95 | idx += 1
96 |
97 |
98 | def main(unused_argv):
99 |
100 | config = configs.load_config(save_config=False)
101 |
102 | dataset = datasets.load_dataset('test', config.data_dir, config)
103 |
104 | key = random.PRNGKey(20200823)
105 | _, state, render_eval_pfn, _, _ = train_utils.setup_model(config, key)
106 |
107 | if config.rawnerf_mode:
108 | postprocess_fn = dataset.metadata['postprocess_fn']
109 | else:
110 | postprocess_fn = lambda z: z
111 |
112 | state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)
113 | step = int(state.step)
114 | print(f'Rendering checkpoint at step {step}.')
115 |
116 | out_name = 'path_renders' if config.render_path else 'test_preds'
117 | out_name = f'{out_name}_step_{step}'
118 | base_dir = config.render_dir
119 | if base_dir is None:
120 | base_dir = os.path.join(config.checkpoint_dir, 'render')
121 | out_dir = os.path.join(base_dir, out_name)
122 | if not utils.isdir(out_dir):
123 | utils.makedirs(out_dir)
124 |
125 | path_fn = lambda x: os.path.join(out_dir, x)
126 |
127 | # Ensure sufficient zero-padding of image indices in output filenames.
128 | zpad = max(3, len(str(dataset.size - 1)))
129 | idx_to_str = lambda idx: str(idx).zfill(zpad)
130 |
131 | if config.render_save_async:
132 | async_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
133 | async_futures = []
134 | def save_fn(fn, *args, **kwargs):
135 | async_futures.append(async_executor.submit(fn, *args, **kwargs))
136 | else:
137 | def save_fn(fn, *args, **kwargs):
138 | fn(*args, **kwargs)
139 |
140 | for idx in range(dataset.size):
141 | if idx % config.render_num_jobs != config.render_job_id:
142 | continue
143 | # If current image and next image both already exist, skip ahead.
144 | idx_str = idx_to_str(idx)
145 | curr_file = path_fn(f'color_{idx_str}.png')
146 | next_idx_str = idx_to_str(idx + config.render_num_jobs)
147 | next_file = path_fn(f'color_{next_idx_str}.png')
148 | if utils.file_exists(curr_file) and utils.file_exists(next_file):
149 | print(f'Image {idx}/{dataset.size} already exists, skipping')
150 | continue
151 | print(f'Evaluating image {idx+1}/{dataset.size}')
152 | eval_start_time = time.time()
153 | rays = dataset.generate_ray_batch(idx).rays
154 | train_frac = 1.
155 | rendering = models.render_image(
156 | functools.partial(render_eval_pfn, state.params, train_frac),
157 | rays, None, config)
158 | print(f'Rendered in {(time.time() - eval_start_time):0.3f}s')
159 |
160 | if jax.host_id() != 0: # Only record via host 0.
161 | continue
162 |
163 | rendering['rgb'] = postprocess_fn(rendering['rgb'])
164 |
165 | save_fn(
166 | utils.save_img_u8, rendering['rgb'], path_fn(f'color_{idx_str}.png'))
167 | save_fn(
168 | utils.save_img_f32, rendering['distance_mean'],
169 | path_fn(f'distance_mean_{idx_str}.tiff'))
170 | save_fn(
171 | utils.save_img_f32, rendering['distance_median'],
172 | path_fn(f'distance_median_{idx_str}.tiff'))
173 | save_fn(
174 | utils.save_img_f32, rendering['acc'], path_fn(f'acc_{idx_str}.tiff'))
175 |
176 | if config.render_save_async:
177 | # Wait until all worker threads finish.
178 | async_executor.shutdown(wait=True)
179 |
180 | # This will ensure that exceptions in child threads are raised to the
181 | # main thread.
182 | for future in async_futures:
183 | future.result()
184 |
185 | time.sleep(1)
186 | num_files = len(glob.glob(path_fn('acc_*.tiff')))
187 | time.sleep(10)
188 | if jax.host_id() == 0 and num_files == dataset.size:
189 | print(f'All files found, creating videos (job {config.render_job_id}).')
190 | create_videos(config, base_dir, out_dir, out_name, dataset.size)
191 |
192 | # A hack that forces Jax to keep all TPUs alive until every TPU is finished.
193 | x = jax.numpy.ones([jax.local_device_count()])
194 | x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
195 | print(x)
196 |
197 |
198 | if __name__ == '__main__':
199 | with gin.config_scope('eval'): # Use the same scope as eval.py
200 | app.run(main)
201 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | jax
3 | jaxlib
4 | opencv-python
5 | Pillow
6 | tensorboard
7 | tensorflow
8 | gin-config
9 | dm_pix
10 | rawpy
11 | mediapy
12 | lpips_jax
13 | chex
14 | optax
15 | ml-dtypes
16 | flax
17 | gdown
18 | torch
19 | torchvision
20 | torchaudio
21 | orbax-checkpoint==0.3.5
22 | matplotlib==3.8.4
--------------------------------------------------------------------------------
/scripts/download_on-the-go.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mkdir Datasets
4 | wget https://cvg-data.inf.ethz.ch/on-the-go.zip
5 | unzip on-the-go.zip -d Datasets
6 | rm on-the-go.zip
7 | # Base directory containing the sequence directories
8 | base_dir="./Datasets/on-the-go"
9 |
10 | # Loop through each sequence directory in the base directory
11 | for seq_dir in "$base_dir"/*; do
12 | # Extract just the name of the directory (the sequence name)
13 | seq_name=$(basename "$seq_dir")
14 | echo "Processing sequence: $seq_name"
15 |
16 | # Determine the downsampling rate based on the sequence name
17 | if [ "$seq_name" = "arcdetriomphe" ] || [ "$seq_name" = "patio" ]; then
18 | rate=4
19 | else
20 | rate=8
21 | fi
22 |
23 | # Calculate percentage for resizing based on the downsample rate
24 | percentage=$(bc <<< "scale=2; 100 / $rate")
25 |
26 | # Directory names for images, defined relative to the base_dir
27 | original_images_dir="$seq_dir/images"
28 | downsampled_images_dir="$seq_dir/images_$rate"
29 |
30 | # Copy images to new directory before downsampling, handling both JPG and jpg
31 | cp -r "$original_images_dir" "$downsampled_images_dir"
32 |
33 | # Downsample images using mogrify for both JPG and jpg
34 | pushd "$downsampled_images_dir"
35 | ls | xargs -P 8 -I {} mogrify -resize ${percentage}% {}
36 | popd
37 |
38 | done
39 |
--------------------------------------------------------------------------------
/scripts/eval_on-the-go.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH -n 4
4 | #SBATCH --time=15:00:00
5 | #SBATCH --mem-per-cpu=20g
6 | #SBATCH --tmp=4000 # per node!!
7 | #SBATCH --gpus=4090:4
8 | #SBATCH --gres=gpumem:20g
9 | #SBATCH --job-name=yard_high
10 | #SBATCH --output=slurm/yard_high.out
11 | #SBATCH --error=slurm/yard_high.err
12 |
13 | python -m eval \
14 | --gin_configs=configs/360_dino.gin \
15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio_high'" \
16 | --gin_bindings="Config.checkpoint_dir = 'output/patio_high/run_1/checkpoints'" \
17 | --gin_bindings="Config.eval_train = False" \
18 | --gin_bindings="Config.factor = 8" \
19 |
20 |
--------------------------------------------------------------------------------
/scripts/eval_on-the-go_HD.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH -n 4
4 | #SBATCH --time=15:00:00
5 | #SBATCH --mem-per-cpu=20g
6 | #SBATCH --tmp=4000 # per node!!
7 | #SBATCH --gpus=4090:4
8 | #SBATCH --gres=gpumem:20g
9 | #SBATCH --job-name=yard_high
10 | #SBATCH --output=slurm/yard_high.out
11 | #SBATCH --error=slurm/yard_high.err
12 |
13 | python -m eval \
14 | --gin_configs=configs/360_dino.gin \
15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio'" \
16 | --gin_bindings="Config.checkpoint_dir = 'output/patio/run_1/checkpoints'" \
17 | --gin_bindings="Config.eval_train = False" \
18 | --gin_bindings="Config.factor = 4" \
19 | --gin_bindings="Config.H = 1080" \
20 | --gin_bindings="Config.W = 1920" \
21 | --gin_bindings="Config.factor = 4" \
22 | --gin_bindings="Config.feat_rate = 2" \
23 |
24 |
25 |
--------------------------------------------------------------------------------
/scripts/feature_extract.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import torchvision.transforms as T
4 | import os
5 | # import hubconf
6 | from tqdm import tqdm
7 | import shutil
8 | import numpy as np
9 |
10 | if __name__ == '__main__':
11 | import argparse
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--seq', type=str, required=True)
14 | parser.add_argument('--rate', type=int, default=4)
15 | parser.add_argument('--H', type=int, default=3024)
16 | parser.add_argument('--W', type=int, default=4032)
17 |
18 | args = parser.parse_args()
19 | base_path = f"./Datasets/on-the-go/{args.seq}"
20 |
21 | device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
22 | dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
23 | dinov2_vits14.to(device)
24 | extractor = dinov2_vits14
25 |
26 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
27 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
28 | RATE = args.rate
29 | RESIZE_H = (args.H // RATE) // 14 * 14
30 | RESIZE_W = (args.W // RATE) // 14 * 14
31 |
32 | if os.path.exists(os.path.join(base_path, f'features_{RATE}')):
33 | shutil.rmtree(os.path.join(base_path, f'features_{RATE}'))
34 | folder = os.path.join(base_path, 'images')
35 | files = os.listdir(folder)
36 | files = [os.path.join(folder, f) for f in files]
37 | features = []
38 | for f in tqdm(files):
39 | img = Image.open(f).convert('RGB')
40 | transform = T.Compose([
41 | T.Resize((RESIZE_H, RESIZE_W)),
42 | T.ToTensor(),
43 | T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
44 | ])
45 | img = transform(img)[:3].unsqueeze(0)
46 | with torch.no_grad():
47 | features_dict = extractor.forward_features(img.cuda())
48 | features = features_dict['x_norm_patchtokens'].view(RESIZE_H // 14, RESIZE_W // 14, -1)
49 | img_type = f[-4:]
50 | save_path = f.replace(f'{img_type}', '.npy').replace('/images/', f'/features_{RATE}/')
51 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
52 | np.save(save_path, features.detach().cpu().numpy())
53 |
--------------------------------------------------------------------------------
/scripts/feature_extract.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Base directory containing the sequence directories
4 | base_dir="./Datasets/on-the-go"
5 |
6 | # Define the sequences that need special parameters
7 | special_seqs=("arcdetriomphe" "patio")
8 |
9 | # Loop through each sequence directory in the base directory
10 | for seq_dir in "$base_dir"/*; do
11 | # Extract just the name of the directory (the sequence name)
12 | seq_name=$(basename "$seq_dir")
13 |
14 | # Check if the sequence is one of the special cases
15 | if [[ " ${special_seqs[@]} " =~ " $seq_name " ]]; then
16 | # Run feature extraction with additional parameters for special sequences
17 | python scripts/feature_extract.py --seq "$seq_name" --H 1080 --W 1920 --rate 2
18 | else
19 | # Run feature extraction without additional parameters for all other sequences
20 | python scripts/feature_extract.py --seq "$seq_name"
21 | fi
22 | done
23 |
--------------------------------------------------------------------------------
/scripts/local_colmap_and_resize.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2022 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | # Set to 0 if you do not have a GPU.
18 | USE_GPU=1
19 | # Path to a directory `base/` with images in `base/images/`.
20 | DATASET_PATH=$1
21 | # Recommended CAMERA values: OPENCV for perspective, OPENCV_FISHEYE for fisheye.
22 | CAMERA=${2:-OPENCV}
23 |
24 |
25 | # Run COLMAP.
26 |
27 | ### Feature extraction
28 |
29 | colmap feature_extractor \
30 | --database_path "$DATASET_PATH"/database.db \
31 | --image_path "$DATASET_PATH"/images \
32 | --ImageReader.single_camera 1 \
33 | --ImageReader.camera_model "$CAMERA" \
34 | --SiftExtraction.use_gpu "$USE_GPU"
35 |
36 |
37 | ### Feature matching
38 |
39 | colmap exhaustive_matcher \
40 | --database_path "$DATASET_PATH"/database.db \
41 | --SiftMatching.use_gpu "$USE_GPU"
42 |
43 | ## Use if your scene has > 500 images
44 | ## Replace this path with your own local copy of the file.
45 | ## Download from: https://demuc.de/colmap/#download
46 | # VOCABTREE_PATH=/usr/local/google/home/bmild/vocab_tree_flickr100K_words32K.bin
47 | # colmap vocab_tree_matcher \
48 | # --database_path "$DATASET_PATH"/database.db \
49 | # --VocabTreeMatching.vocab_tree_path $VOCABTREE_PATH \
50 | # --SiftMatching.use_gpu "$USE_GPU"
51 |
52 |
53 | ### Bundle adjustment
54 |
55 | # The default Mapper tolerance is unnecessarily large,
56 | # decreasing it speeds up bundle adjustment steps.
57 | mkdir -p "$DATASET_PATH"/sparse
58 | colmap mapper \
59 | --database_path "$DATASET_PATH"/database.db \
60 | --image_path "$DATASET_PATH"/images \
61 | --output_path "$DATASET_PATH"/sparse \
62 | --Mapper.ba_global_function_tolerance=0.000001
63 |
64 |
65 | ### Image undistortion
66 |
67 | ## Use this if you want to undistort your images into ideal pinhole intrinsics.
68 | # mkdir -p "$DATASET_PATH"/dense
69 | # colmap image_undistorter \
70 | # --image_path "$DATASET_PATH"/images \
71 | # --input_path "$DATASET_PATH"/sparse/0 \
72 | # --output_path "$DATASET_PATH"/dense \
73 | # --output_type COLMAP
74 |
75 | # Resize images.
76 |
77 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_2
78 |
79 | pushd "$DATASET_PATH"/images_2
80 | ls | xargs -P 8 -I {} mogrify -resize 50% {}
81 | popd
82 |
83 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_4
84 |
85 | pushd "$DATASET_PATH"/images_4
86 | ls | xargs -P 8 -I {} mogrify -resize 25% {}
87 | popd
88 |
89 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_8
90 |
91 | pushd "$DATASET_PATH"/images_8
92 | ls | xargs -P 8 -I {} mogrify -resize 12.5% {}
93 | popd
94 |
--------------------------------------------------------------------------------
/scripts/render_on-the-go.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH -n 4
4 | #SBATCH --time=15:00:00
5 | #SBATCH --mem-per-cpu=20g
6 | #SBATCH --tmp=4000 # per node!!
7 | #SBATCH --gpus=4090:4
8 | #SBATCH --gres=gpumem:20g
9 | #SBATCH --job-name=patio_high
10 | #SBATCH --output=slurm/patio_high.out
11 | #SBATCH --error=slurm/patio_high.err
12 |
13 | python -m render \
14 | --gin_configs=configs/360_dino.gin \
15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio_high'" \
16 | --gin_bindings="Config.checkpoint_dir = 'output/patio_high/run_1/checkpoints'" \
17 | --gin_bindings="Config.render_dir = 'output/patio_high/run_1/checkpoints'" \
18 | --gin_bindings="Config.render_path = True" \
19 | --gin_bindings="Config.render_path_frames = 160" \
20 | --gin_bindings="Config.render_video_fps = 160" \
21 |
--------------------------------------------------------------------------------
/scripts/render_on-the-go_HD.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH -n 4
4 | #SBATCH --time=15:00:00
5 | #SBATCH --mem-per-cpu=20g
6 | #SBATCH --tmp=4000 # per node!!
7 | #SBATCH --gpus=4090:4
8 | #SBATCH --gres=gpumem:20g
9 | #SBATCH --job-name=patio_high
10 | #SBATCH --output=slurm/patio_high.out
11 | #SBATCH --error=slurm/patio_high.err
12 |
13 | python -m render \
14 | --gin_configs=configs/360_dino.gin \
15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio_high'" \
16 | --gin_bindings="Config.checkpoint_dir = 'output/patio_high/run_1/checkpoints'" \
17 | --gin_bindings="Config.render_dir = 'output/patio_high/run_1/checkpoints'" \
18 | --gin_bindings="Config.render_path = True" \
19 | --gin_bindings="Config.render_path_frames = 160" \
20 | --gin_bindings="Config.render_video_fps = 160" \
21 | --gin_bindings="Config.H = 1080" \
22 | --gin_bindings="Config.W = 1920" \
23 | --gin_bindings="Config.factor = 4" \
24 | --gin_bindings="Config.feat_rate = 2" \
25 |
--------------------------------------------------------------------------------
/scripts/run_all_unit_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2022 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | python -m unittest tests.camera_utils_test
18 | python -m unittest tests.stepfun_test
19 | python -m unittest tests.coord_test
20 | python -m unittest tests.math_test
21 |
--------------------------------------------------------------------------------
/scripts/train_on-the-go.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH -n 4
4 | #SBATCH --time=36:00:00
5 | #SBATCH --mem-per-cpu=20g
6 | #SBATCH --tmp=4000 # per node!!
7 | #SBATCH --gpus=a100_80gb:1
8 | #SBATCH --gres=gpumem:20g
9 | #SBATCH --job-name=patio_high
10 | #SBATCH --output=slurm/patio_high.out
11 | #SBATCH --error=slurm/patio_high.err
12 |
13 | python -m train \
14 | --gin_configs=configs/360_dino.gin \
15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio_high'" \
16 | --gin_bindings="Config.checkpoint_dir = 'output/patio_high/run_1/checkpoints'" \
17 | --gin_bindings="Config.patch_size = 32" \
18 | --gin_bindings="Config.dilate = 4" \
19 | --gin_bindings="Config.data_loss_type = 'on-the-go'" \
--------------------------------------------------------------------------------
/scripts/train_on-the-go_HD.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH -n 4
4 | #SBATCH --time=36:00:00
5 | #SBATCH --mem-per-cpu=20g
6 | #SBATCH --tmp=4000 # per node!!
7 | #SBATCH --gpus=a100_80gb:1
8 | #SBATCH --gres=gpumem:20g
9 | #SBATCH --job-name=patio
10 | #SBATCH --output=slurm/patio.out
11 | #SBATCH --error=slurm/patio.err
12 |
13 | python -m train \
14 | --gin_configs=configs/360_dino.gin \
15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio'" \
16 | --gin_bindings="Config.checkpoint_dir = 'output/patio/run_1/checkpoints'" \
17 | --gin_bindings="Config.patch_size = 32" \
18 | --gin_bindings="Config.dilate = 4" \
19 | --gin_bindings="Config.data_loss_type = 'on-the-go'" \
20 | --gin_bindings="Config.train_render_every = 5000" \
21 | --gin_bindings="Config.H = 1080" \
22 | --gin_bindings="Config.W = 1920" \
23 | --gin_bindings="Config.factor = 4" \
24 | --gin_bindings="Config.feat_rate = 2" \
25 |
--------------------------------------------------------------------------------
/tests/camera_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for camera_utils."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from internal import camera_utils
20 | from jax import random
21 | import jax.numpy as jnp
22 | import numpy as np
23 |
24 |
25 | class CameraUtilsTest(parameterized.TestCase):
26 |
27 | def test_convert_to_ndc(self):
28 | rng = random.PRNGKey(0)
29 | for _ in range(10):
30 | # Random pinhole camera intrinsics.
31 | key, rng = random.split(rng)
32 | focal, width, height = random.uniform(key, (3,), minval=100., maxval=200.)
33 | camtopix = camera_utils.intrinsic_matrix(focal, focal, width / 2.,
34 | height / 2.)
35 | pixtocam = np.linalg.inv(camtopix)
36 | near = 1.
37 |
38 | # Random rays, pointing forward (negative z direction).
39 | num_rays = 1000
40 | key, rng = random.split(rng)
41 | origins = jnp.array([0., 0., 1.])
42 | origins += random.uniform(key, (num_rays, 3), minval=-1., maxval=1.)
43 | directions = jnp.array([0., 0., -1.])
44 | directions += random.uniform(key, (num_rays, 3), minval=-.5, maxval=.5)
45 |
46 | # Project world-space points along each ray into NDC space.
47 | t = jnp.linspace(0., 1., 10)
48 | pts_world = origins + t[:, None, None] * directions
49 | pts_ndc = jnp.stack([
50 | -focal / (.5 * width) * pts_world[..., 0] / pts_world[..., 2],
51 | -focal / (.5 * height) * pts_world[..., 1] / pts_world[..., 2],
52 | 1. + 2. * near / pts_world[..., 2],
53 | ],
54 | axis=-1)
55 |
56 | # Get NDC space rays.
57 | origins_ndc, directions_ndc = camera_utils.convert_to_ndc(
58 | origins, directions, pixtocam, near)
59 |
60 | # Ensure that the NDC space points lie on the calculated rays.
61 | directions_ndc_norm = jnp.linalg.norm(
62 | directions_ndc, axis=-1, keepdims=True)
63 | directions_ndc_unit = directions_ndc / directions_ndc_norm
64 | projection = ((pts_ndc - origins_ndc) * directions_ndc_unit).sum(axis=-1)
65 | pts_ndc_proj = origins_ndc + directions_ndc_unit * projection[..., None]
66 |
67 | # pts_ndc should be close to their projections pts_ndc_proj onto the rays.
68 | np.testing.assert_allclose(pts_ndc, pts_ndc_proj, atol=1e-5, rtol=1e-5)
69 |
70 |
71 | if __name__ == '__main__':
72 | absltest.main()
73 |
--------------------------------------------------------------------------------
/tests/coord_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit tests for coord."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from internal import coord
20 | from internal import math
21 | import jax
22 | from jax import random
23 | import jax.numpy as jnp
24 | import numpy as np
25 |
26 |
27 | def sample_covariance(rng, batch_size, num_dims):
28 | """Sample a random covariance matrix."""
29 | half_cov = jax.random.normal(rng, [batch_size] + [num_dims] * 2)
30 | cov = math.matmul(half_cov, jnp.moveaxis(half_cov, -1, -2))
31 | return cov
32 |
33 |
34 | def stable_pos_enc(x, n):
35 | """A stable pos_enc for very high degrees, courtesy of Sameer Agarwal."""
36 | sin_x = np.sin(x)
37 | cos_x = np.cos(x)
38 | output = []
39 | rotmat = np.array([[cos_x, -sin_x], [sin_x, cos_x]], dtype='double')
40 | for _ in range(n):
41 | output.append(rotmat[::-1, 0, :])
42 | rotmat = np.einsum('ijn,jkn->ikn', rotmat, rotmat)
43 | return np.reshape(np.transpose(np.stack(output, 0), [2, 1, 0]), [-1, 2 * n])
44 |
45 |
46 | class CoordTest(parameterized.TestCase):
47 |
48 | def test_stable_pos_enc(self):
49 | """Test that the stable posenc implementation works on multiples of pi/2."""
50 | n = 10
51 | x = np.linspace(-np.pi, np.pi, 5)
52 | z = stable_pos_enc(x, n).reshape([-1, 2, n])
53 | z0_true = np.zeros_like(z[:, 0, :])
54 | z1_true = np.ones_like(z[:, 1, :])
55 | z0_true[:, 0] = [0, -1, 0, 1, 0]
56 | z1_true[:, 0] = [-1, 0, 1, 0, -1]
57 | z1_true[:, 1] = [1, -1, 1, -1, 1]
58 | z_true = np.stack([z0_true, z1_true], axis=1)
59 | np.testing.assert_allclose(z, z_true, atol=1e-10)
60 |
61 | def test_contract_matches_special_case(self):
62 | """Test the math for Figure 2 of https://arxiv.org/abs/2111.12077."""
63 | n = 10
64 | _, s_to_t = coord.construct_ray_warps(jnp.reciprocal, 1, jnp.inf)
65 | s = jnp.linspace(0, 1 - jnp.finfo(jnp.float32).eps, n + 1)
66 | tc = coord.contract(s_to_t(s)[:, None])[:, 0]
67 | delta_tc = tc[1:] - tc[:-1]
68 | np.testing.assert_allclose(
69 | delta_tc, np.full_like(delta_tc, 1 / n), atol=1E-5, rtol=1E-5)
70 |
71 | def test_contract_is_bounded(self):
72 | n, d = 10000, 3
73 | rng = random.PRNGKey(0)
74 | key0, key1, rng = random.split(rng, 3)
75 | x = jnp.where(random.bernoulli(key0, shape=[n, d]), 1, -1) * jnp.exp(
76 | random.uniform(key1, [n, d], minval=-10, maxval=10))
77 | y = coord.contract(x)
78 | self.assertLessEqual(jnp.max(y), 2)
79 |
80 | def test_contract_is_noop_when_norm_is_leq_one(self):
81 | n, d = 10000, 3
82 | rng = random.PRNGKey(0)
83 | key, rng = random.split(rng)
84 | x = random.normal(key, shape=[n, d])
85 | xc = x / jnp.maximum(1, jnp.linalg.norm(x, axis=-1, keepdims=True))
86 |
87 | # Sanity check on the test itself.
88 | assert jnp.abs(jnp.max(jnp.linalg.norm(xc, axis=-1)) - 1) < 1e-6
89 |
90 | yc = coord.contract(xc)
91 | np.testing.assert_allclose(xc, yc, atol=1E-5, rtol=1E-5)
92 |
93 | def test_contract_gradients_are_finite(self):
94 | # Construct x such that we probe x == 0, where things are unstable.
95 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1)
96 | grad = jax.grad(lambda x: jnp.sum(coord.contract(x)))(x)
97 | self.assertTrue(jnp.all(jnp.isfinite(grad)))
98 |
99 | def test_inv_contract_gradients_are_finite(self):
100 | z = jnp.stack(jnp.meshgrid(*[jnp.linspace(-2, 2, 21)] * 2), axis=-1)
101 | z = z.reshape([-1, 2])
102 | z = z[jnp.sum(z**2, axis=-1) < 2, :]
103 | grad = jax.grad(lambda z: jnp.sum(coord.inv_contract(z)))(z)
104 | self.assertTrue(jnp.all(jnp.isfinite(grad)))
105 |
106 | def test_inv_contract_inverts_contract(self):
107 | """Do a round-trip from metric space to contracted space and back."""
108 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1)
109 | x_recon = coord.inv_contract(coord.contract(x))
110 | np.testing.assert_allclose(x, x_recon, atol=1E-5, rtol=1E-5)
111 |
112 | @parameterized.named_parameters(
113 | ('05_1e-5', 5, 1e-5),
114 | ('10_1e-4', 10, 1e-4),
115 | ('15_0.005', 15, 0.005),
116 | ('20_0.2', 20, 0.2), # At high degrees, our implementation is unstable.
117 | ('25_2', 25, 2), # 2 is the maximum possible error.
118 | ('30_2', 30, 2),
119 | )
120 | def test_pos_enc(self, n, tol):
121 | """test pos_enc against a stable recursive implementation."""
122 | x = np.linspace(-np.pi, np.pi, 10001)
123 | z = coord.pos_enc(x[:, None], 0, n, append_identity=False)
124 | z_stable = stable_pos_enc(x, n)
125 | max_err = np.max(np.abs(z - z_stable))
126 | print(f'PE of degree {n} has a maximum error of {max_err}')
127 | self.assertLess(max_err, tol)
128 |
129 | def test_pos_enc_matches_integrated(self):
130 | """Integrated positional encoding with a variance of zero must be pos_enc."""
131 | min_deg = 0
132 | max_deg = 10
133 | np.linspace(-jnp.pi, jnp.pi, 10)
134 | x = jnp.stack(
135 | jnp.meshgrid(*[np.linspace(-jnp.pi, jnp.pi, 10)] * 2), axis=-1)
136 | x = np.linspace(-jnp.pi, jnp.pi, 10000)
137 | z_ipe = coord.integrated_pos_enc(x, jnp.zeros_like(x), min_deg, max_deg)
138 | z_pe = coord.pos_enc(x, min_deg, max_deg, append_identity=False)
139 | # We're using a pretty wide tolerance because IPE uses safe_sin().
140 | np.testing.assert_allclose(z_pe, z_ipe, atol=1e-4)
141 |
142 | def test_track_linearize(self):
143 | rng = random.PRNGKey(0)
144 | batch_size = 20
145 | for _ in range(30):
146 | # Construct some random Gaussians with dimensionalities in [1, 10].
147 | key, rng = random.split(rng)
148 | in_dims = random.randint(key, (), 1, 10)
149 | key, rng = random.split(rng)
150 | mean = jax.random.normal(key, [batch_size, in_dims])
151 | key, rng = random.split(rng)
152 | cov = sample_covariance(key, batch_size, in_dims)
153 | key, rng = random.split(rng)
154 | out_dims = random.randint(key, (), 1, 10)
155 |
156 | # Construct a random affine transformation.
157 | key, rng = random.split(rng)
158 | a_mat = jax.random.normal(key, [out_dims, in_dims])
159 | key, rng = random.split(rng)
160 | b = jax.random.normal(key, [out_dims])
161 |
162 | def fn(x):
163 | x_vec = x.reshape([-1, x.shape[-1]])
164 | y_vec = jax.vmap(lambda z: math.matmul(a_mat, z))(x_vec) + b # pylint:disable=cell-var-from-loop
165 | y = y_vec.reshape(list(x.shape[:-1]) + [y_vec.shape[-1]])
166 | return y
167 |
168 | # Apply the affine function to the Gaussians.
169 | fn_mean_true = fn(mean)
170 | fn_cov_true = math.matmul(math.matmul(a_mat, cov), a_mat.T)
171 |
172 | # Tracking the Gaussians through a linearized function of a linear
173 | # operator should be the same.
174 | fn_mean, fn_cov = coord.track_linearize(fn, mean, cov)
175 | np.testing.assert_allclose(fn_mean, fn_mean_true, atol=1E-5, rtol=1E-5)
176 | np.testing.assert_allclose(fn_cov, fn_cov_true, atol=1e-5, rtol=1e-5)
177 |
178 | @parameterized.named_parameters(('reciprocal', jnp.reciprocal),
179 | ('log', jnp.log), ('sqrt', jnp.sqrt))
180 | def test_construct_ray_warps_extents(self, fn):
181 | n = 100
182 | rng = random.PRNGKey(0)
183 | key, rng = random.split(rng)
184 | t_near = jnp.exp(jax.random.normal(key, [n]))
185 | key, rng = random.split(rng)
186 | t_far = t_near + jnp.exp(jax.random.normal(key, [n]))
187 |
188 | t_to_s, s_to_t = coord.construct_ray_warps(fn, t_near, t_far)
189 |
190 | np.testing.assert_allclose(
191 | t_to_s(t_near), jnp.zeros_like(t_near), atol=1E-5, rtol=1E-5)
192 | np.testing.assert_allclose(
193 | t_to_s(t_far), jnp.ones_like(t_far), atol=1E-5, rtol=1E-5)
194 | np.testing.assert_allclose(
195 | s_to_t(jnp.zeros_like(t_near)), t_near, atol=1E-5, rtol=1E-5)
196 | np.testing.assert_allclose(
197 | s_to_t(jnp.ones_like(t_near)), t_far, atol=1E-5, rtol=1E-5)
198 |
199 | def test_construct_ray_warps_special_reciprocal(self):
200 | """Test fn=1/x against its closed form."""
201 | n = 100
202 | rng = random.PRNGKey(0)
203 | key, rng = random.split(rng)
204 | t_near = jnp.exp(jax.random.normal(key, [n]))
205 | key, rng = random.split(rng)
206 | t_far = t_near + jnp.exp(jax.random.normal(key, [n]))
207 |
208 | key, rng = random.split(rng)
209 | u = jax.random.uniform(key, [n])
210 | t = t_near * (1 - u) + t_far * u
211 | key, rng = random.split(rng)
212 | s = jax.random.uniform(key, [n])
213 |
214 | t_to_s, s_to_t = coord.construct_ray_warps(jnp.reciprocal, t_near, t_far)
215 |
216 | # Special cases for fn=reciprocal.
217 | s_to_t_ref = lambda s: 1 / (s / t_far + (1 - s) / t_near)
218 | t_to_s_ref = lambda t: (t_far * (t - t_near)) / (t * (t_far - t_near))
219 |
220 | np.testing.assert_allclose(t_to_s(t), t_to_s_ref(t), atol=1E-5, rtol=1E-5)
221 | np.testing.assert_allclose(s_to_t(s), s_to_t_ref(s), atol=1E-5, rtol=1E-5)
222 |
223 | def test_expected_sin(self):
224 | normal_samples = random.normal(random.PRNGKey(0), (10000,))
225 | for mu, var in [(0, 1), (1, 3), (-2, .2), (10, 10)]:
226 | sin_mu = coord.expected_sin(mu, var)
227 | x = jnp.sin(jnp.sqrt(var) * normal_samples + mu)
228 | np.testing.assert_allclose(sin_mu, jnp.mean(x), atol=1e-2)
229 |
230 | def test_integrated_pos_enc(self):
231 | num_dims = 2 # The number of input dimensions.
232 | min_deg = 0 # Must be 0 for this test to work.
233 | max_deg = 4
234 | num_samples = 100000
235 | rng = random.PRNGKey(0)
236 | for _ in range(5):
237 | # Generate a coordinate's mean and covariance matrix.
238 | key, rng = random.split(rng)
239 | mean = random.normal(key, (2,))
240 | key, rng = random.split(rng)
241 | half_cov = jax.random.normal(key, [num_dims] * 2)
242 | cov = half_cov @ half_cov.T
243 | var = jnp.diag(cov)
244 | # Generate an IPE.
245 | enc = coord.integrated_pos_enc(
246 | mean,
247 | var,
248 | min_deg,
249 | max_deg,
250 | )
251 |
252 | # Draw samples, encode them, and take their mean.
253 | key, rng = random.split(rng)
254 | samples = random.multivariate_normal(key, mean, cov, [num_samples])
255 | assert min_deg == 0
256 | enc_samples = np.concatenate(
257 | [stable_pos_enc(x, max_deg) for x in tuple(samples.T)], axis=-1)
258 | # Correct for a different dimension ordering in stable_pos_enc.
259 | enc_gt = jnp.mean(enc_samples, 0)
260 | enc_gt = enc_gt.reshape([num_dims, max_deg * 2]).T.reshape([-1])
261 | np.testing.assert_allclose(enc, enc_gt, rtol=1e-2, atol=1e-2)
262 |
263 |
264 | if __name__ == '__main__':
265 | absltest.main()
266 |
--------------------------------------------------------------------------------
/tests/math_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit tests for math."""
16 |
17 | import functools
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from internal import math
22 | import jax
23 | from jax import random
24 | import jax.numpy as jnp
25 | import numpy as np
26 |
27 |
28 | def safe_trig_harness(fn, max_exp):
29 | x = 10**np.linspace(-30, max_exp, 10000)
30 | x = np.concatenate([-x[::-1], np.array([0]), x])
31 | y_true = getattr(np, fn)(x)
32 | y = getattr(math, 'safe_' + fn)(x)
33 | return y_true, y
34 |
35 |
36 | class MathTest(parameterized.TestCase):
37 |
38 | def test_sin(self):
39 | """In [-1e10, 1e10] safe_sin and safe_cos are accurate."""
40 | for fn in ['sin', 'cos']:
41 | y_true, y = safe_trig_harness(fn, 10)
42 | self.assertLess(jnp.max(jnp.abs(y - y_true)), 1e-4)
43 | self.assertFalse(jnp.any(jnp.isnan(y)))
44 | # Beyond that range it's less accurate but we just don't want it to be NaN.
45 | for fn in ['sin', 'cos']:
46 | y_true, y = safe_trig_harness(fn, 60)
47 | self.assertFalse(jnp.any(jnp.isnan(y)))
48 |
49 | def test_safe_exp_correct(self):
50 | """math.safe_exp() should match np.exp() for not-huge values."""
51 | x = jnp.linspace(-80, 80, 10001)
52 | y = math.safe_exp(x)
53 | g = jax.vmap(jax.grad(math.safe_exp))(x)
54 | yg_true = jnp.exp(x)
55 | np.testing.assert_allclose(y, yg_true)
56 | np.testing.assert_allclose(g, yg_true)
57 |
58 | def test_safe_exp_finite(self):
59 | """math.safe_exp() behaves reasonably for huge values."""
60 | x = jnp.linspace(-100000, 100000, 10001)
61 | y = math.safe_exp(x)
62 | g = jax.vmap(jax.grad(math.safe_exp))(x)
63 | # `y` and `g` should both always be finite.
64 | self.assertTrue(jnp.all(jnp.isfinite(y)))
65 | self.assertTrue(jnp.all(jnp.isfinite(g)))
66 | # The derivative of exp() should be exp().
67 | np.testing.assert_allclose(y, g)
68 | # safe_exp()'s output and gradient should be monotonic.
69 | self.assertTrue(jnp.all(y[1:] >= y[:-1]))
70 | self.assertTrue(jnp.all(g[1:] >= g[:-1]))
71 |
72 | def test_learning_rate_decay(self):
73 | rng = random.PRNGKey(0)
74 | for _ in range(10):
75 | key, rng = random.split(rng)
76 | lr_init = jnp.exp(random.normal(key) - 3)
77 | key, rng = random.split(rng)
78 | lr_final = lr_init * jnp.exp(random.normal(key) - 5)
79 | key, rng = random.split(rng)
80 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key))))
81 |
82 | lr_fn = functools.partial(
83 | math.learning_rate_decay,
84 | lr_init=lr_init,
85 | lr_final=lr_final,
86 | max_steps=max_steps)
87 |
88 | # Test that the rate at the beginning is the initial rate.
89 | np.testing.assert_allclose(lr_fn(0), lr_init, atol=1E-5, rtol=1E-5)
90 |
91 | # Test that the rate at the end is the final rate.
92 | np.testing.assert_allclose(
93 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5)
94 |
95 | # Test that the rate at the middle is the geometric mean of the two rates.
96 | np.testing.assert_allclose(
97 | lr_fn(max_steps / 2),
98 | jnp.sqrt(lr_init * lr_final),
99 | atol=1E-5,
100 | rtol=1E-5)
101 |
102 | # Test that the rate past the end is the final rate
103 | np.testing.assert_allclose(
104 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5)
105 |
106 | def test_delayed_learning_rate_decay(self):
107 | rng = random.PRNGKey(0)
108 | for _ in range(10):
109 | key, rng = random.split(rng)
110 | lr_init = jnp.exp(random.normal(key) - 3)
111 | key, rng = random.split(rng)
112 | lr_final = lr_init * jnp.exp(random.normal(key) - 5)
113 | key, rng = random.split(rng)
114 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key))))
115 | key, rng = random.split(rng)
116 | lr_delay_steps = int(
117 | random.uniform(key, minval=0.1, maxval=0.4) * max_steps)
118 | key, rng = random.split(rng)
119 | lr_delay_mult = jnp.exp(random.normal(key) - 3)
120 |
121 | lr_fn = functools.partial(
122 | math.learning_rate_decay,
123 | lr_init=lr_init,
124 | lr_final=lr_final,
125 | max_steps=max_steps,
126 | lr_delay_steps=lr_delay_steps,
127 | lr_delay_mult=lr_delay_mult)
128 |
129 | # Test that the rate at the beginning is the delayed initial rate.
130 | np.testing.assert_allclose(
131 | lr_fn(0), lr_delay_mult * lr_init, atol=1E-5, rtol=1E-5)
132 |
133 | # Test that the rate at the end is the final rate.
134 | np.testing.assert_allclose(
135 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5)
136 |
137 | # Test that the rate at after the delay is over is the usual rate.
138 | np.testing.assert_allclose(
139 | lr_fn(lr_delay_steps),
140 | math.learning_rate_decay(lr_delay_steps, lr_init, lr_final,
141 | max_steps),
142 | atol=1E-5,
143 | rtol=1E-5)
144 |
145 | # Test that the rate at the middle is the geometric mean of the two rates.
146 | np.testing.assert_allclose(
147 | lr_fn(max_steps / 2),
148 | jnp.sqrt(lr_init * lr_final),
149 | atol=1E-5,
150 | rtol=1E-5)
151 |
152 | # Test that the rate past the end is the final rate
153 | np.testing.assert_allclose(
154 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5)
155 |
156 | @parameterized.named_parameters(('', False), ('sort', True))
157 | def test_interp(self, sort):
158 | n, d0, d1 = 100, 10, 20
159 | rng = random.PRNGKey(0)
160 |
161 | key, rng = random.split(rng)
162 | x = random.normal(key, [n, d0])
163 |
164 | key, rng = random.split(rng)
165 | xp = random.normal(key, [n, d1])
166 |
167 | key, rng = random.split(rng)
168 | fp = random.normal(key, [n, d1])
169 |
170 | if sort:
171 | xp = jnp.sort(xp, axis=-1)
172 | fp = jnp.sort(fp, axis=-1)
173 | z = math.sorted_interp(x, xp, fp)
174 | else:
175 | z = math.interp(x, xp, fp)
176 |
177 | z_true = jnp.stack([jnp.interp(x[i], xp[i], fp[i]) for i in range(n)])
178 | np.testing.assert_allclose(z, z_true, atol=1e-5, rtol=1e-5)
179 |
180 |
181 | if __name__ == '__main__':
182 | absltest.main()
183 |
--------------------------------------------------------------------------------
/tests/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for utils."""
16 |
17 | from absl.testing import absltest
18 |
19 | from internal import utils
20 |
21 |
22 | class UtilsTest(absltest.TestCase):
23 |
24 | def test_dummy_rays(self):
25 | """Ensures that the dummy Rays object is correctly initialized."""
26 | rays = utils.dummy_rays()
27 | self.assertEqual(rays.origins.shape[-1], 3)
28 |
29 |
30 | if __name__ == '__main__':
31 | absltest.main()
32 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Training script."""
16 |
17 | import functools
18 | import gc
19 | import time
20 |
21 | from absl import app
22 | import flax
23 | from flax.metrics import tensorboard
24 | from flax.training import checkpoints
25 | import gin
26 | from internal import configs
27 | from internal import datasets
28 | from internal import image
29 | from internal import models
30 | from internal import train_utils
31 | from internal import utils
32 | from internal import vis
33 | import jax
34 | from jax import random
35 | import jax.numpy as jnp
36 | import numpy as np
37 |
38 | configs.define_common_flags()
39 | jax.config.parse_flags_with_absl()
40 |
41 | TIME_PRECISION = 1000 # Internally represent integer times in milliseconds.
42 |
43 |
44 | def main(unused_argv):
45 | rng = random.PRNGKey(20200823)
46 | # Shift the numpy random seed by host_id() to shuffle data loaded by different
47 | # hosts.
48 | np.random.seed(20201473 + jax.host_id())
49 |
50 | config = configs.load_config()
51 |
52 | if config.batch_size % jax.device_count() != 0:
53 | raise ValueError('Batch size must be divisible by the number of devices.')
54 |
55 | dataset = datasets.load_dataset('train', config.data_dir, config)
56 | test_dataset = datasets.load_dataset('test', config.data_dir, config)
57 | np_to_jax = lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x
58 | cameras = tuple(np_to_jax(x) for x in dataset.cameras)
59 |
60 | if config.rawnerf_mode:
61 | postprocess_fn = test_dataset.metadata['postprocess_fn']
62 | else:
63 | postprocess_fn = lambda z, _=None: z
64 |
65 | rng, key = random.split(rng)
66 | setup = train_utils.setup_model(config, key, dataset=dataset)
67 | model, state, render_eval_pfn, train_pstep, lr_fn = setup
68 |
69 | variables = state.params
70 | num_params = jax.tree_util.tree_reduce(
71 | lambda x, y: x + jnp.prod(jnp.array(y.shape)), variables, initializer=0)
72 | print(f'Number of parameters being optimized: {num_params}')
73 |
74 | if (dataset.size > model.num_glo_embeddings and model.num_glo_features > 0):
75 | raise ValueError(f'Number of glo embeddings {model.num_glo_embeddings} '
76 | f'must be at least equal to number of train images '
77 | f'{dataset.size}')
78 |
79 | metric_harness = image.MetricHarness()
80 |
81 | if not utils.isdir(config.checkpoint_dir):
82 | utils.makedirs(config.checkpoint_dir)
83 | state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)
84 | # Resume training at the step of the last checkpoint.
85 | init_step = state.step + 1
86 | state = flax.jax_utils.replicate(state)
87 |
88 | if jax.host_id() == 0:
89 | summary_writer = tensorboard.SummaryWriter(config.checkpoint_dir)
90 | if config.rawnerf_mode:
91 | for name, data in zip(['train', 'test'], [dataset, test_dataset]):
92 | # Log shutter speed metadata in TensorBoard for debug purposes.
93 | for key in ['exposure_idx', 'exposure_values', 'unique_shutters']:
94 | summary_writer.text(f'{name}_{key}', str(data.metadata[key]), 0)
95 |
96 | # Prefetch_buffer_size = 3 x batch_size.
97 | pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
98 | rng = rng + jax.host_id() # Make random seed separate across hosts.
99 | rngs = random.split(rng, jax.local_device_count()) # For pmapping RNG keys.
100 | gc.disable() # Disable automatic garbage collection for efficiency.
101 | total_time = 0
102 | total_steps = 0
103 | reset_stats = True
104 | if config.early_exit_steps is not None:
105 | num_steps = config.early_exit_steps
106 | else:
107 | num_steps = config.max_steps
108 | loss_threshold = 1.0
109 | for step, batch in zip(range(init_step, num_steps + 1), pdataset):
110 | if reset_stats and (jax.host_id() == 0):
111 | stats_buffer = []
112 | train_start_time = time.time()
113 | reset_stats = False
114 |
115 | learning_rate = lr_fn(step)
116 | train_frac = jnp.clip((step - 1) / (config.max_steps - 1), 0, 1)
117 |
118 | state, stats, rngs = train_pstep(
119 | rngs,
120 | state,
121 | batch,
122 | cameras,
123 | train_frac,
124 | loss_threshold,
125 | )
126 | if config.enable_robustnerf_loss:
127 | loss_threshold = jnp.mean(stats['loss_threshold'])
128 |
129 | if step % config.gc_every == 0:
130 | gc.collect() # Disable automatic garbage collection for efficiency.
131 |
132 | # Log training summaries. This is put behind a host_id check because in
133 | # multi-host evaluation, all hosts need to run inference even though we
134 | # only use host 0 to record results.
135 | if jax.host_id() == 0:
136 | stats = flax.jax_utils.unreplicate(stats)
137 |
138 | stats_buffer.append(stats)
139 |
140 | if step == init_step or step % config.print_every == 0:
141 | elapsed_time = time.time() - train_start_time
142 | steps_per_sec = config.print_every / elapsed_time
143 | rays_per_sec = config.batch_size * steps_per_sec
144 |
145 | # A robust approximation of total training time, in case of pre-emption.
146 | total_time += int(round(TIME_PRECISION * elapsed_time))
147 | total_steps += config.print_every
148 | approx_total_time = int(round(step * total_time / total_steps))
149 |
150 | # Transpose and stack stats_buffer along axis 0.
151 | fs = [flax.traverse_util.flatten_dict(s, sep='/') for s in stats_buffer]
152 | stats_stacked = {k: jnp.stack([f[k] for f in fs]) for k in fs[0].keys()}
153 |
154 | # Split every statistic that isn't a vector into a set of statistics.
155 | stats_split = {}
156 | for k, v in stats_stacked.items():
157 | if v.ndim not in [1, 2] and v.shape[0] != len(stats_buffer):
158 | raise ValueError('statistics must be of size [n], or [n, k].')
159 | if v.ndim == 1:
160 | stats_split[k] = v
161 | elif v.ndim == 2:
162 | for i, vi in enumerate(tuple(v.T)):
163 | stats_split[f'{k}/{i}'] = vi
164 |
165 | # Summarize the entire histogram of each statistic.
166 | for k, v in stats_split.items():
167 | summary_writer.histogram('train_' + k, v, step)
168 |
169 | # Take the mean and max of each statistic since the last summary.
170 | avg_stats = {k: jnp.mean(v) for k, v in stats_split.items()}
171 | max_stats = {k: jnp.max(v) for k, v in stats_split.items()}
172 |
173 | summ_fn = lambda s, v: summary_writer.scalar(s, v, step) # pylint:disable=cell-var-from-loop
174 |
175 | # Summarize the mean and max of each statistic.
176 | for k, v in avg_stats.items():
177 | summ_fn(f'train_avg_{k}', v)
178 | for k, v in max_stats.items():
179 | summ_fn(f'train_max_{k}', v)
180 |
181 | summ_fn('train_num_params', num_params)
182 | summ_fn('train_learning_rate', learning_rate)
183 | summ_fn('train_steps_per_sec', steps_per_sec)
184 | summ_fn('train_rays_per_sec', rays_per_sec)
185 |
186 | summary_writer.scalar('train_avg_psnr_timed', avg_stats['psnr'],
187 | total_time // TIME_PRECISION)
188 | summary_writer.scalar('train_avg_psnr_timed_approx', avg_stats['psnr'],
189 | approx_total_time // TIME_PRECISION)
190 |
191 | if dataset.metadata is not None and model.learned_exposure_scaling:
192 | params = state.params['params']
193 | scalings = params['exposure_scaling_offsets']['embedding'][0]
194 | num_shutter_speeds = dataset.metadata['unique_shutters'].shape[0]
195 | for i_s in range(num_shutter_speeds):
196 | for j_s, value in enumerate(scalings[i_s]):
197 | summary_name = f'exposure/scaling_{i_s}_{j_s}'
198 | summary_writer.scalar(summary_name, value, step)
199 |
200 | precision = int(np.ceil(np.log10(config.max_steps))) + 1
201 | avg_loss = avg_stats['loss']
202 | avg_psnr = avg_stats['psnr']
203 | str_losses = { # Grab each "losses_{x}" field and print it as "x[:4]".
204 | k[7:11]: (f'{v:0.5f}' if v >= 1e-4 and v < 10 else f'{v:0.1e}')
205 | for k, v in avg_stats.items()
206 | if k.startswith('losses/')
207 | }
208 | print(f'{step:{precision}d}' + f'/{config.max_steps:d}: ' +
209 | f'loss={avg_loss:0.5f}, ' + f'psnr={avg_psnr:6.3f}, ' +
210 | f'lr={learning_rate:0.2e} | ' +
211 | ', '.join([f'{k}={s}' for k, s in str_losses.items()]) +
212 | f', {rays_per_sec:0.0f} r/s')
213 |
214 | # Reset everything we are tracking between summarizations.
215 | reset_stats = True
216 |
217 | if step == 1 or step % config.checkpoint_every == 0:
218 | state_to_save = jax.device_get(
219 | flax.jax_utils.unreplicate(state))
220 | checkpoints.save_checkpoint(
221 | config.checkpoint_dir, state_to_save, int(step), keep=100)
222 |
223 | # Test-set evaluation.
224 | if config.train_render_every > 0 and step % config.train_render_every == 0:
225 | # We reuse the same random number generator from the optimization step
226 | # here on purpose so that the visualization matches what happened in
227 | # training.
228 | eval_start_time = time.time()
229 | eval_variables = flax.jax_utils.unreplicate(state).params
230 | test_case = next(test_dataset)
231 | rendering = models.render_image(
232 | functools.partial(render_eval_pfn, eval_variables, train_frac),
233 | test_case.rays, rngs[0], config)
234 | # Log eval summaries on host 0.
235 | if jax.host_id() == 0:
236 | eval_time = time.time() - eval_start_time
237 | num_rays = jnp.prod(jnp.array(test_case.rays.directions.shape[:-1]))
238 | rays_per_sec = num_rays / eval_time
239 | summary_writer.scalar('test_rays_per_sec', rays_per_sec, step)
240 | print(f'Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec')
241 |
242 | metric_start_time = time.time()
243 | metric = metric_harness(
244 | postprocess_fn(rendering['rgb']), postprocess_fn(test_case.rgb))
245 | print(f'Metrics computed in {(time.time() - metric_start_time):0.3f}s')
246 | for name, val in metric.items():
247 | if not np.isnan(val):
248 | print(f'{name} = {val:.4f}')
249 | summary_writer.scalar('train_metrics/' + name, val, step)
250 |
251 | if config.vis_decimate > 1:
252 | d = config.vis_decimate
253 | decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d]
254 | else:
255 | decimate_fn = lambda x: x
256 | rendering = jax.tree_util.tree_map(decimate_fn, rendering)
257 | test_case = jax.tree_util.tree_map(decimate_fn, test_case)
258 | vis_start_time = time.time()
259 | vis_suite = vis.visualize_suite(rendering, test_case.rays)
260 | print(f'Visualized in {(time.time() - vis_start_time):0.3f}s')
261 | if config.rawnerf_mode:
262 | # Unprocess raw output.
263 | vis_suite['color_raw'] = rendering['rgb']
264 | # Autoexposed colors.
265 | vis_suite['color_auto'] = postprocess_fn(rendering['rgb'], None)
266 | summary_writer.image('test_true_auto',
267 | postprocess_fn(test_case.rgb, None), step)
268 | # Exposure sweep colors.
269 | exposures = test_dataset.metadata['exposure_levels']
270 | for p, x in list(exposures.items()):
271 | vis_suite[f'color/{p}'] = postprocess_fn(rendering['rgb'], x)
272 | summary_writer.image(f'test_true_color/{p}',
273 | postprocess_fn(test_case.rgb, x), step)
274 | summary_writer.image('test_true_color', test_case.rgb, step)
275 | for k, v in vis_suite.items():
276 | summary_writer.image('test_output_' + k, v, step)
277 |
278 | if jax.host_id() == 0 and config.max_steps % config.checkpoint_every != 0:
279 | state = jax.device_get(flax.jax_utils.unreplicate(state))
280 | checkpoints.save_checkpoint(
281 | config.checkpoint_dir, state, int(config.max_steps), keep=100)
282 |
283 |
284 | if __name__ == '__main__':
285 | with gin.config_scope('train'):
286 | app.run(main)
287 |
--------------------------------------------------------------------------------