├── CONTRIBUTING.md ├── LICENSE ├── OWNERS ├── README.md ├── configs ├── blender.gin ├── blender_noextras.gin ├── blender_noipe.gin ├── llff.gin ├── multiblender.gin ├── multiblender_noextras.gin ├── multiblender_noipe.gin └── multiblender_noloss.gin ├── eval.py ├── internal ├── datasets.py ├── math.py ├── math_test.py ├── mip.py ├── mip_test.py ├── models.py ├── utils.py └── vis.py ├── requirements.txt ├── scripts ├── convert_blender_data.py ├── eval_blender.sh ├── eval_llff.sh ├── eval_multiblender.sh ├── summarize.ipynb ├── train_blender.sh ├── train_llff.sh └── train_multiblender.sh └── train.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code Reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /OWNERS: -------------------------------------------------------------------------------- 1 | barron 2 | bmild 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mip-NeRF 2 | 3 | This repository contains the code release for 4 | [Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields](https://jonbarron.info/mipnerf/). 5 | This implementation is written in [JAX](https://github.com/google/jax), and 6 | is a fork of Google's [JaxNeRF implementation](https://github.com/google-research/google-research/tree/master/jaxnerf). 7 | Contact [Jon Barron](https://jonbarron.info/) if you encounter any issues. 8 | 9 | ![rays](https://user-images.githubusercontent.com/3310961/118305131-6ce86700-b49c-11eb-99b8-adcf276e9fe9.jpg) 10 | 11 | ## Abstract 12 | 13 | The rendering procedure used by neural radiance fields (NeRF) samples a scene 14 | with a single ray per pixel and may therefore produce renderings that are 15 | excessively blurred or aliased when training or testing images observe scene 16 | content at different resolutions. The straightforward solution of supersampling 17 | by rendering with multiple rays per pixel is impractical for NeRF, because 18 | rendering each ray requires querying a multilayer perceptron hundreds of times. 19 | Our solution, which we call "mip-NeRF" (à la "mipmap"), extends NeRF to 20 | represent the scene at a continuously-valued scale. By efficiently rendering 21 | anti-aliased conical frustums instead of rays, mip-NeRF reduces objectionable 22 | aliasing artifacts and significantly improves NeRF's ability to represent 23 | fine details, while also being 7% faster than NeRF and half the size. Compared 24 | to NeRF, mip-NeRF reduces average error rates by 17% on the dataset presented 25 | with NeRF and by 60% on a challenging multiscale variant of that dataset that 26 | we present. mip-NeRF is also able to match the accuracy of a brute-force 27 | supersampled NeRF on our multiscale dataset while being 22x faster. 28 | 29 | 30 | ## Installation 31 | We recommend using [Anaconda](https://www.anaconda.com/products/individual) to set 32 | up the environment. Run the following commands: 33 | 34 | ``` 35 | # Clone the repo 36 | git clone https://github.com/google/mipnerf.git; cd mipnerf 37 | # Create a conda environment, note you can use python 3.6-3.8 as 38 | # one of the dependencies (TensorFlow) hasn't supported python 3.9 yet. 39 | conda create --name mipnerf python=3.6.13; conda activate mipnerf 40 | # Prepare pip 41 | conda install pip; pip install --upgrade pip 42 | # Install requirements 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | [Optional] Install GPU and TPU support for Jax 47 | ``` 48 | # Remember to change cuda101 to your CUDA version, e.g. cuda110 for CUDA 11.0. 49 | pip install --upgrade jax jaxlib==0.1.65+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html 50 | ``` 51 | 52 | ## Data 53 | 54 | Then, you'll need to download the datasets 55 | from the [NeRF official Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). 56 | Please download and unzip `nerf_synthetic.zip` and `nerf_llff_data.zip`. 57 | 58 | ### Generate multiscale dataset 59 | You can generate the multiscale dataset used in the paper by running the following command, 60 | ``` 61 | python scripts/convert_blender_data.py --blenderdir /nerf_synthetic --outdir /multiscale 62 | ``` 63 | 64 | ## Running 65 | 66 | Example scripts for training mip-NeRF on individual scenes from the three 67 | datasets used in the paper can be found in `scripts/`. You'll need to change 68 | the paths to point to wherever the datasets are located. 69 | [Gin](https://github.com/google/gin-config) configuration files for our model 70 | and some ablations can be found in `configs/`. 71 | An example script for evaluating on the test set of each scene can be found 72 | in `scripts/`, after which you can use `scripts/summarize.ipynb` to produce 73 | error metrics across all scenes in the same format as was used in tables in the 74 | paper. 75 | 76 | ### OOM errors 77 | You may need to reduce the batch size to avoid out of memory errors. For example the model can be run on a NVIDIA 3080 (10Gb) using the following flag. 78 | ``` 79 | --gin_param="Config.batch_size = 1024" 80 | ``` 81 | 82 | ## Citation 83 | If you use this software package, please cite our paper: 84 | 85 | ``` 86 | @misc{barron2021mipnerf, 87 | title={Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields}, 88 | author={Jonathan T. Barron and Ben Mildenhall and Matthew Tancik and Peter Hedman and Ricardo Martin-Brualla and Pratul P. Srinivasan}, 89 | year={2021}, 90 | eprint={2103.13415}, 91 | archivePrefix={arXiv}, 92 | primaryClass={cs.CV} 93 | } 94 | ``` 95 | 96 | ## Acknowledgements 97 | Thanks to [Boyang Deng](https://boyangdeng.com/) for JaxNeRF. 98 | -------------------------------------------------------------------------------- /configs/blender.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'blender' 2 | Config.batching = 'single_image' 3 | -------------------------------------------------------------------------------- /configs/blender_noextras.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'blender' 2 | Config.batching = 'single_image' 3 | MipNerfModel.density_activation = @flax.nn.relu 4 | MipNerfModel.density_bias = 0.0 5 | MipNerfModel.rgb_padding = 0.0 6 | -------------------------------------------------------------------------------- /configs/blender_noipe.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'blender' 2 | Config.batching = 'single_image' 3 | MipNerfModel.disable_integration = True 4 | -------------------------------------------------------------------------------- /configs/llff.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.white_bkgd = False 3 | Config.randomized = True 4 | Config.near = 0. 5 | Config.far = 1. 6 | Config.factor = 4 7 | Config.llffhold = 8 8 | MipNerfModel.use_viewdirs = True 9 | MipNerfModel.ray_shape = 'cylinder' 10 | MipNerfModel.density_noise = 1. 11 | -------------------------------------------------------------------------------- /configs/multiblender.gin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/mipnerf/84c969e0a623edd183b75693aed72a7e7c22902d/configs/multiblender.gin -------------------------------------------------------------------------------- /configs/multiblender_noextras.gin: -------------------------------------------------------------------------------- 1 | MipNerfModel.density_activation = @flax.nn.relu 2 | MipNerfModel.density_bias = 0.0 3 | MipNerfModel.rgb_padding = 0.0 4 | -------------------------------------------------------------------------------- /configs/multiblender_noipe.gin: -------------------------------------------------------------------------------- 1 | MipNerfModel.disable_integration = True 2 | -------------------------------------------------------------------------------- /configs/multiblender_noloss.gin: -------------------------------------------------------------------------------- 1 | Config.disable_multiscale_loss = True 2 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Evaluation script for mip-NeRF.""" 17 | import functools 18 | from os import path 19 | 20 | from absl import app 21 | from absl import flags 22 | import flax 23 | from flax.metrics import tensorboard 24 | from flax.training import checkpoints 25 | import jax 26 | from jax import random 27 | import numpy as np 28 | 29 | from internal import datasets 30 | from internal import math 31 | from internal import models 32 | from internal import utils 33 | from internal import vis 34 | 35 | FLAGS = flags.FLAGS 36 | utils.define_common_flags() 37 | flags.DEFINE_bool( 38 | 'eval_once', True, 39 | 'If True, evaluate the model only once, otherwise keeping evaluating new' 40 | 'checkpoints if any exist.') 41 | flags.DEFINE_bool('save_output', True, 42 | 'If True, save predicted images to disk.') 43 | 44 | 45 | def main(unused_argv): 46 | config = utils.load_config() 47 | 48 | dataset = datasets.get_dataset('test', FLAGS.data_dir, config) 49 | model, init_variables = models.construct_mipnerf( 50 | random.PRNGKey(20200823), dataset.peek()) 51 | optimizer = flax.optim.Adam(config.lr_init).create(init_variables) 52 | state = utils.TrainState(optimizer=optimizer) 53 | del optimizer, init_variables 54 | 55 | # Rendering is forced to be deterministic even if training was randomized, as 56 | # this eliminates 'speckle' artifacts. 57 | def render_eval_fn(variables, _, rays): 58 | return jax.lax.all_gather( 59 | model.apply( 60 | variables, 61 | random.PRNGKey(0), # Unused. 62 | rays, 63 | randomized=False, 64 | white_bkgd=config.white_bkgd), 65 | axis_name='batch') 66 | 67 | # pmap over only the data input. 68 | render_eval_pfn = jax.pmap( 69 | render_eval_fn, 70 | in_axes=(None, None, 0), 71 | donate_argnums=2, 72 | axis_name='batch', 73 | ) 74 | 75 | ssim_fn = jax.jit(functools.partial(math.compute_ssim, max_val=1.)) 76 | 77 | last_step = 0 78 | out_dir = path.join(FLAGS.train_dir, 79 | 'path_renders' if config.render_path else 'test_preds') 80 | if not FLAGS.eval_once: 81 | summary_writer = tensorboard.SummaryWriter( 82 | path.join(FLAGS.train_dir, 'eval')) 83 | while True: 84 | state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) 85 | step = int(state.optimizer.state.step) 86 | if step <= last_step: 87 | continue 88 | if FLAGS.save_output and (not utils.isdir(out_dir)): 89 | utils.makedirs(out_dir) 90 | psnr_values = [] 91 | ssim_values = [] 92 | avg_values = [] 93 | if not FLAGS.eval_once: 94 | showcase_index = random.randint(random.PRNGKey(step), (), 0, dataset.size) 95 | for idx in range(dataset.size): 96 | print(f'Evaluating {idx+1}/{dataset.size}') 97 | batch = next(dataset) 98 | pred_color, pred_distance, pred_acc = models.render_image( 99 | functools.partial(render_eval_pfn, state.optimizer.target), 100 | batch['rays'], 101 | None, 102 | chunk=FLAGS.chunk) 103 | 104 | vis_suite = vis.visualize_suite(pred_distance, pred_acc) 105 | 106 | if jax.host_id() != 0: # Only record via host 0. 107 | continue 108 | if not FLAGS.eval_once and idx == showcase_index: 109 | showcase_color = pred_color 110 | showcase_acc = pred_acc 111 | showcase_vis_suite = vis_suite 112 | if not config.render_path: 113 | showcase_gt = batch['pixels'] 114 | if not config.render_path: 115 | psnr = float( 116 | math.mse_to_psnr(((pred_color - batch['pixels'])**2).mean())) 117 | ssim = float(ssim_fn(pred_color, batch['pixels'])) 118 | print(f'PSNR={psnr:.4f} SSIM={ssim:.4f}') 119 | psnr_values.append(psnr) 120 | ssim_values.append(ssim) 121 | if FLAGS.save_output and (config.test_render_interval > 0): 122 | if (idx % config.test_render_interval) == 0: 123 | utils.save_img_uint8( 124 | pred_color, path.join(out_dir, 'color_{:03d}.png'.format(idx))) 125 | utils.save_img_float32( 126 | pred_distance, 127 | path.join(out_dir, 'distance_{:03d}.tiff'.format(idx))) 128 | utils.save_img_float32( 129 | pred_acc, path.join(out_dir, 'acc_{:03d}.tiff'.format(idx))) 130 | for k, v in vis_suite.items(): 131 | utils.save_img_uint8( 132 | v, path.join(out_dir, k + '_{:03d}.png'.format(idx))) 133 | if (not FLAGS.eval_once) and (jax.host_id() == 0): 134 | summary_writer.image('pred_color', showcase_color, step) 135 | summary_writer.image('pred_acc', showcase_acc, step) 136 | for k, v in showcase_vis_suite.items(): 137 | summary_writer.image('pred_' + k, v, step) 138 | if not config.render_path: 139 | summary_writer.scalar('psnr', np.mean(np.array(psnr_values)), step) 140 | summary_writer.scalar('ssim', np.mean(np.array(ssim_values)), step) 141 | summary_writer.image('target', showcase_gt, step) 142 | if FLAGS.save_output and (not config.render_path) and (jax.host_id() == 0): 143 | with utils.open_file(path.join(out_dir, f'psnrs_{step}.txt'), 'w') as f: 144 | f.write(' '.join([str(v) for v in psnr_values])) 145 | with utils.open_file(path.join(out_dir, f'ssims_{step}.txt'), 'w') as f: 146 | f.write(' '.join([str(v) for v in ssim_values])) 147 | if FLAGS.eval_once: 148 | break 149 | if int(step) >= config.max_steps: 150 | break 151 | last_step = step 152 | 153 | 154 | if __name__ == '__main__': 155 | app.run(main) 156 | -------------------------------------------------------------------------------- /internal/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Different datasets implementation plus a general port for all the datasets.""" 17 | import json 18 | import os 19 | from os import path 20 | import queue 21 | import threading 22 | import cv2 23 | import jax 24 | import numpy as np 25 | from PIL import Image 26 | from internal import utils 27 | 28 | 29 | def get_dataset(split, train_dir, config): 30 | return dataset_dict[config.dataset_loader](split, train_dir, config) 31 | 32 | 33 | def convert_to_ndc(origins, directions, focal, w, h, near=1.): 34 | """Convert a set of rays to NDC coordinates.""" 35 | # Shift ray origins to near plane 36 | t = -(near + origins[..., 2]) / directions[..., 2] 37 | origins = origins + t[..., None] * directions 38 | 39 | dx, dy, dz = tuple(np.moveaxis(directions, -1, 0)) 40 | ox, oy, oz = tuple(np.moveaxis(origins, -1, 0)) 41 | 42 | # Projection 43 | o0 = -((2 * focal) / w) * (ox / oz) 44 | o1 = -((2 * focal) / h) * (oy / oz) 45 | o2 = 1 + 2 * near / oz 46 | 47 | d0 = -((2 * focal) / w) * (dx / dz - ox / oz) 48 | d1 = -((2 * focal) / h) * (dy / dz - oy / oz) 49 | d2 = -2 * near / oz 50 | 51 | origins = np.stack([o0, o1, o2], -1) 52 | directions = np.stack([d0, d1, d2], -1) 53 | return origins, directions 54 | 55 | 56 | class Dataset(threading.Thread): 57 | """Dataset Base Class.""" 58 | 59 | def __init__(self, split, data_dir, config): 60 | super(Dataset, self).__init__() 61 | self.queue = queue.Queue(3) # Set prefetch buffer to 3 batches. 62 | self.daemon = True 63 | self.split = split 64 | self.data_dir = data_dir 65 | self.near = config.near 66 | self.far = config.far 67 | if split == 'train': 68 | self._train_init(config) 69 | elif split == 'test': 70 | self._test_init(config) 71 | else: 72 | raise ValueError( 73 | 'the split argument should be either \'train\' or \'test\', set' 74 | 'to {} here.'.format(split)) 75 | self.batch_size = config.batch_size // jax.host_count() 76 | self.batching = config.batching 77 | self.render_path = config.render_path 78 | self.start() 79 | 80 | def __iter__(self): 81 | return self 82 | 83 | def __next__(self): 84 | """Get the next training batch or test example. 85 | 86 | Returns: 87 | batch: dict, has 'pixels' and 'rays'. 88 | """ 89 | x = self.queue.get() 90 | if self.split == 'train': 91 | return utils.shard(x) 92 | else: 93 | return utils.to_device(x) 94 | 95 | def peek(self): 96 | """Peek at the next training batch or test example without dequeuing it. 97 | 98 | Returns: 99 | batch: dict, has 'pixels' and 'rays'. 100 | """ 101 | x = self.queue.queue[0].copy() # Make a copy of the front of the queue. 102 | if self.split == 'train': 103 | return utils.shard(x) 104 | else: 105 | return utils.to_device(x) 106 | 107 | def run(self): 108 | if self.split == 'train': 109 | next_func = self._next_train 110 | else: 111 | next_func = self._next_test 112 | while True: 113 | self.queue.put(next_func()) 114 | 115 | @property 116 | def size(self): 117 | return self.n_examples 118 | 119 | def _train_init(self, config): 120 | """Initialize training.""" 121 | self._load_renderings(config) 122 | self._generate_rays() 123 | 124 | if config.batching == 'all_images': 125 | # flatten the ray and image dimension together. 126 | self.images = self.images.reshape([-1, 3]) 127 | self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]), 128 | self.rays) 129 | elif config.batching == 'single_image': 130 | self.images = self.images.reshape([-1, self.resolution, 3]) 131 | self.rays = utils.namedtuple_map( 132 | lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays) 133 | else: 134 | raise NotImplementedError( 135 | f'{config.batching} batching strategy is not implemented.') 136 | 137 | def _test_init(self, config): 138 | self._load_renderings(config) 139 | self._generate_rays() 140 | self.it = 0 141 | 142 | def _next_train(self): 143 | """Sample next training batch.""" 144 | 145 | if self.batching == 'all_images': 146 | ray_indices = np.random.randint(0, self.rays[0].shape[0], 147 | (self.batch_size,)) 148 | batch_pixels = self.images[ray_indices] 149 | batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays) 150 | elif self.batching == 'single_image': 151 | image_index = np.random.randint(0, self.n_examples, ()) 152 | ray_indices = np.random.randint(0, self.rays[0][0].shape[0], 153 | (self.batch_size,)) 154 | batch_pixels = self.images[image_index][ray_indices] 155 | batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices], 156 | self.rays) 157 | else: 158 | raise NotImplementedError( 159 | f'{self.batching} batching strategy is not implemented.') 160 | 161 | return {'pixels': batch_pixels, 'rays': batch_rays} 162 | 163 | def _next_test(self): 164 | """Sample next test example.""" 165 | idx = self.it 166 | self.it = (self.it + 1) % self.n_examples 167 | 168 | if self.render_path: 169 | return {'rays': utils.namedtuple_map(lambda r: r[idx], self.render_rays)} 170 | else: 171 | return { 172 | 'pixels': self.images[idx], 173 | 'rays': utils.namedtuple_map(lambda r: r[idx], self.rays) 174 | } 175 | 176 | # TODO(bydeng): Swap this function with a more flexible camera model. 177 | def _generate_rays(self): 178 | """Generating rays for all images.""" 179 | x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking 180 | np.arange(self.w, dtype=np.float32), # X-Axis (columns) 181 | np.arange(self.h, dtype=np.float32), # Y-Axis (rows) 182 | indexing='xy') 183 | camera_dirs = np.stack( 184 | [(x - self.w * 0.5 + 0.5) / self.focal, 185 | -(y - self.h * 0.5 + 0.5) / self.focal, -np.ones_like(x)], 186 | axis=-1) 187 | directions = ((camera_dirs[None, ..., None, :] * 188 | self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1)) 189 | origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1], 190 | directions.shape) 191 | viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) 192 | 193 | # Distance from each unit-norm direction vector to its x-axis neighbor. 194 | dx = np.sqrt( 195 | np.sum((directions[:, :-1, :, :] - directions[:, 1:, :, :])**2, -1)) 196 | dx = np.concatenate([dx, dx[:, -2:-1, :]], 1) 197 | # Cut the distance in half, and then round it out so that it's 198 | # halfway between inscribed by / circumscribed about the pixel. 199 | 200 | radii = dx[..., None] * 2 / np.sqrt(12) 201 | 202 | ones = np.ones_like(origins[..., :1]) 203 | self.rays = utils.Rays( 204 | origins=origins, 205 | directions=directions, 206 | viewdirs=viewdirs, 207 | radii=radii, 208 | lossmult=ones, 209 | near=ones * self.near, 210 | far=ones * self.far) 211 | 212 | 213 | class Multicam(Dataset): 214 | """Multicam Dataset.""" 215 | 216 | def _load_renderings(self, config): 217 | """Load images from disk.""" 218 | if config.render_path: 219 | raise ValueError('render_path cannot be used for the Multicam dataset.') 220 | with utils.open_file(path.join(self.data_dir, 'metadata.json'), 221 | 'r') as fp: 222 | self.meta = json.load(fp)[self.split] 223 | self.meta = {k: np.array(self.meta[k]) for k in self.meta} 224 | # should now have ['pix2cam', 'cam2world', 'width', 'height'] in self.meta 225 | images = [] 226 | for fbase in self.meta['file_path']: 227 | fname = os.path.join(self.data_dir, fbase) 228 | with utils.open_file(fname, 'rb') as imgin: 229 | image = np.array(Image.open(imgin), dtype=np.float32) / 255. 230 | if config.white_bkgd: 231 | image = image[..., :3] * image[..., -1:] + (1. - image[..., -1:]) 232 | images.append(image[..., :3]) 233 | self.images = images 234 | self.n_examples = len(self.images) 235 | 236 | def _train_init(self, config): 237 | """Initialize training.""" 238 | self._load_renderings(config) 239 | self._generate_rays() 240 | 241 | def flatten(x): 242 | # Always flatten out the height x width dimensions 243 | x = [y.reshape([-1, y.shape[-1]]) for y in x] 244 | if config.batching == 'all_images': 245 | # If global batching, also concatenate all data into one list 246 | x = np.concatenate(x, axis=0) 247 | return x 248 | 249 | self.images = flatten(self.images) 250 | self.rays = utils.namedtuple_map(flatten, self.rays) 251 | 252 | def _test_init(self, config): 253 | self._load_renderings(config) 254 | self._generate_rays() 255 | self.it = 0 256 | 257 | def _generate_rays(self): 258 | """Generating rays for all images.""" 259 | pix2cam = self.meta['pix2cam'] 260 | cam2world = self.meta['cam2world'] 261 | width = self.meta['width'] 262 | height = self.meta['height'] 263 | 264 | def res2grid(w, h): 265 | return np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking 266 | np.arange(w, dtype=np.float32) + .5, # X-Axis (columns) 267 | np.arange(h, dtype=np.float32) + .5, # Y-Axis (rows) 268 | indexing='xy') 269 | 270 | xy = [res2grid(w, h) for w, h in zip(width, height)] 271 | pixel_dirs = [np.stack([x, y, np.ones_like(x)], axis=-1) for x, y in xy] 272 | camera_dirs = [v @ p2c[:3, :3].T for v, p2c in zip(pixel_dirs, pix2cam)] 273 | directions = [v @ c2w[:3, :3].T for v, c2w in zip(camera_dirs, cam2world)] 274 | origins = [ 275 | np.broadcast_to(c2w[:3, -1], v.shape) 276 | for v, c2w in zip(directions, cam2world) 277 | ] 278 | viewdirs = [ 279 | v / np.linalg.norm(v, axis=-1, keepdims=True) for v in directions 280 | ] 281 | 282 | def broadcast_scalar_attribute(x): 283 | return [ 284 | np.broadcast_to(x[i], origins[i][..., :1].shape) 285 | for i in range(self.n_examples) 286 | ] 287 | 288 | lossmult = broadcast_scalar_attribute(self.meta['lossmult']) 289 | near = broadcast_scalar_attribute(self.meta['near']) 290 | far = broadcast_scalar_attribute(self.meta['far']) 291 | 292 | # Distance from each unit-norm direction vector to its x-axis neighbor. 293 | dx = [ 294 | np.sqrt(np.sum((v[:-1, :, :] - v[1:, :, :])**2, -1)) for v in directions 295 | ] 296 | dx = [np.concatenate([v, v[-2:-1, :]], 0) for v in dx] 297 | # Cut the distance in half, and then round it out so that it's 298 | # halfway between inscribed by / circumscribed about the pixel. 299 | radii = [v[..., None] * 2 / np.sqrt(12) for v in dx] 300 | 301 | self.rays = utils.Rays( 302 | origins=origins, 303 | directions=directions, 304 | viewdirs=viewdirs, 305 | radii=radii, 306 | lossmult=lossmult, 307 | near=near, 308 | far=far) 309 | 310 | 311 | class Blender(Dataset): 312 | """Blender Dataset.""" 313 | 314 | def _load_renderings(self, config): 315 | """Load images from disk.""" 316 | if config.render_path: 317 | raise ValueError('render_path cannot be used for the blender dataset.') 318 | with utils.open_file( 319 | path.join(self.data_dir, 'transforms_{}.json'.format(self.split)), 320 | 'r') as fp: 321 | meta = json.load(fp) 322 | images = [] 323 | cams = [] 324 | for i in range(len(meta['frames'])): 325 | frame = meta['frames'][i] 326 | fname = os.path.join(self.data_dir, frame['file_path'] + '.png') 327 | with utils.open_file(fname, 'rb') as imgin: 328 | image = np.array(Image.open(imgin), dtype=np.float32) / 255. 329 | if config.factor == 2: 330 | [halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]] 331 | image = cv2.resize( 332 | image, (halfres_w, halfres_h), interpolation=cv2.INTER_AREA) 333 | elif config.factor > 0: 334 | raise ValueError('Blender dataset only supports factor=0 or 2, {} ' 335 | 'set.'.format(config.factor)) 336 | cams.append(np.array(frame['transform_matrix'], dtype=np.float32)) 337 | images.append(image) 338 | self.images = np.stack(images, axis=0) 339 | if config.white_bkgd: 340 | self.images = ( 341 | self.images[..., :3] * self.images[..., -1:] + 342 | (1. - self.images[..., -1:])) 343 | else: 344 | self.images = self.images[..., :3] 345 | self.h, self.w = self.images.shape[1:3] 346 | self.resolution = self.h * self.w 347 | self.camtoworlds = np.stack(cams, axis=0) 348 | camera_angle_x = float(meta['camera_angle_x']) 349 | self.focal = .5 * self.w / np.tan(.5 * camera_angle_x) 350 | self.n_examples = self.images.shape[0] 351 | 352 | 353 | class LLFF(Dataset): 354 | """LLFF Dataset.""" 355 | 356 | def _load_renderings(self, config): 357 | """Load images from disk.""" 358 | # Load images. 359 | imgdir_suffix = '' 360 | if config.factor > 0: 361 | imgdir_suffix = '_{}'.format(config.factor) 362 | factor = config.factor 363 | else: 364 | factor = 1 365 | imgdir = path.join(self.data_dir, 'images' + imgdir_suffix) 366 | if not utils.file_exists(imgdir): 367 | raise ValueError('Image folder {} does not exist.'.format(imgdir)) 368 | imgfiles = [ 369 | path.join(imgdir, f) 370 | for f in sorted(utils.listdir(imgdir)) 371 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png') 372 | ] 373 | images = [] 374 | for imgfile in imgfiles: 375 | with utils.open_file(imgfile, 'rb') as imgin: 376 | image = np.array(Image.open(imgin), dtype=np.float32) / 255. 377 | images.append(image) 378 | images = np.stack(images, axis=-1) 379 | 380 | # Load poses and bds. 381 | with utils.open_file(path.join(self.data_dir, 'poses_bounds.npy'), 382 | 'rb') as fp: 383 | poses_arr = np.load(fp) 384 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) 385 | bds = poses_arr[:, -2:].transpose([1, 0]) 386 | if poses.shape[-1] != images.shape[-1]: 387 | raise RuntimeError('Mismatch between imgs {} and poses {}'.format( 388 | images.shape[-1], poses.shape[-1])) 389 | 390 | # Update poses according to downsampling. 391 | poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) 392 | poses[2, 4, :] = poses[2, 4, :] * 1. / factor 393 | 394 | # Correct rotation matrix ordering and move variable dim to axis 0. 395 | poses = np.concatenate( 396 | [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 397 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 398 | images = np.moveaxis(images, -1, 0) 399 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 400 | 401 | # Rescale according to a default bd factor. 402 | scale = 1. / (bds.min() * .75) 403 | poses[:, :3, 3] *= scale 404 | bds *= scale 405 | 406 | # Recenter poses. 407 | poses = self._recenter_poses(poses) 408 | 409 | # Generate a spiral/spherical ray path for rendering videos. 410 | if config.spherify: 411 | poses = self._generate_spherical_poses(poses, bds) 412 | self.spherify = True 413 | else: 414 | self.spherify = False 415 | if not config.spherify and self.split == 'test': 416 | self._generate_spiral_poses(poses, bds) 417 | 418 | # Select the split. 419 | i_test = np.arange(images.shape[0])[::config.llffhold] 420 | i_train = np.array( 421 | [i for i in np.arange(int(images.shape[0])) if i not in i_test]) 422 | if self.split == 'train': 423 | indices = i_train 424 | else: 425 | indices = i_test 426 | images = images[indices] 427 | poses = poses[indices] 428 | 429 | self.images = images 430 | self.camtoworlds = poses[:, :3, :4] 431 | self.focal = poses[0, -1, -1] 432 | self.h, self.w = images.shape[1:3] 433 | self.resolution = self.h * self.w 434 | if config.render_path: 435 | self.n_examples = self.render_poses.shape[0] 436 | else: 437 | self.n_examples = images.shape[0] 438 | 439 | def _generate_rays(self): 440 | """Generate normalized device coordinate rays for llff.""" 441 | if self.split == 'test': 442 | n_render_poses = self.render_poses.shape[0] 443 | self.camtoworlds = np.concatenate([self.render_poses, self.camtoworlds], 444 | axis=0) 445 | 446 | super()._generate_rays() 447 | 448 | if not self.spherify: 449 | ndc_origins, ndc_directions = convert_to_ndc(self.rays.origins, 450 | self.rays.directions, 451 | self.focal, self.w, self.h) 452 | 453 | mat = ndc_origins 454 | # Distance from each unit-norm direction vector to its x-axis neighbor. 455 | dx = np.sqrt(np.sum((mat[:, :-1, :, :] - mat[:, 1:, :, :])**2, -1)) 456 | dx = np.concatenate([dx, dx[:, -2:-1, :]], 1) 457 | 458 | dy = np.sqrt(np.sum((mat[:, :, :-1, :] - mat[:, :, 1:, :])**2, -1)) 459 | dy = np.concatenate([dy, dy[:, :, -2:-1]], 2) 460 | # Cut the distance in half, and then round it out so that it's 461 | # halfway between inscribed by / circumscribed about the pixel. 462 | radii = (0.5 * (dx + dy))[..., None] * 2 / np.sqrt(12) 463 | 464 | ones = np.ones_like(ndc_origins[..., :1]) 465 | self.rays = utils.Rays( 466 | origins=ndc_origins, 467 | directions=ndc_directions, 468 | viewdirs=self.rays.directions, 469 | radii=radii, 470 | lossmult=ones, 471 | near=ones * self.near, 472 | far=ones * self.far) 473 | 474 | # Split poses from the dataset and generated poses 475 | if self.split == 'test': 476 | self.camtoworlds = self.camtoworlds[n_render_poses:] 477 | split = [np.split(r, [n_render_poses], 0) for r in self.rays] 478 | split0, split1 = zip(*split) 479 | self.render_rays = utils.Rays(*split0) 480 | self.rays = utils.Rays(*split1) 481 | 482 | def _recenter_poses(self, poses): 483 | """Recenter poses according to the original NeRF code.""" 484 | poses_ = poses.copy() 485 | bottom = np.reshape([0, 0, 0, 1.], [1, 4]) 486 | c2w = self._poses_avg(poses) 487 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 488 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 489 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 490 | poses = np.linalg.inv(c2w) @ poses 491 | poses_[:, :3, :4] = poses[:, :3, :4] 492 | poses = poses_ 493 | return poses 494 | 495 | def _poses_avg(self, poses): 496 | """Average poses according to the original NeRF code.""" 497 | hwf = poses[0, :3, -1:] 498 | center = poses[:, :3, 3].mean(0) 499 | vec2 = self._normalize(poses[:, :3, 2].sum(0)) 500 | up = poses[:, :3, 1].sum(0) 501 | c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1) 502 | return c2w 503 | 504 | def _viewmatrix(self, z, up, pos): 505 | """Construct lookat view matrix.""" 506 | vec2 = self._normalize(z) 507 | vec1_avg = up 508 | vec0 = self._normalize(np.cross(vec1_avg, vec2)) 509 | vec1 = self._normalize(np.cross(vec2, vec0)) 510 | m = np.stack([vec0, vec1, vec2, pos], 1) 511 | return m 512 | 513 | def _normalize(self, x): 514 | """Normalization helper function.""" 515 | return x / np.linalg.norm(x) 516 | 517 | def _generate_spiral_poses(self, poses, bds): 518 | """Generate a spiral path for rendering.""" 519 | c2w = self._poses_avg(poses) 520 | # Get average pose. 521 | up = self._normalize(poses[:, :3, 1].sum(0)) 522 | # Find a reasonable 'focus depth' for this dataset. 523 | close_depth, inf_depth = bds.min() * .9, bds.max() * 5. 524 | dt = .75 525 | mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth)) 526 | focal = mean_dz 527 | # Get radii for spiral path. 528 | tt = poses[:, :3, 3] 529 | rads = np.percentile(np.abs(tt), 90, 0) 530 | c2w_path = c2w 531 | n_views = 120 532 | n_rots = 2 533 | # Generate poses for spiral path. 534 | render_poses = [] 535 | rads = np.array(list(rads) + [1.]) 536 | hwf = c2w_path[:, 4:5] 537 | zrate = .5 538 | for theta in np.linspace(0., 2. * np.pi * n_rots, n_views + 1)[:-1]: 539 | c = np.dot(c2w[:3, :4], (np.array( 540 | [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)) 541 | z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) 542 | render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1)) 543 | self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4] 544 | 545 | def _generate_spherical_poses(self, poses, bds): 546 | """Generate a 360 degree spherical path for rendering.""" 547 | # pylint: disable=g-long-lambda 548 | p34_to_44 = lambda p: np.concatenate([ 549 | p, 550 | np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1]) 551 | ], 1) 552 | rays_d = poses[:, :3, 2:3] 553 | rays_o = poses[:, :3, 3:4] 554 | 555 | def min_line_dist(rays_o, rays_d): 556 | a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) 557 | b_i = -a_i @ rays_o 558 | pt_mindist = np.squeeze(-np.linalg.inv( 559 | (np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) @ (b_i).mean(0)) 560 | return pt_mindist 561 | 562 | pt_mindist = min_line_dist(rays_o, rays_d) 563 | center = pt_mindist 564 | up = (poses[:, :3, 3] - center).mean(0) 565 | vec0 = self._normalize(up) 566 | vec1 = self._normalize(np.cross([.1, .2, .3], vec0)) 567 | vec2 = self._normalize(np.cross(vec0, vec1)) 568 | pos = center 569 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 570 | poses_reset = ( 571 | np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])) 572 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) 573 | sc = 1. / rad 574 | poses_reset[:, :3, 3] *= sc 575 | bds *= sc 576 | rad *= sc 577 | centroid = np.mean(poses_reset[:, :3, 3], 0) 578 | zh = centroid[2] 579 | radcircle = np.sqrt(rad**2 - zh**2) 580 | new_poses = [] 581 | 582 | for th in np.linspace(0., 2. * np.pi, 120): 583 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 584 | up = np.array([0, 0, -1.]) 585 | vec2 = self._normalize(camorigin) 586 | vec0 = self._normalize(np.cross(vec2, up)) 587 | vec1 = self._normalize(np.cross(vec2, vec0)) 588 | pos = camorigin 589 | p = np.stack([vec0, vec1, vec2, pos], 1) 590 | new_poses.append(p) 591 | 592 | new_poses = np.stack(new_poses, 0) 593 | new_poses = np.concatenate([ 594 | new_poses, 595 | np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape) 596 | ], -1) 597 | poses_reset = np.concatenate([ 598 | poses_reset[:, :3, :4], 599 | np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape) 600 | ], -1) 601 | if self.split == 'test': 602 | self.render_poses = new_poses[:, :3, :4] 603 | return poses_reset 604 | 605 | 606 | dataset_dict = { 607 | 'blender': Blender, 608 | 'llff': LLFF, 609 | 'multicam': Multicam, 610 | } 611 | -------------------------------------------------------------------------------- /internal/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Mathy utility functions.""" 17 | import jax 18 | import jax.numpy as jnp 19 | import jax.scipy as jsp 20 | 21 | 22 | def matmul(a, b): 23 | """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" 24 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 25 | 26 | 27 | def safe_trig_helper(x, fn, t=100 * jnp.pi): 28 | return fn(jnp.where(jnp.abs(x) < t, x, x % t)) 29 | 30 | 31 | def safe_cos(x): 32 | """jnp.cos() on a TPU may NaN out for large values.""" 33 | return safe_trig_helper(x, jnp.cos) 34 | 35 | 36 | def safe_sin(x): 37 | """jnp.sin() on a TPU may NaN out for large values.""" 38 | return safe_trig_helper(x, jnp.sin) 39 | 40 | 41 | def mse_to_psnr(mse): 42 | """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" 43 | return -10. / jnp.log(10.) * jnp.log(mse) 44 | 45 | 46 | def psnr_to_mse(psnr): 47 | """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" 48 | return jnp.exp(-0.1 * jnp.log(10.) * psnr) 49 | 50 | 51 | def compute_avg_error(psnr, ssim, lpips): 52 | """The 'average' error used in the paper.""" 53 | mse = psnr_to_mse(psnr) 54 | dssim = jnp.sqrt(1 - ssim) 55 | return jnp.exp(jnp.mean(jnp.log(jnp.array([mse, dssim, lpips])))) 56 | 57 | 58 | def compute_ssim(img0, 59 | img1, 60 | max_val, 61 | filter_size=11, 62 | filter_sigma=1.5, 63 | k1=0.01, 64 | k2=0.03, 65 | return_map=False): 66 | """Computes SSIM from two images. 67 | 68 | This function was modeled after tf.image.ssim, and should produce comparable 69 | output. 70 | 71 | Args: 72 | img0: array. An image of size [..., width, height, num_channels]. 73 | img1: array. An image of size [..., width, height, num_channels]. 74 | max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. 75 | filter_size: int >= 1. Window size. 76 | filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. 77 | k1: float > 0. One of the SSIM dampening parameters. 78 | k2: float > 0. One of the SSIM dampening parameters. 79 | return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned 80 | 81 | Returns: 82 | Each image's mean SSIM, or a tensor of individual values if `return_map`. 83 | """ 84 | # Construct a 1D Gaussian blur filter. 85 | hw = filter_size // 2 86 | shift = (2 * hw - filter_size + 1) / 2 87 | f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2 88 | filt = jnp.exp(-0.5 * f_i) 89 | filt /= jnp.sum(filt) 90 | 91 | # Blur in x and y (faster than the 2D convolution). 92 | def convolve2d(z, f): 93 | return jsp.signal.convolve2d( 94 | z, f, mode='valid', precision=jax.lax.Precision.HIGHEST) 95 | 96 | filt_fn1 = lambda z: convolve2d(z, filt[:, None]) 97 | filt_fn2 = lambda z: convolve2d(z, filt[None, :]) 98 | 99 | # Vmap the blurs to the tensor size, and then compose them. 100 | num_dims = len(img0.shape) 101 | map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1]) 102 | for d in map_axes: 103 | filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d) 104 | filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d) 105 | filt_fn = lambda z: filt_fn1(filt_fn2(z)) 106 | 107 | mu0 = filt_fn(img0) 108 | mu1 = filt_fn(img1) 109 | mu00 = mu0 * mu0 110 | mu11 = mu1 * mu1 111 | mu01 = mu0 * mu1 112 | sigma00 = filt_fn(img0**2) - mu00 113 | sigma11 = filt_fn(img1**2) - mu11 114 | sigma01 = filt_fn(img0 * img1) - mu01 115 | 116 | # Clip the variances and covariances to valid values. 117 | # Variance must be non-negative: 118 | sigma00 = jnp.maximum(0., sigma00) 119 | sigma11 = jnp.maximum(0., sigma11) 120 | sigma01 = jnp.sign(sigma01) * jnp.minimum( 121 | jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) 122 | 123 | c1 = (k1 * max_val)**2 124 | c2 = (k2 * max_val)**2 125 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 126 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 127 | ssim_map = numer / denom 128 | ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims))) 129 | return ssim_map if return_map else ssim 130 | 131 | 132 | def linear_to_srgb(linear): 133 | # Assumes `linear` is in [0, 1]. https://en.wikipedia.org/wiki/SRGB 134 | eps = jnp.finfo(jnp.float32).eps 135 | srgb0 = 323 / 25 * linear 136 | srgb1 = (211 * jnp.maximum(eps, linear)**(5 / 12) - 11) / 200 137 | return jnp.where(linear <= 0.0031308, srgb0, srgb1) 138 | 139 | 140 | def srgb_to_linear(srgb): 141 | # Assumes `srgb` is in [0, 1]. https://en.wikipedia.org/wiki/SRGB 142 | eps = jnp.finfo(jnp.float32).eps 143 | linear0 = 25 / 323 * srgb 144 | linear1 = jnp.maximum(eps, ((200 * srgb + 11) / (211)))**(12 / 5) 145 | return jnp.where(srgb <= 0.04045, linear0, linear1) 146 | 147 | 148 | def learning_rate_decay(step, 149 | lr_init, 150 | lr_final, 151 | max_steps, 152 | lr_delay_steps=0, 153 | lr_delay_mult=1): 154 | """Continuous learning rate decay function. 155 | 156 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 157 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 158 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 159 | function of lr_delay_mult, such that the initial learning rate is 160 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 161 | to the normal learning rate when steps>lr_delay_steps. 162 | 163 | Args: 164 | step: int, the current optimization step. 165 | lr_init: float, the initial learning rate. 166 | lr_final: float, the final learning rate. 167 | max_steps: int, the number of steps during optimization. 168 | lr_delay_steps: int, the number of steps to delay the full learning rate. 169 | lr_delay_mult: float, the multiplier on the rate when delaying it. 170 | 171 | Returns: 172 | lr: the learning for current step 'step'. 173 | """ 174 | if lr_delay_steps > 0: 175 | # A kind of reverse cosine decay. 176 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * jnp.sin( 177 | 0.5 * jnp.pi * jnp.clip(step / lr_delay_steps, 0, 1)) 178 | else: 179 | delay_rate = 1. 180 | t = jnp.clip(step / max_steps, 0, 1) 181 | log_lerp = jnp.exp(jnp.log(lr_init) * (1 - t) + jnp.log(lr_final) * t) 182 | return delay_rate * log_lerp 183 | 184 | 185 | def sorted_piecewise_constant_pdf(key, bins, weights, num_samples, randomized): 186 | """Piecewise-Constant PDF sampling from sorted bins. 187 | 188 | Args: 189 | key: jnp.ndarray(float32), [2,], random number generator. 190 | bins: jnp.ndarray(float32), [batch_size, num_bins + 1]. 191 | weights: jnp.ndarray(float32), [batch_size, num_bins]. 192 | num_samples: int, the number of samples. 193 | randomized: bool, use randomized samples. 194 | 195 | Returns: 196 | t_samples: jnp.ndarray(float32), [batch_size, num_samples]. 197 | """ 198 | # Pad each weight vector (only if necessary) to bring its sum to `eps`. This 199 | # avoids NaNs when the input is zeros or small, but has no effect otherwise. 200 | eps = 1e-5 201 | weight_sum = jnp.sum(weights, axis=-1, keepdims=True) 202 | padding = jnp.maximum(0, eps - weight_sum) 203 | weights += padding / weights.shape[-1] 204 | weight_sum += padding 205 | 206 | # Compute the PDF and CDF for each weight vector, while ensuring that the CDF 207 | # starts with exactly 0 and ends with exactly 1. 208 | pdf = weights / weight_sum 209 | cdf = jnp.minimum(1, jnp.cumsum(pdf[..., :-1], axis=-1)) 210 | cdf = jnp.concatenate([ 211 | jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf, 212 | jnp.ones(list(cdf.shape[:-1]) + [1]) 213 | ], 214 | axis=-1) 215 | 216 | # Draw uniform samples. 217 | if randomized: 218 | s = 1 / num_samples 219 | u = jnp.arange(num_samples) * s 220 | u += jax.random.uniform( 221 | key, 222 | list(cdf.shape[:-1]) + [num_samples], 223 | maxval=s - jnp.finfo('float32').eps) 224 | # `u` is in [0, 1) --- it can be zero, but it can never be 1. 225 | u = jnp.minimum(u, 1. - jnp.finfo('float32').eps) 226 | else: 227 | # Match the behavior of jax.random.uniform() by spanning [0, 1-eps]. 228 | u = jnp.linspace(0., 1. - jnp.finfo('float32').eps, num_samples) 229 | u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples]) 230 | 231 | # Identify the location in `cdf` that corresponds to a random sample. 232 | # The final `True` index in `mask` will be the start of the sampled interval. 233 | mask = u[..., None, :] >= cdf[..., :, None] 234 | 235 | def find_interval(x): 236 | # Grab the value where `mask` switches from True to False, and vice versa. 237 | # This approach takes advantage of the fact that `x` is sorted. 238 | x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2) 239 | x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2) 240 | return x0, x1 241 | 242 | bins_g0, bins_g1 = find_interval(bins) 243 | cdf_g0, cdf_g1 = find_interval(cdf) 244 | 245 | t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1) 246 | samples = bins_g0 + t * (bins_g1 - bins_g0) 247 | return samples 248 | -------------------------------------------------------------------------------- /internal/math_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Unit tests for math.""" 17 | import functools 18 | 19 | from absl.testing import absltest 20 | import jax 21 | from jax import random 22 | from jax import test_util as jtu 23 | import jax.numpy as jnp 24 | import numpy as np 25 | import scipy as sp 26 | import tensorflow as tf 27 | 28 | from internal import math 29 | 30 | 31 | def safe_trig_harness(fn, max_exp): 32 | x = 10**np.linspace(-30, max_exp, 10000) 33 | x = np.concatenate([-x[::-1], np.array([0]), x]) 34 | y_true = getattr(np, fn)(x) 35 | y = getattr(math, 'safe_' + fn)(x) 36 | return y_true, y 37 | 38 | 39 | class MathUtilsTest(jtu.JaxTestCase): 40 | 41 | def test_sin(self): 42 | """In [-1e10, 1e10] safe_sin and safe_cos are accurate.""" 43 | for fn in ['sin', 'cos']: 44 | y_true, y = safe_trig_harness(fn, 10) 45 | self.assertLess(np.max(np.abs(y - y_true)), 1e-4) 46 | self.assertFalse(jnp.any(jnp.isnan(y))) 47 | # Beyond that range it's less accurate but we just don't want it to be NaN. 48 | for fn in ['sin', 'cos']: 49 | y_true, y = safe_trig_harness(fn, 60) 50 | self.assertFalse(jnp.any(jnp.isnan(y))) 51 | 52 | def test_psnr_round_trip(self): 53 | """MSE -> PSNR -> MSE is a no-op.""" 54 | mse = 0.07 55 | self.assertAllClose(math.psnr_to_mse(math.mse_to_psnr(mse)), mse) 56 | 57 | def test_learning_rate_decay(self): 58 | np.random.seed(0) 59 | for _ in range(10): 60 | lr_init = np.exp(np.random.normal() - 3) 61 | lr_final = lr_init * np.exp(np.random.normal() - 5) 62 | max_steps = int(np.ceil(100 + 100 * np.exp(np.random.normal()))) 63 | 64 | lr_fn = functools.partial( 65 | math.learning_rate_decay, 66 | lr_init=lr_init, 67 | lr_final=lr_final, 68 | max_steps=max_steps) 69 | 70 | # Test that the rate at the beginning is the initial rate. 71 | self.assertAllClose(lr_fn(0), lr_init) 72 | 73 | # Test that the rate at the end is the final rate. 74 | self.assertAllClose(lr_fn(max_steps), lr_final) 75 | 76 | # Test that the rate at the middle is the geometric mean of the two rates. 77 | self.assertAllClose(lr_fn(max_steps / 2), np.sqrt(lr_init * lr_final)) 78 | 79 | # Test that the rate past the end is the final rate 80 | self.assertAllClose(lr_fn(max_steps + 100), lr_final) 81 | 82 | def test_delayed_learning_rate_decay(self): 83 | np.random.seed(0) 84 | for _ in range(10): 85 | lr_init = np.exp(np.random.normal() - 3) 86 | lr_final = lr_init * np.exp(np.random.normal() - 5) 87 | max_steps = int(np.ceil(100 + 100 * np.exp(np.random.normal()))) 88 | lr_delay_steps = int(np.random.uniform(low=0.1, high=0.4) * max_steps) 89 | lr_delay_mult = np.exp(np.random.normal() - 3) 90 | 91 | lr_fn = functools.partial( 92 | math.learning_rate_decay, 93 | lr_init=lr_init, 94 | lr_final=lr_final, 95 | max_steps=max_steps, 96 | lr_delay_steps=lr_delay_steps, 97 | lr_delay_mult=lr_delay_mult) 98 | 99 | # Test that the rate at the beginning is the delayed initial rate. 100 | self.assertAllClose(lr_fn(0), lr_delay_mult * lr_init) 101 | 102 | # Test that the rate at the end is the final rate. 103 | self.assertAllClose(lr_fn(max_steps), lr_final) 104 | 105 | # Test that the rate at after the delay is over is the usual rate. 106 | self.assertAllClose( 107 | lr_fn(lr_delay_steps), 108 | math.learning_rate_decay(lr_delay_steps, lr_init, lr_final, 109 | max_steps)) 110 | 111 | # Test that the rate at the middle is the geometric mean of the two rates. 112 | self.assertAllClose(lr_fn(max_steps / 2), np.sqrt(lr_init * lr_final)) 113 | 114 | # Test that the rate past the end is the final rate 115 | self.assertAllClose(lr_fn(max_steps + 100), lr_final) 116 | 117 | def test_ssim_golden(self): 118 | """Test our SSIM implementation against the Tensorflow version.""" 119 | rng = random.PRNGKey(0) 120 | shape = (2, 12, 12, 3) 121 | for _ in range(4): 122 | rng, key = random.split(rng) 123 | max_val = random.uniform(key, minval=0.1, maxval=3.) 124 | rng, key = random.split(rng) 125 | img0 = max_val * random.uniform(key, shape=shape, minval=-1, maxval=1) 126 | rng, key = random.split(rng) 127 | img1 = max_val * random.uniform(key, shape=shape, minval=-1, maxval=1) 128 | rng, key = random.split(rng) 129 | filter_size = random.randint(key, shape=(), minval=1, maxval=10) 130 | rng, key = random.split(rng) 131 | filter_sigma = random.uniform(key, shape=(), minval=0.1, maxval=10.) 132 | rng, key = random.split(rng) 133 | k1 = random.uniform(key, shape=(), minval=0.001, maxval=0.1) 134 | rng, key = random.split(rng) 135 | k2 = random.uniform(key, shape=(), minval=0.001, maxval=0.1) 136 | 137 | ssim_gt = tf.image.ssim( 138 | img0, 139 | img1, 140 | max_val, 141 | filter_size=filter_size, 142 | filter_sigma=filter_sigma, 143 | k1=k1, 144 | k2=k2).numpy() 145 | for return_map in [False, True]: 146 | ssim_fn = jax.jit( 147 | functools.partial( 148 | math.compute_ssim, 149 | max_val=max_val, 150 | filter_size=filter_size, 151 | filter_sigma=filter_sigma, 152 | k1=k1, 153 | k2=k2, 154 | return_map=return_map)) 155 | ssim = ssim_fn(img0, img1) 156 | if not return_map: 157 | self.assertAllClose(ssim, ssim_gt) 158 | else: 159 | self.assertAllClose(np.mean(ssim, [1, 2, 3]), ssim_gt) 160 | self.assertLessEqual(np.max(ssim), 1.) 161 | self.assertGreaterEqual(np.min(ssim), -1.) 162 | 163 | def test_ssim_lowerbound(self): 164 | """Test the unusual corner case where SSIM is -1.""" 165 | sz = 11 166 | img = np.meshgrid(*([np.linspace(-1, 1, sz)] * 2))[0][None, ..., None] 167 | eps = 1e-5 168 | ssim = math.compute_ssim( 169 | img, -img, 1., filter_size=sz, filter_sigma=1.5, k1=eps, k2=eps) 170 | self.assertAllClose(ssim, -np.ones_like(ssim)) 171 | 172 | def test_srgb_linearize(self): 173 | x = np.linspace(-1, 3, 10000) # Nobody should call this <0 but it works. 174 | # Check that the round-trip transformation is a no-op. 175 | self.assertAllClose(math.linear_to_srgb(math.srgb_to_linear(x)), x) 176 | self.assertAllClose(math.srgb_to_linear(math.linear_to_srgb(x)), x) 177 | # Check that gradients are finite. 178 | self.assertTrue( 179 | np.all(np.isfinite(jax.vmap(jax.grad(math.linear_to_srgb))(x)))) 180 | self.assertTrue( 181 | np.all(np.isfinite(jax.vmap(jax.grad(math.srgb_to_linear))(x)))) 182 | 183 | def test_sorted_piecewise_constant_pdf_train_mode(self): 184 | """Test that piecewise-constant sampling reproduces its distribution.""" 185 | batch_size = 4 186 | num_bins = 16 187 | num_samples = 1000000 188 | precision = 1e5 189 | rng = random.PRNGKey(20202020) 190 | 191 | # Generate a series of random PDFs to sample from. 192 | data = [] 193 | for _ in range(batch_size): 194 | rng, key = random.split(rng) 195 | # Randomly initialize the distances between bins. 196 | # We're rolling our own fixed precision here to make cumsum exact. 197 | bins_delta = jnp.round(precision * jnp.exp( 198 | random.uniform(key, shape=(num_bins + 1,), minval=-3, maxval=3))) 199 | 200 | # Set some of the bin distances to 0. 201 | rng, key = random.split(rng) 202 | bins_delta *= random.uniform(key, shape=bins_delta.shape) < 0.9 203 | 204 | # Integrate the bins. 205 | bins = jnp.cumsum(bins_delta) / precision 206 | rng, key = random.split(rng) 207 | bins += random.normal(key) * num_bins / 2 208 | rng, key = random.split(rng) 209 | 210 | # Randomly generate weights, allowing some to be zero. 211 | weights = jnp.maximum( 212 | 0, random.uniform(key, shape=(num_bins,), minval=-0.5, maxval=1.)) 213 | gt_hist = weights / weights.sum() 214 | data.append((bins, weights, gt_hist)) 215 | 216 | # Tack on an "all zeros" weight matrix, which is a common cause of NaNs. 217 | weights = jnp.zeros_like(weights) 218 | gt_hist = jnp.ones_like(gt_hist) / num_bins 219 | data.append((bins, weights, gt_hist)) 220 | 221 | bins, weights, gt_hist = [jnp.stack(x) for x in zip(*data)] 222 | 223 | for randomized in [True, False]: 224 | rng, key = random.split(rng) 225 | # Draw samples from the batch of PDFs. 226 | samples = math.sorted_piecewise_constant_pdf( 227 | key, 228 | bins, 229 | weights, 230 | num_samples, 231 | randomized, 232 | ) 233 | self.assertEqual(samples.shape[-1], num_samples) 234 | 235 | # Check that samples are sorted. 236 | self.assertTrue(jnp.all(samples[..., 1:] >= samples[..., :-1])) 237 | 238 | # Verify that each set of samples resembles the target distribution. 239 | for i_samples, i_bins, i_gt_hist in zip(samples, bins, gt_hist): 240 | i_hist = jnp.float32(jnp.histogram(i_samples, i_bins)[0]) / num_samples 241 | i_gt_hist = jnp.array(i_gt_hist) 242 | 243 | # Merge any of the zero-span bins until there aren't any left. 244 | while jnp.any(i_bins[:-1] == i_bins[1:]): 245 | j = int(jnp.where(i_bins[:-1] == i_bins[1:])[0][0]) 246 | i_hist = jnp.concatenate([ 247 | i_hist[:j], 248 | jnp.array([i_hist[j] + i_hist[j + 1]]), i_hist[j + 2:] 249 | ]) 250 | i_gt_hist = jnp.concatenate([ 251 | i_gt_hist[:j], 252 | jnp.array([i_gt_hist[j] + i_gt_hist[j + 1]]), i_gt_hist[j + 2:] 253 | ]) 254 | i_bins = jnp.concatenate([i_bins[:j], i_bins[j + 1:]]) 255 | 256 | # Angle between the two histograms in degrees. 257 | angle = 180 / jnp.pi * jnp.arccos( 258 | jnp.minimum( 259 | 1., 260 | jnp.mean( 261 | (i_hist * i_gt_hist) / 262 | jnp.sqrt(jnp.mean(i_hist**2) * jnp.mean(i_gt_hist**2))))) 263 | # Jensen-Shannon divergence. 264 | m = (i_hist + i_gt_hist) / 2 265 | js_div = jnp.sum( 266 | sp.special.kl_div(i_hist, m) + sp.special.kl_div(i_gt_hist, m)) / 2 267 | self.assertLessEqual(angle, 0.5) 268 | self.assertLessEqual(js_div, 1e-5) 269 | 270 | def test_sorted_piecewise_constant_pdf_large_flat(self): 271 | """Test sampling when given a large flat distribution.""" 272 | num_samples = 100 273 | num_bins = 100000 274 | key = random.PRNGKey(0) 275 | bins = jnp.arange(num_bins) 276 | weights = np.ones(len(bins) - 1) 277 | samples = math.sorted_piecewise_constant_pdf( 278 | key, 279 | bins[None], 280 | weights[None], 281 | num_samples, 282 | True, 283 | )[0] 284 | # All samples should be within the range of the bins. 285 | self.assertTrue(jnp.all(samples >= bins[0])) 286 | self.assertTrue(jnp.all(samples <= bins[-1])) 287 | 288 | # Samples modded by their bin index should resemble a uniform distribution. 289 | samples_mod = jnp.mod(samples, 1) 290 | self.assertLessEqual( 291 | sp.stats.kstest(samples_mod, 'uniform', (0, 1)).statistic, 0.2) 292 | 293 | # All samples should collectively resemble a uniform distribution. 294 | self.assertLessEqual( 295 | sp.stats.kstest(samples, 'uniform', (bins[0], bins[-1])).statistic, 0.2) 296 | 297 | def test_sorted_piecewise_constant_pdf_sparse_delta(self): 298 | """Test sampling when given a large distribution with a big delta in it.""" 299 | num_samples = 100 300 | num_bins = 100000 301 | key = random.PRNGKey(0) 302 | bins = jnp.arange(num_bins) 303 | weights = np.ones(len(bins) - 1) 304 | delta_idx = len(weights) // 2 305 | weights[delta_idx] = len(weights) - 1 306 | samples = math.sorted_piecewise_constant_pdf( 307 | key, 308 | bins[None], 309 | weights[None], 310 | num_samples, 311 | True, 312 | )[0] 313 | 314 | # All samples should be within the range of the bins. 315 | self.assertTrue(jnp.all(samples >= bins[0])) 316 | self.assertTrue(jnp.all(samples <= bins[-1])) 317 | 318 | # Samples modded by their bin index should resemble a uniform distribution. 319 | samples_mod = jnp.mod(samples, 1) 320 | self.assertLessEqual( 321 | sp.stats.kstest(samples_mod, 'uniform', (0, 1)).statistic, 0.2) 322 | 323 | # The delta function bin should contain ~half of the samples. 324 | in_delta = (samples >= bins[delta_idx]) & (samples <= bins[delta_idx + 1]) 325 | self.assertAllClose(jnp.mean(in_delta), 0.5, atol=0.05) 326 | 327 | def test_sorted_piecewise_constant_pdf_single_bin(self): 328 | """Test sampling when given a small `one hot' distribution.""" 329 | num_samples = 625 330 | key = random.PRNGKey(0) 331 | bins = jnp.array([0, 1, 3, 6, 10], jnp.float32) 332 | for randomized in [False, True]: 333 | for i in range(len(bins) - 1): 334 | weights = np.zeros(len(bins) - 1, jnp.float32) 335 | weights[i] = 1. 336 | samples = math.sorted_piecewise_constant_pdf( 337 | key, 338 | bins[None], 339 | weights[None], 340 | num_samples, 341 | randomized, 342 | )[0] 343 | 344 | # All samples should be within [bins[i], bins[i+1]]. 345 | self.assertTrue(jnp.all(samples >= bins[i])) 346 | self.assertTrue(jnp.all(samples <= bins[i + 1])) 347 | 348 | 349 | if __name__ == '__main__': 350 | absltest.main(testLoader=jtu.JaxTestLoader()) 351 | -------------------------------------------------------------------------------- /internal/mip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Helper functions for mip-NeRF.""" 17 | 18 | from jax import lax 19 | from jax import random 20 | import jax.numpy as jnp 21 | 22 | from internal import math 23 | 24 | 25 | def pos_enc(x, min_deg, max_deg, append_identity=True): 26 | """The positional encoding used by the original NeRF paper.""" 27 | scales = jnp.array([2**i for i in range(min_deg, max_deg)]) 28 | xb = jnp.reshape((x[..., None, :] * scales[:, None]), 29 | list(x.shape[:-1]) + [-1]) 30 | four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1)) 31 | if append_identity: 32 | return jnp.concatenate([x] + [four_feat], axis=-1) 33 | else: 34 | return four_feat 35 | 36 | 37 | def expected_sin(x, x_var): 38 | """Estimates mean and variance of sin(z), z ~ N(x, var).""" 39 | # When the variance is wide, shrink sin towards zero. 40 | y = jnp.exp(-0.5 * x_var) * math.safe_sin(x) 41 | y_var = jnp.maximum( 42 | 0, 0.5 * (1 - jnp.exp(-2 * x_var) * math.safe_cos(2 * x)) - y**2) 43 | return y, y_var 44 | 45 | 46 | def lift_gaussian(d, t_mean, t_var, r_var, diag): 47 | """Lift a Gaussian defined along a ray to 3D coordinates.""" 48 | mean = d[..., None, :] * t_mean[..., None] 49 | 50 | d_mag_sq = jnp.maximum(1e-10, jnp.sum(d**2, axis=-1, keepdims=True)) 51 | 52 | if diag: 53 | d_outer_diag = d**2 54 | null_outer_diag = 1 - d_outer_diag / d_mag_sq 55 | t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :] 56 | xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :] 57 | cov_diag = t_cov_diag + xy_cov_diag 58 | return mean, cov_diag 59 | else: 60 | d_outer = d[..., :, None] * d[..., None, :] 61 | eye = jnp.eye(d.shape[-1]) 62 | null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :] 63 | t_cov = t_var[..., None, None] * d_outer[..., None, :, :] 64 | xy_cov = r_var[..., None, None] * null_outer[..., None, :, :] 65 | cov = t_cov + xy_cov 66 | return mean, cov 67 | 68 | 69 | def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True): 70 | """Approximate a conical frustum as a Gaussian distribution (mean+cov). 71 | 72 | Assumes the ray is originating from the origin, and base_radius is the 73 | radius at dist=1. Doesn't assume `d` is normalized. 74 | 75 | Args: 76 | d: jnp.float32 3-vector, the axis of the cone 77 | t0: float, the starting distance of the frustum. 78 | t1: float, the ending distance of the frustum. 79 | base_radius: float, the scale of the radius as a function of distance. 80 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance. 81 | stable: boolean, whether or not to use the stable computation described in 82 | the paper (setting this to False will cause catastrophic failure). 83 | 84 | Returns: 85 | a Gaussian (mean and covariance). 86 | """ 87 | if stable: 88 | mu = (t0 + t1) / 2 89 | hw = (t1 - t0) / 2 90 | t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2) 91 | t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) / 92 | (3 * mu**2 + hw**2)**2) 93 | r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 * 94 | (hw**4) / (3 * mu**2 + hw**2)) 95 | else: 96 | t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3)) 97 | r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3)) 98 | t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3) 99 | t_var = t_mosq - t_mean**2 100 | return lift_gaussian(d, t_mean, t_var, r_var, diag) 101 | 102 | 103 | def cylinder_to_gaussian(d, t0, t1, radius, diag): 104 | """Approximate a cylinder as a Gaussian distribution (mean+cov). 105 | 106 | Assumes the ray is originating from the origin, and radius is the 107 | radius. Does not renormalize `d`. 108 | 109 | Args: 110 | d: jnp.float32 3-vector, the axis of the cylinder 111 | t0: float, the starting distance of the cylinder. 112 | t1: float, the ending distance of the cylinder. 113 | radius: float, the radius of the cylinder 114 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance. 115 | 116 | Returns: 117 | a Gaussian (mean and covariance). 118 | """ 119 | t_mean = (t0 + t1) / 2 120 | r_var = radius**2 / 4 121 | t_var = (t1 - t0)**2 / 12 122 | return lift_gaussian(d, t_mean, t_var, r_var, diag) 123 | 124 | 125 | def cast_rays(t_vals, origins, directions, radii, ray_shape, diag=True): 126 | """Cast rays (cone- or cylinder-shaped) and featurize sections of it. 127 | 128 | Args: 129 | t_vals: float array, the "fencepost" distances along the ray. 130 | origins: float array, the ray origin coordinates. 131 | directions: float array, the ray direction vectors. 132 | radii: float array, the radii (base radii for cones) of the rays. 133 | ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'. 134 | diag: boolean, whether or not the covariance matrices should be diagonal. 135 | 136 | Returns: 137 | a tuple of arrays of means and covariances. 138 | """ 139 | t0 = t_vals[..., :-1] 140 | t1 = t_vals[..., 1:] 141 | if ray_shape == 'cone': 142 | gaussian_fn = conical_frustum_to_gaussian 143 | elif ray_shape == 'cylinder': 144 | gaussian_fn = cylinder_to_gaussian 145 | else: 146 | assert False 147 | means, covs = gaussian_fn(directions, t0, t1, radii, diag) 148 | means = means + origins[..., None, :] 149 | return means, covs 150 | 151 | 152 | def integrated_pos_enc(x_coord, min_deg, max_deg, diag=True): 153 | """Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1]. 154 | 155 | Args: 156 | x_coord: a tuple containing: x, jnp.ndarray, variables to be encoded. Should 157 | be in [-pi, pi]. x_cov, jnp.ndarray, covariance matrices for `x`. 158 | min_deg: int, the min degree of the encoding. 159 | max_deg: int, the max degree of the encoding. 160 | diag: bool, if true, expects input covariances to be diagonal (full 161 | otherwise). 162 | 163 | Returns: 164 | encoded: jnp.ndarray, encoded variables. 165 | """ 166 | if diag: 167 | x, x_cov_diag = x_coord 168 | scales = jnp.array([2**i for i in range(min_deg, max_deg)]) 169 | shape = list(x.shape[:-1]) + [-1] 170 | y = jnp.reshape(x[..., None, :] * scales[:, None], shape) 171 | y_var = jnp.reshape(x_cov_diag[..., None, :] * scales[:, None]**2, shape) 172 | else: 173 | x, x_cov = x_coord 174 | num_dims = x.shape[-1] 175 | basis = jnp.concatenate( 176 | [2**i * jnp.eye(num_dims) for i in range(min_deg, max_deg)], 1) 177 | y = math.matmul(x, basis) 178 | # Get the diagonal of a covariance matrix (ie, variance). This is equivalent 179 | # to jax.vmap(jnp.diag)((basis.T @ covs) @ basis). 180 | y_var = jnp.sum((math.matmul(x_cov, basis)) * basis, -2) 181 | 182 | return expected_sin( 183 | jnp.concatenate([y, y + 0.5 * jnp.pi], axis=-1), 184 | jnp.concatenate([y_var] * 2, axis=-1))[0] 185 | 186 | 187 | def volumetric_rendering(rgb, density, t_vals, dirs, white_bkgd): 188 | """Volumetric Rendering Function. 189 | 190 | Args: 191 | rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3] 192 | density: jnp.ndarray(float32), density, [batch_size, num_samples, 1]. 193 | t_vals: jnp.ndarray(float32), [batch_size, num_samples]. 194 | dirs: jnp.ndarray(float32), [batch_size, 3]. 195 | white_bkgd: bool. 196 | 197 | Returns: 198 | comp_rgb: jnp.ndarray(float32), [batch_size, 3]. 199 | disp: jnp.ndarray(float32), [batch_size]. 200 | acc: jnp.ndarray(float32), [batch_size]. 201 | weights: jnp.ndarray(float32), [batch_size, num_samples] 202 | """ 203 | t_mids = 0.5 * (t_vals[..., :-1] + t_vals[..., 1:]) 204 | t_dists = t_vals[..., 1:] - t_vals[..., :-1] 205 | delta = t_dists * jnp.linalg.norm(dirs[..., None, :], axis=-1) 206 | # Note that we're quietly turning density from [..., 0] to [...]. 207 | density_delta = density[..., 0] * delta 208 | 209 | alpha = 1 - jnp.exp(-density_delta) 210 | trans = jnp.exp(-jnp.concatenate([ 211 | jnp.zeros_like(density_delta[..., :1]), 212 | jnp.cumsum(density_delta[..., :-1], axis=-1) 213 | ], 214 | axis=-1)) 215 | weights = alpha * trans 216 | 217 | comp_rgb = (weights[..., None] * rgb).sum(axis=-2) 218 | acc = weights.sum(axis=-1) 219 | distance = (weights * t_mids).sum(axis=-1) / acc 220 | distance = jnp.clip( 221 | jnp.nan_to_num(distance, jnp.inf), t_vals[:, 0], t_vals[:, -1]) 222 | if white_bkgd: 223 | comp_rgb = comp_rgb + (1. - acc[..., None]) 224 | return comp_rgb, distance, acc, weights 225 | 226 | 227 | def sample_along_rays(key, origins, directions, radii, num_samples, near, far, 228 | randomized, lindisp, ray_shape): 229 | """Stratified sampling along the rays. 230 | 231 | Args: 232 | key: jnp.ndarray, random generator key. 233 | origins: jnp.ndarray(float32), [batch_size, 3], ray origins. 234 | directions: jnp.ndarray(float32), [batch_size, 3], ray directions. 235 | radii: jnp.ndarray(float32), [batch_size, 3], ray radii. 236 | num_samples: int. 237 | near: jnp.ndarray, [batch_size, 1], near clip. 238 | far: jnp.ndarray, [batch_size, 1], far clip. 239 | randomized: bool, use randomized stratified sampling. 240 | lindisp: bool, sampling linearly in disparity rather than depth. 241 | ray_shape: string, which shape ray to assume. 242 | 243 | Returns: 244 | t_vals: jnp.ndarray, [batch_size, num_samples], sampled z values. 245 | means: jnp.ndarray, [batch_size, num_samples, 3], sampled means. 246 | covs: jnp.ndarray, [batch_size, num_samples, 3, 3], sampled covariances. 247 | """ 248 | batch_size = origins.shape[0] 249 | 250 | t_vals = jnp.linspace(0., 1., num_samples + 1) 251 | if lindisp: 252 | t_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals) 253 | else: 254 | t_vals = near * (1. - t_vals) + far * t_vals 255 | 256 | if randomized: 257 | mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1]) 258 | upper = jnp.concatenate([mids, t_vals[..., -1:]], -1) 259 | lower = jnp.concatenate([t_vals[..., :1], mids], -1) 260 | t_rand = random.uniform(key, [batch_size, num_samples + 1]) 261 | t_vals = lower + (upper - lower) * t_rand 262 | else: 263 | # Broadcast t_vals to make the returned shape consistent. 264 | t_vals = jnp.broadcast_to(t_vals, [batch_size, num_samples + 1]) 265 | means, covs = cast_rays(t_vals, origins, directions, radii, ray_shape) 266 | return t_vals, (means, covs) 267 | 268 | 269 | def resample_along_rays(key, origins, directions, radii, t_vals, weights, 270 | randomized, ray_shape, stop_grad, resample_padding): 271 | """Resampling. 272 | 273 | Args: 274 | key: jnp.ndarray(float32), [2,], random number generator. 275 | origins: jnp.ndarray(float32), [batch_size, 3], ray origins. 276 | directions: jnp.ndarray(float32), [batch_size, 3], ray directions. 277 | radii: jnp.ndarray(float32), [batch_size, 3], ray radii. 278 | t_vals: jnp.ndarray(float32), [batch_size, num_samples+1]. 279 | weights: jnp.array(float32), weights for t_vals 280 | randomized: bool, use randomized samples. 281 | ray_shape: string, which kind of shape to assume for the ray. 282 | stop_grad: bool, whether or not to backprop through sampling. 283 | resample_padding: float, added to the weights before normalizing. 284 | 285 | Returns: 286 | t_vals: jnp.ndarray(float32), [batch_size, num_samples+1]. 287 | points: jnp.ndarray(float32), [batch_size, num_samples, 3]. 288 | """ 289 | # Do a blurpool. 290 | weights_pad = jnp.concatenate([ 291 | weights[..., :1], 292 | weights, 293 | weights[..., -1:], 294 | ], 295 | axis=-1) 296 | weights_max = jnp.maximum(weights_pad[..., :-1], weights_pad[..., 1:]) 297 | weights_blur = 0.5 * (weights_max[..., :-1] + weights_max[..., 1:]) 298 | 299 | # Add in a constant (the sampling function will renormalize the PDF). 300 | weights = weights_blur + resample_padding 301 | 302 | new_t_vals = math.sorted_piecewise_constant_pdf( 303 | key, 304 | t_vals, 305 | weights, 306 | t_vals.shape[-1], 307 | randomized, 308 | ) 309 | if stop_grad: 310 | new_t_vals = lax.stop_gradient(new_t_vals) 311 | means, covs = cast_rays(new_t_vals, origins, directions, radii, ray_shape) 312 | return new_t_vals, (means, covs) 313 | -------------------------------------------------------------------------------- /internal/mip_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Unit tests for mip.""" 17 | from absl.testing import absltest 18 | import jax 19 | from jax import random 20 | from jax import test_util as jtu 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | from internal import math 25 | from internal import mip 26 | 27 | 28 | def surface_stats(points): 29 | """Get the sample mean and covariance matrix of a set of matrices [..., d].""" 30 | means = jnp.mean(points, -1) 31 | centered = points - means[..., None] 32 | covs = jnp.mean(centered[..., None, :, :] * centered[..., :, None, :], -1) 33 | return means, covs 34 | 35 | 36 | def sqrtm(mat): 37 | """Take the matrix square root of a PSD matrix [..., d, d].""" 38 | eigval, eigvec = jax.scipy.linalg.eigh(mat) 39 | scaling = jnp.sqrt(jnp.maximum(0., eigval))[..., None, :] 40 | return math.matmul(eigvec * scaling, jnp.moveaxis(eigvec, -2, -1)) 41 | 42 | 43 | def control_points(mean, cov): 44 | """Construct "sigma points" using a matrix sqrt (Cholesky or SVD are fine).""" 45 | sqrtm_cov = sqrtm(cov) # or could be jax.scipy.linalg.cholesky(cov) 46 | offsets = jnp.sqrt(mean.shape[-1] + 0.5) * jnp.concatenate( 47 | [jnp.zeros_like(mean[..., None]), sqrtm_cov, -sqrtm_cov], -1) 48 | return mean[..., None] + offsets 49 | 50 | 51 | def inside_conical_frustum(x, d, t0, t1, r, ttol=1e-6, rtol=1e-6): 52 | """Test if `x` is inside the conical frustum specified by the other inputs.""" 53 | d_normsq = jnp.sum(d**2) 54 | d_norm = jnp.sqrt(d_normsq) 55 | x_normsq = jnp.sum(x**2, -1) 56 | x_norm = jnp.sqrt(x_normsq) 57 | xd = math.matmul(x, d) 58 | is_inside = ( 59 | (t0 - ttol) <= xd / d_normsq) & (xd / d_normsq <= (t1 + ttol)) & ( 60 | (xd / (d_norm * x_norm)) >= 61 | (1 / jnp.sqrt(1 + r**2 / d_normsq) - rtol)) 62 | return is_inside 63 | 64 | 65 | def stable_pos_enc(x, n): 66 | """A stable posenc for very high degrees, courtesy of Sameer Agrawal.""" 67 | sin_x = np.sin(x) 68 | cos_x = np.cos(x) 69 | output = [] 70 | rotmat = np.array([[cos_x, -sin_x], [sin_x, cos_x]], dtype='double') 71 | for _ in range(n): 72 | output.append(rotmat[::-1, 0, :]) 73 | rotmat = np.einsum('ijn,jkn->ikn', rotmat, rotmat) 74 | return np.reshape(np.transpose(np.stack(output, 0), [2, 1, 0]), [-1, 2 * n]) 75 | 76 | 77 | def sample_conical_frustum(rng, num_samples, d, t0, t1, base_radius): 78 | """Draw random samples from a conical frustum. 79 | 80 | Args: 81 | rng: The RNG seed. 82 | num_samples: int, the number of samples to draw. 83 | d: jnp.float32 3-vector, the axis of the cone. 84 | t0: float, the starting distance of the frustum. 85 | t1: float, the ending distance of the frustum. 86 | base_radius: float, the scale of the radius as a function of distance. 87 | 88 | Returns: 89 | A matrix of samples. 90 | """ 91 | key, rng = random.split(rng) 92 | u = random.uniform(key, shape=[num_samples]) 93 | t = (t0**3 * (1 - u) + t1**3 * u)**(1 / 3) 94 | key, rng = random.split(rng) 95 | theta = random.uniform(key, shape=[num_samples], minval=0, maxval=jnp.pi * 2) 96 | key, rng = random.split(rng) 97 | r = base_radius * t * jnp.sqrt(random.uniform(key, shape=[num_samples])) 98 | 99 | d_norm = d / jnp.linalg.norm(d) 100 | null = jnp.eye(3) - d_norm[:, None] * d_norm[None, :] 101 | basis = jnp.linalg.svd(null)[0][:, :2] 102 | rot_samples = ((basis[:, 0:1] * r * jnp.cos(theta)) + 103 | (basis[:, 1:2] * r * jnp.sin(theta)) + d[:, None] * t).T 104 | return rot_samples 105 | 106 | 107 | def generate_random_cylinder(rng, num_zs=4): 108 | t0, t1 = [], [] 109 | for _ in range(num_zs): 110 | rng, key = random.split(rng) 111 | z_mean = random.uniform(key, minval=1.5, maxval=3) 112 | rng, key = random.split(rng) 113 | z_delta = random.uniform(key, minval=0.1, maxval=.3) 114 | t0.append(z_mean - z_delta) 115 | t1.append(z_mean + z_delta) 116 | t0 = jnp.array(t0) 117 | t1 = jnp.array(t1) 118 | 119 | rng, key = random.split(rng) 120 | radius = random.uniform(key, minval=0.1, maxval=.2) 121 | 122 | rng, key = random.split(rng) 123 | raydir = random.normal(key, [3]) 124 | raydir = raydir / jnp.sqrt(jnp.sum(raydir**2, -1)) 125 | 126 | rng, key = random.split(rng) 127 | scale = random.uniform(key, minval=0.4, maxval=1.2) 128 | raydir = scale * raydir 129 | 130 | return raydir, t0, t1, radius 131 | 132 | 133 | def generate_random_conical_frustum(rng, num_zs=4): 134 | t0, t1 = [], [] 135 | for _ in range(num_zs): 136 | rng, key = random.split(rng) 137 | z_mean = random.uniform(key, minval=1.5, maxval=3) 138 | rng, key = random.split(rng) 139 | z_delta = random.uniform(key, minval=0.1, maxval=.3) 140 | t0.append(z_mean - z_delta) 141 | t1.append(z_mean + z_delta) 142 | t0 = jnp.array(t0) 143 | t1 = jnp.array(t1) 144 | 145 | rng, key = random.split(rng) 146 | r = random.uniform(key, minval=0.01, maxval=.05) 147 | 148 | rng, key = random.split(rng) 149 | raydir = random.normal(key, [3]) 150 | raydir = raydir / jnp.sqrt(jnp.sum(raydir**2, -1)) 151 | 152 | rng, key = random.split(rng) 153 | scale = random.uniform(key, minval=0.8, maxval=1.2) 154 | raydir = scale * raydir 155 | 156 | return raydir, t0, t1, r 157 | 158 | 159 | def cylinder_to_gaussian_sample(key, 160 | raydir, 161 | t0, 162 | t1, 163 | radius, 164 | padding=1, 165 | num_samples=1000000): 166 | # Sample uniformly from a cube that surrounds the entire conical frustom. 167 | z_max = max(t0, t1) 168 | samples = random.uniform( 169 | key, [num_samples, 3], 170 | minval=jnp.min(raydir) * z_max - padding, 171 | maxval=jnp.max(raydir) * z_max + padding) 172 | 173 | # Grab only the points within the cylinder. 174 | raydir_magsq = jnp.sum(raydir**2, -1, keepdims=True) 175 | proj = (raydir * (samples @ raydir)[:, None]) / raydir_magsq 176 | dist = samples @ raydir 177 | mask = (dist >= raydir_magsq * t0) & (dist <= raydir_magsq * t1) & ( 178 | jnp.sum((proj - samples)**2, -1) < radius**2) 179 | samples = samples[mask, :] 180 | 181 | # Compute their mean and covariance. 182 | mean = jnp.mean(samples, 0) 183 | cov = jnp.cov(samples.T, bias=False) 184 | return mean, cov 185 | 186 | 187 | def conical_frustum_to_gaussian_sample(key, raydir, t0, t1, r): 188 | """A brute-force numerical approximation to conical_frustum_to_gaussian().""" 189 | # Sample uniformly from a cube that surrounds the entire conical frustum. 190 | samples = sample_conical_frustum(key, 100000, raydir, t0, t1, r) 191 | # Compute their mean and covariance. 192 | return surface_stats(samples.T) 193 | 194 | 195 | class MipUtilsTest(jtu.JaxTestCase): 196 | 197 | def test_posenc(self): 198 | n = 10 199 | x = np.linspace(-1, 1, 100) 200 | z = mip.pos_enc(x[:, None], 0, n, append_identity=False) 201 | z_stable = stable_pos_enc(x, n) 202 | self.assertLess(np.max(np.abs(z - z_stable)), 1e-4) 203 | 204 | def test_cylinder_scaling(self): 205 | d = jnp.array([0., 0., 1.]) 206 | t0 = jnp.array([0.3]) 207 | t1 = jnp.array([0.7]) 208 | radius = jnp.array([0.4]) 209 | mean, cov = mip.cylinder_to_gaussian( 210 | d, 211 | t0, 212 | t1, 213 | radius, 214 | False, 215 | ) 216 | scale = 2.7 217 | scaled_mean, scaled_cov = mip.cylinder_to_gaussian( 218 | scale * d, 219 | t0, 220 | t1, 221 | radius, 222 | False, 223 | ) 224 | self.assertAllClose(scale * mean, scaled_mean) 225 | self.assertAllClose(scale**2 * cov[2, 2], scaled_cov[2, 2]) 226 | control = control_points(mean, cov)[0] 227 | control_scaled = control_points(scaled_mean, scaled_cov)[0] 228 | self.assertAllClose(control[:2, :], control_scaled[:2, :]) 229 | self.assertAllClose(control[2, :] * scale, control_scaled[2, :]) 230 | 231 | def test_conical_frustum_scaling(self): 232 | d = jnp.array([0., 0., 1.]) 233 | t0 = jnp.array([0.3]) 234 | t1 = jnp.array([0.7]) 235 | radius = jnp.array([0.4]) 236 | mean, cov = mip.conical_frustum_to_gaussian( 237 | d, 238 | t0, 239 | t1, 240 | radius, 241 | False, 242 | ) 243 | scale = 2.7 244 | scaled_mean, scaled_cov = mip.conical_frustum_to_gaussian( 245 | scale * d, 246 | t0, 247 | t1, 248 | radius, 249 | False, 250 | ) 251 | self.assertAllClose(scale * mean, scaled_mean) 252 | self.assertAllClose(scale**2 * cov[2, 2], scaled_cov[2, 2]) 253 | control = control_points(mean, cov)[0] 254 | control_scaled = control_points(scaled_mean, scaled_cov)[0] 255 | self.assertAllClose(control[:2, :], control_scaled[:2, :]) 256 | self.assertAllClose(control[2, :] * scale, control_scaled[2, :]) 257 | 258 | def test_expected_sin(self): 259 | normal_samples = random.normal(random.PRNGKey(0), (10000,)) 260 | for mu, var in [(0, 1), (1, 3), (-2, .2), (10, 10)]: 261 | sin_mu, sin_var = mip.expected_sin(mu, var) 262 | x = jnp.sin(jnp.sqrt(var) * normal_samples + mu) 263 | self.assertAllClose(sin_mu, jnp.mean(x), atol=1e-2) 264 | self.assertAllClose(sin_var, jnp.var(x), atol=1e-2) 265 | 266 | def test_control_points(self): 267 | rng = random.PRNGKey(0) 268 | batch_size = 10 269 | for num_dims in [1, 2, 3]: 270 | key, rng = random.split(rng) 271 | mean = jax.random.normal(key, [batch_size, num_dims]) 272 | key, rng = random.split(rng) 273 | half_cov = jax.random.normal(key, [batch_size] + [num_dims] * 2) 274 | cov = half_cov @ jnp.moveaxis(half_cov, -1, -2) 275 | 276 | sqrtm_cov = sqrtm(cov) 277 | self.assertArraysAllClose(sqrtm_cov @ sqrtm_cov, cov, atol=1e-5) 278 | 279 | points = control_points(mean, cov) 280 | mean_recon, cov_recon = surface_stats(points) 281 | self.assertArraysAllClose(mean, mean_recon) 282 | self.assertArraysAllClose(cov, cov_recon, atol=1e-5) 283 | 284 | def test_conical_frustum(self): 285 | rng = random.PRNGKey(0) 286 | data = [] 287 | for _ in range(10): 288 | key, rng = random.split(rng) 289 | raydir, t0, t1, r = generate_random_conical_frustum(key) 290 | i_results = [] 291 | for i_t0, i_t1 in zip(t0, t1): 292 | key, rng = random.split(rng) 293 | i_results.append( 294 | conical_frustum_to_gaussian_sample(key, raydir, i_t0, i_t1, r)) 295 | mean_gt, cov_gt = [jnp.stack(x, 0) for x in zip(*i_results)] 296 | data.append((raydir, t0, t1, r, mean_gt, cov_gt)) 297 | raydir, t0, t1, r, mean_gt, cov_gt = [jnp.stack(x, 0) for x in zip(*data)] 298 | diag_cov_gt = jax.vmap(jax.vmap(jnp.diag))(cov_gt) 299 | for diag in [False, True]: 300 | for stable in [False, True]: 301 | mean, cov = mip.conical_frustum_to_gaussian( 302 | raydir, t0, t1, r[..., None], diag, stable=stable) 303 | self.assertAllClose(mean, mean_gt, atol=0.001) 304 | if diag: 305 | self.assertAllClose(cov, diag_cov_gt, atol=0.0002) 306 | else: 307 | self.assertAllClose(cov, cov_gt, atol=0.0002) 308 | 309 | def test_inside_conical_frustum(self): 310 | """This test only tests helper functions used by other tests.""" 311 | rng = random.PRNGKey(0) 312 | for _ in range(20): 313 | key, rng = random.split(rng) 314 | d, t0, t1, r = generate_random_conical_frustum(key, num_zs=1) 315 | key, rng = random.split(rng) 316 | # Sample some points. 317 | samples = sample_conical_frustum(key, 1000000, d, t0, t1, r) 318 | # Check that they're all inside. 319 | check = lambda x: inside_conical_frustum(x, d, t0, t1, r) 320 | self.assertTrue(jnp.all(check(samples))) 321 | # Check that wiggling them a little puts some outside (potentially flaky). 322 | self.assertFalse(jnp.all(check(samples + 1e-3))) 323 | self.assertFalse(jnp.all(check(samples - 1e-3))) 324 | 325 | def test_conical_frustum_stable(self): 326 | rng = random.PRNGKey(0) 327 | for _ in range(10): 328 | key, rng = random.split(rng) 329 | d, t0, t1, r = generate_random_conical_frustum(key) 330 | for diag in [False, True]: 331 | mean, cov = mip.conical_frustum_to_gaussian( 332 | d, t0, t1, r, diag, stable=False) 333 | mean_stable, cov_stable = mip.conical_frustum_to_gaussian( 334 | d, t0, t1, r, diag, stable=True) 335 | self.assertAllClose(mean, mean_stable, atol=1e-7) 336 | self.assertAllClose(cov, cov_stable, atol=1e-5) 337 | 338 | def test_cylinder(self): 339 | rng = random.PRNGKey(0) 340 | data = [] 341 | for _ in range(10): 342 | key, rng = random.split(rng) 343 | raydir, t0, t1, radius = generate_random_cylinder(rng) 344 | key, rng = random.split(rng) 345 | i_results = [] 346 | for i_t0, i_t1 in zip(t0, t1): 347 | i_results.append( 348 | cylinder_to_gaussian_sample(key, raydir, i_t0, i_t1, radius)) 349 | mean_gt, cov_gt = [jnp.stack(x, 0) for x in zip(*i_results)] 350 | data.append((raydir, t0, t1, radius, mean_gt, cov_gt)) 351 | raydir, t0, t1, radius, mean_gt, cov_gt = [ 352 | jnp.stack(x, 0) for x in zip(*data) 353 | ] 354 | mean, cov = mip.cylinder_to_gaussian(raydir, t0, t1, radius[..., None], 355 | False) 356 | self.assertAllClose(mean, mean_gt, atol=0.1) 357 | self.assertAllClose(cov, cov_gt, atol=0.01) 358 | 359 | def test_integrated_pos_enc(self): 360 | num_dims = 2 # The number of input dimensions. 361 | min_deg = 0 362 | max_deg = 4 363 | num_samples = 100000 364 | rng = random.PRNGKey(0) 365 | for _ in range(5): 366 | # Generate a coordinate's mean and covariance matrix. 367 | key, rng = random.split(rng) 368 | mean = random.normal(key, (2,)) 369 | key, rng = random.split(rng) 370 | half_cov = jax.random.normal(key, [num_dims] * 2) 371 | cov = half_cov @ half_cov.T 372 | for diag in [False, True]: 373 | # Generate an IPE. 374 | enc = mip.integrated_pos_enc( 375 | (mean, jnp.diag(cov) if diag else cov), 376 | min_deg, 377 | max_deg, 378 | diag, 379 | ) 380 | 381 | # Draw samples, encode them, and take their mean. 382 | key, rng = random.split(rng) 383 | samples = random.multivariate_normal(key, mean, cov, [num_samples]) 384 | enc_samples = mip.pos_enc( 385 | samples, min_deg, max_deg, append_identity=False) 386 | enc_gt = jnp.mean(enc_samples, 0) 387 | self.assertAllClose(enc, enc_gt, rtol=1e-2, atol=1e-2) 388 | 389 | def test_lift_gaussian_diag(self): 390 | dims, n, m = 3, 10, 4 391 | rng = random.PRNGKey(0) 392 | key, rng = random.split(rng) 393 | d = random.normal(key, [n, dims]) 394 | key, rng = random.split(rng) 395 | z_mean = random.normal(key, [n, m]) 396 | key, rng = random.split(rng) 397 | z_var = jnp.exp(random.normal(key, [n, m])) 398 | key, rng = random.split(rng) 399 | xy_var = jnp.exp(random.normal(key, [n, m])) 400 | mean, cov = mip.lift_gaussian(d, z_mean, z_var, xy_var, diag=False) 401 | mean_diag, cov_diag = mip.lift_gaussian(d, z_mean, z_var, xy_var, diag=True) 402 | self.assertAllClose(mean, mean_diag) 403 | self.assertAllClose(jax.vmap(jax.vmap(jnp.diag))(cov), cov_diag) 404 | 405 | def test_rotated_conic_frustums(self): 406 | # Test that conic frustum Gaussians are closed under rotation. 407 | diag = False 408 | rng = random.PRNGKey(0) 409 | for _ in range(10): 410 | rng, key = random.split(rng) 411 | z_mean = random.uniform(key, minval=1.5, maxval=3) 412 | rng, key = random.split(rng) 413 | z_delta = random.uniform(key, minval=0.1, maxval=.3) 414 | t0 = jnp.array(z_mean - z_delta) 415 | t1 = jnp.array(z_mean + z_delta) 416 | 417 | rng, key = random.split(rng) 418 | r = random.uniform(key, minval=0.1, maxval=.2) 419 | 420 | rng, key = random.split(rng) 421 | d = random.normal(key, [3]) 422 | 423 | mean, cov = mip.conical_frustum_to_gaussian(d, t0, t1, r, diag) 424 | 425 | # Make a random rotation matrix. 426 | rng, key = random.split(rng) 427 | x = random.normal(key, [10, 3]) 428 | rot_mat = x.T @ x 429 | u, _, v = jnp.linalg.svd(rot_mat) 430 | rot_mat = u @ v.T 431 | 432 | mean, cov = mip.conical_frustum_to_gaussian(d, t0, t1, r, diag) 433 | rot_mean, rot_cov = mip.conical_frustum_to_gaussian( 434 | rot_mat @ d, t0, t1, r, diag) 435 | gt_rot_mean, gt_rot_cov = surface_stats( 436 | rot_mat @ control_points(mean, cov)) 437 | 438 | self.assertAllClose(rot_mean, gt_rot_mean) 439 | self.assertAllClose(rot_cov, gt_rot_cov) 440 | 441 | 442 | if __name__ == '__main__': 443 | absltest.main(testLoader=jtu.JaxTestLoader()) 444 | -------------------------------------------------------------------------------- /internal/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Different model implementation plus a general port for all the models.""" 17 | import functools 18 | from typing import Any, Callable 19 | from flax import linen as nn 20 | import gin 21 | import jax 22 | from jax import random 23 | import jax.numpy as jnp 24 | 25 | from internal import mip 26 | from internal import utils 27 | 28 | 29 | @gin.configurable 30 | class MipNerfModel(nn.Module): 31 | """Nerf NN Model with both coarse and fine MLPs.""" 32 | num_samples: int = 128 # The number of samples per level. 33 | num_levels: int = 2 # The number of sampling levels. 34 | resample_padding: float = 0.01 # Dirichlet/alpha "padding" on the histogram. 35 | stop_level_grad: bool = True # If True, don't backprop across levels') 36 | use_viewdirs: bool = True # If True, use view directions as a condition. 37 | lindisp: bool = False # If True, sample linearly in disparity, not in depth. 38 | ray_shape: str = 'cone' # The shape of cast rays ('cone' or 'cylinder'). 39 | min_deg_point: int = 0 # Min degree of positional encoding for 3D points. 40 | max_deg_point: int = 16 # Max degree of positional encoding for 3D points. 41 | deg_view: int = 4 # Degree of positional encoding for viewdirs. 42 | density_activation: Callable[..., Any] = nn.softplus # Density activation. 43 | density_noise: float = 0. # Standard deviation of noise added to raw density. 44 | density_bias: float = -1. # The shift added to raw densities pre-activation. 45 | rgb_activation: Callable[..., Any] = nn.sigmoid # The RGB activation. 46 | rgb_padding: float = 0.001 # Padding added to the RGB outputs. 47 | disable_integration: bool = False # If True, use PE instead of IPE. 48 | 49 | @nn.compact 50 | def __call__(self, rng, rays, randomized, white_bkgd): 51 | """The mip-NeRF Model. 52 | 53 | Args: 54 | rng: jnp.ndarray, random number generator. 55 | rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs. 56 | randomized: bool, use randomized stratified sampling. 57 | white_bkgd: bool, if True, use white as the background (black o.w.). 58 | 59 | Returns: 60 | ret: list, [*(rgb, distance, acc)] 61 | """ 62 | # Construct the MLP. 63 | mlp = MLP() 64 | 65 | ret = [] 66 | for i_level in range(self.num_levels): 67 | key, rng = random.split(rng) 68 | if i_level == 0: 69 | # Stratified sampling along rays 70 | t_vals, samples = mip.sample_along_rays( 71 | key, 72 | rays.origins, 73 | rays.directions, 74 | rays.radii, 75 | self.num_samples, 76 | rays.near, 77 | rays.far, 78 | randomized, 79 | self.lindisp, 80 | self.ray_shape, 81 | ) 82 | else: 83 | t_vals, samples = mip.resample_along_rays( 84 | key, 85 | rays.origins, 86 | rays.directions, 87 | rays.radii, 88 | t_vals, 89 | weights, 90 | randomized, 91 | self.ray_shape, 92 | self.stop_level_grad, 93 | resample_padding=self.resample_padding, 94 | ) 95 | if self.disable_integration: 96 | samples = (samples[0], jnp.zeros_like(samples[1])) 97 | samples_enc = mip.integrated_pos_enc( 98 | samples, 99 | self.min_deg_point, 100 | self.max_deg_point, 101 | ) 102 | 103 | # Point attribute predictions 104 | if self.use_viewdirs: 105 | viewdirs_enc = mip.pos_enc( 106 | rays.viewdirs, 107 | min_deg=0, 108 | max_deg=self.deg_view, 109 | append_identity=True, 110 | ) 111 | raw_rgb, raw_density = mlp(samples_enc, viewdirs_enc) 112 | else: 113 | raw_rgb, raw_density = mlp(samples_enc) 114 | 115 | # Add noise to regularize the density predictions if needed. 116 | if randomized and (self.density_noise > 0): 117 | key, rng = random.split(rng) 118 | raw_density += self.density_noise * random.normal( 119 | key, raw_density.shape, dtype=raw_density.dtype) 120 | 121 | # Volumetric rendering. 122 | rgb = self.rgb_activation(raw_rgb) 123 | rgb = rgb * (1 + 2 * self.rgb_padding) - self.rgb_padding 124 | density = self.density_activation(raw_density + self.density_bias) 125 | comp_rgb, distance, acc, weights = mip.volumetric_rendering( 126 | rgb, 127 | density, 128 | t_vals, 129 | rays.directions, 130 | white_bkgd=white_bkgd, 131 | ) 132 | ret.append((comp_rgb, distance, acc)) 133 | 134 | return ret 135 | 136 | 137 | def construct_mipnerf(rng, example_batch): 138 | """Construct a Neural Radiance Field. 139 | 140 | Args: 141 | rng: jnp.ndarray. Random number generator. 142 | example_batch: dict, an example of a batch of data. 143 | 144 | Returns: 145 | model: nn.Model. Nerf model with parameters. 146 | state: flax.Module.state. Nerf model state for stateful parameters. 147 | """ 148 | model = MipNerfModel() 149 | key, rng = random.split(rng) 150 | init_variables = model.init( 151 | key, 152 | rng=rng, 153 | rays=utils.namedtuple_map(lambda x: x[0], example_batch['rays']), 154 | randomized=False, 155 | white_bkgd=False) 156 | return model, init_variables 157 | 158 | 159 | @gin.configurable 160 | class MLP(nn.Module): 161 | """A simple MLP.""" 162 | net_depth: int = 8 # The depth of the first part of MLP. 163 | net_width: int = 256 # The width of the first part of MLP. 164 | net_depth_condition: int = 1 # The depth of the second part of MLP. 165 | net_width_condition: int = 128 # The width of the second part of MLP. 166 | net_activation: Callable[..., Any] = nn.relu # The activation function. 167 | skip_layer: int = 4 # Add a skip connection to the output of every N layers. 168 | num_rgb_channels: int = 3 # The number of RGB channels. 169 | num_density_channels: int = 1 # The number of density channels. 170 | 171 | @nn.compact 172 | def __call__(self, x, condition=None): 173 | """Evaluate the MLP. 174 | 175 | Args: 176 | x: jnp.ndarray(float32), [batch, num_samples, feature], points. 177 | condition: jnp.ndarray(float32), [batch, feature], if not None, this 178 | variable will be part of the input to the second part of the MLP 179 | concatenated with the output vector of the first part of the MLP. If 180 | None, only the first part of the MLP will be used with input x. In the 181 | original paper, this variable is the view direction. 182 | 183 | Returns: 184 | raw_rgb: jnp.ndarray(float32), with a shape of 185 | [batch, num_samples, num_rgb_channels]. 186 | raw_density: jnp.ndarray(float32), with a shape of 187 | [batch, num_samples, num_density_channels]. 188 | """ 189 | feature_dim = x.shape[-1] 190 | num_samples = x.shape[1] 191 | x = x.reshape([-1, feature_dim]) 192 | dense_layer = functools.partial( 193 | nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform()) 194 | inputs = x 195 | for i in range(self.net_depth): 196 | x = dense_layer(self.net_width)(x) 197 | x = self.net_activation(x) 198 | if i % self.skip_layer == 0 and i > 0: 199 | x = jnp.concatenate([x, inputs], axis=-1) 200 | raw_density = dense_layer(self.num_density_channels)(x).reshape( 201 | [-1, num_samples, self.num_density_channels]) 202 | 203 | if condition is not None: 204 | # Output of the first part of MLP. 205 | bottleneck = dense_layer(self.net_width)(x) 206 | # Broadcast condition from [batch, feature] to 207 | # [batch, num_samples, feature] since all the samples along the same ray 208 | # have the same viewdir. 209 | condition = jnp.tile(condition[:, None, :], (1, num_samples, 1)) 210 | # Collapse the [batch, num_samples, feature] tensor to 211 | # [batch * num_samples, feature] so that it can be fed into nn.Dense. 212 | condition = condition.reshape([-1, condition.shape[-1]]) 213 | x = jnp.concatenate([bottleneck, condition], axis=-1) 214 | # Here use 1 extra layer to align with the original nerf model. 215 | for i in range(self.net_depth_condition): 216 | x = dense_layer(self.net_width_condition)(x) 217 | x = self.net_activation(x) 218 | raw_rgb = dense_layer(self.num_rgb_channels)(x).reshape( 219 | [-1, num_samples, self.num_rgb_channels]) 220 | return raw_rgb, raw_density 221 | 222 | 223 | def render_image(render_fn, rays, rng, chunk=8192): 224 | """Render all the pixels of an image (in test mode). 225 | 226 | Args: 227 | render_fn: function, jit-ed render function. 228 | rays: a `Rays` namedtuple, the rays to be rendered. 229 | rng: jnp.ndarray, random number generator (used in training mode only). 230 | chunk: int, the size of chunks to render sequentially. 231 | 232 | Returns: 233 | rgb: jnp.ndarray, rendered color image. 234 | disp: jnp.ndarray, rendered disparity image. 235 | acc: jnp.ndarray, rendered accumulated weights per pixel. 236 | """ 237 | height, width = rays[0].shape[:2] 238 | num_rays = height * width 239 | rays = utils.namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays) 240 | 241 | host_id = jax.host_id() 242 | results = [] 243 | for i in range(0, num_rays, chunk): 244 | # pylint: disable=cell-var-from-loop 245 | chunk_rays = utils.namedtuple_map(lambda r: r[i:i + chunk], rays) 246 | chunk_size = chunk_rays[0].shape[0] 247 | rays_remaining = chunk_size % jax.device_count() 248 | if rays_remaining != 0: 249 | padding = jax.device_count() - rays_remaining 250 | chunk_rays = utils.namedtuple_map( 251 | lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode='edge'), chunk_rays) 252 | else: 253 | padding = 0 254 | # After padding the number of chunk_rays is always divisible by 255 | # host_count. 256 | rays_per_host = chunk_rays[0].shape[0] // jax.host_count() 257 | start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host 258 | chunk_rays = utils.namedtuple_map(lambda r: utils.shard(r[start:stop]), 259 | chunk_rays) 260 | chunk_results = render_fn(rng, chunk_rays)[-1] 261 | results.append([utils.unshard(x[0], padding) for x in chunk_results]) 262 | # pylint: enable=cell-var-from-loop 263 | rgb, distance, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)] 264 | rgb = rgb.reshape((height, width, -1)) 265 | distance = distance.reshape((height, width)) 266 | acc = acc.reshape((height, width)) 267 | return (rgb, distance, acc) 268 | -------------------------------------------------------------------------------- /internal/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Utility functions.""" 17 | import collections 18 | import os 19 | from os import path 20 | from absl import flags 21 | import dataclasses 22 | import flax 23 | import gin 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | from PIL import Image 28 | 29 | gin.add_config_file_search_path('../') 30 | 31 | 32 | gin.config.external_configurable(flax.nn.relu, module='flax.nn') 33 | gin.config.external_configurable(flax.nn.sigmoid, module='flax.nn') 34 | gin.config.external_configurable(flax.nn.softplus, module='flax.nn') 35 | 36 | 37 | @flax.struct.dataclass 38 | class TrainState: 39 | optimizer: flax.optim.Optimizer 40 | 41 | 42 | @flax.struct.dataclass 43 | class Stats: 44 | loss: float 45 | losses: float 46 | weight_l2: float 47 | psnr: float 48 | psnrs: float 49 | grad_norm: float 50 | grad_abs_max: float 51 | grad_norm_clipped: float 52 | 53 | 54 | Rays = collections.namedtuple( 55 | 'Rays', 56 | ('origins', 'directions', 'viewdirs', 'radii', 'lossmult', 'near', 'far')) 57 | 58 | 59 | # TODO(barron): Do a default.gin thing 60 | @gin.configurable() 61 | @dataclasses.dataclass 62 | class Config: 63 | """Configuration flags for everything.""" 64 | dataset_loader: str = 'multicam' # The type of dataset loader to use. 65 | batching: str = 'all_images' # Batch composition, [single_image, all_images]. 66 | batch_size: int = 4096 # The number of rays/pixels in each batch. 67 | factor: int = 0 # The downsample factor of images, 0 for no downsampling. 68 | spherify: bool = False # Set to True for spherical 360 scenes. 69 | render_path: bool = False # If True, render a path. Used only by LLFF. 70 | llffhold: int = 8 # Use every Nth image for the test set. Used only by LLFF. 71 | lr_init: float = 5e-4 # The initial learning rate. 72 | lr_final: float = 5e-6 # The final learning rate. 73 | lr_delay_steps: int = 2500 # The number of "warmup" learning steps. 74 | lr_delay_mult: float = 0.01 # How much sever the "warmup" should be. 75 | grad_max_norm: float = 0. # Gradient clipping magnitude, disabled if == 0. 76 | grad_max_val: float = 0. # Gradient clipping value, disabled if == 0. 77 | max_steps: int = 1000000 # The number of optimization steps. 78 | save_every: int = 100000 # The number of steps to save a checkpoint. 79 | print_every: int = 100 # The number of steps between reports to tensorboard. 80 | gc_every: int = 10000 # The number of steps between garbage collections. 81 | test_render_interval: int = 1 # The interval between images saved to disk. 82 | disable_multiscale_loss: bool = False # If True, disable multiscale loss. 83 | randomized: bool = True # Use randomized stratified sampling. 84 | near: float = 2. # Near plane distance. 85 | far: float = 6. # Far plane distance. 86 | coarse_loss_mult: float = 0.1 # How much to downweight the coarse loss(es). 87 | weight_decay_mult: float = 0. # The multiplier on weight decay. 88 | white_bkgd: bool = True # If True, use white as the background (black o.w.). 89 | 90 | 91 | def define_common_flags(): 92 | # Define the flags used by both train.py and eval.py 93 | flags.DEFINE_multi_string('gin_file', None, 94 | 'List of paths to the config files.') 95 | flags.DEFINE_multi_string( 96 | 'gin_param', None, 'Newline separated list of Gin parameter bindings.') 97 | flags.DEFINE_string('train_dir', None, 'where to store ckpts and logs') 98 | flags.DEFINE_string('data_dir', None, 'input data directory.') 99 | flags.DEFINE_integer( 100 | 'chunk', 8192, 101 | 'the size of chunks for evaluation inferences, set to the value that' 102 | 'fits your GPU/TPU memory.') 103 | 104 | 105 | def load_config(): 106 | gin.parse_config_files_and_bindings(flags.FLAGS.gin_file, 107 | flags.FLAGS.gin_param) 108 | return Config() 109 | 110 | 111 | def open_file(pth, mode='r'): 112 | return open(pth, mode=mode) 113 | 114 | 115 | def file_exists(pth): 116 | return path.exists(pth) 117 | 118 | 119 | def listdir(pth): 120 | return os.listdir(pth) 121 | 122 | 123 | def isdir(pth): 124 | return path.isdir(pth) 125 | 126 | 127 | def makedirs(pth): 128 | os.makedirs(pth) 129 | 130 | 131 | def namedtuple_map(fn, tup): 132 | """Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple.""" 133 | return type(tup)(*map(fn, tup)) 134 | 135 | 136 | def shard(xs): 137 | """Split data into shards for multiple devices along the first dimension.""" 138 | return jax.tree_map( 139 | lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs) 140 | 141 | 142 | def to_device(xs): 143 | """Transfer data to devices (GPU/TPU).""" 144 | return jax.tree_map(jnp.array, xs) 145 | 146 | 147 | def unshard(x, padding=0): 148 | """Collect the sharded tensor to the shape before sharding.""" 149 | y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:])) 150 | if padding > 0: 151 | y = y[:-padding] 152 | return y 153 | 154 | 155 | def save_img_uint8(img, pth): 156 | """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" 157 | with open_file(pth, 'wb') as f: 158 | Image.fromarray( 159 | (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(jnp.uint8)).save( 160 | f, 'PNG') 161 | 162 | 163 | def save_img_float32(depthmap, pth): 164 | """Save an image (probably a depthmap) to disk as a float32 TIFF.""" 165 | with open_file(pth, 'wb') as f: 166 | Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF') 167 | -------------------------------------------------------------------------------- /internal/vis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Helper functions for visualizing things.""" 17 | import jax 18 | import jax.numpy as jnp 19 | import jax.scipy as jsp 20 | import matplotlib.cm as cm 21 | 22 | 23 | def sinebow(h): 24 | """A cyclic and uniform colormap, see http://basecase.org/env/on-rainbows.""" 25 | f = lambda x: jnp.sin(jnp.pi * x)**2 26 | return jnp.stack([f(3 / 6 - h), f(5 / 6 - h), f(7 / 6 - h)], -1) 27 | 28 | 29 | def convolve2d(z, f): 30 | return jsp.signal.convolve2d( 31 | z, f, mode='same', precision=jax.lax.Precision.HIGHEST) 32 | 33 | 34 | def depth_to_normals(depth): 35 | """Assuming `depth` is orthographic, linearize it to a set of normals.""" 36 | f_blur = jnp.array([1, 2, 1]) / 4 37 | f_edge = jnp.array([-1, 0, 1]) / 2 38 | dy = convolve2d(depth, f_blur[None, :] * f_edge[:, None]) 39 | dx = convolve2d(depth, f_blur[:, None] * f_edge[None, :]) 40 | inv_denom = 1 / jnp.sqrt(1 + dx**2 + dy**2) 41 | normals = jnp.stack([dx * inv_denom, dy * inv_denom, inv_denom], -1) 42 | return normals 43 | 44 | 45 | def visualize_depth(depth, 46 | acc=None, 47 | near=None, 48 | far=None, 49 | ignore_frac=0, 50 | curve_fn=lambda x: -jnp.log(x + jnp.finfo(jnp.float32).eps), 51 | modulus=0, 52 | colormap=None): 53 | """Visualize a depth map. 54 | 55 | Args: 56 | depth: A depth map. 57 | acc: An accumulation map, in [0, 1]. 58 | near: The depth of the near plane, if None then just use the min(). 59 | far: The depth of the far plane, if None then just use the max(). 60 | ignore_frac: What fraction of the depth map to ignore when automatically 61 | generating `near` and `far`. Depends on `acc` as well as `depth'. 62 | curve_fn: A curve function that gets applied to `depth`, `near`, and `far` 63 | before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). 64 | Note that the default choice will flip the sign of depths, so that the 65 | default colormap (turbo) renders "near" as red and "far" as blue. 66 | modulus: If > 0, mod the normalized depth by `modulus`. Use (0, 1]. 67 | colormap: A colormap function. If None (default), will be set to 68 | matplotlib's turbo if modulus==0, sinebow otherwise. 69 | 70 | Returns: 71 | An RGB visualization of `depth`. 72 | """ 73 | if acc is None: 74 | acc = jnp.ones_like(depth) 75 | acc = jnp.where(jnp.isnan(depth), jnp.zeros_like(acc), acc) 76 | 77 | # Sort `depth` and `acc` according to `depth`, then identify the depth values 78 | # that span the middle of `acc`, ignoring `ignore_frac` fraction of `acc`. 79 | sortidx = jnp.argsort(depth.reshape([-1])) 80 | depth_sorted = depth.reshape([-1])[sortidx] 81 | acc_sorted = acc.reshape([-1])[sortidx] 82 | cum_acc_sorted = jnp.cumsum(acc_sorted) 83 | mask = ((cum_acc_sorted >= cum_acc_sorted[-1] * ignore_frac) & 84 | (cum_acc_sorted <= cum_acc_sorted[-1] * (1 - ignore_frac))) 85 | depth_keep = depth_sorted[mask] 86 | 87 | # If `near` or `far` are None, use the highest and lowest non-NaN values in 88 | # `depth_keep` as automatic near/far planes. 89 | eps = jnp.finfo(jnp.float32).eps 90 | near = near or depth_keep[0] - eps 91 | far = far or depth_keep[-1] + eps 92 | 93 | # Curve all values. 94 | depth, near, far = [curve_fn(x) for x in [depth, near, far]] 95 | 96 | # Wrap the values around if requested. 97 | if modulus > 0: 98 | value = jnp.mod(depth, modulus) / modulus 99 | colormap = colormap or sinebow 100 | else: 101 | # Scale to [0, 1]. 102 | value = jnp.nan_to_num( 103 | jnp.clip((depth - jnp.minimum(near, far)) / jnp.abs(far - near), 0, 1)) 104 | colormap = colormap or cm.get_cmap('turbo') 105 | 106 | vis = colormap(value)[:, :, :3] 107 | 108 | # Set non-accumulated pixels to white. 109 | vis = vis * acc[:, :, None] + (1 - acc)[:, :, None] 110 | 111 | return vis 112 | 113 | 114 | def visualize_normals(depth, acc, scaling=None): 115 | """Visualize fake normals of `depth` (optionally scaled to be isotropic).""" 116 | if scaling is None: 117 | mask = ~jnp.isnan(depth) 118 | x, y = jnp.meshgrid( 119 | jnp.arange(depth.shape[1]), jnp.arange(depth.shape[0]), indexing='xy') 120 | xy_var = (jnp.var(x[mask]) + jnp.var(y[mask])) / 2 121 | z_var = jnp.var(depth[mask]) 122 | scaling = jnp.sqrt(xy_var / z_var) 123 | 124 | scaled_depth = scaling * depth 125 | normals = depth_to_normals(scaled_depth) 126 | vis = jnp.isnan(normals) + jnp.nan_to_num((normals + 1) / 2, 0) 127 | 128 | # Set non-accumulated pixels to white. 129 | if acc is not None: 130 | vis = vis * acc[:, :, None] + (1 - acc)[:, :, None] 131 | 132 | return vis 133 | 134 | 135 | def visualize_suite(depth, acc): 136 | """A wrapper around other visualizations for easy integration.""" 137 | vis = { 138 | 'depth': visualize_depth(depth, acc), 139 | 'depth_mod': visualize_depth(depth, acc, modulus=0.1), 140 | 'depth_normals': visualize_normals(depth, acc) 141 | } 142 | return vis 143 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.4,<1.19.0 2 | jax>=0.2.12 3 | jaxlib>=0.1.65 4 | flax>=0.2.2 5 | opencv-python>=4.4.0 6 | Pillow>=7.2.0 7 | tensorboard>=2.4.0 8 | tensorflow>=2.3.1 9 | gin-config 10 | -------------------------------------------------------------------------------- /scripts/convert_blender_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from os import path 4 | 5 | from absl import app 6 | from absl import flags 7 | import jax 8 | import numpy as np 9 | from PIL import Image 10 | 11 | FLAGS = flags.FLAGS 12 | 13 | flags.DEFINE_string('blenderdir', None, 14 | 'Base directory for all Blender data.') 15 | flags.DEFINE_string('outdir', None, 16 | 'Where to save multiscale data.') 17 | flags.DEFINE_integer('n_down', 4, 18 | 'How many levels of downscaling to use.') 19 | 20 | jax.config.parse_flags_with_absl() 21 | 22 | 23 | def load_renderings(data_dir, split): 24 | """Load images and metadata from disk.""" 25 | f = 'transforms_{}.json'.format(split) 26 | with open(path.join(data_dir, f), 'r') as fp: 27 | meta = json.load(fp) 28 | images = [] 29 | cams = [] 30 | print('Loading imgs') 31 | for frame in meta['frames']: 32 | fname = os.path.join(data_dir, frame['file_path'] + '.png') 33 | with open(fname, 'rb') as imgin: 34 | image = np.array(Image.open(imgin), dtype=np.float32) / 255. 35 | cams.append(frame['transform_matrix']) 36 | images.append(image) 37 | ret = {} 38 | ret['images'] = np.stack(images, axis=0) 39 | print('Loaded all images, shape is', ret['images'].shape) 40 | ret['camtoworlds'] = np.stack(cams, axis=0) 41 | w = ret['images'].shape[2] 42 | camera_angle_x = float(meta['camera_angle_x']) 43 | ret['focal'] = .5 * w / np.tan(.5 * camera_angle_x) 44 | return ret 45 | 46 | 47 | def down2(img): 48 | sh = img.shape 49 | return np.mean(np.reshape(img, [sh[0] // 2, 2, sh[1] // 2, 2, -1]), (1, 3)) 50 | 51 | 52 | def convert_to_nerfdata(basedir, newdir, n_down): 53 | """Convert Blender data to multiscale.""" 54 | if not os.path.exists(newdir): 55 | os.makedirs(newdir) 56 | splits = ['train', 'val', 'test'] 57 | bigmeta = {} 58 | # Foreach split in the dataset 59 | for split in splits: 60 | print('Split', split) 61 | # Load everything 62 | data = load_renderings(basedir, split) 63 | 64 | # Save out all the images 65 | imgdir = 'images_{}'.format(split) 66 | os.makedirs(os.path.join(newdir, imgdir), exist_ok=True) 67 | fnames = [] 68 | widths = [] 69 | heights = [] 70 | focals = [] 71 | cam2worlds = [] 72 | lossmults = [] 73 | labels = [] 74 | nears, fars = [], [] 75 | f = data['focal'] 76 | print('Saving images') 77 | for i, img in enumerate(data['images']): 78 | for j in range(n_down): 79 | fname = '{}/{:03d}_d{}.png'.format(imgdir, i, j) 80 | fnames.append(fname) 81 | fname = os.path.join(newdir, fname) 82 | with open(fname, 'wb') as imgout: 83 | img8 = Image.fromarray(np.uint8(img * 255)) 84 | img8.save(imgout) 85 | widths.append(img.shape[1]) 86 | heights.append(img.shape[0]) 87 | focals.append(f / 2**j) 88 | cam2worlds.append(data['camtoworlds'][i].tolist()) 89 | lossmults.append(4.**j) 90 | labels.append(j) 91 | nears.append(2.) 92 | fars.append(6.) 93 | img = down2(img) 94 | 95 | # Create metadata 96 | meta = {} 97 | meta['file_path'] = fnames 98 | meta['cam2world'] = cam2worlds 99 | meta['width'] = widths 100 | meta['height'] = heights 101 | meta['focal'] = focals 102 | meta['label'] = labels 103 | meta['near'] = nears 104 | meta['far'] = fars 105 | meta['lossmult'] = lossmults 106 | 107 | fx = np.array(focals) 108 | fy = np.array(focals) 109 | cx = np.array(meta['width']) * .5 110 | cy = np.array(meta['height']) * .5 111 | arr0 = np.zeros_like(cx) 112 | arr1 = np.ones_like(cx) 113 | k_inv = np.array([ 114 | [arr1 / fx, arr0, -cx / fx], 115 | [arr0, -arr1 / fy, cy / fy], 116 | [arr0, arr0, -arr1], 117 | ]) 118 | k_inv = np.moveaxis(k_inv, -1, 0) 119 | meta['pix2cam'] = k_inv.tolist() 120 | 121 | bigmeta[split] = meta 122 | 123 | for k in bigmeta: 124 | for j in bigmeta[k]: 125 | print(k, j, type(bigmeta[k][j]), np.array(bigmeta[k][j]).shape) 126 | 127 | jsonfile = os.path.join(newdir, 'metadata.json') 128 | with open(jsonfile, 'w') as f: 129 | json.dump(bigmeta, f, ensure_ascii=False, indent=4) 130 | 131 | 132 | def main(unused_argv): 133 | 134 | blenderdir = FLAGS.blenderdir 135 | outdir = FLAGS.outdir 136 | n_down = FLAGS.n_down 137 | if not os.path.exists(outdir): 138 | os.makedirs(outdir) 139 | 140 | dirs = [os.path.join(blenderdir, f) for f in os.listdir(blenderdir)] 141 | dirs = [d for d in dirs if os.path.isdir(d)] 142 | print(dirs) 143 | for basedir in dirs: 144 | print() 145 | newdir = os.path.join(outdir, os.path.basename(basedir)) 146 | print('Converting from', basedir, 'to', newdir) 147 | convert_to_nerfdata(basedir, newdir, n_down) 148 | 149 | 150 | if __name__ == '__main__': 151 | app.run(main) 152 | -------------------------------------------------------------------------------- /scripts/eval_blender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Script for evaluating on the Blender dataset. 17 | 18 | SCENE=lego 19 | EXPERIMENT=debug 20 | TRAIN_DIR=/Users/barron/tmp/nerf_results/$EXPERIMENT/$SCENE 21 | DATA_DIR=/Users/barron/data/nerf_synthetic/$SCENE 22 | 23 | python -m eval \ 24 | --data_dir=$DATA_DIR \ 25 | --train_dir=$TRAIN_DIR \ 26 | --chunk=3076 \ 27 | --gin_file=configs/blender.gin \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /scripts/eval_llff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Script for evaluating on the LLFF dataset. 17 | 18 | SCENE=trex 19 | EXPERIMENT=debug 20 | TRAIN_DIR=/Users/barron/tmp/nerf_results/$EXPERIMENT/$SCENE 21 | DATA_DIR=/Users/barron/data/nerf_llff_data/$SCENE 22 | 23 | python -m eval \ 24 | --data_dir=$DATA_DIR \ 25 | --train_dir=$TRAIN_DIR \ 26 | --chunk=3076 \ 27 | --gin_file=configs/llff.gin \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /scripts/eval_multiblender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Script for evaluating on the multiscale Blender dataset. 17 | 18 | SCENE=lego 19 | EXPERIMENT=debug 20 | TRAIN_DIR=/Users/barron/tmp/nerf_results/$EXPERIMENT/$SCENE 21 | DATA_DIR=/Users/barron/data/down4/$SCENE 22 | 23 | python -m eval \ 24 | --data_dir=$DATA_DIR \ 25 | --train_dir=$TRAIN_DIR \ 26 | --chunk=3076 \ 27 | --gin_file=configs/multiblender.gin \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /scripts/summarize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "yTbIyn5Hsob0" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# This script scrapes the results of training and generates numbers in the\n", 12 | "# format of Tables 1 and 2 in the paper (https://arxiv.org/abs/2103.13415).\n", 13 | "# Numbers are slightly different because of implementation details and\n", 14 | "# randomness across different runs of training." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "id": "wbSGA8PNKvIy" 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import os\n", 26 | "import numpy as np\n", 27 | "from google3.pyglib import gfile" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "id": "gKnbxt2AKz8w" 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "blender_scenes = ['chair', 'drums', 'ficus', 'hotdog', 'lego', 'materials', 'mic', 'ship']\n", 39 | "\n", 40 | "def summarize_results(folder, scene_names, num_buckets):\n", 41 | " metric_names = ['psnrs', 'ssims', 'lpips']\n", 42 | " num_iters = 1000000\n", 43 | " precisions = [3, 4, 4, 4]\n", 44 | "\n", 45 | " results = []\n", 46 | " for scene_name in scene_names:\n", 47 | " test_preds_folder = os.path.join(folder, scene_name, 'test_preds')\n", 48 | " values = []\n", 49 | " for metric_name in metric_names:\n", 50 | " filename = os.path.join(folder, scene_name, 'test_preds', f'{metric_name}_{num_iters}.txt')\n", 51 | " with gfile.Open(filename) as f:\n", 52 | " v = np.array([float(s) for s in f.readline().split(' ')])\n", 53 | " values.append(np.mean(np.reshape(v, [-1, num_buckets]), 0))\n", 54 | " results.append(np.concatenate(values))\n", 55 | " avg_results = np.mean(np.array(results), 0)\n", 56 | "\n", 57 | " psnr, ssim, lpips = np.mean(np.reshape(avg_results, [-1, num_buckets]), 1)\n", 58 | "\n", 59 | " mse = np.exp(-0.1 * np.log(10.) * psnr)\n", 60 | " dssim = np.sqrt(1 - ssim)\n", 61 | " avg_avg = np.exp(np.mean(np.log(np.array([mse, dssim, lpips]))))\n", 62 | "\n", 63 | " s = []\n", 64 | " for i, v in enumerate(np.reshape(avg_results, [-1, num_buckets])):\n", 65 | " s.append(' '.join([f'{s:0.{precisions[i]}f}' for s in v]))\n", 66 | " s.append(f'{avg_avg:0.{precisions[-1]}f}')\n", 67 | " return ' | '.join(s)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": { 74 | "executionInfo": { 75 | "elapsed": 10288, 76 | "status": "ok", 77 | "timestamp": 1619902906249, 78 | "user": { 79 | "displayName": "", 80 | "photoUrl": "", 81 | "userId": "" 82 | }, 83 | "user_tz": 420 84 | }, 85 | "id": "YxsRO0c5kC7F", 86 | "outputId": "92e7f5e7-ef8b-4694-a89d-321748934bd7" 87 | }, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "mip-NeRF: 32.634 34.350 35.505 35.636 | 0.9578 0.9703 0.9786 0.9834 | 0.0469 0.0260 0.0168 0.0120 | 0.0114\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "# The Multiscale Blender benchmark.\n", 99 | "# These numbers roughly correspond to the \"Mip-NeRF\" row of Table 1.\n", 100 | "print('mip-NeRF: ' + summarize_results('/cns/oz-d/home/barron/nerf/mipnerf/multiblender', blender_scenes, 4))" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "executionInfo": { 108 | "elapsed": 8608, 109 | "status": "ok", 110 | "timestamp": 1619902914863, 111 | "user": { 112 | "displayName": "", 113 | "photoUrl": "", 114 | "userId": "" 115 | }, 116 | "user_tz": 420 117 | }, 118 | "id": "00fXHEbBpNPR", 119 | "outputId": "8f31319d-b724-43b2-c27b-cf4d9cc99d0b" 120 | }, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "mip-NeRF: 33.085 | 0.9605 | 0.0425 | 0.0161\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "# The (single-scale) Blender benchmark.\n", 132 | "# These numbers roughly correspond to the \"Mip-NeRF\" row of Table 2.\n", 133 | "print('mip-NeRF: ' + summarize_results('/cns/oz-d/home/barron/nerf/mipnerf/blender', blender_scenes, 1))" 134 | ] 135 | } 136 | ], 137 | "metadata": { 138 | "colab": { 139 | "collapsed_sections": [], 140 | "last_runtime": { 141 | "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", 142 | "kind": "private" 143 | }, 144 | "name": "summarize.ipynb", 145 | "provenance": [ 146 | { 147 | "file_id": "/piper/depot/google3/experimental/users/barron/prob_nerf/scripts/eval_multi.ipynb?workspaceId=barron:mipnerf_loglin::citc", 148 | "timestamp": 1618455902794 149 | }, 150 | { 151 | "file_id": "/piper/depot/google3/experimental/users/barron/prob_nerf/scripts/Pre_NeRF_Eval.ipynb?workspaceId=barron:jaxnerf_mono5::citc", 152 | "timestamp": 1614038274387 153 | }, 154 | { 155 | "file_id": "10opVizeODokMJ10R7hwq7qVyLmYZx_ZA", 156 | "timestamp": 1613166364224 157 | } 158 | ] 159 | }, 160 | "kernelspec": { 161 | "display_name": "Python 3", 162 | "name": "python3" 163 | } 164 | }, 165 | "nbformat": 4, 166 | "nbformat_minor": 0 167 | } 168 | -------------------------------------------------------------------------------- /scripts/train_blender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Script for training on the Blender dataset. 17 | 18 | SCENE=lego 19 | EXPERIMENT=debug 20 | TRAIN_DIR=/Users/barron/tmp/nerf_results/$EXPERIMENT/$SCENE 21 | DATA_DIR=/Users/barron/data/nerf_synthetic/$SCENE 22 | 23 | rm $TRAIN_DIR/* 24 | python -m train \ 25 | --data_dir=$DATA_DIR \ 26 | --train_dir=$TRAIN_DIR \ 27 | --gin_file=configs/blender.gin \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /scripts/train_llff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Script for training on the LLFF dataset. 17 | 18 | SCENE=trex 19 | EXPERIMENT=debug 20 | TRAIN_DIR=/Users/barron/tmp/nerf_results/$EXPERIMENT/$SCENE 21 | DATA_DIR=/Users/barron/data/nerf_llff_data/$SCENE 22 | 23 | rm $TRAIN_DIR/* 24 | python -m train \ 25 | --data_dir=$DATA_DIR \ 26 | --train_dir=$TRAIN_DIR \ 27 | --gin_file=configs/llff.gin \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /scripts/train_multiblender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Script for training on the multiscale Blender dataset. 17 | 18 | SCENE=lego 19 | EXPERIMENT=debug 20 | TRAIN_DIR=/Users/barron/tmp/nerf_results/$EXPERIMENT/$SCENE 21 | DATA_DIR=/Users/barron/data/down4/$SCENE 22 | 23 | rm $TRAIN_DIR/* 24 | python -m train \ 25 | --data_dir=$DATA_DIR \ 26 | --train_dir=$TRAIN_DIR \ 27 | --gin_file=configs/multiblender.gin \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Training script for Nerf.""" 17 | 18 | import functools 19 | import gc 20 | import time 21 | from absl import app 22 | from absl import flags 23 | import flax 24 | from flax.metrics import tensorboard 25 | from flax.training import checkpoints 26 | import jax 27 | from jax import random 28 | import jax.numpy as jnp 29 | import numpy as np 30 | 31 | from internal import datasets 32 | from internal import math 33 | from internal import models 34 | from internal import utils 35 | from internal import vis 36 | 37 | 38 | FLAGS = flags.FLAGS 39 | utils.define_common_flags() 40 | flags.DEFINE_integer('render_every', 5000, 41 | 'The number of steps between test set image renderings.') 42 | 43 | jax.config.parse_flags_with_absl() 44 | 45 | 46 | def train_step(model, config, rng, state, batch, lr): 47 | """One optimization step. 48 | 49 | Args: 50 | model: The linen model. 51 | config: The configuration. 52 | rng: jnp.ndarray, random number generator. 53 | state: utils.TrainState, state of the model/optimizer. 54 | batch: dict, a mini-batch of data for training. 55 | lr: float, real-time learning rate. 56 | 57 | Returns: 58 | new_state: utils.TrainState, new training state. 59 | stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)]. 60 | rng: jnp.ndarray, updated random number generator. 61 | """ 62 | rng, key = random.split(rng) 63 | 64 | def loss_fn(variables): 65 | 66 | def tree_sum_fn(fn): 67 | return jax.tree_util.tree_reduce( 68 | lambda x, y: x + fn(y), variables, initializer=0) 69 | 70 | weight_l2 = config.weight_decay_mult * ( 71 | tree_sum_fn(lambda z: jnp.sum(z**2)) / 72 | tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape)))) 73 | 74 | ret = model.apply( 75 | variables, 76 | key, 77 | batch['rays'], 78 | randomized=config.randomized, 79 | white_bkgd=config.white_bkgd) 80 | 81 | mask = batch['rays'].lossmult 82 | if config.disable_multiscale_loss: 83 | mask = jnp.ones_like(mask) 84 | 85 | losses = [] 86 | for (rgb, _, _) in ret: 87 | losses.append( 88 | (mask * (rgb - batch['pixels'][..., :3])**2).sum() / mask.sum()) 89 | losses = jnp.array(losses) 90 | 91 | loss = ( 92 | config.coarse_loss_mult * jnp.sum(losses[:-1]) + losses[-1] + weight_l2) 93 | 94 | stats = utils.Stats( 95 | loss=loss, 96 | losses=losses, 97 | weight_l2=weight_l2, 98 | psnr=0.0, 99 | psnrs=0.0, 100 | grad_norm=0.0, 101 | grad_abs_max=0.0, 102 | grad_norm_clipped=0.0, 103 | ) 104 | return loss, stats 105 | 106 | (_, stats), grad = ( 107 | jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target)) 108 | grad = jax.lax.pmean(grad, axis_name='batch') 109 | stats = jax.lax.pmean(stats, axis_name='batch') 110 | 111 | def tree_norm(tree): 112 | return jnp.sqrt( 113 | jax.tree_util.tree_reduce( 114 | lambda x, y: x + jnp.sum(y**2), tree, initializer=0)) 115 | 116 | if config.grad_max_val > 0: 117 | clip_fn = lambda z: jnp.clip(z, -config.grad_max_val, config.grad_max_val) 118 | grad = jax.tree_util.tree_map(clip_fn, grad) 119 | 120 | grad_abs_max = jax.tree_util.tree_reduce( 121 | lambda x, y: jnp.maximum(x, jnp.max(jnp.abs(y))), grad, initializer=0) 122 | 123 | grad_norm = tree_norm(grad) 124 | if config.grad_max_norm > 0: 125 | mult = jnp.minimum(1, config.grad_max_norm / (1e-7 + grad_norm)) 126 | grad = jax.tree_util.tree_map(lambda z: mult * z, grad) 127 | grad_norm_clipped = tree_norm(grad) 128 | 129 | new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr) 130 | new_state = state.replace(optimizer=new_optimizer) 131 | 132 | psnrs = math.mse_to_psnr(stats.losses) 133 | stats = utils.Stats( 134 | loss=stats.loss, 135 | losses=stats.losses, 136 | weight_l2=stats.weight_l2, 137 | psnr=psnrs[-1], 138 | psnrs=psnrs, 139 | grad_norm=grad_norm, 140 | grad_abs_max=grad_abs_max, 141 | grad_norm_clipped=grad_norm_clipped, 142 | ) 143 | 144 | return new_state, stats, rng 145 | 146 | 147 | def main(unused_argv): 148 | rng = random.PRNGKey(20200823) 149 | # Shift the numpy random seed by host_id() to shuffle data loaded by different 150 | # hosts. 151 | np.random.seed(20201473 + jax.host_id()) 152 | 153 | config = utils.load_config() 154 | 155 | if config.batch_size % jax.device_count() != 0: 156 | raise ValueError('Batch size must be divisible by the number of devices.') 157 | 158 | dataset = datasets.get_dataset('train', FLAGS.data_dir, config) 159 | test_dataset = datasets.get_dataset('test', FLAGS.data_dir, config) 160 | 161 | rng, key = random.split(rng) 162 | model, variables = models.construct_mipnerf(key, dataset.peek()) 163 | num_params = jax.tree_util.tree_reduce( 164 | lambda x, y: x + jnp.prod(jnp.array(y.shape)), variables, initializer=0) 165 | print(f'Number of parameters being optimized: {num_params}') 166 | optimizer = flax.optim.Adam(config.lr_init).create(variables) 167 | state = utils.TrainState(optimizer=optimizer) 168 | del optimizer, variables 169 | 170 | learning_rate_fn = functools.partial( 171 | math.learning_rate_decay, 172 | lr_init=config.lr_init, 173 | lr_final=config.lr_final, 174 | max_steps=config.max_steps, 175 | lr_delay_steps=config.lr_delay_steps, 176 | lr_delay_mult=config.lr_delay_mult) 177 | 178 | train_pstep = jax.pmap( 179 | functools.partial(train_step, model, config), 180 | axis_name='batch', 181 | in_axes=(0, 0, 0, None), 182 | donate_argnums=(2,)) 183 | 184 | # Because this is only used for test set rendering, we disable randomization. 185 | def render_eval_fn(variables, _, rays): 186 | return jax.lax.all_gather( 187 | model.apply( 188 | variables, 189 | random.PRNGKey(0), # Unused. 190 | rays, 191 | randomized=False, 192 | white_bkgd=config.white_bkgd), 193 | axis_name='batch') 194 | 195 | render_eval_pfn = jax.pmap( 196 | render_eval_fn, 197 | in_axes=(None, None, 0), # Only distribute the data input. 198 | donate_argnums=(2,), 199 | axis_name='batch', 200 | ) 201 | 202 | ssim_fn = jax.jit(functools.partial(math.compute_ssim, max_val=1.)) 203 | 204 | if not utils.isdir(FLAGS.train_dir): 205 | utils.makedirs(FLAGS.train_dir) 206 | state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) 207 | # Resume training a the step of the last checkpoint. 208 | init_step = state.optimizer.state.step + 1 209 | state = flax.jax_utils.replicate(state) 210 | 211 | if jax.host_id() == 0: 212 | summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) 213 | 214 | # Prefetch_buffer_size = 3 x batch_size 215 | pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) 216 | rng = rng + jax.host_id() # Make random seed separate across hosts. 217 | keys = random.split(rng, jax.local_device_count()) # For pmapping RNG keys. 218 | gc.disable() # Disable automatic garbage collection for efficiency. 219 | stats_trace = [] 220 | reset_timer = True 221 | for step, batch in zip(range(init_step, config.max_steps + 1), pdataset): 222 | if reset_timer: 223 | t_loop_start = time.time() 224 | reset_timer = False 225 | lr = learning_rate_fn(step) 226 | state, stats, keys = train_pstep(keys, state, batch, lr) 227 | if jax.host_id() == 0: 228 | stats_trace.append(stats) 229 | if step % config.gc_every == 0: 230 | gc.collect() 231 | 232 | # Log training summaries. This is put behind a host_id check because in 233 | # multi-host evaluation, all hosts need to run inference even though we 234 | # only use host 0 to record results. 235 | if jax.host_id() == 0: 236 | if step % config.print_every == 0: 237 | summary_writer.scalar('num_params', num_params, step) 238 | summary_writer.scalar('train_loss', stats.loss[0], step) 239 | summary_writer.scalar('train_psnr', stats.psnr[0], step) 240 | for i, l in enumerate(stats.losses[0]): 241 | summary_writer.scalar(f'train_losses_{i}', l, step) 242 | for i, p in enumerate(stats.psnrs[0]): 243 | summary_writer.scalar(f'train_psnrs_{i}', p, step) 244 | summary_writer.scalar('weight_l2', stats.weight_l2[0], step) 245 | avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) 246 | avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) 247 | max_grad_norm = np.max( 248 | np.concatenate([s.grad_norm for s in stats_trace])) 249 | avg_grad_norm = np.mean( 250 | np.concatenate([s.grad_norm for s in stats_trace])) 251 | max_clipped_grad_norm = np.max( 252 | np.concatenate([s.grad_norm_clipped for s in stats_trace])) 253 | max_grad_max = np.max( 254 | np.concatenate([s.grad_abs_max for s in stats_trace])) 255 | stats_trace = [] 256 | summary_writer.scalar('train_avg_loss', avg_loss, step) 257 | summary_writer.scalar('train_avg_psnr', avg_psnr, step) 258 | summary_writer.scalar('train_max_grad_norm', max_grad_norm, step) 259 | summary_writer.scalar('train_avg_grad_norm', avg_grad_norm, step) 260 | summary_writer.scalar('train_max_clipped_grad_norm', 261 | max_clipped_grad_norm, step) 262 | summary_writer.scalar('train_max_grad_max', max_grad_max, step) 263 | summary_writer.scalar('learning_rate', lr, step) 264 | steps_per_sec = config.print_every / (time.time() - t_loop_start) 265 | reset_timer = True 266 | rays_per_sec = config.batch_size * steps_per_sec 267 | summary_writer.scalar('train_steps_per_sec', steps_per_sec, step) 268 | summary_writer.scalar('train_rays_per_sec', rays_per_sec, step) 269 | precision = int(np.ceil(np.log10(config.max_steps))) + 1 270 | print(('{:' + '{:d}'.format(precision) + 'd}').format(step) + 271 | f'/{config.max_steps:d}: ' + f'i_loss={stats.loss[0]:0.4f}, ' + 272 | f'avg_loss={avg_loss:0.4f}, ' + 273 | f'weight_l2={stats.weight_l2[0]:0.2e}, ' + f'lr={lr:0.2e}, ' + 274 | f'{rays_per_sec:0.0f} rays/sec') 275 | if step % config.save_every == 0: 276 | state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) 277 | checkpoints.save_checkpoint( 278 | FLAGS.train_dir, state_to_save, int(step), keep=100) 279 | 280 | # Test-set evaluation. 281 | if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: 282 | # We reuse the same random number generator from the optimization step 283 | # here on purpose so that the visualization matches what happened in 284 | # training. 285 | t_eval_start = time.time() 286 | eval_variables = jax.device_get(jax.tree_map(lambda x: x[0], 287 | state)).optimizer.target 288 | test_case = next(test_dataset) 289 | pred_color, pred_distance, pred_acc = models.render_image( 290 | functools.partial(render_eval_pfn, eval_variables), 291 | test_case['rays'], 292 | keys[0], 293 | chunk=FLAGS.chunk) 294 | 295 | vis_suite = vis.visualize_suite(pred_distance, pred_acc) 296 | 297 | # Log eval summaries on host 0. 298 | if jax.host_id() == 0: 299 | psnr = math.mse_to_psnr(((pred_color - test_case['pixels'])**2).mean()) 300 | ssim = ssim_fn(pred_color, test_case['pixels']) 301 | eval_time = time.time() - t_eval_start 302 | num_rays = jnp.prod(jnp.array(test_case['rays'].directions.shape[:-1])) 303 | rays_per_sec = num_rays / eval_time 304 | summary_writer.scalar('test_rays_per_sec', rays_per_sec, step) 305 | print(f'Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec') 306 | summary_writer.scalar('test_psnr', psnr, step) 307 | summary_writer.scalar('test_ssim', ssim, step) 308 | summary_writer.image('test_pred_color', pred_color, step) 309 | for k, v in vis_suite.items(): 310 | summary_writer.image('test_pred_' + k, v, step) 311 | summary_writer.image('test_pred_acc', pred_acc, step) 312 | summary_writer.image('test_target', test_case['pixels'], step) 313 | 314 | if config.max_steps % config.save_every != 0: 315 | state = jax.device_get(jax.tree_map(lambda x: x[0], state)) 316 | checkpoints.save_checkpoint( 317 | FLAGS.train_dir, state, int(config.max_steps), keep=100) 318 | 319 | 320 | if __name__ == '__main__': 321 | app.run(main) 322 | --------------------------------------------------------------------------------