├── CHANGELOG.md ├── CODE_OF_CONDUCT ├── CONTRIBUTING ├── LICENSE ├── README.md ├── batch_builder.py ├── batch_composer.py ├── binary_dataset.py ├── canonical_model.py ├── canonical_model_ngp.py ├── configs ├── .DS_Store └── default.txt ├── data_handler.py ├── data_loader.py ├── data_loader_blender.py ├── deformation_model.py ├── deformation_model_ngp.py ├── environment.yml ├── evaluation.py ├── losses.py ├── misc └── teaser.png ├── multi_gpu.py ├── optimizer.py ├── path_renderer.py ├── post_correction.py ├── pre_correction.py ├── preprocess.py ├── pruning.py ├── ray_builder.py ├── renderer.py ├── rendering.py ├── requirements.txt ├── scene.py ├── scheduler.py ├── settings.py ├── smart_adam.py ├── state_loader_saver.py ├── train.py ├── trainer.py ├── utils.py └── visualizer.py /CHANGELOG.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SceNeRFlow/4c9b6b3e1b83935e43e33768b0b2760405e68641/CHANGELOG.md -------------------------------------------------------------------------------- /CODE_OF_CONDUCT: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | # Contributing to SceNeRFlow 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## License 26 | By contributing to SceNeRFlow, you agree that your contributions will be licensed 27 | under the LICENSE file in the root directory of this source tree. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SceNeRFlow 2 | 3 | This is the official code release for *SceNeRFlow: Time-Consistent Reconstruction of General Dynamic Scenes* (3DV 2024), a NeRF-based method to reconstruct a general, non-rigid scene in a time-consistent manner, including large motion. This work was done by Edith Tretschk, Vladislav Golyanik, Michael Zollhöfer, Aljaž Božič, Christoph Lassner, and Christian Theobalt. 4 | [[Project Page]](https://vcai.mpi-inf.mpg.de/projects/scenerflow/) [[ArXiv]](https://arxiv.org/abs/2308.08258) 5 | 6 | ![Teaser figure](misc/teaser.png) 7 | 8 | * [Installation](https://github.com/facebookresearch/SceNeRFlow#installation) 9 | * [Run on a Scene](https://github.com/facebookresearch/SceNeRFlow#run-on-a-scene) 10 | - [Input Format](https://github.com/facebookresearch/SceNeRFlow#input-format) 11 | - [Input Conversion](https://github.com/facebookresearch/SceNeRFlow#input-conversion) 12 | - [Reconstruction](https://github.com/facebookresearch/SceNeRFlow#reconstruction) 13 | - [Rendering](https://github.com/facebookresearch/SceNeRFlow#rendering) 14 | - [Evaluation](https://github.com/facebookresearch/SceNeRFlow#evaluation) 15 | * [Special Features](https://github.com/facebookresearch/SceNeRFlow#special-features) 16 | * [Other Methods](https://github.com/facebookresearch/SceNeRFlow#other-methods) 17 | - [Instructions for SNF-A, SNF-AG](https://github.com/facebookresearch/SceNeRFlow#instructions-for-snf-a-snf-ag) 18 | - [Instructions for Ablations](https://github.com/facebookresearch/SceNeRFlow#instructions-for-ablations) 19 | - [Instructions for Re-Implementations of PREF, NR-NeRF, and D-NeRF](https://github.com/facebookresearch/SceNeRFlow#instructions-for-re-implementations-of-pref-nr-nerf-and-d-nerf) 20 | * [Citation](https://github.com/facebookresearch/SceNeRFlow#citation) 21 | * [License](https://github.com/facebookresearch/SceNeRFlow#license) 22 | 23 | ## Installation 24 | 25 | 1) Install [Miniconda](https://docs.conda.io/projects/miniconda/en/latest/) (this repository is tested with version 23.9.0) 26 | 2) Clone this repository: `git clone https://github.com/facebookresearch/SceNeRFlow` 27 | 3) Navigate into the root directory: `cd SceNeRFlow` 28 | 4) Create the environment: `conda env create -f environment.yml` 29 | 5) Activate the environment: `conda activate snf` 30 | 6) Install further dependencies: `pip install -r requirements.txt` 31 | 7) Install [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn): 32 | * `git clone --recursive https://github.com/nvlabs/tiny-cuda-nn` 33 | * `cd tiny-cuda-nn` 34 | * `git checkout 9a17f05` 35 | * in `include/tiny-cuda-nn/encodings/grid.h`: uncomment "case 1" in `create_grid_encoding_templated` 36 | * in `include/tiny-cuda-nn/common.h`: replace `#define TCNN_HALF_PRECISION (!(TCNN_MIN_GPU_ARCH == 61 || TCNN_MIN_GPU_ARCH <= 52))` with `#define TCNN_HALF_PRECISION 0` 37 | * in tiny-cuda-nn's root directory, run `cmake . -B build -DCMAKE_CUDA_COMPILER=/usr/lib/cuda-11.3/bin/nvcc && cd build && cmake --build . --config RelWithDebInfo -j && cd ../bindings/torch/ && python setup.py install` 38 | 39 | ## Run on a Scene 40 | 41 | If you want to try out the code with an existing scene, a short sequence compatible with this codebase is available [here](https://4dqv.mpi-inf.mpg.de/data/synthetic_scene.zip.zip). 42 | 43 | If you use this scene, please set `weight_hard_surface_loss = 100` in `configs/default.txt` rather than the default `weight_hard_surface_loss = 1` like the scenes in the paper. (See below for how to set this weight in general.) 44 | 45 | ### Config File 46 | 47 | Adapt `datadir` and `expname` in `configs/default.txt`. 48 | 49 | The test cameras (used only for test-time rendering) are specified via their extrinsic names as `test_cameras = ['some_cam', 'another_cam']` (to specify test images freely, change `self.test_imageids = ...` in `data_loader_blender.py`). 50 | 51 | In general, `weight_hard_surface_loss` can be determined as follows: Look at the renderings of the canonical model (i.e. of the first timestep, after 20k training iterations) in the `1_renderings` subfolder. If the dynamic foreground disappears, keep decreasing this weight by a factor of 3 until this first rendering looks reasonable. If the dynamic foreground turns into a large cloud without a recognizable surface, keep increasing by a factor of 3. 52 | 53 | ### Input Format 54 | 55 | The method takes as input multi-view RGB images with extrinsics, intrinsics, near and far planes, synchronized timestamps, and background images. 56 | 57 | _Images_: The `images` folder contains the multi-view RGB images, while `background` contains the background images. The images in `images` and in `background` can be named arbitrarily. 58 | 59 | _Associations_: The images in `images` are associated with their extrinsics, intrinsics, and timestamps in `frame_to_extrin_intrin_time.json`. The names of the extrinsic and intrinsic cameras (e.g. `StaticCamera.001`) can be arbitrary. Furthermore, the background images are associated with extrinsics and intrinsics in `background_to_extrin_intrin.json`. 60 | 61 | _Extrinsics_: The camera extrinsic translation is the position of the camera in world space. The same unit of length is also used for the near and far plane values. The camera extrinsic rotation is camera-to-world, `R * c = w`. The camera coordinate system has the x-axis pointing to the right, y up, and z back. 62 | 63 | _Intrinsics_: The intrinsics use pixel units. The distortion parameters (k1, k2, p1, p2, k3, s1, s2, s3, s4) follow the OpenCV definitions. The codebase handles distortions by iteratively optimizing for undistorted ray directions for each pixel in the distorted image (thanks [Nerfies](https://github.com/google/nerfies)). This is experimental and prone to divergence. If this happens, consider to instead provide undistorted images and to set all distortion parameters in `intrinsics.json` to zeros. 64 | 65 | ### Input Conversion 66 | 67 | To convert the image files into the custom dataset format used by this codebase, run the following command: 68 | ``` 69 | python preprocess.py ./SOME_INPUT_SCENE_FOLDER 70 | ``` 71 | A different output folder can be used as well: `python preprocess.py ./SOME_INPUT_SCENE_FOLDER ./SOME_OTHER_FOLDER`. Make sure to enter the new output folder in the config file under `datadir`. 72 | 73 | _Pruning_: Setting `debug = True` in `_space_carving()` will output the voxel grids that are used at each timestep as `.obj` point clouds and thereby allows to check whether the voxel grids are plausible. If the dataset is an outside-in 360° camera setup like in a studio or light stage where the entire foreground lies inside the bounding box spanned by the cameras, also try to replace both occurences of `studio_bounding_box = False` with `studio_bounding_box = True` in `_space_carving()` for a nicer voxel grid for pruning. 74 | 75 | ### Reconstruction 76 | 77 | To train the reconstruction, use: 78 | ``` 79 | python train.py --config ./configs/default.txt --no_reload 80 | ``` 81 | The results will be in `./results/SOME_EXPERIMENT_NAME`, where `SOME_EXPERIMENT_NAME` is `expname` in `config.txt`. 82 | 83 | _Note_: (1) To continue an interrupted training, remove the `--no_reload` flag. (2) Training creates a lot of large files, namely a few hundred MBs per timestep in the scene. 84 | 85 | ### Rendering 86 | 87 | To render into the test cameras: 88 | ``` 89 | python rendering.py ./results/SOME_EXPERIMENT_NAME 90 | ``` 91 | The results will be in `./results/SOME_EXPERIMENT_NAME/4_outputs` 92 | 93 | _Circular_: It is also possible to render an inward-facing circular trajectory: 94 | ``` 95 | python rendering.py ./results/SOME_EXPERIMENT_NAME --circular 96 | ``` 97 | The parameters of the circular rendering are at the top of `test_time_rendering_circular()`, before `output_folder = ...`. The path of the circular trajectory is defined in `get_circular_trajectory()`. This functions uses `mode = ...` to specify whether the y-axis or z-axis is the vertical axis. 98 | 99 | _Editing_: Several basic geometry and appearance editing tasks in the canonical space can be done. To this end, please refer to `_editing()` in `rendering.py`. In the code, joints and their surroundings are used as regions to modify. Specifically, the lines below `MODIFY HERE` in `_editing()` offer multiple different editing possiblities in a `radius` around any joint. To provide the joints as input, create a file `joints.json` in the root directory of the dataset, e.g. `./INPUT_SCENE_FOLDER/joints.json`. The content of `joints.json` should be structured as follows: 100 | ``` 101 | { 102 | "0": { # person id, starting from 0 103 | "0": { # joint id, starting from 0 104 | "joint_name": "top_of_head", # joint name 105 | "joint_position": [ # joint position in world space at t=0 106 | [ 107 | 2392.63, 108 | 1658.34, 109 | -231.905 110 | ] 111 | }, 112 | ... # more joints 113 | } 114 | ... # more people 115 | } 116 | ``` 117 | Editing is turned off by default. To turn it on, set `do_edit = True` at the beginning of either `test_time_rendering_circular()` or `test_time_rendering_test_cameras()`. Note that instead of joint positions, any arbitrary location in canonical space can also easily be used inside `_editing()`; joints are just easier to determine as an input to that function. 118 | 119 | ### Evaluation 120 | 121 | To evaluate the (unmasked) test images, use: 122 | ``` 123 | python evaluation.py ./results/SOME_EXPERIMENT_NAME 124 | ``` 125 | The results will be in `./results/SOME_EXPERIMENT_NAME/4_outputs/novel_view_eval`. All test images will be rendered from scratch. The images from `rendering.py` or from training are not re-used. See `quantitative_evaluation_novel_views()` before `output_folder = ...` for further settings. 126 | 127 | _Masked_: The paper also uses masked evaluation. The code allows for this with `python evaluation.py ./results/SOME_EXPERIMENT_NAME --masked`. The masks are generated with a simple thresholding technique, see the supplementary material for details. The parameters of this thresholding need to be adjusted per scene. This can be done at the beginning of `quantitative_evaluation_novel_views()`. 128 | 129 | _LPIPS_: To enable LPIPS evaluation, install [LPIPS](https://github.com/richzhang/PerceptualSimilarity): 130 | ``` 131 | pip install lpips 132 | ``` 133 | 134 | ## Special Features 135 | 136 | * _Automatic batch splitting_: The code adaptively determines a virtual batch size, such that any batch size should fit into limited GPU memory. See `trainer.py` for details. 137 | * _Robustness to interruption_: The code is robust to being split into short-term server jobs. The results in the paper were obtained by chaining one-hour jobs. The first job needs to have `no_reload = True`, while all follow-up jobs need to have `no_reload = False`. (Note that command line flags take precedence over flags in the config file.) See `state_loader_saver.py` for details. 138 | * _Fast dataset format_: The code uses a basic custom dataset format to allow for fast random access without querying the filesystem. Ideal for clusters with shared filesystem. See `binary_dataset.py` for details. 139 | * _Experimental multi-GPU support_: The code should be able to use multiple GPUs. However, this feature was never tested. See `multi_gpu.py` for details and all references to it. 140 | 141 | ## Other Methods 142 | 143 | The following other methods have not been tested with the cleaned code in this repository. However, they should still work since they did work with the original code used in the paper. 144 | 145 | ### Instructions for SNF-A, SNF-AG 146 | 147 | In addition to a time-invariant canonical model, the paper also present variants that allow for the appearance to change (SNF-A) or for both appearance and geometry to change (SNF-AG). Please refer to the paper and the supplementary material for details. 148 | 149 | These variants can be trained by setting `variant = snfa` or `variant = snfag`. 150 | 151 | ### Instructions for Ablations 152 | 153 | The ablations in the paper can be obtained as follows: 154 | 155 | * _Online optimization_: `tracking_mode = plain` and `optimization_mode = all` 156 | * _Extending the deformation field_: in `losses.py`, set `max_windowed=False` and `smart_maxpool=False` everywhere in `compute()`; in `deformation_model_ngp.py`, set `self.hash_encoding_config["base_resolution"] = 512` 157 | * _Coarse deformations_: `weight_coarse_smooth_deformations = 30` 158 | * _Fine deformations_: at the beginning of `test_time_rendering_test_cameras()` of `rendering.py`, set `only_coarse = True` 159 | 160 | ### Instructions for Re-Implementations of PREF, NR-NeRF, and D-NeRF 161 | 162 | Before using these re-implementations, please see the supplementary material of SceNeRFlow for which modifications were made to the original methods. 163 | 164 | If these modifications are acceptable for your use case, you can set `do_dnerf = True` in the config file for [D-NeRF](https://www.albertpumarola.com/research/D-NeRF/index.html) or `do_nrnerf = True` for [NR-NeRF](https://vcai.mpi-inf.mpg.de/projects/nonrigid_nerf/). See the bottom of `train.py` for what happens internally when setting these flags. 165 | 166 | For [PREF](https://lsongx.github.io/projects/pref.html), in addition to setting `do_pref = True`, multiple independent trainings need to be run because PREF splits the scene into temporal windows. First, start a training with `pref_dataset_index = 0`. Then, start another training with `pref_dataset_index = 1`, another one with `2` etc., until the latest one throws an error that the dataset is not that large. 167 | 168 | ## Citation 169 | 170 | ``` 171 | @inproceedings{tretschk2024scenerflow, 172 | title = {SceNeRFlow: Time-Consistent Reconstruction of General Dynamic Scenes}, 173 | author = {Tretschk, Edith and Golyanik, Vladislav and Zollh\"{o}fer, Michael and Bozic, Aljaz and Lassner, Christoph and Theobalt, Christian}, 174 | year = {2024}, 175 | booktitle={International Conference on 3D Vision (3DV)}, 176 | } 177 | ``` 178 | 179 | ## License 180 | 181 | This repository is released under a CC-BY-NC 4.0 license, please refer to the `LICENSE` file for details. 182 | 183 | Several functions in `utils.py` are modified versions from the [FFJORD codebase](https://github.com/rtqichen/ffjord). The modified AdamW implementation in `smart_adam.py` is based on PyTorch's Adam implementation. The volumetric rendering in `renderer.py` is based on Yen-Chen Lin's [NeRF code](https://github.com/yenchenlin/nerf-pytorch), which in turn is based on the [original NeRF code](https://github.com/bmild/nerf). The iterative optimization in `ray_builder.py` is inspired by code from [Nerfies](https://github.com/google/nerfies). We thank all of them for releasing their code. 184 | -------------------------------------------------------------------------------- /batch_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import numpy as np 9 | import torch 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | 14 | class BatchBuilder: 15 | def __init__(self, settings, ray_builder): 16 | 17 | self.ray_builder = ray_builder 18 | 19 | self.do_vignetting_correction = settings.do_vignetting_correction 20 | self.do_ngp_mip_nerf = settings.do_ngp_mip_nerf 21 | self.debug = settings.debug 22 | 23 | def build(self, batch_size=None, active_imageids=None, precomputed=None, single_image=None): 24 | 25 | batch = {} 26 | 27 | if single_image is not None: 28 | # this assumes that a single full image is requested 29 | # single_image: extrin, intrin, timestep, intrinid, (background) 30 | 31 | rays_dict = self.ray_builder.build( 32 | single_image["extrin"], single_image["intrin"] 33 | ) # cuda 34 | rays_origin = rays_dict["rays_origin"].view(-1, 3) # H * W x 3 35 | rays_dir = rays_dict["rays_dir"].view(-1, 3) 36 | 37 | num_rays = rays_origin.shape[0] 38 | timesteps = single_image["timestep"].repeat(num_rays) 39 | near = torch.tensor(single_image["extrin"]["near"]).repeat(num_rays) 40 | far = torch.tensor(single_image["extrin"]["far"]).repeat(num_rays) 41 | intrinids = torch.tensor(single_image["intrin"]["intrinid"]).repeat(num_rays) 42 | if "background" in single_image: 43 | batch["background"] = single_image["background"].view(-1, 3) 44 | 45 | if self.do_vignetting_correction or self.do_ngp_mip_nerf: 46 | width = single_image["intrin"]["width"] 47 | height = single_image["intrin"]["height"] 48 | y_coordinates = torch.arange(height).repeat_interleave( 49 | width 50 | ) # [0,0,0,1,1,1,2,2,2,3,3,3] 51 | x_coordinates = torch.arange(width).repeat(height) # [0,1,2,0,1,2,0,1,2,0,1,2] 52 | 53 | x_center = torch.tensor(single_image["intrin"]["center_x"]) 54 | y_center = torch.tensor(single_image["intrin"]["center_y"]) 55 | 56 | if self.do_ngp_mip_nerf: 57 | x_center = x_center.repeat(num_rays) 58 | y_center = y_center.repeat(num_rays) 59 | x_focal = torch.tensor(single_image["intrin"]["focal_x"]).repeat(num_rays) 60 | y_focal = torch.tensor(single_image["intrin"]["focal_y"]).repeat(num_rays) 61 | 62 | batch["rotation"] = single_image["extrin"]["rotation"].repeat( 63 | num_rays, 1, 1 64 | ) # num_rays x 3 x 3 65 | 66 | elif precomputed is not None: 67 | 68 | num_images = precomputed["rays_origin"].shape[0] 69 | if active_imageids is None: 70 | active_imageids = torch.arange(num_images) 71 | 72 | def flatten_pixel_dimensions(tensor): 73 | # turns N x H x W x F into N x H * W x F 74 | return tensor.view(tensor.shape[0], -1, tensor.shape[-1]) 75 | 76 | rgb = flatten_pixel_dimensions(precomputed["rgb"]) 77 | rays_origin = flatten_pixel_dimensions(precomputed["rays_origin"]) 78 | rays_dir = flatten_pixel_dimensions(precomputed["rays_dir"]) 79 | 80 | num_rays_per_image = rgb.shape[1] 81 | if batch_size is None: 82 | # used by scheduler and state_loader_saver, which use a named subset as batch. see data_handler. 83 | assert num_images == len(active_imageids) 84 | image_indices = torch.arange(num_images).repeat_interleave( 85 | num_rays_per_image 86 | ) # [0,0,0,1,1,1,2,2,2,3,3,3] 87 | flattened_indices = torch.arange(num_rays_per_image).repeat( 88 | num_images 89 | ) # [0,1,2,0,1,2,0,1,2,0,1,2] 90 | else: 91 | # standard training batch 92 | image_indices = torch.randint( 93 | len(active_imageids), size=(batch_size,) 94 | ) # among active training images 95 | flattened_indices = torch.randint(num_rays_per_image, size=(batch_size,)) 96 | 97 | all_train_image_indices = active_imageids[image_indices] # among all training images 98 | image_indices = precomputed["train_to_loaded_train_ids"][ 99 | all_train_image_indices 100 | ] # among loaded training images 101 | if self.debug and torch.any(image_indices == -1): 102 | raise AssertionError("mapping is broken") 103 | 104 | rgb = rgb[image_indices, flattened_indices] 105 | batch["rgb"] = rgb 106 | 107 | rays_origin = rays_origin[image_indices, flattened_indices] 108 | rays_dir = rays_dir[image_indices, flattened_indices] 109 | timesteps = precomputed["timesteps"][image_indices] 110 | near = precomputed["near"][image_indices] 111 | far = precomputed["far"][image_indices] 112 | intrinids = precomputed["intrinids"][image_indices] 113 | 114 | if "background" in precomputed or self.do_vignetting_correction or self.do_ngp_mip_nerf: 115 | if "coordinate_subsets" in precomputed: 116 | # not using full images 117 | yx_coordinates = precomputed["coordinate_subsets"][ 118 | image_indices, flattened_indices 119 | ] 120 | y_coordinates = yx_coordinates[:, 0] 121 | x_coordinates = yx_coordinates[:, 1] 122 | if self.do_vignetting_correction: 123 | width = precomputed["intrins"]["width"][0] 124 | else: 125 | # when using a standard training batch and data_handler loaded full images. 126 | width = precomputed["background"].shape[2] 127 | y_coordinates = torch.div(flattened_indices, width, rounding_mode="floor") 128 | x_coordinates = flattened_indices % width 129 | 130 | if self.do_vignetting_correction or self.do_ngp_mip_nerf: 131 | x_center = precomputed["intrins"]["center_x"][intrinids] 132 | y_center = precomputed["intrins"]["center_y"][intrinids] 133 | 134 | if "background" in precomputed: 135 | exintrinids = precomputed["exintrinids"][image_indices] 136 | batch["background"] = precomputed["background"][ 137 | exintrinids, y_coordinates, x_coordinates 138 | ] 139 | 140 | if self.do_ngp_mip_nerf: 141 | x_focal = precomputed["intrins"]["focal_x"][intrinids] 142 | y_focal = precomputed["intrins"]["focal_y"][intrinids] 143 | batch["rotation"] = precomputed["extrins"]["rotation"][ 144 | image_indices, :, : 145 | ] # num_rays x 3 x 3 146 | 147 | batch["image_indices"] = all_train_image_indices 148 | 149 | if self.do_vignetting_correction: 150 | batch["normalized_x_coordinate"] = (x_coordinates - x_center) / width 151 | batch["normalized_y_coordinate"] = ( 152 | y_coordinates - y_center 153 | ) / width # same normalization as x 154 | 155 | if self.do_ngp_mip_nerf: 156 | batch["x_coordinate"] = x_coordinates 157 | batch["y_coordinate"] = y_coordinates 158 | batch["x_center"] = x_center 159 | batch["y_center"] = y_center 160 | batch["x_focal"] = x_focal 161 | batch["y_focal"] = y_focal 162 | 163 | batch.update( 164 | { 165 | "rays_origin": rays_origin, 166 | "rays_dir": rays_dir, 167 | "timesteps": timesteps, 168 | "near": near, 169 | "far": far, 170 | "intrinids": intrinids, 171 | } 172 | ) 173 | 174 | batch = {key: tensor.cuda() for key, tensor in batch.items()} 175 | 176 | return batch 177 | -------------------------------------------------------------------------------- /batch_composer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | from random import shuffle 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | import torch 9 | 10 | 11 | class BatchComposer: 12 | def __init__(self, batch_builder): 13 | self.batch_builder = batch_builder 14 | 15 | def set_batch_composition(self, batch_composition): 16 | self.batch_composition = batch_composition 17 | 18 | def _determine_subbatch_sizes(self, batch_size): 19 | 20 | shuffle(self.batch_composition) # in-place shuffling 21 | 22 | subbatch_sizes = [] 23 | for subbatch_info in self.batch_composition: 24 | if len(subbatch_info["imageids"]) == 0: 25 | subbatch_size = 0 26 | else: 27 | subbatch_size = int(subbatch_info["fraction"] * batch_size) 28 | subbatch_sizes.append(subbatch_size) 29 | 30 | final_subbatch_sizes = [] 31 | remaining = batch_size - sum(subbatch_sizes) 32 | for subbatch_size in subbatch_sizes: 33 | if subbatch_size > 0 and remaining > 0: 34 | final_subbatch_sizes.append(subbatch_size + remaining) 35 | else: 36 | final_subbatch_sizes.append(subbatch_size) 37 | 38 | return final_subbatch_sizes 39 | 40 | def compose(self, batch_size, precomputed): 41 | 42 | subbatch_sizes = self._determine_subbatch_sizes(batch_size) 43 | 44 | subbatches = [] 45 | for subbatch_size, subbatch_info in zip(subbatch_sizes, self.batch_composition): 46 | if subbatch_size > 0: 47 | imageids = subbatch_info["imageids"] # among all training images 48 | 49 | subbatch = self.batch_builder.build( 50 | active_imageids=imageids, batch_size=subbatch_size, precomputed=precomputed 51 | ) 52 | 53 | subbatches.append(subbatch) 54 | 55 | batch = { 56 | key: torch.cat([subbatch[key] for subbatch in subbatches], axis=0) 57 | for key in subbatches[0].keys() 58 | } 59 | 60 | return batch 61 | -------------------------------------------------------------------------------- /binary_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import os 8 | 9 | 10 | class BinaryDataset: 11 | def __init__(self, folder, name=None, delete_existing=False, read_only=False): 12 | 13 | if name is None: 14 | name = "dataset" 15 | self.name = name 16 | self.folder = folder 17 | 18 | self.read_only = read_only 19 | 20 | self.open(delete_existing=delete_existing) 21 | 22 | def _get_dataset_index_file_name(self): 23 | return os.path.join(self.folder, self.name + "_index.json") 24 | 25 | def _get_dataset_file_name(self): 26 | return os.path.join(self.folder, self.name + ".bin") 27 | 28 | def open(self, delete_existing=False): 29 | 30 | dataset_file = self._get_dataset_file_name() 31 | if self.read_only: 32 | mode = "br" 33 | else: 34 | if delete_existing: 35 | mode = "bw+" 36 | else: 37 | mode = "ba+" 38 | self._dataset_bin = open(dataset_file, mode) 39 | 40 | dataset_index_file = self._get_dataset_index_file_name() 41 | if os.path.exists(dataset_index_file) and not delete_existing: 42 | with open(dataset_index_file, "r") as json_file: 43 | self._dataset_index = json.load(json_file) 44 | if len(self._dataset_index) > 0: 45 | self._start = max([entry["end"] for entry in self._dataset_index.values()]) 46 | else: 47 | self._start = 0 48 | self._modified = False 49 | else: 50 | self._dataset_index = {} 51 | self._start = 0 52 | self._modified = True 53 | 54 | def maybe_add_entry(self, entry_bytes, key): 55 | 56 | if self.read_only: 57 | raise RuntimeError("trying to add to BinaryDataset in read_only mode") 58 | 59 | key = str(key) 60 | 61 | if key in self: 62 | return 63 | 64 | self._modified = True 65 | 66 | self._dataset_bin.seek(self._start) 67 | 68 | self._end = self._start + entry_bytes.getbuffer().nbytes 69 | self._dataset_index[key] = { 70 | "start": self._start, # inclusive 71 | "end": self._end, # exclusive 72 | } 73 | self._start = self._end 74 | 75 | self._dataset_bin.write(entry_bytes.getbuffer()) 76 | 77 | def __contains__(self, key): 78 | return str(key) in self._dataset_index 79 | 80 | def keys(self): 81 | return self._dataset_index.keys() 82 | 83 | def get_entry(self, key): 84 | 85 | key = str(key) 86 | 87 | start = self._dataset_index[key]["start"] 88 | end = self._dataset_index[key]["end"] 89 | 90 | self._dataset_bin.seek(start) 91 | entry_bytes = self._dataset_bin.read(end - start) 92 | 93 | return entry_bytes 94 | 95 | def close(self): 96 | 97 | self._dataset_bin.close() 98 | 99 | if self._modified: 100 | dataset_index_file = self._get_dataset_index_file_name() 101 | with open(dataset_index_file, "w", encoding="utf-8") as json_file: 102 | json.dump(self._dataset_index, json_file, ensure_ascii=False, indent=4) 103 | 104 | def flush(self): 105 | self.close() 106 | self.open() 107 | -------------------------------------------------------------------------------- /canonical_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | 8 | 9 | def get_canonical_model(settings): 10 | from canonical_model_ngp import CanonicalModelNGP 11 | 12 | if settings.backbone == "ngp": 13 | return CanonicalModelNGP(settings) 14 | 15 | 16 | class CanonicalModel(torch.nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SceNeRFlow/4c9b6b3e1b83935e43e33768b0b2760405e68641/configs/.DS_Store -------------------------------------------------------------------------------- /configs/default.txt: -------------------------------------------------------------------------------- 1 | expname = SOME_EXPERIMENT_NAME 2 | datadir = ./SOME_INPUT_SCENE_FOLDER/ 3 | test_cameras = ['StaticCamera.008'] 4 | 5 | basedir = ./results/ 6 | 7 | debug = False 8 | backbone = ngp 9 | prefer_cutlass_over_fullyfused_mlp = False 10 | slurm = False 11 | multi_gpu = False 12 | 13 | factor = 1 14 | render_factor = 2 15 | 16 | tracking_mode = temporal 17 | optimization_mode = per_timestep 18 | variant = snf 19 | 20 | netdepth = 8 21 | netwidth = 256 22 | netdepth_fine = 8 23 | netwidth_fine = 256 24 | 25 | batch_size = 1024 26 | num_points_per_ray = 3072 27 | 28 | points_per_chunk = 4194304 29 | no_batching = False 30 | ft_path = None 31 | reconstruction_loss_type = L1 32 | smooth_deformations_type = norm_preserving 33 | num_iterations = -1 34 | learning_rate_decay_autodecoding_fraction = 0.01 35 | learning_rate_decay_autodecoding_iterations = 1000000 36 | learning_rate_decay_mlp_fraction = 1e-2 37 | learning_rate_decay_mlp_iterations = -1 38 | activation_function = LeakyReLU 39 | use_visualizer = False 40 | N_importance = 0 41 | perturb = 1.0 42 | 43 | use_viewdirs = False 44 | do_ngp_mip_nerf = False 45 | do_pref = False 46 | pref_tau_window = 3 47 | pref_dataset_index = -1 48 | do_nrnerf = False 49 | do_dnerf = False 50 | 51 | i_embed = 0 52 | multires = 10 53 | multires_views = 4 54 | raw_noise_std = 1.0 55 | use_background = True 56 | brightness_variability = 0.0 57 | render_only = False 58 | render_test = False 59 | color_calibration_mode = none 60 | 61 | weight_smooth_deformations = 0.0 62 | weight_coarse_smooth_deformations = 1000.0 63 | weight_fine_smooth_deformations = 30.0 64 | weight_parameter_regularization = 0.01 65 | weight_background_loss = 0.001 66 | weight_brightness_change_regularization = 0.0 67 | weight_hard_surface_loss = 1.0 68 | 69 | pure_mlp_bending = False 70 | coarse_parametrization = hashgrid # MLP, hashgrid 71 | use_global_transform = False 72 | coarse_and_fine = True 73 | fine_range = 0.1 74 | 75 | deformation_per_timestep_decay_rate = 0.1 76 | slow_canonical_per_timestep_learning_rate = 1e-4 77 | fix_coarse_after_a_while = True 78 | let_canonical_vary_at_last = False 79 | let_only_brightness_vary = False 80 | keep_coarse_mlp_constant = False 81 | 82 | use_pruning = True 83 | voxel_grid_size = 128 84 | no_pruning_probability = 0.0 85 | 86 | do_vignetting_correction = True 87 | coarse_mlp_weight_decay = 0 88 | coarse_mlp_skip_connections = 2 89 | smoothness_robustness_threshold = 0.0 90 | use_half_precision = False 91 | do_zero_out = True 92 | dataset_type = blender 93 | testskip = 8 94 | use_temporal_latent_codes = False 95 | use_time_conditioning = False 96 | use_nerfies_se3 = False 97 | no_ndc = False 98 | disparity_sampling = False 99 | spherify = False 100 | llffhold = 8 101 | i_print = 100 102 | i_img = 2500 103 | i_testset = 200000000 104 | 105 | always_load_full_dataset = False 106 | no_reload = False 107 | allow_scratch_datadir_copy = False 108 | 109 | save_temporary_checkpoint_every = 2500 110 | save_checkpoint_every = 10000 111 | save_intermediate_checkpoint_every = 1000000 112 | save_per_timestep = True 113 | save_per_timestep_in_scratch = False 114 | 115 | i_video = 100000000 116 | -------------------------------------------------------------------------------- /data_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import numpy as np 9 | import torch 10 | from batch_builder import BatchBuilder 11 | from batch_composer import BatchComposer 12 | from data_loader import get_data_loader 13 | from ray_builder import RayBuilder 14 | 15 | LOGGER = logging.getLogger(__name__) 16 | 17 | 18 | class DataHandler: 19 | def __init__(self, settings, rank=0, world_size=1, precomputation_mode=False): 20 | 21 | self.rank = rank 22 | self.world_size = world_size 23 | self.optimization_mode = settings.optimization_mode 24 | self.do_ngp_mip_nerf = settings.do_ngp_mip_nerf 25 | 26 | self.data_loader = get_data_loader(settings, self.rank, self.world_size) 27 | 28 | self.ray_builder = RayBuilder(rank=self.rank, multi_gpu=settings.multi_gpu) 29 | self.ray_builder.use_precomputed_dataset( 30 | self.data_loader.get_dataset_folder(), create=precomputation_mode 31 | ) 32 | 33 | self.batch_builder = BatchBuilder(settings, self.ray_builder) 34 | self.batch_composer = BatchComposer(self.batch_builder) 35 | 36 | def get_batch(self, batch_size=None, subset_name=None): 37 | if subset_name is None: 38 | return self.batch_composer.compose( 39 | batch_size, precomputed=self.get_precomputed(subset_name="main") 40 | ) 41 | else: 42 | # used by scheduler and state_loader_saver 43 | return self.batch_builder.build(precomputed=self.get_precomputed(subset_name)) 44 | 45 | def get_timeline_range(self): 46 | return self.data_loader.get_timeline_range() 47 | 48 | def get_training_set_size(self): 49 | return len(self.data_loader.get_train_imageids()) 50 | 51 | def get_train_imageids(self): 52 | return self.data_loader.get_train_imageids() 53 | 54 | def get_precomputed(self, subset_name=None): 55 | 56 | if subset_name is None: 57 | subset_name = "main" 58 | 59 | needs_special_handling = ["rgb", "rays_origin", "rays_dir", "coordinate_subsets"] 60 | 61 | precomputed = { 62 | key: value 63 | for key, value in self.precomputed.items() 64 | if key not in needs_special_handling 65 | } 66 | 67 | for key in needs_special_handling: 68 | value = self.precomputed[key][subset_name] 69 | # remove the key "coordinate_subsets" if we use all pixels. 70 | # this serves as a signal for functions that they are getting all pixels. 71 | if value is not None: 72 | precomputed[key] = value 73 | 74 | return precomputed 75 | 76 | def load_training_set( 77 | self, 78 | factor, 79 | num_total_rays_to_precompute=None, 80 | num_pixels_per_image=None, 81 | foreground_focused=False, 82 | imageids=None, 83 | also_load_top_left_corner_and_four_courners=False, 84 | ): 85 | 86 | self.precomputed = {} # pytorch cpu 87 | 88 | training_imageids = torch.from_numpy(self.data_loader.get_train_imageids()) 89 | training_imageids = training_imageids.cuda() 90 | if imageids is None: 91 | imageids = training_imageids.clone() 92 | 93 | # indexing for batch_builder (all training images -> loaded training images) 94 | # imageids: loaded training images -> all images 95 | # training_imageids: all training images -> all images 96 | num_total_images = self.data_loader.num_total_images() 97 | all_to_loaded_train = torch.zeros(num_total_images, dtype=np.long) - 1 98 | all_to_loaded_train[imageids] = torch.arange(len(imageids)) 99 | train_to_loaded_train = all_to_loaded_train[training_imageids.cpu()] 100 | self.precomputed["train_to_loaded_train_ids"] = train_to_loaded_train 101 | 102 | # background 103 | if self.data_loader.has_background_images(): 104 | LOGGER.info("loading background images...") 105 | self.precomputed["background"] = torch.from_numpy( 106 | self.data_loader.load_background_images(factor=factor) 107 | ) 108 | self.precomputed["exintrinids"] = self.data_loader.get_exintrinids(imageids).cpu() 109 | 110 | # decide which pixels to load 111 | if num_total_rays_to_precompute is not None: 112 | if num_pixels_per_image is not None: 113 | raise RuntimeError( 114 | "only provide one of num_total_rays_to_precompute or num_pixels_per_image" 115 | ) 116 | num_pixels_per_image = max(1, num_total_rays_to_precompute // len(imageids)) 117 | 118 | if foreground_focused: 119 | desired_subsets = { 120 | "main": { 121 | "mode": "foreground_focused", 122 | "foreground_fraction": 0.8, 123 | "num_pixels_per_image": num_pixels_per_image, 124 | } 125 | } 126 | else: 127 | if num_pixels_per_image is None: 128 | desired_subsets = {"main": {"mode": "all"}} 129 | else: 130 | desired_subsets = { 131 | "main": {"mode": "random", "num_pixels_per_image": num_pixels_per_image} 132 | } 133 | 134 | if also_load_top_left_corner_and_four_courners: 135 | desired_subsets["top_left_corner"] = { 136 | "mode": "specific", 137 | "y_coordinates": [0], 138 | "x_coordinates": [0], 139 | } 140 | desired_subsets["four_corners"] = { 141 | "mode": "specific", 142 | "y_coordinates": [0, 0, -1, -1], 143 | "x_coordinates": [0, -1, 0, -1], 144 | } 145 | 146 | # rgb 147 | LOGGER.info("loading training rgb...") 148 | rgb_subsets, coordinate_subsets = self.data_loader.load_images( 149 | factor=factor, 150 | imageids=imageids, 151 | desired_subsets=desired_subsets, 152 | background_images_dict=self.precomputed, 153 | ) 154 | self.precomputed["rgb"] = { 155 | subset_name: torch.from_numpy(rgb) for subset_name, rgb in rgb_subsets.items() 156 | } 157 | self.precomputed["coordinate_subsets"] = coordinate_subsets 158 | 159 | # rays_origin, rays_dir 160 | LOGGER.info("loading training rays...") 161 | 162 | extrinids = self.data_loader.get_extrinids(imageids) 163 | extrins = self.data_loader.get_extrinsics(extrinids=extrinids) 164 | if self.do_ngp_mip_nerf: 165 | self.precomputed["extrins"] = {} 166 | self.precomputed["extrins"]["rotation"] = torch.stack( 167 | [extrin["rotation"] for extrin in extrins], dim=0 168 | ) 169 | 170 | intrinids = self.data_loader.get_intrinids(imageids) 171 | self.precomputed["intrinids"] = intrinids.cpu() 172 | all_intrins = self.data_loader.get_intrinsics(factor=factor) 173 | self.precomputed["intrins"] = {} 174 | for key in all_intrins[0].keys(): 175 | if key in ["distortion"]: 176 | continue 177 | values = torch.from_numpy(np.array([intrin[key] for intrin in all_intrins])) 178 | if values.dtype == torch.float64: 179 | values = values.float() 180 | self.precomputed["intrins"][key] = values 181 | image_intrins = self.data_loader.get_intrinsics(intrinids=intrinids, factor=factor) 182 | 183 | rays_dict = self.ray_builder.build_multiple( 184 | extrins, image_intrins, coordinate_subsets=coordinate_subsets 185 | ) 186 | self.precomputed["rays_origin"] = rays_dict["rays_origin"] 187 | self.precomputed["rays_dir"] = rays_dict["rays_dir"] 188 | 189 | # timesteps 190 | self.precomputed["timesteps"] = self.data_loader.get_timesteps(imageids).cpu() 191 | 192 | # near, far 193 | self.precomputed["near"] = torch.from_numpy( 194 | np.array([extrin["near"] for extrin in extrins], dtype=np.float32) 195 | ) 196 | self.precomputed["far"] = torch.from_numpy( 197 | np.array([extrin["far"] for extrin in extrins], dtype=np.float32) 198 | ) 199 | 200 | def get_test_cameras_for_rendering(self, factor=None): 201 | 202 | imageids = self.data_loader.get_test_imageids() 203 | 204 | only_render_current_timestep = self.optimization_mode == "per_timestep" 205 | if only_render_current_timestep: 206 | timesteps = self.data_loader.get_timesteps(imageids).cpu() 207 | imageids_with_right_timestep = [] 208 | right_timesteps = list(set(self.precomputed["timesteps"].numpy())) 209 | for imageid, timestep in zip(imageids, timesteps): 210 | if timestep in right_timesteps: 211 | imageids_with_right_timestep.append(imageid) 212 | imageids = np.array(imageids_with_right_timestep) 213 | 214 | # sort by (extrinid, timestep) such that renderings can be concatenated easily 215 | extrinids = self.data_loader.get_extrinids(imageids).cpu() 216 | timesteps = self.data_loader.get_timesteps(imageids).cpu() 217 | image_info = [ 218 | (int(extrinid), float(timestep), int(imageid)) 219 | for extrinid, timestep, imageid in zip(extrinids, timesteps, imageids) 220 | ] 221 | image_info = sorted(image_info, key=lambda x: (x[0], x[1])) 222 | imageids = np.array([image[2] for image in image_info]) 223 | 224 | LOGGER.debug("test set: " + str(imageids) + " " + str(right_timesteps)) 225 | 226 | test_cameras = {} 227 | 228 | test_cameras["timesteps"] = self.data_loader.get_timesteps(imageids).cpu() 229 | 230 | extrinids = self.data_loader.get_extrinids(imageids) 231 | test_cameras["extrins"] = self.data_loader.get_extrinsics(extrinids=extrinids) 232 | 233 | intrinids = self.data_loader.get_intrinids(imageids) 234 | test_cameras["intrins"] = self.data_loader.get_intrinsics( 235 | intrinids=intrinids, factor=factor 236 | ) 237 | 238 | test_cameras["rgb"] = torch.from_numpy( 239 | self.data_loader.load_images(factor=factor, imageids=imageids)[0]["everything"] 240 | ) 241 | 242 | if self.data_loader.has_background_images(): 243 | exintrinids = self.data_loader.get_exintrinids(imageids) 244 | test_cameras["backgrounds"] = torch.from_numpy( 245 | self.data_loader.load_background_images(factor=factor, exintrinids=exintrinids) 246 | ) 247 | 248 | return test_cameras 249 | 250 | def visualize_images_in_3D(self, results_folder): 251 | # results_folder = state_loader_saver.get_results_folder() 252 | origin = self.precomputed["rays_origin"][:, 0, 0] # N x 3 253 | top_left = self.precomputed["rays_dir"][:, 0, 0] # N x 3 254 | top_right = self.precomputed["rays_dir"][:, 0, -1] # N x 3 255 | bottom_right = self.precomputed["rays_dir"][:, -1, -1] # N x 3 256 | bottom_left = self.precomputed["rays_dir"][:, -1, 0] # N x 3 257 | 258 | rescale_factor = 0.5 259 | 260 | beginning = np.tile(origin, [4, 1]) # 4*N x 3 261 | end = np.concatenate([top_left, top_right, bottom_right, bottom_left], 0) # 4*N x 3 262 | end = beginning + rescale_factor * end 263 | 264 | mesh_string = "" 265 | for x, y, z in beginning: 266 | mesh_string += "v " + str(x) + " " + str(y) + " " + str(z) + " 0.0 1.0 0.0\n" 267 | for x, y, z in end: 268 | mesh_string += "v " + str(x) + " " + str(y) + " " + str(z) + " 1.0 0.0 0.0\n" 269 | for x, y, z in end: 270 | mesh_string += "v " + str(x + 0.00001) + " " + str(y) + " " + str(z) + " 1.0 0.0 0.0\n" 271 | num_vertices = beginning.shape[0] 272 | for i in range(num_vertices): 273 | i += 1 274 | mesh_string += ( 275 | "f " + str(i) + " " + str(i + num_vertices) + " " + str(i + 2 * num_vertices) + "\n" 276 | ) 277 | 278 | import os 279 | 280 | with open(os.path.join(results_folder, "cameras_2.obj"), "w") as mesh_file: 281 | mesh_file.write(mesh_string) 282 | 283 | num_images = 20 284 | total_num_images = self.precomputed["rgb"].shape[0] 285 | step = total_num_images // num_images 286 | from tqdm import trange 287 | 288 | for image in trange(num_images): 289 | image = image * step 290 | rgb = self.precomputed["rgb"][image].numpy() # H x W x 3 291 | origin = self.precomputed["rays_origin"][image].numpy() # H x W x 3 292 | directions = self.precomputed["rays_dir"][image].numpy() # H x W x 3 293 | pos = origin + rescale_factor * directions # H x W x 3 294 | 295 | mesh_string = "" 296 | for (x, y, z), (r, g, b) in zip(pos.reshape(-1, 3), rgb.reshape(-1, 3)): 297 | mesh_string += ( 298 | "v " 299 | + str(x) 300 | + " " 301 | + str(y) 302 | + " " 303 | + str(z) 304 | + " " 305 | + str(r) 306 | + " " 307 | + str(g) 308 | + " " 309 | + str(b) 310 | + "\n" 311 | ) 312 | 313 | with open( 314 | os.path.join(results_folder, "test_" + str(image).zfill(3) + ".obj"), "w" 315 | ) as mesh_file: 316 | mesh_file.write(mesh_string) 317 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | class DataLoader: 7 | def __init__(self): 8 | pass 9 | 10 | 11 | def get_data_loader(settings, rank, world_size): 12 | from data_loader_blender import DataLoaderBlender 13 | 14 | if settings.dataset_type == "blender": 15 | return DataLoaderBlender(settings, rank, world_size) 16 | -------------------------------------------------------------------------------- /deformation_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | 8 | 9 | def get_deformation_model(settings, timeline_range): 10 | from deformation_model_ngp import DeformationModelNGP 11 | 12 | if settings.backbone == "ngp": 13 | return DeformationModelNGP(settings, timeline_range) 14 | 15 | 16 | class DeformationModel(torch.nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def viewdirs_via_finite_differences(self, positions, returns): 21 | # positions: num_rays x points_per_ray x 3 22 | num_rays, points_per_ray, _ = positions.shape 23 | 24 | eps = 1e-6 25 | difference_type = "backward" 26 | if difference_type == "central": 27 | # central differences (except for first and last sample since one neighbor is missing for them) 28 | unnormalized_central_differences = ( 29 | positions[:, 2:, :] - positions[:, :-2, :] 30 | ) # rays x (samples-2) x 3 31 | central_differences = unnormalized_central_differences / ( 32 | torch.norm(unnormalized_central_differences, dim=-1, keepdim=True) + eps 33 | ) 34 | # fill in first and last sample by duplicating neighboring direction 35 | view_directions = torch.cat( 36 | [ 37 | central_differences[:, 0, :].view(-1, 1, 3), 38 | central_differences, 39 | central_differences[:, -1, :].view(-1, 1, 3), 40 | ], 41 | axis=1, 42 | ) # rays x samples x 3 43 | elif difference_type == "backward": 44 | unnormalized_backward_differences = ( 45 | positions[:, 1:, :] - positions[:, :-1, :] 46 | ) # rays x (samples-1) x 3. 0-th sample has no direction. 47 | backward_differences = unnormalized_backward_differences / ( 48 | torch.norm(unnormalized_backward_differences, dim=-1, keepdim=True) + eps 49 | ) 50 | # fill in first sample by duplicating neighboring direction 51 | view_directions = torch.cat( 52 | [backward_differences[:, 0, :].view(-1, 1, 3), backward_differences], 53 | axis=1, 54 | ) # rays x samples x 3 55 | 56 | return view_directions 57 | 58 | def _apply_se3(self, undeformed_positions, network_output): 59 | w, v, pivot, translation = torch.split( 60 | network_output, [3, 3, 3, 3], dim=1 61 | ) # all: num_points x 3 62 | eps = 10e-7 63 | theta = torch.norm(w, dim=-1, keepdim=True) + eps # num_points x 1 64 | w = w / theta 65 | v = v / theta 66 | skew = torch.zeros((w.shape[0], 3, 3), device=w.device) 67 | skew[:, 0, 1] = -w[:, 2] 68 | skew[:, 0, 2] = w[:, 1] 69 | skew[:, 1, 0] = w[:, 2] 70 | skew[:, 1, 2] = -w[:, 0] 71 | skew[:, 2, 0] = -w[:, 1] 72 | skew[:, 2, 1] = w[:, 0] 73 | eye = torch.zeros((w.shape[0], 3, 3), device=w.device) 74 | eye[:, 0, 0] = 1.0 75 | eye[:, 1, 1] = 1.0 76 | eye[:, 2, 2] = 1.0 77 | skew_squared = torch.matmul(skew, skew) 78 | exp_so3 = ( 79 | eye 80 | + torch.sin(theta).view(-1, 1, 1) * skew 81 | + (1.0 - torch.cos(theta)).view(-1, 1, 1) * skew_squared 82 | ) # num_points x 3 x 3 83 | p = ( 84 | theta.view(-1, 1, 1) * eye 85 | + (1.0 - torch.cos(theta)).view(-1, 1, 1) * skew 86 | + (theta - torch.sin(theta)).view(-1, 1, 1) * skew_squared 87 | ) # num_points x 3 x 3 88 | p = torch.matmul(p, v.view(-1, 3, 1)) # num_points x 3 x 1 89 | se3_transform = torch.cat([exp_so3, p], -1) # num_points x 3 x 4 90 | se3_transform = torch.cat( 91 | [ 92 | se3_transform, 93 | torch.zeros((se3_transform.shape[0], 1, 4), device=se3_transform.device), 94 | ], 95 | 1, 96 | ) # num_points x 4 x 4 97 | se3_transform[:, 3, 3] = 1.0 98 | warped_pts = undeformed_positions + pivot # num_points x 3 99 | # in homogenuous coordinates 100 | warped_pts = torch.cat( 101 | [ 102 | warped_pts, 103 | torch.ones((warped_pts.shape[0], 1), device=warped_pts.device), 104 | ], 105 | -1, 106 | ) 107 | warped_pts = torch.matmul(se3_transform, warped_pts.view(-1, 4, 1)).view( 108 | -1, 4 109 | ) # num_points x 4 110 | warped_pts = warped_pts[:, :3] / warped_pts[:, 3].view(-1, 1) # num_points x 3 111 | 112 | warped_pts = warped_pts - pivot 113 | warped_pts = warped_pts + translation 114 | unmasked_offsets = warped_pts - undeformed_positions 115 | return unmasked_offsets 116 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: snf 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.7.12 6 | - configargparse=1.5.3 7 | - numpy=1.21 8 | - cudatoolkit=11.3.1 9 | - pytorch::pytorch=1.10.1 10 | - pytorch::torchvision=0.11.2 -------------------------------------------------------------------------------- /misc/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SceNeRFlow/4c9b6b3e1b83935e43e33768b0b2760405e68641/misc/teaser.png -------------------------------------------------------------------------------- /multi_gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import os 8 | 9 | import torch 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | 14 | def multi_gpu_barrier(rank): 15 | torch.distributed.barrier(device_ids=[rank]) 16 | 17 | 18 | def multi_gpu_receive_returns_from_rank_pathrenderer(rank, world_size, counter, returns): 19 | 20 | some_returns = returns.get_returns() 21 | 22 | returns.activate_mode(counter) 23 | 24 | async_ops = [] 25 | new_tensors = [] 26 | for _name, tensor in some_returns.items(): 27 | this_tensor = torch.empty_like(tensor, device=rank) 28 | 29 | async_op = torch.distributed.irecv(this_tensor, src=counter % world_size) 30 | 31 | async_ops.append(async_op) 32 | new_tensors.append(this_tensor) 33 | 34 | for async_op in async_ops: 35 | async_op.wait() 36 | 37 | for name, tensor in zip(some_returns.keys(), new_tensors): 38 | returns.add_return(name, tensor.cpu()) 39 | 40 | 41 | def multi_gpu_send_returns_to_rank_pathrenderer(target_rank, returns): 42 | 43 | async_ops = [] 44 | for tensor in returns.get_returns().values(): 45 | async_op = torch.distributed.isend(tensor.cuda(), dst=target_rank) 46 | async_ops.append(async_op) 47 | 48 | for async_op in async_ops: 49 | async_op.wait() 50 | 51 | 52 | def multi_gpu_sync_gradients(parameters): 53 | async_ops = [] 54 | for param in parameters: 55 | param = param["parameters"] 56 | if param.grad is None: 57 | continue 58 | async_op = torch.distributed.all_reduce( 59 | param.grad, torch.distributed.ReduceOp.SUM, async_op=True 60 | ) 61 | async_ops.append(async_op) 62 | for async_op in async_ops: 63 | async_op.wait() 64 | 65 | 66 | def multi_gpu_setup(rank, world_size, port): 67 | os.environ["MASTER_ADDR"] = "localhost" 68 | os.environ["MASTER_PORT"] = str(port) # "29500" 69 | torch.cuda.set_device(rank) 70 | torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) 71 | 72 | # testing. also avoids a potential issue when barrier() is the first distributed function called. 73 | a = torch.ones(2) + rank 74 | a = a.to(rank) 75 | torch.distributed.all_reduce(a, op=torch.distributed.ReduceOp.SUM) 76 | b = a == 0 # force to actually sync a 77 | LOGGER.debug("testing parallelization on rank " + str(rank) + ": " + str(a) + " " + str(b)) 78 | 79 | 80 | def multi_gpu_cleanup(rank): 81 | multi_gpu_barrier(rank) 82 | torch.distributed.destroy_process_group() 83 | 84 | 85 | def exception_logging_wrapper(rank, process_function, *args, **kwargs): 86 | try: 87 | process_function(rank, *args, **kwargs) 88 | except Exception as e: 89 | LOGGER.exception("EXCEPTION in rank " + str(rank) + ": " + str(e)) 90 | raise 91 | 92 | 93 | def multi_gpu_train(settings): 94 | 95 | world_size = torch.cuda.device_count() 96 | 97 | LOGGER.info("found " + str(world_size) + " GPUs") 98 | 99 | port = 29500 100 | import random 101 | 102 | port += random.randint(-100, +100) 103 | 104 | from train import train 105 | 106 | process_context = torch.multiprocessing.spawn( 107 | exception_logging_wrapper, 108 | args=( 109 | train, 110 | settings, 111 | world_size, 112 | port, 113 | ), 114 | nprocs=world_size, 115 | join=False, 116 | ) 117 | 118 | try: 119 | process_context.join() 120 | except KeyboardInterrupt: 121 | LOGGER.warning("SHUTTING DOWN! DO NOT INTERRUPT AGAIN!") 122 | for i, process in enumerate(process_context.processes): 123 | if process.is_alive(): 124 | LOGGER.info("terminating process " + str(i) + "...") 125 | process.terminate() 126 | process.join() 127 | LOGGER.info("process " + str(i) + " finished") 128 | raise 129 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import torch 9 | from smart_adam import SmartAdam 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | 14 | class Optimizer: 15 | def __init__(self, settings, scene, renderer): 16 | 17 | self.debug = settings.debug 18 | 19 | self.parameter_sources = [scene, renderer] 20 | 21 | self.optimizers_with_information = [] 22 | for param_with_info in self.get_parameters_with_information(): 23 | 24 | # tag:half_precision : we need Adam because it's invariant under global scaling of the loss. 25 | 26 | if param_with_info["optimizer"] == "Adam": 27 | optimizer = torch.optim.AdamW( 28 | params=param_with_info["parameters"], 29 | lr=param_with_info["learning_rate"], 30 | weight_decay=param_with_info["weight_decay"], 31 | betas=(0.9, 0.99), 32 | eps=1e-15, 33 | ) 34 | elif param_with_info["optimizer"] == "SmartAdam": 35 | optimizer = SmartAdam( 36 | params=param_with_info["parameters"], 37 | lr=param_with_info["learning_rate"], 38 | weight_decay=param_with_info["weight_decay"], 39 | betas=(0.9, 0.99), 40 | eps=1e-15, 41 | ) 42 | 43 | self.optimizers_with_information.append( 44 | { 45 | "name": param_with_info["name"], 46 | "tags": param_with_info["tags"], 47 | "optimizer": optimizer, 48 | "parameters": param_with_info["parameters"], 49 | "initial_learning_rate": param_with_info["learning_rate"], 50 | "decay_steps": param_with_info["decay_steps"], 51 | "decay_rate": param_with_info["decay_rate"], 52 | } 53 | ) 54 | 55 | def get_parameters_with_information(self): 56 | parameters_with_information = [] 57 | for parameter_source in self.parameter_sources: 58 | parameters_with_information += ( 59 | parameter_source.get_parameters_with_optimization_information() 60 | ) 61 | return parameters_with_information 62 | 63 | def zero_grad(self, set_to_none=None): 64 | if set_to_none is None: 65 | set_to_none = True 66 | for optimizer_with_info in self.optimizers_with_information: 67 | optimizer_with_info["optimizer"].zero_grad(set_to_none=set_to_none) 68 | 69 | def scale_gradients(self, factor): 70 | if factor == 1.0: 71 | return 72 | with torch.no_grad(): 73 | for param_with_info in self.get_parameters_with_information(): 74 | for param in param_with_info["parameters"]: 75 | if param.grad is not None: 76 | param.grad *= factor 77 | 78 | def step(self, scaler=None, use_gradient_scaling=False): 79 | for optimizer_with_info in self.optimizers_with_information: 80 | if use_gradient_scaling: 81 | scaler.step(optimizer_with_info["optimizer"]) 82 | else: 83 | optimizer_with_info["optimizer"].step() 84 | 85 | if self.debug and torch.rand(1) < 0.01: 86 | max_abs_grad_value = torch.zeros(1).cuda() 87 | max_abs_grad_name = "-" 88 | 89 | max_abs_value = torch.zeros(1).cuda() 90 | max_abs_name = "-" 91 | 92 | for param_with_info in self.get_parameters_with_information(): 93 | for param in param_with_info["parameters"]: 94 | 95 | if not torch.all(torch.isfinite(param)): 96 | try: 97 | mask = torch.where(~torch.isfinite(param)) 98 | LOGGER.debug( 99 | "Non-finite value in " 100 | + param_with_info["name"] 101 | + ": " 102 | + str(param[mask]) 103 | ) 104 | except Exception: 105 | LOGGER.debug("Non-finite value in " + param_with_info["name"]) 106 | 107 | if param.grad is not None and not torch.all(torch.isfinite(param.grad)): 108 | try: 109 | mask = torch.where(~torch.isfinite(param.grad)) 110 | LOGGER.debug( 111 | "Non-finite value in " 112 | + param_with_info["name"] 113 | + ".grad: " 114 | + str(param.grad[mask]) 115 | ) 116 | except Exception: 117 | LOGGER.debug("Non-finite value in " + param_with_info["name"] + ".grad") 118 | 119 | this_max = torch.max(torch.abs(param)) 120 | if this_max > max_abs_value: 121 | max_abs_value = this_max 122 | max_abs_name = param_with_info["name"] 123 | 124 | if param.grad is not None: 125 | this_max = torch.max(torch.abs(param.grad)) 126 | if this_max > max_abs_grad_value: 127 | max_abs_grad_value = this_max 128 | max_abs_grad_name = param_with_info["name"] 129 | 130 | LOGGER.debug("Parameter max value " + str(max_abs_value.item()) + " in " + max_abs_name) 131 | LOGGER.debug( 132 | "Gradient max value " + str(max_abs_grad_value.item()) + " in " + max_abs_grad_name 133 | ) 134 | 135 | def load_state_dict(self, state_dict): 136 | for optimizer_with_info, this_state_dict in zip( 137 | self.optimizers_with_information, state_dict["optimizers"] 138 | ): 139 | optimizer_with_info["optimizer"].load_state_dict(this_state_dict["optimizer"]) 140 | optimizer_with_info["initial_learning_rate"] = this_state_dict["initial_learning_rate"] 141 | optimizer_with_info["decay_steps"] = this_state_dict["decay_steps"] 142 | optimizer_with_info["decay_rate"] = this_state_dict["decay_rate"] 143 | 144 | def state_dict(self): 145 | optimizers = [] 146 | 147 | for optimizer_with_info in self.optimizers_with_information: 148 | optimizers.append( 149 | { 150 | "optimizer": optimizer_with_info["optimizer"].state_dict(), 151 | "initial_learning_rate": optimizer_with_info["initial_learning_rate"], 152 | "decay_steps": optimizer_with_info["decay_steps"], 153 | "decay_rate": optimizer_with_info["decay_rate"], 154 | } 155 | ) 156 | 157 | return {"optimizers": optimizers} 158 | -------------------------------------------------------------------------------- /path_renderer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import os 8 | 9 | import imageio 10 | import numpy as np 11 | import torch 12 | from multi_gpu import ( 13 | multi_gpu_receive_returns_from_rank_pathrenderer, 14 | multi_gpu_send_returns_to_rank_pathrenderer, 15 | ) 16 | from tqdm import trange 17 | from utils import Returns 18 | 19 | logging.getLogger("matplotlib").setLevel(logging.WARNING) 20 | LOGGER = logging.getLogger(__name__) 21 | 22 | 23 | class PathRenderer: 24 | def __init__(self, data_handler, rank, world_size): 25 | self.batch_builder = data_handler.batch_builder 26 | 27 | self.rank = rank 28 | self.world_size = world_size 29 | 30 | def render( 31 | self, 32 | extrins, 33 | intrins, 34 | timesteps, 35 | scene, 36 | renderer, 37 | backgrounds=None, 38 | points_per_ray=None, 39 | reduce_memory_for_correspondences=False, 40 | returns=None, 41 | hacky_checkpoint_loading=None, 42 | **kwargs 43 | ): 44 | 45 | if returns is None: 46 | returns = Returns(restricted=["corrected_rgb"]) 47 | 48 | for counter in trange(len(extrins)): 49 | 50 | returns.activate_mode(counter) 51 | 52 | if counter % self.world_size != self.rank: 53 | if self.rank == 0: 54 | multi_gpu_receive_returns_from_rank_pathrenderer( 55 | self.rank, self.world_size, counter, returns 56 | ) 57 | continue 58 | 59 | extrin = extrins[counter] 60 | intrin = intrins[counter] 61 | timestep = timesteps[counter] 62 | if backgrounds is not None: 63 | background = backgrounds[counter] 64 | 65 | float_timestep = float(timestep.numpy()) 66 | 67 | if hacky_checkpoint_loading is not None: 68 | hacky_checkpoint_loading.load(float_timestep) 69 | 70 | uses_pruning = ( 71 | scene.pruning.do_use_pruning and scene.pruning.current_timestep != timestep 72 | ) 73 | if uses_pruning: 74 | original_timestep = scene.pruning.current_timestep 75 | scene.pruning.load_voxel_grid( 76 | float_timestep if "only_canonical" not in kwargs else "all" 77 | ) 78 | 79 | single_image = { 80 | "extrin": extrin, 81 | "intrin": intrin, 82 | "timestep": timestep, 83 | } 84 | if backgrounds is not None: 85 | single_image["background"] = background 86 | 87 | with torch.no_grad(): 88 | 89 | batch = self.batch_builder.build(single_image=single_image) 90 | 91 | rendering_generator_wrapper = renderer.render( 92 | batch=batch, 93 | scene=scene, 94 | points_per_ray=points_per_ray, 95 | returns=returns, 96 | pull_to_cpu=True, 97 | iterator_yield=True, 98 | **kwargs 99 | ) 100 | 101 | for subreturns in rendering_generator_wrapper(): 102 | 103 | if ( 104 | reduce_memory_for_correspondences 105 | and "deformed_positions" in subreturns 106 | and "weights" in subreturns 107 | ): 108 | correspondences_rgb = self._visualize_correspondences( 109 | subreturns, 110 | number_of_small_rgb_voxels=kwargs["number_of_small_rgb_voxels"] 111 | if "number_of_small_rgb_voxels" in kwargs 112 | else 30, 113 | blend_with_corrected_rgb=kwargs["blend_with_corrected_rgb"] 114 | if "blend_with_corrected_rgb" in kwargs 115 | else 0.0, 116 | ) 117 | subreturns.delete_return("deformed_positions") 118 | subreturns.delete_return("weights") 119 | subreturns.add_return("correspondences_rgb", correspondences_rgb) 120 | 121 | returns.add_returns(subreturns.concatenate_returns()) 122 | 123 | returns.reshape_returns(width=intrin["width"], height=intrin["height"]) 124 | 125 | if uses_pruning: 126 | scene.pruning.load_voxel_grid(original_timestep) 127 | 128 | if self.rank != 0: 129 | multi_gpu_send_returns_to_rank_pathrenderer(target_rank=0, returns=returns) 130 | 131 | return returns.get_returns()["corrected_rgb"], returns 132 | 133 | def _visualize_correspondences( 134 | self, 135 | returns, 136 | number_of_small_rgb_voxels=30, 137 | background_threshold=0.4, 138 | blend_with_corrected_rgb=0.0, 139 | ): 140 | 141 | device = "cpu" # self.rank 142 | 143 | deformed_positions = returns.get_returns()["deformed_positions"].to( 144 | device 145 | ) # num_rays x num_points x 3 146 | weights = returns.get_returns()["weights"].to(device) # num_rays x num_points 147 | 148 | # visibility_weight is the weight of the influence that each sample has on the final rgb value. so they sum to at most 1. 149 | accumulated_visibility = torch.cumsum(weights, dim=-1) # num_rays x num_points 150 | background_mask = accumulated_visibility[:, -1] < background_threshold # num_rays 151 | median_indices = torch.min(torch.abs(accumulated_visibility - 0.5), dim=-1)[ 152 | 1 153 | ] # num_rays. visibility goes from 0 to 1. 0.5 is the median, so treat it as "most likely to be on the actually visible surface" 154 | num_rays = median_indices.shape[0] 155 | # median_indices contains the index of one ray sample for each pixel. 156 | # this ray sample is selected in this line of code. 157 | surface_pixels = deformed_positions[ 158 | torch.arange(num_rays, device=device), median_indices, : 159 | ] # num_rays x 3 160 | correspondences_rgb = surface_pixels 161 | 162 | # break the canonical space into smaller voxels. 163 | # each voxel covers the entire RGB space [0,1]^3. 164 | # makes it easier to visualize small changes. leads to a 3D checkerboard pattern. 165 | if number_of_small_rgb_voxels > 1: 166 | correspondences_rgb *= number_of_small_rgb_voxels 167 | correspondences_rgb = correspondences_rgb - correspondences_rgb.long() 168 | 169 | # correspondences_rgb[background_mask] = 0.0 170 | 171 | corrected_rgb = returns.get_returns()["corrected_rgb"].to(device) # num_rays 172 | correspondences_rgb = ( 173 | 1.0 - blend_with_corrected_rgb 174 | ) * correspondences_rgb + blend_with_corrected_rgb * corrected_rgb 175 | correspondences_rgb[background_mask] = corrected_rgb[background_mask] # modified 176 | 177 | z_vals = returns.get_returns()["z_vals"] 178 | depth = z_vals[torch.arange(num_rays, device=device), median_indices] 179 | depth = torch.clamp(depth, min=0.0, max=1.0) 180 | depth[background_mask] = 1.0 181 | returns.delete_return("depth") 182 | returns.add_return("depth", depth) 183 | min_normalized_depth = 0.1 184 | disparity = 1.0 / torch.max(min_normalized_depth * torch.ones_like(depth), depth) 185 | disparity *= min_normalized_depth 186 | returns.delete_return("disparity") 187 | returns.add_return("disparity", disparity) 188 | 189 | return correspondences_rgb.cpu() 190 | 191 | def render_and_store( 192 | self, 193 | state_loader_saver, 194 | output_name, 195 | returns=None, 196 | rgb=None, 197 | visualize_correspondences=True, 198 | reduce_memory_for_correspondences=True, 199 | also_store_images=False, 200 | output_folder=None, 201 | hacky_checkpoint_loading=None, 202 | only_render_if_file_does_not_exist=True, 203 | **kwargs 204 | ): 205 | 206 | if output_folder is None and self.rank == 0: 207 | output_folder = os.path.join(state_loader_saver.get_results_folder(), "0_renderings") 208 | state_loader_saver.create_folder(output_folder) 209 | 210 | if only_render_if_file_does_not_exist: 211 | check_output_file = os.path.join(output_folder, output_name + "_rgb.mp4") 212 | if os.path.exists(check_output_file): 213 | LOGGER.info("already rendered. will not render again: " + output_name) 214 | return 215 | 216 | if returns is None: 217 | if visualize_correspondences: 218 | correspondences = ["deformed_positions", "weights", "correspondences_rgb"] 219 | # deformed_positions and weights get deleted if reduce_memory_for_correspondences==True 220 | else: 221 | correspondences = [] 222 | returns = Returns( 223 | restricted=["corrected_rgb", "disparity", "depth", "z_vals"] + correspondences 224 | ) 225 | else: 226 | reduce_memory_for_correspondences = False 227 | 228 | _, returns = self.render( 229 | returns=returns, 230 | reduce_memory_for_correspondences=reduce_memory_for_correspondences, 231 | hacky_checkpoint_loading=hacky_checkpoint_loading, 232 | **kwargs 233 | ) 234 | 235 | if self.rank != 0: 236 | return 237 | 238 | def store(output_file, images, fps=30, quality=10): 239 | imageio.mimwrite(output_file, images, fps=fps, quality=quality) 240 | if also_store_images: 241 | for counter, image in enumerate(images): 242 | imageio.imsave( 243 | output_file + "_" + str(counter).zfill(5) + ".jpg", image, quality=100 244 | ) 245 | 246 | def saveable(tensor): 247 | try: 248 | tensor = tensor.numpy() 249 | except Exception: 250 | pass 251 | return (255 * np.clip(tensor, 0, 1)).astype(np.uint8) 252 | 253 | def jet_color_scheme(tensor): 254 | # tensor: values in [0,1] 255 | from matplotlib import cm 256 | 257 | tensor = cm.jet(saveable(tensor))[:, :, :, :3] 258 | return tensor # values in [0,1] 259 | 260 | def stack_images_for_video(name): 261 | stacked = [] 262 | for counter in returns.get_modes(): 263 | returns.activate_mode(counter) 264 | this_result = returns.get_returns()[name] 265 | stacked.append(this_result) 266 | return torch.stack(stacked, dim=0) 267 | 268 | # rgb 269 | if "corrected_rgb" in returns.get_returns(): 270 | corrected_rgb = stack_images_for_video("corrected_rgb") 271 | 272 | output_file = os.path.join(output_folder, output_name + "_rgb.mp4") 273 | store(output_file, saveable(corrected_rgb)) 274 | 275 | if rgb is not None: 276 | # save groundtruth only once 277 | save_groundtruth = True 278 | # save_groundtruth = not any("_rgb_gt" in file for file in os.listdir(output_folder)) 279 | if save_groundtruth: 280 | output_file = os.path.join(output_folder, output_name + "_rgb_gt.mp4") 281 | store(output_file, saveable(rgb)) 282 | 283 | error_map = np.linalg.norm(rgb - corrected_rgb, axis=-1) / np.sqrt(3) 284 | error_map = np.clip( 285 | error_map / 0.10, 0.0, 1.0 286 | ) # emphasize small errors, clip at 10% max error 287 | error_map = jet_color_scheme(error_map) 288 | output_file = os.path.join(output_folder, output_name + "_error.mp4") 289 | store(output_file, saveable(error_map), quality=6) 290 | 291 | # depth 292 | if "depth" in returns.get_returns(): 293 | depth = stack_images_for_video("depth") 294 | depth = depth / torch.max(depth) 295 | output_file = os.path.join(output_folder, output_name + "_depth.mp4") 296 | store(output_file, saveable(depth)) 297 | 298 | # disparity 299 | if "disparity" in returns.get_returns(): 300 | disparity = stack_images_for_video("disparity") 301 | disparity = jet_color_scheme(disparity) 302 | output_file = os.path.join(output_folder, output_name + "_disparity.mp4") 303 | store(output_file, saveable(disparity)) 304 | 305 | # correspondences 306 | if "correspondences_rgb" in returns.get_returns(): 307 | correspondences_rgb = stack_images_for_video("correspondences_rgb") 308 | output_file = os.path.join(output_folder, output_name + "_correspondences_rgb.mp4") 309 | store(output_file, saveable(correspondences_rgb), quality=5) 310 | 311 | if hacky_checkpoint_loading is not None: 312 | raise AssertionError # make sure that hacky_checkpoint_loading is only used intentionally, so throw an error that needs to be caught 313 | -------------------------------------------------------------------------------- /post_correction.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import torch 9 | from utils import project_to_correct_range 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | 14 | class ColorCalibration(torch.nn.Module): 15 | def __init__(self, settings): 16 | 17 | super().__init__() 18 | 19 | self.color_calibration_mode = settings.color_calibration_mode 20 | 21 | if self.color_calibration_mode == "none": 22 | pass 23 | 24 | elif self.color_calibration_mode == "full_matrix": 25 | 26 | num_train_images = data_handler.get_training_set_size() 27 | 28 | self.register_parameter( 29 | name="imageids_to_bias", 30 | param=torch.nn.Parameter(torch.zeros(num_train_images, 3, 1)), 31 | ) 32 | self.register_parameter( 33 | name="imageids_to_full_matrix", 34 | param=torch.nn.Parameter(torch.zeros(num_train_images, 3, 3)), 35 | ) 36 | with torch.no_grad(): 37 | self.imageids_to_full_matrix[:, 0, 0] = 1.0 38 | self.imageids_to_full_matrix[:, 1, 1] = 1.0 39 | self.imageids_to_full_matrix[:, 2, 2] = 1.0 40 | 41 | elif self.color_calibration_mode == "neural_volumes": 42 | 43 | num_train_images = data_handler.get_training_set_size() 44 | 45 | self.register_parameter( 46 | name="imageids_to_scalings", 47 | param=torch.nn.Parameter(torch.ones(num_train_images, 3)), 48 | ) 49 | self.register_parameter( 50 | name="imageids_to_biases", 51 | param=torch.nn.Parameter(torch.zeros(num_train_images, 3)), 52 | ) 53 | 54 | def forward(self, rgb, batch, returns=None): 55 | 56 | if self.color_calibration_mode == "none": 57 | return rgb 58 | 59 | elif self.color_calibration_mode == "full_matrix": 60 | 61 | biases = self.imageids_to_bias[batch["image_indices"]] # N x 3 x 1 62 | matrices = self.imageids_to_full_matrix[batch["image_indices"]] # N x 3 x 3 63 | rgb = torch.matmul(matrices, rgb.view(-1, 3, 1)) + biases # N x 3 x 1 64 | rgb = rgb.view(-1, 3) # N x 3 65 | 66 | rgb = project_to_correct_range(rgb, mode="zick_zack") 67 | 68 | return rgb 69 | 70 | elif self.color_calibration_mode == "neural_volumes": 71 | 72 | biases = self.imageids_to_bias[batch["image_indices"]] # N x 3 73 | scalings = self.imageids_to_scalings[batch["image_indices"]] # N x 3 74 | rgb = scalings * rgb + biases # N x 3 75 | 76 | rgb = project_to_correct_range(rgb, mode="zick_zack") 77 | 78 | return rgb 79 | 80 | def get_parameters_with_optimization_information(self): 81 | params = [] 82 | 83 | if self.color_calibration_mode == "full_matrix": 84 | params.append( 85 | { 86 | "name": "color_calibration", 87 | "parameters": self.parameters(), 88 | "optimizer": "SmartAdam", 89 | "learning_rate": 1e-6, 90 | "decay_steps": self.learning_rate_decay_iterations, 91 | "decay_rate": self.learning_rate_decay_fraction, 92 | "weight_decay": 0.0, 93 | } 94 | ) 95 | elif self.color_calibration_mode == "neural_volumes": 96 | params.append( 97 | { 98 | "name": "color_calibration", 99 | "parameters": self.parameters(), 100 | "optimizer": "SmartAdam", 101 | "learning_rate": 1e-6, 102 | "decay_steps": self.learning_rate_decay_iterations, 103 | "decay_rate": self.learning_rate_decay_fraction, 104 | "weight_decay": 0.0, 105 | } 106 | ) 107 | 108 | return params 109 | 110 | def get_regularization_losses(self): 111 | return {} 112 | 113 | 114 | class Background(torch.nn.Module): 115 | def __init__(self, settings): 116 | super().__init__() 117 | 118 | def forward(self, rgb, batch, accumulated_weights, returns=None): 119 | 120 | if "background" in batch: 121 | rgb = rgb + (1.0 - accumulated_weights).view(-1, 1) * batch["background"] 122 | return rgb 123 | 124 | def get_parameters_with_optimization_information(self): 125 | return [] 126 | 127 | def get_regularization_losses(self): 128 | return {} 129 | 130 | 131 | class VignettingCorrection(torch.nn.Module): 132 | def __init__(self, settings): 133 | super().__init__() 134 | 135 | self.do_vignetting_correction = settings.do_vignetting_correction 136 | 137 | self.learning_rate_decay_autodecoding_iterations = ( 138 | settings.learning_rate_decay_autodecoding_iterations 139 | ) 140 | self.learning_rate_decay_autodecoding_fraction = ( 141 | settings.learning_rate_decay_autodecoding_fraction 142 | ) 143 | 144 | # assume all cameras follow the same vignetting 145 | self.vignetting_parameters = torch.nn.Parameter(torch.zeros((3,), dtype=torch.float32)) 146 | 147 | def forward(self, rgb, batch, returns=None): 148 | 149 | if not self.do_vignetting_correction or "normalized_x_coordinate" not in batch: 150 | return rgb 151 | 152 | k1, k2, k3 = torch.unbind(self.vignetting_parameters) 153 | r = batch["normalized_x_coordinate"] ** 2 + batch["normalized_y_coordinate"] ** 2 154 | 155 | offset = k1 * r + k2 * r**2 + k3 * r**3 156 | 157 | rgb = rgb * (1.0 + offset.view(-1, 1)) 158 | 159 | rgb = project_to_correct_range(rgb, mode="zick_zack") 160 | 161 | return rgb 162 | 163 | def get_parameters_with_optimization_information(self): 164 | params = [] 165 | if self.do_vignetting_correction: 166 | params.append( 167 | { 168 | "name": "vignetting_parameters", 169 | "tags": ["autodecoding", "vignetting"], 170 | "parameters": [self.vignetting_parameters], 171 | "optimizer": "Adam", 172 | "learning_rate": 1e-2, 173 | "decay_steps": self.learning_rate_decay_autodecoding_iterations, 174 | "decay_rate": self.learning_rate_decay_autodecoding_fraction, 175 | "weight_decay": 0.0, 176 | } 177 | ) 178 | return params 179 | 180 | def get_regularization_losses(self): 181 | return {} 182 | 183 | 184 | class PostCorrection(torch.nn.Module): 185 | def __init__(self, settings): 186 | super().__init__() 187 | 188 | self.color_calibration = ColorCalibration(settings) 189 | self.background = Background(settings) 190 | self.vignetting_correction = VignettingCorrection(settings) 191 | 192 | def forward(self, rgb, batch, accumulated_weights, is_training, returns=None): 193 | 194 | rgb = self.vignetting_correction(rgb, batch, returns=returns) 195 | 196 | if is_training: 197 | rgb = self.color_calibration(rgb, batch, returns=returns) 198 | 199 | rgb = self.background(rgb, batch, accumulated_weights, returns=returns) 200 | 201 | returns.add_return("corrected_rgb", rgb) 202 | 203 | return rgb 204 | 205 | def get_parameters_with_optimization_information(self): 206 | color_calibration_parameters = ( 207 | self.color_calibration.get_parameters_with_optimization_information() 208 | ) 209 | background_parameters = self.background.get_parameters_with_optimization_information() 210 | vignetting_correction_parameters = ( 211 | self.vignetting_correction.get_parameters_with_optimization_information() 212 | ) 213 | return ( 214 | color_calibration_parameters + background_parameters + vignetting_correction_parameters 215 | ) 216 | 217 | def get_regularization_losses(self): 218 | regularization_losses = {} 219 | regularization_losses[ 220 | "color_calibration" 221 | ] = self.color_calibration.get_regularization_losses() 222 | regularization_losses["background"] = self.background.get_regularization_losses() 223 | regularization_losses[ 224 | "vignetting_correction" 225 | ] = self.vignetting_correction.get_regularization_losses() 226 | return regularization_losses 227 | -------------------------------------------------------------------------------- /pre_correction.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | 8 | 9 | class PreCorrection(torch.nn.Module): 10 | def __init__(self, settings): 11 | super().__init__() 12 | 13 | def forward(self, batch, returns=None): 14 | return batch 15 | 16 | def get_parameters_with_optimization_information(self): 17 | return [] 18 | 19 | def get_regularization_losses(self): 20 | return {} 21 | -------------------------------------------------------------------------------- /pruning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import os 8 | 9 | import numpy as np 10 | import torch 11 | from binary_dataset import BinaryDataset 12 | from utils import szudzik 13 | 14 | LOGGER = logging.getLogger(__name__) 15 | 16 | 17 | def get_pruning_key(voxel_grid_size, timestep): 18 | key = szudzik(voxel_grid_size, timestep) 19 | return key 20 | 21 | 22 | class Pruning(torch.nn.Module): 23 | def __init__(self, settings, data_handler): 24 | super().__init__() 25 | 26 | self.do_use_pruning = settings.use_pruning 27 | self.voxel_grid_size = settings.voxel_grid_size 28 | 29 | self.no_pruning_probability = settings.no_pruning_probability 30 | 31 | self.voxel_grid = None # register_buffer("voxel_grid", torch.empty((voxel_grid_size, voxel_grid_size, voxel_grid_size), dtype=torch.bool)) # registering as buffer would store this in the checkpoint, which just wastes disk space 32 | 33 | self.current_timestep = None 34 | 35 | if self.do_use_pruning: 36 | self._dataset = BinaryDataset( 37 | data_handler.data_loader.get_dataset_folder(), 38 | name="foreground_voxel_grids", 39 | read_only=True, 40 | ) 41 | 42 | def forward(self, positions, is_training): 43 | # positions are normalized to unit cube 44 | # positions: num_rays x num_points_per_ray x 3 45 | 46 | input_shape = positions.shape[ 47 | :-1 48 | ] # num_rays x num_points_per_ray or num_rays * num_points_per_ray 49 | 50 | if ( 51 | not self.do_use_pruning 52 | or self.voxel_grid is None 53 | or (torch.rand(1) < self.no_pruning_probability and is_training) 54 | ): 55 | mask = torch.ones(input_shape, device=positions.device, dtype=torch.bool) 56 | return mask 57 | 58 | voxel_indices = positions.view(-1, 3) 59 | voxel_indices = torch.floor( 60 | voxel_indices * self.voxel_grid_size 61 | ).long() # num_rays * num_points_per_ray x 3 62 | mask = self.voxel_grid[voxel_indices[:, 0], voxel_indices[:, 1], voxel_indices[:, 2]] 63 | mask = mask.view(input_shape) 64 | 65 | return mask 66 | 67 | def load_voxel_grid(self, timestep): 68 | 69 | self.current_timestep = timestep 70 | 71 | if timestep is None: 72 | self.voxel_grid = None 73 | return 74 | 75 | if timestep == "all": 76 | key = "all" 77 | else: 78 | key = get_pruning_key(self.voxel_grid_size, float(timestep)) 79 | if key not in self._dataset: 80 | LOGGER.warning( 81 | "undesirable flow. got wrong timestep, fall back to 'all' voxel grid pruning" 82 | ) 83 | key = "all" 84 | 85 | from io import BytesIO 86 | 87 | voxel_grid_bytes = BytesIO(self._dataset.get_entry(key)) 88 | voxel_grid_bytes.seek(0) 89 | try: 90 | voxel_grid = np.load(voxel_grid_bytes) 91 | except Exception as exception: 92 | LOGGER.exception( 93 | "failed to load pruning at time: " + str(timestep) + " " + str(self.voxel_grid_size) 94 | ) 95 | raise exception 96 | self.voxel_grid = torch.from_numpy(voxel_grid["foreground_voxel_grid"]).cuda() 97 | 98 | def get_parameters_with_optimization_information(self): 99 | return [] 100 | 101 | def get_regularization_losses(self): 102 | return {} 103 | -------------------------------------------------------------------------------- /ray_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import numpy as np 9 | import torch 10 | from binary_dataset import BinaryDataset 11 | from tqdm import trange 12 | from utils import szudzik 13 | from multi_gpu import multi_gpu_barrier 14 | 15 | LOGGER = logging.getLogger(__name__) 16 | 17 | 18 | class RayBuilder: 19 | def __init__(self, rank=0, multi_gpu=False, max_num_precomputed_undistortions=None): 20 | 21 | if max_num_precomputed_undistortions is None: 22 | max_num_precomputed_undistortions = 1000 23 | self.max_num_precomputed_undistortions = max_num_precomputed_undistortions 24 | self.precomputed_undistortion = {} 25 | 26 | self.rank = rank 27 | self.multi_gpu = multi_gpu 28 | 29 | self._precomputed_dataset = None 30 | self.only_use_precomputed_dataset_for_nonzero_distortions = True 31 | 32 | def use_precomputed_dataset(self, dataset_folder, create=False): 33 | 34 | if create: 35 | if self.rank == 0: 36 | self._precomputed_dataset = BinaryDataset( 37 | dataset_folder, name="precomputed_rays", read_only=False 38 | ) 39 | self._precomputed_dataset.flush() 40 | if self.multi_gpu: 41 | multi_gpu_barrier(self.rank) # wait for dataset creation if not existing 42 | if self.rank > 0: 43 | self._precomputed_dataset = BinaryDataset( 44 | dataset_folder, name="precomputed_rays", read_only=True 45 | ) 46 | else: 47 | self._precomputed_dataset = BinaryDataset( 48 | dataset_folder, name="precomputed_rays", read_only=True 49 | ) 50 | 51 | def undistort(self, intrin, i, j): 52 | original_i_shape = i.shape 53 | original_j_shape = j.shape 54 | 55 | i = i.reshape(-1) 56 | j = j.reshape(-1) 57 | 58 | target_i = i.clone() 59 | target_j = j.clone() 60 | 61 | def _get_value(name): 62 | if name in intrin["distortion"]: 63 | return intrin["distortion"][name] 64 | else: 65 | return 0.0 66 | 67 | k1 = _get_value("k1") 68 | k2 = _get_value("k2") 69 | p1 = _get_value("p1") 70 | p2 = _get_value("p2") 71 | k3 = _get_value("k3") 72 | s1 = _get_value("s1") 73 | s2 = _get_value("s2") 74 | s3 = _get_value("s3") 75 | s4 = _get_value("s4") 76 | 77 | if all(x == 0.0 for x in [k1, k2, p1, p2, k3, s1, s2, s3, s4]): 78 | max_num_iterations = 0 79 | mean_error_i = torch.zeros(1) 80 | mean_error_j = torch.zeros(1) 81 | else: 82 | max_num_iterations = 20000 83 | custom_max_update = 2.0 # if above this value, downscale the update values to this value 84 | stability_mask_threshold = 1e-10 85 | convergence_threshold = 1e-8 # stop optimization if mean error below this threshold 86 | acceptance_threshold = ( 87 | 1e-4 # after optimization is over, accept the result if its error is below this value 88 | ) 89 | 90 | for _current_iter in range(max_num_iterations): 91 | 92 | if _current_iter % 10000 == 0 and _current_iter > 0: 93 | custom_max_update /= 5.0 94 | 95 | # components of current position 96 | r = i * i + j * j 97 | radial = k1 * r + k2 * r * r + k3 * r * r * r 98 | tangential_i = 2.0 * p1 * i * j + p2 * (r + 2.0 * i * i) 99 | tangential_j = p1 * (r + 2.0 * j * j) + 2.0 * p2 * i * j 100 | thin_prism_i = s1 * r + s2 * r * r 101 | thin_prism_j = s3 * r + s4 * r * r 102 | 103 | current_i = i + radial * i + tangential_i + thin_prism_i 104 | current_j = j + radial * j + tangential_j + thin_prism_j 105 | 106 | # residual 107 | 108 | error_i = current_i - target_i 109 | error_j = current_j - target_j 110 | 111 | # build 2x2 Jacobi matrix (error_i and error_j wrt. i and j) 112 | 113 | d_radial_wrt_i = (k1 + 2.0 * k2 * r + 3 * k3 * r * r) * 2.0 * i 114 | d_radial_wrt_j = (k1 + 2.0 * k2 * r + 3 * k3 * r * r) * 2.0 * j 115 | 116 | # i wrt i 117 | d_tangential_i_wrt_i = 2.0 * p1 * j + p2 * 6.0 * i 118 | d_thin_prism_i_wrt_i = s1 * 2.0 * i + (s2 * 2.0 * r) * 2.0 * i 119 | d_current_i_wrt_i = ( 120 | 1.0 121 | + (d_radial_wrt_i * i + radial * 1.0) 122 | + d_tangential_i_wrt_i 123 | + d_thin_prism_i_wrt_i 124 | ) 125 | 126 | # i wrt j 127 | d_tangential_i_wrt_j = 2.0 * p1 * i + p2 * 2.0 * j 128 | d_thin_prism_i_wrt_j = s1 * 2.0 * j + (s2 * 2.0 * r) * 2.0 * j 129 | d_current_i_wrt_j = d_radial_wrt_j * i + d_tangential_i_wrt_j + d_thin_prism_i_wrt_j 130 | 131 | # j wrt i 132 | d_tangential_j_wrt_i = p1 * 2.0 * i + 2.0 * p2 * j 133 | d_thin_prism_j_wrt_i = s3 * 2.0 * i + (s4 * 2.0 * r) * 2.0 * i 134 | d_current_j_wrt_i = d_radial_wrt_i * j + d_tangential_j_wrt_i + d_thin_prism_j_wrt_i 135 | 136 | # j wrt j 137 | d_tangential_j_wrt_j = p1 * 6.0 * j + 2.0 * p2 * i 138 | d_thin_prism_j_wrt_j = s3 * 2.0 * j + (s4 * 2.0 * r) * 2.0 * j 139 | d_current_j_wrt_j = ( 140 | 1.0 141 | + (d_radial_wrt_j * j + radial * 1.0) 142 | + d_tangential_j_wrt_j 143 | + d_thin_prism_j_wrt_j 144 | ) 145 | 146 | # Gauss-Newton with n=m 147 | denominator = ( 148 | d_current_i_wrt_i * d_current_j_wrt_j - d_current_i_wrt_j * d_current_j_wrt_i 149 | ) 150 | update_i = d_current_j_wrt_j * error_i - d_current_i_wrt_j * error_j 151 | update_j = -d_current_j_wrt_i * error_i + d_current_i_wrt_i * error_j 152 | 153 | update_i /= denominator 154 | update_j /= denominator 155 | 156 | # update 157 | stability_mask = torch.abs(denominator) > stability_mask_threshold 158 | 159 | max_update = torch.max( 160 | torch.abs(torch.cat([update_i[stability_mask], update_j[stability_mask]], 0)) 161 | ) 162 | if max_update > custom_max_update: 163 | update_i[stability_mask] = torch.clamp( 164 | update_i[stability_mask], min=-custom_max_update, max=custom_max_update 165 | ) 166 | update_j[stability_mask] = torch.clamp( 167 | update_j[stability_mask], min=-custom_max_update, max=custom_max_update 168 | ) 169 | 170 | i[stability_mask] -= update_i[stability_mask] 171 | j[stability_mask] -= update_j[stability_mask] 172 | 173 | mean_error_i = torch.mean(torch.abs(error_i)) 174 | mean_error_j = torch.mean(torch.abs(error_j)) 175 | if mean_error_i < convergence_threshold and mean_error_j < convergence_threshold: 176 | break 177 | 178 | LOGGER.debug( 179 | "undistortion error for " 180 | + str(intrin["intrinid"]) 181 | + ": " 182 | + str(mean_error_i.item()) 183 | + " " 184 | + str(mean_error_j.item()) 185 | ) 186 | if ( 187 | not torch.isfinite(mean_error_i) 188 | or mean_error_i > acceptance_threshold 189 | or not torch.isfinite(mean_error_j) 190 | or mean_error_j > acceptance_threshold 191 | ): 192 | LOGGER.warning("did not converge: " + str([k1, k2, p1, p2, k3, s1, s2, s3, s4])) 193 | LOGGER.warning( 194 | "undistortion error for " 195 | + str(intrin["intrinid"]) 196 | + ": " 197 | + str(mean_error_i.item()) 198 | + " " 199 | + str(mean_error_j.item()) 200 | ) 201 | raise RuntimeError("undistortion did not converge") 202 | 203 | i = i.reshape(original_i_shape) 204 | j = j.reshape(original_j_shape) 205 | 206 | return i, j 207 | 208 | def _convert_intrin_to_key(self, intrin): 209 | 210 | a = intrin["intrinid"] 211 | b = intrin["height"] # a proxy for the image rescaling "factor" 212 | 213 | key = szudzik(a, b) 214 | return key 215 | 216 | def maybe_get_precomputed_undistortion(self, intrin, device): 217 | 218 | key = self._convert_intrin_to_key(intrin) 219 | 220 | if key in self.precomputed_undistortion: 221 | i, j = self.precomputed_undistortion[key] 222 | return i.clone().to(device), j.clone().to(device) 223 | 224 | if self.only_use_precomputed_dataset_for_nonzero_distortions and ( 225 | "distortion" not in intrin 226 | or ("distortion" in intrin and all(x == 0 for x in intrin["distortion"].values())) 227 | ): 228 | return None, None 229 | 230 | if self._precomputed_dataset is not None: 231 | if key in self._precomputed_dataset: 232 | i_j_bytes = self._precomputed_dataset.get_entry(key) 233 | try: 234 | i_j = np.frombuffer(i_j_bytes, dtype=np.float32) 235 | except Exception as exception: 236 | LOGGER.warning("failed at: " + str(intrin)) 237 | raise exception 238 | i, j = torch.from_numpy(i_j["i"]), torch.from_numpy(i_j["j"]) 239 | self.maybe_store_precomputed_undistortion( 240 | intrin, i, j 241 | ) # maybe load into RAM dictionary (self.precomputed_undistortion) 242 | return i.clone().to(device), j.clone().to( 243 | device 244 | ) # in case it didn't get stored, need to return here 245 | 246 | return None, None 247 | 248 | def maybe_store_precomputed_undistortion(self, intrin, i, j): 249 | 250 | key = self._convert_intrin_to_key(intrin) 251 | 252 | if ( 253 | key not in self.precomputed_undistortion 254 | and len(self.precomputed_undistortion) < self.max_num_precomputed_undistortions 255 | ): 256 | self.precomputed_undistortion[key] = (i.clone().cpu(), j.clone().cpu()) 257 | 258 | if self.only_use_precomputed_dataset_for_nonzero_distortions and ( 259 | "distortion" not in intrin 260 | or ("distortion" in intrin and all(x == 0 for x in intrin["distortion"].values())) 261 | ): 262 | return 263 | 264 | if self._precomputed_dataset is not None: 265 | if key not in self._precomputed_dataset: 266 | from io import BytesIO 267 | 268 | i_j_bytes = BytesIO() 269 | np.savez_compressed( 270 | i_j_bytes, 271 | i=i.cpu().numpy().astype(np.float32), 272 | j=j.cpu().numpy().astype(np.float32), 273 | ) 274 | self._precomputed_dataset.maybe_add_entry(i_j_bytes, key=key) 275 | self._precomputed_dataset.flush() 276 | 277 | def build(self, extrin, intrin): 278 | 279 | device = extrin["rotation"].device 280 | 281 | i, j = self.maybe_get_precomputed_undistortion(intrin, device) 282 | if i is None: 283 | if self.rank == 0: 284 | 285 | # (0, 0) is top left (?) 286 | i, j = torch.meshgrid( 287 | torch.linspace(0, intrin["width"] - 1, intrin["width"], device=device), 288 | torch.linspace(0, intrin["height"] - 1, intrin["height"], device=device), 289 | indexing="ij", 290 | ) # pytorch's meshgrid has indexing='ij' 291 | i = i.t() 292 | j = j.t() 293 | 294 | i = (i - intrin["center_x"]) / intrin["focal_x"] 295 | j = (j - intrin["center_y"]) / intrin["focal_y"] 296 | 297 | if "distortion" in intrin: 298 | i, j = self.undistort(intrin, i, j) 299 | 300 | self.maybe_store_precomputed_undistortion(intrin, i, j) 301 | 302 | if self.multi_gpu: 303 | multi_gpu_barrier(self.rank) 304 | if self.rank > 0: 305 | self._precomputed_dataset.flush() # get updated dataset with current undistortion 306 | i, j = self.maybe_get_precomputed_undistortion(intrin, device) 307 | 308 | dirs = torch.stack([i, -j, -torch.ones_like(i, device=device)], -1) 309 | # Rotate ray directions from camera frame to the world frame 310 | rays_dir = torch.sum( 311 | dirs[..., np.newaxis, :] * extrin["rotation"], -1 312 | ) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 313 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 314 | rays_origin = extrin["translation"].expand(rays_dir.shape) 315 | 316 | return { 317 | "rays_origin": rays_origin, # pytorch, H x W x 3 318 | "rays_dir": rays_dir, 319 | } 320 | 321 | def build_multiple(self, extrins, intrins, coordinate_subsets=None): 322 | 323 | num_images = len(extrins) 324 | 325 | if coordinate_subsets is None: 326 | coordinate_subsets = {"everything": None} 327 | 328 | rays_origin_subsets = {subset_name: None for subset_name in coordinate_subsets.keys()} 329 | rays_dir_subsets = {subset_name: None for subset_name in coordinate_subsets.keys()} 330 | 331 | for index in trange(len(extrins)): 332 | 333 | extrin = extrins[index] 334 | intrin = intrins[index] 335 | 336 | rays_dict = self.build(extrin, intrin) 337 | rays_origin = rays_dict["rays_origin"] 338 | rays_dir = rays_dict["rays_dir"] 339 | 340 | for subset_name, coordinate_subset in coordinate_subsets.items(): 341 | 342 | if coordinate_subset is None: 343 | this_rays_origin = rays_origin 344 | this_rays_dir = rays_dir 345 | else: 346 | y_coordinates = coordinate_subset[index, :, 0] 347 | x_coordinates = coordinate_subset[index, :, 1] 348 | this_rays_origin = rays_origin[y_coordinates, x_coordinates] 349 | this_rays_dir = rays_dir[y_coordinates, x_coordinates] 350 | 351 | if rays_origin_subsets[subset_name] is None: 352 | rays_origin_subsets[subset_name] = torch.empty( 353 | (num_images,) + this_rays_origin.shape, dtype=torch.float32, device="cpu" 354 | ) 355 | rays_dir_subsets[subset_name] = torch.empty( 356 | (num_images,) + this_rays_dir.shape, dtype=torch.float32, device="cpu" 357 | ) 358 | 359 | rays_origin_subsets[subset_name][index] = this_rays_origin.cpu() 360 | rays_dir_subsets[subset_name][index] = this_rays_dir.cpu() 361 | 362 | return { 363 | "rays_origin": rays_origin_subsets, # torch cpu, N x H x W x 3 364 | "rays_dir": rays_dir_subsets, 365 | } 366 | -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import torch 9 | from post_correction import PostCorrection 10 | from pre_correction import PreCorrection 11 | from utils import Returns, project_to_correct_range 12 | 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | 16 | class Renderer(torch.nn.Module): 17 | def __init__(self, settings): 18 | super().__init__() 19 | 20 | self.pre_correction = PreCorrection(settings) 21 | self.post_correction = PostCorrection(settings) 22 | 23 | self.points_per_chunk = settings.points_per_chunk 24 | self.default_num_points_per_ray = settings.num_points_per_ray 25 | self.default_disparity_sampling = settings.disparity_sampling # boolean 26 | self.raw_noise_std = settings.raw_noise_std 27 | self.do_ngp_mip_nerf = settings.do_ngp_mip_nerf 28 | 29 | self.use_half_precision = settings.use_half_precision 30 | 31 | def _generate_points_on_rays( 32 | self, batch, num_points_per_ray, scene, is_training, returns, disparity_sampling=None 33 | ): 34 | 35 | if disparity_sampling is None: 36 | disparity_sampling = self.default_disparity_sampling 37 | 38 | device = batch["rays_origin"].device 39 | num_rays = batch["rays_origin"].shape[0] 40 | 41 | # near/far 42 | t_vals = torch.linspace(0.0, 1.0, steps=num_points_per_ray, device=device) 43 | t_vals = t_vals.expand([num_rays, num_points_per_ray]) # num_rays x num_points_per_ray 44 | if disparity_sampling: # linear in inverse depth 45 | z_vals = 1.0 / ( 46 | 1.0 / batch["near"].view(-1, 1) * (1.0 - t_vals) 47 | + 1.0 / batch["far"].view(-1, 1) * (t_vals) 48 | ) 49 | else: # linear in depth 50 | z_vals = batch["near"].view(-1, 1) * (1.0 - t_vals) + batch["far"].view(-1, 1) * ( 51 | t_vals 52 | ) 53 | 54 | if is_training: 55 | # get intervals between samples 56 | mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) 57 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 58 | lower = torch.cat([z_vals[..., :1], mids], -1) 59 | # stratified samples in those intervals 60 | t_rand = torch.rand(z_vals.shape, device=device) 61 | z_vals = lower + (upper - lower) * t_rand 62 | 63 | positions = ( 64 | batch["rays_origin"][..., None, :] 65 | + batch["rays_dir"][..., None, :] * z_vals[..., :, None] 66 | ) # num_rays x num_points_per_ray x 3 67 | batch["positions"] = positions 68 | returns.add_return("unnormalized_undeformed_positions", positions) 69 | 70 | do_normalize_z_vals = True # relevant for _compose_rays 71 | if do_normalize_z_vals: 72 | scaling = torch.mean(scene.get_pos_max() - scene.get_pos_min()) 73 | z_vals = z_vals / scaling 74 | 75 | batch["timesteps"] = ( 76 | batch["timesteps"].view(-1, 1).tile([1, num_points_per_ray]) 77 | ) # num_rays x num_points_per_ray 78 | 79 | batch["intrinids"] = ( 80 | batch["intrinids"].view(-1, 1).tile([1, num_points_per_ray]) 81 | ) # num_rays x num_points_per_ray 82 | 83 | view_directions = batch["rays_dir"] / torch.norm( 84 | batch["rays_dir"], dim=-1, keepdim=True 85 | ) # num_rays x 3 86 | batch["view_directions"] = view_directions.view(num_rays, 1, 3).tile( 87 | [1, num_points_per_ray, 1] 88 | ) # num_rays x num_points_per_ray x 3 89 | 90 | batch["mip_scale"] = None 91 | 92 | return batch, z_vals 93 | 94 | def _compose_rays( 95 | self, 96 | raw_rgb_per_point, 97 | raw_alpha, 98 | z_vals, 99 | rays_dir, 100 | pruning_mask, 101 | is_training, 102 | points_per_ray, 103 | returns, 104 | ): 105 | 106 | device = raw_rgb_per_point.device 107 | corrective_factor = float(points_per_ray) / 1024.0 108 | raw2alpha = lambda raw_alpha, dists, act_fn=torch.nn.functional.relu: 1.0 - torch.exp( 109 | -act_fn(raw_alpha) * dists * corrective_factor 110 | ) 111 | 112 | dists = z_vals[..., 1:] - z_vals[..., :-1] 113 | dists = torch.cat( 114 | [dists, torch.tensor([1e10], device=device).expand(dists[..., :1].shape)], -1 115 | ) # [N_rays, N_samples] 116 | 117 | dists = dists * torch.norm(rays_dir[..., None, :], dim=-1) 118 | 119 | if is_training and self.raw_noise_std > 0.0: 120 | noise = torch.randn(raw_alpha.shape, device=device) * self.raw_noise_std 121 | noise[~pruning_mask] = 0.0 # pruned samples get no noise 122 | else: 123 | noise = 0.0 124 | 125 | # noise is added to alpha during accumulation 126 | alpha = raw2alpha(raw_alpha + noise, dists) # [N_rays, N_samples] 127 | returns.add_return("alpha", alpha) 128 | 129 | weights = ( 130 | alpha 131 | * torch.cumprod( 132 | torch.cat( 133 | [torch.ones((alpha.shape[0], 1), device=device), 1.0 - alpha + 1e-10], -1 134 | ), 135 | -1, 136 | )[:, :-1] 137 | ) 138 | returns.add_return("weights", weights) 139 | 140 | rgb = torch.sum(weights[..., None] * raw_rgb_per_point, -2) # [N_rays, 3] 141 | returns.add_return("uncorrected_rgb", rgb) 142 | 143 | accumulated_weights = torch.sum(weights, -1) 144 | returns.add_return("accumulated_weights", accumulated_weights) 145 | 146 | depth = ( 147 | torch.sum(weights[:, :-1] * z_vals[:, :-1], -1) 148 | + (1.0 - accumulated_weights + weights[:, -1]) * z_vals[:, -1] 149 | ) 150 | returns.add_return("depth", depth) 151 | 152 | returns.add_return("z_vals", z_vals) 153 | 154 | disparity = 1.0 / torch.max(1e-4 * torch.ones_like(depth), depth / torch.sum(weights, -1)) 155 | returns.add_return("disparity", disparity) 156 | 157 | return rgb, accumulated_weights 158 | 159 | def render(self, *args, **kwargs): 160 | if self.use_half_precision: 161 | with torch.autocast("cuda"): # tag:half_precision 162 | return self._render(*args, **kwargs) 163 | else: 164 | return self._render(*args, **kwargs) 165 | 166 | def _render( 167 | self, 168 | batch, 169 | scene, 170 | returns=None, 171 | subreturns=None, 172 | points_per_ray=None, 173 | is_training=False, 174 | pull_to_cpu=False, 175 | iterator_yield=False, 176 | **kwargs 177 | ): 178 | 179 | if points_per_ray is None: 180 | points_per_ray = self.default_num_points_per_ray 181 | 182 | if returns is None: 183 | returns = Returns(restricted=["corrected_rgb"]) 184 | returns.activate_mode("coarse") 185 | subreturns = Returns(restricted=["corrected_rgb"]) 186 | elif subreturns is None: 187 | subreturns = Returns(restricted=returns.get_restricted_list()) 188 | 189 | num_rays = batch["rays_origin"].shape[0] 190 | rays_per_chunk = max(1, self.points_per_chunk // points_per_ray) 191 | 192 | def generator_wrapper(): # hacky way to allow path_renderer to save memory 193 | for chunkid, chunk_start in enumerate(range(0, num_rays, rays_per_chunk)): 194 | 195 | subreturns.activate_mode(chunkid) 196 | 197 | subbatch = { 198 | key: tensor[chunk_start : chunk_start + rays_per_chunk] 199 | for key, tensor in batch.items() 200 | } 201 | 202 | subbatch = self.pre_correction(subbatch, returns=subreturns) 203 | 204 | subbatch, z_vals = self._generate_points_on_rays( 205 | subbatch, points_per_ray, scene, is_training, returns=subreturns 206 | ) 207 | 208 | raw_rgb_per_point, raw_alpha, pruning_mask = scene( 209 | subbatch["positions"], 210 | subbatch["view_directions"], 211 | subbatch["timesteps"], 212 | mip_scales=subbatch["mip_scale"], 213 | is_training=is_training, 214 | returns=subreturns, 215 | **kwargs 216 | ) 217 | 218 | rgb, accumulated_weights = self._compose_rays( 219 | raw_rgb_per_point, 220 | raw_alpha, 221 | z_vals, 222 | subbatch["rays_dir"], 223 | pruning_mask, 224 | is_training, 225 | points_per_ray, 226 | returns=subreturns, 227 | ) 228 | 229 | rgb = self.post_correction( 230 | rgb, subbatch, accumulated_weights, is_training, returns=subreturns 231 | ) 232 | 233 | if pull_to_cpu: 234 | subreturns.pull_to_cpu() 235 | 236 | if iterator_yield: # used by path_renderer to save memory 237 | yield subreturns 238 | 239 | if iterator_yield: # only for path_renderer 240 | return generator_wrapper 241 | else: # main branch 242 | for _ in generator_wrapper(): # hacky way to run the generator 243 | pass 244 | returns.add_returns(subreturns.concatenate_returns()) 245 | rgb = returns.get_returns()["corrected_rgb"] 246 | return rgb 247 | 248 | def get_parameters_with_optimization_information(self): 249 | return ( 250 | self.pre_correction.get_parameters_with_optimization_information() 251 | + self.post_correction.get_parameters_with_optimization_information() 252 | ) 253 | 254 | def get_regularization_losses(self): 255 | regularization_losses = {} 256 | regularization_losses["pre_correction"] = self.pre_correction.get_regularization_losses() 257 | regularization_losses["post_correction"] = self.post_correction.get_regularization_losses() 258 | return regularization_losses 259 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.19.2 2 | imageio-ffmpeg==0.4.7 3 | matplotlib==3.5.2 4 | tqdm 5 | scikit-image==0.19.3 6 | opencv-python>=4.8.1.78 7 | coloredlogs==15.0.1 8 | cmake==3.23.* 9 | -------------------------------------------------------------------------------- /scene.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import numpy as np 9 | import torch 10 | from canonical_model import get_canonical_model 11 | from deformation_model import get_deformation_model 12 | from pruning import Pruning 13 | from utils import Returns, infill_masked 14 | 15 | LOGGER = logging.getLogger(__name__) 16 | 17 | 18 | class Scene(torch.nn.Module): 19 | def __init__(self, settings, data_handler): 20 | super().__init__() 21 | self.canonical_model = get_canonical_model(settings) 22 | 23 | self.use_viewdirs = settings.use_viewdirs 24 | 25 | self.use_deformations = True 26 | if not self.use_deformations: 27 | self.deformation_model = None 28 | else: 29 | self.deformation_model = get_deformation_model( 30 | settings, data_handler.get_timeline_range() 31 | ) 32 | 33 | self.pruning = Pruning(settings, data_handler) 34 | 35 | # necessary for determine_nerf_volume_extent() to work. will be overwritten with correct values 36 | self.register_buffer( 37 | "pos_max", torch.from_numpy(np.array([1e5, 1e5, 1e5], dtype=np.float32)) 38 | ) 39 | self.register_buffer( 40 | "pos_min", torch.from_numpy(np.array([-1e5, -1e5, -1e5], dtype=np.float32)) 41 | ) 42 | 43 | def forward( 44 | self, positions, view_directions, timesteps, mip_scales, is_training, returns=None, **kwargs 45 | ): 46 | # timesteps in [0,1] 47 | timesteps = torch.clamp(timesteps, min=0.0, max=1.0) 48 | 49 | if returns is None: 50 | returns = Returns(restricted=[]) 51 | returns.activate_mode("coarse") 52 | 53 | # normalize positions into unit cube for NGP 54 | positions = (positions - self.pos_min) / (self.pos_max - self.pos_min) # in [0,1] 55 | positions.requires_grad = True # necessary for gradient-based losses 56 | returns.add_return("normalized_undeformed_positions", positions, clone=False) 57 | 58 | mask = self.pruning(positions, is_training=is_training) 59 | returns.set_mask(mask=mask) # fills in tensors to num_rays x num_points_per_ray 60 | 61 | visualize_pruning = False 62 | if visualize_pruning: 63 | self._visualize(positions, mask) 64 | 65 | # removing all samples in tensors can lead to issues, keep at least one dummy 66 | completely_pruned = not torch.any(mask) 67 | if completely_pruned: 68 | mask_shape = mask.shape 69 | mask = mask.view(-1) 70 | mask[0] = True 71 | mask = mask.view(mask_shape) 72 | 73 | positions = positions[mask].view(-1, 3) 74 | view_directions = view_directions[mask].view(-1, 3) 75 | timesteps = timesteps[mask].view(-1, 1) 76 | 77 | if not self.use_deformations: 78 | pref_timesteps = None 79 | else: 80 | positions, view_directions, pref_timesteps = self.deformation_model( 81 | positions, timesteps, mask, is_training=is_training, returns=returns 82 | ) 83 | 84 | view_directions = view_directions[mask].view(-1, 3) 85 | if self.use_viewdirs: 86 | returns.add_return("view_directions", view_directions) 87 | 88 | rgb, alpha = self.canonical_model( 89 | positions, 90 | view_directions, 91 | timesteps=timesteps, 92 | pref_timesteps=pref_timesteps, 93 | mip_scales=mip_scales, 94 | returns=returns, 95 | **kwargs 96 | ) 97 | 98 | if completely_pruned: 99 | alpha *= 0.0 100 | 101 | returns.set_mask(mask=None) 102 | 103 | # get infilled num_rays x num_points_per_ray tensors 104 | rgb = infill_masked(mask, rgb, infill_value=0) 105 | alpha = infill_masked(mask, alpha, infill_value=0) 106 | 107 | return rgb, alpha, mask 108 | 109 | def _visualize(self, positions, mask): 110 | if self.pruning.voxel_grid is not None: 111 | 112 | from tqdm import tqdm 113 | 114 | LOGGER.info("writing debug point clouds") 115 | 116 | debug_pos = positions * (self.pos_max - self.pos_min) + self.pos_min 117 | prune = debug_pos[~mask].view(-1, 3) 118 | keep = debug_pos[mask].view(-1, 3) 119 | samples_list = [] 120 | for x, y, z in tqdm(prune.cpu().detach().numpy()[::100]): 121 | samples_list.append("v " + str(x) + " " + str(y) + " " + str(z) + " 1 0 0") 122 | for x, y, z in tqdm(keep.cpu().detach().numpy()[::100]): 123 | samples_list.append("v " + str(x) + " " + str(y) + " " + str(z) + " 0 1 0") 124 | samples_string = "\n".join(samples_list) 125 | output_file = "samples.obj" 126 | with open(output_file, "w") as output_file: 127 | output_file.write(samples_string) 128 | 129 | voxel_grid_list = [] 130 | size = self.pruning.voxel_grid_size 131 | for x in tqdm(range(size)): 132 | for y in range(size): 133 | for z in range(size): 134 | occ = self.pruning.voxel_grid[x, y, z] 135 | if occ: 136 | x2 = x + 0.5 137 | y2 = y + 0.5 138 | z2 = z + 0.5 139 | x2 = (x2 / size) * (self.pos_max[0] - self.pos_min[0]) + self.pos_min[0] 140 | y2 = (y2 / size) * (self.pos_max[1] - self.pos_min[1]) + self.pos_min[1] 141 | z2 = (z2 / size) * (self.pos_max[2] - self.pos_min[2]) + self.pos_min[2] 142 | voxel = ( 143 | "v " 144 | + str(x2.item()) 145 | + " " 146 | + str(y2.item()) 147 | + " " 148 | + str(z2.item()) 149 | + " 0 0 1" 150 | ) 151 | voxel_grid_list.append(voxel) 152 | voxel_grid_string = "\n".join(voxel_grid_list) 153 | output_file = "voxel_grid.obj" 154 | with open(output_file, "w") as output_file: 155 | output_file.write(voxel_grid_string) 156 | 157 | def set_pos_max_min(self, pos_max, pos_min): 158 | self.pos_max = pos_max 159 | self.pos_min = pos_min 160 | 161 | def get_pos_max(self): 162 | return self.pos_max 163 | 164 | def get_pos_min(self): 165 | return self.pos_min 166 | 167 | def state_dict(self): 168 | state_dict = super().state_dict() 169 | state_dict["pos_max"] = self.pos_max 170 | state_dict["pos_min"] = self.pos_min 171 | return state_dict 172 | 173 | def load_state_dict(self, state_dict): 174 | super().load_state_dict(state_dict) 175 | self.pos_max = state_dict["pos_max"] 176 | self.pos_min = state_dict["pos_min"] 177 | 178 | def _add_tags(self, parameters, tags): 179 | for parameter in parameters: 180 | parameter["tags"] += tags 181 | return parameters 182 | 183 | def step(self): 184 | self.canonical_model.step() 185 | if self.use_deformations: 186 | self.deformation_model.step() 187 | 188 | def get_parameters_with_optimization_information(self): 189 | canonical_parameters = self.canonical_model.get_parameters_with_optimization_information() 190 | canonical_parameters = self._add_tags(canonical_parameters, ["canonical"]) 191 | 192 | if not self.use_deformations: 193 | deformation_parameters = [] 194 | else: 195 | deformation_parameters = ( 196 | self.deformation_model.get_parameters_with_optimization_information() 197 | ) 198 | deformation_parameters = self._add_tags(deformation_parameters, ["deformation"]) 199 | 200 | return canonical_parameters + deformation_parameters 201 | 202 | def get_regularization_losses(self): 203 | 204 | regularization_losses = {} 205 | 206 | regularization_losses["canonical_model"] = self.canonical_model.get_regularization_losses() 207 | 208 | if self.use_deformations: 209 | regularization_losses[ 210 | "deformation_model" 211 | ] = self.deformation_model.get_regularization_losses() 212 | 213 | return regularization_losses 214 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import configargparse 7 | 8 | 9 | def config_parser(config_file=None): 10 | 11 | if config_file is None: 12 | config_file = "./configs/default.txt" 13 | 14 | parser = configargparse.ArgumentParser() 15 | parser.add_argument( 16 | "--config", default=config_file, is_config_file=True, help="config file path" 17 | ) 18 | parser.add_argument("--expname", type=str, help="experiment name") 19 | parser.add_argument( 20 | "--basedir", type=str, default="./logs/", help="where to store ckpts and logs" 21 | ) 22 | parser.add_argument( 23 | "--temporary_basedir", type=str, default=None, help="where to store ckpts and logs" 24 | ) 25 | parser.add_argument("--datadir", type=str, default=None, help="input data directory") 26 | parser.add_argument( 27 | "--allow_scratch_datadir_copy", 28 | action="store_true", 29 | help="only take random rays from 1 image at a time", 30 | ) 31 | 32 | # training options 33 | parser.add_argument("--netdepth", type=int, default=8, help="layers in network") 34 | parser.add_argument("--netwidth", type=int, default=256, help="channels per layer") 35 | parser.add_argument("--netdepth_fine", type=int, default=8, help="layers in fine network") 36 | parser.add_argument( 37 | "--netwidth_fine", type=int, default=256, help="channels per layer in fine network" 38 | ) 39 | parser.add_argument( 40 | "--batch_size", 41 | type=int, 42 | default=32 * 32 * 4, 43 | help="batch size (number of random rays per gradient step)", 44 | ) 45 | parser.add_argument( 46 | "--points_per_chunk", 47 | type=int, 48 | default=1e7, 49 | help="number of pts sent through network in parallel, decrease if running out of memory", 50 | ) 51 | parser.add_argument( 52 | "--no_batching", action="store_true", help="only take random rays from 1 image at a time" 53 | ) 54 | parser.add_argument( 55 | "--no_reload", action="store_true", help="do not reload weights from saved ckpt" 56 | ) 57 | parser.add_argument( 58 | "--ft_path", 59 | type=str, 60 | default=None, 61 | help="specific weights npy file to reload for coarse network", 62 | ) 63 | parser.add_argument("--tracking_mode", type=str, default="plain", help="plain, temporal") 64 | parser.add_argument("--reconstruction_loss_type", type=str, default="L1", help="L1, L2") 65 | parser.add_argument( 66 | "--smooth_deformations_type", 67 | type=str, 68 | default="finite", 69 | help="finite, divergence, jacobian", 70 | ) 71 | parser.add_argument( 72 | "--num_iterations", type=int, default=200000, help="number of training iterations" 73 | ) 74 | parser.add_argument( 75 | "--learning_rate_decay_autodecoding_fraction", 76 | type=float, 77 | default=1e-2, 78 | help="what fraction of the learning rate to reduce to", 79 | ) 80 | parser.add_argument( 81 | "--learning_rate_decay_autodecoding_iterations", 82 | type=int, 83 | default=0, 84 | help='number of iterations to reduce learning rate by "fraction"', 85 | ) 86 | parser.add_argument( 87 | "--learning_rate_decay_mlp_fraction", 88 | type=float, 89 | default=1e-2, 90 | help="what fraction of the learning rate to reduce to", 91 | ) 92 | parser.add_argument( 93 | "--learning_rate_decay_mlp_iterations", 94 | type=int, 95 | default=0, 96 | help="number of iterations to reduce learning rate", 97 | ) 98 | parser.add_argument( 99 | "--activation_function", 100 | type=str, 101 | default="ReLU", 102 | help="ReLU, Exponential, Sine, Squareplus", 103 | ) 104 | 105 | parser.add_argument( 106 | "--use_visualizer", action="store_true", help="auto-decoded latent codes or raw time" 107 | ) 108 | parser.add_argument( 109 | "--test_cameras", nargs="+", type=str, help="extrinsic names of test cameras", default=[] 110 | ) 111 | 112 | # rendering options 113 | parser.add_argument( 114 | "--num_points_per_ray", type=int, default=64, help="number of coarse samples per ray" 115 | ) 116 | parser.add_argument( 117 | "--N_importance", type=int, default=0, help="number of additional fine samples per ray" 118 | ) 119 | parser.add_argument( 120 | "--perturb", type=float, default=1.0, help="set to 0. for no jitter, 1. for jitter" 121 | ) 122 | parser.add_argument( 123 | "--use_viewdirs", action="store_true", help="use full 5D input instead of 3D" 124 | ) 125 | parser.add_argument( 126 | "--i_embed", type=int, default=0, help="set 0 for default positional encoding, -1 for none" 127 | ) 128 | parser.add_argument( 129 | "--multires", 130 | type=int, 131 | default=10, 132 | help="log2 of max freq for positional encoding (3D location)", 133 | ) 134 | parser.add_argument( 135 | "--multires_views", 136 | type=int, 137 | default=4, 138 | help="log2 of max freq for positional encoding (2D direction)", 139 | ) 140 | parser.add_argument( 141 | "--raw_noise_std", 142 | type=float, 143 | default=0.0, 144 | help="std dev of noise added to regularize sigma_a output, 1e0 recommended", 145 | ) 146 | parser.add_argument( 147 | "--use_background", 148 | action="store_true", 149 | help="use static background images when composing rays", 150 | ) 151 | parser.add_argument( 152 | "--brightness_variability", 153 | type=float, 154 | default=0.0, 155 | help="maximum allowed change in learned brightness. up to 1.0 makes sense. 0.0 turns it off.", 156 | ) 157 | parser.add_argument( 158 | "--variant", type=str, default="snf", help="options: llff / blender / deepvoxels" 159 | ) 160 | 161 | parser.add_argument( 162 | "--render_only", 163 | action="store_true", 164 | help="do not optimize, reload weights and render out render_poses path", 165 | ) 166 | parser.add_argument( 167 | "--render_test", 168 | action="store_true", 169 | help="render the test set instead of render_poses path", 170 | ) 171 | parser.add_argument( 172 | "--render_factor", 173 | type=int, 174 | default=0, 175 | help="downsampling factor to speed up rendering, set 4 or 8 for fast preview", 176 | ) 177 | 178 | parser.add_argument( 179 | "--color_calibration_mode", 180 | type=str, 181 | default="none", 182 | help="options: none / full_matrix / neural_volumes", 183 | ) 184 | 185 | # training options 186 | parser.add_argument( 187 | "--weight_smooth_deformations", 188 | type=float, 189 | default=0.0, 190 | help="weight for regularization loss", 191 | ) 192 | parser.add_argument( 193 | "--weight_coarse_smooth_deformations", 194 | type=float, 195 | default=0.0, 196 | help="weight for regularization loss", 197 | ) 198 | parser.add_argument( 199 | "--weight_fine_smooth_deformations", 200 | type=float, 201 | default=0.0, 202 | help="weight for regularization loss", 203 | ) 204 | parser.add_argument( 205 | "--weight_parameter_regularization", 206 | type=float, 207 | default=0.0, 208 | help="weight for weight decay in AdamW", 209 | ) 210 | parser.add_argument( 211 | "--weight_background_loss", type=float, default=0.0, help="weight for background loss" 212 | ) 213 | parser.add_argument( 214 | "--weight_brightness_change_regularization", 215 | type=float, 216 | default=0.0, 217 | help="weight for brightness change regularization", 218 | ) 219 | parser.add_argument( 220 | "--weight_hard_surface_loss", 221 | type=float, 222 | default=0.0, 223 | help="weight for brightness change regularization", 224 | ) 225 | parser.add_argument( 226 | "--weight_small_fine_offsets_loss", 227 | type=float, 228 | default=0.0, 229 | help="weight for brightness change regularization", 230 | ) 231 | parser.add_argument( 232 | "--weight_similar_coarse_and_total_offsets_loss", 233 | type=float, 234 | default=0.0, 235 | help="weight for brightness change regularization", 236 | ) 237 | parser.add_argument( 238 | "--weight_per_frequency_regularization", 239 | type=float, 240 | default=0.0, 241 | help="weight for brightness change regularization", 242 | ) 243 | 244 | parser.add_argument( 245 | "--coarse_and_fine", action="store_true", help="auto-decoded latent codes or raw time" 246 | ) 247 | parser.add_argument( 248 | "--fine_range", 249 | type=float, 250 | default=0.1, 251 | help="hard restriction on the range of the fine deformation model in normalized space", 252 | ) 253 | parser.add_argument( 254 | "--deformation_per_timestep_decay_rate", 255 | type=float, 256 | default=0.1, 257 | help="weight for brightness change regularization", 258 | ) 259 | parser.add_argument( 260 | "--slow_canonical_per_timestep_learning_rate", 261 | type=float, 262 | default=1e-5, 263 | help="weight for brightness change regularization", 264 | ) 265 | parser.add_argument( 266 | "--fix_coarse_after_a_while", 267 | action="store_true", 268 | help="auto-decoded latent codes or raw time", 269 | ) 270 | parser.add_argument( 271 | "--let_canonical_vary_at_last", 272 | action="store_true", 273 | help="auto-decoded latent codes or raw time", 274 | ) 275 | parser.add_argument( 276 | "--let_only_brightness_vary", 277 | action="store_true", 278 | help="auto-decoded latent codes or raw time", 279 | ) 280 | parser.add_argument( 281 | "--keep_coarse_mlp_constant", 282 | action="store_true", 283 | help="auto-decoded latent codes or raw time", 284 | ) 285 | parser.add_argument( 286 | "--coarse_parametrization", 287 | type=str, 288 | default="MLP", 289 | help="auto-decoded latent codes or raw time", 290 | ) 291 | parser.add_argument( 292 | "--use_global_transform", action="store_true", help="auto-decoded latent codes or raw time" 293 | ) 294 | parser.add_argument( 295 | "--do_vignetting_correction", 296 | action="store_true", 297 | help="auto-decoded latent codes or raw time", 298 | ) 299 | parser.add_argument( 300 | "--coarse_mlp_weight_decay", type=float, default=1e-2, help="weight for background loss" 301 | ) 302 | parser.add_argument( 303 | "--coarse_mlp_skip_connections", 304 | type=int, 305 | default=0, 306 | help="downsample factor for LLFF images", 307 | ) 308 | parser.add_argument( 309 | "--smoothness_robustness_threshold", 310 | type=float, 311 | default=1e-2, 312 | help="weight for background loss", 313 | ) 314 | 315 | parser.add_argument( 316 | "--use_half_precision", 317 | action="store_true", 318 | help="nerfies deformation parametrization instead of offsets field", 319 | ) 320 | parser.add_argument( 321 | "--do_zero_out", 322 | action="store_true", 323 | help="nerfies deformation parametrization instead of offsets field", 324 | ) 325 | 326 | parser.add_argument( 327 | "--use_pruning", 328 | action="store_true", 329 | help="nerfies deformation parametrization instead of offsets field", 330 | ) 331 | parser.add_argument( 332 | "--voxel_grid_size", type=int, default=128, help="downsample factor for LLFF images" 333 | ) 334 | parser.add_argument( 335 | "--no_pruning_probability", 336 | type=float, 337 | default=0.0, 338 | help="hard restriction on the range of the fine deformation model in normalized space", 339 | ) 340 | 341 | parser.add_argument( 342 | "--per_frequency_training_mode", 343 | type=str, 344 | default="only_last", 345 | help="options: up_to_last / only_last / all_at_once", 346 | ) 347 | parser.add_argument( 348 | "--optimization_mode", type=str, default="per_timestep", help="options: per_timestep / all" 349 | ) 350 | parser.add_argument("--do_ngp_mip_nerf", action="store_true", help="submit to slurm") 351 | parser.add_argument("--do_pref", action="store_true", help="submit to slurm") 352 | parser.add_argument( 353 | "--pref_tau_window", 354 | type=int, 355 | default=3, 356 | help="will load 1/N images from test/val sets, useful for large datasets like deepvoxels", 357 | ) 358 | parser.add_argument( 359 | "--pref_dataset_index", 360 | type=int, 361 | default=-1, 362 | help="will load 1/N images from test/val sets, useful for large datasets like deepvoxels", 363 | ) 364 | parser.add_argument("--do_nrnerf", action="store_true", help="submit to slurm") 365 | parser.add_argument("--do_dnerf", action="store_true", help="submit to slurm") 366 | 367 | # dataset options 368 | parser.add_argument( 369 | "--dataset_type", type=str, default="blender", help="options: llff / blender / deepvoxels" 370 | ) 371 | parser.add_argument( 372 | "--testskip", 373 | type=int, 374 | default=8, 375 | help="will load 1/N images from test/val sets, useful for large datasets like deepvoxels", 376 | ) 377 | 378 | # deformation options 379 | parser.add_argument( 380 | "--use_temporal_latent_codes", 381 | action="store_true", 382 | help="auto-decoded latent codes or raw time", 383 | ) 384 | parser.add_argument("--pure_mlp_bending", action="store_true", help="pure MLP bending network") 385 | parser.add_argument( 386 | "--use_time_conditioning", 387 | action="store_true", 388 | help="time-condition the canonical model for deformations.", 389 | ) 390 | parser.add_argument( 391 | "--use_nerfies_se3", 392 | action="store_true", 393 | help="nerfies deformation parametrization instead of offsets field", 394 | ) 395 | 396 | ## llff flags 397 | parser.add_argument("--factor", type=int, default=8, help="downsample factor for LLFF images") 398 | parser.add_argument( 399 | "--no_ndc", 400 | action="store_true", 401 | help="do not use normalized device coordinates (set for non-forward facing scenes)", 402 | ) 403 | parser.add_argument( 404 | "--disparity_sampling", 405 | action="store_true", 406 | help="sampling linearly in disparity instead of depth", 407 | ) 408 | parser.add_argument("--spherify", action="store_true", help="set for spherical 360 scenes") 409 | parser.add_argument( 410 | "--llffhold", 411 | type=int, 412 | default=8, 413 | help="will take every 1/N images as LLFF test set, paper uses 8", 414 | ) 415 | 416 | # logging/saving options 417 | parser.add_argument( 418 | "--i_print", type=int, default=100, help="frequency of console printout and metric loggin" 419 | ) 420 | parser.add_argument( 421 | "--i_img", type=int, default=2500, help="frequency of tensorboard image logging" 422 | ) 423 | parser.add_argument( 424 | "--save_temporary_checkpoint_every", 425 | type=int, 426 | default=2500, 427 | help="frequency of weight ckpt saving", 428 | ) 429 | parser.add_argument( 430 | "--save_intermediate_checkpoint_every", 431 | type=int, 432 | default=10000, 433 | help="frequency of weight ckpt saving", 434 | ) 435 | parser.add_argument( 436 | "--save_checkpoint_every", type=int, default=2500, help="frequency of weight ckpt saving" 437 | ) 438 | parser.add_argument( 439 | "--save_per_timestep", action="store_true", help="set for spherical 360 scenes" 440 | ) 441 | parser.add_argument( 442 | "--save_per_timestep_in_scratch", action="store_true", help="set for spherical 360 scenes" 443 | ) 444 | parser.add_argument("--i_testset", type=int, default=50000, help="frequency of testset saving") 445 | parser.add_argument( 446 | "--i_video", type=int, default=50000, help="frequency of render_poses video saving" 447 | ) 448 | 449 | # backbone network 450 | parser.add_argument("--backbone", type=str, default="mlp", help="backbone: mlp / ngp") 451 | parser.add_argument( 452 | "--prefer_cutlass_over_fullyfused_mlp", 453 | action="store_true", 454 | help="cutlass for gpu20, fullyfused for gpu22", 455 | ) 456 | 457 | parser.add_argument("--slurm", action="store_true", help="submit to slurm") 458 | parser.add_argument( 459 | "--time_cpu_ram_cluster_gpu", 460 | type=str, 461 | default="1:00:00 16 200 gpu20 gpu:1", 462 | help="settings for slurm", 463 | ) 464 | 465 | parser.add_argument("--multi_gpu", action="store_true", help="whether to use multiple GPUs") 466 | parser.add_argument("--debug", action="store_true", help="debug flag") 467 | parser.add_argument( 468 | "--always_load_full_dataset", 469 | action="store_true", 470 | help="auto-decoded latent codes or raw time", 471 | ) 472 | 473 | return parser 474 | -------------------------------------------------------------------------------- /smart_adam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import math 7 | from typing import List, Optional 8 | 9 | import torch 10 | from torch import Tensor 11 | from torch.optim.optimizer import Optimizer 12 | 13 | 14 | class SmartAdam(Optimizer): 15 | """Implements a modified AdamW. Please refer to the original AdamW Pytorch implementation for additional information.""" 16 | 17 | def __init__( 18 | self, 19 | params, 20 | lr=1e-3, 21 | betas=(0.9, 0.999), 22 | eps=1e-8, 23 | weight_decay=0, 24 | amsgrad=False, 25 | *, 26 | foreach: Optional[bool] = None, 27 | maximize: bool = False 28 | ): 29 | if not 0.0 <= lr: 30 | raise ValueError("Invalid learning rate: {}".format(lr)) 31 | if not 0.0 <= eps: 32 | raise ValueError("Invalid epsilon value: {}".format(eps)) 33 | if not 0.0 <= betas[0] < 1.0: 34 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 35 | if not 0.0 <= betas[1] < 1.0: 36 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 37 | if not 0.0 <= weight_decay: 38 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 39 | defaults = dict( 40 | lr=lr, 41 | betas=betas, 42 | eps=eps, 43 | weight_decay=weight_decay, 44 | amsgrad=amsgrad, 45 | maximize=maximize, 46 | foreach=foreach, 47 | ) 48 | super(SmartAdam, self).__init__(params, defaults) 49 | 50 | def __setstate__(self, state): 51 | super().__setstate__(state) 52 | for group in self.param_groups: 53 | group.setdefault("amsgrad", False) 54 | group.setdefault("maximize", False) 55 | group.setdefault("foreach", None) 56 | state_values = list(self.state.values()) 57 | step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]["step"]) 58 | if not step_is_tensor: 59 | for s in state_values: 60 | s["step"] = torch.tensor(float(s["step"])) 61 | 62 | @torch.no_grad() 63 | def step(self, closure=None): 64 | """Performs a single optimization step. 65 | Args: 66 | closure (callable, optional): A closure that reevaluates the model 67 | and returns the loss. 68 | """ 69 | loss = None 70 | if closure is not None: 71 | with torch.enable_grad(): 72 | loss = closure() 73 | 74 | for group in self.param_groups: 75 | params_with_grad = [] 76 | grads = [] 77 | exp_avgs = [] 78 | exp_avg_sqs = [] 79 | max_exp_avg_sqs = [] 80 | state_steps = [] 81 | beta1, beta2 = group["betas"] 82 | 83 | for p in group["params"]: 84 | if p.grad is not None: 85 | params_with_grad.append(p) 86 | if p.grad.is_sparse: 87 | raise RuntimeError( 88 | "Adam does not support sparse gradients, please consider SparseAdam instead" 89 | ) 90 | grads.append(p.grad) 91 | 92 | state = self.state[p] 93 | # Lazy state initialization 94 | if len(state) == 0: 95 | state["step"] = torch.zeros_like(p) ###torch.tensor(0.) 96 | # Exponential moving average of gradient values 97 | state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) 98 | # Exponential moving average of squared gradient values 99 | state["exp_avg_sq"] = torch.zeros_like( 100 | p, memory_format=torch.preserve_format 101 | ) 102 | if group["amsgrad"]: 103 | # Maintains max of all exp. moving avg. of sq. grad. values 104 | state["max_exp_avg_sq"] = torch.zeros_like( 105 | p, memory_format=torch.preserve_format 106 | ) 107 | 108 | exp_avgs.append(state["exp_avg"]) 109 | exp_avg_sqs.append(state["exp_avg_sq"]) 110 | 111 | if group["amsgrad"]: 112 | max_exp_avg_sqs.append(state["max_exp_avg_sq"]) 113 | 114 | state_steps.append(state["step"]) 115 | 116 | adam( 117 | params_with_grad, 118 | grads, 119 | exp_avgs, 120 | exp_avg_sqs, 121 | max_exp_avg_sqs, 122 | state_steps, 123 | amsgrad=group["amsgrad"], 124 | beta1=beta1, 125 | beta2=beta2, 126 | lr=group["lr"], 127 | weight_decay=group["weight_decay"], 128 | eps=group["eps"], 129 | maximize=group["maximize"], 130 | foreach=group["foreach"], 131 | ) 132 | 133 | return loss 134 | 135 | 136 | def adam( 137 | params: List[Tensor], 138 | grads: List[Tensor], 139 | exp_avgs: List[Tensor], 140 | exp_avg_sqs: List[Tensor], 141 | max_exp_avg_sqs: List[Tensor], 142 | state_steps: List[Tensor], 143 | # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 144 | # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 145 | foreach: bool = None, 146 | *, 147 | amsgrad: bool, 148 | beta1: float, 149 | beta2: float, 150 | lr: float, 151 | weight_decay: float, 152 | eps: float, 153 | maximize: bool 154 | ): 155 | r"""Functional API that performs Adam algorithm computation. 156 | See :class:`~torch.optim.Adam` for details. 157 | """ 158 | 159 | if not all([isinstance(t, torch.Tensor) for t in state_steps]): 160 | raise RuntimeError( 161 | "API has changed, `state_steps` argument must contain a list of singleton tensors" 162 | ) 163 | 164 | _single_tensor_adam( 165 | params, 166 | grads, 167 | exp_avgs, 168 | exp_avg_sqs, 169 | max_exp_avg_sqs, 170 | state_steps, 171 | amsgrad=amsgrad, 172 | beta1=beta1, 173 | beta2=beta2, 174 | lr=lr, 175 | weight_decay=weight_decay, 176 | eps=eps, 177 | maximize=maximize, 178 | ) 179 | 180 | 181 | def _single_tensor_adam( 182 | params: List[Tensor], 183 | grads: List[Tensor], 184 | exp_avgs: List[Tensor], 185 | exp_avg_sqs: List[Tensor], 186 | max_exp_avg_sqs: List[Tensor], 187 | state_steps: List[Tensor], 188 | *, 189 | amsgrad: bool, 190 | beta1: float, 191 | beta2: float, 192 | lr: float, 193 | weight_decay: float, 194 | eps: float, 195 | maximize: bool 196 | ): 197 | 198 | for i, full_param in enumerate(params): 199 | 200 | full_grad = grads[i] if not maximize else -grads[i] 201 | 202 | # core modification: use a mask to only consider non-zero gradient elements. 203 | mask = full_grad != 0.0 204 | 205 | grad = full_grad[mask] 206 | param = full_param[mask] 207 | 208 | full_exp_avg = exp_avgs[i] 209 | exp_avg = full_exp_avg[mask] 210 | full_exp_avg_sq = exp_avg_sqs[i] 211 | exp_avg_sq = full_exp_avg_sq[mask] 212 | full_step_t = state_steps[i] 213 | step_t = full_step_t[mask] 214 | # update step 215 | step_t += 1 216 | 217 | # AdamW-style weight decay. Follows Pytorch's official AdamW implementation. 218 | if weight_decay != 0.0: 219 | param.mul_(1 - lr * weight_decay) 220 | 221 | bias_correction1 = 1 - beta1**step_t 222 | bias_correction2 = 1 - beta2**step_t 223 | 224 | # classic Adam with L2 weight regularization. not ideal. Do AdamW instead. 225 | # if weight_decay != 0: 226 | # grad = grad.add(param, alpha=weight_decay) 227 | 228 | # Decay the first and second moment running average coefficient 229 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 230 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 231 | if amsgrad: 232 | full_max_exp_avg_sqs = max_exp_avg_sqs[i] 233 | max_exp_avg_sqs = full_max_exp_avg_sqs[mask] 234 | # Maintains the maximum of all 2nd moment running avg. till now 235 | torch.maximum(max_exp_avg_sqs, exp_avg_sq, out=max_exp_avg_sqs) 236 | full_max_exp_avg_sqs[mask] = max_exp_avg_sqs 237 | # Use the max. for normalizing running avg. of gradient 238 | denom = bias_correction2.sqrt() / max_exp_avg_sqs.sqrt().add_(eps) 239 | else: 240 | denom = bias_correction2.sqrt() / exp_avg_sq.sqrt().add_(eps) 241 | 242 | step_size = -lr / bias_correction1 243 | denom.mul_(exp_avg) 244 | denom.mul_(step_size) 245 | param.add_(denom) 246 | 247 | full_param[mask] = param 248 | full_exp_avg[mask] = exp_avg 249 | full_exp_avg_sq[mask] = exp_avg_sq 250 | full_step_t[mask] = step_t 251 | -------------------------------------------------------------------------------- /state_loader_saver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | import logging 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | import os 9 | import pathlib 10 | import shutil 11 | 12 | import torch 13 | from multi_gpu import multi_gpu_barrier 14 | from tqdm import tqdm 15 | from utils import Returns 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | 20 | class StateLoaderSaver: 21 | def __init__(self, settings, rank): 22 | 23 | self.basedir = settings.basedir 24 | if settings.temporary_basedir is None: 25 | self.temporary_basedir = settings.basedir 26 | else: 27 | self.temporary_basedir = settings.temporary_basedir 28 | self.expname = settings.expname 29 | self.reload = not settings.no_reload 30 | 31 | self.save_checkpoint_every = settings.save_checkpoint_every 32 | self.save_intermediate_checkpoint_every = settings.save_intermediate_checkpoint_every 33 | self.save_temporary_checkpoint_every = settings.save_temporary_checkpoint_every 34 | 35 | self.multi_gpu = settings.multi_gpu 36 | self.rank = rank 37 | 38 | @staticmethod 39 | def create_folder(folder): 40 | pathlib.Path(folder).mkdir(parents=True, exist_ok=True) 41 | 42 | def get_experiment_name(self): 43 | return self.expname 44 | 45 | def get_results_folder(self): 46 | return os.path.join(self.basedir, self.get_experiment_name()) 47 | 48 | def get_checkpoint_folder(self): 49 | return os.path.join(self.get_results_folder(), "1_checkpoints/") 50 | 51 | def get_latest_checkpoint_file(self): 52 | return os.path.join(self.get_checkpoint_folder(), "latest.pth") 53 | 54 | def get_temporary_results_folder(self): 55 | return os.path.join(self.temporary_basedir, self.get_experiment_name()) 56 | 57 | def get_temporary_checkpoint_folder(self): 58 | return os.path.join(self.get_temporary_results_folder(), "1_checkpoints/") 59 | 60 | def get_latest_temporary_checkpoint_file(self): 61 | return os.path.join(self.get_temporary_checkpoint_folder(), "latest.pth") 62 | 63 | ### BACKUP 64 | 65 | def backup_files(self, settings): 66 | 67 | if self.rank != 0: 68 | return 69 | 70 | LOGGER.info("backing up... ") 71 | 72 | results_folder = self.get_results_folder() 73 | temporary_results_folder = self.get_temporary_results_folder() 74 | if not self.reload and os.path.exists(temporary_results_folder): 75 | shutil.rmtree(temporary_results_folder) 76 | if os.path.exists(results_folder): 77 | if self.reload: 78 | LOGGER.info("already exists.") 79 | return 80 | else: 81 | shutil.rmtree(results_folder) 82 | self.create_folder(results_folder) 83 | 84 | self.create_folder(self.get_checkpoint_folder()) 85 | self.create_folder(self.get_temporary_checkpoint_folder()) 86 | 87 | f = os.path.join(results_folder, "settings.txt") 88 | with open(f, "w") as file: 89 | for arg in sorted(vars(settings)): 90 | attr = getattr(settings, arg) 91 | file.write("{} = {}\n".format(arg, attr)) 92 | if settings.config is not None: 93 | f = os.path.join(results_folder, "config.txt") 94 | with open(f, "w") as file: 95 | with open(settings.config, "r") as original_config: 96 | file.write(original_config.read()) 97 | 98 | target_folder = os.path.join(results_folder, "2_backup/") 99 | 100 | special_files_to_copy = [] 101 | filetypes_to_copy = [".py", ".txt"] 102 | subfolders_to_copy = ["", "configs/"] 103 | 104 | this_file = os.path.realpath(__file__) 105 | this_folder = os.path.dirname(this_file) + "/" 106 | self.create_folder(target_folder) 107 | # special files 108 | [ 109 | self.create_folder(os.path.join(target_folder, os.path.split(file)[0])) 110 | for file in special_files_to_copy 111 | ] 112 | [ 113 | shutil.copyfile(os.path.join(this_folder, file), os.path.join(target_folder, file)) 114 | for file in special_files_to_copy 115 | ] 116 | # folders 117 | for subfolder in subfolders_to_copy: 118 | self.create_folder(os.path.join(target_folder, subfolder)) 119 | files = os.listdir(os.path.join(this_folder, subfolder)) 120 | files = [ 121 | file 122 | for file in files 123 | if os.path.isfile(os.path.join(this_folder, subfolder, file)) 124 | and file[file.rfind(".") :] in filetypes_to_copy 125 | ] 126 | [ 127 | shutil.copyfile( 128 | os.path.join(this_folder, subfolder, file), 129 | os.path.join(target_folder, subfolder, file), 130 | ) 131 | for file in files 132 | ] 133 | 134 | ### LOGGING 135 | 136 | def print_log(self, training_iteration, logging): 137 | logging_string = "[TRAIN] Iter: " + str(training_iteration) 138 | if logging is not None and "psnr" in logging: 139 | logging_string += " PSNR: " + str(logging["psnr"]) 140 | LOGGER.info(logging_string) 141 | 142 | ### CHECKPOINTS 143 | 144 | def save( 145 | self, training_iteration, scene, renderer, scheduler, trainer, force_save_in_stable=False 146 | ): 147 | 148 | if self.rank != 0: 149 | return 150 | 151 | checkpoint = { 152 | "training_iteration": training_iteration, 153 | "scene": scene.state_dict(), 154 | "renderer": renderer.state_dict(), 155 | "scheduler": scheduler.state_dict(), 156 | "trainer": trainer.state_dict(), 157 | } 158 | 159 | # save in temporary /scratch storage 160 | checkpoint_file = self.get_latest_temporary_checkpoint_file() 161 | LOGGER.info("saving to scratch storage: " + checkpoint_file) 162 | temporary_file = checkpoint_file + "_TEMP" 163 | torch.save( 164 | checkpoint, temporary_file 165 | ) # try to avoid getting interrupted while writing to checkpoint 166 | os.rename(temporary_file, checkpoint_file) 167 | 168 | # save in stable storage 169 | if force_save_in_stable or (training_iteration % self.save_checkpoint_every == 0): 170 | stable_checkpoint_file = self.get_latest_checkpoint_file() 171 | if checkpoint_file != stable_checkpoint_file: 172 | LOGGER.info("copying to stable storage: " + stable_checkpoint_file) 173 | shutil.copyfile(checkpoint_file, stable_checkpoint_file) 174 | 175 | # keep copy of this intermediate checkpoint 176 | if training_iteration % self.save_intermediate_checkpoint_every == 0: 177 | intermediate_checkpoint_file = os.path.join( 178 | self.get_checkpoint_folder(), str(training_iteration).zfill(9) + ".pth" 179 | ) 180 | shutil.copyfile(checkpoint_file, intermediate_checkpoint_file) 181 | 182 | def save_for_only_test(self, timestep, scene, renderer, stable_storage=True): 183 | 184 | if self.rank != 0: 185 | return 186 | 187 | checkpoint = { 188 | "timestep": timestep, 189 | "renderer": renderer.state_dict(), 190 | } 191 | 192 | timevariant_canonical_model = ( 193 | scene.canonical_model.variant in ["snfa", "snfag"] 194 | or scene.canonical_model.brightness_variability > 0.0 195 | ) 196 | if timevariant_canonical_model or timestep == 0.0: 197 | checkpoint["scene"] = scene.state_dict() 198 | else: 199 | if scene.deformation_model is not None: 200 | checkpoint["deformation_model"] = scene.deformation_model.state_dict() 201 | 202 | if stable_storage: 203 | # save in stable storage 204 | checkpoint_file = os.path.join( 205 | self.get_checkpoint_folder(), "timestep_" + str(timestep) + ".pth" 206 | ) 207 | LOGGER.info("saving timestep to stable storage: " + checkpoint_file) 208 | temporary_file = checkpoint_file + "_TEMP" 209 | torch.save( 210 | checkpoint, temporary_file 211 | ) # try to avoid getting interrupted while writing to checkpoint 212 | os.rename(temporary_file, checkpoint_file) 213 | else: 214 | # save in temporary /scratch storage 215 | checkpoint_file = os.path.join( 216 | self.get_temporary_checkpoint_folder(), "timestep_" + str(timestep) + ".pth" 217 | ) 218 | LOGGER.info("saving timestep to scratch storage: " + checkpoint_file) 219 | temporary_file = checkpoint_file + "_TEMP" 220 | torch.save( 221 | checkpoint, temporary_file 222 | ) # try to avoid getting interrupted while writing to checkpoint 223 | os.rename(temporary_file, checkpoint_file) 224 | 225 | def get_last_stored_training_iteration(self): 226 | return self.last_stored_training_iteration 227 | 228 | def latest_checkpoint(self): 229 | try: 230 | checkpoint_file = self.get_latest_temporary_checkpoint_file() 231 | if not os.path.exists(checkpoint_file): 232 | checkpoint_file = self.get_latest_checkpoint_file() 233 | 234 | map_location = { 235 | "cuda:0": "cuda:" + str(self.rank) 236 | } # rank=0 stores the checkpoint, but when loading, we want to load into rank=self.rank 237 | LOGGER.info("trying to load from: " + checkpoint_file) 238 | checkpoint = torch.load(checkpoint_file, map_location=map_location) 239 | except FileNotFoundError: 240 | checkpoint = None 241 | return checkpoint 242 | 243 | def initialize_parameters(self, scene, renderer, scheduler, trainer, data_handler): 244 | 245 | checkpoint = self.latest_checkpoint() 246 | if checkpoint is None: 247 | 248 | pos_max, pos_min = self.determine_nerf_volume_extent(scene, renderer, data_handler) 249 | scene.set_pos_max_min(pos_max=pos_max, pos_min=pos_min) 250 | 251 | # share initial parameters from rank 0 to all other ranks by loading and storing a checkpoint 252 | if self.multi_gpu: 253 | multi_gpu_barrier(self.rank) # all processes need to see that checkpoint is None 254 | if self.rank == 0: 255 | self.save(-1, scene, renderer, scheduler, trainer) 256 | if self.multi_gpu: 257 | multi_gpu_barrier(self.rank) 258 | checkpoint = self.latest_checkpoint() 259 | 260 | scene.load_state_dict(checkpoint["scene"]) 261 | renderer.load_state_dict(checkpoint["renderer"]) 262 | scheduler.load_state_dict(checkpoint["scheduler"]) 263 | trainer.load_state_dict(checkpoint["trainer"]) 264 | 265 | self.last_stored_training_iteration = checkpoint["training_iteration"] 266 | 267 | ### INITIALIZATION 268 | 269 | def determine_nerf_volume_extent( 270 | self, scene, renderer, data_handler, output_camera_visualization=True 271 | ): 272 | # the nerf volume has some extent, but this extent is not fixed. this function computes (somewhat approximate) minimum and maximum coordinates along each axis. it considers all cameras (their positions and point samples along the rays of their corners). 273 | 274 | batch = data_handler.get_batch(subset_name="four_corners") 275 | 276 | returns = Returns() 277 | returns.activate_mode("extent") 278 | with torch.no_grad(): 279 | renderer.render( 280 | batch, scene=scene, points_per_ray=4, is_training=False, returns=returns 281 | ) 282 | 283 | critical_ray_points = returns.get_returns()["unnormalized_undeformed_positions"].view(-1, 3) 284 | consider_camera_positions = False 285 | if consider_camera_positions: 286 | camera_positions = batch["rays_origin"].view(-1, 3) 287 | critical_points = torch.cat([critical_ray_points, camera_positions], dim=0) 288 | else: 289 | critical_points = critical_ray_points 290 | pos_min = torch.min(critical_points, dim=0)[0] 291 | pos_max = torch.max(critical_points, dim=0)[0] 292 | 293 | # add some extra space around the volume. stretch away from the center of the volume. 294 | stretching_factor = 1.1 295 | center = (pos_min + pos_max) / 2.0 296 | pos_min -= center 297 | pos_max -= center 298 | pos_min *= stretching_factor 299 | pos_max *= stretching_factor 300 | pos_min += center 301 | pos_max += center 302 | 303 | if output_camera_visualization: 304 | camera_positions = batch["rays_origin"].view(-1, 3) 305 | rays_near = returns.get_returns()["unnormalized_undeformed_positions"][:, 0, :] 306 | rays_far = returns.get_returns()["unnormalized_undeformed_positions"][:, -1, :] 307 | self.visualize_cameras(camera_positions, rays_near, rays_far, filename="cameras.obj") 308 | 309 | return pos_max, pos_min 310 | 311 | def visualize_cameras(self, camera_positions, rays_near, rays_far, filename): 312 | 313 | if self.rank != 0: 314 | return 315 | 316 | cameras = camera_positions.detach().cpu().numpy() 317 | beginning = rays_near.detach().cpu().numpy() 318 | end = rays_far.detach().cpu().numpy() 319 | 320 | mesh_string = "" 321 | for x, y, z in beginning: 322 | mesh_string += "v " + str(x) + " " + str(y) + " " + str(z) + " 0.0 1.0 0.0\n" 323 | for x, y, z in end: 324 | mesh_string += "v " + str(x) + " " + str(y) + " " + str(z) + " 1.0 0.0 0.0\n" 325 | for x, y, z in end: 326 | mesh_string += "v " + str(x + 0.00001) + " " + str(y) + " " + str(z) + " 1.0 0.0 0.0\n" 327 | for x, y, z in cameras: 328 | mesh_string += "v " + str(x) + " " + str(y) + " " + str(z) + " 0.0 0.0 1.0\n" 329 | for x, y, z in cameras: 330 | mesh_string += "v " + str(x + 0.00001) + " " + str(y) + " " + str(z) + " 0.0 0.0 1.0\n" 331 | for x, y, z in cameras: 332 | mesh_string += "v " + str(x) + " " + str(y + 0.00001) + " " + str(z) + " 0.0 0.0 1.0\n" 333 | num_vertices = beginning.shape[0] 334 | for i in range(num_vertices): 335 | i += 1 336 | mesh_string += ( 337 | "f " + str(i) + " " + str(i + num_vertices) + " " + str(i + 2 * num_vertices) + "\n" 338 | ) 339 | offset = 3 * num_vertices 340 | num_cameras = cameras.shape[0] 341 | for i in range(num_cameras): 342 | i += 1 343 | mesh_string += ( 344 | "f " 345 | + str(offset + i) 346 | + " " 347 | + str(offset + i + num_cameras) 348 | + " " 349 | + str(offset + i + 2 * num_cameras) 350 | + "\n" 351 | ) 352 | 353 | with open(os.path.join(self.get_results_folder(), filename), "w") as mesh_file: 354 | mesh_file.write(mesh_string) 355 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import coloredlogs 9 | from data_handler import DataHandler 10 | from multi_gpu import multi_gpu_cleanup, multi_gpu_setup, multi_gpu_train 11 | from path_renderer import PathRenderer 12 | from renderer import Renderer 13 | from scene import Scene 14 | from scheduler import Scheduler 15 | from settings import config_parser 16 | from state_loader_saver import StateLoaderSaver 17 | from trainer import Trainer 18 | from utils import ( 19 | check_for_early_interruption, 20 | fix_random_number_generators, 21 | overwrite_settings_for_dnerf, 22 | overwrite_settings_for_nrnerf, 23 | overwrite_settings_for_pref, 24 | ) 25 | from visualizer import Visualizer 26 | 27 | LOGGER = logging.getLogger(__name__) 28 | 29 | 30 | def get_end_of_timestep(training_iteration, settings, scheduler): 31 | return ( 32 | settings.optimization_mode == "per_timestep" 33 | and settings.tracking_mode == "temporal" 34 | and ( 35 | (training_iteration + 1) == scheduler.timeline.initial_iterations 36 | or ( 37 | training_iteration + 1 > scheduler.timeline.initial_iterations 38 | and (training_iteration + 1 - scheduler.timeline.initial_iterations) 39 | % scheduler.timeline.extend_every_n_iterations 40 | == 0 41 | ) 42 | ) 43 | ) 44 | 45 | 46 | def get_do_render(training_iteration, end_of_timestep, last_iteration, settings): 47 | return ( 48 | end_of_timestep 49 | or last_iteration 50 | or ( 51 | (settings.optimization_mode != "per_timestep" or settings.tracking_mode != "temporal") 52 | and (training_iteration >= 0 and (training_iteration + 1) % settings.i_video == 0) 53 | ) 54 | ) 55 | 56 | 57 | def train(rank=0, settings=None, world_size=1, port=None): 58 | 59 | if settings.multi_gpu: 60 | multi_gpu_setup(rank, world_size, port) 61 | 62 | if settings.debug: 63 | fix_random_number_generators(seed=rank) 64 | import torch 65 | 66 | torch.autograd.set_detect_anomaly(True) 67 | 68 | state_loader_saver = StateLoaderSaver(settings, rank) 69 | state_loader_saver.backup_files(settings) 70 | 71 | try: 72 | data_handler = DataHandler(settings) 73 | except RuntimeError as e: 74 | if "pref index out of bounds" in str(e): 75 | return 76 | else: 77 | raise e 78 | if settings.always_load_full_dataset: 79 | num_total_rays_to_precompute = int(1e9) 80 | data_handler.load_training_set( 81 | factor=settings.factor, 82 | num_total_rays_to_precompute=num_total_rays_to_precompute, 83 | foreground_focused=True, 84 | also_load_top_left_corner_and_four_courners=True, 85 | ) 86 | else: 87 | data_handler.load_training_set( 88 | factor=16, 89 | num_pixels_per_image=5, # some dummy value 90 | also_load_top_left_corner_and_four_courners=True, 91 | ) # the top_left_corner (Timeline in scheduler) and four_corners (determine_nerf_volume_extent in state_loader_saver) subsets need to be loaded 92 | 93 | scene = Scene(settings, data_handler).cuda() # incl. time line 94 | renderer = Renderer(settings).cuda() 95 | 96 | trainer = Trainer(settings, scene, renderer, world_size) 97 | 98 | scheduler = Scheduler(settings, scene, renderer, trainer, data_handler) 99 | 100 | state_loader_saver.initialize_parameters( 101 | scene, renderer, scheduler, trainer, data_handler 102 | ) # incl. pos_min/pos_max 103 | 104 | # if settings.debug and rank == 0: 105 | # data_handler.visualize_images_in_3D(state_loader_saver.get_results_folder()) 106 | 107 | log = None 108 | 109 | first_iteration = True 110 | starting_iteration = state_loader_saver.get_last_stored_training_iteration() 111 | num_training_iterations = scheduler.timeline.get_num_training_iterations( 112 | settings.num_iterations 113 | ) 114 | 115 | # take care of potentially interrupted renderings 116 | training_iteration = starting_iteration 117 | last_iteration = training_iteration == num_training_iterations - 1 118 | end_of_timestep = get_end_of_timestep(training_iteration, settings, scheduler) 119 | only_render = get_do_render(training_iteration, end_of_timestep, last_iteration, settings) 120 | if not only_render: 121 | starting_iteration += 1 122 | 123 | from tqdm import trange 124 | 125 | for training_iteration in trange(starting_iteration, num_training_iterations): 126 | 127 | # training 128 | scheduler.schedule(training_iteration, log) 129 | 130 | if not only_render: 131 | batch = data_handler.get_batch(batch_size=settings.batch_size // world_size) 132 | 133 | trainer.zero_grad() 134 | 135 | log = trainer.losses_and_backward_with_virtual_batches( 136 | batch, scene, renderer, scheduler, training_iteration 137 | ) 138 | 139 | trainer.step() 140 | 141 | scheduler.reset_for_rendering() 142 | 143 | last_iteration = training_iteration == num_training_iterations - 1 144 | end_of_timestep = get_end_of_timestep(training_iteration, settings, scheduler) 145 | do_render = get_do_render(training_iteration, end_of_timestep, last_iteration, settings) 146 | 147 | # checkpoint 148 | if ( 149 | (training_iteration + 1) % settings.save_temporary_checkpoint_every == 0 150 | or last_iteration 151 | ) and not first_iteration: 152 | state_loader_saver.save( 153 | training_iteration, 154 | scene, 155 | renderer, 156 | scheduler, 157 | trainer, 158 | force_save_in_stable=last_iteration, 159 | ) 160 | 161 | if ( 162 | settings.save_per_timestep 163 | and (end_of_timestep or last_iteration) 164 | and not first_iteration 165 | ): 166 | timestep = float(data_handler.precomputed["timesteps"].numpy()[0]) 167 | state_loader_saver.save_for_only_test( 168 | timestep, scene, renderer, stable_storage=not settings.save_per_timestep_in_scratch 169 | ) 170 | 171 | # rendering/visualization 172 | if do_render: 173 | path_rendering = PathRenderer(data_handler, rank, world_size) 174 | test_cameras = data_handler.get_test_cameras_for_rendering( 175 | factor=settings.factor if settings.render_factor == 0 else settings.render_factor 176 | ) 177 | 178 | also_render_coarse = True 179 | if also_render_coarse and scene.deformation_model.coarse_and_fine: 180 | scheduler.zero_out_fine_deformations.fine_deformations(zero_out=True) 181 | path_rendering.render_and_store( 182 | state_loader_saver, 183 | also_store_images=True, 184 | output_name=state_loader_saver.get_experiment_name() 185 | + "_" 186 | + str(training_iteration).zfill(8) 187 | + "_coarse", 188 | scene=scene, 189 | renderer=renderer, 190 | **test_cameras 191 | ) 192 | scheduler.zero_out_fine_deformations.fine_deformations(zero_out=False) 193 | 194 | path_rendering.render_and_store( 195 | state_loader_saver, 196 | also_store_images=True, 197 | output_name=state_loader_saver.get_experiment_name() 198 | + "_" 199 | + str(training_iteration).zfill(8), 200 | scene=scene, 201 | renderer=renderer, 202 | **test_cameras 203 | ) 204 | 205 | # if "backgrounds" in test_cameras: 206 | # del test_cameras["backgrounds"] 207 | # path_rendering.render_and_store(state_loader_saver, output_name=str(training_iteration).zfill(8) + "_nobackground", \ 208 | # scene=scene, renderer=renderer, **test_cameras) 209 | 210 | if settings.use_visualizer: 211 | visualizer = Visualizer(settings, data_handler, rank, world_size) 212 | test_cameras = data_handler.get_test_cameras_for_rendering(factor=8) 213 | visualizer.render_and_store( 214 | state_loader_saver, 215 | output_name=state_loader_saver.get_experiment_name() 216 | + "_" 217 | + str(training_iteration).zfill(8), 218 | scene=scene, 219 | renderer=renderer, 220 | **test_cameras 221 | ) 222 | 223 | # logging 224 | if (training_iteration + 1) % settings.i_print == 0: 225 | state_loader_saver.print_log(training_iteration, log) 226 | 227 | # interruption 228 | if training_iteration % 3000 == 0: 229 | try: 230 | check_for_early_interruption(state_loader_saver) 231 | except RuntimeError: 232 | break 233 | 234 | first_iteration = False 235 | only_render = False # only the first iteration can be an "only render" iteration 236 | 237 | if settings.multi_gpu: 238 | multi_gpu_cleanup(rank) 239 | 240 | 241 | if __name__ == "__main__": 242 | 243 | settings = config_parser().parse_args() 244 | if settings.do_pref: 245 | settings = overwrite_settings_for_pref(settings) 246 | if settings.do_nrnerf: 247 | settings = overwrite_settings_for_nrnerf(settings) 248 | if settings.do_dnerf: 249 | settings = overwrite_settings_for_dnerf(settings) 250 | 251 | logging_level = logging.DEBUG if settings.debug else logging.INFO 252 | coloredlogs.install(level=logging_level, fmt="%(name)s[%(process)d] %(levelname)s %(message)s") 253 | logging.basicConfig(level=logging_level) 254 | 255 | if settings.multi_gpu: 256 | multi_gpu_train(settings) 257 | else: 258 | train(settings=settings) 259 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | from losses import Losses 12 | from multi_gpu import multi_gpu_sync_gradients 13 | from optimizer import Optimizer 14 | 15 | logging.getLogger("matplotlib").setLevel(logging.WARNING) 16 | LOGGER = logging.getLogger(__name__) 17 | 18 | 19 | class Trainer: 20 | def __init__(self, settings, scene, renderer, world_size): 21 | 22 | super().__init__() 23 | 24 | self.multi_gpu = settings.multi_gpu 25 | 26 | self.use_gradient_scaling = False 27 | self.scaler = torch.cuda.amp.GradScaler() 28 | 29 | self.scene = scene 30 | self.optimizer = Optimizer(settings, scene, renderer) 31 | self.losses = Losses(settings, world_size) 32 | 33 | # virtual batches 34 | self.num_virtual_batches = 1 # init 35 | self.iterations_since_last_change = 0 # init 36 | self.automatically_adjust_num_virtual_batches = True 37 | self.try_fewer_virtual_batches_every = 100 38 | self.max_num_virtual_batches = 16 39 | 40 | def get_optimizer(self): 41 | return self.optimizer 42 | 43 | def zero_grad(self, set_to_none=None): 44 | self.optimizer.zero_grad(set_to_none=set_to_none) 45 | 46 | def backward(self, loss, multi_gpu_sync=True): 47 | 48 | if self.use_gradient_scaling: 49 | self.scaler.scale(loss).backward() 50 | else: 51 | loss.backward() 52 | 53 | if self.multi_gpu and multi_gpu_sync: 54 | multi_gpu_sync_gradients(self.optimizer.get_parameters()) 55 | 56 | def losses_and_backward_with_virtual_batches(self, *args): 57 | 58 | while True: 59 | log = self._losses_and_backward_with_virtual_batches( 60 | *args 61 | ) # need to call this from here instead of from within itself, such that the raised exception is out of scope 62 | if ( 63 | log is None 64 | ): # out of memory happened. can only happen until num_virtual_batches > max_num_virtual_batches. 65 | # cleanup, free memory 66 | self.zero_grad(set_to_none=True) # sets gradients to None 67 | torch.cuda.empty_cache() 68 | else: 69 | return log 70 | 71 | def _losses_and_backward_with_virtual_batches( 72 | self, batch, scene, renderer, scheduler, training_iteration 73 | ): 74 | 75 | manually_accumulate_gradients = True 76 | if manually_accumulate_gradients: 77 | accumulated_gradients = {} 78 | for param_with_info in self.optimizer.get_parameters_with_information(): 79 | accumulated_gradients[param_with_info["name"]] = { 80 | param_index: None for param_index in range(len(param_with_info["parameters"])) 81 | } 82 | 83 | num_rays = batch["rays_origin"].shape[0] 84 | rays_per_virtual_batch = max(1, int(np.ceil(num_rays / self.num_virtual_batches))) 85 | 86 | for virtual_batch_index, virtual_batch_start in enumerate( 87 | range(0, num_rays, rays_per_virtual_batch) 88 | ): 89 | 90 | try: 91 | virtual_batch = { 92 | key: tensor[virtual_batch_start : virtual_batch_start + rays_per_virtual_batch] 93 | for key, tensor in batch.items() 94 | } 95 | 96 | LOGGER.debug( 97 | "virtual_batch_index: " 98 | + str(virtual_batch_index) 99 | + " | virtual batch start: " 100 | + str(virtual_batch_start) 101 | + " | virtual batch size: " 102 | + str(virtual_batch["rays_origin"].shape) 103 | ) 104 | 105 | training_loss, loss_scaling_factor, log = self.losses.compute( 106 | virtual_batch, scene, renderer, scheduler, training_iteration 107 | ) 108 | 109 | self.backward(training_loss, multi_gpu_sync=False) 110 | 111 | if manually_accumulate_gradients: 112 | with torch.no_grad(): 113 | for param_with_info in self.optimizer.get_parameters_with_information(): 114 | this_acc_grad = accumulated_gradients[param_with_info["name"]] 115 | for param_index, this_param in enumerate(param_with_info["parameters"]): 116 | if this_param.grad is not None: 117 | if this_acc_grad[param_index] is None: 118 | this_acc_grad[param_index] = this_param.grad.clone() 119 | else: 120 | this_acc_grad[param_index] += this_param.grad 121 | 122 | self.zero_grad() 123 | 124 | except RuntimeError as exception: # handle out of memory 125 | if self.num_virtual_batches > self.max_num_virtual_batches: 126 | 127 | LOGGER.warning( 128 | "trying to exceed maximum number of virtual batches: " 129 | + str(self.num_virtual_batches) 130 | + " / " 131 | + str(self.max_num_virtual_batches) 132 | ) 133 | raise exception 134 | 135 | elif ( 136 | any(oom in str(exception) for oom in ["out of memory", "OUT_OF_MEMORY"]) 137 | and self.automatically_adjust_num_virtual_batches 138 | ): 139 | 140 | sys.stderr.flush() 141 | 142 | LOGGER.info( 143 | "virtual batch too large, need more than " 144 | + str(self.num_virtual_batches) 145 | + " virtual batches" 146 | ) 147 | 148 | self.num_virtual_batches += 1 149 | self.iterations_since_last_change = 0 150 | 151 | return None 152 | 153 | else: 154 | raise exception 155 | 156 | if manually_accumulate_gradients: 157 | with torch.no_grad(): 158 | for param_with_info in self.optimizer.get_parameters_with_information(): 159 | this_acc_grad = accumulated_gradients[param_with_info["name"]] 160 | for param_index, this_param in enumerate(param_with_info["parameters"]): 161 | this_param.grad = this_acc_grad[param_index] 162 | del accumulated_gradients 163 | 164 | self.optimizer.scale_gradients( 165 | factor=1.0 / (float(self.num_virtual_batches) * loss_scaling_factor) 166 | ) 167 | 168 | if self.multi_gpu: 169 | multi_gpu_sync_gradients(self.optimizer.get_parameters()) 170 | 171 | self.iterations_since_last_change += 1 172 | if ( 173 | self.iterations_since_last_change >= self.try_fewer_virtual_batches_every 174 | and self.num_virtual_batches > 1 175 | and self.automatically_adjust_num_virtual_batches 176 | ): 177 | self.num_virtual_batches -= 1 178 | self.iterations_since_last_change = 0 # not really necessary but might reduce the number of OOM exceptions in some weird edge cases 179 | 180 | return log 181 | 182 | def step(self): 183 | 184 | self.optimizer.step(self.scaler, self.use_gradient_scaling) 185 | if self.use_gradient_scaling: 186 | self.scaler.update() 187 | 188 | self.scene.step() # e.g. for enforcing hard constraints via projection 189 | 190 | def state_dict(self): 191 | return {"optimizer": self.optimizer.state_dict(), "scaler": self.scaler.state_dict()} 192 | 193 | def load_state_dict(self, state_dict): 194 | self.optimizer.load_state_dict(state_dict["optimizer"]) 195 | self.scaler.load_state_dict(state_dict["scaler"]) 196 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import torch 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | 13 | def fix_random_number_generators(seed=None): 14 | if seed is None: 15 | seed = 0 16 | import os 17 | 18 | os.environ["PYTHONHASHSEED"] = str(seed) 19 | import random 20 | 21 | random.seed(seed) 22 | import numpy as np 23 | 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | 27 | 28 | def check_for_early_interruption(state_loader_saver): 29 | 30 | import os 31 | 32 | results_folder = state_loader_saver.get_results_folder() 33 | 34 | # subfolders 35 | folders_to_check = [ 36 | os.path.join(results_folder, folder) for folder in os.listdir(results_folder) 37 | ] 38 | folders_to_check = [folder for folder in folders_to_check if os.path.isdir(folder)] 39 | # root folder 40 | folders_to_check.append(results_folder) 41 | 42 | terminating_files = ["stop", "New Text Document.txt"] 43 | terminating_files = [file.casefold() for file in terminating_files] # lower case normalization 44 | 45 | for folder in folders_to_check: 46 | files = os.listdir(folder) 47 | files = [file.casefold() for file in files] 48 | 49 | if any(terminating_file in files for terminating_file in terminating_files): 50 | LOGGER.warning("Shutting down early.") 51 | raise RuntimeError 52 | 53 | 54 | def project_to_correct_range(values, mode, min_=None, max_=None): 55 | if min_ is None: 56 | min_ = 0.0 57 | if max_ is None: 58 | max_ = 1.0 59 | 60 | if mode == "clamp": 61 | # -1 0 1 2 62 | # 1 /---- 63 | # / 64 | # 0 ____/ 65 | return torch.clamp(values, min=min_, max=max_) 66 | 67 | elif mode == "sine": 68 | # sine, but shifted and scaled such that it's close to the identity on [0,1] 69 | # -1 0 1 2 70 | # 1 - -- 71 | # \ / \ 72 | # 0 -- - 73 | if min_ != 0.0 or max_ != 1.0: 74 | values = (values - min_) / (max_ - min_) 75 | # values = (torch.sin(values * np.pi - np.pi/2.0) + 1.0) / 2.0 76 | values = (1.0 - torch.cos(values * torch.pi)) / 2.0 77 | if min_ != 0.0 or max_ != 1.0: 78 | values = (max_ - min_) * values + min_ 79 | return values 80 | 81 | elif mode == "zick_zack": 82 | # identity function on [0,1]. reflecting boundary. linear. 83 | # -1 0 1 2 84 | # 1 \ /\ 85 | # \ / \ 86 | # 0 \/ \ 87 | if min_ != 0.0 or max_ != 1.0: 88 | values = (values - min_) / (max_ - min_) 89 | floor = torch.floor(values) 90 | fraction = values - floor # in [0,1] 91 | odd_mask = floor.long() % 2 == 1 92 | fraction[odd_mask] = 1.0 - fraction[odd_mask] 93 | if min_ != 0.0 or max_ != 1.0: 94 | fraction = (max_ - min_) * fraction + min_ 95 | return fraction 96 | 97 | 98 | def szudzik(a, b): 99 | if a >= b: 100 | key = a * a + a + b 101 | else: 102 | key = a + b * b 103 | return key 104 | 105 | 106 | # from FFJORD github code 107 | def get_minibatch_jacobian(y, x): 108 | """Computes the Jacobian of y wrt x assuming minibatch-mode. 109 | Args: 110 | y: (N, ..., D_y) 111 | x: (N, ..., D_x) 112 | Returns: 113 | The minibatch Jacobian matrix of shape (N, ..., D_y, D_x) 114 | """ 115 | assert y.shape[:-1] == x.shape[:-1] 116 | prefix_shape = y.shape[:-1] 117 | y = y.view(-1, y.shape[-1]) 118 | 119 | # Compute Jacobian row by row. 120 | jac = [] 121 | for j in range(y.shape[1]): 122 | dy_j_dx = torch.autograd.grad( 123 | y[:, j], 124 | x, 125 | torch.ones_like(y[:, j], device=y.get_device()), 126 | retain_graph=True, 127 | create_graph=True, 128 | )[0] 129 | dy_j_dx = dy_j_dx.view(-1, x.shape[-1]) 130 | jac.append(torch.unsqueeze(dy_j_dx, 1)) 131 | jac = torch.cat(jac, 1) 132 | jac = jac.view(prefix_shape + jac.shape[-2:]) 133 | return jac 134 | 135 | 136 | # from FFJORD github code 137 | def divergence_exact(inputs, outputs): 138 | # requires three backward passes instead of one like divergence_approx 139 | prefix_shape = outputs.shape[:-1] 140 | jac = get_minibatch_jacobian(outputs, inputs) 141 | diagonal = jac.view(-1, jac.shape[-1] * jac.shape[-2])[:, :: (jac.shape[-1] + 1)] 142 | divergence = torch.sum(diagonal, 1) 143 | 144 | divergence = divergence.view(prefix_shape) 145 | return divergence 146 | 147 | 148 | # from FFJORD github code 149 | def divergence_approx(inputs, outputs): 150 | # avoids explicitly computing the Jacobian 151 | e = torch.randn_like(outputs, device=outputs.get_device()) 152 | e_dydx = torch.autograd.grad(outputs, inputs, e, create_graph=True)[0] 153 | e_dydx_e = e_dydx * e 154 | approx_tr_dydx = e_dydx_e.sum(dim=-1) 155 | return approx_tr_dydx 156 | 157 | 158 | def positional_encoding(x, freq_bands=None, num_frequencies=None, include_input=None): 159 | 160 | # shape of x : .... x k 161 | # output dimensions: k + num_frequencies * 2 * k 162 | 163 | encoded_input = [] 164 | 165 | if include_input is None: 166 | include_input = num_frequencies is not None 167 | 168 | if include_input: 169 | encoded_input.append(x) 170 | 171 | if num_frequencies is None: 172 | freq_bands = freq_bands 173 | else: 174 | freq_bands = 2.0 ** torch.linspace(0.0, num_frequencies - 1, steps=num_frequencies) 175 | for frequency in freq_bands: 176 | encoded_input.append(torch.sin(x * frequency)) 177 | encoded_input.append(torch.cos(x * frequency)) 178 | 179 | encoded_input = torch.cat(encoded_input, dim=-1) 180 | 181 | return encoded_input 182 | 183 | 184 | class Squareplus(torch.nn.Module): 185 | def __init__(self): 186 | super().__init__() 187 | 188 | def forward(self, x): 189 | return 0.5 * (x + torch.sqrt(x * x + 4)) 190 | 191 | 192 | class Sine(torch.nn.Module): 193 | def __init__(self): 194 | super().__init__() 195 | 196 | def forward(self, x): 197 | return torch.sin(x) 198 | 199 | 200 | class Scaling(torch.nn.Module): 201 | def __init__(self, factor): 202 | super().__init__() 203 | self.factor = factor 204 | 205 | def forward(self, x): 206 | return self.factor * x 207 | 208 | 209 | def default_relu_initialization(layer): 210 | with torch.no_grad(): 211 | torch.nn.init.kaiming_uniform_(layer.weight, a=0, mode="fan_in", nonlinearity="relu") 212 | torch.nn.init.zeros_(layer.bias) 213 | 214 | 215 | def default_sine_initialization(layer, first=False): 216 | import numpy as np 217 | with torch.no_grad(): 218 | a = 30.0 / layer.in_features if first else np.sqrt(6.0 / layer.in_features) 219 | layer.weight.uniform_(-a, a) 220 | torch.nn.init.zeros_(layer.bias) 221 | 222 | 223 | def zero_initialization(layer): 224 | with torch.no_grad(): 225 | torch.nn.init.zeros_(layer.weight) 226 | torch.nn.init.zeros_(layer.bias) 227 | 228 | 229 | def build_pytorch_mlp_from_tinycudann(mlp_dict, half_precision=False, last_layer_zero_init=False): 230 | sequential = [] 231 | 232 | activation_name = mlp_dict["network_config"]["activation"] 233 | 234 | def get_activation(activation_name): 235 | if activation_name == "ReLU": 236 | activation = torch.nn.ReLU() 237 | elif activation_name == "Exponential": 238 | activation = torch.nn.ELU() 239 | elif activation_name == "Sine": 240 | activation = Sine() 241 | elif activation_name == "Squareplus": 242 | activation = Squareplus() 243 | elif activation_name == "LeakyReLU": 244 | activation = torch.nn.LeakyReLU() 245 | else: 246 | raise NotImplementedError 247 | return activation 248 | 249 | activation = get_activation(activation_name) 250 | 251 | input_dims = mlp_dict["n_input_dims"] 252 | if "encoding_config" in mlp_dict: 253 | if mlp_dict["encoding_config"]["otype"] == "Frequency": 254 | input_dims = input_dims + mlp_dict["encoding_config"]["n_frequencies"] * 2 * input_dims 255 | elif mlp_dict["encoding_config"]["otype"] == "Limited_Frequency": 256 | input_dims = len(mlp_dict["encoding_config"]["frequencies"]) * 2 * input_dims 257 | if mlp_dict["encoding_config"]["include_input"]: 258 | input_dims += 3 259 | elif mlp_dict["encoding_config"]["otype"] == "Composite": 260 | if len(mlp_dict["encoding_config"]["nested"]) != 2: 261 | raise NotImplementedError 262 | if ( 263 | mlp_dict["encoding_config"]["nested"][0]["otype"] == "Frequency" 264 | ): # latent code positional encoding 265 | pos_enc_input_dimensions = mlp_dict["encoding_config"]["nested"][0][ 266 | "n_dims_to_encode" 267 | ] 268 | input_dims = (input_dims - pos_enc_input_dimensions) + ( 269 | pos_enc_input_dimensions 270 | + mlp_dict["encoding_config"]["nested"][0]["n_frequencies"] 271 | * 2 272 | * pos_enc_input_dimensions 273 | ) 274 | if mlp_dict["encoding_config"]["nested"][1]["otype"] == "Frequency": 275 | pos_enc_input_dimensions = mlp_dict["encoding_config"]["nested"][1][ 276 | "n_dims_to_encode" 277 | ] 278 | input_dims = (input_dims - pos_enc_input_dimensions) + ( 279 | pos_enc_input_dimensions 280 | + mlp_dict["encoding_config"]["nested"][1]["n_frequencies"] 281 | * 2 282 | * pos_enc_input_dimensions 283 | ) 284 | elif mlp_dict["encoding_config"]["nested"][1]["otype"] != "Identity": 285 | raise NotImplementedError 286 | else: 287 | raise NotImplementedError 288 | first_layer = torch.nn.Linear(input_dims, mlp_dict["network_config"]["n_neurons"]) 289 | if activation_name == "Sine": 290 | default_sine_initialization(first_layer, first=True) 291 | elif activation_name in ["ReLU", "LeakyReLU"]: 292 | default_relu_initialization(first_layer) 293 | sequential.append(first_layer) 294 | sequential.append(activation) 295 | 296 | for layer in range(mlp_dict["network_config"]["n_hidden_layers"] - 1): 297 | layer = torch.nn.Linear( 298 | mlp_dict["network_config"]["n_neurons"], mlp_dict["network_config"]["n_neurons"] 299 | ) 300 | if activation_name == "Sine": 301 | default_sine_initialization(layer) 302 | elif activation_name in ["ReLU", "LeakyReLU"]: 303 | default_relu_initialization(layer) 304 | sequential.append(layer) 305 | sequential.append(activation) 306 | 307 | last_layer = torch.nn.Linear(mlp_dict["network_config"]["n_neurons"], mlp_dict["n_output_dims"]) 308 | if last_layer_zero_init: 309 | zero_initialization(last_layer) # note: different from tiny cuda nn 310 | elif activation_name == "Sine": 311 | default_sine_initialization(last_layer) 312 | elif activation_name in ["ReLU", "LeakyReLU"]: 313 | default_relu_initialization(last_layer) 314 | sequential.append(last_layer) 315 | if mlp_dict["network_config"]["output_activation"] == "None": 316 | pass 317 | else: 318 | sequential.append(get_activation(mlp_dict["network_config"]["output_activation"])) 319 | 320 | mlp = torch.nn.Sequential(*sequential) 321 | if half_precision: # tag:half_precision 322 | mlp = mlp.half() # do not use if autocast is used instead. 323 | return mlp 324 | 325 | 326 | def infill_masked(mask, masked_tensor, infill_value=None): 327 | if infill_value is None: 328 | infill_value = 0 329 | infilled = infill_value * torch.ones( 330 | mask.shape + masked_tensor.shape[1:], dtype=masked_tensor.dtype, device=masked_tensor.device 331 | ) 332 | infilled[mask] = masked_tensor 333 | return infilled 334 | 335 | 336 | def get_scratch_scene_folder(datadir): 337 | import os 338 | 339 | if datadir[-1] == "/": # remove trailing slash 340 | datadir = datadir[:-1] 341 | scene_type, scene_name = datadir.split("/")[-2:] 342 | scratch_root_folder = "/scratch/inf0/user/tretschk/data/" 343 | scratch_scene_folder = os.path.join(scratch_root_folder, scene_type, scene_name) 344 | return scratch_scene_folder 345 | 346 | 347 | def get_scratch_scene_folder_valid_file(datadir): 348 | import os 349 | 350 | scratch_scene_folder = get_scratch_scene_folder(datadir) 351 | valid_file = os.path.join(scratch_scene_folder, "VALID_SCRATCH") 352 | return valid_file 353 | 354 | 355 | def scratch_scene_folder_is_valid(datadir): 356 | import os 357 | 358 | valid_file = get_scratch_scene_folder_valid_file(datadir) 359 | return os.path.exists(valid_file) 360 | 361 | 362 | def check_scratch_for_dataset_copy(datadir): 363 | if scratch_scene_folder_is_valid(datadir): 364 | return get_scratch_scene_folder(datadir) 365 | else: 366 | return datadir 367 | 368 | 369 | def overwrite_settings_for_pref(settings): 370 | 371 | settings.optimization_mode = "all" 372 | settings.pure_mlp_bending = True 373 | settings.use_temporal_latent_codes = True 374 | settings.tracking_mode = "plain" 375 | settings.weight_background_loss = 0.0 376 | settings.weight_hard_surface_loss = 0.0 377 | settings.weight_coarse_smooth_deformations = 0.0 378 | settings.weight_fine_smooth_deformations = 0.0 379 | settings.activation_function = "ReLU" 380 | settings.do_zero_out = False 381 | settings.coarse_and_fine = False 382 | settings.always_load_full_dataset = True 383 | settings.num_iterations = 50000 384 | settings.reconstruction_loss_type = "L2" 385 | 386 | return settings 387 | 388 | 389 | def overwrite_settings_for_nrnerf(settings): 390 | 391 | settings.optimization_mode = "all" 392 | settings.pure_mlp_bending = True 393 | settings.use_temporal_latent_codes = True 394 | settings.tracking_mode = "plain" 395 | settings.smooth_deformations_type = "divergence" 396 | settings.weight_smooth_deformations = 3.0 397 | settings.weight_background_loss = 0.0 398 | settings.weight_hard_surface_loss = 0.0 399 | settings.weight_coarse_smooth_deformations = 0.0 400 | settings.weight_fine_smooth_deformations = 0.0 401 | settings.activation_function = "ReLU" 402 | settings.do_zero_out = False 403 | settings.coarse_and_fine = False 404 | settings.always_load_full_dataset = True 405 | settings.num_iterations = 0 406 | settings.reconstruction_loss_type = "L2" 407 | 408 | return settings 409 | 410 | 411 | def overwrite_settings_for_dnerf(settings): 412 | 413 | settings.optimization_mode = "dnerf" 414 | settings.pure_mlp_bending = True 415 | settings.use_temporal_latent_codes = False 416 | settings.tracking_mode = "temporal" 417 | settings.smooth_deformations_type = "divergence" # won't be used anyway 418 | settings.weight_smooth_deformations = 0.0 419 | settings.weight_background_loss = 0.0 420 | settings.weight_hard_surface_loss = 0.0 421 | settings.weight_coarse_smooth_deformations = 0.0 422 | settings.weight_fine_smooth_deformations = 0.0 423 | settings.activation_function = "ReLU" 424 | settings.do_zero_out = False 425 | settings.coarse_and_fine = False 426 | settings.fix_coarse_after_a_while = False 427 | settings.always_load_full_dataset = True 428 | settings.num_iterations = 800000 429 | settings.reconstruction_loss_type = "L2" 430 | settings.use_viewdirs = True 431 | 432 | return settings 433 | 434 | 435 | class Returns: 436 | def __init__(self, restricted=None): 437 | 438 | self.mode_dict = {} 439 | self.mode = None 440 | self.restricted = restricted # None takes everything. restricted only what's in restricted. 441 | 442 | self.mask = None 443 | 444 | # more internal stuff 445 | 446 | def get_restricted_list(self): 447 | return self.restricted 448 | 449 | def activate_mode(self, mode): 450 | if mode not in self.mode_dict: 451 | self.mode_dict[mode] = {} 452 | self.mode = mode 453 | 454 | def get_modes(self): 455 | return list(self.mode_dict.keys()) 456 | 457 | # add, delete 458 | 459 | def set_mask(self, mask): 460 | self.mask = mask 461 | 462 | def get_mask(self): 463 | return self.mask 464 | 465 | def add_return(self, name, returns, clone=True, infill=0): 466 | if self.restricted is None or name in self.restricted: 467 | if clone and self.mask is None: 468 | returns = returns.clone() 469 | if self.mask is not None: 470 | if not clone: 471 | raise RuntimeError 472 | returns = infill_masked(self.mask, returns, infill_value=infill) 473 | self.mode_dict[self.mode][name] = returns 474 | successful = True 475 | return successful 476 | else: 477 | successful = False 478 | return successful 479 | 480 | def add_returns(self, returns_dict): 481 | for name, returns in returns_dict.items(): 482 | self.add_return(name, returns) 483 | 484 | def delete_return(self, name): 485 | del self.mode_dict[self.mode][name] 486 | 487 | # contains, get 488 | 489 | def __contains__(self, name): 490 | return name in self.mode_dict[self.mode] 491 | 492 | def get_returns(self, mode=None): 493 | if mode is None: 494 | mode = self.mode 495 | return self.mode_dict[mode] 496 | 497 | # modify 498 | 499 | def concatenate_returns(self, modes=None): 500 | if modes is None: 501 | modes = self.get_modes() 502 | 503 | returns = { 504 | key: torch.cat([self.mode_dict[mode][key] for mode in modes], axis=0) 505 | for key in self.mode_dict[modes[0]].keys() 506 | } 507 | return returns 508 | 509 | def pull_to_cpu(self): 510 | self.mode_dict[self.mode] = { 511 | key: tensor.cpu() for key, tensor in self.mode_dict[self.mode].items() 512 | } 513 | 514 | def push_to_gpu(self): 515 | self.mode_dict[self.mode] = { 516 | key: tensor.cuda() for key, tensor in self.mode_dict[self.mode].items() 517 | } 518 | 519 | def reshape_returns(self, height, width): 520 | self.mode_dict[self.mode] = { 521 | key: tensor.view(size=(height, width) + tensor.shape[1:]) 522 | for key, tensor in self.mode_dict[self.mode].items() 523 | } 524 | -------------------------------------------------------------------------------- /visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import os 8 | 9 | import numpy as np 10 | import torch 11 | from multi_gpu import ( 12 | multi_gpu_receive_returns_from_rank_pathrenderer, 13 | multi_gpu_send_returns_to_rank_pathrenderer, 14 | ) 15 | from tqdm import trange 16 | from utils import Returns 17 | 18 | logging.getLogger("matplotlib").setLevel(logging.WARNING) 19 | LOGGER = logging.getLogger(__name__) 20 | 21 | 22 | class Visualizer: 23 | def __init__(self, settings, data_handler, rank, world_size): 24 | self.batch_builder = data_handler.batch_builder 25 | 26 | self.smooth_deformations_type = settings.smooth_deformations_type 27 | self.weight_coarse_smooth_deformations = settings.weight_coarse_smooth_deformations 28 | 29 | self.points_per_chunk = settings.points_per_chunk // 8 30 | self.default_num_points_per_ray = settings.num_points_per_ray 31 | 32 | self.rank = rank 33 | self.world_size = world_size 34 | 35 | def render( 36 | self, 37 | only_first_frame, 38 | extrins, 39 | intrins, 40 | timesteps, 41 | scene, 42 | renderer, 43 | backgrounds=None, 44 | points_per_ray=None, 45 | **kwargs 46 | ): 47 | 48 | if points_per_ray is None: 49 | points_per_ray = self.default_num_points_per_ray 50 | 51 | all_relevant_results = [] 52 | for counter in trange(len(extrins)): 53 | 54 | relevant_results = {} 55 | 56 | extrin = extrins[counter] 57 | intrin = intrins[counter] 58 | timestep = timesteps[counter] 59 | if backgrounds is not None: 60 | background = backgrounds[counter] 61 | 62 | single_image = { 63 | "extrin": extrin, 64 | "intrin": intrin, 65 | "timestep": timestep, 66 | } 67 | if backgrounds is not None: 68 | single_image["background"] = background 69 | 70 | batch = self.batch_builder.build(single_image=single_image) 71 | num_rays = batch["rays_origin"].shape[0] 72 | rays_per_chunk = max(1, self.points_per_chunk // points_per_ray) 73 | 74 | for chunk_start in range(0, num_rays, rays_per_chunk): 75 | 76 | # render 77 | 78 | returns = Returns() # dummy 79 | returns.activate_mode("coarse") 80 | subreturns = Returns() 81 | 82 | subbatch = { 83 | key: tensor[chunk_start : chunk_start + rays_per_chunk] 84 | for key, tensor in batch.items() 85 | } 86 | 87 | renderer.render( 88 | batch=subbatch, 89 | scene=scene, 90 | points_per_ray=points_per_ray, 91 | returns=returns, 92 | subreturns=subreturns, 93 | ) 94 | 95 | del returns 96 | 97 | # compute losses 98 | 99 | if self.smooth_deformations_type not in [ 100 | "finite", 101 | "jacobian", 102 | "divergence", 103 | "nerfies", 104 | ]: 105 | raise NotImplementedError 106 | 107 | for mode in subreturns.get_modes(): 108 | 109 | def wrapper_to_free_memory_quickly(relevant_results, mode, subreturns): 110 | subreturns.activate_mode(mode) 111 | position_offsets = subreturns.get_returns()["coarse_position_offsets"] 112 | normalized_undeformed_positions = subreturns.get_returns()[ 113 | "normalized_undeformed_positions" 114 | ] 115 | 116 | if "coarse_position_offsets" not in relevant_results: 117 | relevant_results["coarse_position_offsets"] = [] 118 | relevant_results["coarse_position_offsets"].append( 119 | position_offsets.detach().cpu() 120 | ) 121 | if "normalized_undeformed_positions" not in relevant_results: 122 | relevant_results["normalized_undeformed_positions"] = [] 123 | relevant_results["normalized_undeformed_positions"].append( 124 | normalized_undeformed_positions.detach().cpu() 125 | ) 126 | if "opacity" not in relevant_results: 127 | relevant_results["opacity"] = [] 128 | relevant_results["opacity"].append( 129 | subreturns.get_returns()["alpha"].detach().cpu() 130 | ) 131 | 132 | if self.smooth_deformations_type in ["jacobian", "nerfies"]: 133 | from utils import get_minibatch_jacobian 134 | 135 | jacobian = get_minibatch_jacobian( 136 | position_offsets, normalized_undeformed_positions 137 | ) # num_points x 3 x 3 138 | if self.smooth_deformations_type == "jacobian_broken": 139 | this_loss = jacobian**2 # if using position_offsets 140 | this_loss = this_loss.view(position_offsets.shape[:-1] + (-1,)) 141 | this_loss = this_loss.mean( 142 | dim=-1 143 | ) # num_rays_in_chunk x num_points_per_ray 144 | eps = 1e-6 145 | this_loss = torch.sqrt(this_loss + eps) 146 | elif self.smooth_deformations_type == "jacobian": 147 | R_times_Rt = torch.matmul( 148 | jacobian, torch.transpose(jacobian, -1, -2) 149 | ) 150 | this_loss = torch.abs( 151 | R_times_Rt - torch.eye(3, device=jacobian.device).view(-1, 3, 3) 152 | ) 153 | this_loss = this_loss.view(position_offsets.shape[:-1] + (-1,)) 154 | this_loss = this_loss.mean( 155 | dim=-1 156 | ) # num_rays_in_chunk x num_points_per_ray 157 | else: 158 | singular_values = torch.linalg.svdvals(jacobian) # num_points x 3 159 | eps = 1e-6 160 | stable_singular_values = torch.maximum( 161 | singular_values, eps * torch.ones_like(singular_values) 162 | ) 163 | log_singular_values = torch.log(stable_singular_values) 164 | this_loss = torch.mean( 165 | log_singular_values**2, dim=-1 166 | ) # num_points 167 | 168 | else: # divergence 169 | from utils import divergence_exact, divergence_approx 170 | 171 | exact = False 172 | divergence_fn = divergence_exact if exact else divergence_approx 173 | divergence = divergence_fn( 174 | inputs=normalized_undeformed_positions, outputs=position_offsets 175 | ) 176 | this_loss = torch.abs( 177 | divergence 178 | ) # num_rays_in_chunk x num_points_per_ray 179 | 180 | weigh_by_opacity = False 181 | if weigh_by_opacity: 182 | opacity = subreturns.get_returns()["alpha"] 183 | max_windowed = True 184 | if max_windowed: 185 | window_fraction = 0.01 186 | points_per_ray = opacity.shape[1] 187 | kernel_size = max(1, int(window_fraction * points_per_ray)) 188 | if kernel_size % 2 == 0: 189 | kernel_size += 1 # needed for integer padding 190 | padding = (kernel_size - 1) // 2 191 | opacity = torch.nn.functional.max_pool1d( 192 | opacity, 193 | kernel_size=kernel_size, 194 | stride=1, 195 | padding=padding, 196 | dilation=1, 197 | ceil_mode=True, 198 | return_indices=False, 199 | ) 200 | this_loss = opacity.detach() * this_loss 201 | 202 | only_for_large_deformations = 0.0001 203 | if only_for_large_deformations is not None: 204 | offset_magnitude = torch.linalg.norm(position_offsets.detach(), dim=-1) 205 | mode = "sigmoid" 206 | if mode == "binary": 207 | offset_weights = offset_magnitude > only_for_large_deformations 208 | elif mode == "sigmoid": 209 | offset_weights = torch.sigmoid( 210 | (4.0 * offset_magnitude / only_for_large_deformations) - 2.0 211 | ) 212 | this_loss = this_loss * offset_weights.detach() 213 | 214 | if "smooth_deformations_loss" not in relevant_results: 215 | relevant_results["smooth_deformations_loss"] = [] 216 | relevant_results["smooth_deformations_loss"].append( 217 | this_loss.detach().cpu() 218 | ) 219 | 220 | wrapper_to_free_memory_quickly(relevant_results, mode, subreturns) 221 | 222 | del subreturns 223 | 224 | for key in relevant_results.keys(): 225 | relevant_results[key] = torch.cat(relevant_results[key], dim=0) 226 | 227 | all_relevant_results.append(relevant_results) 228 | 229 | if only_first_frame: 230 | break 231 | 232 | return all_relevant_results 233 | 234 | def render_and_store(self, state_loader_saver, output_name, only_first_frame=True, **kwargs): 235 | 236 | if self.rank != 0: 237 | return 238 | 239 | all_relevant_results = self.render(only_first_frame=only_first_frame, **kwargs) 240 | 241 | output_folder = os.path.join(state_loader_saver.get_results_folder(), "3_visualization") 242 | state_loader_saver.create_folder(output_folder) 243 | 244 | for counter, relevant_results in enumerate(all_relevant_results): 245 | 246 | loss = relevant_results["smooth_deformations_loss"] 247 | opacity = relevant_results["opacity"] 248 | undeformed_positions = relevant_results["normalized_undeformed_positions"] 249 | offsets = relevant_results["coarse_position_offsets"] 250 | 251 | # flatten 252 | loss = loss.reshape(-1) 253 | opacity = opacity.reshape(-1) 254 | undeformed_positions = undeformed_positions.reshape(-1, 3) 255 | offsets = offsets.reshape(-1, 3) 256 | 257 | # opacity filtering 258 | use_opacity_thresholding = False 259 | opacity_threshold = 0.01 260 | if use_opacity_thresholding: 261 | mask = opacity > opacity_threshold 262 | loss = loss[mask] 263 | opacity = opacity[mask] 264 | undeformed_positions = undeformed_positions[mask, :] 265 | offsets = offsets[mask, :] 266 | 267 | # weigh the loss 268 | weigh_by_loss_weight = False 269 | if weigh_by_loss_weight: 270 | loss_weight = self.weight_coarse_smooth_deformations 271 | loss = loss_weight * loss 272 | 273 | # random subsampling 274 | only_keep_n_points = 10000 275 | if only_keep_n_points is not None: 276 | random_indices = torch.randperm(n=loss.shape[0])[:only_keep_n_points] 277 | loss = loss[random_indices] 278 | opacity = opacity[random_indices] 279 | undeformed_positions = undeformed_positions[random_indices, :] 280 | offsets = offsets[random_indices, :] 281 | 282 | # convert loss to color 283 | max_loss = 10.0 284 | loss[loss > max_loss] = max_loss 285 | loss = loss / max_loss 286 | loss = (255 * loss.numpy()).astype(np.uint8) 287 | from matplotlib import cm 288 | 289 | colors = cm.jet(loss)[:, :3] # num_points x 3 290 | 291 | # mesh generation 292 | mesh_lines = [] 293 | for (x, y, z), (dx, dy, dz), (r, g, b) in zip( 294 | undeformed_positions.numpy(), offsets.numpy(), colors 295 | ): 296 | mesh_lines.append( 297 | "v " 298 | + str(x) 299 | + " " 300 | + str(y) 301 | + " " 302 | + str(z) 303 | + " " 304 | + str(r) 305 | + " " 306 | + str(g) 307 | + " " 308 | + str(b) 309 | ) 310 | mesh_lines.append( 311 | "v " 312 | + str(x + 0.000001) 313 | + " " 314 | + str(y) 315 | + " " 316 | + str(z) 317 | + " " 318 | + str(r) 319 | + " " 320 | + str(g) 321 | + " " 322 | + str(b) 323 | ) 324 | mesh_lines.append( 325 | "v " 326 | + str(x + dx) 327 | + " " 328 | + str(y + dy) 329 | + " " 330 | + str(z + dz) 331 | + " " 332 | + str(r) 333 | + " " 334 | + str(g) 335 | + " " 336 | + str(b) 337 | ) 338 | 339 | for i in range(offsets.shape[0]): 340 | # faces are 1-indexed 341 | mesh_lines.append( 342 | "f " + str(1 + 3 * i) + " " + str(1 + 3 * i + 1) + " " + str(1 + 3 * i + 2) 343 | ) 344 | 345 | mesh_lines = "\n".join(mesh_lines) 346 | 347 | # write out mesh 348 | with open( 349 | os.path.join(output_folder, output_name + "_" + str(counter).zfill(6)) + ".obj", "w" 350 | ) as mesh_file: 351 | mesh_file.write(mesh_lines) 352 | --------------------------------------------------------------------------------