├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── 360_dino.gin ├── 360_glo4.gin └── 360_robustnerf.gin ├── eval.py ├── internal ├── camera_utils.py ├── configs.py ├── coord.py ├── datasets.py ├── geopoly.py ├── image.py ├── math.py ├── models.py ├── raw_utils.py ├── ref_utils.py ├── render.py ├── robustnerf.py ├── stepfun.py ├── train_utils.py ├── utils.py └── vis.py ├── media └── teaser.gif ├── render.py ├── requirements.txt ├── scripts ├── download_on-the-go.sh ├── eval_on-the-go.sh ├── eval_on-the-go_HD.sh ├── feature_extract.py ├── feature_extract.sh ├── local_colmap_and_resize.sh ├── render_on-the-go.sh ├── render_on-the-go_HD.sh ├── run_all_unit_tests.sh ├── train_on-the-go.sh └── train_on-the-go_HD.sh ├── tests ├── camera_utils_test.py ├── coord_test.py ├── math_test.py ├── stepfun_test.py └── utils_test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | internal/pycolmap 2 | __pycache__/ 3 | interal/__pycache__/ 4 | tests/__pycache__/ 5 | .DS_Store 6 | .vscode/ 7 | .idea/ 8 | __MACOSX/ 9 | *.err 10 | *.out 11 | output/ 12 | RobustNerf/ 13 | data/ 14 | jupyter/ 15 | slurm/ 16 | output_8/ 17 | output_2/ 18 | output_ablation/ 19 | zzh_output_8/ 20 | 360/ 21 | output_360/ 22 | output_new_ablation/ 23 | output_highres 24 | scripts/SAM/* 25 | scripts/static/* 26 | scripts/blockview/* 27 | video_maker/ 28 | output_ablation_new/ 29 | tmp_script/ 30 | scripts/SAM/ 31 | Datasets/ 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # NeRF On-the-go: Exploiting Uncertainty for Distractor-free NeRFs in the Wild 4 | 5 |

6 | Weining Ren* 7 | · 8 | Zihan Zhu* 9 | · 10 | Boyang Sun 11 | · 12 | Julia Chen 13 | · 14 | Marc Pollefeys 15 | · 16 | Songyou Peng 17 |

18 |

(* Equal Contribution)

19 |

CVPR 2024

20 |

Paper | Video | Project Page

21 |

22 | 23 | Logo 24 | 25 |

26 |
27 | 28 |
29 | 30 |
31 | 32 | 33 |
34 | Table of Contents 35 |
    36 |
  1. 37 | Description 38 |
  2. 39 |
  3. 40 | Setup 41 |
  4. 42 |
  5. 43 | Dataset Preparation 44 |
  6. 45 |
  7. 46 | Running 47 |
  8. 48 |
  9. 49 | Checkpoint 50 |
  10. 51 |
  11. 52 | Citation 53 |
  12. 54 |
  13. 55 | Contact 56 |
  14. 57 |
58 |
59 | 60 | 61 | ## Description 62 | 63 | This repository hosts the official Jax implementation of the paper "NeRF on-the-go: Exploiting Uncertainty for Distractor-free NeRFs in the Wild" (CVPR 2024). For more details, please visit our [project webpage](https://rwn17.github.io/nerf-on-the-go/). 64 | 65 | This Repo is built upon [Multinerf](https://github.com/google-research/multinerf) codebase. 66 | 67 | ## Setup 68 | 69 | ``` 70 | # Clone the repo. 71 | git clone https://github.com/cvg/nerf-on-the-go 72 | cd nerf-on-the-go 73 | 74 | # Make a conda environment. 75 | conda create --name on-the-go python=3.9 76 | conda activate on-the-go 77 | 78 | # Prepare pip. 79 | conda install pip 80 | pip install --upgrade pip 81 | 82 | 83 | # Install requirements. 84 | pip install -r requirements.txt 85 | 86 | # Manually install rmbrualla's `pycolmap` (don't use pip's! It's different). 87 | git clone https://github.com/rmbrualla/pycolmap.git ./internal/pycolmap 88 | 89 | # Confirm that all the unit tests pass. 90 | ./scripts/run_all_unit_tests.sh 91 | ``` 92 | You'll also need to update your [JAX](https://jax.readthedocs.io/en/latest/installation.html) installation to support GPUs or TPUs. 93 | 94 | ``` 95 | pip install -U "jax[cuda12]" 96 | ``` 97 | 98 | ### Instructions for ETH Euler 99 |
100 | Click to expand 101 | 102 | on ETH Euler, to support for GPU jax, you need to apply for a debug mode gpu and then upgrade the gcc and cuda 103 | ``` 104 | srun -n 4 --mem-per-cpu=12000 --gpus=rtx_3090:1 --gres=gpumem:20g --time=4:00:00 --pty bash 105 | conda activate on-the-go 106 | module load eth_proxy gcc/8.2.0 cuda/12.1.1 cudnn/8.9.2.26 107 | ``` 108 | 109 | After loading the modules, verify their activation by executing ```module list```. Occasionally, modules may not load correctly, requiring you to load each one individually. Following this, proceed with the Jax installation: 110 | 111 | ``` 112 | # Installs the wheel compatible with CUDA 12 and cuDNN 8.9 or newer. 113 | pip install jax==0.4.26 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 114 | pip install jaxlib==0.4.26+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 115 | ``` 116 | 117 | After successful installation, please rerun ```./scripts/run_all_unit_tests.sh```. 118 | 119 | The installation process outlined above has been verified on the Euler system using an RTX 3090. You may get a warning 120 | ``` 121 | The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages. 122 | ``` 123 | 124 | But it's fine. The Euler supports up to CUDA 12.1, while JAX now requires a minimum of CUDA 12.3. As discussed in the JAX [Issue #18032](https://github.com/google/jax/issues/18032), this discrepancy primarily impacts compilation speed rather than overall functionality. 125 |
126 | 127 | 128 | ## Dataset Preparation 129 | 130 | ### Downloading the Dataset 131 | 132 | Before process data, please make sure you install mogrify by 133 | 134 | ``` 135 | conda install -y -c conda-forge imagemagick 136 | ``` 137 | 138 | To download the "On-the-go" dataset, execute the following command: 139 | ```bash 140 | bash ./scripts/download_on-the-go.sh 141 | ``` 142 | This script not only downloads the dataset but also downsamples the images as required. **NOTE: Please double check whether the data has been correctly DOWNSAMPLED!** 143 | 144 | ### Feature Extraction with DINOv2 145 | For extracting features using the DINOv2, use the command below: 146 | ```bash 147 | bash ./scripts/feature_extract.sh 148 | ``` 149 | 150 | 151 | After feature extraction, the dataset should be organized as 152 | ``` 153 | on-the-go 154 | ├── arcdetriomphe 155 | │ ├── images 156 | │ ├── images_{DOWNSAMPLE_RATE} 157 | │ ├── features_{DOWNSAMPLE_RATE} 158 | │ ├── split.json 159 | │ ├── transforms.json 160 | ├── .... 161 | │ 162 | └── tree 163 | ├── images_{DOWNSAMPLE_RATE} 164 | ├── .... 165 | └── transforms.json 166 | ``` 167 | 168 | ### Dataset Structure and Configuration Files 169 | - **split.json**: This file outlines the train and evaluation splits, following the naming conventions used in the RobustNeRF dataset, categorized as 'clutter' and 'clean'. 170 | - **transforms.json**: Contains pose and intrinsic information, formatted according to the Blender dataset format, derived from COLMAP files. Refer to the [Instant-NGP script](https://github.com/NVlabs/instant-ngp/blob/de507662d4b3398163e426fd426d48ff8f2895f6/scripts/colmap2nerf.py) for more details. 171 | 172 | ### Future Updates 173 | We plan to expand support to include custom datasets in future updates. 174 | 175 | 176 | ## Running 177 | 178 | Example scripts for training, evaluating, and rendering can be found in 179 | `scripts/`. You'll need to change the paths to point to wherever the datasets 180 | are located. [Gin](https://github.com/google/gin-config) configuration files 181 | for our model and some ablations can be found in `configs/`. 182 | 183 | 1. Training on-the-go: 184 | ``` 185 | bash scripts/train_on-the-go.sh 186 | ``` 187 | 188 | 2. Evaluating on-the-go: 189 | ``` 190 | bash scripts/eval_on-the-go.sh 191 | ``` 192 | 193 | 3. Rendering on-the-go: 194 | ``` 195 | bash scirpts/render_on-the-go.sh 196 | ``` 197 | 198 | Tensorboard is supported for logging. 199 | 200 | ### Note 201 | Since we use a different recording device for ***arc de triomphe*** and ***patio*** scene, the image downsample rate(4 instead of 8) and feature downsample rate(2 instead of 4) is different. Please use a separate script to train them by 202 | 203 | ``` 204 | bash scripts/train_on-the-go_HD.sh 205 | ``` 206 | 207 | ### OOM errors 208 | 209 | About **80G gpu memory** is needed to run current version.You may need to reduce the batch size (`Config.batch_size`) to avoid out of memory 210 | errors. If you do this, but want to preserve quality, be sure to increase the number 211 | of training iterations and decrease the learning rate by whatever scale factor you 212 | decrease batch size by. 213 | 214 | ## Checkpoint 215 | We release the ckpt for quantatitive scenes [here](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/weining_connect_hku_hk/ER2Esfrn0plAjCa2I6G2BJ4B56qX4B5whdMgk5T90A_F-A?e=3gklLp). 216 | 217 | Scene | Mountain | Fountain | Corner | Patio | Spot | Patio-High 218 | -- | -- | -- | -- | -- | -- | -- 219 | onthego (paper) | 20.15 | 20.11 | 24.22 | 20.78 | 23.33 | 21.41 220 | onthego (released ckpt) | 20.89 | 19.88 | 24.69 | 22.30 | 24.67 | 22.30 221 | 222 | 223 | ## Citation 224 | 225 | If you use NeRF on-the-go, please cite 226 | 227 | ``` 228 | @InProceedings{Ren2024NeRF, 229 | title={NeRF on-the-go: Exploiting Uncertainty for Distractor-free NeRFs in the Wild}, 230 | author={Ren, Weining and Zhu, Zihan and Sun, Boyang and Chen, Jiaqi and Pollefeys, Marc and Peng, Songyou}, 231 | booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 232 | year={2024} 233 | } 234 | ``` 235 | 236 | Also, this code is built upon multinerf, feel free to cite this entire codebase as: 237 | 238 | ``` 239 | @misc{multinerf2022, 240 | title={{MultiNeRF}: {A} {Code} {Release} for {Mip-NeRF} 360, {Ref-NeRF}, and {RawNeRF}}, 241 | author={Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman and Ricardo Martin-Brualla and Jonathan T. Barron}, 242 | year={2022}, 243 | url={https://github.com/google-research/multinerf}, 244 | } 245 | ``` 246 | 247 | ## Contact 248 | If there is any problem, please contact Weining by weining@connect.hku.hk 249 | -------------------------------------------------------------------------------- /configs/360_dino.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'on-the-go' 2 | Config.near = 0.2 3 | Config.far = 1e6 4 | Config.factor = 8 5 | 6 | Config.patch_size = 32 7 | Config.data_loss_type = 'dino_ssim' 8 | Config.enable_robustnerf_loss = False 9 | Config.compute_feature_metrics = True 10 | 11 | Model.raydist_fn = @jnp.reciprocal 12 | Model.opaque_background = True 13 | Model.num_glo_features = 4 14 | 15 | 16 | PropMLP.warp_fn = @coord.contract 17 | PropMLP.net_depth = 4 18 | PropMLP.net_width = 256 19 | PropMLP.disable_density_normals = True 20 | PropMLP.disable_rgb = True 21 | 22 | NerfMLP.warp_fn = @coord.contract 23 | NerfMLP.net_depth = 8 24 | NerfMLP.net_width = 1024 25 | NerfMLP.disable_density_normals = True 26 | -------------------------------------------------------------------------------- /configs/360_glo4.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'on-the-go' 2 | Config.near = 0.2 3 | Config.far = 1e6 4 | Config.factor = 4 5 | 6 | Model.raydist_fn = @jnp.reciprocal 7 | Model.num_glo_features = 4 8 | Model.opaque_background = True 9 | 10 | PropMLP.warp_fn = @coord.contract 11 | PropMLP.net_depth = 4 12 | PropMLP.net_width = 256 13 | PropMLP.disable_density_normals = True 14 | PropMLP.disable_rgb = True 15 | 16 | NerfMLP.warp_fn = @coord.contract 17 | NerfMLP.net_depth = 8 18 | NerfMLP.net_width = 1024 19 | NerfMLP.disable_density_normals = True 20 | -------------------------------------------------------------------------------- /configs/360_robustnerf.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'on-the-go' 2 | Config.near = 0.2 3 | Config.far = 1e6 4 | Config.factor = 8 5 | 6 | Config.patch_size = 16 7 | Config.data_loss_type = 'robustnerf' 8 | Config.robustnerf_inlier_quantile = 0.8 9 | Config.enable_robustnerf_loss = True 10 | Config.compute_feature_metrics = True 11 | Model.num_glo_features = 4 12 | 13 | 14 | Model.raydist_fn = @jnp.reciprocal 15 | Model.opaque_background = True 16 | 17 | PropMLP.warp_fn = @coord.contract 18 | PropMLP.net_depth = 4 19 | PropMLP.net_width = 256 20 | PropMLP.disable_density_normals = True 21 | PropMLP.disable_rgb = True 22 | 23 | NerfMLP.warp_fn = @coord.contract 24 | NerfMLP.net_depth = 8 25 | NerfMLP.net_width = 1024 26 | NerfMLP.disable_density_normals = True 27 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluation script.""" 16 | 17 | import functools 18 | from os import path 19 | import sys 20 | import time 21 | 22 | from absl import app 23 | from flax.metrics import tensorboard 24 | from flax.training import checkpoints 25 | import gin 26 | from internal import configs 27 | from internal import datasets 28 | from internal import image 29 | from internal import models 30 | from internal import raw_utils 31 | from internal import ref_utils 32 | from internal import train_utils 33 | from internal import utils 34 | from internal import vis 35 | import jax 36 | from jax import random 37 | import jax.numpy as jnp 38 | import numpy as np 39 | from matplotlib import cm 40 | from internal.vis import visualize_cmap 41 | 42 | 43 | configs.define_common_flags() 44 | jax.config.parse_flags_with_absl() 45 | 46 | 47 | def main(unused_argv): 48 | config = configs.load_config(save_config=False) 49 | 50 | dataset = datasets.load_dataset('test', config.data_dir, config) 51 | 52 | key = random.PRNGKey(20200823) 53 | _, state, render_eval_pfn, _, _ = train_utils.setup_model(config, key) 54 | 55 | if config.rawnerf_mode: 56 | postprocess_fn = dataset.metadata['postprocess_fn'] 57 | else: 58 | postprocess_fn = lambda z: z 59 | 60 | if config.eval_raw_affine_cc: 61 | cc_fun = raw_utils.match_images_affine 62 | else: 63 | cc_fun = image.color_correct 64 | 65 | metric_harness = image.MetricHarnessLPIPS() 66 | 67 | last_step = 0 68 | dir_name = 'train_preds' if config.eval_train else 'test_preds' 69 | out_dir = path.join(config.checkpoint_dir, 70 | 'path_renders' if config.render_path else dir_name) 71 | path_fn = lambda x: path.join(out_dir, x) 72 | 73 | if not config.eval_only_once: 74 | summary_writer = tensorboard.SummaryWriter( 75 | path.join(config.checkpoint_dir, 'eval')) 76 | while True: 77 | state = checkpoints.restore_checkpoint(config.checkpoint_dir, state) 78 | step = int(state.step) 79 | if step <= last_step: 80 | print(f'Checkpoint step {step} <= last step {last_step}, sleeping.') 81 | time.sleep(10) 82 | continue 83 | print(f'Evaluating checkpoint at step {step}.') 84 | if config.eval_save_output and (not utils.isdir(out_dir)): 85 | utils.makedirs(out_dir) 86 | 87 | num_eval = min(dataset.size, config.eval_dataset_limit) 88 | key = random.PRNGKey(0 if config.deterministic_showcase else step) 89 | perm = random.permutation(key, num_eval) 90 | showcase_indices = np.sort(perm[:config.num_showcase_images]) 91 | 92 | metrics = [] 93 | metrics_cc = [] 94 | showcases = [] 95 | render_times = [] 96 | for idx in range(dataset.size): 97 | eval_start_time = time.time() 98 | batch = next(dataset) 99 | if idx >= num_eval: 100 | print(f'Skipping image {idx+1}/{dataset.size}') 101 | continue 102 | print(f'Evaluating image {idx+1}/{dataset.size}') 103 | rays = batch.rays 104 | train_frac = state.step / config.max_steps 105 | rendering = models.render_image( 106 | functools.partial( 107 | render_eval_pfn, 108 | state.params, 109 | train_frac, 110 | ), 111 | rays, 112 | None, 113 | config, 114 | ) 115 | 116 | if jax.host_id() != 0: # Only record via host 0. 117 | continue 118 | 119 | render_times.append((time.time() - eval_start_time)) 120 | print(f'Rendered in {render_times[-1]:0.3f}s') 121 | 122 | # Cast to 64-bit to ensure high precision for color correction function. 123 | gt_rgb = np.array(batch.rgb, dtype=np.float64) 124 | rendering['rgb'] = np.array(rendering['rgb'], dtype=np.float64) 125 | 126 | cc_start_time = time.time() 127 | rendering['rgb_cc'] = cc_fun(rendering['rgb'], gt_rgb) 128 | # rendering['rgb_cc'] = rendering['rgb'] 129 | print(f'Color corrected in {(time.time() - cc_start_time):0.3f}s') 130 | 131 | if not config.eval_only_once and idx in showcase_indices: 132 | showcase_idx = idx if config.deterministic_showcase else len(showcases) 133 | showcases.append((showcase_idx, rendering, batch)) 134 | if not config.render_path: 135 | rgb = postprocess_fn(rendering['rgb']) 136 | rgb_cc = postprocess_fn(rendering['rgb_cc']) 137 | rgb_gt = postprocess_fn(gt_rgb) 138 | 139 | if config.eval_quantize_metrics: 140 | # Ensures that the images written to disk reproduce the metrics. 141 | rgb = np.round(rgb * 255) / 255 142 | rgb_cc = np.round(rgb_cc * 255) / 255 143 | 144 | if config.eval_crop_borders > 0: 145 | crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c] 146 | rgb = crop_fn(rgb) 147 | rgb_cc = crop_fn(rgb_cc) 148 | rgb_gt = crop_fn(rgb_gt) 149 | metric = metric_harness(rgb.astype(np.float32), rgb_gt.astype(np.float32)) 150 | metric_cc = metric_harness(rgb_cc.astype(np.float32), rgb_gt.astype(np.float32)) 151 | 152 | for m, v in metric.items(): 153 | print(f'{m:30s} = {v:.4f}') 154 | 155 | metrics.append(metric) 156 | metrics_cc.append(metric_cc) 157 | 158 | if config.eval_save_output and (config.eval_render_interval > 0): 159 | if (idx % config.eval_render_interval) == 0: 160 | utils.save_img_u8(postprocess_fn(rendering['rgb']), 161 | path_fn(f'color_{idx:03d}.png')) 162 | utils.save_img_u8(postprocess_fn(rendering['rgb_cc']), 163 | path_fn(f'color_{idx:03d}_cc.png')) 164 | utils.save_img_u8(rgb_gt, 165 | path_fn(f'gt_color_{idx:03d}.png')) 166 | utils.save_img_u8(postprocess_fn(rendering['rgb_cc']), 167 | path_fn(f'color_cc_{idx:03d}.png')) 168 | 169 | for key in ['distance_mean', 'distance_median']: 170 | if key in rendering: 171 | utils.save_img_f32(rendering[key], 172 | path_fn(f'{key}_{idx:03d}.tiff')) 173 | 174 | for key in ['normals']: 175 | if key in rendering: 176 | utils.save_img_u8(rendering[key] / 2. + 0.5, 177 | path_fn(f'{key}_{idx:03d}.png')) 178 | 179 | vis_uncertainty = visualize_cmap( 180 | rendering['uncer'][...,0], 181 | rendering['acc'], 182 | cm.get_cmap('turbo'), 183 | lo=0.2, 184 | hi=2, 185 | ) 186 | utils.save_img_u8(postprocess_fn(vis_uncertainty), path_fn(f'uncer_{idx:03d}.png')) 187 | utils.save_img_f32(rendering['uncer'][...,0], path_fn(f'uncer_raw_{idx:03d}.tiff')) 188 | 189 | if (not config.eval_only_once) and (jax.host_id() == 0): 190 | summary_writer.scalar('eval_median_render_time', np.median(render_times), 191 | step) 192 | for name in metrics[0]: 193 | scores = [m[name] for m in metrics] 194 | summary_writer.scalar('eval_metrics/' + name, np.mean(scores), step) 195 | summary_writer.histogram('eval_metrics/' + 'perimage_' + name, scores, 196 | step) 197 | for name in metrics_cc[0]: 198 | scores = [m[name] for m in metrics_cc] 199 | summary_writer.scalar('eval_metrics_cc/' + name, np.mean(scores), step) 200 | summary_writer.histogram('eval_metrics_cc/' + 'perimage_' + name, 201 | scores, step) 202 | 203 | for i, r, b in showcases: 204 | if config.vis_decimate > 1: 205 | d = config.vis_decimate 206 | decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d] 207 | else: 208 | decimate_fn = lambda x: x 209 | r = jax.tree_util.tree_map(decimate_fn, r) 210 | b = jax.tree_util.tree_map(decimate_fn, b) 211 | visualizations = vis.visualize_suite(r, b.rays) 212 | for k, v in visualizations.items(): 213 | if k == 'color': 214 | v = postprocess_fn(v) 215 | summary_writer.image(f'output_{k}_{i}', v, step) 216 | if not config.render_path: 217 | target = postprocess_fn(b.rgb) 218 | summary_writer.image(f'true_color_{i}', target, step) 219 | pred = postprocess_fn(visualizations['color']) 220 | residual = np.clip(pred - target + 0.5, 0, 1) 221 | summary_writer.image(f'true_residual_{i}', residual, step) 222 | summary_writer.image(f'uncertainty_{i}', visualizations['uncertainty'], 223 | step) 224 | 225 | if (config.eval_save_output and (not config.render_path) and 226 | (jax.host_id() == 0)): 227 | with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f: 228 | f.write(' '.join([str(r) for r in render_times])) 229 | for name in metrics[0]: 230 | with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f: 231 | f.write(' '.join([str(m[name]) for m in metrics])) 232 | for name in metrics_cc[0]: 233 | with utils.open_file(path_fn(f'metric_cc_{name}_{step}.txt'), 'w') as f: 234 | f.write(' '.join([str(m[name]) for m in metrics_cc])) 235 | if config.eval_save_ray_data: 236 | for i, r, b in showcases: 237 | rays = {k: v for k, v in r.items() if 'ray_' in k} 238 | np.set_printoptions(threshold=sys.maxsize) 239 | with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f: 240 | f.write(repr(rays)) 241 | for name in metrics[0]: 242 | with utils.open_file(path_fn(f'metric_{name}_{step}_avg.txt'), 'w') as f: 243 | avg=np.mean(np.array([m[name] for m in metrics])) 244 | if 0 < avg < 1: 245 | f.write(f'{avg:.3f}'[1:]) 246 | else: 247 | f.write(f'{avg:.2f}') 248 | for name in metrics_cc[0]: 249 | with utils.open_file(path_fn(f'metric_cc_{name}_{step}_avg.txt'), 'w') as f: 250 | avg=np.mean(np.array([m[name] for m in metrics_cc])) 251 | if 0 < avg < 1: 252 | f.write(f'{avg:.3f}'[1:]) 253 | else: 254 | f.write(f'{avg:.2f}') 255 | 256 | 257 | 258 | # A hack that forces Jax to keep all TPUs alive until every TPU is finished. 259 | x = jnp.ones([jax.local_device_count()]) 260 | x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)) 261 | print(x) 262 | 263 | if config.eval_only_once: 264 | break 265 | if config.early_exit_steps is not None: 266 | num_steps = config.early_exit_steps 267 | else: 268 | num_steps = config.max_steps 269 | if int(step) >= num_steps: 270 | break 271 | last_step = step 272 | 273 | 274 | if __name__ == '__main__': 275 | with gin.config_scope('eval'): 276 | app.run(main) 277 | -------------------------------------------------------------------------------- /internal/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions for handling configurations.""" 16 | 17 | import dataclasses 18 | from typing import Any, Callable, Optional, Tuple 19 | 20 | from absl import flags 21 | from flax.core import FrozenDict 22 | import gin 23 | from internal import utils 24 | import jax 25 | import jax.numpy as jnp 26 | 27 | gin.add_config_file_search_path('experimental/users/barron/mipnerf360/') 28 | 29 | configurables = { 30 | 'jnp': [jnp.reciprocal, jnp.log, jnp.log1p, jnp.exp, jnp.sqrt, jnp.square], 31 | 'jax.nn': [jax.nn.relu, jax.nn.softplus, jax.nn.silu], 32 | 'jax.nn.initializers.he_normal': [jax.nn.initializers.he_normal()], 33 | 'jax.nn.initializers.he_uniform': [jax.nn.initializers.he_uniform()], 34 | 'jax.nn.initializers.glorot_normal': [jax.nn.initializers.glorot_normal()], 35 | 'jax.nn.initializers.glorot_uniform': [ 36 | jax.nn.initializers.glorot_uniform() 37 | ], 38 | } 39 | 40 | for module, configurables in configurables.items(): 41 | for configurable in configurables: 42 | gin.config.external_configurable(configurable, module=module) 43 | 44 | 45 | @gin.configurable() 46 | @dataclasses.dataclass 47 | class Config: 48 | """Configuration flags for everything.""" 49 | dataset_loader: str = 'on-the-go' # The type of dataset loader to use. 50 | batching: str = 'all_images' # Batch composition, [single_image, all_images]. 51 | batch_size: int = 16384 # The number of rays/pixels in each batch. 52 | patch_size: int = 1 # Resolution of patches sampled for training batches. 53 | factor: int = 0 # The downsample factor of images, 0 for no downsampling. 54 | load_alphabetical: bool = True # Load images in COLMAP vs alphabetical 55 | # ordering (affects heldout test set). 56 | forward_facing: bool = False # Set to True for forward-facing captures. 57 | render_path: bool = False # If True, render a path. 58 | 59 | gc_every: int = 10000 # The number of steps between garbage collections. 60 | disable_multiscale_loss: bool = False # If True, disable multiscale loss. 61 | randomized: bool = True # Use randomized stratified sampling. 62 | near: float = 2. # Near plane distance. 63 | far: float = 6. # Far plane distance. 64 | checkpoint_dir: Optional[str] = None # Where to log checkpoints. 65 | render_dir: Optional[str] = None # Output rendering directory. 66 | data_dir: Optional[str] = None # Input data directory. 67 | render_chunk_size: int = 16384 # Chunk size for whole-image renderings. 68 | num_showcase_images: int = 5 # The number of test-set images to showcase. 69 | deterministic_showcase: bool = True # If True, showcase the same images. 70 | vis_num_rays: int = 16 # The number of rays to visualize. 71 | # Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage. 72 | vis_decimate: int = 0 73 | 74 | 75 | # Only used by train.py: 76 | max_steps: int = 250000 # The number of optimization steps. 77 | early_exit_steps: Optional[int] = None # Early stopping, for debugging. 78 | checkpoint_every: int = 25000 # The number of steps to save a checkpoint. 79 | print_every: int = 100 # The number of steps between reports to tensorboard. 80 | train_render_every: int = 5000 # Steps between test set renders when training 81 | cast_rays_in_train_step: bool = False # If True, compute rays in train step. 82 | data_loss_type: str = 'dino_ssim' # What kind of loss to use ('mse' or 'charb'). 83 | charb_padding: float = 0.001 # The padding used for Charbonnier loss. 84 | data_loss_mult: float = 0.5 # Mult for the finest data term in the loss. 85 | data_coarse_loss_mult: float = 0. # Multiplier for the coarser data terms. 86 | interlevel_loss_mult: float = 1.0 # Mult. for the loss on the proposal MLP. 87 | orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss. 88 | orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights. 89 | # RobustNerf loss hyperparameters 90 | robustnerf_inlier_quantile: float = 0.5 91 | enable_robustnerf_loss: bool = False 92 | robustnerf_inner_patch_size: int = 8 93 | robustnerf_smoothed_filter_size: int = 3 94 | robustnerf_smoothed_inlier_quantile: float = 0.5 95 | robustnerf_inner_patch_inlier_quantile: float = 0.5 96 | # What that loss is imposed on, options are 'normals' or 'normals_pred'. 97 | orientation_loss_target: str = 'normals_pred' 98 | predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss. 99 | # Mult. on the coarser predicted normal loss. 100 | predicted_normal_coarse_loss_mult: float = 0.0 101 | weight_decay_mults: FrozenDict[str, Any] = FrozenDict({}) # Weight decays. 102 | # An example that regularizes the NeRF and the first layer of the prop MLP: 103 | # weight_decay_mults:FrozenDict[str, Any] = FrozenDict({ 104 | # 'NerfMLP_0': 0.00001, 105 | # 'PropMLP_0/Dense_0': 0.001, 106 | # 'UncerMLP_0': 0.00001 107 | # }) 108 | # weight_decay_mults:FrozenDict[str, Any] = FrozenDict({ 109 | # 'UncerMLP_0': 0.00001 110 | # }) 111 | # Any model parameter that isn't specified gets a mult of 0. See the 112 | # train_weight_l2_* parameters in TensorBoard to know what can be regularized. 113 | 114 | lr_init: float = 0.002 # The initial learning rate. 115 | lr_final: float = 0.00002 # The final learning rate. 116 | lr_delay_steps: int = 512 # The number of "warmup" learning steps. 117 | lr_delay_mult: float = 0.01 # How much sever the "warmup" should be. 118 | adam_beta1: float = 0.9 # Adam's beta2 hyperparameter. 119 | adam_beta2: float = 0.999 # Adam's beta2 hyperparameter. 120 | adam_eps: float = 1e-6 # Adam's epsilon hyperparameter. 121 | grad_max_norm: float = 0.001 # Gradient clipping magnitude, disabled if == 0. 122 | grad_max_val: float = 0. # Gradient clipping value, disabled if == 0. 123 | distortion_loss_mult: float = 0.01 # Multiplier on the distortion loss. 124 | 125 | # Only used by eval.py: 126 | eval_only_once: bool = True # If True evaluate the model only once, ow loop. 127 | eval_save_output: bool = True # If True save predicted images to disk. 128 | eval_save_ray_data: bool = False # If True save individual ray traces. 129 | eval_render_interval: int = 1 # The interval between images saved to disk. 130 | eval_dataset_limit: int = jnp.iinfo(jnp.int32).max # Num test images to eval. 131 | eval_quantize_metrics: bool = True # If True, run metrics on 8-bit images. 132 | eval_crop_borders: int = 0 # Ignore c border pixels in eval (x[c:-c, c:-c]). 133 | 134 | # Only used by render.py 135 | is_render: bool = False # whether is rendering, for feature assignment bug 136 | render_video_fps: int = 60 # Framerate in frames-per-second. 137 | render_video_crf: int = 18 # Constant rate factor for ffmpeg video quality. 138 | render_path_frames: int = 120 # Number of frames in render path. 139 | z_variation: float = 0. # How much height variation in render path. 140 | z_phase: float = 0. # Phase offset for height variation in render path. 141 | render_dist_percentile: float = 0.5 # How much to trim from near/far planes. 142 | render_dist_curve_fn: Callable[..., Any] = jnp.log # How depth is curved. 143 | render_path_file: Optional[str] = None # Numpy render pose file to load. 144 | render_job_id: int = 0 # Render job id. 145 | render_num_jobs: int = 1 # Total number of render jobs. 146 | render_resolution: Optional[Tuple[int, int]] = None # Render resolution, as 147 | # (width, height). 148 | render_focal: Optional[float] = None # Render focal length. 149 | render_camtype: Optional[str] = None # 'perspective', 'fisheye', or 'pano'. 150 | render_spherical: bool = False # Render spherical 360 panoramas. 151 | render_save_async: bool = True # Save to CNS using a separate thread. 152 | 153 | render_spline_keyframes: Optional[str] = None # Text file containing names of 154 | # images to be used as spline 155 | # keyframes, OR directory 156 | # containing those images. 157 | render_spline_n_interp: int = 30 # Num. frames to interpolate per keyframe. 158 | render_spline_degree: int = 5 # Polynomial degree of B-spline interpolation. 159 | render_spline_smoothness: float = .03 # B-spline smoothing factor, 0 for 160 | # exact interpolation of keyframes. 161 | # Interpolate per-frame exposure value from spline keyframes. 162 | render_spline_interpolate_exposure: bool = False 163 | 164 | # Flags for raw datasets. 165 | rawnerf_mode: bool = False # Load raw images and train in raw color space. 166 | exposure_percentile: float = 97. # Image percentile to expose as white. 167 | num_border_pixels_to_mask: int = 0 # During training, discard N-pixel border 168 | # around each input image. 169 | apply_bayer_mask: bool = False # During training, apply Bayer mosaic mask. 170 | autoexpose_renders: bool = False # During rendering, autoexpose each image. 171 | # For raw test scenes, use affine raw-space color correction. 172 | eval_raw_affine_cc: bool = False 173 | 174 | # dino configs 175 | dino_var_mult: float = 0.1 # multiplier for the variance of the dino features as regularization 176 | compute_feature_metrics: bool = True # If True, compute feature. 177 | feat_rate: int = 4 # Feature sampling rate w.r.t original image size. 178 | dilate: int = 8 # The dilate rate for the patch size 179 | eval_train: bool = True # evaluate test set or train set, for debug 180 | train_clean: bool = False # train on clean set or clutter set 181 | reg_mult: float = 0.5 # reg weight for uncertainty reg 182 | uncer_lr_rate: float = 1 # lr rate w.r.t nerf 183 | uncer_clip_min: float = 0.1 # minimum value to clip the uncertainty 184 | ssim_clip_max: float = 5 # maximum value to clip the ssim 185 | ssim_mult: float = 0.5 #multiplicative factor for ssim loss 186 | H: int = 3024 #height of the image 187 | W: int = 4032 #weight of the image 188 | ssim_anneal: float= 0.8 # anneal rate for ssim 189 | stop_ssim_gradient: bool = True # whether to stop the gradient flow from ssim to reconstruction 190 | ssim_window_size: int = 5 # window size of ssim 191 | mask_type: str = 'masks' # use the mask of which folder 192 | feat_dim: int = 384 # feature dimension, 384 for dino_s/14 193 | feat_ds: int = 14 # feature downsample rate for dino, 14 for dino_s/14, combine together with feat_rate 194 | 195 | 196 | def define_common_flags(): 197 | # Define the flags used by both train.py and eval.py 198 | flags.DEFINE_string('mode', None, 'Required by GINXM, not used.') 199 | flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.') 200 | flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') 201 | flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.') 202 | 203 | 204 | def load_config(save_config=True): 205 | """Load the config, and optionally checkpoint it.""" 206 | gin.parse_config_files_and_bindings( 207 | flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True) 208 | config = Config() 209 | if save_config and jax.host_id() == 0: 210 | utils.makedirs(config.checkpoint_dir) 211 | with utils.open_file(config.checkpoint_dir + '/config.gin', 'w') as f: 212 | f.write(gin.config_str()) 213 | return config 214 | -------------------------------------------------------------------------------- /internal/coord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tools for manipulating coordinate spaces and distances along rays.""" 15 | 16 | from internal import math 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | 21 | def contract(x): 22 | """Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077).""" 23 | eps = jnp.finfo(jnp.float32).eps 24 | # Clamping to eps prevents non-finite gradients when x == 0. 25 | x_mag_sq = jnp.maximum(eps, jnp.sum(x**2, axis=-1, keepdims=True)) 26 | z = jnp.where(x_mag_sq <= 1, x, ((2 * jnp.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) 27 | return z 28 | 29 | 30 | def inv_contract(z): 31 | """The inverse of contract().""" 32 | eps = jnp.finfo(jnp.float32).eps 33 | # Clamping to eps prevents non-finite gradients when z == 0. 34 | z_mag_sq = jnp.maximum(eps, jnp.sum(z**2, axis=-1, keepdims=True)) 35 | x = jnp.where(z_mag_sq <= 1, z, z / (2 * jnp.sqrt(z_mag_sq) - z_mag_sq)) 36 | return x 37 | 38 | 39 | def track_linearize(fn, mean, cov): 40 | """Apply function `fn` to a set of means and covariances, ala a Kalman filter. 41 | 42 | We can analytically transform a Gaussian parameterized by `mean` and `cov` 43 | with a function `fn` by linearizing `fn` around `mean`, and taking advantage 44 | of the fact that Covar[Ax + y] = A(Covar[x])A^T (see 45 | https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details). 46 | 47 | Args: 48 | fn: the function applied to the Gaussians parameterized by (mean, cov). 49 | mean: a tensor of means, where the last axis is the dimension. 50 | cov: a tensor of covariances, where the last two axes are the dimensions. 51 | 52 | Returns: 53 | fn_mean: the transformed means. 54 | fn_cov: the transformed covariances. 55 | """ 56 | if (len(mean.shape) + 1) != len(cov.shape): 57 | raise ValueError('cov must be non-diagonal') 58 | fn_mean, lin_fn = jax.linearize(fn, mean) 59 | fn_cov = jax.vmap(lin_fn, -1, -2)(jax.vmap(lin_fn, -1, -2)(cov)) 60 | return fn_mean, fn_cov 61 | 62 | 63 | def construct_ray_warps(fn, t_near, t_far): 64 | """Construct a bijection between metric distances and normalized distances. 65 | 66 | See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a 67 | detailed explanation. 68 | 69 | Args: 70 | fn: the function to ray distances. 71 | t_near: a tensor of near-plane distances. 72 | t_far: a tensor of far-plane distances. 73 | 74 | Returns: 75 | t_to_s: a function that maps distances to normalized distances in [0, 1]. 76 | s_to_t: the inverse of t_to_s. 77 | """ 78 | if fn is None: 79 | fn_fwd = lambda x: x 80 | fn_inv = lambda x: x 81 | elif fn == 'piecewise': 82 | # Piecewise spacing combining identity and 1/x functions to allow t_near=0. 83 | fn_fwd = lambda x: jnp.where(x < 1, .5 * x, 1 - .5 / x) 84 | fn_inv = lambda x: jnp.where(x < .5, 2 * x, .5 / (1 - x)) 85 | else: 86 | inv_mapping = { 87 | 'reciprocal': jnp.reciprocal, 88 | 'log': jnp.exp, 89 | 'exp': jnp.log, 90 | 'sqrt': jnp.square, 91 | 'square': jnp.sqrt 92 | } 93 | fn_fwd = fn 94 | fn_inv = inv_mapping[fn.__name__] 95 | 96 | s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)] 97 | t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near) 98 | s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near) 99 | return t_to_s, s_to_t 100 | 101 | 102 | def expected_sin(mean, var): 103 | """Compute the mean of sin(x), x ~ N(mean, var).""" 104 | return jnp.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value. 105 | 106 | 107 | def integrated_pos_enc(mean, var, min_deg, max_deg): 108 | """Encode `x` with sinusoids scaled by 2^[min_deg, max_deg). 109 | 110 | Args: 111 | mean: tensor, the mean coordinates to be encoded 112 | var: tensor, the variance of the coordinates to be encoded. 113 | min_deg: int, the min degree of the encoding. 114 | max_deg: int, the max degree of the encoding. 115 | 116 | Returns: 117 | encoded: jnp.ndarray, encoded variables. 118 | """ 119 | scales = 2**jnp.arange(min_deg, max_deg) 120 | shape = mean.shape[:-1] + (-1,) 121 | scaled_mean = jnp.reshape(mean[..., None, :] * scales[:, None], shape) 122 | scaled_var = jnp.reshape(var[..., None, :] * scales[:, None]**2, shape) 123 | 124 | return expected_sin( 125 | jnp.concatenate([scaled_mean, scaled_mean + 0.5 * jnp.pi], axis=-1), 126 | jnp.concatenate([scaled_var] * 2, axis=-1)) 127 | 128 | 129 | def lift_and_diagonalize(mean, cov, basis): 130 | """Project `mean` and `cov` onto basis and diagonalize the projected cov.""" 131 | fn_mean = math.matmul(mean, basis) 132 | fn_cov_diag = jnp.sum(basis * math.matmul(cov, basis), axis=-2) 133 | return fn_mean, fn_cov_diag 134 | 135 | 136 | def pos_enc(x, min_deg, max_deg, append_identity=True): 137 | """The positional encoding used by the original NeRF paper.""" 138 | scales = 2**jnp.arange(min_deg, max_deg) 139 | shape = x.shape[:-1] + (-1,) 140 | scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape) 141 | # Note that we're not using safe_sin, unlike IPE. 142 | four_feat = jnp.sin( 143 | jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1)) 144 | if append_identity: 145 | return jnp.concatenate([x] + [four_feat], axis=-1) 146 | else: 147 | return four_feat 148 | -------------------------------------------------------------------------------- /internal/geopoly.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for constructing geodesic polyhedron, which are used as a basis.""" 16 | 17 | import itertools 18 | import numpy as np 19 | 20 | 21 | def compute_sq_dist(mat0, mat1=None): 22 | """Compute the squared Euclidean distance between all pairs of columns.""" 23 | if mat1 is None: 24 | mat1 = mat0 25 | # Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y. 26 | sq_norm0 = np.sum(mat0**2, 0) 27 | sq_norm1 = np.sum(mat1**2, 0) 28 | sq_dist = sq_norm0[:, None] + sq_norm1[None, :] - 2 * mat0.T @ mat1 29 | sq_dist = np.maximum(0, sq_dist) # Negative values must be numerical errors. 30 | return sq_dist 31 | 32 | 33 | def compute_tesselation_weights(v): 34 | """Tesselate the vertices of a triangle by a factor of `v`.""" 35 | if v < 1: 36 | raise ValueError(f'v {v} must be >= 1') 37 | int_weights = [] 38 | for i in range(v + 1): 39 | for j in range(v + 1 - i): 40 | int_weights.append((i, j, v - (i + j))) 41 | int_weights = np.array(int_weights) 42 | weights = int_weights / v # Barycentric weights. 43 | return weights 44 | 45 | 46 | def tesselate_geodesic(base_verts, base_faces, v, eps=1e-4): 47 | """Tesselate the vertices of a geodesic polyhedron. 48 | 49 | Args: 50 | base_verts: tensor of floats, the vertex coordinates of the geodesic. 51 | base_faces: tensor of ints, the indices of the vertices of base_verts that 52 | constitute eachface of the polyhedra. 53 | v: int, the factor of the tesselation (v==1 is a no-op). 54 | eps: float, a small value used to determine if two vertices are the same. 55 | 56 | Returns: 57 | verts: a tensor of floats, the coordinates of the tesselated vertices. 58 | """ 59 | if not isinstance(v, int): 60 | raise ValueError(f'v {v} must an integer') 61 | tri_weights = compute_tesselation_weights(v) 62 | 63 | verts = [] 64 | for base_face in base_faces: 65 | new_verts = np.matmul(tri_weights, base_verts[base_face, :]) 66 | new_verts /= np.sqrt(np.sum(new_verts**2, 1, keepdims=True)) 67 | verts.append(new_verts) 68 | verts = np.concatenate(verts, 0) 69 | 70 | sq_dist = compute_sq_dist(verts.T) 71 | assignment = np.array([np.min(np.argwhere(d <= eps)) for d in sq_dist]) 72 | unique = np.unique(assignment) 73 | verts = verts[unique, :] 74 | 75 | return verts 76 | 77 | 78 | def generate_basis(base_shape, 79 | angular_tesselation, 80 | remove_symmetries=True, 81 | eps=1e-4): 82 | """Generates a 3D basis by tesselating a geometric polyhedron. 83 | 84 | Args: 85 | base_shape: string, the name of the starting polyhedron, must be either 86 | 'icosahedron' or 'octahedron'. 87 | angular_tesselation: int, the number of times to tesselate the polyhedron, 88 | must be >= 1 (a value of 1 is a no-op to the polyhedron). 89 | remove_symmetries: bool, if True then remove the symmetric basis columns, 90 | which is usually a good idea because otherwise projections onto the basis 91 | will have redundant negative copies of each other. 92 | eps: float, a small number used to determine symmetries. 93 | 94 | Returns: 95 | basis: a matrix with shape [3, n]. 96 | """ 97 | if base_shape == 'icosahedron': 98 | a = (np.sqrt(5) + 1) / 2 99 | verts = np.array([(-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1), 100 | (0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0), 101 | (-a, 1, 0), (a, -1, 0), (-a, -1, 0)]) / np.sqrt(a + 2) 102 | faces = np.array([(0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1), 103 | (8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3), 104 | (7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6), 105 | (6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5), 106 | (7, 2, 11)]) 107 | verts = tesselate_geodesic(verts, faces, angular_tesselation) 108 | elif base_shape == 'octahedron': 109 | verts = np.array([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), 110 | (1, 0, 0)]) 111 | corners = np.array(list(itertools.product([-1, 1], repeat=3))) 112 | pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2) 113 | faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1) 114 | verts = tesselate_geodesic(verts, faces, angular_tesselation) 115 | else: 116 | raise ValueError(f'base_shape {base_shape} not supported') 117 | 118 | if remove_symmetries: 119 | # Remove elements of `verts` that are reflections of each other. 120 | match = compute_sq_dist(verts.T, -verts.T) < eps 121 | verts = verts[np.any(np.triu(match), 1), :] 122 | 123 | basis = verts[:, ::-1] 124 | return basis 125 | -------------------------------------------------------------------------------- /internal/image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for processing images.""" 16 | 17 | import types 18 | from typing import Optional, Union 19 | 20 | import dm_pix 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | import lpips_jax 25 | 26 | _Array = Union[np.ndarray, jnp.ndarray] 27 | 28 | 29 | def mse_to_psnr(mse): 30 | """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" 31 | return -10. / jnp.log(10.) * jnp.log(mse) 32 | 33 | 34 | def psnr_to_mse(psnr): 35 | """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" 36 | return jnp.exp(-0.1 * jnp.log(10.) * psnr) 37 | 38 | 39 | def ssim_to_dssim(ssim): 40 | """Compute DSSIM given an SSIM.""" 41 | return (1 - ssim) / 2 42 | 43 | 44 | def dssim_to_ssim(dssim): 45 | """Compute DSSIM given an SSIM.""" 46 | return 1 - 2 * dssim 47 | 48 | 49 | def linear_to_srgb(linear: _Array, 50 | eps: Optional[float] = None, 51 | xnp: types.ModuleType = jnp) -> _Array: 52 | """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" 53 | if eps is None: 54 | eps = xnp.finfo(xnp.float32).eps 55 | srgb0 = 323 / 25 * linear 56 | srgb1 = (211 * xnp.maximum(eps, linear)**(5 / 12) - 11) / 200 57 | return xnp.where(linear <= 0.0031308, srgb0, srgb1) 58 | 59 | 60 | def srgb_to_linear(srgb: _Array, 61 | eps: Optional[float] = None, 62 | xnp: types.ModuleType = jnp) -> _Array: 63 | """Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" 64 | if eps is None: 65 | eps = xnp.finfo(xnp.float32).eps 66 | linear0 = 25 / 323 * srgb 67 | linear1 = xnp.maximum(eps, ((200 * srgb + 11) / (211)))**(12 / 5) 68 | return xnp.where(srgb <= 0.04045, linear0, linear1) 69 | 70 | 71 | def downsample(img, factor): 72 | """Area downsample img (factor must evenly divide img height and width).""" 73 | sh = img.shape 74 | if not (sh[0] % factor == 0 and sh[1] % factor == 0): 75 | raise ValueError(f'Downsampling factor {factor} does not ' 76 | f'evenly divide image shape {sh[:2]}') 77 | img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:]) 78 | img = img.mean((1, 3)) 79 | return img 80 | 81 | 82 | def color_correct(img, ref, num_iters=5, eps=0.5 / 255): 83 | """Warp `img` to match the colors in `ref_img`.""" 84 | if img.shape[-1] != ref.shape[-1]: 85 | raise ValueError( 86 | f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match' 87 | ) 88 | num_channels = img.shape[-1] 89 | img_mat = img.reshape([-1, num_channels]) 90 | ref_mat = ref.reshape([-1, num_channels]) 91 | is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps)) # z \in [eps, 1-eps]. 92 | mask0 = is_unclipped(img_mat) 93 | # Because the set of saturated pixels may change after solving for a 94 | # transformation, we repeatedly solve a system `num_iters` times and update 95 | # our estimate of which pixels are saturated. 96 | for _ in range(num_iters): 97 | # Construct the left hand side of a linear system that contains a quadratic 98 | # expansion of each pixel of `img`. 99 | a_mat = [] 100 | for c in range(num_channels): 101 | a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) # Quadratic term. 102 | a_mat.append(img_mat) # Linear term. 103 | a_mat.append(jnp.ones_like(img_mat[:, :1])) # Bias term. 104 | a_mat = jnp.concatenate(a_mat, axis=-1) 105 | warp = [] 106 | for c in range(num_channels): 107 | # Construct the right hand side of a linear system containing each color 108 | # of `ref`. 109 | b = ref_mat[:, c] 110 | # Ignore rows of the linear system that were saturated in the input or are 111 | # saturated in the current corrected color estimate. 112 | mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) 113 | ma_mat = jnp.where(mask[:, None], a_mat, 0) 114 | mb = jnp.where(mask, b, 0) 115 | # Solve the linear system. We're using the np.lstsq instead of jnp because 116 | # it's significantly more stable in this case, for some reason. 117 | w = np.linalg.lstsq(ma_mat, mb, rcond=-1)[0] 118 | assert jnp.all(jnp.isfinite(w)) 119 | warp.append(w) 120 | warp = jnp.stack(warp, axis=-1) 121 | # Apply the warp to update img_mat. 122 | img_mat = jnp.clip( 123 | jnp.matmul(a_mat, warp, precision=jax.lax.Precision.HIGHEST), 0, 1) 124 | corrected_img = jnp.reshape(img_mat, img.shape) 125 | return corrected_img 126 | 127 | 128 | class MetricHarness: 129 | """A helper class for evaluating several error metrics.""" 130 | 131 | def __init__(self): 132 | self.ssim_fn = jax.jit(dm_pix.ssim) 133 | 134 | def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s): 135 | """Evaluate the error between a predicted rgb image and the true image.""" 136 | psnr = float(mse_to_psnr(((rgb_pred - rgb_gt)**2).mean())) 137 | ssim = float(self.ssim_fn(rgb_pred, rgb_gt)) 138 | 139 | return { 140 | name_fn('psnr'): psnr, 141 | name_fn('ssim'): ssim, 142 | } 143 | 144 | 145 | class MetricHarnessLPIPS: 146 | """A helper class for evaluating several error metrics with vgg16 lpips.""" 147 | 148 | def __init__(self): 149 | self.ssim_fn = jax.jit(dm_pix.ssim) 150 | self.lpips_fn = lpips_jax.LPIPSEvaluator(replicate=False, net='vgg16') # ['alexnet', 'vgg16'] 151 | 152 | def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s): 153 | """Evaluate the error between a predicted rgb image and the true image.""" 154 | psnr = float(mse_to_psnr(((rgb_pred - rgb_gt)**2).mean())) 155 | ssim = float(self.ssim_fn(rgb_pred, rgb_gt)) 156 | # To fix lpips calculation, now it's broken 157 | # lpips = float(self.lpips_fn(rgb_pred[None,...]*2-1, rgb_gt[None,...]*2-1)) 158 | 159 | return { 160 | name_fn('psnr'): psnr, 161 | name_fn('ssim'): ssim, 162 | # name_fn('lpips'): lpips, 163 | } 164 | -------------------------------------------------------------------------------- /internal/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mathy utility functions.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | 21 | def matmul(a, b): 22 | """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" 23 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 24 | 25 | 26 | def safe_trig_helper(x, fn, t=100 * jnp.pi): 27 | """Helper function used by safe_cos/safe_sin: mods x before sin()/cos().""" 28 | return fn(jnp.where(jnp.abs(x) < t, x, x % t)) 29 | 30 | 31 | def safe_cos(x): 32 | """jnp.cos() on a TPU may NaN out for large values.""" 33 | return safe_trig_helper(x, jnp.cos) 34 | 35 | 36 | def safe_sin(x): 37 | """jnp.sin() on a TPU may NaN out for large values.""" 38 | return safe_trig_helper(x, jnp.sin) 39 | 40 | 41 | @jax.custom_jvp 42 | def safe_exp(x): 43 | """jnp.exp() but with finite output and gradients for large inputs.""" 44 | return jnp.exp(jnp.minimum(x, 88.)) # jnp.exp(89) is infinity. 45 | 46 | 47 | @safe_exp.defjvp 48 | def safe_exp_jvp(primals, tangents): 49 | """Override safe_exp()'s gradient so that it's large when inputs are large.""" 50 | x, = primals 51 | x_dot, = tangents 52 | exp_x = safe_exp(x) 53 | exp_x_dot = exp_x * x_dot 54 | return exp_x, exp_x_dot 55 | 56 | 57 | def log_lerp(t, v0, v1): 58 | """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" 59 | if v0 <= 0 or v1 <= 0: 60 | raise ValueError(f'Interpolants {v0} and {v1} must be positive.') 61 | lv0 = jnp.log(v0) 62 | lv1 = jnp.log(v1) 63 | return jnp.exp(jnp.clip(t, 0, 1) * (lv1 - lv0) + lv0) 64 | 65 | 66 | def learning_rate_decay(step, 67 | lr_init, 68 | lr_final, 69 | max_steps, 70 | lr_delay_steps=0, 71 | lr_delay_mult=1): 72 | """Continuous learning rate decay function. 73 | 74 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 75 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 76 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 77 | function of lr_delay_mult, such that the initial learning rate is 78 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 79 | to the normal learning rate when steps>lr_delay_steps. 80 | 81 | Args: 82 | step: int, the current optimization step. 83 | lr_init: float, the initial learning rate. 84 | lr_final: float, the final learning rate. 85 | max_steps: int, the number of steps during optimization. 86 | lr_delay_steps: int, the number of steps to delay the full learning rate. 87 | lr_delay_mult: float, the multiplier on the rate when delaying it. 88 | 89 | Returns: 90 | lr: the learning for current step 'step'. 91 | """ 92 | if lr_delay_steps > 0: 93 | # A kind of reverse cosine decay. 94 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * jnp.sin( 95 | 0.5 * jnp.pi * jnp.clip(step / lr_delay_steps, 0, 1)) 96 | else: 97 | delay_rate = 1. 98 | return delay_rate * log_lerp(step / max_steps, lr_init, lr_final) 99 | 100 | 101 | def interp(*args): 102 | """A gather-based (GPU-friendly) vectorized replacement for jnp.interp().""" 103 | args_flat = [x.reshape([-1, x.shape[-1]]) for x in args] 104 | ret = jax.vmap(jnp.interp)(*args_flat).reshape(args[0].shape) 105 | return ret 106 | 107 | 108 | def sorted_interp(x, xp, fp): 109 | """A TPU-friendly version of interp(), where xp and fp must be sorted.""" 110 | 111 | # Identify the location in `xp` that corresponds to each `x`. 112 | # The final `True` index in `mask` is the start of the matching interval. 113 | mask = x[..., None, :] >= xp[..., :, None] 114 | 115 | def find_interval(x): 116 | # Grab the value where `mask` switches from True to False, and vice versa. 117 | # This approach takes advantage of the fact that `x` is sorted. 118 | x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2) 119 | x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2) 120 | return x0, x1 121 | 122 | fp0, fp1 = find_interval(fp) 123 | xp0, xp1 = find_interval(xp) 124 | 125 | offset = jnp.clip(jnp.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) 126 | ret = fp0 + offset * (fp1 - fp0) 127 | return ret 128 | -------------------------------------------------------------------------------- /internal/raw_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for processing and loading raw image data.""" 16 | 17 | import glob 18 | import json 19 | import os 20 | import types 21 | from typing import Any, Mapping, MutableMapping, Optional, Sequence, Tuple, Union 22 | 23 | from internal import image as lib_image 24 | from internal import math 25 | from internal import utils 26 | import jax 27 | import jax.numpy as jnp 28 | import numpy as np 29 | import rawpy 30 | 31 | _Array = Union[np.ndarray, jnp.ndarray] 32 | _Axis = Optional[Union[int, Tuple[int, ...]]] 33 | 34 | 35 | def postprocess_raw(raw: _Array, 36 | camtorgb: _Array, 37 | exposure: Optional[float] = None, 38 | xnp: types.ModuleType = np) -> _Array: 39 | """Converts demosaicked raw to sRGB with a minimal postprocessing pipeline. 40 | 41 | Numpy array inputs will be automatically converted to Jax arrays. 42 | 43 | Args: 44 | raw: [H, W, 3], demosaicked raw camera image. 45 | camtorgb: [3, 3], color correction transformation to apply to raw image. 46 | exposure: color value to be scaled to pure white after color correction. 47 | If None, "autoexposes" at the 97th percentile. 48 | xnp: either numpy or jax.numpy. 49 | 50 | Returns: 51 | srgb: [H, W, 3], color corrected + exposed + gamma mapped image. 52 | """ 53 | if raw.shape[-1] != 3: 54 | raise ValueError(f'raw.shape[-1] is {raw.shape[-1]}, expected 3') 55 | if camtorgb.shape != (3, 3): 56 | raise ValueError(f'camtorgb.shape is {camtorgb.shape}, expected (3, 3)') 57 | # Convert from camera color space to standard linear RGB color space. 58 | matmul = math.matmul if xnp == jnp else np.matmul 59 | rgb_linear = matmul(raw, camtorgb.T) 60 | if exposure is None: 61 | exposure = xnp.percentile(rgb_linear, 97) 62 | # "Expose" image by mapping the input exposure level to white and clipping. 63 | rgb_linear_scaled = xnp.clip(rgb_linear / exposure, 0, 1) 64 | # Apply sRGB gamma curve to serve as a simple tonemap. 65 | srgb = lib_image.linear_to_srgb(rgb_linear_scaled, xnp=xnp) 66 | return srgb 67 | 68 | 69 | def pixels_to_bayer_mask(pix_x: np.ndarray, pix_y: np.ndarray) -> np.ndarray: 70 | """Computes binary RGB Bayer mask values from integer pixel coordinates.""" 71 | # Red is top left (0, 0). 72 | r = (pix_x % 2 == 0) * (pix_y % 2 == 0) 73 | # Green is top right (0, 1) and bottom left (1, 0). 74 | g = (pix_x % 2 == 1) * (pix_y % 2 == 0) + (pix_x % 2 == 0) * (pix_y % 2 == 1) 75 | # Blue is bottom right (1, 1). 76 | b = (pix_x % 2 == 1) * (pix_y % 2 == 1) 77 | return np.stack([r, g, b], -1).astype(np.float32) 78 | 79 | 80 | def bilinear_demosaic(bayer: _Array, 81 | xnp: types.ModuleType) -> _Array: 82 | """Converts Bayer data into a full RGB image using bilinear demosaicking. 83 | 84 | Input data should be ndarray of shape [height, width] with 2x2 mosaic pattern: 85 | ------------- 86 | |red |green| 87 | ------------- 88 | |green|blue | 89 | ------------- 90 | Red and blue channels are bilinearly upsampled 2x, missing green channel 91 | elements are the average of the neighboring 4 values in a cross pattern. 92 | 93 | Args: 94 | bayer: [H, W] array, Bayer mosaic pattern input image. 95 | xnp: either numpy or jax.numpy. 96 | 97 | Returns: 98 | rgb: [H, W, 3] array, full RGB image. 99 | """ 100 | def reshape_quads(*planes): 101 | """Reshape pixels from four input images to make tiled 2x2 quads.""" 102 | planes = xnp.stack(planes, -1) 103 | shape = planes.shape[:-1] 104 | # Create [2, 2] arrays out of 4 channels. 105 | zup = planes.reshape(shape + (2, 2,)) 106 | # Transpose so that x-axis dimensions come before y-axis dimensions. 107 | zup = xnp.transpose(zup, (0, 2, 1, 3)) 108 | # Reshape to 2D. 109 | zup = zup.reshape((shape[0] * 2, shape[1] * 2)) 110 | return zup 111 | 112 | def bilinear_upsample(z): 113 | """2x bilinear image upsample.""" 114 | # Using np.roll makes the right and bottom edges wrap around. The raw image 115 | # data has a few garbage columns/rows at the edges that must be discarded 116 | # anyway, so this does not matter in practice. 117 | # Horizontally interpolated values. 118 | zx = .5 * (z + xnp.roll(z, -1, axis=-1)) 119 | # Vertically interpolated values. 120 | zy = .5 * (z + xnp.roll(z, -1, axis=-2)) 121 | # Diagonally interpolated values. 122 | zxy = .5 * (zx + xnp.roll(zx, -1, axis=-2)) 123 | return reshape_quads(z, zx, zy, zxy) 124 | 125 | def upsample_green(g1, g2): 126 | """Special 2x upsample from the two green channels.""" 127 | z = xnp.zeros_like(g1) 128 | z = reshape_quads(z, g1, g2, z) 129 | alt = 0 130 | # Grab the 4 directly adjacent neighbors in a "cross" pattern. 131 | for i in range(4): 132 | axis = -1 - (i // 2) 133 | roll = -1 + 2 * (i % 2) 134 | alt = alt + .25 * xnp.roll(z, roll, axis=axis) 135 | # For observed pixels, alt = 0, and for unobserved pixels, alt = avg(cross), 136 | # so alt + z will have every pixel filled in. 137 | return alt + z 138 | 139 | r, g1, g2, b = [bayer[(i//2)::2, (i%2)::2] for i in range(4)] 140 | r = bilinear_upsample(r) 141 | # Flip in x and y before and after calling upsample, as bilinear_upsample 142 | # assumes that the samples are at the top-left corner of the 2x2 sample. 143 | b = bilinear_upsample(b[::-1, ::-1])[::-1, ::-1] 144 | g = upsample_green(g1, g2) 145 | rgb = xnp.stack([r, g, b], -1) 146 | return rgb 147 | 148 | 149 | bilinear_demosaic_jax = jax.jit(lambda bayer: bilinear_demosaic(bayer, xnp=jnp)) 150 | 151 | 152 | def load_raw_images(image_dir: str, 153 | image_names: Optional[Sequence[str]] = None 154 | ) -> Tuple[np.ndarray, Sequence[Mapping[str, Any]]]: 155 | """Loads raw images and their metadata from disk. 156 | 157 | Args: 158 | image_dir: directory containing raw image and EXIF data. 159 | image_names: files to load (ignores file extension), loads all DNGs if None. 160 | 161 | Returns: 162 | A tuple (images, exifs). 163 | images: [N, height, width, 3] array of raw sensor data. 164 | exifs: [N] list of dicts, one per image, containing the EXIF data. 165 | Raises: 166 | ValueError: The requested `image_dir` does not exist on disk. 167 | """ 168 | 169 | if not utils.file_exists(image_dir): 170 | raise ValueError(f'Raw image folder {image_dir} does not exist.') 171 | 172 | # Load raw images (dng files) and exif metadata (json files). 173 | def load_raw_exif(image_name): 174 | base = os.path.join(image_dir, os.path.splitext(image_name)[0]) 175 | with utils.open_file(base + '.dng', 'rb') as f: 176 | raw = rawpy.imread(f).raw_image 177 | with utils.open_file(base + '.json', 'rb') as f: 178 | exif = json.load(f)[0] 179 | return raw, exif 180 | 181 | if image_names is None: 182 | image_names = [ 183 | os.path.basename(f) 184 | for f in sorted(glob.glob(os.path.join(image_dir, '*.dng'))) 185 | ] 186 | 187 | data = [load_raw_exif(x) for x in image_names] 188 | raws, exifs = zip(*data) 189 | raws = np.stack(raws, axis=0).astype(np.float32) 190 | 191 | return raws, exifs 192 | 193 | 194 | # Brightness percentiles to use for re-exposing and tonemapping raw images. 195 | _PERCENTILE_LIST = (80, 90, 97, 99, 100) 196 | 197 | # Relevant fields to extract from raw image EXIF metadata. 198 | # For details regarding EXIF parameters, see: 199 | # https://www.adobe.com/content/dam/acom/en/products/photoshop/pdfs/dng_spec_1.4.0.0.pdf. 200 | _EXIF_KEYS = ( 201 | 'BlackLevel', # Black level offset added to sensor measurements. 202 | 'WhiteLevel', # Maximum possible sensor measurement. 203 | 'AsShotNeutral', # RGB white balance coefficients. 204 | 'ColorMatrix2', # XYZ to camera color space conversion matrix. 205 | 'NoiseProfile', # Shot and read noise levels. 206 | ) 207 | 208 | # Color conversion from reference illuminant XYZ to RGB color space. 209 | # See http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html. 210 | _RGB2XYZ = np.array([[0.4124564, 0.3575761, 0.1804375], 211 | [0.2126729, 0.7151522, 0.0721750], 212 | [0.0193339, 0.1191920, 0.9503041]]) 213 | 214 | 215 | def process_exif( 216 | exifs: Sequence[Mapping[str, Any]]) -> MutableMapping[str, Any]: 217 | """Processes list of raw image EXIF data into useful metadata dict. 218 | 219 | Input should be a list of dictionaries loaded from JSON files. 220 | These JSON files are produced by running 221 | $ exiftool -json IMAGE.dng > IMAGE.json 222 | for each input raw file. 223 | 224 | We extract only the parameters relevant to 225 | 1. Rescaling the raw data to [0, 1], 226 | 2. White balance and color correction, and 227 | 3. Noise level estimation. 228 | 229 | Args: 230 | exifs: a list of dicts containing EXIF data as loaded from JSON files. 231 | 232 | Returns: 233 | meta: a dict of the relevant metadata for running RawNeRF. 234 | """ 235 | meta = {} 236 | exif = exifs[0] 237 | # Convert from array of dicts (exifs) to dict of arrays (meta). 238 | for key in _EXIF_KEYS: 239 | exif_value = exif.get(key) 240 | if exif_value is None: 241 | continue 242 | # Values can be a single int or float... 243 | if isinstance(exif_value, int) or isinstance(exif_value, float): 244 | vals = [x[key] for x in exifs] 245 | # Or a string of numbers with ' ' between. 246 | elif isinstance(exif_value, str): 247 | vals = [[float(z) for z in x[key].split(' ')] for x in exifs] 248 | meta[key] = np.squeeze(np.array(vals)) 249 | # Shutter speed is a special case, a string written like 1/N. 250 | meta['ShutterSpeed'] = np.fromiter( 251 | (1. / float(exif['ShutterSpeed'].split('/')[1]) for exif in exifs), float) 252 | 253 | # Create raw-to-sRGB color transform matrices. Pipeline is: 254 | # cam space -> white balanced cam space ("camwb") -> XYZ space -> RGB space. 255 | # 'AsShotNeutral' is an RGB triplet representing how pure white would measure 256 | # on the sensor, so dividing by these numbers corrects the white balance. 257 | whitebalance = meta['AsShotNeutral'].reshape(-1, 3) 258 | cam2camwb = np.array([np.diag(1. / x) for x in whitebalance]) 259 | # ColorMatrix2 converts from XYZ color space to "reference illuminant" (white 260 | # balanced) camera space. 261 | xyz2camwb = meta['ColorMatrix2'].reshape(-1, 3, 3) 262 | rgb2camwb = xyz2camwb @ _RGB2XYZ 263 | # We normalize the rows of the full color correction matrix, as is done in 264 | # https://github.com/AbdoKamel/simple-camera-pipeline. 265 | rgb2camwb /= rgb2camwb.sum(axis=-1, keepdims=True) 266 | # Combining color correction with white balance gives the entire transform. 267 | cam2rgb = np.linalg.inv(rgb2camwb) @ cam2camwb 268 | meta['cam2rgb'] = cam2rgb 269 | 270 | return meta 271 | 272 | 273 | def load_raw_dataset(split: utils.DataSplit, 274 | data_dir: str, 275 | image_names: Sequence[str], 276 | exposure_percentile: float, 277 | n_downsample: int, 278 | ) -> Tuple[np.ndarray, MutableMapping[str, Any], bool]: 279 | """Loads and processes a set of RawNeRF input images. 280 | 281 | Includes logic necessary for special "test" scenes that include a noiseless 282 | ground truth frame, produced by HDR+ merge. 283 | 284 | Args: 285 | split: DataSplit.TRAIN or DataSplit.TEST, only used for test scene logic. 286 | data_dir: base directory for scene data. 287 | image_names: which images were successfully posed by COLMAP. 288 | exposure_percentile: what brightness percentile to expose to white. 289 | n_downsample: returned images are downsampled by a factor of n_downsample. 290 | 291 | Returns: 292 | A tuple (images, meta, testscene). 293 | images: [N, height // n_downsample, width // n_downsample, 3] array of 294 | demosaicked raw image data. 295 | meta: EXIF metadata and other useful processing parameters. Includes per 296 | image exposure information that can be passed into the NeRF model with 297 | each ray: the set of unique exposure times is determined and each image 298 | assigned a corresponding exposure index (mapping to an exposure value). 299 | These are keys 'unique_shutters', 'exposure_idx', and 'exposure_value' in 300 | the `meta` dictionary. 301 | We rescale so the maximum `exposure_value` is 1 for convenience. 302 | testscene: True when dataset includes ground truth test image, else False. 303 | """ 304 | 305 | image_dir = os.path.join(data_dir, 'raw') 306 | 307 | testimg_file = os.path.join(data_dir, 'hdrplus_test/merged.dng') 308 | testscene = utils.file_exists(testimg_file) 309 | if testscene: 310 | # Test scenes have train/ and test/ split subdirectories inside raw/. 311 | image_dir = os.path.join(image_dir, split.value) 312 | if split == utils.DataSplit.TEST: 313 | # COLMAP image names not valid for test split of test scene. 314 | image_names = None 315 | else: 316 | # Discard the first COLMAP image name as it is a copy of the test image. 317 | image_names = image_names[1:] 318 | 319 | raws, exifs = load_raw_images(image_dir, image_names) 320 | meta = process_exif(exifs) 321 | 322 | if testscene and split == utils.DataSplit.TEST: 323 | # Test split for test scene must load the "ground truth" HDR+ merged image. 324 | with utils.open_file(testimg_file, 'rb') as imgin: 325 | testraw = rawpy.imread(imgin).raw_image 326 | # HDR+ output has 2 extra bits of fixed precision, need to divide by 4. 327 | testraw = testraw.astype(np.float32) / 4. 328 | # Need to rescale long exposure test image by fast:slow shutter speed ratio. 329 | fast_shutter = meta['ShutterSpeed'][0] 330 | slow_shutter = meta['ShutterSpeed'][-1] 331 | shutter_ratio = fast_shutter / slow_shutter 332 | # Replace loaded raws with the "ground truth" test image. 333 | raws = testraw[None] 334 | # Test image shares metadata with the first loaded image (fast exposure). 335 | meta = {k: meta[k][:1] for k in meta} 336 | else: 337 | shutter_ratio = 1. 338 | 339 | # Next we determine an index for each unique shutter speed in the data. 340 | shutter_speeds = meta['ShutterSpeed'] 341 | # Sort the shutter speeds from slowest (largest) to fastest (smallest). 342 | # This way index 0 will always correspond to the brightest image. 343 | unique_shutters = np.sort(np.unique(shutter_speeds))[::-1] 344 | exposure_idx = np.zeros_like(shutter_speeds, dtype=np.int32) 345 | for i, shutter in enumerate(unique_shutters): 346 | # Assign index `i` to all images with shutter speed `shutter`. 347 | exposure_idx[shutter_speeds == shutter] = i 348 | meta['exposure_idx'] = exposure_idx 349 | meta['unique_shutters'] = unique_shutters 350 | # Rescale to use relative shutter speeds, where 1. is the brightest. 351 | # This way the NeRF output with exposure=1 will always be reasonable. 352 | meta['exposure_values'] = shutter_speeds / unique_shutters[0] 353 | 354 | # Rescale raw sensor measurements to [0, 1] (plus noise). 355 | blacklevel = meta['BlackLevel'].reshape(-1, 1, 1) 356 | whitelevel = meta['WhiteLevel'].reshape(-1, 1, 1) 357 | images = (raws - blacklevel) / (whitelevel - blacklevel) * shutter_ratio 358 | 359 | # Calculate value for exposure level when gamma mapping, defaults to 97%. 360 | # Always based on full resolution image 0 (for consistency). 361 | image0_raw_demosaic = np.array(bilinear_demosaic_jax(images[0])) 362 | image0_rgb = image0_raw_demosaic @ meta['cam2rgb'][0].T 363 | exposure = np.percentile(image0_rgb, exposure_percentile) 364 | meta['exposure'] = exposure 365 | # Sweep over various exposure percentiles to visualize in training logs. 366 | exposure_levels = {p: np.percentile(image0_rgb, p) for p in _PERCENTILE_LIST} 367 | meta['exposure_levels'] = exposure_levels 368 | 369 | # Create postprocessing function mapping raw images to tonemapped sRGB space. 370 | cam2rgb0 = meta['cam2rgb'][0] 371 | meta['postprocess_fn'] = lambda z, x=exposure: postprocess_raw(z, cam2rgb0, x) 372 | 373 | # Demosaic Bayer images (preserves the measured RGGB values) and downsample 374 | # if needed. Moving array to device + running processing function in Jax + 375 | # copying back to CPU is faster than running directly on CPU. 376 | def processing_fn(x): 377 | x_jax = jnp.array(x) 378 | x_demosaic_jax = bilinear_demosaic_jax(x_jax) 379 | if n_downsample > 1: 380 | x_demosaic_jax = lib_image.downsample(x_demosaic_jax, n_downsample) 381 | return np.array(x_demosaic_jax) 382 | images = np.stack([processing_fn(im) for im in images], axis=0) 383 | 384 | return images, meta, testscene 385 | 386 | 387 | def best_fit_affine(x: _Array, y: _Array, axis: _Axis) -> _Array: 388 | """Computes best fit a, b such that a * x + b = y, in a least square sense.""" 389 | x_m = x.mean(axis=axis) 390 | y_m = y.mean(axis=axis) 391 | xy_m = (x * y).mean(axis=axis) 392 | xx_m = (x * x).mean(axis=axis) 393 | # slope a = Cov(x, y) / Cov(x, x). 394 | a = (xy_m - x_m * y_m) / (xx_m - x_m * x_m) 395 | b = y_m - a * x_m 396 | return a, b 397 | 398 | 399 | def match_images_affine(est: _Array, gt: _Array, 400 | axis: _Axis = (0, 1)) -> _Array: 401 | """Computes affine best fit of gt->est, then maps est back to match gt.""" 402 | # Mapping is computed gt->est to be robust since `est` may be very noisy. 403 | a, b = best_fit_affine(gt, est, axis=axis) 404 | # Inverse mapping back to gt ensures we use a consistent space for metrics. 405 | est_matched = (est - b) / a 406 | return est_matched 407 | -------------------------------------------------------------------------------- /internal/ref_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for reflection directions and directional encodings.""" 16 | 17 | from internal import math 18 | import jax.numpy as jnp 19 | import numpy as np 20 | 21 | 22 | def reflect(viewdirs, normals): 23 | """Reflect view directions about normals. 24 | 25 | The reflection of a vector v about a unit vector n is a vector u such that 26 | dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two 27 | equations is u = 2 dot(n, v) n - v. 28 | 29 | Args: 30 | viewdirs: [..., 3] array of view directions. 31 | normals: [..., 3] array of normal directions (assumed to be unit vectors). 32 | 33 | Returns: 34 | [..., 3] array of reflection directions. 35 | """ 36 | return 2.0 * jnp.sum( 37 | normals * viewdirs, axis=-1, keepdims=True) * normals - viewdirs 38 | 39 | 40 | def l2_normalize(x, eps=jnp.finfo(jnp.float32).eps): 41 | """Normalize x to unit length along last axis.""" 42 | return x / jnp.sqrt(jnp.maximum(jnp.sum(x**2, axis=-1, keepdims=True), eps)) 43 | 44 | 45 | def compute_weighted_mae(weights, normals, normals_gt): 46 | """Compute weighted mean angular error, assuming normals are unit length.""" 47 | one_eps = 1 - jnp.finfo(jnp.float32).eps 48 | return (weights * jnp.arccos( 49 | jnp.clip((normals * normals_gt).sum(-1), -one_eps, 50 | one_eps))).sum() / weights.sum() * 180.0 / jnp.pi 51 | 52 | 53 | def generalized_binomial_coeff(a, k): 54 | """Compute generalized binomial coefficients.""" 55 | return np.prod(a - np.arange(k)) / np.math.factorial(k) 56 | 57 | 58 | def assoc_legendre_coeff(l, m, k): 59 | """Compute associated Legendre polynomial coefficients. 60 | 61 | Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the 62 | (l, m)th associated Legendre polynomial, P_l^m(cos(theta)). 63 | 64 | Args: 65 | l: associated Legendre polynomial degree. 66 | m: associated Legendre polynomial order. 67 | k: power of cos(theta). 68 | 69 | Returns: 70 | A float, the coefficient of the term corresponding to the inputs. 71 | """ 72 | return ((-1)**m * 2**l * np.math.factorial(l) / np.math.factorial(k) / 73 | np.math.factorial(l - k - m) * 74 | generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l)) 75 | 76 | 77 | def sph_harm_coeff(l, m, k): 78 | """Compute spherical harmonic coefficients.""" 79 | return (np.sqrt( 80 | (2.0 * l + 1.0) * np.math.factorial(l - m) / 81 | (4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k)) 82 | 83 | 84 | def get_ml_array(deg_view): 85 | """Create a list with all pairs of (l, m) values to use in the encoding.""" 86 | ml_list = [] 87 | for i in range(deg_view): 88 | l = 2**i 89 | # Only use nonnegative m values, later splitting real and imaginary parts. 90 | for m in range(l + 1): 91 | ml_list.append((m, l)) 92 | 93 | # Convert list into a numpy array. 94 | ml_array = np.array(ml_list).T 95 | return ml_array 96 | 97 | 98 | def generate_ide_fn(deg_view): 99 | """Generate integrated directional encoding (IDE) function. 100 | 101 | This function returns a function that computes the integrated directional 102 | encoding from Equations 6-8 of arxiv.org/abs/2112.03907. 103 | 104 | Args: 105 | deg_view: number of spherical harmonics degrees to use. 106 | 107 | Returns: 108 | A function for evaluating integrated directional encoding. 109 | 110 | Raises: 111 | ValueError: if deg_view is larger than 5. 112 | """ 113 | if deg_view > 5: 114 | raise ValueError('Only deg_view of at most 5 is numerically stable.') 115 | 116 | ml_array = get_ml_array(deg_view) 117 | l_max = 2**(deg_view - 1) 118 | 119 | # Create a matrix corresponding to ml_array holding all coefficients, which, 120 | # when multiplied (from the right) by the z coordinate Vandermonde matrix, 121 | # results in the z component of the encoding. 122 | mat = np.zeros((l_max + 1, ml_array.shape[1])) 123 | for i, (m, l) in enumerate(ml_array.T): 124 | for k in range(l - m + 1): 125 | mat[k, i] = sph_harm_coeff(l, m, k) 126 | 127 | def integrated_dir_enc_fn(xyz, kappa_inv): 128 | """Function returning integrated directional encoding (IDE). 129 | 130 | Args: 131 | xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at. 132 | kappa_inv: [..., 1] reciprocal of the concentration parameter of the von 133 | Mises-Fisher distribution. 134 | 135 | Returns: 136 | An array with the resulting IDE. 137 | """ 138 | x = xyz[..., 0:1] 139 | y = xyz[..., 1:2] 140 | z = xyz[..., 2:3] 141 | 142 | # Compute z Vandermonde matrix. 143 | vmz = jnp.concatenate([z**i for i in range(mat.shape[0])], axis=-1) 144 | 145 | # Compute x+iy Vandermonde matrix. 146 | vmxy = jnp.concatenate([(x + 1j * y)**m for m in ml_array[0, :]], axis=-1) 147 | 148 | # Get spherical harmonics. 149 | sph_harms = vmxy * math.matmul(vmz, mat) 150 | 151 | # Apply attenuation function using the von Mises-Fisher distribution 152 | # concentration parameter, kappa. 153 | sigma = 0.5 * ml_array[1, :] * (ml_array[1, :] + 1) 154 | ide = sph_harms * jnp.exp(-sigma * kappa_inv) 155 | 156 | # Split into real and imaginary parts and return 157 | return jnp.concatenate([jnp.real(ide), jnp.imag(ide)], axis=-1) 158 | 159 | return integrated_dir_enc_fn 160 | 161 | 162 | def generate_dir_enc_fn(deg_view): 163 | """Generate directional encoding (DE) function. 164 | 165 | Args: 166 | deg_view: number of spherical harmonics degrees to use. 167 | 168 | Returns: 169 | A function for evaluating directional encoding. 170 | """ 171 | integrated_dir_enc_fn = generate_ide_fn(deg_view) 172 | 173 | def dir_enc_fn(xyz): 174 | """Function returning directional encoding (DE).""" 175 | return integrated_dir_enc_fn(xyz, jnp.zeros_like(xyz[..., :1])) 176 | 177 | return dir_enc_fn 178 | -------------------------------------------------------------------------------- /internal/render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for shooting and rendering rays.""" 16 | 17 | from internal import stepfun 18 | import jax.numpy as jnp 19 | 20 | 21 | def lift_gaussian(d, t_mean, t_var, r_var, diag): 22 | """Lift a Gaussian defined along a ray to 3D coordinates.""" 23 | mean = d[..., None, :] * t_mean[..., None] 24 | 25 | d_mag_sq = jnp.maximum(1e-10, jnp.sum(d**2, axis=-1, keepdims=True)) 26 | 27 | if diag: 28 | d_outer_diag = d**2 29 | null_outer_diag = 1 - d_outer_diag / d_mag_sq 30 | t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :] 31 | xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :] 32 | cov_diag = t_cov_diag + xy_cov_diag 33 | return mean, cov_diag 34 | else: 35 | d_outer = d[..., :, None] * d[..., None, :] 36 | eye = jnp.eye(d.shape[-1]) 37 | null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :] 38 | t_cov = t_var[..., None, None] * d_outer[..., None, :, :] 39 | xy_cov = r_var[..., None, None] * null_outer[..., None, :, :] 40 | cov = t_cov + xy_cov 41 | return mean, cov 42 | 43 | 44 | def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True): 45 | """Approximate a conical frustum as a Gaussian distribution (mean+cov). 46 | 47 | Assumes the ray is originating from the origin, and base_radius is the 48 | radius at dist=1. Doesn't assume `d` is normalized. 49 | 50 | Args: 51 | d: jnp.float32 3-vector, the axis of the cone 52 | t0: float, the starting distance of the frustum. 53 | t1: float, the ending distance of the frustum. 54 | base_radius: float, the scale of the radius as a function of distance. 55 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance. 56 | stable: boolean, whether or not to use the stable computation described in 57 | the paper (setting this to False will cause catastrophic failure). 58 | 59 | Returns: 60 | a Gaussian (mean and covariance). 61 | """ 62 | if stable: 63 | # Equation 7 in the paper (https://arxiv.org/abs/2103.13415). 64 | mu = (t0 + t1) / 2 # The average of the two `t` values. 65 | hw = (t1 - t0) / 2 # The half-width of the two `t` values. 66 | eps = jnp.finfo(jnp.float32).eps 67 | t_mean = mu + (2 * mu * hw**2) / jnp.maximum(eps, 3 * mu**2 + hw**2) 68 | denom = jnp.maximum(eps, 3 * mu**2 + hw**2) 69 | t_var = (hw**2) / 3 - (4 / 15) * hw**4 * (12 * mu**2 - hw**2) / denom**2 70 | r_var = (mu**2) / 4 + (5 / 12) * hw**2 - (4 / 15) * (hw**4) / denom 71 | else: 72 | # Equations 37-39 in the paper. 73 | t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3)) 74 | r_var = 3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3) 75 | t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3) 76 | t_var = t_mosq - t_mean**2 77 | r_var *= base_radius**2 78 | return lift_gaussian(d, t_mean, t_var, r_var, diag) 79 | 80 | 81 | def cylinder_to_gaussian(d, t0, t1, radius, diag): 82 | """Approximate a cylinder as a Gaussian distribution (mean+cov). 83 | 84 | Assumes the ray is originating from the origin, and radius is the 85 | radius. Does not renormalize `d`. 86 | 87 | Args: 88 | d: jnp.float32 3-vector, the axis of the cylinder 89 | t0: float, the starting distance of the cylinder. 90 | t1: float, the ending distance of the cylinder. 91 | radius: float, the radius of the cylinder 92 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance. 93 | 94 | Returns: 95 | a Gaussian (mean and covariance). 96 | """ 97 | t_mean = (t0 + t1) / 2 98 | r_var = radius**2 / 4 99 | t_var = (t1 - t0)**2 / 12 100 | return lift_gaussian(d, t_mean, t_var, r_var, diag) 101 | 102 | 103 | def cast_rays(tdist, origins, directions, radii, ray_shape, diag=True): 104 | """Cast rays (cone- or cylinder-shaped) and featurize sections of it. 105 | 106 | Args: 107 | tdist: float array, the "fencepost" distances along the ray. 108 | origins: float array, the ray origin coordinates. 109 | directions: float array, the ray direction vectors. 110 | radii: float array, the radii (base radii for cones) of the rays. 111 | ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'. 112 | diag: boolean, whether or not the covariance matrices should be diagonal. 113 | 114 | Returns: 115 | a tuple of arrays of means and covariances. 116 | """ 117 | t0 = tdist[..., :-1] 118 | t1 = tdist[..., 1:] 119 | if ray_shape == 'cone': 120 | gaussian_fn = conical_frustum_to_gaussian 121 | elif ray_shape == 'cylinder': 122 | gaussian_fn = cylinder_to_gaussian 123 | else: 124 | raise ValueError('ray_shape must be \'cone\' or \'cylinder\'') 125 | means, covs = gaussian_fn(directions, t0, t1, radii, diag) 126 | means = means + origins[..., None, :] 127 | return means, covs 128 | 129 | 130 | def compute_alpha_weights(density, tdist, dirs, opaque_background=False): 131 | """Helper function for computing alpha compositing weights.""" 132 | t_delta = tdist[..., 1:] - tdist[..., :-1] 133 | delta = t_delta * jnp.linalg.norm(dirs[..., None, :], axis=-1) 134 | density_delta = density * delta 135 | 136 | if opaque_background: 137 | # Equivalent to making the final t-interval infinitely wide. 138 | density_delta = jnp.concatenate([ 139 | density_delta[..., :-1], 140 | jnp.full_like(density_delta[..., -1:], jnp.inf) 141 | ], 142 | axis=-1) 143 | 144 | alpha = 1 - jnp.exp(-density_delta) 145 | trans = jnp.exp(-jnp.concatenate([ 146 | jnp.zeros_like(density_delta[..., :1]), 147 | jnp.cumsum(density_delta[..., :-1], axis=-1) 148 | ], 149 | axis=-1)) 150 | weights = alpha * trans 151 | return weights, alpha, trans 152 | 153 | 154 | def volumetric_rendering(rgbs, 155 | weights, 156 | tdist, 157 | bg_rgbs, 158 | t_far, 159 | compute_extras, 160 | extras=None): 161 | """Volumetric Rendering Function. 162 | 163 | Args: 164 | rgbs: jnp.ndarray(float32), color, [batch_size, num_samples, 3] 165 | weights: jnp.ndarray(float32), weights, [batch_size, num_samples]. 166 | tdist: jnp.ndarray(float32), [batch_size, num_samples]. 167 | bg_rgbs: jnp.ndarray(float32), the color(s) to use for the background. 168 | t_far: jnp.ndarray(float32), [batch_size, 1], the distance of the far plane. 169 | compute_extras: bool, if True, compute extra quantities besides color. 170 | extras: dict, a set of values along rays to render by alpha compositing. 171 | 172 | Returns: 173 | rendering: a dict containing an rgb image of size [batch_size, 3], and other 174 | visualizations if compute_extras=True. 175 | """ 176 | eps = jnp.finfo(jnp.float32).eps 177 | rendering = {} 178 | 179 | acc = weights.sum(axis=-1) 180 | bg_w = jnp.maximum(0, 1 - acc[..., None]) # The weight of the background. 181 | rgb = (weights[..., None] * rgbs).sum(axis=-2) + bg_w * bg_rgbs 182 | rendering['rgb'] = rgb 183 | 184 | if compute_extras: 185 | rendering['acc'] = acc 186 | 187 | if extras is not None: 188 | for k, v in extras.items(): 189 | if v is not None: 190 | rendering[k] = (weights[..., None] * v).sum(axis=-2) 191 | 192 | expectation = lambda x: (weights * x).sum(axis=-1) / jnp.maximum(eps, acc) 193 | t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:]) 194 | # For numerical stability this expectation is computing using log-distance. 195 | rendering['distance_mean'] = ( 196 | jnp.clip( 197 | jnp.nan_to_num(jnp.exp(expectation(jnp.log(t_mids))), jnp.inf), 198 | tdist[..., 0], tdist[..., -1])) 199 | 200 | # Add an extra fencepost with the far distance at the end of each ray, with 201 | # whatever weight is needed to make the new weight vector sum to exactly 1 202 | # (`weights` is only guaranteed to sum to <= 1, not == 1). 203 | t_aug = jnp.concatenate([tdist, t_far], axis=-1) 204 | weights_aug = jnp.concatenate([weights, bg_w], axis=-1) 205 | 206 | ps = [5, 50, 95] 207 | distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps) 208 | 209 | for i, p in enumerate(ps): 210 | s = 'median' if p == 50 else 'percentile_' + str(p) 211 | rendering['distance_' + s] = distance_percentiles[..., i] 212 | 213 | return rendering 214 | 215 | def volumetric_rendering_nerfw( 216 | rgbs_static, 217 | rgbs_transient, 218 | trans, 219 | alpha_static, 220 | alpha_transient, 221 | tdist, 222 | bg_rgbs, 223 | t_far, 224 | compute_extras, 225 | extras=None): 226 | eps = jnp.finfo(jnp.float32).eps 227 | rendering = {} 228 | weights_static = trans * alpha_static 229 | weights_transient = trans * alpha_transient 230 | weights = weights_static + weights_transient 231 | 232 | acc = (weights_static + weights_transient).sum(axis=-1) 233 | bg_w = jnp.maximum(0, 1 - acc[..., None]) # The weight of the background. 234 | rgb = ( 235 | weights_static[..., None] * rgbs_static + 236 | weights_transient[..., None] * rgbs_transient 237 | ).sum(axis=-2) + bg_w * bg_rgbs 238 | rendering['rgb'] = rgb 239 | 240 | if compute_extras: 241 | rendering['acc'] = acc 242 | 243 | if extras is not None: 244 | for k, v in extras.items(): 245 | if v is not None: 246 | rendering[k] = (weights[..., None] * v).sum(axis=-2) 247 | 248 | expectation = lambda x: (weights * x).sum(axis=-1) / jnp.maximum(eps, acc) 249 | t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:]) 250 | # For numerical stability this expectation is computing using log-distance. 251 | rendering['distance_mean'] = ( 252 | jnp.clip( 253 | jnp.nan_to_num(jnp.exp(expectation(jnp.log(t_mids))), jnp.inf), 254 | tdist[..., 0], tdist[..., -1])) 255 | 256 | # Add an extra fencepost with the far distance at the end of each ray, with 257 | # whatever weight is needed to make the new weight vector sum to exactly 1 258 | # (`weights` is only guaranteed to sum to <= 1, not == 1). 259 | t_aug = jnp.concatenate([tdist, t_far], axis=-1) 260 | weights_aug = jnp.concatenate([weights, bg_w], axis=-1) 261 | 262 | ps = [5, 50, 95] 263 | distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps) 264 | 265 | for i, p in enumerate(ps): 266 | s = 'median' if p == 50 else 'percentile_' + str(p) 267 | rendering['distance_' + s] = distance_percentiles[..., i] 268 | 269 | return rendering 270 | -------------------------------------------------------------------------------- /internal/robustnerf.py: -------------------------------------------------------------------------------- 1 | """Computes RobustNeRF mask.""" 2 | from typing import Mapping, Tuple 3 | 4 | from jax import lax 5 | import jax.numpy as jnp 6 | 7 | 8 | def robustnerf_mask( 9 | errors: jnp.ndarray, loss_threshold: float, config: {str: float} 10 | ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: 11 | """Computes RobustNeRF mask. 12 | 13 | Args: 14 | errors: f32[n,h,w,c]. Per-subpixel errors in a batch of patches. 15 | loss_threshold: f32[]. Upper bound on per-pixel loss to use to determine 16 | if a pixel is an inlier or not. 17 | config: Config object. A dictionary of hyperparameters. 18 | 19 | Returns: 20 | mask: f32[n,h,w,c or 1]. Binary mask that broadcasts to shape [n,h,w,c]. 21 | stats: { str: f32[] }. Statistics to pass on. 22 | """ 23 | epsilon = 1e-3 24 | error_dtype = errors.dtype 25 | error_per_pixel = jnp.mean(errors, axis=-1, keepdims=True) # f32[n,h,w,1] 26 | next_loss_threshold = jnp.quantile( 27 | error_per_pixel, config.robustnerf_inlier_quantile 28 | ) 29 | mask = jnp.ones_like(error_per_pixel, dtype=error_dtype) 30 | stats = { 31 | 'loss_threshold': next_loss_threshold, 32 | } 33 | if config.enable_robustnerf_loss: 34 | assert ( 35 | config.robustnerf_inner_patch_size <= config.patch_size 36 | ), 'patch_size must be larger than robustnerf_inner_patch_size.' 37 | 38 | # Inlier pixels have a value of 1.0 in the mask. 39 | is_inlier_pixel = (error_per_pixel < loss_threshold).astype(error_dtype) 40 | stats['is_inlier_loss'] = jnp.mean(is_inlier_pixel) 41 | 42 | # Apply fxf (3x3) box filter 'window' for smoothing (diffusion). 43 | f = config.robustnerf_smoothed_filter_size 44 | window = jnp.ones((1, 1, f, f)) / (f * f) 45 | has_inlier_neighbors = lax.conv( 46 | jnp.transpose(is_inlier_pixel, [0, 3, 1, 2]), window, (1, 1), 'SAME' 47 | ) 48 | has_inlier_neighbors = jnp.transpose(has_inlier_neighbors, [0, 2, 3, 1]) 49 | 50 | # Binarize after smoothing. 51 | # config.robustnerf_smoothed_inlier_quantile default is 0.5 which means at 52 | # least 50% of neighbouring pixels are inliers. 53 | has_inlier_neighbors = ( 54 | has_inlier_neighbors > 1 - config.robustnerf_smoothed_inlier_quantile 55 | ).astype(error_dtype) 56 | stats['has_inlier_neighbors'] = jnp.mean(has_inlier_neighbors) 57 | is_inlier_pixel = ( 58 | has_inlier_neighbors + is_inlier_pixel > epsilon 59 | ).astype(error_dtype) 60 | # Construct binary mask for inner pixels. The entire inner patch is either 61 | # active or inactive. 62 | # patch_size is the input patch (h,w), inner patch size can be any value 63 | # smaller than patch_size. Default is for the inner patch size to be half 64 | # the input patch size (i.e. 16x16 -> 8x8). 65 | inner_patch_mask = _robustnerf_inner_patch_mask( 66 | config.robustnerf_inner_patch_size // config.stride, config.patch_size // config.stride 67 | ) 68 | is_inlier_patch = jnp.mean( 69 | is_inlier_pixel, axis=[1, 2], keepdims=True 70 | ) # f32[n,1,1,1] 71 | # robustnerf_inner_patch_inlier_quantile what percentage of the patch 72 | # should be inliers so that the patch is counted as an inlier patch. 73 | is_inlier_patch = ( 74 | is_inlier_patch > 1 - config.robustnerf_inner_patch_inlier_quantile 75 | ).astype(error_dtype) 76 | is_inlier_patch = is_inlier_patch * inner_patch_mask 77 | stats['is_inlier_patch'] = jnp.mean(is_inlier_patch) 78 | 79 | # A pixel is an inlier if it is an inlier according to any of the above 80 | # criteria. 81 | mask = ( 82 | is_inlier_patch + is_inlier_pixel > epsilon 83 | ).astype(error_dtype) 84 | 85 | stats['mask'] = jnp.mean(mask) 86 | return mask, stats 87 | 88 | 89 | def _robustnerf_inner_patch_mask( 90 | inner_patch_size, outer_patch_size, *, dtype=jnp.float32 91 | ): 92 | """Constructs binary mask for inner patch. 93 | 94 | Args: 95 | inner_patch_size: Size of the (square) inside patch. 96 | outer_patch_size: Size of the (square) outer patch. 97 | dtype: dtype for result 98 | 99 | Returns: 100 | Binary mask of shape (1, outer_patch_size, outer_patch_size, 1). Mask is 101 | 1.0 for the center (inner_patch_size, inner_patch_size) square and 0.0 102 | elsewhere. 103 | """ 104 | pad_size_lower = (outer_patch_size - inner_patch_size) // 2 105 | pad_size_upper = outer_patch_size - (inner_patch_size + pad_size_lower) 106 | mask = jnp.pad( 107 | jnp.ones((1, inner_patch_size, inner_patch_size, 1), dtype=dtype), 108 | ( 109 | (0, 0), # batch 110 | (pad_size_lower, pad_size_upper), # height 111 | (pad_size_lower, pad_size_upper), # width 112 | (0, 0), # channels 113 | ), 114 | ) 115 | return mask 116 | 117 | 118 | -------------------------------------------------------------------------------- /internal/stepfun.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for manipulating step functions (piecewise-constant 1D functions). 16 | 17 | We have a shared naming and dimension convention for these functions. 18 | All input/output step functions are assumed to be aligned along the last axis. 19 | `t` always indicates the x coordinates of the *endpoints* of a step function. 20 | `y` indicates unconstrained values for the *bins* of a step function 21 | `w` indicates bin weights that sum to <= 1. `p` indicates non-negative bin 22 | values that *integrate* to <= 1. 23 | """ 24 | 25 | from internal import math 26 | import jax 27 | import jax.numpy as jnp 28 | 29 | 30 | def searchsorted(a, v): 31 | """Find indices where v should be inserted into a to maintain order. 32 | 33 | This behaves like jnp.searchsorted (its second output is the same as 34 | jnp.searchsorted's output if all elements of v are in [a[0], a[-1]]) but is 35 | faster because it wastes memory to save some compute. 36 | 37 | Args: 38 | a: tensor, the sorted reference points that we are scanning to see where v 39 | should lie. 40 | v: tensor, the query points that we are pretending to insert into a. Does 41 | not need to be sorted. All but the last dimensions should match or expand 42 | to those of a, the last dimension can differ. 43 | 44 | Returns: 45 | (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the 46 | range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or 47 | last index of a. 48 | """ 49 | i = jnp.arange(a.shape[-1]) 50 | v_ge_a = v[..., None, :] >= a[..., :, None] 51 | idx_lo = jnp.max(jnp.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2) 52 | idx_hi = jnp.min(jnp.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2) 53 | return idx_lo, idx_hi 54 | 55 | 56 | def query(tq, t, y, outside_value=0): 57 | """Look up the values of the step function (t, y) at locations tq.""" 58 | idx_lo, idx_hi = searchsorted(t, tq) 59 | yq = jnp.where(idx_lo == idx_hi, outside_value, 60 | jnp.take_along_axis(y, idx_lo, axis=-1)) 61 | return yq 62 | 63 | 64 | def inner_outer(t0, t1, y1): 65 | """Construct inner and outer measures on (t1, y1) for t0.""" 66 | cy1 = jnp.concatenate([jnp.zeros_like(y1[..., :1]), 67 | jnp.cumsum(y1, axis=-1)], 68 | axis=-1) 69 | idx_lo, idx_hi = searchsorted(t1, t0) 70 | 71 | cy1_lo = jnp.take_along_axis(cy1, idx_lo, axis=-1) 72 | cy1_hi = jnp.take_along_axis(cy1, idx_hi, axis=-1) 73 | 74 | y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1] 75 | y0_inner = jnp.where(idx_hi[..., :-1] <= idx_lo[..., 1:], 76 | cy1_lo[..., 1:] - cy1_hi[..., :-1], 0) 77 | return y0_inner, y0_outer 78 | 79 | 80 | def lossfun_outer(t, w, t_env, w_env, eps=jnp.finfo(jnp.float32).eps): 81 | """The proposal weight should be an upper envelope on the nerf weight.""" 82 | _, w_outer = inner_outer(t, t_env, w_env) 83 | # We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's 84 | # more effective to pull w_outer up than it is to push w_inner down. 85 | # Scaled half-quadratic loss that gives a constant gradient at w_outer = 0. 86 | return jnp.maximum(0, w - w_outer)**2 / (w + eps) 87 | 88 | 89 | def weight_to_pdf(t, w, eps=jnp.finfo(jnp.float32).eps**2): 90 | """Turn a vector of weights that sums to 1 into a PDF that integrates to 1.""" 91 | return w / jnp.maximum(eps, (t[..., 1:] - t[..., :-1])) 92 | 93 | 94 | def pdf_to_weight(t, p): 95 | """Turn a PDF that integrates to 1 into a vector of weights that sums to 1.""" 96 | return p * (t[..., 1:] - t[..., :-1]) 97 | 98 | 99 | def max_dilate(t, w, dilation, domain=(-jnp.inf, jnp.inf)): 100 | """Dilate (via max-pooling) a non-negative step function.""" 101 | t0 = t[..., :-1] - dilation 102 | t1 = t[..., 1:] + dilation 103 | t_dilate = jnp.sort(jnp.concatenate([t, t0, t1], axis=-1), axis=-1) 104 | t_dilate = jnp.clip(t_dilate, *domain) 105 | w_dilate = jnp.max( 106 | jnp.where( 107 | (t0[..., None, :] <= t_dilate[..., None]) 108 | & (t1[..., None, :] > t_dilate[..., None]), 109 | w[..., None, :], 110 | 0, 111 | ), 112 | axis=-1)[..., :-1] 113 | return t_dilate, w_dilate 114 | 115 | 116 | def max_dilate_weights(t, 117 | w, 118 | dilation, 119 | domain=(-jnp.inf, jnp.inf), 120 | renormalize=False, 121 | eps=jnp.finfo(jnp.float32).eps**2): 122 | """Dilate (via max-pooling) a set of weights.""" 123 | p = weight_to_pdf(t, w) 124 | t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain) 125 | w_dilate = pdf_to_weight(t_dilate, p_dilate) 126 | if renormalize: 127 | w_dilate /= jnp.maximum(eps, jnp.sum(w_dilate, axis=-1, keepdims=True)) 128 | return t_dilate, w_dilate 129 | 130 | 131 | def integrate_weights(w): 132 | """Compute the cumulative sum of w, assuming all weight vectors sum to 1. 133 | 134 | The output's size on the last dimension is one greater than that of the input, 135 | because we're computing the integral corresponding to the endpoints of a step 136 | function, not the integral of the interior/bin values. 137 | 138 | Args: 139 | w: Tensor, which will be integrated along the last axis. This is assumed to 140 | sum to 1 along the last axis, and this function will (silently) break if 141 | that is not the case. 142 | 143 | Returns: 144 | cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 145 | """ 146 | cw = jnp.minimum(1, jnp.cumsum(w[..., :-1], axis=-1)) 147 | shape = cw.shape[:-1] + (1,) 148 | # Ensure that the CDF starts with exactly 0 and ends with exactly 1. 149 | cw0 = jnp.concatenate([jnp.zeros(shape), cw, jnp.ones(shape)], axis=-1) 150 | return cw0 151 | 152 | 153 | def invert_cdf(u, t, w_logits, use_gpu_resampling=False): 154 | """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" 155 | # Compute the PDF and CDF for each weight vector. 156 | w = jax.nn.softmax(w_logits, axis=-1) 157 | cw = integrate_weights(w) 158 | # Interpolate into the inverse CDF. 159 | interp_fn = math.interp if use_gpu_resampling else math.sorted_interp 160 | t_new = interp_fn(u, cw, t) 161 | return t_new 162 | 163 | 164 | def sample(rng, 165 | t, 166 | w_logits, 167 | num_samples, 168 | single_jitter=False, 169 | deterministic_center=False, 170 | use_gpu_resampling=False): 171 | """Piecewise-Constant PDF sampling from a step function. 172 | 173 | Args: 174 | rng: random number generator (or None for `linspace` sampling). 175 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) 176 | w_logits: [..., num_bins], logits corresponding to bin weights 177 | num_samples: int, the number of samples. 178 | single_jitter: bool, if True, jitter every sample along each ray by the same 179 | amount in the inverse CDF. Otherwise, jitter each sample independently. 180 | deterministic_center: bool, if False, when `rng` is None return samples that 181 | linspace the entire PDF. If True, skip the front and back of the linspace 182 | so that the centers of each PDF interval are returned. 183 | use_gpu_resampling: bool, If True this resamples the rays based on a 184 | "gather" instruction, which is fast on GPUs but slow on TPUs. If False, 185 | this resamples the rays based on brute-force searches, which is fast on 186 | TPUs, but slow on GPUs. 187 | 188 | Returns: 189 | t_samples: jnp.ndarray(float32), [batch_size, num_samples]. 190 | """ 191 | eps = jnp.finfo(jnp.float32).eps 192 | 193 | # Draw uniform samples. 194 | if rng is None: 195 | # Match the behavior of jax.random.uniform() by spanning [0, 1-eps]. 196 | if deterministic_center: 197 | pad = 1 / (2 * num_samples) 198 | u = jnp.linspace(pad, 1. - pad - eps, num_samples) 199 | else: 200 | u = jnp.linspace(0, 1. - eps, num_samples) 201 | u = jnp.broadcast_to(u, t.shape[:-1] + (num_samples,)) 202 | else: 203 | # `u` is in [0, 1) --- it can be zero, but it can never be 1. 204 | u_max = eps + (1 - eps) / num_samples 205 | max_jitter = (1 - u_max) / (num_samples - 1) - eps 206 | d = 1 if single_jitter else num_samples 207 | u = ( 208 | jnp.linspace(0, 1 - u_max, num_samples) + 209 | jax.random.uniform(rng, t.shape[:-1] + (d,), maxval=max_jitter)) 210 | 211 | return invert_cdf(u, t, w_logits, use_gpu_resampling=use_gpu_resampling) 212 | 213 | 214 | def sample_intervals(rng, 215 | t, 216 | w_logits, 217 | num_samples, 218 | single_jitter=False, 219 | domain=(-jnp.inf, jnp.inf), 220 | use_gpu_resampling=False): 221 | """Sample *intervals* (rather than points) from a step function. 222 | 223 | Args: 224 | rng: random number generator (or None for `linspace` sampling). 225 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) 226 | w_logits: [..., num_bins], logits corresponding to bin weights 227 | num_samples: int, the number of intervals to sample. 228 | single_jitter: bool, if True, jitter every sample along each ray by the same 229 | amount in the inverse CDF. Otherwise, jitter each sample independently. 230 | domain: (minval, maxval), the range of valid values for `t`. 231 | use_gpu_resampling: bool, If True this resamples the rays based on a 232 | "gather" instruction, which is fast on GPUs but slow on TPUs. If False, 233 | this resamples the rays based on brute-force searches, which is fast on 234 | TPUs, but slow on GPUs. 235 | 236 | Returns: 237 | t_samples: jnp.ndarray(float32), [batch_size, num_samples]. 238 | """ 239 | if num_samples <= 1: 240 | raise ValueError(f'num_samples must be > 1, is {num_samples}.') 241 | 242 | # Sample a set of points from the step function. 243 | centers = sample( 244 | rng, 245 | t, 246 | w_logits, 247 | num_samples, 248 | single_jitter, 249 | deterministic_center=True, 250 | use_gpu_resampling=use_gpu_resampling) 251 | 252 | # The intervals we return will span the midpoints of each adjacent sample. 253 | mid = (centers[..., 1:] + centers[..., :-1]) / 2 254 | 255 | # Each first/last fencepost is the reflection of the first/last midpoint 256 | # around the first/last sampled center. We clamp to the limits of the input 257 | # domain, provided by the caller. 258 | minval, maxval = domain 259 | first = jnp.maximum(minval, 2 * centers[..., :1] - mid[..., :1]) 260 | last = jnp.minimum(maxval, 2 * centers[..., -1:] - mid[..., -1:]) 261 | 262 | t_samples = jnp.concatenate([first, mid, last], axis=-1) 263 | return t_samples 264 | 265 | 266 | def lossfun_distortion(t, w): 267 | """Compute iint w[i] w[j] |t[i] - t[j]| di dj.""" 268 | # The loss incurred between all pairs of intervals. 269 | ut = (t[..., 1:] + t[..., :-1]) / 2 270 | dut = jnp.abs(ut[..., :, None] - ut[..., None, :]) 271 | loss_inter = jnp.sum(w * jnp.sum(w[..., None, :] * dut, axis=-1), axis=-1) 272 | 273 | # The loss incurred within each individual interval with itself. 274 | loss_intra = jnp.sum(w**2 * (t[..., 1:] - t[..., :-1]), axis=-1) / 3 275 | 276 | return loss_inter + loss_intra 277 | 278 | 279 | def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi): 280 | """Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi]).""" 281 | # Distortion when the intervals do not overlap. 282 | d_disjoint = jnp.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2) 283 | 284 | # Distortion when the intervals overlap. 285 | d_overlap = (2 * 286 | (jnp.minimum(t0_hi, t1_hi)**3 - jnp.maximum(t0_lo, t1_lo)**3) + 287 | 3 * (t1_hi * t0_hi * jnp.abs(t1_hi - t0_hi) + 288 | t1_lo * t0_lo * jnp.abs(t1_lo - t0_lo) + t1_hi * t0_lo * 289 | (t0_lo - t1_hi) + t1_lo * t0_hi * 290 | (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo)) 291 | 292 | # Are the two intervals not overlapping? 293 | are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi) 294 | 295 | return jnp.where(are_disjoint, d_disjoint, d_overlap) 296 | 297 | 298 | def weighted_percentile(t, w, ps): 299 | """Compute the weighted percentiles of a step function. w's must sum to 1.""" 300 | cw = integrate_weights(w) 301 | # We want to interpolate into the integrated weights according to `ps`. 302 | fn = lambda cw_i, t_i: jnp.interp(jnp.array(ps) / 100, cw_i, t_i) 303 | # Vmap fn to an arbitrary number of leading dimensions. 304 | cw_mat = cw.reshape([-1, cw.shape[-1]]) 305 | t_mat = t.reshape([-1, t.shape[-1]]) 306 | wprctile_mat = (jax.vmap(fn, 0)(cw_mat, t_mat)) 307 | wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),)) 308 | return wprctile 309 | 310 | 311 | def resample(t, tp, vp, use_avg=False, eps=jnp.finfo(jnp.float32).eps): 312 | """Resample a step function defined by (tp, vp) into intervals t. 313 | 314 | Notation roughly matches jnp.interp. Resamples by summation by default. 315 | 316 | Args: 317 | t: tensor with shape (..., n+1), the endpoints to resample into. 318 | tp: tensor with shape (..., m+1), the endpoints of the step function being 319 | resampled. 320 | vp: tensor with shape (..., m), the values of the step function being 321 | resampled. 322 | use_avg: bool, if False, return the sum of the step function for each 323 | interval in `t`. If True, return the average, weighted by the width of 324 | each interval in `t`. 325 | eps: float, a small value to prevent division by zero when use_avg=True. 326 | 327 | Returns: 328 | v: tensor with shape (..., n), the values of the resampled step function. 329 | """ 330 | if use_avg: 331 | wp = jnp.diff(tp, axis=-1) 332 | v_numer = resample(t, tp, vp * wp, use_avg=False) 333 | v_denom = resample(t, tp, wp, use_avg=False) 334 | v = v_numer / jnp.maximum(eps, v_denom) 335 | return v 336 | 337 | acc = jnp.cumsum(vp, axis=-1) 338 | acc0 = jnp.concatenate([jnp.zeros(acc.shape[:-1] + (1,)), acc], axis=-1) 339 | acc0_resampled = jnp.vectorize( 340 | jnp.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0) 341 | v = jnp.diff(acc0_resampled, axis=-1) 342 | return v 343 | -------------------------------------------------------------------------------- /internal/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training step and model creation functions.""" 16 | 17 | import collections 18 | import functools 19 | from typing import Any, Callable, Dict, MutableMapping, Optional, Text, Tuple 20 | 21 | from flax.core.scope import FrozenVariableDict 22 | from flax.training.train_state import TrainState 23 | from internal import camera_utils 24 | from internal import configs 25 | from internal import datasets 26 | from internal import image 27 | from internal import math 28 | from internal import models 29 | from internal import ref_utils 30 | from internal import robustnerf 31 | from internal import stepfun 32 | from internal import utils 33 | import jax 34 | from jax import random 35 | import jax.numpy as jnp 36 | import optax 37 | 38 | from flax import traverse_util 39 | import jax.scipy as jsp 40 | from functools import partial 41 | 42 | def flattened_traversal(fn): 43 | def mask(tree): 44 | flat = traverse_util.flatten_dict(tree) 45 | return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()}) 46 | 47 | return mask 48 | 49 | def tree_sum(tree): 50 | return jax.tree_util.tree_reduce(lambda x, y: x + y, tree, initializer=0) 51 | 52 | 53 | def tree_norm_sq(tree): 54 | return tree_sum(jax.tree_util.tree_map(lambda x: jnp.sum(x**2), tree)) 55 | 56 | 57 | def tree_norm(tree): 58 | return jnp.sqrt(tree_norm_sq(tree)) 59 | 60 | 61 | def tree_abs_max(tree): 62 | return jax.tree_util.tree_reduce( 63 | lambda x, y: jnp.maximum(x, jnp.max(jnp.abs(y))), tree, initializer=0) 64 | 65 | 66 | def tree_len(tree): 67 | return tree_sum( 68 | jax.tree_util.tree_map(lambda z: jnp.prod(jnp.array(z.shape)), tree)) 69 | 70 | 71 | def summarize_tree(tree, fn, ancestry=(), max_depth=3): 72 | """Flatten 'tree' while 'fn'-ing values and formatting keys like/this.""" 73 | stats = {} 74 | for k, v in tree.items(): 75 | name = ancestry + (k,) 76 | stats['/'.join(name)] = fn(v) 77 | if hasattr(v, 'items') and len(ancestry) < (max_depth - 1): 78 | stats.update(summarize_tree(v, fn, ancestry=name, max_depth=max_depth)) 79 | return stats 80 | 81 | def dino_var_loss(renderings, config): 82 | losses = [] 83 | for rendering in renderings: 84 | losses.append(jnp.mean(rendering['uncer_var'])) 85 | return config.dino_var_mult * jnp.mean(jnp.array(losses)) 86 | 87 | 88 | from jax.lax import conv_general_dilated 89 | 90 | def create_window(window_size, channels): 91 | """Create a window for SSIM computation.""" 92 | window = jnp.ones((window_size, window_size, channels)) / (window_size**2 * channels) 93 | # Reshape for convolution: (spatial_dim_1, spatial_dim_2, in_channels, out_channels) 94 | return window.reshape(window_size, window_size, channels, 1) 95 | 96 | def convolve(img, window): 97 | """Perform a convolution operation in a functional style.""" 98 | # Define the dimension specification for the convolution operation 99 | dimension_numbers = ('NHWC', 'HWIO', 'NHWC') 100 | return conv_general_dilated(img, window, (1, 1), 'SAME', dimension_numbers=dimension_numbers) 101 | 102 | def compute_ssim(img1, img2, window_size=5): 103 | C1 = 0.01 ** 2 104 | C2 = 0.03 ** 2 105 | C3 = C2 / 2 106 | 107 | window = create_window(window_size, 3)# size for channel 108 | 109 | mu1 = convolve(img1, window) 110 | mu2 = convolve(img2, window) 111 | 112 | mu1_sq = mu1 ** 2 113 | mu2_sq = mu2 ** 2 114 | mu1_mu2 = mu1 * mu2 115 | 116 | sigma1_sq = convolve(img1 * img1, window) - mu1_sq 117 | sigma2_sq = convolve(img2 * img2, window) - mu2_sq 118 | sigma12 = convolve(img1 * img2, window) - mu1_mu2 119 | 120 | # Clip the variances and covariances to valid values. 121 | # Variance must be non-negative: 122 | epsilon = jnp.finfo(jnp.float32).eps**2 123 | sigma1_sq = jnp.maximum(epsilon, sigma1_sq) 124 | sigma2_sq = jnp.maximum(epsilon, sigma2_sq) 125 | sigma12 = jnp.sign(sigma12) * jnp.minimum( 126 | jnp.sqrt(sigma1_sq * sigma2_sq), jnp.abs(sigma12)) 127 | 128 | l = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1) 129 | c = (2 * jnp.sqrt(sigma1_sq) * jnp.sqrt(sigma2_sq) + C2) / (sigma1_sq + sigma2_sq + C2) 130 | s = (sigma12 + C3) / (jnp.sqrt(sigma1_sq) * jnp.sqrt(sigma2_sq) + C3) 131 | 132 | c = jnp.clip(c, a_max=0.98) 133 | s = jnp.clip(s, a_max=0.98) 134 | 135 | return l, c, s 136 | 137 | def compute_data_loss(batch, renderings, rays, loss_threshold, config, train_frac): 138 | """Computes data loss terms for RGB, normal, and depth outputs.""" 139 | data_losses = [] 140 | stats = collections.defaultdict(lambda: []) 141 | 142 | # lossmult can be used to apply a weight to each ray in the batch. 143 | # For example: masking out rays, applying the Bayer mosaic mask, upweighting 144 | # rays from lower resolution images and so on. 145 | lossmult = rays.lossmult 146 | lossmult = jnp.broadcast_to(lossmult, batch.rgb[..., :3].shape) 147 | if config.disable_multiscale_loss: 148 | lossmult = jnp.ones_like(lossmult) 149 | 150 | for rendering in renderings: 151 | resid_sq = (rendering['rgb'] - batch.rgb[..., :3])**2 152 | denom = lossmult.sum() 153 | stats['mses'].append((lossmult * resid_sq).sum() / denom) 154 | 155 | if config.data_loss_type == 'mse': 156 | # Mean-squared error (L2) loss. 157 | data_loss = resid_sq 158 | 159 | elif config.data_loss_type == 'on-the-go': 160 | uncer = rendering['uncer'] 161 | uncer = jnp.clip(uncer, a_min=config.uncer_clip_min)+1e-3 162 | if config.stop_ssim_gradient: 163 | l,c,s = compute_ssim(jax.lax.stop_gradient(rendering['rgb']), batch.rgb[..., :3], config.ssim_window_size) 164 | else: 165 | l,c,s = compute_ssim(rendering['rgb'], batch.rgb[..., :3]) 166 | train_frac = jnp.broadcast_to(train_frac, uncer.shape) 167 | 168 | # Calculate the SSIM loss rate, which starts at 100 and can scale up to 1000. 169 | # This is not mentioned in the paper since its effect is marginal 170 | bias = lambda x, s: x / (1 + (1 - x)*(1 / s - 2)) 171 | rate = 100 + bias(train_frac, config.ssim_anneal) * 900 172 | my_ssim_loss = jnp.clip(rate * (1-l)*(1-s)*(1-c), a_max=config.ssim_clip_max) 173 | ssim_loss = my_ssim_loss / uncer**2 + config.reg_mult * jnp.log(uncer) 174 | 175 | # Adjust uncertainty to slowly increase based on the SSIM training fraction. 176 | uncer_rate = 1 + 1 * bias(train_frac, config.ssim_anneal) 177 | uncer = (jax.lax.stop_gradient(uncer) - config.uncer_clip_min) * uncer_rate + config.uncer_clip_min 178 | data_loss = 0.5 * resid_sq / (uncer) ** 2 179 | data_loss += config.ssim_mult * ssim_loss 180 | 181 | # robustnerf loss 182 | elif config.data_loss_type == 'robustnerf': 183 | mask, robust_stats = robustnerf.robustnerf_mask(resid_sq, loss_threshold, 184 | config) 185 | data_loss = resid_sq * mask 186 | stats.update(robust_stats) 187 | else: 188 | assert False 189 | data_losses.append((lossmult * data_loss).sum() / denom) 190 | 191 | data_losses = jnp.array(data_losses) 192 | loss = ( 193 | config.data_coarse_loss_mult * jnp.sum(data_losses[:-1]) + 194 | config.data_loss_mult * data_losses[-1]) 195 | stats = {k: jnp.array(stats[k]) for k in stats} 196 | return loss, stats 197 | 198 | 199 | def interlevel_loss(ray_history, config): 200 | """Computes the interlevel loss defined in mip-NeRF 360.""" 201 | # Stop the gradient from the interlevel loss onto the NeRF MLP. 202 | last_ray_results = ray_history[-1] 203 | c = jax.lax.stop_gradient(last_ray_results['sdist']) 204 | w = jax.lax.stop_gradient(last_ray_results['weights']) 205 | loss_interlevel = 0. 206 | for ray_results in ray_history[:-1]: 207 | cp = ray_results['sdist'] 208 | wp = ray_results['weights'] 209 | loss_interlevel += jnp.mean(stepfun.lossfun_outer(c, w, cp, wp)) 210 | return config.interlevel_loss_mult * loss_interlevel 211 | 212 | 213 | def distortion_loss(ray_history, config): 214 | """Computes the distortion loss regularizer defined in mip-NeRF 360.""" 215 | last_ray_results = ray_history[-1] 216 | c = last_ray_results['sdist'] 217 | w = last_ray_results['weights'] 218 | loss = jnp.mean(stepfun.lossfun_distortion(c, w)) 219 | return config.distortion_loss_mult * loss 220 | 221 | 222 | def orientation_loss(rays, model, ray_history, config): 223 | """Computes the orientation loss regularizer defined in ref-NeRF.""" 224 | total_loss = 0. 225 | for i, ray_results in enumerate(ray_history): 226 | w = ray_results['weights'] 227 | n = ray_results[config.orientation_loss_target] 228 | if n is None: 229 | raise ValueError('Normals cannot be None if orientation loss is on.') 230 | # Negate viewdirs to represent normalized vectors from point to camera. 231 | v = -1. * rays.viewdirs 232 | n_dot_v = (n * v[..., None, :]).sum(axis=-1) 233 | loss = jnp.mean((w * jnp.minimum(0.0, n_dot_v)**2).sum(axis=-1)) 234 | if i < model.num_levels - 1: 235 | total_loss += config.orientation_coarse_loss_mult * loss 236 | else: 237 | total_loss += config.orientation_loss_mult * loss 238 | return total_loss 239 | 240 | 241 | def predicted_normal_loss(model, ray_history, config): 242 | """Computes the predicted normal supervision loss defined in ref-NeRF.""" 243 | total_loss = 0. 244 | for i, ray_results in enumerate(ray_history): 245 | w = ray_results['weights'] 246 | n = ray_results['normals'] 247 | n_pred = ray_results['normals_pred'] 248 | if n is None or n_pred is None: 249 | raise ValueError( 250 | 'Predicted normals and gradient normals cannot be None if ' 251 | 'predicted normal loss is on.') 252 | loss = jnp.mean((w * (1.0 - jnp.sum(n * n_pred, axis=-1))).sum(axis=-1)) 253 | if i < model.num_levels - 1: 254 | total_loss += config.predicted_normal_coarse_loss_mult * loss 255 | else: 256 | total_loss += config.predicted_normal_loss_mult * loss 257 | return total_loss 258 | 259 | 260 | def clip_gradients(grad, config): 261 | """Clips gradients of each MLP individually based on norm and max value.""" 262 | # Clip the gradients of each MLP individually. 263 | grad_clipped = {'params': {}} 264 | for k, g in grad['params'].items(): 265 | # Clip by value. 266 | if config.grad_max_val > 0: 267 | g = jax.tree_util.tree_map( 268 | lambda z: jnp.clip(z, -config.grad_max_val, config.grad_max_val), g) 269 | 270 | # Then clip by norm. 271 | if config.grad_max_norm > 0: 272 | mult = jnp.minimum( 273 | 1, config.grad_max_norm / (jnp.finfo(jnp.float32).eps + tree_norm(g))) 274 | g = jax.tree_util.tree_map(lambda z: mult * z, g) # pylint:disable=cell-var-from-loop 275 | 276 | grad_clipped['params'][k] = g 277 | grad = type(grad)(grad_clipped) 278 | return grad 279 | 280 | 281 | def create_train_step(model: models.Model, 282 | config: configs.Config, 283 | dataset: Optional[datasets.Dataset] = None): 284 | """Creates the pmap'ed Nerf training function. 285 | 286 | Args: 287 | model: The linen model. 288 | config: The configuration. 289 | dataset: Training dataset. 290 | 291 | Returns: 292 | pmap'ed training function. 293 | """ 294 | if dataset is None: 295 | camtype = camera_utils.ProjectionType.PERSPECTIVE 296 | else: 297 | camtype = dataset.camtype 298 | 299 | def train_step( 300 | rng, 301 | state, 302 | batch, 303 | cameras, 304 | train_frac, 305 | loss_threshold, 306 | ): 307 | """One optimization step. 308 | 309 | Args: 310 | rng: jnp.ndarray, random number generator. 311 | state: TrainState, state of the model/optimizer. 312 | batch: dict, a mini-batch of data for training. 313 | cameras: module containing camera poses. 314 | train_frac: float, the fraction of training that is complete. 315 | loss_threshold: float, the loss threshold for inliers (for robustness). 316 | 317 | Returns: 318 | A tuple (new_state, stats, rng) with 319 | new_state: TrainState, new training state. 320 | stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)]. 321 | rng: jnp.ndarray, updated random number generator. 322 | """ 323 | rng, key, dropout_key = random.split(rng, num=3) 324 | 325 | def loss_fn(variables, dropout_key): 326 | rays = batch.rays 327 | if config.cast_rays_in_train_step: 328 | rays = camera_utils.cast_ray_batch(cameras, rays, camtype, xnp=jnp) 329 | 330 | # Indicates whether we need to compute output normal or depth maps in 2D. 331 | 332 | renderings, ray_history = model.apply( 333 | variables, 334 | key if config.randomized else None, 335 | rays, 336 | train_frac=train_frac, 337 | compute_extras=(), 338 | zero_glo=False, 339 | rngs={'dropout': dropout_key}) 340 | losses = {} 341 | data_loss, stats = compute_data_loss(batch, renderings, rays, 342 | loss_threshold, config, train_frac) 343 | losses['data'] = data_loss 344 | 345 | if config.interlevel_loss_mult > 0: 346 | losses['interlevel'] = interlevel_loss(ray_history, config) 347 | 348 | if config.distortion_loss_mult > 0: 349 | losses['distortion'] = distortion_loss(ray_history, config) 350 | 351 | if (config.orientation_coarse_loss_mult > 0 or 352 | config.orientation_loss_mult > 0): 353 | losses['orientation'] = orientation_loss(rays, model, ray_history, 354 | config) 355 | 356 | if (config.predicted_normal_coarse_loss_mult > 0 or 357 | config.predicted_normal_loss_mult > 0): 358 | losses['predicted_normals'] = predicted_normal_loss( 359 | model, ray_history, config) 360 | 361 | if config.dino_var_mult > 0: 362 | losses['dino_var'] = dino_var_loss(renderings, config) 363 | 364 | stats['weight_l2s'] = summarize_tree(variables['params'], tree_norm_sq) 365 | 366 | if config.weight_decay_mults: 367 | it = config.weight_decay_mults.items 368 | losses['weight'] = jnp.sum( 369 | jnp.array([m * stats['weight_l2s'][k] for k, m in it()])) 370 | stats['loss'] = jnp.sum(jnp.array(list(losses.values()))) 371 | stats['losses'] = losses 372 | 373 | return stats['loss'], stats 374 | 375 | loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 376 | (_, stats), grad = loss_grad_fn(state.params, dropout_key) 377 | 378 | pmean = lambda x: jax.lax.pmean(x, axis_name='batch') 379 | grad = pmean(grad) 380 | stats = pmean(stats) 381 | 382 | stats['grad_norms'] = summarize_tree(grad['params'], tree_norm) 383 | stats['grad_maxes'] = summarize_tree(grad['params'], tree_abs_max) 384 | 385 | grad = clip_gradients(grad, config) 386 | 387 | grad = jax.tree_util.tree_map(jnp.nan_to_num, grad) 388 | 389 | new_state = state.apply_gradients(grads=grad) 390 | 391 | opt_delta = jax.tree_util.tree_map(lambda x, y: x - y, new_state, 392 | state).params['params'] 393 | stats['opt_update_norms'] = summarize_tree(opt_delta, tree_norm) 394 | stats['opt_update_maxes'] = summarize_tree(opt_delta, tree_abs_max) 395 | 396 | stats['psnrs'] = image.mse_to_psnr(stats['mses']) 397 | stats['psnr'] = stats['psnrs'][-1] 398 | return new_state, stats, rng 399 | 400 | train_pstep = jax.pmap( 401 | train_step, 402 | axis_name='batch', 403 | in_axes=(0, 0, 0, None, None, None), 404 | donate_argnums=(0, 1)) 405 | return train_pstep 406 | 407 | 408 | def create_optimizer( 409 | config: configs.Config, 410 | variables: FrozenVariableDict) -> Tuple[TrainState, Callable[[int], float]]: 411 | """Creates optax optimizer for model training.""" 412 | adam_kwargs = { 413 | 'b1': config.adam_beta1, 414 | 'b2': config.adam_beta2, 415 | 'eps': config.adam_eps, 416 | } 417 | lr_kwargs = { 418 | 'max_steps': config.max_steps, 419 | 'lr_delay_steps': config.lr_delay_steps, 420 | 'lr_delay_mult': config.lr_delay_mult, 421 | } 422 | 423 | 424 | 425 | def get_lr_fn(lr_init, lr_final): 426 | return functools.partial( 427 | math.learning_rate_decay, 428 | lr_init=lr_init, 429 | lr_final=lr_final, 430 | **lr_kwargs) 431 | 432 | lr_fn_main = get_lr_fn(config.lr_init, config.lr_final) 433 | tx = optax.adam(learning_rate=lr_fn_main, **adam_kwargs) 434 | 435 | return TrainState.create(apply_fn=None, params=variables, tx=tx), lr_fn_main 436 | 437 | 438 | def create_render_fn(model: models.Model): 439 | """Creates pmap'ed function for full image rendering.""" 440 | 441 | def render_eval_fn(variables, train_frac, _, rays): 442 | return jax.lax.all_gather( 443 | model.apply( 444 | variables, 445 | None, # Deterministic. 446 | rays, 447 | train_frac=train_frac, 448 | compute_extras=True, 449 | is_training=False), 450 | axis_name='batch') 451 | 452 | # pmap over only the data input. 453 | render_eval_pfn = jax.pmap( 454 | render_eval_fn, 455 | in_axes=(None, None, None, 0), 456 | axis_name='batch', 457 | ) 458 | return render_eval_pfn 459 | 460 | 461 | def setup_model( 462 | config: configs.Config, 463 | rng: jnp.array, 464 | dataset: Optional[datasets.Dataset] = None, 465 | ) -> Tuple[models.Model, TrainState, Callable[ 466 | [FrozenVariableDict, jnp.array, utils.Rays], 467 | MutableMapping[Text, Any]], Callable[ 468 | [jnp.array, TrainState, utils.Batch, 469 | Optional[Tuple[Any, ...]], float, float], 470 | Tuple[TrainState, Dict[Text, Any], jnp.array]], Callable[[int], float]]: 471 | """Creates NeRF model, optimizer, and pmap-ed train/render functions.""" 472 | feat_dim = config.feat_dim 473 | dummy_rays = utils.dummy_rays( 474 | feat_dim, include_exposure_idx=config.rawnerf_mode, include_exposure_values=True) 475 | model, variables = models.construct_model(rng, dummy_rays, config) 476 | 477 | state, lr_fn = create_optimizer(config, variables) 478 | render_eval_pfn = create_render_fn(model) 479 | train_pstep = create_train_step(model, config, dataset=dataset) 480 | 481 | return model, state, render_eval_pfn, train_pstep, lr_fn 482 | -------------------------------------------------------------------------------- /internal/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions.""" 16 | 17 | import enum 18 | import os 19 | from typing import Any, Dict, Optional, Union 20 | 21 | import flax 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | from PIL import ExifTags 26 | from PIL import Image 27 | 28 | _Array = Union[np.ndarray, jnp.ndarray] 29 | 30 | 31 | @flax.struct.dataclass 32 | class Pixels: 33 | """All tensors must have the same num_dims and first n-1 dims must match.""" 34 | pix_x_int: _Array 35 | pix_y_int: _Array 36 | lossmult: _Array 37 | near: _Array 38 | far: _Array 39 | cam_idx: _Array 40 | exposure_idx: Optional[_Array] = None 41 | exposure_values: Optional[_Array] = None 42 | features: Optional[_Array] = None 43 | 44 | 45 | @flax.struct.dataclass 46 | class Rays: 47 | """All tensors must have the same num_dims and first n-1 dims must match.""" 48 | origins: _Array 49 | directions: _Array 50 | viewdirs: _Array 51 | radii: _Array 52 | imageplane: _Array 53 | lossmult: _Array 54 | near: _Array 55 | far: _Array 56 | cam_idx: _Array 57 | exposure_idx: Optional[_Array] = None 58 | exposure_values: Optional[_Array] = None 59 | features: Optional[_Array] = None 60 | 61 | 62 | 63 | 64 | # Dummy Rays object that can be used to initialize NeRF model. 65 | def dummy_rays(feat_dim: int, 66 | include_exposure_idx: bool = False, 67 | include_exposure_values: bool = False) -> Rays: 68 | data_fn = lambda n: jnp.zeros((1, n)) 69 | exposure_kwargs = {} 70 | if include_exposure_idx: 71 | exposure_kwargs['exposure_idx'] = data_fn(1).astype(jnp.int32) 72 | if include_exposure_values: 73 | exposure_kwargs['exposure_values'] = data_fn(1) 74 | return Rays( 75 | origins=data_fn(3), 76 | directions=data_fn(3), 77 | viewdirs=data_fn(3), 78 | radii=data_fn(1), 79 | imageplane=data_fn(2), 80 | lossmult=data_fn(1), 81 | near=data_fn(1), 82 | far=data_fn(1), 83 | cam_idx=data_fn(1).astype(jnp.int32), 84 | features=data_fn(feat_dim), 85 | **exposure_kwargs) 86 | 87 | 88 | @flax.struct.dataclass 89 | class Batch: 90 | """Data batch for NeRF training or testing.""" 91 | rays: Union[Pixels, Rays] 92 | rgb: Optional[_Array] = None 93 | disps: Optional[_Array] = None 94 | normals: Optional[_Array] = None 95 | alphas: Optional[_Array] = None 96 | features: Optional[_Array] = None 97 | 98 | 99 | class DataSplit(enum.Enum): 100 | """Dataset split.""" 101 | TRAIN = 'train' 102 | TEST = 'test' 103 | 104 | 105 | class BatchingMethod(enum.Enum): 106 | """Draw rays randomly from a single image or all images, in each batch.""" 107 | ALL_IMAGES = 'all_images' 108 | SINGLE_IMAGE = 'single_image' 109 | 110 | 111 | def open_file(pth, mode='r'): 112 | return open(pth, mode=mode) 113 | 114 | 115 | def file_exists(pth): 116 | return os.path.exists(pth) 117 | 118 | 119 | def listdir(pth): 120 | return os.listdir(pth) 121 | 122 | 123 | def isdir(pth): 124 | return os.path.isdir(pth) 125 | 126 | 127 | def makedirs(pth): 128 | if not file_exists(pth): 129 | os.makedirs(pth) 130 | 131 | 132 | def shard(xs): 133 | """Split data into shards for multiple devices along the first dimension.""" 134 | return jax.tree_util.tree_map( 135 | lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs) 136 | 137 | 138 | def unshard(x, padding=0): 139 | """Collect the sharded tensor to the shape before sharding.""" 140 | y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:])) 141 | if padding > 0: 142 | y = y[:-padding] 143 | return y 144 | 145 | def load_npy(pth: str) -> np.ndarray: 146 | """Load a numpy array.""" 147 | with open_file(pth, 'rb') as f: 148 | data = np.load(f).astype(np.float32) 149 | return data 150 | 151 | def load_img(pth: str) -> np.ndarray: 152 | """Load an image and cast to float32.""" 153 | with open_file(pth, 'rb') as f: 154 | image = np.array(Image.open(f), dtype=np.float32) 155 | return image 156 | 157 | 158 | def load_exif(pth: str) -> Dict[str, Any]: 159 | """Load EXIF data for an image.""" 160 | with open_file(pth, 'rb') as f: 161 | image_pil = Image.open(f) 162 | exif_pil = image_pil._getexif() # pylint: disable=protected-access 163 | if exif_pil is not None: 164 | exif = { 165 | ExifTags.TAGS[k]: v for k, v in exif_pil.items() if k in ExifTags.TAGS 166 | } 167 | else: 168 | exif = {} 169 | return exif 170 | 171 | 172 | def save_img_u8(img, pth): 173 | """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" 174 | with open_file(pth, 'wb') as f: 175 | Image.fromarray( 176 | (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save( 177 | f, 'PNG') 178 | 179 | 180 | def save_img_f32(depthmap, pth): 181 | """Save an image (probably a depthmap) to disk as a float32 TIFF.""" 182 | with open_file(pth, 'wb') as f: 183 | Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF') 184 | -------------------------------------------------------------------------------- /internal/vis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for visualizing things.""" 16 | 17 | from internal import stepfun 18 | import jax.numpy as jnp 19 | from matplotlib import cm 20 | 21 | 22 | def weighted_percentile(x, w, ps, assume_sorted=False): 23 | """Compute the weighted percentile(s) of a single vector.""" 24 | x = x.reshape([-1]) 25 | w = w.reshape([-1]) 26 | if not assume_sorted: 27 | sortidx = jnp.argsort(x) 28 | x, w = x[sortidx], w[sortidx] 29 | acc_w = jnp.cumsum(w) 30 | return jnp.interp(jnp.array(ps) * (acc_w[-1] / 100), acc_w, x) 31 | 32 | 33 | def sinebow(h): 34 | """A cyclic and uniform colormap, see http://basecase.org/env/on-rainbows.""" 35 | f = lambda x: jnp.sin(jnp.pi * x)**2 36 | return jnp.stack([f(3 / 6 - h), f(5 / 6 - h), f(7 / 6 - h)], -1) 37 | 38 | 39 | def matte(vis, acc, dark=0.8, light=1.0, width=8): 40 | """Set non-accumulated pixels to a Photoshop-esque checker pattern.""" 41 | bg_mask = jnp.logical_xor( 42 | (jnp.arange(acc.shape[0]) % (2 * width) // width)[:, None], 43 | (jnp.arange(acc.shape[1]) % (2 * width) // width)[None, :]) 44 | bg = jnp.where(bg_mask, light, dark) 45 | return vis * acc[:, :, None] + (bg * (1 - acc))[:, :, None] 46 | 47 | 48 | def visualize_cmap(value, 49 | weight, 50 | colormap, 51 | lo=None, 52 | hi=None, 53 | percentile=99., 54 | curve_fn=lambda x: x, 55 | modulus=None, 56 | matte_background=True): 57 | """Visualize a 1D image and a 1D weighting according to some colormap. 58 | 59 | Args: 60 | value: A 1D image. 61 | weight: A weight map, in [0, 1]. 62 | colormap: A colormap function. 63 | lo: The lower bound to use when rendering, if None then use a percentile. 64 | hi: The upper bound to use when rendering, if None then use a percentile. 65 | percentile: What percentile of the value map to crop to when automatically 66 | generating `lo` and `hi`. Depends on `weight` as well as `value'. 67 | curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` 68 | before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). 69 | modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If 70 | `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. 71 | matte_background: If True, matte the image over a checkerboard. 72 | 73 | Returns: 74 | A colormap rendering. 75 | """ 76 | # Identify the values that bound the middle of `value' according to `weight`. 77 | lo_auto, hi_auto = weighted_percentile( 78 | value, weight, [50 - percentile / 2, 50 + percentile / 2]) 79 | 80 | # If `lo` or `hi` are None, use the automatically-computed bounds above. 81 | eps = jnp.finfo(jnp.float32).eps 82 | lo = lo or (lo_auto - eps) 83 | hi = hi or (hi_auto + eps) 84 | 85 | # Curve all values. 86 | value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] 87 | 88 | # Wrap the values around if requested. 89 | if modulus: 90 | value = jnp.mod(value, modulus) / modulus 91 | else: 92 | # Otherwise, just scale to [0, 1]. 93 | value = jnp.nan_to_num( 94 | jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1)) 95 | 96 | if colormap: 97 | colorized = colormap(value)[:, :, :3] 98 | else: 99 | if len(value.shape) != 3: 100 | raise ValueError(f'value must have 3 dims but has {len(value.shape)}') 101 | if value.shape[-1] != 3: 102 | raise ValueError( 103 | f'value must have 3 channels but has {len(value.shape[-1])}') 104 | colorized = value 105 | 106 | return matte(colorized, weight) if matte_background else colorized 107 | 108 | 109 | def visualize_coord_mod(coords, acc): 110 | """Visualize the coordinate of each point within its "cell".""" 111 | return matte(((coords + 1) % 2) / 2, acc) 112 | 113 | 114 | def visualize_rays(dist, 115 | dist_range, 116 | weights, 117 | rgbs, 118 | accumulate=False, 119 | renormalize=False, 120 | resolution=2048, 121 | bg_color=0.8): 122 | """Visualize a bundle of rays.""" 123 | dist_vis = jnp.linspace(*dist_range, resolution + 1) 124 | vis_rgb, vis_alpha = [], [] 125 | for ds, ws, rs in zip(dist, weights, rgbs): 126 | vis_rs, vis_ws = [], [] 127 | for d, w, r in zip(ds, ws, rs): 128 | if accumulate: 129 | # Produce the accumulated color and weight at each point along the ray. 130 | w_csum = jnp.cumsum(w, axis=0) 131 | rw_csum = jnp.cumsum((r * w[:, None]), axis=0) 132 | eps = jnp.finfo(jnp.float32).eps 133 | r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum 134 | vis_rs.append(stepfun.resample(dist_vis, d, r.T, use_avg=True).T) 135 | vis_ws.append(stepfun.resample(dist_vis, d, w.T, use_avg=True).T) 136 | vis_rgb.append(jnp.stack(vis_rs)) 137 | vis_alpha.append(jnp.stack(vis_ws)) 138 | vis_rgb = jnp.stack(vis_rgb, axis=1) 139 | vis_alpha = jnp.stack(vis_alpha, axis=1) 140 | 141 | if renormalize: 142 | # Scale the alphas so that the largest value is 1, for visualization. 143 | vis_alpha /= jnp.maximum(jnp.finfo(jnp.float32).eps, jnp.max(vis_alpha)) 144 | 145 | if resolution > vis_rgb.shape[0]: 146 | rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1) 147 | stride = rep * vis_rgb.shape[1] 148 | 149 | vis_rgb = jnp.tile(vis_rgb, (1, 1, rep, 1)).reshape((-1,) + vis_rgb.shape[2:]) 150 | vis_alpha = jnp.tile(vis_alpha, (1, 1, rep)).reshape((-1,) + vis_alpha.shape[2:]) 151 | 152 | # Add a strip of background pixels after each set of levels of rays. 153 | vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:]) 154 | vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:]) 155 | vis_rgb = jnp.concatenate([vis_rgb, jnp.zeros_like(vis_rgb[:, :1])], 156 | axis=1).reshape((-1,) + vis_rgb.shape[2:]) 157 | vis_alpha = jnp.concatenate( 158 | [vis_alpha, jnp.zeros_like(vis_alpha[:, :1])], 159 | axis=1).reshape((-1,) + vis_alpha.shape[2:]) 160 | 161 | # Matte the RGB image over the background. 162 | vis = vis_rgb * vis_alpha[..., None] + (bg_color * (1 - vis_alpha))[..., None] 163 | 164 | # Remove the final row of background pixels. 165 | vis = vis[:-1] 166 | vis_alpha = vis_alpha[:-1] 167 | return vis, vis_alpha 168 | 169 | 170 | def visualize_suite(rendering, rays): 171 | """A wrapper around other visualizations for easy integration.""" 172 | 173 | depth_curve_fn = lambda x: -jnp.log(x + jnp.finfo(jnp.float32).eps) 174 | 175 | rgb = rendering['rgb'] 176 | acc = rendering['acc'] 177 | 178 | distance_mean = rendering['distance_mean'] 179 | distance_median = rendering['distance_median'] 180 | distance_p5 = rendering['distance_percentile_5'] 181 | distance_p95 = rendering['distance_percentile_95'] 182 | acc = jnp.where(jnp.isnan(distance_mean), jnp.zeros_like(acc), acc) 183 | 184 | # The xyz coordinates where rays terminate. 185 | coords = rays.origins + rays.directions * distance_mean[:, :, None] 186 | 187 | vis_depth_mean, vis_depth_median = [ 188 | visualize_cmap(x, acc, cm.get_cmap('turbo'), curve_fn=depth_curve_fn) 189 | for x in [distance_mean, distance_median] 190 | ] 191 | 192 | # Render three depth percentiles directly to RGB channels, where the spacing 193 | # determines the color. delta == big change, epsilon = small change. 194 | # Gray: A strong discontinuitiy, [x-epsilon, x, x+epsilon] 195 | # Purple: A thin but even density, [x-delta, x, x+delta] 196 | # Red: A thin density, then a thick density, [x-delta, x, x+epsilon] 197 | # Blue: A thick density, then a thin density, [x-epsilon, x, x+delta] 198 | vis_depth_triplet = visualize_cmap( 199 | jnp.stack( 200 | [2 * distance_median - distance_p5, distance_median, distance_p95], 201 | axis=-1), 202 | acc, 203 | None, 204 | curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps)) 205 | 206 | dist = rendering['ray_sdist'] 207 | dist_range = (0, 1) 208 | weights = rendering['ray_weights'] 209 | rgbs = [jnp.clip(r, 0, 1) for r in rendering['ray_rgbs']] 210 | 211 | vis_ray_colors, _ = visualize_rays(dist, dist_range, weights, rgbs) 212 | 213 | sqrt_weights = [jnp.sqrt(w) for w in weights] 214 | sqrt_ray_weights, ray_alpha = visualize_rays( 215 | dist, 216 | dist_range, 217 | [jnp.ones_like(lw) for lw in sqrt_weights], 218 | [lw[..., None] for lw in sqrt_weights], 219 | bg_color=0, 220 | ) 221 | sqrt_ray_weights = sqrt_ray_weights[..., 0] 222 | null_color = jnp.array([1., 0., 0.]) 223 | vis_ray_weights = jnp.where( 224 | ray_alpha[:, :, None] == 0, 225 | null_color[None, None], 226 | visualize_cmap( 227 | sqrt_ray_weights, 228 | jnp.ones_like(sqrt_ray_weights), 229 | cm.get_cmap('gray'), 230 | lo=0, 231 | hi=1, 232 | matte_background=False, 233 | ), 234 | ) 235 | vis_uncertainty = visualize_cmap( 236 | rendering['uncer'][...,0], 237 | acc, 238 | cm.get_cmap('turbo'), 239 | lo=0.2, 240 | hi=2, 241 | ) 242 | vis = { 243 | 'color': rgb, 244 | 'acc': acc, 245 | 'color_matte': matte(rgb, acc), 246 | 'depth_mean': vis_depth_mean, 247 | 'depth_median': vis_depth_median, 248 | 'depth_triplet': vis_depth_triplet, 249 | 'coords_mod': visualize_coord_mod(coords, acc), 250 | 'ray_colors': vis_ray_colors, 251 | 'ray_weights': vis_ray_weights, 252 | 'uncertainty': vis_uncertainty, 253 | 'uncertainty_raw': rendering['uncer'][...,0], 254 | } 255 | 256 | if 'rgb_cc' in rendering: 257 | vis['color_corrected'] = rendering['rgb_cc'] 258 | 259 | return vis 260 | -------------------------------------------------------------------------------- /media/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/nerf-on-the-go/0c32fbb5fdec68d989d406618c253cc56524f64f/media/teaser.gif -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Render script.""" 16 | 17 | import concurrent.futures 18 | import functools 19 | import glob 20 | import os 21 | import time 22 | 23 | from absl import app 24 | from flax.training import checkpoints 25 | import gin 26 | from internal import configs 27 | from internal import datasets 28 | from internal import models 29 | from internal import train_utils 30 | from internal import utils 31 | import jax 32 | from jax import random 33 | from matplotlib import cm 34 | import mediapy as media 35 | import numpy as np 36 | 37 | configs.define_common_flags() 38 | jax.config.parse_flags_with_absl() 39 | 40 | 41 | def create_videos(config, base_dir, out_dir, out_name, num_frames): 42 | """Creates videos out of the images saved to disk.""" 43 | names = [n for n in config.checkpoint_dir.split('/') if n] 44 | # Last two parts of checkpoint path are experiment name and scene name. 45 | exp_name, scene_name = names[-2:] 46 | video_prefix = f'{scene_name}_{exp_name}_{out_name}' 47 | 48 | zpad = max(3, len(str(num_frames - 1))) 49 | idx_to_str = lambda idx: str(idx).zfill(zpad) 50 | 51 | utils.makedirs(base_dir) 52 | 53 | # Load one example frame to get image shape and depth range. 54 | depth_file = os.path.join(out_dir, f'distance_mean_{idx_to_str(0)}.tiff') 55 | depth_frame = utils.load_img(depth_file) 56 | shape = depth_frame.shape 57 | p = config.render_dist_percentile 58 | distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p]) 59 | lo, hi = [config.render_dist_curve_fn(x) for x in distance_limits] 60 | print(f'Video shape is {shape[:2]}') 61 | 62 | video_kwargs = { 63 | 'shape': shape[:2], 64 | 'codec': 'h264', 65 | 'fps': config.render_video_fps, 66 | 'crf': config.render_video_crf, 67 | } 68 | 69 | for k in ['color', 'acc', 'distance_mean', 'distance_median']: 70 | video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') 71 | input_format = 'gray' if k == 'acc' else 'rgb' 72 | file_ext = 'png' if k in ['color', 'normals'] else 'tiff' 73 | idx = 0 74 | file0 = os.path.join(out_dir, f'{k}_{idx_to_str(0)}.{file_ext}') 75 | if not utils.file_exists(file0): 76 | print(f'Images missing for tag {k}') 77 | continue 78 | print(f'Making video {video_file}...') 79 | with media.VideoWriter( 80 | video_file, **video_kwargs, input_format=input_format) as writer: 81 | for idx in range(num_frames): 82 | img_file = os.path.join(out_dir, f'{k}_{idx_to_str(idx)}.{file_ext}') 83 | if not utils.file_exists(img_file): 84 | ValueError(f'Image file {img_file} does not exist.') 85 | img = utils.load_img(img_file) 86 | if k in ['color', 'normals']: 87 | img = img / 255. 88 | elif k.startswith('distance'): 89 | img = config.render_dist_curve_fn(img) 90 | img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1) 91 | img = cm.get_cmap('turbo')(img)[..., :3] 92 | 93 | frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8) 94 | writer.add_image(frame) 95 | idx += 1 96 | 97 | 98 | def main(unused_argv): 99 | 100 | config = configs.load_config(save_config=False) 101 | 102 | dataset = datasets.load_dataset('test', config.data_dir, config) 103 | 104 | key = random.PRNGKey(20200823) 105 | _, state, render_eval_pfn, _, _ = train_utils.setup_model(config, key) 106 | 107 | if config.rawnerf_mode: 108 | postprocess_fn = dataset.metadata['postprocess_fn'] 109 | else: 110 | postprocess_fn = lambda z: z 111 | 112 | state = checkpoints.restore_checkpoint(config.checkpoint_dir, state) 113 | step = int(state.step) 114 | print(f'Rendering checkpoint at step {step}.') 115 | 116 | out_name = 'path_renders' if config.render_path else 'test_preds' 117 | out_name = f'{out_name}_step_{step}' 118 | base_dir = config.render_dir 119 | if base_dir is None: 120 | base_dir = os.path.join(config.checkpoint_dir, 'render') 121 | out_dir = os.path.join(base_dir, out_name) 122 | if not utils.isdir(out_dir): 123 | utils.makedirs(out_dir) 124 | 125 | path_fn = lambda x: os.path.join(out_dir, x) 126 | 127 | # Ensure sufficient zero-padding of image indices in output filenames. 128 | zpad = max(3, len(str(dataset.size - 1))) 129 | idx_to_str = lambda idx: str(idx).zfill(zpad) 130 | 131 | if config.render_save_async: 132 | async_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) 133 | async_futures = [] 134 | def save_fn(fn, *args, **kwargs): 135 | async_futures.append(async_executor.submit(fn, *args, **kwargs)) 136 | else: 137 | def save_fn(fn, *args, **kwargs): 138 | fn(*args, **kwargs) 139 | 140 | for idx in range(dataset.size): 141 | if idx % config.render_num_jobs != config.render_job_id: 142 | continue 143 | # If current image and next image both already exist, skip ahead. 144 | idx_str = idx_to_str(idx) 145 | curr_file = path_fn(f'color_{idx_str}.png') 146 | next_idx_str = idx_to_str(idx + config.render_num_jobs) 147 | next_file = path_fn(f'color_{next_idx_str}.png') 148 | if utils.file_exists(curr_file) and utils.file_exists(next_file): 149 | print(f'Image {idx}/{dataset.size} already exists, skipping') 150 | continue 151 | print(f'Evaluating image {idx+1}/{dataset.size}') 152 | eval_start_time = time.time() 153 | rays = dataset.generate_ray_batch(idx).rays 154 | train_frac = 1. 155 | rendering = models.render_image( 156 | functools.partial(render_eval_pfn, state.params, train_frac), 157 | rays, None, config) 158 | print(f'Rendered in {(time.time() - eval_start_time):0.3f}s') 159 | 160 | if jax.host_id() != 0: # Only record via host 0. 161 | continue 162 | 163 | rendering['rgb'] = postprocess_fn(rendering['rgb']) 164 | 165 | save_fn( 166 | utils.save_img_u8, rendering['rgb'], path_fn(f'color_{idx_str}.png')) 167 | save_fn( 168 | utils.save_img_f32, rendering['distance_mean'], 169 | path_fn(f'distance_mean_{idx_str}.tiff')) 170 | save_fn( 171 | utils.save_img_f32, rendering['distance_median'], 172 | path_fn(f'distance_median_{idx_str}.tiff')) 173 | save_fn( 174 | utils.save_img_f32, rendering['acc'], path_fn(f'acc_{idx_str}.tiff')) 175 | 176 | if config.render_save_async: 177 | # Wait until all worker threads finish. 178 | async_executor.shutdown(wait=True) 179 | 180 | # This will ensure that exceptions in child threads are raised to the 181 | # main thread. 182 | for future in async_futures: 183 | future.result() 184 | 185 | time.sleep(1) 186 | num_files = len(glob.glob(path_fn('acc_*.tiff'))) 187 | time.sleep(10) 188 | if jax.host_id() == 0 and num_files == dataset.size: 189 | print(f'All files found, creating videos (job {config.render_job_id}).') 190 | create_videos(config, base_dir, out_dir, out_name, dataset.size) 191 | 192 | # A hack that forces Jax to keep all TPUs alive until every TPU is finished. 193 | x = jax.numpy.ones([jax.local_device_count()]) 194 | x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)) 195 | print(x) 196 | 197 | 198 | if __name__ == '__main__': 199 | with gin.config_scope('eval'): # Use the same scope as eval.py 200 | app.run(main) 201 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | jax 3 | jaxlib 4 | opencv-python 5 | Pillow 6 | tensorboard 7 | tensorflow 8 | gin-config 9 | dm_pix 10 | rawpy 11 | mediapy 12 | lpips_jax 13 | chex 14 | optax 15 | ml-dtypes 16 | flax 17 | gdown 18 | torch 19 | torchvision 20 | torchaudio 21 | orbax-checkpoint==0.3.5 22 | matplotlib==3.8.4 -------------------------------------------------------------------------------- /scripts/download_on-the-go.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir Datasets 4 | wget https://cvg-data.inf.ethz.ch/on-the-go.zip 5 | unzip on-the-go.zip -d Datasets 6 | rm on-the-go.zip 7 | # Base directory containing the sequence directories 8 | base_dir="./Datasets/on-the-go" 9 | 10 | # Loop through each sequence directory in the base directory 11 | for seq_dir in "$base_dir"/*; do 12 | # Extract just the name of the directory (the sequence name) 13 | seq_name=$(basename "$seq_dir") 14 | echo "Processing sequence: $seq_name" 15 | 16 | # Determine the downsampling rate based on the sequence name 17 | if [ "$seq_name" = "arcdetriomphe" ] || [ "$seq_name" = "patio" ]; then 18 | rate=4 19 | else 20 | rate=8 21 | fi 22 | 23 | # Calculate percentage for resizing based on the downsample rate 24 | percentage=$(bc <<< "scale=2; 100 / $rate") 25 | 26 | # Directory names for images, defined relative to the base_dir 27 | original_images_dir="$seq_dir/images" 28 | downsampled_images_dir="$seq_dir/images_$rate" 29 | 30 | # Copy images to new directory before downsampling, handling both JPG and jpg 31 | cp -r "$original_images_dir" "$downsampled_images_dir" 32 | 33 | # Downsample images using mogrify for both JPG and jpg 34 | pushd "$downsampled_images_dir" 35 | ls | xargs -P 8 -I {} mogrify -resize ${percentage}% {} 36 | popd 37 | 38 | done 39 | -------------------------------------------------------------------------------- /scripts/eval_on-the-go.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -n 4 4 | #SBATCH --time=15:00:00 5 | #SBATCH --mem-per-cpu=20g 6 | #SBATCH --tmp=4000 # per node!! 7 | #SBATCH --gpus=4090:4 8 | #SBATCH --gres=gpumem:20g 9 | #SBATCH --job-name=yard_high 10 | #SBATCH --output=slurm/yard_high.out 11 | #SBATCH --error=slurm/yard_high.err 12 | 13 | python -m eval \ 14 | --gin_configs=configs/360_dino.gin \ 15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio_high'" \ 16 | --gin_bindings="Config.checkpoint_dir = 'output/patio_high/run_1/checkpoints'" \ 17 | --gin_bindings="Config.eval_train = False" \ 18 | --gin_bindings="Config.factor = 8" \ 19 | 20 | -------------------------------------------------------------------------------- /scripts/eval_on-the-go_HD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -n 4 4 | #SBATCH --time=15:00:00 5 | #SBATCH --mem-per-cpu=20g 6 | #SBATCH --tmp=4000 # per node!! 7 | #SBATCH --gpus=4090:4 8 | #SBATCH --gres=gpumem:20g 9 | #SBATCH --job-name=yard_high 10 | #SBATCH --output=slurm/yard_high.out 11 | #SBATCH --error=slurm/yard_high.err 12 | 13 | python -m eval \ 14 | --gin_configs=configs/360_dino.gin \ 15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio'" \ 16 | --gin_bindings="Config.checkpoint_dir = 'output/patio/run_1/checkpoints'" \ 17 | --gin_bindings="Config.eval_train = False" \ 18 | --gin_bindings="Config.factor = 4" \ 19 | --gin_bindings="Config.H = 1080" \ 20 | --gin_bindings="Config.W = 1920" \ 21 | --gin_bindings="Config.factor = 4" \ 22 | --gin_bindings="Config.feat_rate = 2" \ 23 | 24 | 25 | -------------------------------------------------------------------------------- /scripts/feature_extract.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import torchvision.transforms as T 4 | import os 5 | # import hubconf 6 | from tqdm import tqdm 7 | import shutil 8 | import numpy as np 9 | 10 | if __name__ == '__main__': 11 | import argparse 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--seq', type=str, required=True) 14 | parser.add_argument('--rate', type=int, default=4) 15 | parser.add_argument('--H', type=int, default=3024) 16 | parser.add_argument('--W', type=int, default=4032) 17 | 18 | args = parser.parse_args() 19 | base_path = f"./Datasets/on-the-go/{args.seq}" 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 22 | dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') 23 | dinov2_vits14.to(device) 24 | extractor = dinov2_vits14 25 | 26 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 27 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 28 | RATE = args.rate 29 | RESIZE_H = (args.H // RATE) // 14 * 14 30 | RESIZE_W = (args.W // RATE) // 14 * 14 31 | 32 | if os.path.exists(os.path.join(base_path, f'features_{RATE}')): 33 | shutil.rmtree(os.path.join(base_path, f'features_{RATE}')) 34 | folder = os.path.join(base_path, 'images') 35 | files = os.listdir(folder) 36 | files = [os.path.join(folder, f) for f in files] 37 | features = [] 38 | for f in tqdm(files): 39 | img = Image.open(f).convert('RGB') 40 | transform = T.Compose([ 41 | T.Resize((RESIZE_H, RESIZE_W)), 42 | T.ToTensor(), 43 | T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 44 | ]) 45 | img = transform(img)[:3].unsqueeze(0) 46 | with torch.no_grad(): 47 | features_dict = extractor.forward_features(img.cuda()) 48 | features = features_dict['x_norm_patchtokens'].view(RESIZE_H // 14, RESIZE_W // 14, -1) 49 | img_type = f[-4:] 50 | save_path = f.replace(f'{img_type}', '.npy').replace('/images/', f'/features_{RATE}/') 51 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 52 | np.save(save_path, features.detach().cpu().numpy()) 53 | -------------------------------------------------------------------------------- /scripts/feature_extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Base directory containing the sequence directories 4 | base_dir="./Datasets/on-the-go" 5 | 6 | # Define the sequences that need special parameters 7 | special_seqs=("arcdetriomphe" "patio") 8 | 9 | # Loop through each sequence directory in the base directory 10 | for seq_dir in "$base_dir"/*; do 11 | # Extract just the name of the directory (the sequence name) 12 | seq_name=$(basename "$seq_dir") 13 | 14 | # Check if the sequence is one of the special cases 15 | if [[ " ${special_seqs[@]} " =~ " $seq_name " ]]; then 16 | # Run feature extraction with additional parameters for special sequences 17 | python scripts/feature_extract.py --seq "$seq_name" --H 1080 --W 1920 --rate 2 18 | else 19 | # Run feature extraction without additional parameters for all other sequences 20 | python scripts/feature_extract.py --seq "$seq_name" 21 | fi 22 | done 23 | -------------------------------------------------------------------------------- /scripts/local_colmap_and_resize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Set to 0 if you do not have a GPU. 18 | USE_GPU=1 19 | # Path to a directory `base/` with images in `base/images/`. 20 | DATASET_PATH=$1 21 | # Recommended CAMERA values: OPENCV for perspective, OPENCV_FISHEYE for fisheye. 22 | CAMERA=${2:-OPENCV} 23 | 24 | 25 | # Run COLMAP. 26 | 27 | ### Feature extraction 28 | 29 | colmap feature_extractor \ 30 | --database_path "$DATASET_PATH"/database.db \ 31 | --image_path "$DATASET_PATH"/images \ 32 | --ImageReader.single_camera 1 \ 33 | --ImageReader.camera_model "$CAMERA" \ 34 | --SiftExtraction.use_gpu "$USE_GPU" 35 | 36 | 37 | ### Feature matching 38 | 39 | colmap exhaustive_matcher \ 40 | --database_path "$DATASET_PATH"/database.db \ 41 | --SiftMatching.use_gpu "$USE_GPU" 42 | 43 | ## Use if your scene has > 500 images 44 | ## Replace this path with your own local copy of the file. 45 | ## Download from: https://demuc.de/colmap/#download 46 | # VOCABTREE_PATH=/usr/local/google/home/bmild/vocab_tree_flickr100K_words32K.bin 47 | # colmap vocab_tree_matcher \ 48 | # --database_path "$DATASET_PATH"/database.db \ 49 | # --VocabTreeMatching.vocab_tree_path $VOCABTREE_PATH \ 50 | # --SiftMatching.use_gpu "$USE_GPU" 51 | 52 | 53 | ### Bundle adjustment 54 | 55 | # The default Mapper tolerance is unnecessarily large, 56 | # decreasing it speeds up bundle adjustment steps. 57 | mkdir -p "$DATASET_PATH"/sparse 58 | colmap mapper \ 59 | --database_path "$DATASET_PATH"/database.db \ 60 | --image_path "$DATASET_PATH"/images \ 61 | --output_path "$DATASET_PATH"/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001 63 | 64 | 65 | ### Image undistortion 66 | 67 | ## Use this if you want to undistort your images into ideal pinhole intrinsics. 68 | # mkdir -p "$DATASET_PATH"/dense 69 | # colmap image_undistorter \ 70 | # --image_path "$DATASET_PATH"/images \ 71 | # --input_path "$DATASET_PATH"/sparse/0 \ 72 | # --output_path "$DATASET_PATH"/dense \ 73 | # --output_type COLMAP 74 | 75 | # Resize images. 76 | 77 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_2 78 | 79 | pushd "$DATASET_PATH"/images_2 80 | ls | xargs -P 8 -I {} mogrify -resize 50% {} 81 | popd 82 | 83 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_4 84 | 85 | pushd "$DATASET_PATH"/images_4 86 | ls | xargs -P 8 -I {} mogrify -resize 25% {} 87 | popd 88 | 89 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_8 90 | 91 | pushd "$DATASET_PATH"/images_8 92 | ls | xargs -P 8 -I {} mogrify -resize 12.5% {} 93 | popd 94 | -------------------------------------------------------------------------------- /scripts/render_on-the-go.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -n 4 4 | #SBATCH --time=15:00:00 5 | #SBATCH --mem-per-cpu=20g 6 | #SBATCH --tmp=4000 # per node!! 7 | #SBATCH --gpus=4090:4 8 | #SBATCH --gres=gpumem:20g 9 | #SBATCH --job-name=patio_high 10 | #SBATCH --output=slurm/patio_high.out 11 | #SBATCH --error=slurm/patio_high.err 12 | 13 | python -m render \ 14 | --gin_configs=configs/360_dino.gin \ 15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio_high'" \ 16 | --gin_bindings="Config.checkpoint_dir = 'output/patio_high/run_1/checkpoints'" \ 17 | --gin_bindings="Config.render_dir = 'output/patio_high/run_1/checkpoints'" \ 18 | --gin_bindings="Config.render_path = True" \ 19 | --gin_bindings="Config.render_path_frames = 160" \ 20 | --gin_bindings="Config.render_video_fps = 160" \ 21 | -------------------------------------------------------------------------------- /scripts/render_on-the-go_HD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -n 4 4 | #SBATCH --time=15:00:00 5 | #SBATCH --mem-per-cpu=20g 6 | #SBATCH --tmp=4000 # per node!! 7 | #SBATCH --gpus=4090:4 8 | #SBATCH --gres=gpumem:20g 9 | #SBATCH --job-name=patio_high 10 | #SBATCH --output=slurm/patio_high.out 11 | #SBATCH --error=slurm/patio_high.err 12 | 13 | python -m render \ 14 | --gin_configs=configs/360_dino.gin \ 15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio_high'" \ 16 | --gin_bindings="Config.checkpoint_dir = 'output/patio_high/run_1/checkpoints'" \ 17 | --gin_bindings="Config.render_dir = 'output/patio_high/run_1/checkpoints'" \ 18 | --gin_bindings="Config.render_path = True" \ 19 | --gin_bindings="Config.render_path_frames = 160" \ 20 | --gin_bindings="Config.render_video_fps = 160" \ 21 | --gin_bindings="Config.H = 1080" \ 22 | --gin_bindings="Config.W = 1920" \ 23 | --gin_bindings="Config.factor = 4" \ 24 | --gin_bindings="Config.feat_rate = 2" \ 25 | -------------------------------------------------------------------------------- /scripts/run_all_unit_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | python -m unittest tests.camera_utils_test 18 | python -m unittest tests.stepfun_test 19 | python -m unittest tests.coord_test 20 | python -m unittest tests.math_test 21 | -------------------------------------------------------------------------------- /scripts/train_on-the-go.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -n 4 4 | #SBATCH --time=36:00:00 5 | #SBATCH --mem-per-cpu=20g 6 | #SBATCH --tmp=4000 # per node!! 7 | #SBATCH --gpus=a100_80gb:1 8 | #SBATCH --gres=gpumem:20g 9 | #SBATCH --job-name=patio_high 10 | #SBATCH --output=slurm/patio_high.out 11 | #SBATCH --error=slurm/patio_high.err 12 | 13 | python -m train \ 14 | --gin_configs=configs/360_dino.gin \ 15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio_high'" \ 16 | --gin_bindings="Config.checkpoint_dir = 'output/patio_high/run_1/checkpoints'" \ 17 | --gin_bindings="Config.patch_size = 32" \ 18 | --gin_bindings="Config.dilate = 4" \ 19 | --gin_bindings="Config.data_loss_type = 'on-the-go'" \ -------------------------------------------------------------------------------- /scripts/train_on-the-go_HD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -n 4 4 | #SBATCH --time=36:00:00 5 | #SBATCH --mem-per-cpu=20g 6 | #SBATCH --tmp=4000 # per node!! 7 | #SBATCH --gpus=a100_80gb:1 8 | #SBATCH --gres=gpumem:20g 9 | #SBATCH --job-name=patio 10 | #SBATCH --output=slurm/patio.out 11 | #SBATCH --error=slurm/patio.err 12 | 13 | python -m train \ 14 | --gin_configs=configs/360_dino.gin \ 15 | --gin_bindings="Config.data_dir = 'Datasets/on-the-go/patio'" \ 16 | --gin_bindings="Config.checkpoint_dir = 'output/patio/run_1/checkpoints'" \ 17 | --gin_bindings="Config.patch_size = 32" \ 18 | --gin_bindings="Config.dilate = 4" \ 19 | --gin_bindings="Config.data_loss_type = 'on-the-go'" \ 20 | --gin_bindings="Config.train_render_every = 5000" \ 21 | --gin_bindings="Config.H = 1080" \ 22 | --gin_bindings="Config.W = 1920" \ 23 | --gin_bindings="Config.factor = 4" \ 24 | --gin_bindings="Config.feat_rate = 2" \ 25 | -------------------------------------------------------------------------------- /tests/camera_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for camera_utils.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from internal import camera_utils 20 | from jax import random 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | class CameraUtilsTest(parameterized.TestCase): 26 | 27 | def test_convert_to_ndc(self): 28 | rng = random.PRNGKey(0) 29 | for _ in range(10): 30 | # Random pinhole camera intrinsics. 31 | key, rng = random.split(rng) 32 | focal, width, height = random.uniform(key, (3,), minval=100., maxval=200.) 33 | camtopix = camera_utils.intrinsic_matrix(focal, focal, width / 2., 34 | height / 2.) 35 | pixtocam = np.linalg.inv(camtopix) 36 | near = 1. 37 | 38 | # Random rays, pointing forward (negative z direction). 39 | num_rays = 1000 40 | key, rng = random.split(rng) 41 | origins = jnp.array([0., 0., 1.]) 42 | origins += random.uniform(key, (num_rays, 3), minval=-1., maxval=1.) 43 | directions = jnp.array([0., 0., -1.]) 44 | directions += random.uniform(key, (num_rays, 3), minval=-.5, maxval=.5) 45 | 46 | # Project world-space points along each ray into NDC space. 47 | t = jnp.linspace(0., 1., 10) 48 | pts_world = origins + t[:, None, None] * directions 49 | pts_ndc = jnp.stack([ 50 | -focal / (.5 * width) * pts_world[..., 0] / pts_world[..., 2], 51 | -focal / (.5 * height) * pts_world[..., 1] / pts_world[..., 2], 52 | 1. + 2. * near / pts_world[..., 2], 53 | ], 54 | axis=-1) 55 | 56 | # Get NDC space rays. 57 | origins_ndc, directions_ndc = camera_utils.convert_to_ndc( 58 | origins, directions, pixtocam, near) 59 | 60 | # Ensure that the NDC space points lie on the calculated rays. 61 | directions_ndc_norm = jnp.linalg.norm( 62 | directions_ndc, axis=-1, keepdims=True) 63 | directions_ndc_unit = directions_ndc / directions_ndc_norm 64 | projection = ((pts_ndc - origins_ndc) * directions_ndc_unit).sum(axis=-1) 65 | pts_ndc_proj = origins_ndc + directions_ndc_unit * projection[..., None] 66 | 67 | # pts_ndc should be close to their projections pts_ndc_proj onto the rays. 68 | np.testing.assert_allclose(pts_ndc, pts_ndc_proj, atol=1e-5, rtol=1e-5) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /tests/coord_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for coord.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from internal import coord 20 | from internal import math 21 | import jax 22 | from jax import random 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | 27 | def sample_covariance(rng, batch_size, num_dims): 28 | """Sample a random covariance matrix.""" 29 | half_cov = jax.random.normal(rng, [batch_size] + [num_dims] * 2) 30 | cov = math.matmul(half_cov, jnp.moveaxis(half_cov, -1, -2)) 31 | return cov 32 | 33 | 34 | def stable_pos_enc(x, n): 35 | """A stable pos_enc for very high degrees, courtesy of Sameer Agarwal.""" 36 | sin_x = np.sin(x) 37 | cos_x = np.cos(x) 38 | output = [] 39 | rotmat = np.array([[cos_x, -sin_x], [sin_x, cos_x]], dtype='double') 40 | for _ in range(n): 41 | output.append(rotmat[::-1, 0, :]) 42 | rotmat = np.einsum('ijn,jkn->ikn', rotmat, rotmat) 43 | return np.reshape(np.transpose(np.stack(output, 0), [2, 1, 0]), [-1, 2 * n]) 44 | 45 | 46 | class CoordTest(parameterized.TestCase): 47 | 48 | def test_stable_pos_enc(self): 49 | """Test that the stable posenc implementation works on multiples of pi/2.""" 50 | n = 10 51 | x = np.linspace(-np.pi, np.pi, 5) 52 | z = stable_pos_enc(x, n).reshape([-1, 2, n]) 53 | z0_true = np.zeros_like(z[:, 0, :]) 54 | z1_true = np.ones_like(z[:, 1, :]) 55 | z0_true[:, 0] = [0, -1, 0, 1, 0] 56 | z1_true[:, 0] = [-1, 0, 1, 0, -1] 57 | z1_true[:, 1] = [1, -1, 1, -1, 1] 58 | z_true = np.stack([z0_true, z1_true], axis=1) 59 | np.testing.assert_allclose(z, z_true, atol=1e-10) 60 | 61 | def test_contract_matches_special_case(self): 62 | """Test the math for Figure 2 of https://arxiv.org/abs/2111.12077.""" 63 | n = 10 64 | _, s_to_t = coord.construct_ray_warps(jnp.reciprocal, 1, jnp.inf) 65 | s = jnp.linspace(0, 1 - jnp.finfo(jnp.float32).eps, n + 1) 66 | tc = coord.contract(s_to_t(s)[:, None])[:, 0] 67 | delta_tc = tc[1:] - tc[:-1] 68 | np.testing.assert_allclose( 69 | delta_tc, np.full_like(delta_tc, 1 / n), atol=1E-5, rtol=1E-5) 70 | 71 | def test_contract_is_bounded(self): 72 | n, d = 10000, 3 73 | rng = random.PRNGKey(0) 74 | key0, key1, rng = random.split(rng, 3) 75 | x = jnp.where(random.bernoulli(key0, shape=[n, d]), 1, -1) * jnp.exp( 76 | random.uniform(key1, [n, d], minval=-10, maxval=10)) 77 | y = coord.contract(x) 78 | self.assertLessEqual(jnp.max(y), 2) 79 | 80 | def test_contract_is_noop_when_norm_is_leq_one(self): 81 | n, d = 10000, 3 82 | rng = random.PRNGKey(0) 83 | key, rng = random.split(rng) 84 | x = random.normal(key, shape=[n, d]) 85 | xc = x / jnp.maximum(1, jnp.linalg.norm(x, axis=-1, keepdims=True)) 86 | 87 | # Sanity check on the test itself. 88 | assert jnp.abs(jnp.max(jnp.linalg.norm(xc, axis=-1)) - 1) < 1e-6 89 | 90 | yc = coord.contract(xc) 91 | np.testing.assert_allclose(xc, yc, atol=1E-5, rtol=1E-5) 92 | 93 | def test_contract_gradients_are_finite(self): 94 | # Construct x such that we probe x == 0, where things are unstable. 95 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1) 96 | grad = jax.grad(lambda x: jnp.sum(coord.contract(x)))(x) 97 | self.assertTrue(jnp.all(jnp.isfinite(grad))) 98 | 99 | def test_inv_contract_gradients_are_finite(self): 100 | z = jnp.stack(jnp.meshgrid(*[jnp.linspace(-2, 2, 21)] * 2), axis=-1) 101 | z = z.reshape([-1, 2]) 102 | z = z[jnp.sum(z**2, axis=-1) < 2, :] 103 | grad = jax.grad(lambda z: jnp.sum(coord.inv_contract(z)))(z) 104 | self.assertTrue(jnp.all(jnp.isfinite(grad))) 105 | 106 | def test_inv_contract_inverts_contract(self): 107 | """Do a round-trip from metric space to contracted space and back.""" 108 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1) 109 | x_recon = coord.inv_contract(coord.contract(x)) 110 | np.testing.assert_allclose(x, x_recon, atol=1E-5, rtol=1E-5) 111 | 112 | @parameterized.named_parameters( 113 | ('05_1e-5', 5, 1e-5), 114 | ('10_1e-4', 10, 1e-4), 115 | ('15_0.005', 15, 0.005), 116 | ('20_0.2', 20, 0.2), # At high degrees, our implementation is unstable. 117 | ('25_2', 25, 2), # 2 is the maximum possible error. 118 | ('30_2', 30, 2), 119 | ) 120 | def test_pos_enc(self, n, tol): 121 | """test pos_enc against a stable recursive implementation.""" 122 | x = np.linspace(-np.pi, np.pi, 10001) 123 | z = coord.pos_enc(x[:, None], 0, n, append_identity=False) 124 | z_stable = stable_pos_enc(x, n) 125 | max_err = np.max(np.abs(z - z_stable)) 126 | print(f'PE of degree {n} has a maximum error of {max_err}') 127 | self.assertLess(max_err, tol) 128 | 129 | def test_pos_enc_matches_integrated(self): 130 | """Integrated positional encoding with a variance of zero must be pos_enc.""" 131 | min_deg = 0 132 | max_deg = 10 133 | np.linspace(-jnp.pi, jnp.pi, 10) 134 | x = jnp.stack( 135 | jnp.meshgrid(*[np.linspace(-jnp.pi, jnp.pi, 10)] * 2), axis=-1) 136 | x = np.linspace(-jnp.pi, jnp.pi, 10000) 137 | z_ipe = coord.integrated_pos_enc(x, jnp.zeros_like(x), min_deg, max_deg) 138 | z_pe = coord.pos_enc(x, min_deg, max_deg, append_identity=False) 139 | # We're using a pretty wide tolerance because IPE uses safe_sin(). 140 | np.testing.assert_allclose(z_pe, z_ipe, atol=1e-4) 141 | 142 | def test_track_linearize(self): 143 | rng = random.PRNGKey(0) 144 | batch_size = 20 145 | for _ in range(30): 146 | # Construct some random Gaussians with dimensionalities in [1, 10]. 147 | key, rng = random.split(rng) 148 | in_dims = random.randint(key, (), 1, 10) 149 | key, rng = random.split(rng) 150 | mean = jax.random.normal(key, [batch_size, in_dims]) 151 | key, rng = random.split(rng) 152 | cov = sample_covariance(key, batch_size, in_dims) 153 | key, rng = random.split(rng) 154 | out_dims = random.randint(key, (), 1, 10) 155 | 156 | # Construct a random affine transformation. 157 | key, rng = random.split(rng) 158 | a_mat = jax.random.normal(key, [out_dims, in_dims]) 159 | key, rng = random.split(rng) 160 | b = jax.random.normal(key, [out_dims]) 161 | 162 | def fn(x): 163 | x_vec = x.reshape([-1, x.shape[-1]]) 164 | y_vec = jax.vmap(lambda z: math.matmul(a_mat, z))(x_vec) + b # pylint:disable=cell-var-from-loop 165 | y = y_vec.reshape(list(x.shape[:-1]) + [y_vec.shape[-1]]) 166 | return y 167 | 168 | # Apply the affine function to the Gaussians. 169 | fn_mean_true = fn(mean) 170 | fn_cov_true = math.matmul(math.matmul(a_mat, cov), a_mat.T) 171 | 172 | # Tracking the Gaussians through a linearized function of a linear 173 | # operator should be the same. 174 | fn_mean, fn_cov = coord.track_linearize(fn, mean, cov) 175 | np.testing.assert_allclose(fn_mean, fn_mean_true, atol=1E-5, rtol=1E-5) 176 | np.testing.assert_allclose(fn_cov, fn_cov_true, atol=1e-5, rtol=1e-5) 177 | 178 | @parameterized.named_parameters(('reciprocal', jnp.reciprocal), 179 | ('log', jnp.log), ('sqrt', jnp.sqrt)) 180 | def test_construct_ray_warps_extents(self, fn): 181 | n = 100 182 | rng = random.PRNGKey(0) 183 | key, rng = random.split(rng) 184 | t_near = jnp.exp(jax.random.normal(key, [n])) 185 | key, rng = random.split(rng) 186 | t_far = t_near + jnp.exp(jax.random.normal(key, [n])) 187 | 188 | t_to_s, s_to_t = coord.construct_ray_warps(fn, t_near, t_far) 189 | 190 | np.testing.assert_allclose( 191 | t_to_s(t_near), jnp.zeros_like(t_near), atol=1E-5, rtol=1E-5) 192 | np.testing.assert_allclose( 193 | t_to_s(t_far), jnp.ones_like(t_far), atol=1E-5, rtol=1E-5) 194 | np.testing.assert_allclose( 195 | s_to_t(jnp.zeros_like(t_near)), t_near, atol=1E-5, rtol=1E-5) 196 | np.testing.assert_allclose( 197 | s_to_t(jnp.ones_like(t_near)), t_far, atol=1E-5, rtol=1E-5) 198 | 199 | def test_construct_ray_warps_special_reciprocal(self): 200 | """Test fn=1/x against its closed form.""" 201 | n = 100 202 | rng = random.PRNGKey(0) 203 | key, rng = random.split(rng) 204 | t_near = jnp.exp(jax.random.normal(key, [n])) 205 | key, rng = random.split(rng) 206 | t_far = t_near + jnp.exp(jax.random.normal(key, [n])) 207 | 208 | key, rng = random.split(rng) 209 | u = jax.random.uniform(key, [n]) 210 | t = t_near * (1 - u) + t_far * u 211 | key, rng = random.split(rng) 212 | s = jax.random.uniform(key, [n]) 213 | 214 | t_to_s, s_to_t = coord.construct_ray_warps(jnp.reciprocal, t_near, t_far) 215 | 216 | # Special cases for fn=reciprocal. 217 | s_to_t_ref = lambda s: 1 / (s / t_far + (1 - s) / t_near) 218 | t_to_s_ref = lambda t: (t_far * (t - t_near)) / (t * (t_far - t_near)) 219 | 220 | np.testing.assert_allclose(t_to_s(t), t_to_s_ref(t), atol=1E-5, rtol=1E-5) 221 | np.testing.assert_allclose(s_to_t(s), s_to_t_ref(s), atol=1E-5, rtol=1E-5) 222 | 223 | def test_expected_sin(self): 224 | normal_samples = random.normal(random.PRNGKey(0), (10000,)) 225 | for mu, var in [(0, 1), (1, 3), (-2, .2), (10, 10)]: 226 | sin_mu = coord.expected_sin(mu, var) 227 | x = jnp.sin(jnp.sqrt(var) * normal_samples + mu) 228 | np.testing.assert_allclose(sin_mu, jnp.mean(x), atol=1e-2) 229 | 230 | def test_integrated_pos_enc(self): 231 | num_dims = 2 # The number of input dimensions. 232 | min_deg = 0 # Must be 0 for this test to work. 233 | max_deg = 4 234 | num_samples = 100000 235 | rng = random.PRNGKey(0) 236 | for _ in range(5): 237 | # Generate a coordinate's mean and covariance matrix. 238 | key, rng = random.split(rng) 239 | mean = random.normal(key, (2,)) 240 | key, rng = random.split(rng) 241 | half_cov = jax.random.normal(key, [num_dims] * 2) 242 | cov = half_cov @ half_cov.T 243 | var = jnp.diag(cov) 244 | # Generate an IPE. 245 | enc = coord.integrated_pos_enc( 246 | mean, 247 | var, 248 | min_deg, 249 | max_deg, 250 | ) 251 | 252 | # Draw samples, encode them, and take their mean. 253 | key, rng = random.split(rng) 254 | samples = random.multivariate_normal(key, mean, cov, [num_samples]) 255 | assert min_deg == 0 256 | enc_samples = np.concatenate( 257 | [stable_pos_enc(x, max_deg) for x in tuple(samples.T)], axis=-1) 258 | # Correct for a different dimension ordering in stable_pos_enc. 259 | enc_gt = jnp.mean(enc_samples, 0) 260 | enc_gt = enc_gt.reshape([num_dims, max_deg * 2]).T.reshape([-1]) 261 | np.testing.assert_allclose(enc, enc_gt, rtol=1e-2, atol=1e-2) 262 | 263 | 264 | if __name__ == '__main__': 265 | absltest.main() 266 | -------------------------------------------------------------------------------- /tests/math_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for math.""" 16 | 17 | import functools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from internal import math 22 | import jax 23 | from jax import random 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | def safe_trig_harness(fn, max_exp): 29 | x = 10**np.linspace(-30, max_exp, 10000) 30 | x = np.concatenate([-x[::-1], np.array([0]), x]) 31 | y_true = getattr(np, fn)(x) 32 | y = getattr(math, 'safe_' + fn)(x) 33 | return y_true, y 34 | 35 | 36 | class MathTest(parameterized.TestCase): 37 | 38 | def test_sin(self): 39 | """In [-1e10, 1e10] safe_sin and safe_cos are accurate.""" 40 | for fn in ['sin', 'cos']: 41 | y_true, y = safe_trig_harness(fn, 10) 42 | self.assertLess(jnp.max(jnp.abs(y - y_true)), 1e-4) 43 | self.assertFalse(jnp.any(jnp.isnan(y))) 44 | # Beyond that range it's less accurate but we just don't want it to be NaN. 45 | for fn in ['sin', 'cos']: 46 | y_true, y = safe_trig_harness(fn, 60) 47 | self.assertFalse(jnp.any(jnp.isnan(y))) 48 | 49 | def test_safe_exp_correct(self): 50 | """math.safe_exp() should match np.exp() for not-huge values.""" 51 | x = jnp.linspace(-80, 80, 10001) 52 | y = math.safe_exp(x) 53 | g = jax.vmap(jax.grad(math.safe_exp))(x) 54 | yg_true = jnp.exp(x) 55 | np.testing.assert_allclose(y, yg_true) 56 | np.testing.assert_allclose(g, yg_true) 57 | 58 | def test_safe_exp_finite(self): 59 | """math.safe_exp() behaves reasonably for huge values.""" 60 | x = jnp.linspace(-100000, 100000, 10001) 61 | y = math.safe_exp(x) 62 | g = jax.vmap(jax.grad(math.safe_exp))(x) 63 | # `y` and `g` should both always be finite. 64 | self.assertTrue(jnp.all(jnp.isfinite(y))) 65 | self.assertTrue(jnp.all(jnp.isfinite(g))) 66 | # The derivative of exp() should be exp(). 67 | np.testing.assert_allclose(y, g) 68 | # safe_exp()'s output and gradient should be monotonic. 69 | self.assertTrue(jnp.all(y[1:] >= y[:-1])) 70 | self.assertTrue(jnp.all(g[1:] >= g[:-1])) 71 | 72 | def test_learning_rate_decay(self): 73 | rng = random.PRNGKey(0) 74 | for _ in range(10): 75 | key, rng = random.split(rng) 76 | lr_init = jnp.exp(random.normal(key) - 3) 77 | key, rng = random.split(rng) 78 | lr_final = lr_init * jnp.exp(random.normal(key) - 5) 79 | key, rng = random.split(rng) 80 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key)))) 81 | 82 | lr_fn = functools.partial( 83 | math.learning_rate_decay, 84 | lr_init=lr_init, 85 | lr_final=lr_final, 86 | max_steps=max_steps) 87 | 88 | # Test that the rate at the beginning is the initial rate. 89 | np.testing.assert_allclose(lr_fn(0), lr_init, atol=1E-5, rtol=1E-5) 90 | 91 | # Test that the rate at the end is the final rate. 92 | np.testing.assert_allclose( 93 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5) 94 | 95 | # Test that the rate at the middle is the geometric mean of the two rates. 96 | np.testing.assert_allclose( 97 | lr_fn(max_steps / 2), 98 | jnp.sqrt(lr_init * lr_final), 99 | atol=1E-5, 100 | rtol=1E-5) 101 | 102 | # Test that the rate past the end is the final rate 103 | np.testing.assert_allclose( 104 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5) 105 | 106 | def test_delayed_learning_rate_decay(self): 107 | rng = random.PRNGKey(0) 108 | for _ in range(10): 109 | key, rng = random.split(rng) 110 | lr_init = jnp.exp(random.normal(key) - 3) 111 | key, rng = random.split(rng) 112 | lr_final = lr_init * jnp.exp(random.normal(key) - 5) 113 | key, rng = random.split(rng) 114 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key)))) 115 | key, rng = random.split(rng) 116 | lr_delay_steps = int( 117 | random.uniform(key, minval=0.1, maxval=0.4) * max_steps) 118 | key, rng = random.split(rng) 119 | lr_delay_mult = jnp.exp(random.normal(key) - 3) 120 | 121 | lr_fn = functools.partial( 122 | math.learning_rate_decay, 123 | lr_init=lr_init, 124 | lr_final=lr_final, 125 | max_steps=max_steps, 126 | lr_delay_steps=lr_delay_steps, 127 | lr_delay_mult=lr_delay_mult) 128 | 129 | # Test that the rate at the beginning is the delayed initial rate. 130 | np.testing.assert_allclose( 131 | lr_fn(0), lr_delay_mult * lr_init, atol=1E-5, rtol=1E-5) 132 | 133 | # Test that the rate at the end is the final rate. 134 | np.testing.assert_allclose( 135 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5) 136 | 137 | # Test that the rate at after the delay is over is the usual rate. 138 | np.testing.assert_allclose( 139 | lr_fn(lr_delay_steps), 140 | math.learning_rate_decay(lr_delay_steps, lr_init, lr_final, 141 | max_steps), 142 | atol=1E-5, 143 | rtol=1E-5) 144 | 145 | # Test that the rate at the middle is the geometric mean of the two rates. 146 | np.testing.assert_allclose( 147 | lr_fn(max_steps / 2), 148 | jnp.sqrt(lr_init * lr_final), 149 | atol=1E-5, 150 | rtol=1E-5) 151 | 152 | # Test that the rate past the end is the final rate 153 | np.testing.assert_allclose( 154 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5) 155 | 156 | @parameterized.named_parameters(('', False), ('sort', True)) 157 | def test_interp(self, sort): 158 | n, d0, d1 = 100, 10, 20 159 | rng = random.PRNGKey(0) 160 | 161 | key, rng = random.split(rng) 162 | x = random.normal(key, [n, d0]) 163 | 164 | key, rng = random.split(rng) 165 | xp = random.normal(key, [n, d1]) 166 | 167 | key, rng = random.split(rng) 168 | fp = random.normal(key, [n, d1]) 169 | 170 | if sort: 171 | xp = jnp.sort(xp, axis=-1) 172 | fp = jnp.sort(fp, axis=-1) 173 | z = math.sorted_interp(x, xp, fp) 174 | else: 175 | z = math.interp(x, xp, fp) 176 | 177 | z_true = jnp.stack([jnp.interp(x[i], xp[i], fp[i]) for i in range(n)]) 178 | np.testing.assert_allclose(z, z_true, atol=1e-5, rtol=1e-5) 179 | 180 | 181 | if __name__ == '__main__': 182 | absltest.main() 183 | -------------------------------------------------------------------------------- /tests/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utils.""" 16 | 17 | from absl.testing import absltest 18 | 19 | from internal import utils 20 | 21 | 22 | class UtilsTest(absltest.TestCase): 23 | 24 | def test_dummy_rays(self): 25 | """Ensures that the dummy Rays object is correctly initialized.""" 26 | rays = utils.dummy_rays() 27 | self.assertEqual(rays.origins.shape[-1], 3) 28 | 29 | 30 | if __name__ == '__main__': 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training script.""" 16 | 17 | import functools 18 | import gc 19 | import time 20 | 21 | from absl import app 22 | import flax 23 | from flax.metrics import tensorboard 24 | from flax.training import checkpoints 25 | import gin 26 | from internal import configs 27 | from internal import datasets 28 | from internal import image 29 | from internal import models 30 | from internal import train_utils 31 | from internal import utils 32 | from internal import vis 33 | import jax 34 | from jax import random 35 | import jax.numpy as jnp 36 | import numpy as np 37 | 38 | configs.define_common_flags() 39 | jax.config.parse_flags_with_absl() 40 | 41 | TIME_PRECISION = 1000 # Internally represent integer times in milliseconds. 42 | 43 | 44 | def main(unused_argv): 45 | rng = random.PRNGKey(20200823) 46 | # Shift the numpy random seed by host_id() to shuffle data loaded by different 47 | # hosts. 48 | np.random.seed(20201473 + jax.host_id()) 49 | 50 | config = configs.load_config() 51 | 52 | if config.batch_size % jax.device_count() != 0: 53 | raise ValueError('Batch size must be divisible by the number of devices.') 54 | 55 | dataset = datasets.load_dataset('train', config.data_dir, config) 56 | test_dataset = datasets.load_dataset('test', config.data_dir, config) 57 | np_to_jax = lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x 58 | cameras = tuple(np_to_jax(x) for x in dataset.cameras) 59 | 60 | if config.rawnerf_mode: 61 | postprocess_fn = test_dataset.metadata['postprocess_fn'] 62 | else: 63 | postprocess_fn = lambda z, _=None: z 64 | 65 | rng, key = random.split(rng) 66 | setup = train_utils.setup_model(config, key, dataset=dataset) 67 | model, state, render_eval_pfn, train_pstep, lr_fn = setup 68 | 69 | variables = state.params 70 | num_params = jax.tree_util.tree_reduce( 71 | lambda x, y: x + jnp.prod(jnp.array(y.shape)), variables, initializer=0) 72 | print(f'Number of parameters being optimized: {num_params}') 73 | 74 | if (dataset.size > model.num_glo_embeddings and model.num_glo_features > 0): 75 | raise ValueError(f'Number of glo embeddings {model.num_glo_embeddings} ' 76 | f'must be at least equal to number of train images ' 77 | f'{dataset.size}') 78 | 79 | metric_harness = image.MetricHarness() 80 | 81 | if not utils.isdir(config.checkpoint_dir): 82 | utils.makedirs(config.checkpoint_dir) 83 | state = checkpoints.restore_checkpoint(config.checkpoint_dir, state) 84 | # Resume training at the step of the last checkpoint. 85 | init_step = state.step + 1 86 | state = flax.jax_utils.replicate(state) 87 | 88 | if jax.host_id() == 0: 89 | summary_writer = tensorboard.SummaryWriter(config.checkpoint_dir) 90 | if config.rawnerf_mode: 91 | for name, data in zip(['train', 'test'], [dataset, test_dataset]): 92 | # Log shutter speed metadata in TensorBoard for debug purposes. 93 | for key in ['exposure_idx', 'exposure_values', 'unique_shutters']: 94 | summary_writer.text(f'{name}_{key}', str(data.metadata[key]), 0) 95 | 96 | # Prefetch_buffer_size = 3 x batch_size. 97 | pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) 98 | rng = rng + jax.host_id() # Make random seed separate across hosts. 99 | rngs = random.split(rng, jax.local_device_count()) # For pmapping RNG keys. 100 | gc.disable() # Disable automatic garbage collection for efficiency. 101 | total_time = 0 102 | total_steps = 0 103 | reset_stats = True 104 | if config.early_exit_steps is not None: 105 | num_steps = config.early_exit_steps 106 | else: 107 | num_steps = config.max_steps 108 | loss_threshold = 1.0 109 | for step, batch in zip(range(init_step, num_steps + 1), pdataset): 110 | if reset_stats and (jax.host_id() == 0): 111 | stats_buffer = [] 112 | train_start_time = time.time() 113 | reset_stats = False 114 | 115 | learning_rate = lr_fn(step) 116 | train_frac = jnp.clip((step - 1) / (config.max_steps - 1), 0, 1) 117 | 118 | state, stats, rngs = train_pstep( 119 | rngs, 120 | state, 121 | batch, 122 | cameras, 123 | train_frac, 124 | loss_threshold, 125 | ) 126 | if config.enable_robustnerf_loss: 127 | loss_threshold = jnp.mean(stats['loss_threshold']) 128 | 129 | if step % config.gc_every == 0: 130 | gc.collect() # Disable automatic garbage collection for efficiency. 131 | 132 | # Log training summaries. This is put behind a host_id check because in 133 | # multi-host evaluation, all hosts need to run inference even though we 134 | # only use host 0 to record results. 135 | if jax.host_id() == 0: 136 | stats = flax.jax_utils.unreplicate(stats) 137 | 138 | stats_buffer.append(stats) 139 | 140 | if step == init_step or step % config.print_every == 0: 141 | elapsed_time = time.time() - train_start_time 142 | steps_per_sec = config.print_every / elapsed_time 143 | rays_per_sec = config.batch_size * steps_per_sec 144 | 145 | # A robust approximation of total training time, in case of pre-emption. 146 | total_time += int(round(TIME_PRECISION * elapsed_time)) 147 | total_steps += config.print_every 148 | approx_total_time = int(round(step * total_time / total_steps)) 149 | 150 | # Transpose and stack stats_buffer along axis 0. 151 | fs = [flax.traverse_util.flatten_dict(s, sep='/') for s in stats_buffer] 152 | stats_stacked = {k: jnp.stack([f[k] for f in fs]) for k in fs[0].keys()} 153 | 154 | # Split every statistic that isn't a vector into a set of statistics. 155 | stats_split = {} 156 | for k, v in stats_stacked.items(): 157 | if v.ndim not in [1, 2] and v.shape[0] != len(stats_buffer): 158 | raise ValueError('statistics must be of size [n], or [n, k].') 159 | if v.ndim == 1: 160 | stats_split[k] = v 161 | elif v.ndim == 2: 162 | for i, vi in enumerate(tuple(v.T)): 163 | stats_split[f'{k}/{i}'] = vi 164 | 165 | # Summarize the entire histogram of each statistic. 166 | for k, v in stats_split.items(): 167 | summary_writer.histogram('train_' + k, v, step) 168 | 169 | # Take the mean and max of each statistic since the last summary. 170 | avg_stats = {k: jnp.mean(v) for k, v in stats_split.items()} 171 | max_stats = {k: jnp.max(v) for k, v in stats_split.items()} 172 | 173 | summ_fn = lambda s, v: summary_writer.scalar(s, v, step) # pylint:disable=cell-var-from-loop 174 | 175 | # Summarize the mean and max of each statistic. 176 | for k, v in avg_stats.items(): 177 | summ_fn(f'train_avg_{k}', v) 178 | for k, v in max_stats.items(): 179 | summ_fn(f'train_max_{k}', v) 180 | 181 | summ_fn('train_num_params', num_params) 182 | summ_fn('train_learning_rate', learning_rate) 183 | summ_fn('train_steps_per_sec', steps_per_sec) 184 | summ_fn('train_rays_per_sec', rays_per_sec) 185 | 186 | summary_writer.scalar('train_avg_psnr_timed', avg_stats['psnr'], 187 | total_time // TIME_PRECISION) 188 | summary_writer.scalar('train_avg_psnr_timed_approx', avg_stats['psnr'], 189 | approx_total_time // TIME_PRECISION) 190 | 191 | if dataset.metadata is not None and model.learned_exposure_scaling: 192 | params = state.params['params'] 193 | scalings = params['exposure_scaling_offsets']['embedding'][0] 194 | num_shutter_speeds = dataset.metadata['unique_shutters'].shape[0] 195 | for i_s in range(num_shutter_speeds): 196 | for j_s, value in enumerate(scalings[i_s]): 197 | summary_name = f'exposure/scaling_{i_s}_{j_s}' 198 | summary_writer.scalar(summary_name, value, step) 199 | 200 | precision = int(np.ceil(np.log10(config.max_steps))) + 1 201 | avg_loss = avg_stats['loss'] 202 | avg_psnr = avg_stats['psnr'] 203 | str_losses = { # Grab each "losses_{x}" field and print it as "x[:4]". 204 | k[7:11]: (f'{v:0.5f}' if v >= 1e-4 and v < 10 else f'{v:0.1e}') 205 | for k, v in avg_stats.items() 206 | if k.startswith('losses/') 207 | } 208 | print(f'{step:{precision}d}' + f'/{config.max_steps:d}: ' + 209 | f'loss={avg_loss:0.5f}, ' + f'psnr={avg_psnr:6.3f}, ' + 210 | f'lr={learning_rate:0.2e} | ' + 211 | ', '.join([f'{k}={s}' for k, s in str_losses.items()]) + 212 | f', {rays_per_sec:0.0f} r/s') 213 | 214 | # Reset everything we are tracking between summarizations. 215 | reset_stats = True 216 | 217 | if step == 1 or step % config.checkpoint_every == 0: 218 | state_to_save = jax.device_get( 219 | flax.jax_utils.unreplicate(state)) 220 | checkpoints.save_checkpoint( 221 | config.checkpoint_dir, state_to_save, int(step), keep=100) 222 | 223 | # Test-set evaluation. 224 | if config.train_render_every > 0 and step % config.train_render_every == 0: 225 | # We reuse the same random number generator from the optimization step 226 | # here on purpose so that the visualization matches what happened in 227 | # training. 228 | eval_start_time = time.time() 229 | eval_variables = flax.jax_utils.unreplicate(state).params 230 | test_case = next(test_dataset) 231 | rendering = models.render_image( 232 | functools.partial(render_eval_pfn, eval_variables, train_frac), 233 | test_case.rays, rngs[0], config) 234 | # Log eval summaries on host 0. 235 | if jax.host_id() == 0: 236 | eval_time = time.time() - eval_start_time 237 | num_rays = jnp.prod(jnp.array(test_case.rays.directions.shape[:-1])) 238 | rays_per_sec = num_rays / eval_time 239 | summary_writer.scalar('test_rays_per_sec', rays_per_sec, step) 240 | print(f'Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec') 241 | 242 | metric_start_time = time.time() 243 | metric = metric_harness( 244 | postprocess_fn(rendering['rgb']), postprocess_fn(test_case.rgb)) 245 | print(f'Metrics computed in {(time.time() - metric_start_time):0.3f}s') 246 | for name, val in metric.items(): 247 | if not np.isnan(val): 248 | print(f'{name} = {val:.4f}') 249 | summary_writer.scalar('train_metrics/' + name, val, step) 250 | 251 | if config.vis_decimate > 1: 252 | d = config.vis_decimate 253 | decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d] 254 | else: 255 | decimate_fn = lambda x: x 256 | rendering = jax.tree_util.tree_map(decimate_fn, rendering) 257 | test_case = jax.tree_util.tree_map(decimate_fn, test_case) 258 | vis_start_time = time.time() 259 | vis_suite = vis.visualize_suite(rendering, test_case.rays) 260 | print(f'Visualized in {(time.time() - vis_start_time):0.3f}s') 261 | if config.rawnerf_mode: 262 | # Unprocess raw output. 263 | vis_suite['color_raw'] = rendering['rgb'] 264 | # Autoexposed colors. 265 | vis_suite['color_auto'] = postprocess_fn(rendering['rgb'], None) 266 | summary_writer.image('test_true_auto', 267 | postprocess_fn(test_case.rgb, None), step) 268 | # Exposure sweep colors. 269 | exposures = test_dataset.metadata['exposure_levels'] 270 | for p, x in list(exposures.items()): 271 | vis_suite[f'color/{p}'] = postprocess_fn(rendering['rgb'], x) 272 | summary_writer.image(f'test_true_color/{p}', 273 | postprocess_fn(test_case.rgb, x), step) 274 | summary_writer.image('test_true_color', test_case.rgb, step) 275 | for k, v in vis_suite.items(): 276 | summary_writer.image('test_output_' + k, v, step) 277 | 278 | if jax.host_id() == 0 and config.max_steps % config.checkpoint_every != 0: 279 | state = jax.device_get(flax.jax_utils.unreplicate(state)) 280 | checkpoints.save_checkpoint( 281 | config.checkpoint_dir, state, int(config.max_steps), keep=100) 282 | 283 | 284 | if __name__ == '__main__': 285 | with gin.config_scope('train'): 286 | app.run(main) 287 | --------------------------------------------------------------------------------