├── .DS_Store ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs └── figs │ ├── framework.png │ ├── results.gif │ └── snapshot_meshlab.png ├── examples ├── data │ ├── bunny01.blend │ ├── nerf_render_ori.py │ └── render_example.bat ├── render │ ├── render_jade.sh │ └── render_wineholder.sh ├── train │ ├── train_family.sh │ ├── train_jade.sh │ ├── train_wineholder.sh │ └── train_wineholder_with_slurm.sh └── valid │ └── valid_wineholder.sh ├── extract.py ├── fairnr ├── __init__.py ├── clib │ ├── __init__.py │ ├── include │ │ ├── cuda_utils.h │ │ ├── cutil_math.h │ │ ├── intersect.h │ │ ├── octree.h │ │ ├── sample.h │ │ └── utils.h │ └── src │ │ ├── binding.cpp │ │ ├── intersect.cpp │ │ ├── intersect_gpu.cu │ │ ├── octree.cpp │ │ ├── sample.cpp │ │ └── sample_gpu.cu ├── criterions │ ├── __init__.py │ ├── perceptual_loss.py │ ├── rendering_loss.py │ └── utils.py ├── data │ ├── __init__.py │ ├── data_utils.py │ ├── geometry.py │ ├── shape_dataset.py │ └── trajectory.py ├── models │ ├── __init__.py │ ├── fairnr_model.py │ ├── multi_nsvf.py │ ├── nerf.py │ ├── nmf.py │ ├── nsvf.py │ └── nsvf_bg.py ├── modules │ ├── __init__.py │ ├── encoder.py │ ├── field.py │ ├── hyper.py │ ├── implicit.py │ ├── module_utils.py │ ├── reader.py │ └── renderer.py ├── options.py ├── renderer.py └── tasks │ ├── __init__.py │ └── neural_rendering.py ├── fairnr_cli ├── __init__.py ├── extract.py ├── launch_slurm.py ├── render.py ├── render_multigpu.py ├── train.py └── validate.py ├── render.py ├── requirements.txt ├── setup.py ├── train.py └── validate.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/NSVF/ecd08088cf61498d7a8ba155cf5383d87cf8ff4d/.DS_Store -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | *.pkl 4 | .DS_Store/ 5 | __pycache__/ 6 | *.pt 7 | .idea/ 8 | .vscode/ 9 | build/ 10 | *.egg-info/ 11 | images/ 12 | *.blend1 13 | .vscode/ 14 | .history/ 15 | tools/ 16 | fb_sweep/ 17 | fb_scripts/ 18 | code_backup/ 19 | results/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Open Source Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | Using welcoming and inclusive language 12 | Being respectful of differing viewpoints and experiences 13 | Gracefully accepting constructive criticism 14 | Focusing on what is best for the community 15 | Showing empathy towards other community members 16 | Examples of unacceptable behavior by participants include: 17 | 18 | The use of sexualized language or imagery and unwelcome sexual attention or advances 19 | Trolling, insulting/derogatory comments, and personal or political attacks 20 | Public or private harassment 21 | Publishing others’ private information, such as a physical or electronic address, without explicit permission 22 | Other conduct which could reasonably be considered inappropriate in a professional setting 23 | 24 | ## Our Responsibilities 25 | 26 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 27 | 28 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 29 | 30 | ## Scope 31 | 32 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 33 | 34 | ## Enforcement 35 | 36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 37 | 38 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. 39 | 40 | ## Attribution 41 | 42 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 43 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 44 | 45 | [homepage]: https://www.contributor-covenant.org -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Neural Sparse Voxel Fields (NSVF) 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## License 26 | By contributing to Neural Sparse Voxel Fields, 27 | you agree that your contributions will be licensed under the LICENSE file in 28 | the root directory of this -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Sparse Voxel Fields (NSVF) 2 | 3 | ### [Project Page](https://lingjie0206.github.io/papers/NSVF/) | [Video](https://www.youtube.com/watch?v=RFqPwH7QFEI) | [Paper](https://arxiv.org/abs/2007.11571) | [Data](#dataset) 4 | 5 | 6 | 7 | Photo-realistic free-viewpoint rendering of real-world scenes using classical computer graphics techniques is a challenging problem because it requires the difficult step of capturing detailed appearance and geometry models. 8 | Neural rendering is an emerging field that employs deep neural networks to implicitly learn scene representations encapsulating both geometry and appearance from 2D observations with or without a coarse geometry. 9 | However, existing approaches in this field often show blurry renderings or suffer from slow rendering process. We propose [Neural Sparse Voxel Fields (NSVF)](https://arxiv.org/abs/2007.11571), a new neural scene representation for fast and high-quality free-viewpoint rendering. 10 | 11 | Here is the official repo for the paper: 12 | 13 | * [Neural Sparse Voxel Fields (Liu et al., 2020, NeurIPS 2020 Spotlight)](https://arxiv.org/abs/2007.11571). 14 | 15 | We also provide our unofficial implementation for: 16 | * [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis (Mildenhall et al., 2020)](https://arxiv.org/pdf/2003.08934.pdf). 17 | 18 | 19 | ## Table of contents 20 | ----- 21 | * [Installation](#requirements-and-installation) 22 | * [Dataset](#dataset) 23 | * [Usage](#train-a-new-model) 24 | + [Training](#train-a-new-model) 25 | + [Evaluation](#evaluation) 26 | + [Free-view Rendering](#free-viewpoint-rendering) 27 | + [Extracting Geometry](#extract-the-geometry) 28 | * [License](#license) 29 | * [Citation](#citation) 30 | ------ 31 | 32 | ## Requirements and Installation 33 | 34 | This code is implemented in PyTorch using [fairseq framework](https://github.com/pytorch/fairseq). 35 | 36 | The code has been tested on the following system: 37 | 38 | * Python 3.7 39 | * PyTorch 1.4.0 40 | * [Nvidia apex library](https://github.com/NVIDIA/apex) (optional) 41 | * Nvidia GPU (Tesla V100 32GB) CUDA 10.1 42 | 43 | Only learning and rendering on GPUs are supported. 44 | 45 | To install, first clone this repo and install all dependencies: 46 | 47 | ```bash 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | Then, run 52 | 53 | ```bash 54 | pip install --editable ./ 55 | ``` 56 | 57 | Or if you want to install the code locally, run: 58 | 59 | ```bash 60 | python setup.py build_ext --inplace 61 | ``` 62 | 63 | ## Dataset 64 | 65 | You can download the pre-processed synthetic and real datasets used in our paper. 66 | Please also cite the original papers if you use any of them in your work. 67 | 68 | Dataset | Download Link | Notes on Dataset Split 69 | ---|---|--- 70 | Synthetic-NSVF | [download (.zip)](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip) | 0_\* (training) 1_\* (validation) 2_\* (testing) 71 | [Synthetic-NeRF](https://github.com/bmild/nerf) | [download (.zip)](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NeRF.zip) | 0_\* (training) 1_\* (validation) 2_\* (testing) 72 | [BlendedMVS](https://github.com/YoYo000/BlendedMVS) | [download (.zip)](https://dl.fbaipublicfiles.com/nsvf/dataset/BlendedMVS.zip) | 0_\* (training) 1_\* (testing) 73 | [Tanks&Temples](https://www.tanksandtemples.org/) | [download (.zip)](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip) | 0_\* (training) 1_\* (testing) 74 | 75 | ### Prepare your own dataset 76 | 77 | To prepare a new dataset of a single scene for training and testing, please follow the data structure: 78 | 79 | ```bash 80 | 81 | |-- bbox.txt # bounding-box file 82 | |-- intrinsics.txt # 4x4 camera intrinsics 83 | |-- rgb 84 | |-- 0.png # target image for each view 85 | |-- 1.png 86 | ... 87 | |-- pose 88 | |-- 0.txt # camera pose for each view (4x4 matrices) 89 | |-- 1.txt 90 | ... 91 | [optional] 92 | |-- test_traj.txt # camera pose for free-view rendering demonstration (4N x 4) 93 | ``` 94 | 95 | where the ``bbox.txt`` file contains a line describing the initial bounding box and voxel size: 96 | 97 | ```bash 98 | x_min y_min z_min x_max y_max z_max initial_voxel_size 99 | ``` 100 | 101 | Note that the file names of target images and those of the corresponding camera pose files are not required to be exactly the same. However, the orders of these two kinds of files (sorted by string) must match. The datasets are split with view indices. 102 | For example, "``train (0..100)``, ``valid (100..200)`` and ``test (200..400)``" mean the first 100 views for training, 100-199th views for validation, and 200-399th views for testing. 103 | 104 | ## Train a new model 105 | 106 | Given the dataset of a single scene (``{DATASET}``), we use the following command for training an NSVF model to synthesize novel views at ``800x800`` pixels, with a batch size of ``4`` images per GPU and ``2048`` rays per image. By default, the code will automatically detect all available GPUs. 107 | 108 | In the following example, we use a pre-defined architecture ``nsvf_base`` with specific arguments: 109 | 110 | * By setting ``--no-sampling-at-reader``, the model only samples pixels in the projected image region of sparse voxels for training. 111 | * By default, we set the ray-marching step size to be the ratio ``1/8 (0.125)`` of the voxel size which is typically described in the ``bbox.txt`` file. 112 | * It is optional to turn on ``--use-octree``. It will build a sparse voxel octree to speed-up the ray-voxel intersection especially when the number of voxels is larger than ``10000``. 113 | * By setting ``--pruning-every-steps`` as ``2500``, the model performs self-pruning at every ``2500`` steps. 114 | * By setting ``--half-voxel-size-at`` and ``--reduce-step-size-at`` as ``5000,25000,75000``, the voxel size and step size are halved at ``5k``, ``25k`` and ``75k``, respectively. 115 | 116 | Note that, although above parameter settings are used for most of the experiments in the paper, it is possible to tune these parameters to achieve better quality. Besides the above parameters, other parameters can also use default settings. 117 | 118 | Besides the architecture ``nsvf_base``, you may check other architectures or define your own architectures in the file ``fairnr/models/nsvf.py``. 119 | 120 | ```bash 121 | python -u train.py ${DATASET} \ 122 | --user-dir fairnr \ 123 | --task single_object_rendering \ 124 | --train-views "0..100" --view-resolution "800x800" \ 125 | --max-sentences 1 --view-per-batch 4 --pixel-per-view 2048 \ 126 | --no-preload \ 127 | --sampling-on-mask 1.0 --no-sampling-at-reader \ 128 | --valid-views "100..200" --valid-view-resolution "400x400" \ 129 | --valid-view-per-batch 1 \ 130 | --transparent-background "1.0,1.0,1.0" --background-stop-gradient \ 131 | --arch nsvf_base \ 132 | --initial-boundingbox ${DATASET}/bbox.txt \ 133 | --use-octree \ 134 | --raymarching-stepsize-ratio 0.125 \ 135 | --discrete-regularization \ 136 | --color-weight 128.0 --alpha-weight 1.0 \ 137 | --optimizer "adam" --adam-betas "(0.9, 0.999)" \ 138 | --lr 0.001 --lr-scheduler "polynomial_decay" --total-num-update 150000 \ 139 | --criterion "srn_loss" --clip-norm 0.0 \ 140 | --num-workers 0 \ 141 | --seed 2 \ 142 | --save-interval-updates 500 --max-update 150000 \ 143 | --virtual-epoch-steps 5000 --save-interval 1 \ 144 | --half-voxel-size-at "5000,25000,75000" \ 145 | --reduce-step-size-at "5000,25000,75000" \ 146 | --pruning-every-steps 2500 \ 147 | --keep-interval-updates 5 --keep-last-epochs 5 \ 148 | --log-format simple --log-interval 1 \ 149 | --save-dir ${SAVE} \ 150 | --tensorboard-logdir ${SAVE}/tensorboard \ 151 | | tee -a $SAVE/train.log 152 | ``` 153 | 154 | The checkpoints are saved in ``{SAVE}``. You can launch tensorboard to check training progress: 155 | 156 | ```bash 157 | tensorboard --logdir=${SAVE}/tensorboard --port=10000 158 | ``` 159 | 160 | There are more examples of training scripts to reproduce the results of our paper under [examples](./examples/train/). 161 | 162 | ## Evaluation 163 | 164 | Once the model is trained, the following command is used to evaluate rendering quality on the test views given the ``{MODEL_PATH}``. 165 | 166 | ```bash 167 | python validate.py ${DATASET} \ 168 | --user-dir fairnr \ 169 | --valid-views "200..400" \ 170 | --valid-view-resolution "800x800" \ 171 | --no-preload \ 172 | --task single_object_rendering \ 173 | --max-sentences 1 \ 174 | --valid-view-per-batch 1 \ 175 | --path ${MODEL_PATH} \ 176 | --model-overrides '{"chunk_size":512,"raymarching_tolerance":0.01,"tensorboard_logdir":"","eval_lpips":True}' \ 177 | ``` 178 | 179 | Note that we override the ``raymarching_tolerance`` to ``0.01`` to enable early termination for rendering speed-up. 180 | 181 | ## Free Viewpoint Rendering 182 | 183 | Free-viewpoint rendering can be achieved once a model is trained and a rendering trajectory is specified. For example, the following command is for rendering with a circle trajectory (angular speed 3 degree/frame, 15 frames per GPU). This outputs per-view rendered images and merge the images into a ``.mp4`` video in ``${SAVE}/output`` as follows: 184 | 185 | 186 | 187 | By default, the code can detect all available GPUs. 188 | 189 | ```bash 190 | python render.py ${DATASET} \ 191 | --user-dir fairnr \ 192 | --task single_object_rendering \ 193 | --path ${MODEL_PATH} \ 194 | --model-overrides '{"chunk_size":512,"raymarching_tolerance":0.01}' \ 195 | --render-beam 1 --render-angular-speed 3 --render-num-frames 15 \ 196 | --render-save-fps 24 \ 197 | --render-resolution "800x800" \ 198 | --render-path-style "circle" \ 199 | --render-path-args "{'radius': 3, 'h': 2, 'axis': 'z', 't0': -2, 'r':-1}" \ 200 | --render-output ${SAVE}/output \ 201 | --render-output-types "color" "depth" "voxel" "normal" --render-combine-output \ 202 | --log-format "simple" 203 | ``` 204 | 205 | Our code also supports rendering for given camera poses. 206 | For instance, the following command is for rendering with the camera poses defined in the 200-399th files under folder ``${DATASET}/pose``: 207 | 208 | ```bash 209 | python render.py ${DATASET} \ 210 | --user-dir fairnr \ 211 | --task single_object_rendering \ 212 | --path ${MODEL_PATH} \ 213 | --model-overrides '{"chunk_size":512,"raymarching_tolerance":0.01}' \ 214 | --render-save-fps 24 \ 215 | --render-resolution "800x800" \ 216 | --render-camera-poses ${DATASET}/pose \ 217 | --render-views "200..400" \ 218 | --render-output ${SAVE}/output \ 219 | --render-output-types "color" "depth" "voxel" "normal" --render-combine-output \ 220 | --log-format "simple" 221 | ``` 222 | 223 | The code also supports rendering with camera poses defined in a ``.txt`` file. Please refer to this [example](./examples/render/render_jade.sh). 224 | 225 | ## Extract the Geometry 226 | 227 | We also support running marching cubes to extract the iso-surfaces as triangle meshes from a trained NSVF model and saved as ``{SAVE}/{NAME}.ply``. 228 | ```bash 229 | python extract.py \ 230 | --user-dir fairnr \ 231 | --path ${MODEL_PATH} \ 232 | --output ${SAVE} \ 233 | --name ${NAME} \ 234 | --format 'mc_mesh' \ 235 | --mc-threshold 0.5 \ 236 | --mc-num-samples-per-halfvoxel 5 237 | ``` 238 | It is also possible to export the learned sparse voxels by setting ``--format 'voxel_mesh'``. 239 | The output ``.ply`` file can be opened with any 3D viewers such as [MeshLab](https://www.meshlab.net/). 240 | 241 | 242 | 243 | ## License 244 | 245 | NSVF is MIT-licensed. 246 | The license applies to the pre-trained models as well. 247 | 248 | ## Citation 249 | 250 | Please cite as 251 | ```bibtex 252 | @article{liu2020neural, 253 | title={Neural Sparse Voxel Fields}, 254 | author={Liu, Lingjie and Gu, Jiatao and Lin, Kyaw Zaw and Chua, Tat-Seng and Theobalt, Christian}, 255 | journal={NeurIPS}, 256 | year={2020} 257 | } 258 | ``` 259 | -------------------------------------------------------------------------------- /docs/figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/NSVF/ecd08088cf61498d7a8ba155cf5383d87cf8ff4d/docs/figs/framework.png -------------------------------------------------------------------------------- /docs/figs/results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/NSVF/ecd08088cf61498d7a8ba155cf5383d87cf8ff4d/docs/figs/results.gif -------------------------------------------------------------------------------- /docs/figs/snapshot_meshlab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/NSVF/ecd08088cf61498d7a8ba155cf5383d87cf8ff4d/docs/figs/snapshot_meshlab.png -------------------------------------------------------------------------------- /examples/data/bunny01.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/NSVF/ecd08088cf61498d7a8ba155cf5383d87cf8ff4d/examples/data/bunny01.blend -------------------------------------------------------------------------------- /examples/data/nerf_render_ori.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import sys, os, argparse 7 | import json 8 | import bpy 9 | import mathutils 10 | from mathutils import Vector 11 | import numpy as np 12 | 13 | np.random.seed(2) # fixed seed 14 | 15 | DEBUG = False 16 | VOXEL_NUMS = 512 17 | VIEWS = 200 18 | RESOLUTION = 800 19 | RESULTS_PATH = 'rgb' 20 | DEPTH_SCALE = 1.4 21 | COLOR_DEPTH = 8 22 | FORMAT = 'PNG' 23 | RANDOM_VIEWS = True 24 | UPPER_VIEWS = True 25 | CIRCLE_FIXED_START = (.3,0,0) 26 | 27 | parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.') 28 | parser.add_argument('output', type=str, help='path where files will be saved') 29 | 30 | argv = sys.argv 31 | argv = argv[argv.index("--") + 1:] 32 | args = parser.parse_args(argv) 33 | 34 | homedir = args.output 35 | fp = bpy.path.abspath(f"{homedir}/{RESULTS_PATH}") 36 | 37 | def listify_matrix(matrix): 38 | matrix_list = [] 39 | for row in matrix: 40 | matrix_list.append(list(row)) 41 | return matrix_list 42 | 43 | if not os.path.exists(fp): 44 | os.makedirs(fp) 45 | if not os.path.exists(os.path.join(homedir, "pose")): 46 | os.mkdir(os.path.join(homedir, "pose")) 47 | 48 | # Data to store in JSON file 49 | out_data = { 50 | 'camera_angle_x': bpy.data.objects['Camera'].data.angle_x, 51 | } 52 | 53 | # Render Optimizations 54 | bpy.context.scene.render.use_persistent_data = True 55 | 56 | 57 | # Set up rendering of depth map. 58 | bpy.context.scene.use_nodes = True 59 | tree = bpy.context.scene.node_tree 60 | links = tree.links 61 | 62 | # Add passes for additionally dumping albedo and normals. 63 | #bpy.context.scene.view_layers["RenderLayer"].use_pass_normal = True 64 | bpy.context.scene.render.image_settings.file_format = str(FORMAT) 65 | bpy.context.scene.render.image_settings.color_depth = str(COLOR_DEPTH) 66 | 67 | if not DEBUG: 68 | # Create input render layer node. 69 | render_layers = tree.nodes.new('CompositorNodeRLayers') 70 | 71 | depth_file_output = tree.nodes.new(type="CompositorNodeOutputFile") 72 | depth_file_output.label = 'Depth Output' 73 | if FORMAT == 'OPEN_EXR': 74 | links.new(render_layers.outputs['Depth'], depth_file_output.inputs[0]) 75 | else: 76 | # Remap as other types can not represent the full range of depth. 77 | map = tree.nodes.new(type="CompositorNodeMapValue") 78 | # Size is chosen kind of arbitrarily, try out until you're satisfied with resulting depth map. 79 | map.offset = [-0.7] 80 | map.size = [DEPTH_SCALE] 81 | map.use_min = True 82 | map.min = [0] 83 | links.new(render_layers.outputs['Depth'], map.inputs[0]) 84 | 85 | links.new(map.outputs[0], depth_file_output.inputs[0]) 86 | 87 | normal_file_output = tree.nodes.new(type="CompositorNodeOutputFile") 88 | normal_file_output.label = 'Normal Output' 89 | links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0]) 90 | 91 | # Background 92 | bpy.context.scene.render.dither_intensity = 0.0 93 | bpy.context.scene.render.film_transparent = True 94 | 95 | # Create collection for objects not to render with background 96 | objs = [ob for ob in bpy.context.scene.objects if ob.type in ('EMPTY') and 'Empty' in ob.name] 97 | bpy.ops.object.delete({"selected_objects": objs}) 98 | 99 | # bounding box 100 | for obj in bpy.context.scene.objects: 101 | if 'Camera' not in obj.name: 102 | bbox = [obj.matrix_world @ Vector(corner) for corner in obj.bound_box] 103 | bbox = [min([bb[i] for bb in bbox]) for i in range(3)] + \ 104 | [max([bb[i] for bb in bbox]) for i in range(3)] 105 | voxel_size = ((bbox[3]-bbox[0]) * (bbox[4]-bbox[1]) * (bbox[5]-bbox[2]) / VOXEL_NUMS) ** (1/3) 106 | print(" ".join(['{:.5f}'.format(f) for f in bbox + [voxel_size]]), 107 | file=open(os.path.join(homedir, 'bbox.txt'), 'w')) 108 | 109 | def parent_obj_to_camera(b_camera): 110 | origin = (0, 0, 0) 111 | b_empty = bpy.data.objects.new("Empty", None) 112 | b_empty.location = origin 113 | b_camera.parent = b_empty # setup parenting 114 | 115 | scn = bpy.context.scene 116 | scn.collection.objects.link(b_empty) 117 | bpy.context.view_layer.objects.active = b_empty 118 | # scn.objects.active = b_empty 119 | return b_empty 120 | 121 | 122 | scene = bpy.context.scene 123 | scene.render.resolution_x = RESOLUTION 124 | scene.render.resolution_y = RESOLUTION 125 | scene.render.resolution_percentage = 100 126 | 127 | cam = scene.objects['Camera'] 128 | cam.location = (4, -4, 4) 129 | cam_constraint = cam.constraints.new(type='TRACK_TO') 130 | cam_constraint.track_axis = 'TRACK_NEGATIVE_Z' 131 | cam_constraint.up_axis = 'UP_Y' 132 | b_empty = parent_obj_to_camera(cam) 133 | cam_constraint.target = b_empty 134 | 135 | scene.render.image_settings.file_format = 'PNG' # set output format to .png 136 | 137 | from math import radians 138 | 139 | stepsize = 360.0 / VIEWS 140 | rotation_mode = 'XYZ' 141 | 142 | if not DEBUG: 143 | for output_node in [depth_file_output, normal_file_output]: 144 | output_node.base_path = '' 145 | 146 | out_data['frames'] = [] 147 | 148 | if not RANDOM_VIEWS: 149 | b_empty.rotation_euler = CIRCLE_FIXED_START 150 | 151 | for i in range(0, VIEWS): 152 | if DEBUG: 153 | i = np.random.randint(0,VIEWS) 154 | b_empty.rotation_euler[2] += radians(stepsize*i) 155 | if RANDOM_VIEWS: 156 | scene.render.filepath = os.path.join(fp, '{:04d}'.format(i)) 157 | if UPPER_VIEWS: 158 | rot = np.random.uniform(0, 1, size=3) * (1,0,2*np.pi) 159 | rot[0] = np.abs(np.arccos(1 - 2 * rot[0]) - np.pi/2) 160 | b_empty.rotation_euler = rot 161 | else: 162 | b_empty.rotation_euler = np.random.uniform(0, 2*np.pi, size=3) 163 | else: 164 | print("Rotation {}, {}".format((stepsize * i), radians(stepsize * i))) 165 | scene.render.filepath = os.path.join(fp, '{:04d}'.format(i)) 166 | 167 | # depth_file_output.file_slots[0].path = scene.render.filepath + "_depth_" 168 | # normal_file_output.file_slots[0].path = scene.render.filepath + "_normal_" 169 | print('BEFORE RENDER') 170 | if DEBUG: 171 | break 172 | else: 173 | bpy.ops.render.render(write_still=True) # render still 174 | print('AFTER RENDER') 175 | 176 | frame_data = { 177 | 'file_path': scene.render.filepath, 178 | 'rotation': radians(stepsize), 179 | 'transform_matrix': listify_matrix(cam.matrix_world) 180 | } 181 | with open(os.path.join(homedir, "pose", '{:04d}.txt'.format(i)), 'w') as fo: 182 | for ii, pose in enumerate(frame_data['transform_matrix']): 183 | print(" ".join([str(-p) if (((j == 2) | (j == 1)) and (ii < 3)) else str(p) 184 | for j, p in enumerate(pose)]), 185 | file=fo) 186 | out_data['frames'].append(frame_data) 187 | 188 | if RANDOM_VIEWS: 189 | if UPPER_VIEWS: 190 | rot = np.random.uniform(0, 1, size=3) * (1,0,2*np.pi) 191 | rot[0] = np.abs(np.arccos(1 - 2 * rot[0]) - np.pi/2) 192 | b_empty.rotation_euler = rot 193 | else: 194 | b_empty.rotation_euler = np.random.uniform(0, 2*np.pi, size=3) 195 | else: 196 | b_empty.rotation_euler[2] += radians(stepsize) 197 | 198 | if not DEBUG: 199 | with open(os.path.join(homedir, 'transforms.json'), 'w') as out_file: 200 | json.dump(out_data, out_file, indent=4) 201 | 202 | 203 | # save camera data 204 | H, W = RESOLUTION, RESOLUTION 205 | f = .5 * W /np.tan(.5 * float(out_data['camera_angle_x'])) 206 | cx = cy = W // 2 207 | 208 | # write intrinsics 209 | with open(os.path.join(homedir, 'intrinsics.txt'), 'w') as fi: 210 | print("{} {} {} 0.".format(f, cx, cy), file=fi) 211 | print("0. 0. 0.", file=fi) 212 | print("0.", file=fi) 213 | print("1.", file=fi) 214 | print("{} {}".format(H, W), file=fi) -------------------------------------------------------------------------------- /examples/data/render_example.bat: -------------------------------------------------------------------------------- 1 | :: Copyright (c) Facebook, Inc. and its affiliates. 2 | :: 3 | :: This source code is licensed under the MIT license found in the 4 | :: LICENSE file in the root directory of this source tree. 5 | 6 | set BLENDER="C:\\Users\\jgu\\Downloads\\bunny01.blend" 7 | set codepath="\\wsl$\\Ubuntu\\home\\jgu\\work\\NSVF" 8 | set OUTPUT="C:\\Users\\jgu\\Downloads\\bunny" 9 | 10 | blender --background %BLENDER% --python %codepath%\examples\data\nerf_render_ori.py -- %OUTPUT% 11 | 12 | pause -------------------------------------------------------------------------------- /examples/render/render_jade.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # just for debugging 7 | DATA="Jade" 8 | RES="576x768" 9 | ARCH="nsvf_base" 10 | SUFFIX="v1" 11 | DATASET=/private/home/jgu/data/shapenet/release/BlendedMVS/${DATA} 12 | SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA 13 | MODEL=$ARCH$SUFFIX 14 | MODEL_PATH=$SAVE/$MODEL/checkpoint_last.pt 15 | 16 | # additional rendering args 17 | MODELTEMP='{"chunk_size":%d,"raymarching_tolerance":%.3f,"use_octree":True}' 18 | MODELARGS=$(printf "$MODELTEMP" 256 0.0) 19 | 20 | # rendering with pre-defined testing trajectory 21 | python render.py ${DATASET} \ 22 | --user-dir fairnr \ 23 | --task single_object_rendering \ 24 | --path ${MODEL_PATH} \ 25 | --render-beam 1 \ 26 | --render-save-fps 24 \ 27 | --render-camera-poses $DATASET/test_traj.txt \ 28 | --model-overrides $MODELARGS \ 29 | --render-resolution $RES \ 30 | --render-output ${SAVE}/$ARCH/output \ 31 | --render-output-types "color" "depth" "voxel" "normal" \ 32 | --render-combine-output --log-format "simple" -------------------------------------------------------------------------------- /examples/render/render_wineholder.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # just for debugging 7 | DATA="Wineholder" 8 | RES="800x800" 9 | ARCH="nsvf_base" 10 | SUFFIX="v1" 11 | DATASET=/private/home/jgu/data/shapenet/release/Synthetic_NSVF/${DATA} 12 | SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA 13 | MODEL=$ARCH$SUFFIX 14 | MODEL_PATH=$SAVE/$MODEL/checkpoint_last.pt 15 | 16 | # CUDA_VISIBLE_DEVICES=0 \ 17 | python render.py ${DATASET} \ 18 | --user-dir fairnr \ 19 | --task single_object_rendering \ 20 | --path ${MODEL_PATH} \ 21 | --render-beam 1 \ 22 | --render-save-fps 24 \ 23 | --render-camera-poses ${DATASET}/pose \ 24 | --render-views "200..400" \ 25 | --model-overrides '{"chunk_size":256,"raymarching_tolerance":0.01}' \ 26 | --render-resolution $RES \ 27 | --render-output ${SAVE}/output \ 28 | --render-output-types "color" "depth" "voxel" "normal" \ 29 | --render-combine-output --log-format "simple" -------------------------------------------------------------------------------- /examples/train/train_family.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # just for debugging 7 | DATA="Family" 8 | RES="1080x1920" 9 | VALIDRES="540x960" # the original size maybe too slow for evaluation 10 | # we can optionally half the image size only for validation 11 | ARCH="nsvf_base" 12 | SUFFIX="v1" 13 | DATASET=/private/home/jgu/data/shapenet/release/TanksAndTemple/${DATA} 14 | SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA 15 | MODEL=$ARCH$SUFFIX 16 | mkdir -p $SAVE/$MODEL 17 | 18 | # start training locally 19 | python train.py ${DATASET} \ 20 | --slurm-args ${SLURM_ARGS//[[:space:]]/} \ 21 | --user-dir fairnr \ 22 | --task single_object_rendering \ 23 | --train-views "0..133" \ 24 | --view-resolution $RES \ 25 | --max-sentences 1 \ 26 | --view-per-batch 2 \ 27 | --pixel-per-view 2048 \ 28 | --valid-chunk-size 128 \ 29 | --no-preload\ 30 | --sampling-on-mask 1.0 --no-sampling-at-reader \ 31 | --valid-view-resolution $VALIDRES \ 32 | --valid-views "133..152" \ 33 | --valid-view-per-batch 1 \ 34 | --transparent-background "1.0,1.0,1.0" \ 35 | --background-stop-gradient \ 36 | --arch $ARCH \ 37 | --initial-boundingbox ${DATASET}/bbox.txt \ 38 | --raymarching-stepsize-ratio 0.125 \ 39 | --discrete-regularization \ 40 | --color-weight 128.0 \ 41 | --alpha-weight 1.0 \ 42 | --optimizer "adam" \ 43 | --adam-betas "(0.9, 0.999)" \ 44 | --lr-scheduler "polynomial_decay" \ 45 | --total-num-update 150000 \ 46 | --lr 0.001 \ 47 | --clip-norm 0.0 \ 48 | --criterion "srn_loss" \ 49 | --num-workers 0 \ 50 | --seed 2 \ 51 | --save-interval-updates 500 --max-update 150000 \ 52 | --virtual-epoch-steps 5000 --save-interval 1 \ 53 | --half-voxel-size-at "5000,25000,75000" \ 54 | --reduce-step-size-at "5000,25000,75000" \ 55 | --pruning-every-steps 2500 \ 56 | --keep-interval-updates 5 \ 57 | --log-format simple --log-interval 1 \ 58 | --tensorboard-logdir ${SAVE}/tensorboard/${MODEL} \ 59 | --save-dir ${SAVE}/${MODEL} 60 | -------------------------------------------------------------------------------- /examples/train/train_jade.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # just for debugging 7 | DATA="Jade" 8 | RES="576x768" 9 | ARCH="nsvf_base" 10 | SUFFIX="v1" 11 | DATASET=/private/home/jgu/data/shapenet/release/BlendedMVS/${DATA} 12 | SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA 13 | MODEL=$ARCH$SUFFIX 14 | mkdir -p $SAVE/$MODEL 15 | 16 | # start training locally 17 | python train.py ${DATASET} \ 18 | --user-dir fairnr \ 19 | --task single_object_rendering \ 20 | --train-views "0..50" \ 21 | --view-resolution $RES \ 22 | --max-sentences 1 \ 23 | --view-per-batch 4 \ 24 | --pixel-per-view 2048 \ 25 | --no-preload \ 26 | --sampling-on-mask 1.0 --no-sampling-at-reader \ 27 | --valid-view-resolution $RES \ 28 | --valid-views "50..58" \ 29 | --valid-view-per-batch 1 \ 30 | --transparent-background "0.0,0.0,0.0" \ 31 | --background-stop-gradient \ 32 | --arch $ARCH \ 33 | --initial-boundingbox ${DATASET}/bbox.txt \ 34 | --raymarching-stepsize-ratio 0.125 \ 35 | --use-octree \ 36 | --discrete-regularization \ 37 | --color-weight 128.0 \ 38 | --alpha-weight 1.0 \ 39 | --optimizer "adam" \ 40 | --adam-betas "(0.9, 0.999)" \ 41 | --lr-scheduler "polynomial_decay" \ 42 | --total-num-update 150000 \ 43 | --lr 0.001 \ 44 | --clip-norm 0.0 \ 45 | --criterion "srn_loss" \ 46 | --num-workers 0 \ 47 | --seed 2 \ 48 | --save-interval-updates 500 --max-update 100000 \ 49 | --virtual-epoch-steps 5000 --save-interval 1 \ 50 | --half-voxel-size-at "5000,25000" \ 51 | --reduce-step-size-at "5000,25000" \ 52 | --pruning-every-steps 2500 \ 53 | --keep-interval-updates 5 \ 54 | --log-format simple --log-interval 1 \ 55 | --tensorboard-logdir ${SAVE}/tensorboard/${MODEL} \ 56 | --save-dir ${SAVE}/${MODEL} 57 | -------------------------------------------------------------------------------- /examples/train/train_wineholder.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # just for debugging 7 | DATA="Wineholder" 8 | RES="800x800" 9 | ARCH="nsvf_base" 10 | SUFFIX="v1" 11 | DATASET=/private/home/jgu/data/shapenet/release/Synthetic_NSVF/${DATA} 12 | SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA 13 | MODEL=$ARCH$SUFFIX 14 | mkdir -p $SAVE/$MODEL 15 | 16 | # start training locally 17 | python train.py ${DATASET} \ 18 | --user-dir fairnr \ 19 | --task single_object_rendering \ 20 | --train-views "0..100" \ 21 | --view-resolution $RES \ 22 | --max-sentences 1 \ 23 | --view-per-batch 2 \ 24 | --pixel-per-view 2048 \ 25 | --no-preload \ 26 | --sampling-on-mask 1.0 --no-sampling-at-reader \ 27 | --valid-view-resolution $RES \ 28 | --valid-views "100..200" \ 29 | --valid-view-per-batch 1 \ 30 | --transparent-background "1.0,1.0,1.0" \ 31 | --background-stop-gradient \ 32 | --arch $ARCH \ 33 | --initial-boundingbox ${DATASET}/bbox.txt \ 34 | --raymarching-stepsize-ratio 0.125 \ 35 | --use-octree \ 36 | --discrete-regularization \ 37 | --color-weight 128.0 \ 38 | --alpha-weight 1.0 \ 39 | --optimizer "adam" \ 40 | --adam-betas "(0.9, 0.999)" \ 41 | --lr-scheduler "polynomial_decay" \ 42 | --total-num-update 150000 \ 43 | --lr 0.001 \ 44 | --clip-norm 0.0 \ 45 | --criterion "srn_loss" \ 46 | --num-workers 0 \ 47 | --seed 2 \ 48 | --save-interval-updates 500 --max-update 150000 \ 49 | --virtual-epoch-steps 5000 --save-interval 1 \ 50 | --half-voxel-size-at "5000,25000,75000" \ 51 | --reduce-step-size-at "5000,25000,75000" \ 52 | --pruning-every-steps 2500 \ 53 | --keep-interval-updates 5 \ 54 | --log-format simple --log-interval 1 \ 55 | --tensorboard-logdir ${SAVE}/tensorboard/${MODEL} \ 56 | --save-dir ${SAVE}/${MODEL} 57 | -------------------------------------------------------------------------------- /examples/train/train_wineholder_with_slurm.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # just for debugging 7 | DATA="Wineholder" 8 | RES="800x800" 9 | ARCH="nsvf_base" 10 | SUFFIX="v1" 11 | DATASET=/private/home/jgu/data/shapenet/release/Synthetic_NSVF/${DATA} 12 | SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA 13 | MODEL=$ARCH$SUFFIX 14 | mkdir -p $SAVE/$MODEL 15 | 16 | # By defining the following environment variables 17 | # The code will automatically detect it and trying to submit the code in slurm-based clusters 18 | # We don't need to change the main body of the training code. 19 | export SLURM_ARGS="""{ 20 | 'job-name': '${DATA}-${MODEL}', 21 | 'partition': 'priority', 22 | 'comment': 'NeurIPS2020 open-source', 23 | 'nodes': 1, 24 | 'gpus': 8, 25 | 'output': '$SAVE/$MODEL/train.out', 26 | 'error': '$SAVE/$MODEL/train.stderr.%j', 27 | 'constraint': 'volta32gb', 28 | 'local': False} 29 | """ 30 | 31 | # start training based on SLURM_ARGS 32 | python train.py ${DATASET} \ 33 | --user-dir fairnr \ 34 | --task single_object_rendering \ 35 | --train-views "0..100" \ 36 | --view-resolution $RES \ 37 | --max-sentences 1 \ 38 | --view-per-batch 2 \ 39 | --pixel-per-view 2048 \ 40 | --no-preload \ 41 | --sampling-on-mask 1.0 --no-sampling-at-reader \ 42 | --valid-view-resolution $RES \ 43 | --valid-views "100..200" \ 44 | --valid-view-per-batch 1 \ 45 | --transparent-background "1.0,1.0,1.0" \ 46 | --background-stop-gradient \ 47 | --arch $ARCH \ 48 | --initial-boundingbox ${DATASET}/bbox.txt \ 49 | --raymarching-stepsize-ratio 0.125 \ 50 | --use-octree \ 51 | --discrete-regularization \ 52 | --color-weight 128.0 \ 53 | --alpha-weight 1.0 \ 54 | --optimizer "adam" \ 55 | --adam-betas "(0.9, 0.999)" \ 56 | --lr-scheduler "polynomial_decay" \ 57 | --total-num-update 150000 \ 58 | --lr 0.001 \ 59 | --clip-norm 0.0 \ 60 | --criterion "srn_loss" \ 61 | --num-workers 0 \ 62 | --seed 2 \ 63 | --save-interval-updates 500 --max-update 150000 \ 64 | --virtual-epoch-steps 5000 --save-interval 1 \ 65 | --half-voxel-size-at "5000,25000,75000" \ 66 | --reduce-step-size-at "5000,25000,75000" \ 67 | --pruning-every-steps 2500 \ 68 | --keep-interval-updates 5 \ 69 | --log-format simple --log-interval 1 \ 70 | --tensorboard-logdir ${SAVE}/tensorboard/${MODEL} \ 71 | --save-dir ${SAVE}/${MODEL} 72 | -------------------------------------------------------------------------------- /examples/valid/valid_wineholder.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # just for debugging 7 | DATA="Wineholder" 8 | RES="800x800" 9 | ARCH="nsvf_base" 10 | SUFFIX="v1" 11 | DATASET=/private/home/jgu/data/shapenet/release/Synthetic_NSVF/${DATA} 12 | SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA 13 | MODEL=$ARCH$SUFFIX 14 | MODEL_PATH=$SAVE/$MODEL/checkpoint_last.pt 15 | 16 | # start validating a trained model with target images. 17 | # CUDA_VISIBLE_DEVICES=0 \ 18 | python validate.py ${DATASET} \ 19 | --user-dir fairnr \ 20 | --valid-views "200..400" \ 21 | --valid-view-resolution "800x800" \ 22 | --no-preload \ 23 | --task single_object_rendering \ 24 | --max-sentences 1 \ 25 | --valid-view-per-batch 1 \ 26 | --path ${MODEL_PATH} \ 27 | --model-overrides '{"chunk_size":1024,"raymarching_tolerance":0.01,"tensorboard_logdir":"","eval_lpips":True}' \ -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairnr_cli.extract import cli_main 8 | 9 | 10 | if __name__ == '__main__': 11 | cli_main() 12 | -------------------------------------------------------------------------------- /fairnr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | class ResetTrainerException(Exception): 8 | pass 9 | 10 | 11 | from . import data, tasks, models, modules, criterions 12 | -------------------------------------------------------------------------------- /fairnr/clib/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /fairnr/clib/include/intersect.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | std::tuple ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 11 | const float radius, const int n_max); 12 | std::tuple aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 13 | const float voxelsize, const int n_max); 14 | std::tuple svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, at::Tensor children, 15 | const float voxelsize, const int n_max); 16 | std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, 17 | const float cagesize, const float blur, const int n_max); 18 | -------------------------------------------------------------------------------- /fairnr/clib/include/octree.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | std::tuple build_octree(at::Tensor center, at::Tensor points, int depth); -------------------------------------------------------------------------------- /fairnr/clib/include/sample.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | 11 | std::tuple uniform_ray_sampling( 12 | at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, 13 | const float step_size, const int max_steps); 14 | std::tuple inverse_cdf_sampling( 15 | at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, 16 | at::Tensor probs, at::Tensor steps, float fixed_step_size); -------------------------------------------------------------------------------- /fairnr/clib/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /fairnr/clib/src/binding.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "intersect.h" 7 | #include "octree.h" 8 | #include "sample.h" 9 | 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("ball_intersect", &ball_intersect); 13 | m.def("aabb_intersect", &aabb_intersect); 14 | m.def("svo_intersect", &svo_intersect); 15 | m.def("triangle_intersect", &triangle_intersect); 16 | 17 | m.def("uniform_ray_sampling", &uniform_ray_sampling); 18 | m.def("inverse_cdf_sampling", &inverse_cdf_sampling); 19 | 20 | m.def("build_octree", &build_octree); 21 | } -------------------------------------------------------------------------------- /fairnr/clib/src/intersect.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "intersect.h" 7 | #include "utils.h" 8 | #include 9 | 10 | void ball_intersect_point_kernel_wrapper( 11 | int b, int n, int m, float radius, int n_max, 12 | const float *ray_start, const float *ray_dir, const float *points, 13 | int *idx, float *min_depth, float *max_depth); 14 | 15 | std::tuple< at::Tensor, at::Tensor, at::Tensor > ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 16 | const float radius, const int n_max){ 17 | CHECK_CONTIGUOUS(ray_start); 18 | CHECK_CONTIGUOUS(ray_dir); 19 | CHECK_CONTIGUOUS(points); 20 | CHECK_IS_FLOAT(ray_start); 21 | CHECK_IS_FLOAT(ray_dir); 22 | CHECK_IS_FLOAT(points); 23 | CHECK_CUDA(ray_start); 24 | CHECK_CUDA(ray_dir); 25 | CHECK_CUDA(points); 26 | 27 | at::Tensor idx = 28 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 29 | at::device(ray_start.device()).dtype(at::ScalarType::Int)); 30 | at::Tensor min_depth = 31 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 32 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 33 | at::Tensor max_depth = 34 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 35 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 36 | ball_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), 37 | radius, n_max, 38 | ray_start.data_ptr (), ray_dir.data_ptr (), points.data_ptr (), 39 | idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr ()); 40 | return std::make_tuple(idx, min_depth, max_depth); 41 | } 42 | 43 | 44 | void aabb_intersect_point_kernel_wrapper( 45 | int b, int n, int m, float voxelsize, int n_max, 46 | const float *ray_start, const float *ray_dir, const float *points, 47 | int *idx, float *min_depth, float *max_depth); 48 | 49 | std::tuple< at::Tensor, at::Tensor, at::Tensor > aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 50 | const float voxelsize, const int n_max){ 51 | CHECK_CONTIGUOUS(ray_start); 52 | CHECK_CONTIGUOUS(ray_dir); 53 | CHECK_CONTIGUOUS(points); 54 | CHECK_IS_FLOAT(ray_start); 55 | CHECK_IS_FLOAT(ray_dir); 56 | CHECK_IS_FLOAT(points); 57 | CHECK_CUDA(ray_start); 58 | CHECK_CUDA(ray_dir); 59 | CHECK_CUDA(points); 60 | 61 | at::Tensor idx = 62 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 63 | at::device(ray_start.device()).dtype(at::ScalarType::Int)); 64 | at::Tensor min_depth = 65 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 66 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 67 | at::Tensor max_depth = 68 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 69 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 70 | aabb_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), 71 | voxelsize, n_max, 72 | ray_start.data_ptr (), ray_dir.data_ptr (), points.data_ptr (), 73 | idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr ()); 74 | return std::make_tuple(idx, min_depth, max_depth); 75 | } 76 | 77 | 78 | void svo_intersect_point_kernel_wrapper( 79 | int b, int n, int m, float voxelsize, int n_max, 80 | const float *ray_start, const float *ray_dir, const float *points, const int *children, 81 | int *idx, float *min_depth, float *max_depth); 82 | 83 | 84 | std::tuple< at::Tensor, at::Tensor, at::Tensor > svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 85 | at::Tensor children, const float voxelsize, const int n_max){ 86 | CHECK_CONTIGUOUS(ray_start); 87 | CHECK_CONTIGUOUS(ray_dir); 88 | CHECK_CONTIGUOUS(points); 89 | CHECK_CONTIGUOUS(children); 90 | CHECK_IS_FLOAT(ray_start); 91 | CHECK_IS_FLOAT(ray_dir); 92 | CHECK_IS_FLOAT(points); 93 | CHECK_CUDA(ray_start); 94 | CHECK_CUDA(ray_dir); 95 | CHECK_CUDA(points); 96 | CHECK_CUDA(children); 97 | 98 | at::Tensor idx = 99 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 100 | at::device(ray_start.device()).dtype(at::ScalarType::Int)); 101 | at::Tensor min_depth = 102 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 103 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 104 | at::Tensor max_depth = 105 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 106 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 107 | svo_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), 108 | voxelsize, n_max, 109 | ray_start.data_ptr (), ray_dir.data_ptr (), points.data_ptr (), 110 | children.data_ptr (), idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr ()); 111 | return std::make_tuple(idx, min_depth, max_depth); 112 | } 113 | 114 | 115 | void triangle_intersect_point_kernel_wrapper( 116 | int b, int n, int m, float cagesize, float blur, int n_max, 117 | const float *ray_start, const float *ray_dir, const float *face_points, 118 | int *idx, float *depth, float *uv); 119 | 120 | std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, 121 | const float cagesize, const float blur, const int n_max){ 122 | CHECK_CONTIGUOUS(ray_start); 123 | CHECK_CONTIGUOUS(ray_dir); 124 | CHECK_CONTIGUOUS(face_points); 125 | CHECK_IS_FLOAT(ray_start); 126 | CHECK_IS_FLOAT(ray_dir); 127 | CHECK_IS_FLOAT(face_points); 128 | CHECK_CUDA(ray_start); 129 | CHECK_CUDA(ray_dir); 130 | CHECK_CUDA(face_points); 131 | 132 | at::Tensor idx = 133 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, 134 | at::device(ray_start.device()).dtype(at::ScalarType::Int)); 135 | at::Tensor depth = 136 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 3}, 137 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 138 | at::Tensor uv = 139 | torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 2}, 140 | at::device(ray_start.device()).dtype(at::ScalarType::Float)); 141 | triangle_intersect_point_kernel_wrapper(face_points.size(0), face_points.size(1), ray_start.size(1), 142 | cagesize, blur, n_max, 143 | ray_start.data_ptr (), ray_dir.data_ptr (), face_points.data_ptr (), 144 | idx.data_ptr (), depth.data_ptr (), uv.data_ptr ()); 145 | return std::make_tuple(idx, depth, uv); 146 | } 147 | -------------------------------------------------------------------------------- /fairnr/clib/src/octree.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "octree.h" 7 | #include "utils.h" 8 | #include 9 | #include 10 | using namespace std::chrono; 11 | 12 | 13 | typedef struct OcTree 14 | { 15 | int depth; 16 | int index; 17 | at::Tensor center; 18 | struct OcTree *children[8]; 19 | void init(at::Tensor center, int d, int i) { 20 | this->center = center; 21 | this->depth = d; 22 | this->index = i; 23 | for (int i=0; i<8; i++) this->children[i] = nullptr; 24 | } 25 | }OcTree; 26 | 27 | class EasyOctree { 28 | public: 29 | OcTree *root; 30 | int total; 31 | int terminal; 32 | 33 | at::Tensor all_centers; 34 | at::Tensor all_children; 35 | 36 | EasyOctree(at::Tensor center, int depth) { 37 | root = new OcTree; 38 | root->init(center, depth, -1); 39 | total = -1; 40 | terminal = -1; 41 | } 42 | ~EasyOctree() { 43 | OcTree *p = root; 44 | destory(p); 45 | } 46 | void destory(OcTree * &p); 47 | void insert(OcTree * &p, at::Tensor point, int index); 48 | void finalize(); 49 | std::pair count(OcTree * &p); 50 | }; 51 | 52 | void EasyOctree::destory(OcTree * &p){ 53 | if (p != nullptr) { 54 | for (int i=0; i<8; i++) { 55 | if (p->children[i] != nullptr) destory(p->children[i]); 56 | } 57 | delete p; 58 | p = nullptr; 59 | } 60 | } 61 | 62 | void EasyOctree::insert(OcTree * &p, at::Tensor point, int index) { 63 | at::Tensor diff = (point > p->center).to(at::kInt); 64 | int idx = diff[0].item() + 2 * diff[1].item() + 4 * diff[2].item(); 65 | if (p->depth == 0) { 66 | p->children[idx] = new OcTree; 67 | p->children[idx]->init(point, -1, index); 68 | } else { 69 | if (p->children[idx] == nullptr) { 70 | int length = 1 << (p->depth - 1); 71 | at::Tensor new_center = p->center + (2 * diff - 1) * length; 72 | p->children[idx] = new OcTree; 73 | p->children[idx]->init(new_center, p->depth-1, -1); 74 | } 75 | insert(p->children[idx], point, index); 76 | } 77 | } 78 | 79 | std::pair EasyOctree::count(OcTree * &p) { 80 | int total = 0, terminal = 0; 81 | for (int i=0; i<8; i++) { 82 | if (p->children[i] != nullptr) { 83 | std::pair sub = count(p->children[i]); 84 | total += sub.first; 85 | terminal += sub.second; 86 | } 87 | } 88 | total += 1; 89 | if (p->depth == -1) terminal += 1; 90 | return std::make_pair(total, terminal); 91 | } 92 | 93 | void EasyOctree::finalize() { 94 | std::pair outs = count(root); 95 | total = outs.first; terminal = outs.second; 96 | 97 | all_centers = 98 | torch::zeros({outs.first, 3}, at::device(root->center.device()).dtype(at::ScalarType::Int)); 99 | all_children = 100 | -torch::ones({outs.first, 9}, at::device(root->center.device()).dtype(at::ScalarType::Int)); 101 | 102 | int node_idx = outs.first - 1; 103 | root->index = node_idx; 104 | 105 | std::queue all_leaves; all_leaves.push(root); 106 | while (!all_leaves.empty()) { 107 | OcTree* node_ptr = all_leaves.front(); 108 | all_leaves.pop(); 109 | for (int i=0; i<8; i++) { 110 | if (node_ptr->children[i] != nullptr) { 111 | if (node_ptr->children[i]->depth > -1) { 112 | node_idx--; 113 | node_ptr->children[i]->index = node_idx; 114 | } 115 | all_leaves.push(node_ptr->children[i]); 116 | all_children[node_ptr->index][i] = node_ptr->children[i]->index; 117 | } 118 | } 119 | all_children[node_ptr->index][8] = 1 << (node_ptr->depth + 1); 120 | all_centers[node_ptr->index] = node_ptr->center; 121 | } 122 | assert (node_idx == outs.second); 123 | }; 124 | 125 | std::tuple build_octree(at::Tensor center, at::Tensor points, int depth) { 126 | auto start = high_resolution_clock::now(); 127 | EasyOctree tree(center, depth); 128 | for (int k=0; k(stop - start); 133 | printf("Building EasyOctree done. total #nodes = %d, terminal #nodes = %d (time taken %f s)\n", 134 | tree.total, tree.terminal, float(duration.count()) / 1000000.); 135 | return std::make_tuple(tree.all_centers, tree.all_children); 136 | } -------------------------------------------------------------------------------- /fairnr/clib/src/sample.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sample.h" 7 | #include "utils.h" 8 | #include 9 | 10 | 11 | void uniform_ray_sampling_kernel_wrapper( 12 | int b, int num_rays, int max_hits, int max_steps, float step_size, 13 | const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise, 14 | int *sampled_idx, float *sampled_depth, float *sampled_dists); 15 | 16 | void inverse_cdf_sampling_kernel_wrapper( 17 | int b, int num_rays, int max_hits, int max_steps, float fixed_step_size, 18 | const int *pts_idx, const float *min_depth, const float *max_depth, 19 | const float *uniform_noise, const float *probs, const float *steps, 20 | int *sampled_idx, float *sampled_depth, float *sampled_dists); 21 | 22 | 23 | std::tuple< at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling( 24 | at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, 25 | const float step_size, const int max_steps){ 26 | 27 | CHECK_CONTIGUOUS(pts_idx); 28 | CHECK_CONTIGUOUS(min_depth); 29 | CHECK_CONTIGUOUS(max_depth); 30 | CHECK_CONTIGUOUS(uniform_noise); 31 | CHECK_IS_FLOAT(min_depth); 32 | CHECK_IS_FLOAT(max_depth); 33 | CHECK_IS_FLOAT(uniform_noise); 34 | CHECK_IS_INT(pts_idx); 35 | CHECK_CUDA(pts_idx); 36 | CHECK_CUDA(min_depth); 37 | CHECK_CUDA(max_depth); 38 | CHECK_CUDA(uniform_noise); 39 | 40 | at::Tensor sampled_idx = 41 | -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps}, 42 | at::device(pts_idx.device()).dtype(at::ScalarType::Int)); 43 | at::Tensor sampled_depth = 44 | torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, 45 | at::device(min_depth.device()).dtype(at::ScalarType::Float)); 46 | at::Tensor sampled_dists = 47 | torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, 48 | at::device(min_depth.device()).dtype(at::ScalarType::Float)); 49 | uniform_ray_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2), 50 | step_size, 51 | pts_idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr (), 52 | uniform_noise.data_ptr (), sampled_idx.data_ptr (), 53 | sampled_depth.data_ptr (), sampled_dists.data_ptr ()); 54 | return std::make_tuple(sampled_idx, sampled_depth, sampled_dists); 55 | } 56 | 57 | 58 | std::tuple inverse_cdf_sampling( 59 | at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, 60 | at::Tensor probs, at::Tensor steps, float fixed_step_size) { 61 | 62 | CHECK_CONTIGUOUS(pts_idx); 63 | CHECK_CONTIGUOUS(min_depth); 64 | CHECK_CONTIGUOUS(max_depth); 65 | CHECK_CONTIGUOUS(probs); 66 | CHECK_CONTIGUOUS(steps); 67 | CHECK_CONTIGUOUS(uniform_noise); 68 | CHECK_IS_FLOAT(min_depth); 69 | CHECK_IS_FLOAT(max_depth); 70 | CHECK_IS_FLOAT(uniform_noise); 71 | CHECK_IS_FLOAT(probs); 72 | CHECK_IS_FLOAT(steps); 73 | CHECK_IS_INT(pts_idx); 74 | CHECK_CUDA(pts_idx); 75 | CHECK_CUDA(min_depth); 76 | CHECK_CUDA(max_depth); 77 | CHECK_CUDA(uniform_noise); 78 | CHECK_CUDA(probs); 79 | CHECK_CUDA(steps); 80 | 81 | int max_steps = uniform_noise.size(-1); 82 | at::Tensor sampled_idx = 83 | -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps}, 84 | at::device(pts_idx.device()).dtype(at::ScalarType::Int)); 85 | at::Tensor sampled_depth = 86 | torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, 87 | at::device(min_depth.device()).dtype(at::ScalarType::Float)); 88 | at::Tensor sampled_dists = 89 | torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, 90 | at::device(min_depth.device()).dtype(at::ScalarType::Float)); 91 | inverse_cdf_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2), fixed_step_size, 92 | pts_idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr (), 93 | uniform_noise.data_ptr (), probs.data_ptr (), steps.data_ptr (), 94 | sampled_idx.data_ptr (), sampled_depth.data_ptr (), sampled_dists.data_ptr ()); 95 | return std::make_tuple(sampled_idx, sampled_depth, sampled_dists); 96 | } -------------------------------------------------------------------------------- /fairnr/clib/src/sample_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "cuda_utils.h" 12 | #include "cutil_math.h" // required for float3 vector math 13 | 14 | 15 | __global__ void uniform_ray_sampling_kernel( 16 | int b, int num_rays, 17 | int max_hits, 18 | int max_steps, 19 | float step_size, 20 | const int *__restrict__ pts_idx, 21 | const float *__restrict__ min_depth, 22 | const float *__restrict__ max_depth, 23 | const float *__restrict__ uniform_noise, 24 | int *__restrict__ sampled_idx, 25 | float *__restrict__ sampled_depth, 26 | float *__restrict__ sampled_dists) { 27 | 28 | int batch_index = blockIdx.x; 29 | int index = threadIdx.x; 30 | int stride = blockDim.x; 31 | 32 | pts_idx += batch_index * num_rays * max_hits; 33 | min_depth += batch_index * num_rays * max_hits; 34 | max_depth += batch_index * num_rays * max_hits; 35 | 36 | uniform_noise += batch_index * num_rays * max_steps; 37 | sampled_idx += batch_index * num_rays * max_steps; 38 | sampled_depth += batch_index * num_rays * max_steps; 39 | sampled_dists += batch_index * num_rays * max_steps; 40 | 41 | // loop over all rays 42 | for (int j = index; j < num_rays; j += stride) { 43 | int H = j * max_hits, K = j * max_steps; 44 | int s = 0, ucur = 0, umin = 0, umax = 0; 45 | float last_min_depth, last_max_depth, curr_depth; 46 | 47 | // sort all depths 48 | while (true) { 49 | if ((umax == max_hits) || (ucur == max_steps) || (pts_idx[H + umax] == -1)) { 50 | break; // reach the maximum 51 | } 52 | if (umin < max_hits) { 53 | last_min_depth = min_depth[H + umin]; 54 | } else { 55 | last_min_depth = 10000.0; 56 | } 57 | if (umax < max_hits) { 58 | last_max_depth = max_depth[H + umax]; 59 | } else { 60 | last_max_depth = 10000.0; 61 | } 62 | if (ucur < max_steps) { 63 | curr_depth = min_depth[H] + (float(ucur) + uniform_noise[K + ucur]) * step_size; 64 | } 65 | 66 | if ((last_max_depth <= curr_depth) && (last_max_depth <= last_min_depth)) { 67 | sampled_depth[K + s] = last_max_depth; 68 | sampled_idx[K + s] = pts_idx[H + umax]; 69 | umax++; s++; continue; 70 | } 71 | if ((curr_depth <= last_min_depth) && (curr_depth <= last_max_depth)) { 72 | sampled_depth[K + s] = curr_depth; 73 | sampled_idx[K + s] = pts_idx[H + umin - 1]; 74 | ucur++; s++; continue; 75 | } 76 | if ((last_min_depth <= curr_depth) && (last_min_depth <= last_max_depth)) { 77 | sampled_depth[K + s] = last_min_depth; 78 | sampled_idx[K + s] = pts_idx[H + umin]; 79 | umin++; s++; continue; 80 | } 81 | } 82 | 83 | float l_depth, r_depth; 84 | int step = 0; 85 | for (ucur = 0, umin = 0, umax = 0; ucur < max_steps - 1; ucur++) { 86 | if (sampled_idx[K + ucur + 1] == -1) break; 87 | l_depth = sampled_depth[K + ucur]; 88 | r_depth = sampled_depth[K + ucur + 1]; 89 | sampled_depth[K + ucur] = (l_depth + r_depth) * .5; 90 | sampled_dists[K + ucur] = (r_depth - l_depth); 91 | if ((umin < max_hits) && (sampled_depth[K + ucur] >= min_depth[H + umin]) && (pts_idx[H + umin] > -1)) umin++; 92 | if ((umax < max_hits) && (sampled_depth[K + ucur] >= max_depth[H + umax]) && (pts_idx[H + umax] > -1)) umax++; 93 | if ((umax == max_hits) || (pts_idx[H + umax] == -1)) break; 94 | if ((umin - 1 == umax) && (sampled_dists[K + ucur] > 0)) { 95 | sampled_depth[K + step] = sampled_depth[K + ucur]; 96 | sampled_dists[K + step] = sampled_dists[K + ucur]; 97 | sampled_idx[K + step] = sampled_idx[K + ucur]; 98 | step++; 99 | } 100 | } 101 | 102 | for (int s = step; s < max_steps; s++) { 103 | sampled_idx[K + s] = -1; 104 | } 105 | } 106 | } 107 | 108 | __global__ void inverse_cdf_sampling_kernel( 109 | int b, int num_rays, 110 | int max_hits, 111 | int max_steps, 112 | float fixed_step_size, 113 | const int *__restrict__ pts_idx, 114 | const float *__restrict__ min_depth, 115 | const float *__restrict__ max_depth, 116 | const float *__restrict__ uniform_noise, 117 | const float *__restrict__ probs, 118 | const float *__restrict__ steps, 119 | int *__restrict__ sampled_idx, 120 | float *__restrict__ sampled_depth, 121 | float *__restrict__ sampled_dists) { 122 | 123 | int batch_index = blockIdx.x; 124 | int index = threadIdx.x; 125 | int stride = blockDim.x; 126 | 127 | pts_idx += batch_index * num_rays * max_hits; 128 | min_depth += batch_index * num_rays * max_hits; 129 | max_depth += batch_index * num_rays * max_hits; 130 | probs += batch_index * num_rays * max_hits; 131 | steps += batch_index * num_rays; 132 | 133 | uniform_noise += batch_index * num_rays * max_steps; 134 | sampled_idx += batch_index * num_rays * max_steps; 135 | sampled_depth += batch_index * num_rays * max_steps; 136 | sampled_dists += batch_index * num_rays * max_steps; 137 | 138 | // loop over all rays 139 | for (int j = index; j < num_rays; j += stride) { 140 | int H = j * max_hits, K = j * max_steps; 141 | int curr_bin = 0, s = 0; // current index (bin) 142 | 143 | float curr_min_depth = min_depth[H]; // lower depth 144 | float curr_max_depth = max_depth[H]; // upper depth 145 | float curr_min_cdf = 0; 146 | float curr_max_cdf = probs[H]; 147 | float step_size = 1.0 / steps[j]; 148 | float z_low = curr_min_depth; 149 | int total_steps = int(ceil(steps[j])); 150 | bool done = false; 151 | 152 | // optional use a fixed step size 153 | if (fixed_step_size > 0.0) step_size = fixed_step_size; 154 | 155 | // sample points 156 | for (int curr_step = 0; curr_step < total_steps; curr_step++) { 157 | float curr_cdf = (float(curr_step) + uniform_noise[K + curr_step]) * step_size; 158 | while (curr_cdf > curr_max_cdf) { 159 | // first include max cdf 160 | sampled_idx[K + s] = pts_idx[H + curr_bin]; 161 | sampled_dists[K + s] = (curr_max_depth - z_low); 162 | sampled_depth[K + s] = (curr_max_depth + z_low) * .5; 163 | 164 | // move to next cdf 165 | curr_bin++; 166 | s++; 167 | if ((curr_bin >= max_hits) || (pts_idx[H + curr_bin] == -1)) { 168 | done = true; break; 169 | } 170 | curr_min_depth = min_depth[H + curr_bin]; 171 | curr_max_depth = max_depth[H + curr_bin]; 172 | curr_min_cdf = curr_max_cdf; 173 | curr_max_cdf = curr_max_cdf + probs[H + curr_bin]; 174 | z_low = curr_min_depth; 175 | } 176 | if (done) break; 177 | 178 | // if the sampled cdf is inside bin 179 | float u = (curr_cdf - curr_min_cdf) / (curr_max_cdf - curr_min_cdf); 180 | float z = curr_min_depth + u * (curr_max_depth - curr_min_depth); 181 | sampled_idx[K + s] = pts_idx[H + curr_bin]; 182 | sampled_dists[K + s] = (z - z_low); 183 | sampled_depth[K + s] = (z + z_low) * .5; 184 | z_low = z; s++; 185 | } 186 | 187 | // if there are bins still remained 188 | while ((z_low < curr_max_depth) && (~done)) { 189 | sampled_idx[K + s] = pts_idx[H + curr_bin]; 190 | sampled_dists[K + s] = (curr_max_depth - z_low); 191 | sampled_depth[K + s] = (curr_max_depth + z_low) * .5; 192 | curr_bin++; 193 | s++; 194 | if ((curr_bin >= max_hits) || (pts_idx[curr_bin] == -1)) 195 | break; 196 | 197 | curr_min_depth = min_depth[H + curr_bin]; 198 | curr_max_depth = max_depth[H + curr_bin]; 199 | z_low = curr_min_depth; 200 | } 201 | } 202 | } 203 | 204 | void uniform_ray_sampling_kernel_wrapper( 205 | int b, int num_rays, int max_hits, int max_steps, float step_size, 206 | const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise, 207 | int *sampled_idx, float *sampled_depth, float *sampled_dists) { 208 | 209 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 210 | uniform_ray_sampling_kernel<<>>( 211 | b, num_rays, max_hits, max_steps, step_size, pts_idx, 212 | min_depth, max_depth, uniform_noise, sampled_idx, sampled_depth, sampled_dists); 213 | 214 | CUDA_CHECK_ERRORS(); 215 | } 216 | 217 | void inverse_cdf_sampling_kernel_wrapper( 218 | int b, int num_rays, int max_hits, int max_steps, float fixed_step_size, 219 | const int *pts_idx, const float *min_depth, const float *max_depth, 220 | const float *uniform_noise, const float *probs, const float *steps, 221 | int *sampled_idx, float *sampled_depth, float *sampled_dists) { 222 | 223 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 224 | inverse_cdf_sampling_kernel<<>>( 225 | b, num_rays, max_hits, max_steps, fixed_step_size, 226 | pts_idx, min_depth, max_depth, uniform_noise, probs, steps, 227 | sampled_idx, sampled_depth, sampled_dists); 228 | 229 | CUDA_CHECK_ERRORS(); 230 | } 231 | -------------------------------------------------------------------------------- /fairnr/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | for file in os.listdir(os.path.dirname(__file__)): 10 | if file.endswith(".py") and not file.startswith("_"): 11 | criterion_name = file[: file.find(".py")] 12 | importlib.import_module( 13 | "fairnr.criterions." + criterion_name 14 | ) 15 | -------------------------------------------------------------------------------- /fairnr/criterions/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torchvision 8 | 9 | class VGGPerceptualLoss(torch.nn.Module): 10 | def __init__(self, resize=False): 11 | super(VGGPerceptualLoss, self).__init__() 12 | blocks = [] 13 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) 14 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) 15 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) 16 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) 17 | self.blocks = torch.nn.ModuleList(blocks) 18 | self.transform = torch.nn.functional.interpolate 19 | self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) 20 | self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) 21 | self.resize = resize 22 | 23 | # NO GRADIENT! 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def forward(self, input, target, level=2): 28 | # print(input.device, input.dtype, self.mean.device, self.mean.dtype, self.std, self.std.dtype) 29 | if input.shape[1] != 3: 30 | input = input.repeat(1, 3, 1, 1) 31 | target = target.repeat(1, 3, 1, 1) 32 | input = (input-self.mean) / self.std 33 | target = (target-self.mean) / self.std 34 | 35 | if self.resize: 36 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) 37 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) 38 | 39 | loss = 0.0 40 | x = input 41 | y = target 42 | for i, block in enumerate(self.blocks): 43 | x = block(x) 44 | y = block(y) 45 | if i < level: 46 | loss += torch.nn.functional.mse_loss(x, y) 47 | else: 48 | break 49 | return loss 50 | -------------------------------------------------------------------------------- /fairnr/criterions/rendering_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import math 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | from torch import Tensor 12 | 13 | from fairseq import metrics 14 | from fairseq.utils import item 15 | from fairseq.criterions import FairseqCriterion, register_criterion 16 | import fairnr.criterions.utils as utils 17 | 18 | class RenderingCriterion(FairseqCriterion): 19 | 20 | def __init__(self, args, task): 21 | super().__init__(task) 22 | self.args = args 23 | self.hierarchical = getattr(args, 'hierarchical_loss', False) 24 | 25 | @classmethod 26 | def build_criterion(cls, args, task): 27 | """Construct a criterion from command-line args.""" 28 | return cls(args, task) 29 | 30 | @staticmethod 31 | def add_args(parser): 32 | """Add criterion-specific arguments to the parser.""" 33 | parser.add_argument('--hierarchical-loss', action='store_true', 34 | help='if set, it computes both the coarse and fine-level losses in hierarchical sampling.') 35 | 36 | 37 | def forward(self, model, sample, reduce=True): 38 | """Compute the loss for the given sample. 39 | 40 | Returns a tuple with three elements: 41 | 1) the loss 42 | 2) the sample size, which is used as the denominator for the gradient 43 | 3) logging outputs to display while training 44 | """ 45 | net_output = model(**sample) 46 | sample.update(net_output['samples']) 47 | 48 | loss, loss_output = self.compute_loss(model, net_output, sample, reduce=reduce) 49 | if self.hierarchical: 50 | assert net_output.get('coarse', None) is not None, "missing coarse level outputs." 51 | loss0, loss_output0 = self.compute_loss(model, net_output['coarse'], sample, reduce=reduce) 52 | loss = loss + loss0 53 | loss_output.update({'cor-' + key: loss_output0[key] for key in loss_output0}) 54 | 55 | sample_size = 1 56 | 57 | logging_output = { 58 | 'loss': loss.data.item() if reduce else loss.data, 59 | 'nsentences': sample['alpha'].size(0), 60 | 'ntokens': sample['alpha'].size(1), 61 | 'npixels': sample['alpha'].size(2), 62 | 'sample_size': sample_size, 63 | } 64 | for w in loss_output: 65 | logging_output[w] = loss_output[w] 66 | 67 | return loss, sample_size, logging_output 68 | 69 | def compute_loss(self, model, net_output, sample, reduce=True): 70 | raise NotImplementedError 71 | 72 | @staticmethod 73 | def reduce_metrics(logging_outputs) -> None: 74 | """Aggregate logging outputs from data parallel training.""" 75 | 76 | summed_logging_outputs = { 77 | w: sum(log.get(w, 0) for log in logging_outputs) 78 | for w in logging_outputs[0] 79 | } 80 | sample_size = summed_logging_outputs['sample_size'] 81 | 82 | for w in summed_logging_outputs: 83 | if '_loss' in w: 84 | metrics.log_scalar(w.split('_')[0], summed_logging_outputs[w] / sample_size, sample_size, round=3) 85 | elif '_weight' in w: 86 | metrics.log_scalar('w_' + w[:3], summed_logging_outputs[w] / sample_size, sample_size, round=3) 87 | elif '_acc' in w: 88 | metrics.log_scalar('a_' + w[:3], summed_logging_outputs[w] / sample_size, sample_size, round=3) 89 | elif w == 'loss': 90 | metrics.log_scalar('loss', summed_logging_outputs['loss'] / sample_size, sample_size, priority=0, round=3) 91 | elif '_log' in w: 92 | metrics.log_scalar(w[:3], summed_logging_outputs[w] / sample_size, sample_size, priority=1, round=3) 93 | 94 | @staticmethod 95 | def logging_outputs_can_be_summed() -> bool: 96 | """ 97 | Whether the logging outputs returned by `forward` can be summed 98 | across workers prior to calling `reduce_metrics`. Setting this 99 | to True will improves distributed training speed. 100 | """ 101 | return True 102 | 103 | 104 | @register_criterion('srn_loss') 105 | class SRNLossCriterion(RenderingCriterion): 106 | 107 | def __init__(self, args, task): 108 | super().__init__(args, task) 109 | # HACK: to avoid warnings in c10d 110 | self.dummy_loss = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32), requires_grad=True) 111 | if args.vgg_weight > 0: 112 | from fairnr.criterions.perceptual_loss import VGGPerceptualLoss 113 | self.vgg = VGGPerceptualLoss(resize=False) 114 | 115 | if args.eval_lpips: 116 | from lpips_pytorch import LPIPS 117 | self.lpips = LPIPS(net_type='alex', version='0.1') 118 | 119 | @staticmethod 120 | def add_args(parser): 121 | RenderingCriterion.add_args(parser) 122 | parser.add_argument('--L1', action='store_true', 123 | help='if enabled, use L1 instead of L2 for RGB loss') 124 | parser.add_argument('--color-weight', type=float, default=256.0) 125 | parser.add_argument('--depth-weight', type=float, default=0.0) 126 | parser.add_argument('--depth-weight-decay', type=str, default=None, 127 | help="""if set, use tuple to set (final_ratio, steps). 128 | For instance, (0, 30000) 129 | """) 130 | parser.add_argument('--alpha-weight', type=float, default=0.0) 131 | parser.add_argument('--vgg-weight', type=float, default=0.0) 132 | parser.add_argument('--eikonal-weight', type=float, default=0.0) 133 | parser.add_argument('--regz-weight', type=float, default=0.0) 134 | parser.add_argument('--vgg-level', type=int, choices=[1,2,3,4], default=2) 135 | parser.add_argument('--eval-lpips', action='store_true', 136 | help="evaluate LPIPS scores in validation") 137 | parser.add_argument('--no-background-loss', action='store_true') 138 | 139 | def compute_loss(self, model, net_output, sample, reduce=True): 140 | losses, other_logs = {}, {} 141 | 142 | # prepare data before computing loss 143 | sampled_uv = sample['sampled_uv'] # S, V, 2, N, P, P (patch-size) 144 | S, V, _, N, P1, P2 = sampled_uv.size() 145 | H, W, h, w = sample['size'][0, 0].long().cpu().tolist() 146 | L = N * P1 * P2 147 | flatten_uv = sampled_uv.view(S, V, 2, L) 148 | flatten_index = (flatten_uv[:,:,0] // h + flatten_uv[:,:,1] // w * W).long() 149 | 150 | assert 'colors' in sample and sample['colors'] is not None, "ground-truth colors not provided" 151 | target_colors = sample['colors'] 152 | masks = (sample['alpha'] > 0) if self.args.no_background_loss else None 153 | if L < target_colors.size(2): 154 | target_colors = target_colors.gather(2, flatten_index.unsqueeze(-1).repeat(1,1,1,3)) 155 | masks = masks.gather(2, flatten_uv) if masks is not None else None 156 | 157 | if 'other_logs' in net_output: 158 | other_logs.update(net_output['other_logs']) 159 | 160 | # computing loss 161 | if self.args.color_weight > 0: 162 | color_loss = utils.rgb_loss( 163 | net_output['colors'], target_colors, 164 | masks, self.args.L1) 165 | losses['color_loss'] = (color_loss, self.args.color_weight) 166 | 167 | if self.args.alpha_weight > 0: 168 | _alpha = net_output['missed'].reshape(-1) 169 | alpha_loss = torch.log1p( 170 | 1. / 0.11 * _alpha.float() * (1 - _alpha.float()) 171 | ).mean().type_as(_alpha) 172 | losses['alpha_loss'] = (alpha_loss, self.args.alpha_weight) 173 | 174 | if self.args.depth_weight > 0: 175 | if sample['depths'] is not None: 176 | target_depths = target_depths.gather(2, flatten_index) 177 | depth_mask = masks & (target_depths > 0) 178 | depth_loss = utils.depth_loss(net_output['depths'], target_depths, depth_mask) 179 | 180 | else: 181 | # no depth map is provided, depth loss only applied on background based on masks 182 | max_depth_target = self.args.max_depth * torch.ones_like(net_output['depths']) 183 | if sample['mask'] is not None: 184 | depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, (1 - sample['mask']).bool()) 185 | else: 186 | depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, ~masks) 187 | 188 | depth_weight = self.args.depth_weight 189 | if self.args.depth_weight_decay is not None: 190 | final_factor, final_steps = eval(self.args.depth_weight_decay) 191 | depth_weight *= max(0, 1 - (1 - final_factor) * self.task._num_updates / final_steps) 192 | other_logs['depth_weight'] = depth_weight 193 | 194 | losses['depth_loss'] = (depth_loss, depth_weight) 195 | 196 | 197 | if self.args.vgg_weight > 0: 198 | assert P1 * P2 > 1, "we have to use a patch-based sampling for VGG loss" 199 | target_colors = target_colors.reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5 200 | output_colors = net_output['colors'].reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5 201 | vgg_loss = self.vgg(output_colors, target_colors) 202 | losses['vgg_loss'] = (vgg_loss, self.args.vgg_weight) 203 | 204 | if self.args.eikonal_weight > 0: 205 | losses['eik_loss'] = (net_output['eikonal-term'].mean(), self.args.eikonal_weight) 206 | 207 | # if self.args.regz_weight > 0: 208 | losses['reg_loss'] = (net_output['regz-term'].mean(), self.args.regz_weight) 209 | loss = sum(losses[key][0] * losses[key][1] for key in losses) 210 | 211 | # add a dummy loss 212 | loss = loss + model.dummy_loss + self.dummy_loss * 0. 213 | logging_outputs = {key: item(losses[key][0]) for key in losses} 214 | logging_outputs.update(other_logs) 215 | return loss, logging_outputs 216 | -------------------------------------------------------------------------------- /fairnr/criterions/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | TINY = 1e-7 11 | 12 | 13 | def rgb_loss(predicts, rgbs, masks=None, L1=False, sum=False): 14 | if masks is not None: 15 | if masks.sum() == 0: 16 | return predicts.new_zeros(1).mean() 17 | predicts = predicts[masks] 18 | rgbs = rgbs[masks] 19 | 20 | if L1: 21 | loss = torch.abs(predicts - rgbs).sum(-1) 22 | else: 23 | loss = ((predicts - rgbs) ** 2).sum(-1) 24 | 25 | return loss.mean() if not sum else loss.sum() 26 | 27 | 28 | def depth_loss(depths, depth_gt, masks=None, sum=False): 29 | if masks is not None: 30 | if masks.sum() == 0: 31 | return depths.new_zeros(1).mean() 32 | depth_gt = depth_gt[masks] 33 | depths = depths[masks] 34 | 35 | loss = (depths[masks] - depth_gt[masks]) ** 2 36 | return loss.mean() if not sum else loss.sum() -------------------------------------------------------------------------------- /fairnr/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .shape_dataset import ( 7 | ShapeDataset, ShapeViewDataset, ShapeViewStreamDataset, 8 | SampledPixelDataset, WorldCoordDataset, 9 | InfiniteDataset 10 | ) 11 | 12 | __all__ = [ 13 | 'ShapeDataset', 14 | 'ShapeViewDataset', 15 | 'ShapeViewStreamDataset', 16 | 'SampledPixelDataset', 17 | 'WorldCoordDataset', 18 | ] 19 | -------------------------------------------------------------------------------- /fairnr/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import functools 8 | import cv2 9 | import math 10 | import numpy as np 11 | import imageio 12 | from glob import glob 13 | import os 14 | import copy 15 | import shutil 16 | import skimage.metrics 17 | import pandas as pd 18 | import pylab as plt 19 | import fairseq.distributed_utils as du 20 | 21 | from plyfile import PlyData, PlyElement 22 | from fairseq.meters import StopwatchMeter 23 | 24 | def get_rank(): 25 | try: 26 | return du.get_rank() 27 | except AssertionError: 28 | return 0 29 | 30 | 31 | def get_world_size(): 32 | try: 33 | return du.get_world_size() 34 | except AssertionError: 35 | return 1 36 | 37 | 38 | def parse_views(view_args): 39 | output = [] 40 | try: 41 | xx = view_args.split(':') 42 | ids = xx[0].split(',') 43 | for id in ids: 44 | if '..' in id: 45 | a, b = id.split('..') 46 | output += list(range(int(a), int(b))) 47 | else: 48 | output += [int(id)] 49 | if len(xx) > 1: 50 | output = output[::int(xx[-1])] 51 | except Exception as e: 52 | raise Exception("parse view args error: {}".format(e)) 53 | 54 | return output 55 | 56 | 57 | def get_uv(H, W, h, w): 58 | """ 59 | H, W: real image (intrinsics) 60 | h, w: resized image 61 | """ 62 | uv = np.flip(np.mgrid[0: h, 0: w], axis=0).astype(np.float32) 63 | uv[0] = uv[0] * float(W / w) 64 | uv[1] = uv[1] * float(H / h) 65 | return uv, [float(H / h), float(W / w)] 66 | 67 | 68 | def load_rgb( 69 | path, 70 | resolution=None, 71 | with_alpha=True, 72 | bg_color=[1.0, 1.0, 1.0], 73 | min_rgb=-1, 74 | interpolation='AREA'): 75 | if with_alpha: 76 | img = imageio.imread(path) # RGB-ALPHA 77 | else: 78 | img = imageio.imread(path)[:, :, :3] 79 | 80 | img = skimage.img_as_float32(img).astype('float32') 81 | H, W, D = img.shape 82 | h, w = resolution 83 | 84 | if D == 3: 85 | img = np.concatenate([img, np.ones((img.shape[0], img.shape[1], 1))], -1).astype('float32') 86 | 87 | uv, ratio = get_uv(H, W, h, w) 88 | if (h < H) or (w < W): 89 | # img = cv2.resize(img, (w, h), interpolation=cv2.INTER_NEAREST).astype('float32') 90 | img = cv2.resize(img, (w, h), interpolation=cv2.INTER_AREA).astype('float32') 91 | 92 | if min_rgb == -1: # 0, 1 --> -1, 1 93 | img[:, :, :3] -= 0.5 94 | img[:, :, :3] *= 2. 95 | 96 | img[:, :, :3] = img[:, :, :3] * img[:, :, 3:] + np.asarray(bg_color)[None, None, :] * (1 - img[:, :, 3:]) 97 | img[:, :, 3] = img[:, :, 3] * (img[:, :, :3] != np.asarray(bg_color)[None, None, :]).any(-1) 98 | img = img.transpose(2, 0, 1) 99 | 100 | return img, uv, ratio 101 | 102 | 103 | def load_depth(path, resolution=None, depth_plane=5): 104 | if path is None: 105 | return None 106 | 107 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) 108 | # ret, img = cv2.threshold(img, depth_plane, depth_plane, cv2.THRESH_TRUNC) 109 | 110 | H, W = img.shape[:2] 111 | h, w = resolution 112 | if (h < H) or (w < W): 113 | img = cv2.resize(img, (w, h), interpolation=cv2.INTER_NEAREST).astype('float32') 114 | #img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR) 115 | 116 | if len(img.shape) ==3: 117 | img = img[:,:,:1] 118 | img = img.transpose(2,0,1) 119 | else: 120 | img = img[None,:,:] 121 | return img 122 | 123 | 124 | def load_mask(path, resolution=None): 125 | if path is None: 126 | return None 127 | 128 | img = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32) 129 | h, w = resolution 130 | H, W = img.shape[:2] 131 | if (h < H) or (w < W): 132 | img = cv2.resize(img, (w, h), interpolation=cv2.INTER_NEAREST).astype('float32') 133 | img = img / (img.max() + 1e-7) 134 | return img 135 | 136 | 137 | def load_matrix(path): 138 | lines = [[float(w) for w in line.strip().split()] for line in open(path)] 139 | if len(lines[0]) == 2: 140 | lines = lines[1:] 141 | if len(lines[-1]) == 2: 142 | lines = lines[:-1] 143 | return np.array(lines).astype(np.float32) 144 | 145 | 146 | def load_intrinsics(filepath, resized_width=None, invert_y=False): 147 | try: 148 | intrinsics = load_matrix(filepath) 149 | if intrinsics.shape[0] == 3 and intrinsics.shape[1] == 3: 150 | _intrinsics = np.zeros((4, 4), np.float32) 151 | _intrinsics[:3, :3] = intrinsics 152 | _intrinsics[3, 3] = 1 153 | intrinsics = _intrinsics 154 | if intrinsics.shape[0] == 1 and intrinsics.shape[1] == 16: 155 | intrinsics = intrinsics.reshape(4, 4) 156 | return intrinsics 157 | except ValueError: 158 | pass 159 | 160 | # Get camera intrinsics 161 | with open(filepath, 'r') as file: 162 | 163 | f, cx, cy, _ = map(float, file.readline().split()) 164 | fx = f 165 | if invert_y: 166 | fy = -f 167 | else: 168 | fy = f 169 | 170 | # Build the intrinsic matrices 171 | full_intrinsic = np.array([[fx, 0., cx, 0.], 172 | [0., fy, cy, 0], 173 | [0., 0, 1, 0], 174 | [0, 0, 0, 1]]) 175 | return full_intrinsic 176 | 177 | 178 | def unflatten_img(img, width=512): 179 | sizes = img.size() 180 | height = sizes[-1] // width 181 | return img.reshape(*sizes[:-1], height, width) 182 | 183 | 184 | def square_crop_img(img): 185 | if img.shape[0] == img.shape[1]: 186 | return img # already square 187 | 188 | min_dim = np.amin(img.shape[:2]) 189 | center_coord = np.array(img.shape[:2]) // 2 190 | img = img[center_coord[0] - min_dim // 2:center_coord[0] + min_dim // 2, 191 | center_coord[1] - min_dim // 2:center_coord[1] + min_dim // 2] 192 | return img 193 | 194 | 195 | def sample_pixel_from_image( 196 | num_pixel, num_sample, 197 | mask=None, ratio=1.0, 198 | use_bbox=False, 199 | center_ratio=1.0, 200 | width=512, 201 | patch_size=1): 202 | 203 | if patch_size > 1: 204 | assert (num_pixel % (patch_size * patch_size) == 0) \ 205 | and (num_sample % (patch_size * patch_size) == 0), "size must match" 206 | _num_pixel = num_pixel // (patch_size * patch_size) 207 | _num_sample = num_sample // (patch_size * patch_size) 208 | height = num_pixel // width 209 | 210 | _mask = None if mask is None else \ 211 | mask.reshape(height, width).reshape( 212 | height//patch_size, patch_size, width//patch_size, patch_size 213 | ).any(1).any(-1).reshape(-1) 214 | _width = width // patch_size 215 | _out = sample_pixel_from_image(_num_pixel, _num_sample, _mask, ratio, use_bbox, _width) 216 | _x, _y = _out % _width, _out // _width 217 | x, y = _x * patch_size, _y * patch_size 218 | x = x[:, None, None] + np.arange(patch_size)[None, :, None] 219 | y = y[:, None, None] + np.arange(patch_size)[None, None, :] 220 | out = x + y * width 221 | return out.reshape(-1) 222 | 223 | if center_ratio < 1.0: 224 | r = (1 - center_ratio) / 2.0 225 | H, W = num_pixel // width, width 226 | mask0 = np.zeros((H, W)) 227 | mask0[int(H * r): H - int(H * r), int(W * r): W - int(W * r)] = 1 228 | mask0 = mask0.reshape(-1) 229 | 230 | if mask is None: 231 | mask = mask0 232 | else: 233 | mask = mask * mask0 234 | 235 | if mask is not None: 236 | mask = (mask > 0.0).astype('float32') 237 | 238 | if (mask is None) or \ 239 | (ratio <= 0.0) or \ 240 | (mask.sum() == 0) or \ 241 | ((1 - mask).sum() == 0): 242 | return np.random.choice(num_pixel, num_sample) 243 | 244 | if use_bbox: 245 | mask = mask.reshape(-1, width) 246 | x, y = np.where(mask == 1) 247 | mask = np.zeros_like(mask) 248 | mask[x.min(): x.max()+1, y.min(): y.max()+1] = 1.0 249 | mask = mask.reshape(-1) 250 | 251 | try: 252 | probs = mask * ratio / (mask.sum()) + (1 - mask) / (num_pixel - mask.sum()) * (1 - ratio) 253 | # x = np.random.choice(num_pixel, num_sample, True, p=probs) 254 | return np.random.choice(num_pixel, num_sample, True, p=probs) 255 | 256 | except Exception: 257 | return np.random.choice(num_pixel, num_sample) 258 | 259 | 260 | def colormap(dz): 261 | return plt.cm.jet(dz) 262 | # return plt.cm.viridis(dz) 263 | # return plt.cm.gray(dz) 264 | 265 | 266 | def recover_image(img, min_val=-1, max_val=1, width=512, bg=None, weight=None, raw=False): 267 | if raw: return img 268 | 269 | sizes = img.size() 270 | height = sizes[0] // width 271 | img = img.float().to('cpu') 272 | 273 | if len(sizes) == 1 and (bg is not None): 274 | bg_mask = img.eq(bg)[:, None].type_as(img) 275 | 276 | img = ((img - min_val) / (max_val - min_val)).clamp(min=0, max=1) 277 | if len(sizes) == 1: 278 | img = torch.from_numpy(colormap(img.numpy())[:, :3]) 279 | if weight is not None: 280 | weight = weight.float().to('cpu') 281 | img = img * weight[:, None] 282 | 283 | if bg is not None: 284 | img = img * (1 - bg_mask) + bg_mask 285 | img = img.reshape(height, width, -1) 286 | return img 287 | 288 | 289 | def write_images(writer, images, updates): 290 | for tag in images: 291 | img = images[tag] 292 | tag, dataform = tag.split(':') 293 | writer.add_image(tag, img, updates, dataformats=dataform) 294 | 295 | 296 | def compute_psnr(p, t): 297 | """Compute PSNR of model image predictions. 298 | :param prediction: Return value of forward pass. 299 | :param ground_truth: Ground truth. 300 | :return: (psnr, ssim): tuple of floats 301 | """ 302 | ssim = skimage.metrics.structural_similarity(p, t, multichannel=True, data_range=1) 303 | psnr = skimage.metrics.peak_signal_noise_ratio(p, t, data_range=1) 304 | return ssim, psnr 305 | 306 | 307 | def save_point_cloud(filename, xyz, rgb=None): 308 | if rgb is None: 309 | vertex = np.array([(xyz[k, 0], xyz[k, 1], xyz[k, 2]) for k in range(xyz.shape[0])], 310 | dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 311 | else: 312 | vertex = np.array([(xyz[k, 0], xyz[k, 1], xyz[k, 2], rgb[k, 0], rgb[k, 1], rgb[k, 2]) for k in range(xyz.shape[0])], 313 | dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 314 | # PlyData([PlyElement.describe(vertex, 'vertex')], text=True).write(filename) 315 | # from fairseq import pdb; pdb.set_trace() 316 | PlyData([PlyElement.describe(vertex, 'vertex')]).write(open(filename, 'wb')) 317 | 318 | 319 | class InfIndex(object): 320 | 321 | def __init__(self, index_list, shuffle=False): 322 | self.index_list = index_list 323 | self.size = len(index_list) 324 | self.shuffle = shuffle 325 | self.reset_permutation() 326 | 327 | def reset_permutation(self): 328 | if self.shuffle: 329 | self._perm = np.random.permutation(self.index_list).tolist() 330 | else: 331 | self._perm = copy.deepcopy(self.index_list) 332 | 333 | def __iter__(self): 334 | return self 335 | 336 | def __next__(self): 337 | if len(self._perm) == 0: 338 | self.reset_permutation() 339 | return self._perm.pop() 340 | 341 | def __len__(self): 342 | return self.size 343 | 344 | 345 | class Timer(StopwatchMeter): 346 | def __enter__(self): 347 | """Start a new timer as a context manager""" 348 | self.start() 349 | return self 350 | 351 | def __exit__(self, *exc_info): 352 | """Stop the context manager timer""" 353 | self.stop() 354 | 355 | 356 | class GPUTimer(object): 357 | def __enter__(self): 358 | """Start a new timer as a context manager""" 359 | self.start = torch.cuda.Event(enable_timing=True) 360 | self.end = torch.cuda.Event(enable_timing=True) 361 | self.start.record() 362 | self.sum = 0 363 | return self 364 | 365 | def __exit__(self, *exc_info): 366 | """Stop the context manager timer""" 367 | self.end.record() 368 | torch.cuda.synchronize() 369 | self.sum = self.start.elapsed_time(self.end) / 1000. 370 | -------------------------------------------------------------------------------- /fairnr/data/geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from fairnr.data import data_utils as D 11 | try: 12 | from fairnr.clib._ext import build_octree 13 | except ImportError: 14 | pass 15 | 16 | INF = 1000.0 17 | 18 | 19 | def ones_like(x): 20 | T = torch if isinstance(x, torch.Tensor) else np 21 | return T.ones_like(x) 22 | 23 | 24 | def stack(x): 25 | T = torch if isinstance(x[0], torch.Tensor) else np 26 | return T.stack(x) 27 | 28 | 29 | def matmul(x, y): 30 | T = torch if isinstance(x, torch.Tensor) else np 31 | return T.matmul(x, y) 32 | 33 | 34 | def cross(x, y, axis=0): 35 | T = torch if isinstance(x, torch.Tensor) else np 36 | return T.cross(x, y, axis) 37 | 38 | 39 | def cat(x, axis=1): 40 | if isinstance(x[0], torch.Tensor): 41 | return torch.cat(x, dim=axis) 42 | return np.concatenate(x, axis=axis) 43 | 44 | 45 | def normalize(x, axis=-1, order=2): 46 | if isinstance(x, torch.Tensor): 47 | l2 = x.norm(p=order, dim=axis, keepdim=True) 48 | return x / (l2 + 1e-8), l2 49 | 50 | else: 51 | l2 = np.linalg.norm(x, order, axis) 52 | l2 = np.expand_dims(l2, axis) 53 | l2[l2==0] = 1 54 | return x / l2, l2 55 | 56 | 57 | def parse_extrinsics(extrinsics, world2camera=True): 58 | """ this function is only for numpy for now""" 59 | if extrinsics.shape[0] == 3 and extrinsics.shape[1] == 4: 60 | extrinsics = np.vstack([extrinsics, np.array([[0, 0, 0, 1.0]])]) 61 | if extrinsics.shape[0] == 1 and extrinsics.shape[1] == 16: 62 | extrinsics = extrinsics.reshape(4, 4) 63 | if world2camera: 64 | extrinsics = np.linalg.inv(extrinsics).astype(np.float32) 65 | return extrinsics 66 | 67 | 68 | def parse_intrinsics(intrinsics): 69 | fx = intrinsics[0, 0] 70 | fy = intrinsics[1, 1] 71 | cx = intrinsics[0, 2] 72 | cy = intrinsics[1, 2] 73 | return fx, fy, cx, cy 74 | 75 | 76 | def uv2cam(uv, z, intrinsics, homogeneous=False): 77 | fx, fy, cx, cy = parse_intrinsics(intrinsics) 78 | x_lift = (uv[0] - cx) / fx * z 79 | y_lift = (uv[1] - cy) / fy * z 80 | z_lift = ones_like(x_lift) * z 81 | 82 | if homogeneous: 83 | return stack([x_lift, y_lift, z_lift, ones_like(z_lift)]) 84 | else: 85 | return stack([x_lift, y_lift, z_lift]) 86 | 87 | 88 | def cam2world(xyz_cam, inv_RT): 89 | return matmul(inv_RT, xyz_cam)[:3] 90 | 91 | 92 | def r6d2mat(d6: torch.Tensor) -> torch.Tensor: 93 | """ 94 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 95 | using Gram--Schmidt orthogonalisation per Section B of [1]. 96 | Args: 97 | d6: 6D rotation representation, of size (*, 6) 98 | 99 | Returns: 100 | batch of rotation matrices of size (*, 3, 3) 101 | 102 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 103 | On the Continuity of Rotation Representations in Neural Networks. 104 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 105 | Retrieved from http://arxiv.org/abs/1812.07035 106 | """ 107 | 108 | a1, a2 = d6[..., :3], d6[..., 3:] 109 | b1 = F.normalize(a1, dim=-1) 110 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 111 | b2 = F.normalize(b2, dim=-1) 112 | b3 = torch.cross(b1, b2, dim=-1) 113 | return torch.stack((b1, b2, b3), dim=-2) 114 | 115 | 116 | def get_ray_direction(ray_start, uv, intrinsics, inv_RT, depths=None): 117 | if depths is None: 118 | depths = 1 119 | rt_cam = uv2cam(uv, depths, intrinsics, True) 120 | rt = cam2world(rt_cam, inv_RT) 121 | ray_dir, _ = normalize(rt - ray_start[:, None], axis=0) 122 | return ray_dir 123 | 124 | 125 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): 126 | """ 127 | This function takes a vector 'camera_position' which specifies the location 128 | of the camera in world coordinates and two vectors `at` and `up` which 129 | indicate the position of the object and the up directions of the world 130 | coordinate system respectively. The object is assumed to be centered at 131 | the origin. 132 | 133 | The output is a rotation matrix representing the transformation 134 | from world coordinates -> view coordinates. 135 | 136 | Input: 137 | camera_position: 3 138 | at: 1 x 3 or N x 3 (0, 0, 0) in default 139 | up: 1 x 3 or N x 3 (0, 1, 0) in default 140 | """ 141 | 142 | if at is None: 143 | at = torch.zeros_like(camera_position) 144 | else: 145 | at = torch.tensor(at).type_as(camera_position) 146 | if up is None: 147 | up = torch.zeros_like(camera_position) 148 | up[2] = -1 149 | else: 150 | up = torch.tensor(up).type_as(camera_position) 151 | 152 | z_axis = normalize(at - camera_position)[0] 153 | x_axis = normalize(cross(up, z_axis))[0] 154 | y_axis = normalize(cross(z_axis, x_axis))[0] 155 | 156 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) 157 | return R 158 | 159 | 160 | def ray(ray_start, ray_dir, depths): 161 | return ray_start + ray_dir * depths 162 | 163 | 164 | def compute_normal_map(ray_start, ray_dir, depths, RT, width=512, proj=False): 165 | # TODO: 166 | # this function is pytorch-only (for not) 167 | wld_coords = ray(ray_start, ray_dir, depths.unsqueeze(-1)).transpose(0, 1) 168 | cam_coords = matmul(RT[:3, :3], wld_coords) + RT[:3, 3].unsqueeze(-1) 169 | cam_coords = D.unflatten_img(cam_coords, width) 170 | 171 | # estimate local normal 172 | shift_l = cam_coords[:, 2:, :] 173 | shift_r = cam_coords[:, :-2, :] 174 | shift_u = cam_coords[:, :, 2: ] 175 | shift_d = cam_coords[:, :, :-2] 176 | diff_hor = normalize(shift_r - shift_l, axis=0)[0][:, :, 1:-1] 177 | diff_ver = normalize(shift_u - shift_d, axis=0)[0][:, 1:-1, :] 178 | normal = cross(diff_hor, diff_ver) 179 | _normal = normal.new_zeros(*cam_coords.size()) 180 | _normal[:, 1:-1, 1:-1] = normal 181 | _normal = _normal.reshape(3, -1).transpose(0, 1) 182 | 183 | # compute the projected color 184 | if proj: 185 | _normal = normalize(_normal, axis=1)[0] 186 | wld_coords0 = ray(ray_start, ray_dir, 0).transpose(0, 1) 187 | cam_coords0 = matmul(RT[:3, :3], wld_coords0) + RT[:3, 3].unsqueeze(-1) 188 | cam_coords0 = D.unflatten_img(cam_coords0, width) 189 | cam_raydir = normalize(cam_coords - cam_coords0, 0)[0].reshape(3, -1).transpose(0, 1) 190 | proj_factor = (_normal * cam_raydir).sum(-1).abs() * 0.8 + 0.2 191 | return proj_factor 192 | return _normal 193 | 194 | 195 | def trilinear_interp(p, q, point_feats): 196 | weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True) 197 | if point_feats.dim() == 2: 198 | point_feats = point_feats.view(point_feats.size(0), 8, -1) 199 | point_feats = (weights * point_feats).sum(1) 200 | return point_feats 201 | 202 | 203 | # helper functions for encoder 204 | 205 | def padding_points(xs, pad): 206 | if len(xs) == 1: 207 | return xs[0].unsqueeze(0) 208 | 209 | maxlen = max([x.size(0) for x in xs]) 210 | xt = xs[0].new_ones(len(xs), maxlen, xs[0].size(1)).fill_(pad) 211 | for i in range(len(xs)): 212 | xt[i, :xs[i].size(0)] = xs[i] 213 | return xt 214 | 215 | 216 | def pruning_points(feats, points, scores, depth=0, th=0.5): 217 | if depth > 0: 218 | g = int(8 ** depth) 219 | scores = scores.reshape(scores.size(0), -1, g).sum(-1, keepdim=True) 220 | scores = scores.expand(*scores.size()[:2], g).reshape(scores.size(0), -1) 221 | alpha = (1 - torch.exp(-scores)) > th 222 | feats = [feats[i][alpha[i]] for i in range(alpha.size(0))] 223 | points = [points[i][alpha[i]] for i in range(alpha.size(0))] 224 | points = padding_points(points, INF) 225 | feats = padding_points(feats, 0) 226 | return feats, points 227 | 228 | 229 | def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2): 230 | c = torch.arange(1, 2 * bits, 2, device=point_xyz.device) 231 | ox, oy, oz = torch.meshgrid([c, c, c]) 232 | offset = (torch.cat([ 233 | ox.reshape(-1, 1), 234 | oy.reshape(-1, 1), 235 | oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1) 236 | if not offset_only: 237 | return point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel 238 | return offset.type_as(point_xyz) * quarter_voxel 239 | 240 | 241 | def discretize_points(voxel_points, voxel_size): 242 | # this function turns voxel centers/corners into integer indeices 243 | # we assume all points are alreay put as voxels (real numbers) 244 | minimal_voxel_point = voxel_points.min(dim=0, keepdim=True)[0] 245 | voxel_indices = ((voxel_points - minimal_voxel_point) / voxel_size).round_().long() # float 246 | residual = (voxel_points - voxel_indices.type_as(voxel_points) * voxel_size).mean(0, keepdim=True) 247 | return voxel_indices, residual 248 | 249 | 250 | def splitting_points(point_xyz, point_feats, values, half_voxel): 251 | # generate new centers 252 | quarter_voxel = half_voxel * .5 253 | new_points = offset_points(point_xyz, quarter_voxel).reshape(-1, 3) 254 | old_coords = discretize_points(point_xyz, quarter_voxel)[0] 255 | new_coords = offset_points(old_coords).reshape(-1, 3) 256 | new_keys0 = offset_points(new_coords).reshape(-1, 3) 257 | 258 | # get unique keys and inverse indices (for original key0, where it maps to in keys) 259 | new_keys, new_feats = torch.unique(new_keys0, dim=0, sorted=True, return_inverse=True) 260 | new_keys_idx = new_feats.new_zeros(new_keys.size(0)).scatter_( 261 | 0, new_feats, torch.arange(new_keys0.size(0), device=new_feats.device) // 64) 262 | 263 | # recompute key vectors using trilinear interpolation 264 | new_feats = new_feats.reshape(-1, 8) 265 | 266 | if values is not None: 267 | p = (new_keys - old_coords[new_keys_idx]).type_as(point_xyz).unsqueeze(1) * .25 + 0.5 # (1/4 voxel size) 268 | q = offset_points(p, .5, offset_only=True).unsqueeze(0) + 0.5 # BUG? 269 | point_feats = point_feats[new_keys_idx] 270 | point_feats = F.embedding(point_feats, values).view(point_feats.size(0), -1) 271 | new_values = trilinear_interp(p, q, point_feats) 272 | else: 273 | new_values = None 274 | return new_points, new_feats, new_values, new_keys 275 | 276 | 277 | def expand_points(voxel_points, voxel_size): 278 | _voxel_size = min([ 279 | torch.sqrt(((voxel_points[j:j+1] - voxel_points[j+1:]) ** 2).sum(-1).min()) 280 | for j in range(100)]) 281 | depth = int(np.round(torch.log2(_voxel_size / voxel_size))) 282 | if depth > 0: 283 | half_voxel = _voxel_size / 2.0 284 | for _ in range(depth): 285 | voxel_points = offset_points(voxel_points, half_voxel / 2.0).reshape(-1, 3) 286 | half_voxel = half_voxel / 2.0 287 | 288 | return voxel_points, depth 289 | 290 | 291 | def get_edge(depth_pts, voxel_pts, voxel_size, th=0.05): 292 | voxel_pts = offset_points(voxel_pts, voxel_size / 2.0) 293 | diff_pts = (voxel_pts - depth_pts[:, None, :]).norm(dim=2) 294 | ab = diff_pts.sort(dim=1)[0][:, :2] 295 | a, b = ab[:, 0], ab[:, 1] 296 | c = voxel_size 297 | p = (ab.sum(-1) + c) / 2.0 298 | h = (p * (p - a) * (p - b) * (p - c)) ** 0.5 / c 299 | return h < (th * voxel_size) 300 | 301 | 302 | # fill-in image 303 | def fill_in(shape, hits, input, initial=1.0): 304 | input_sizes = [k for k in input.size()] 305 | if (len(input_sizes) == len(shape)) and \ 306 | all([shape[i] == input_sizes[i] for i in range(len(shape))]): 307 | return input # shape is the same no need to fill 308 | 309 | if isinstance(initial, torch.Tensor): 310 | output = initial.expand(*shape) 311 | else: 312 | output = input.new_ones(*shape) * initial 313 | if input is not None: 314 | if len(shape) == 1: 315 | return output.masked_scatter(hits, input) 316 | return output.masked_scatter(hits.unsqueeze(-1).expand(*shape), input) 317 | return output 318 | 319 | 320 | def build_easy_octree(points, half_voxel): 321 | coords, residual = discretize_points(points, half_voxel) 322 | ranges = coords.max(0)[0] - coords.min(0)[0] 323 | depths = torch.log2(ranges.max().float()).ceil_().long() - 1 324 | center = (coords.max(0)[0] + coords.min(0)[0]) / 2 325 | centers, children = build_octree(center, coords, int(depths)) 326 | centers = centers.float() * half_voxel + residual # transform back to float 327 | return centers, children 328 | 329 | 330 | def cartesian_to_spherical(xyz): 331 | """ xyz: batch x 3 332 | """ 333 | r = xyz.norm(p=2, dim=-1) 334 | theta = torch.atan2(xyz[:, :2].norm(p=2, dim=-1), xyz[:, 2]) 335 | phi = torch.atan2(xyz[:, 1], xyz[:, 0]) 336 | return torch.stack((r, theta, phi), 1) 337 | 338 | 339 | def spherical_to_cartesian(rtp): 340 | x = rtp[:, 0] * torch.sin(rtp[:, 1]) * torch.cos(rtp[:, 2]) 341 | y = rtp[:, 0] * torch.sin(rtp[:, 1]) * torch.sin(rtp[:, 2]) 342 | z = rtp[:, 0] * torch.cos(rtp[:, 1]) 343 | return torch.stack((x, y, z), 1) -------------------------------------------------------------------------------- /fairnr/data/trajectory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | 9 | TRAJECTORY_REGISTRY = {} 10 | 11 | 12 | def register_traj(name): 13 | def register_traj_fn(fn): 14 | if name in TRAJECTORY_REGISTRY: 15 | raise ValueError('Cannot register duplicate trajectory ({})'.format(name)) 16 | TRAJECTORY_REGISTRY[name] = fn 17 | return fn 18 | return register_traj_fn 19 | 20 | 21 | def get_trajectory(name): 22 | return TRAJECTORY_REGISTRY.get(name, None) 23 | 24 | 25 | @register_traj('circle') 26 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): 27 | if axis == 'z': 28 | return lambda t: [radius * np.cos(r * t+t0), radius * np.sin(r * t+t0), h] 29 | elif axis == 'y': 30 | return lambda t: [radius * np.cos(r * t+t0), h, radius * np.sin(r * t+t0)] 31 | else: 32 | return lambda t: [h, radius * np.cos(r * t+t0), radius * np.sin(r * t+t0)] 33 | 34 | 35 | @register_traj('zoomin_circle') 36 | def zoomin_circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): 37 | ra = lambda t: 0.1 + abs(4.0 - t * 2 / np.pi) 38 | 39 | if axis == 'z': 40 | return lambda t: [radius * ra(t) * np.cos(r * t+t0), radius * ra(t) * np.sin(r * t+t0), h] 41 | elif axis == 'y': 42 | return lambda t: [radius * ra(t) * np.cos(r * t+t0), h, radius * ra(t) * np.sin(r * t+t0)] 43 | else: 44 | return lambda t: [h, radius * (4.2 - t * 2 / np.pi) * np.cos(r * t+t0), radius * (4.2 - t * 2 / np.pi) * np.sin(r * t+t0)] 45 | 46 | 47 | @register_traj('zoomin_line') 48 | def zoomin_line(radius=3.5, h=0.0, axis='z', t0=0, r=1, min_r=0.0001, max_r=10, step_r=10): 49 | ra = lambda t: min_r + (max_r - min_r) * t * 180 / np.pi / step_r 50 | 51 | if axis == 'z': 52 | return lambda t: [radius * ra(t) * np.cos(t0), radius * ra(t) * np.sin(t0), h * ra(t)] 53 | elif axis == 'y': 54 | return lambda t: [radius * ra(t) * np.cos(t0), h, radius * ra(t) * np.sin(t0)] 55 | else: 56 | return lambda t: [h, radius * (4.2 - t * 2 / np.pi) * np.cos(r * t+t0), radius * (4.2 - t * 2 / np.pi) * np.sin(r * t+t0)] 57 | -------------------------------------------------------------------------------- /fairnr/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | # automatically import any Python files in the models/ directory 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('fairnr.models.' + model_name) 16 | -------------------------------------------------------------------------------- /fairnr/models/multi_nsvf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | import torch 10 | 11 | from fairseq.models import ( 12 | register_model, 13 | register_model_architecture 14 | ) 15 | from fairnr.models.nsvf import NSVFModel, base_architecture 16 | 17 | 18 | @register_model('multi_nsvf') 19 | class MultiNSVFModel(NSVFModel): 20 | 21 | ENCODER = 'multi_sparsevoxel_encoder' 22 | 23 | @torch.no_grad() 24 | def split_voxels(self): 25 | logger.info("half the global voxel size {:.4f} -> {:.4f}".format( 26 | self.encoder.all_voxels[0].voxel_size.item(), 27 | self.encoder.all_voxels[0].voxel_size.item() * .5)) 28 | self.encoder.splitting() 29 | for id in range(len(self.encoder.all_voxels)): 30 | self.encoder.all_voxels[id].voxel_size *= .5 31 | self.encoder.all_voxels[id].max_hits *= 1.5 32 | 33 | @torch.no_grad() 34 | def reduce_stepsize(self): 35 | logger.info("reduce the raymarching step size {:.4f} -> {:.4f}".format( 36 | self.encoder.all_voxels[0].step_size.item(), 37 | self.encoder.all_voxels[0].step_size.item() * .5)) 38 | for id in range(len(self.encoder.all_voxels)): 39 | self.encoder.all_voxels[id].step_size *= .5 40 | 41 | 42 | @register_model("shared_nsvf") 43 | class SharedNSVFModel(MultiNSVFModel): 44 | 45 | ENCODER = 'shared_sparsevoxel_encoder' 46 | 47 | 48 | @register_model_architecture('multi_nsvf', "multi_nsvf_base") 49 | def multi_base_architecture(args): 50 | base_architecture(args) 51 | 52 | 53 | @register_model_architecture('shared_nsvf', 'shared_nsvf') 54 | def shared_base_architecture(args): 55 | # encoder 56 | args.context_embed_dim = getattr(args, "context_embed_dim", 96) 57 | 58 | # field 59 | args.inputs_to_density = getattr(args, "inputs_to_density", "emb:6:32, context:0:96") 60 | args.hypernetwork = getattr(args, "hypernetwork", False) 61 | base_architecture(args) -------------------------------------------------------------------------------- /fairnr/models/nerf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | import cv2, math, time 11 | import numpy as np 12 | from collections import defaultdict 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | from fairseq.models import ( 19 | register_model, 20 | register_model_architecture 21 | ) 22 | from fairseq.utils import with_torch_seed 23 | 24 | from fairnr.models.fairnr_model import BaseModel 25 | 26 | 27 | @register_model('nerf') 28 | class NeRFModel(BaseModel): 29 | """ This is a simple re-implementation of the vanilla NeRF 30 | """ 31 | ENCODER = 'volume_encoder' 32 | READER = 'image_reader' 33 | FIELD = 'radiance_field' 34 | RAYMARCHER = 'volume_rendering' 35 | 36 | @classmethod 37 | def add_args(cls, parser): 38 | super().add_args(parser) 39 | parser.add_argument('--fixed-num-samples', type=int, 40 | help='number of samples for the first pass along the ray.') 41 | parser.add_argument('--fixed-fine-num-samples', type=int, 42 | help='sample a fixed number of points for each ray in hierarchical sampling, e.g. 64, 128.') 43 | parser.add_argument('--reduce-fine-for-missed', action='store_true', 44 | help='if set, the number of fine samples is discounted based on foreground probability only.') 45 | 46 | def preprocessing(self, **kwargs): 47 | return self.encoder.precompute(**kwargs) 48 | 49 | def intersecting(self, ray_start, ray_dir, encoder_states, **kwargs): 50 | ray_start, ray_dir, intersection_outputs, hits = \ 51 | self.encoder.ray_intersect(ray_start, ray_dir, encoder_states) 52 | return ray_start, ray_dir, intersection_outputs, hits, None 53 | 54 | def raymarching(self, ray_start, ray_dir, intersection_outputs, encoder_states, fine=False): 55 | # sample points and use middle point approximation 56 | with with_torch_seed(self.unique_seed): # make sure each GPU sample differently. 57 | samples = self.encoder.ray_sample(intersection_outputs) 58 | field = self.field_fine if fine and (self.field_fine is not None) else self.field 59 | all_results = self.raymarcher( 60 | self.encoder, field, ray_start, ray_dir, samples, encoder_states 61 | ) 62 | return samples, all_results 63 | 64 | def prepare_hierarchical_sampling(self, intersection_outputs, samples, all_results): 65 | # this function is basically the same as that in NSVF model. 66 | depth = samples.get('original_point_depth', samples['sampled_point_depth']) 67 | dists = samples.get('original_point_distance', samples['sampled_point_distance']) 68 | intersection_outputs['min_depth'] = depth - dists * .5 69 | intersection_outputs['max_depth'] = depth + dists * .5 70 | intersection_outputs['intersected_voxel_idx'] = samples['sampled_point_voxel_idx'].contiguous() 71 | # safe_probs = all_results['probs'] + 1e-8 # HACK: make a non-zero distribution 72 | safe_probs = all_results['probs'] + 1e-5 # NeRF used 1e-5, will this make a change? 73 | intersection_outputs['probs'] = safe_probs / safe_probs.sum(-1, keepdim=True) 74 | intersection_outputs['steps'] = safe_probs.new_ones(*safe_probs.size()[:-1]) 75 | if getattr(self.args, "fixed_fine_num_samples", 0) > 0: 76 | intersection_outputs['steps'] = intersection_outputs['steps'] * self.args.fixed_fine_num_samples 77 | if getattr(self.args, "reduce_fine_for_missed", False): 78 | intersection_outputs['steps'] = intersection_outputs['steps'] * safe_probs.sum(-1) 79 | return intersection_outputs 80 | 81 | def postprocessing(self, ray_start, ray_dir, all_results, hits, sizes): 82 | # vanilla nerf hits everything. so no need to fill_in 83 | S, V, P = sizes 84 | fullsize = S * V * P 85 | 86 | all_results['missed'] = all_results['missed'].view(S, V, P) 87 | all_results['colors'] = all_results['colors'].view(S, V, P, 3) 88 | all_results['depths'] = all_results['depths'].view(S, V, P) 89 | if 'z' in all_results: 90 | all_results['z'] = all_results['z'].view(S, V, P) 91 | 92 | BG_DEPTH = self.field.bg_color.depth 93 | bg_color = self.field.bg_color(all_results['colors']) 94 | all_results['colors'] += all_results['missed'].unsqueeze(-1) * bg_color.reshape(fullsize, 3).view(S, V, P, 3) 95 | all_results['depths'] += all_results['missed'] * BG_DEPTH 96 | if 'normal' in all_results: 97 | all_results['normal'] = all_results['normal'].view(S, V, P, 3) 98 | return all_results 99 | 100 | def add_other_logs(self, all_results): 101 | return {} 102 | 103 | 104 | @register_model_architecture("nerf", "nerf_base") 105 | def base_architecture(args): 106 | # parameter needs to be changed 107 | args.near = getattr(args, "near", 2) 108 | args.far = getattr(args, "far", 4) 109 | args.fixed_num_samples = getattr(args, "fixed_num_samples", 64) 110 | args.fixed_fine_num_samples = getattr(args, "fixed_fine_num_samples", 128) 111 | args.hierarchical_sampling = getattr(args, "hierarchical_sampling", True) 112 | args.use_fine_model = getattr(args, "use_fine_model", True) 113 | 114 | # field 115 | args.inputs_to_density = getattr(args, "inputs_to_density", "pos:10") 116 | args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, ray:4") 117 | args.feature_embed_dim = getattr(args, "feature_embed_dim", 256) 118 | args.density_embed_dim = getattr(args, "density_embed_dim", 128) 119 | args.texture_embed_dim = getattr(args, "texture_embed_dim", 256) 120 | 121 | # API Update: fix the number of layers 122 | args.feature_layers = getattr(args, "feature_layers", 1) 123 | args.texture_layers = getattr(args, "texture_layers", 3) 124 | 125 | args.background_stop_gradient = getattr(args, "background_stop_gradient", False) 126 | args.background_depth = getattr(args, "background_depth", 5.0) 127 | 128 | # raymarcher 129 | args.discrete_regularization = getattr(args, "discrete_regularization", False) 130 | args.deterministic_step = getattr(args, "deterministic_step", False) 131 | args.raymarching_tolerance = getattr(args, "raymarching_tolerance", 0) 132 | 133 | # reader 134 | args.pixel_per_view = getattr(args, "pixel_per_view", 2048) 135 | args.sampling_on_mask = getattr(args, "sampling_on_mask", 0.0) 136 | args.sampling_at_center = getattr(args, "sampling_at_center", 1.0) 137 | args.sampling_on_bbox = getattr(args, "sampling_on_bbox", False) 138 | args.sampling_patch_size = getattr(args, "sampling_patch_size", 1) 139 | args.sampling_skipping_size = getattr(args, "sampling_skipping_size", 1) 140 | 141 | # others 142 | args.chunk_size = getattr(args, "chunk_size", 64) 143 | args.valid_chunk_size = getattr(args, "valid_chunk_size", 64) 144 | 145 | @register_model_architecture("nerf", "nerf_deep") 146 | def nerf_deep_architecture(args): 147 | args.feature_layers = getattr(args, "feature_layers", 6) 148 | args.feature_field_skip_connect = getattr(args, "feature_field_skip_connect", 3) 149 | args.no_layernorm_mlp = getattr(args, "no_layernorm_mlp", True) 150 | base_architecture(args) 151 | 152 | @register_model_architecture("nerf", "nerf_nerf") 153 | def nerf_nerf_architecture(args): 154 | args.feature_layers = getattr(args, "feature_layers", 6) 155 | args.texture_layers = getattr(args, "texture_layers", 0) 156 | args.feature_field_skip_connect = getattr(args, "feature_field_skip_connect", 3) 157 | args.no_layernorm_mlp = getattr(args, "no_layernorm_mlp", True) 158 | base_architecture(args) 159 | 160 | @register_model_architecture("nerf", "nerf_xyzn_nope") 161 | def nerf2_architecture(args): 162 | args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, pos:0:3, normal:0:3, sigma:0:1, ray:4") 163 | base_architecture(args) 164 | 165 | 166 | @register_model('sdf_nerf') 167 | class SDFNeRFModel(NeRFModel): 168 | 169 | FIELD = "sdf_radiance_field" 170 | 171 | 172 | @register_model_architecture("sdf_nerf", "sdf_nerf") 173 | def sdf_nsvf_architecture(args): 174 | args.feature_layers = getattr(args, "feature_layers", 6) 175 | args.feature_field_skip_connect = getattr(args, "feature_field_skip_connect", 3) 176 | args.no_layernorm_mlp = getattr(args, "no_layernorm_mlp", True) 177 | nerf2_architecture(args) 178 | 179 | 180 | @register_model('sg_nerf') 181 | class SGNeRFModel(NeRFModel): 182 | """ This is a simple re-implementation of the vanilla NeRF 183 | """ 184 | ENCODER = 'infinite_volume_encoder' 185 | 186 | def postprocessing(self, ray_start, ray_dir, all_results, hits, sizes): 187 | # vanilla nerf hits everything. so no need to fill_in 188 | S, V, P = sizes 189 | all_results['missed'] = all_results['missed'].view(S, V, P) 190 | all_results['colors'] = all_results['colors'].view(S, V, P, 3) 191 | all_results['depths'] = all_results['depths'].view(S, V, P) 192 | if 'z' in all_results: 193 | all_results['z'] = all_results['z'].view(S, V, P) 194 | if 'normal' in all_results: 195 | all_results['normal'] = all_results['normal'].view(S, V, P, 3) 196 | return all_results 197 | 198 | @register_model_architecture("sg_nerf", "sg_nerf_base") 199 | def sg_nerf_architecture(args): 200 | INF_FAR = 1e6 201 | args.inputs_to_density = getattr(args, "inputs_to_density", "pos:10:4") 202 | args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, ray:4:3:b") 203 | args.near = getattr(args, "near", 2) 204 | args.far = getattr(args, "far", INF_FAR) 205 | base_architecture(args) 206 | 207 | 208 | @register_model_architecture("sg_nerf", "sg_nerf_new") 209 | def sg_nerf2_architecture(args): 210 | args.nerf_style_mlp = getattr(args, "nerf_style_mlp", True) 211 | args.texture_embed_dim = getattr(args, "texture_embed_dim", 128) 212 | sg_nerf_architecture(args) -------------------------------------------------------------------------------- /fairnr/models/nmf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | import torch 10 | from fairseq.models import ( 11 | register_model, 12 | register_model_architecture 13 | ) 14 | from fairnr.models.nsvf import NSVFModel 15 | 16 | 17 | @register_model('nmf') 18 | class NMFModel(NSVFModel): 19 | """ 20 | Experimental code: Neural Mesh Field 21 | """ 22 | ENCODER = 'triangle_mesh_encoder' 23 | 24 | @torch.no_grad() 25 | def prune_voxels(self, *args, **kwargs): 26 | pass 27 | 28 | @torch.no_grad() 29 | def split_voxels(self): 30 | pass 31 | # logger.info("half the global cage size {:.4f} -> {:.4f}".format( 32 | # self.encoder.cage_size.item(), self.encoder.cage_size.item() * .5)) 33 | # self.encoder.cage_size *= .5 34 | 35 | 36 | @register_model_architecture("nmf", "nmf_base") 37 | def base_architecture(args): 38 | # parameter needs to be changed 39 | args.max_hits = getattr(args, "max_hits", 60) 40 | args.raymarching_stepsize = getattr(args, "raymarching_stepsize", 0.01) 41 | 42 | # encoder default parameter 43 | args.voxel_embed_dim = getattr(args, "voxel_embed_dim", 0) 44 | args.voxel_path = getattr(args, "voxel_path", None) 45 | 46 | # field 47 | args.inputs_to_density = getattr(args, "inputs_to_density", "pos:10") 48 | args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, pos:10, ray:4") 49 | args.feature_embed_dim = getattr(args, "feature_embed_dim", 256) 50 | args.density_embed_dim = getattr(args, "density_embed_dim", 128) 51 | args.texture_embed_dim = getattr(args, "texture_embed_dim", 256) 52 | 53 | args.feature_layers = getattr(args, "feature_layers", 1) 54 | args.texture_layers = getattr(args, "texture_layers", 3) 55 | 56 | args.background_stop_gradient = getattr(args, "background_stop_gradient", False) 57 | args.background_depth = getattr(args, "background_depth", 5.0) 58 | 59 | # raymarcher 60 | args.discrete_regularization = getattr(args, "discrete_regularization", False) 61 | args.deterministic_step = getattr(args, "deterministic_step", False) 62 | args.raymarching_tolerance = getattr(args, "raymarching_tolerance", 0) 63 | 64 | # reader 65 | args.pixel_per_view = getattr(args, "pixel_per_view", 2048) 66 | args.sampling_on_mask = getattr(args, "sampling_on_mask", 0.0) 67 | args.sampling_at_center = getattr(args, "sampling_at_center", 1.0) 68 | args.sampling_on_bbox = getattr(args, "sampling_on_bbox", False) 69 | args.sampling_patch_size = getattr(args, "sampling_patch_size", 1) 70 | args.sampling_skipping_size = getattr(args, "sampling_skipping_size", 1) 71 | 72 | # others 73 | args.chunk_size = getattr(args, "chunk_size", 64) 74 | 75 | 76 | @register_model_architecture("nmf", "nmf_nerf") 77 | def nerf_style_architecture(args): 78 | args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, ray:4") 79 | args.feature_layers = getattr(args, "feature_layers", 6) 80 | args.texture_layers = getattr(args, "texture_layers", 0) 81 | args.feature_field_skip_connect = getattr(args, "feature_field_skip_connect", 3) 82 | args.no_layernorm_mlp = getattr(args, "no_layernorm_mlp", True) 83 | base_architecture(args) -------------------------------------------------------------------------------- /fairnr/models/nsvf_bg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | import cv2, math, time, copy, json 10 | import numpy as np 11 | from collections import defaultdict 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from fairseq.models import ( 18 | register_model, 19 | register_model_architecture 20 | ) 21 | from fairseq.utils import item, with_torch_seed 22 | from fairnr.data.geometry import compute_normal_map, fill_in 23 | from fairnr.models.nsvf import NSVFModel, base_architecture, nerf_style_architecture 24 | from fairnr.models.fairnr_model import get_encoder, get_field, get_reader, get_renderer 25 | 26 | @register_model('nsvf_bg') 27 | class NSVFBGModel(NSVFModel): 28 | 29 | def __init__(self, args, setups): 30 | super().__init__(args, setups) 31 | 32 | args_copy = copy.deepcopy(args) 33 | if getattr(args, "bg_field_args", None) is not None: 34 | args_copy.__dict__.update(json.loads(args.bg_field_args)) 35 | else: 36 | args_copy.inputs_to_density = "pos:10" 37 | args_copy.inputs_to_texture = "feat:0:256, ray:4:3:b" 38 | self.bg_field = get_field("radiance_field")(args_copy) 39 | self.bg_encoder = get_encoder("volume_encoder")(args_copy) 40 | 41 | @classmethod 42 | def add_args(cls, parser): 43 | super().add_args(parser) 44 | parser.add_argument('--near', type=float, help='near distance of the volume') 45 | parser.add_argument('--far', type=float, help='far distance of the volume') 46 | parser.add_argument('--nerf-steps', type=int, help='additional nerf steps') 47 | parser.add_argument('--bg-field-args', type=str, default=None, help='override args for bg field') 48 | 49 | def postprocessing(self, ray_start, ray_dir, all_results, hits, sizes): 50 | # we will trace the background field here 51 | S, V, P = sizes 52 | fullsize = S * V * P 53 | 54 | vox_colors = fill_in((fullsize, 3), hits, all_results['colors'], 0.0) 55 | vox_missed = fill_in((fullsize, ), hits, all_results['missed'], 1.0) 56 | vox_depths = fill_in((fullsize, ), hits, all_results['depths'], 0.0) 57 | 58 | mid_dis = (self.args.near + self.args.far) / 2 59 | n_depth = fill_in((fullsize, ), hits, all_results['min_depths'], mid_dis)[:, None] 60 | f_depth = fill_in((fullsize, ), hits, all_results['max_depths'], mid_dis)[:, None] 61 | 62 | # front field 63 | nerf_step = getattr(self.args, "nerf_steps", 64) 64 | max_depth = n_depth 65 | min_depth = torch.ones_like(max_depth) * self.args.near 66 | intersection_outputs = { 67 | "min_depth": min_depth, "max_depth": max_depth, 68 | "probs": torch.ones_like(max_depth), 69 | "steps": torch.ones_like(max_depth).squeeze(-1) * nerf_step, 70 | "intersected_voxel_idx": torch.zeros_like(min_depth).int()} 71 | with with_torch_seed(self.unique_seed): 72 | fg_samples = self.bg_encoder.ray_sample(intersection_outputs) 73 | fg_results = self.raymarcher( 74 | self.bg_encoder, self.bg_field, ray_start, ray_dir, fg_samples, {}) 75 | 76 | # back field 77 | min_depth = f_depth 78 | max_depth = torch.ones_like(min_depth) * self.args.far 79 | intersection_outputs = { 80 | "min_depth": min_depth, "max_depth": max_depth, 81 | "probs": torch.ones_like(max_depth), 82 | "steps": torch.ones_like(max_depth).squeeze(-1) * nerf_step, 83 | "intersected_voxel_idx": torch.zeros_like(min_depth).int()} 84 | with with_torch_seed(self.unique_seed): 85 | bg_samples = self.bg_encoder.ray_sample(intersection_outputs) 86 | bg_results = self.raymarcher( 87 | self.bg_encoder, self.bg_field, ray_start, ray_dir, bg_samples, {}) 88 | 89 | # merge background to foreground 90 | all_results['voxcolors'] = vox_colors.view(S, V, P, 3) 91 | all_results['colors'] = fg_results['colors'] + fg_results['missed'][:, None] * (vox_colors + vox_missed[:, None] * bg_results['colors']) 92 | all_results['depths'] = fg_results['depths'] + fg_results['missed'] * (vox_depths + vox_missed * bg_results['depths']) 93 | all_results['missed'] = fg_results['missed'] * vox_missed * bg_results['missed'] 94 | 95 | # apply the NSVF post-processing 96 | return super().postprocessing(ray_start, ray_dir, all_results, hits, sizes) 97 | 98 | def _visualize(self, images, sample, output, state, **kwargs): 99 | img_id, shape, view, width, name = state 100 | images = super()._visualize(images, sample, output, state, **kwargs) 101 | if 'voxcolors' in output and output['voxcolors'] is not None: 102 | images['{}_vcolors/{}:HWC'.format(name, img_id)] ={ 103 | 'img': output['voxcolors'][shape, view], 104 | 'min_val': float(self.args.min_color) 105 | } 106 | return images 107 | 108 | 109 | @register_model_architecture("nsvf_bg", "nsvf_bg") 110 | def base_bg_architecture(args): 111 | base_architecture(args) 112 | 113 | @register_model_architecture("nsvf_bg", "nsvf_bg_xyz") 114 | def base_bg2_architecture(args): 115 | args.nerf_steps = getattr(args, "nerf_steps", 64) 116 | nerf_style_architecture(args) 117 | 118 | 119 | @register_model('shared_nsvf_bg') 120 | class SharedNSVFBGModel(NSVFBGModel): 121 | 122 | ENCODER = 'shared_sparsevoxel_encoder' 123 | 124 | def postprocessing(self, ray_start, ray_dir, all_results, hits, sizes): 125 | # we will trace the background field here 126 | # pass context vector from NSVF to NeRF 127 | self.bg_encoder.precompute(context=self.encoder.contexts(self.encoder.cid).unsqueeze(0)) 128 | return super().postprocessing(ray_start, ray_dir, all_results, hits, sizes) 129 | 130 | @torch.no_grad() 131 | def split_voxels(self): 132 | logger.info("half the global voxel size {:.4f} -> {:.4f}".format( 133 | self.encoder.all_voxels[0].voxel_size.item(), 134 | self.encoder.all_voxels[0].voxel_size.item() * .5)) 135 | self.encoder.splitting() 136 | for id in range(len(self.encoder.all_voxels)): 137 | self.encoder.all_voxels[id].voxel_size *= .5 138 | self.encoder.all_voxels[id].max_hits *= 1.5 139 | self.clean_caches() 140 | 141 | @torch.no_grad() 142 | def reduce_stepsize(self): 143 | logger.info("reduce the raymarching step size {:.4f} -> {:.4f}".format( 144 | self.encoder.all_voxels[0].step_size.item(), 145 | self.encoder.all_voxels[0].step_size.item() * .5)) 146 | for id in range(len(self.encoder.all_voxels)): 147 | self.encoder.all_voxels[id].step_size *= .5 148 | 149 | 150 | @register_model_architecture("shared_nsvf_bg", "shared_nsvf_bg_xyz") 151 | def base_shared_architecture(args): 152 | args.context_embed_dim = getattr(args, "context_embed_dim", 96) 153 | args.hypernetwork = getattr(args, "hypernetwork", False) 154 | args.inputs_to_density = getattr(args, "inputs_to_density", "pos:10, context:0:96") 155 | args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, ray:4:3:b") 156 | args.bg_field_args = getattr(args, "bg_field_args", 157 | "{'inputs_to_density': 'pos:10, context:0:96', 'inputs_to_texture': 'feat:0:256, ray:4:3:b}'}") 158 | nerf_style_architecture(args) -------------------------------------------------------------------------------- /fairnr/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | # automatically import any Python files in the models/ directory 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('fairnr.modules.' + model_name) -------------------------------------------------------------------------------- /fairnr/modules/hyper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | ''' 8 | Pytorch implementations of hyper-network modules. 9 | This code is largely adapted from 10 | https://github.com/vsitzmann/scene-representation-networks 11 | ''' 12 | 13 | import torch 14 | import torch.nn as nn 15 | import functools 16 | 17 | from fairnr.modules.module_utils import FCBlock 18 | 19 | 20 | def partialclass(cls, *args, **kwds): 21 | 22 | class NewCls(cls): 23 | __init__ = functools.partialmethod(cls.__init__, *args, **kwds) 24 | 25 | return NewCls 26 | 27 | 28 | class LookupLayer(nn.Module): 29 | def __init__(self, in_ch, out_ch, num_objects): 30 | super().__init__() 31 | 32 | self.out_ch = out_ch 33 | self.lookup_lin = LookupLinear(in_ch, 34 | out_ch, 35 | num_objects=num_objects) 36 | self.norm_nl = nn.Sequential( 37 | nn.LayerNorm([self.out_ch], elementwise_affine=False), 38 | nn.ReLU(inplace=True) 39 | ) 40 | 41 | def forward(self, obj_idx): 42 | net = nn.Sequential( 43 | self.lookup_lin(obj_idx), 44 | self.norm_nl 45 | ) 46 | return net 47 | 48 | 49 | class LookupFC(nn.Module): 50 | def __init__(self, 51 | hidden_ch, 52 | num_hidden_layers, 53 | num_objects, 54 | in_ch, 55 | out_ch, 56 | outermost_linear=False): 57 | super().__init__() 58 | self.layers = nn.ModuleList() 59 | self.layers.append(LookupLayer(in_ch=in_ch, out_ch=hidden_ch, num_objects=num_objects)) 60 | 61 | for i in range(num_hidden_layers): 62 | self.layers.append(LookupLayer(in_ch=hidden_ch, out_ch=hidden_ch, num_objects=num_objects)) 63 | 64 | if outermost_linear: 65 | self.layers.append(LookupLinear(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)) 66 | else: 67 | self.layers.append(LookupLayer(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)) 68 | 69 | def forward(self, obj_idx): 70 | net = [] 71 | for i in range(len(self.layers)): 72 | net.append(self.layers[i](obj_idx)) 73 | 74 | return nn.Sequential(*net) 75 | 76 | 77 | class LookupLinear(nn.Module): 78 | def __init__(self, 79 | in_ch, 80 | out_ch, 81 | num_objects): 82 | super().__init__() 83 | self.in_ch = in_ch 84 | self.out_ch = out_ch 85 | 86 | self.hypo_params = nn.Embedding(num_objects, in_ch * out_ch + out_ch) 87 | 88 | for i in range(num_objects): 89 | nn.init.kaiming_normal_(self.hypo_params.weight.data[i, :self.in_ch * self.out_ch].view(self.out_ch, self.in_ch), 90 | a=0.0, 91 | nonlinearity='relu', 92 | mode='fan_in') 93 | self.hypo_params.weight.data[i, self.in_ch * self.out_ch:].fill_(0.) 94 | 95 | def forward(self, obj_idx): 96 | hypo_params = self.hypo_params(obj_idx) 97 | 98 | # Indices explicit to catch erros in shape of output layer 99 | weights = hypo_params[..., :self.in_ch * self.out_ch] 100 | biases = hypo_params[..., self.in_ch * self.out_ch:(self.in_ch * self.out_ch)+self.out_ch] 101 | 102 | biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) 103 | weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) 104 | 105 | return BatchLinear(weights=weights, biases=biases) 106 | 107 | 108 | class HyperLayer(nn.Module): 109 | '''A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU.''' 110 | def __init__(self, 111 | in_ch, 112 | out_ch, 113 | hyper_in_ch, 114 | hyper_num_hidden_layers, 115 | hyper_hidden_ch): 116 | super().__init__() 117 | 118 | self.hyper_linear = HyperLinear(in_ch=in_ch, 119 | out_ch=out_ch, 120 | hyper_in_ch=hyper_in_ch, 121 | hyper_num_hidden_layers=hyper_num_hidden_layers, 122 | hyper_hidden_ch=hyper_hidden_ch) 123 | self.norm_nl = nn.Sequential( 124 | nn.LayerNorm([out_ch], elementwise_affine=False), 125 | nn.ReLU(inplace=True) 126 | ) 127 | 128 | def forward(self, hyper_input): 129 | ''' 130 | :param hyper_input: input to hypernetwork. 131 | :return: nn.Module; predicted fully connected network. 132 | ''' 133 | return nn.Sequential(self.hyper_linear(hyper_input), self.norm_nl) 134 | 135 | 136 | class HyperFC(nn.Module): 137 | '''Builds a hypernetwork that predicts a fully connected neural network. 138 | ''' 139 | def __init__(self, 140 | hyper_in_ch, 141 | hyper_num_hidden_layers, 142 | hyper_hidden_ch, 143 | hidden_ch, 144 | num_hidden_layers, 145 | in_ch, 146 | out_ch, 147 | outermost_linear=False): 148 | super().__init__() 149 | 150 | PreconfHyperLinear = partialclass(HyperLinear, 151 | hyper_in_ch=hyper_in_ch, 152 | hyper_num_hidden_layers=hyper_num_hidden_layers, 153 | hyper_hidden_ch=hyper_hidden_ch) 154 | PreconfHyperLayer = partialclass(HyperLayer, 155 | hyper_in_ch=hyper_in_ch, 156 | hyper_num_hidden_layers=hyper_num_hidden_layers, 157 | hyper_hidden_ch=hyper_hidden_ch) 158 | 159 | self.layers = nn.ModuleList() 160 | self.layers.append(PreconfHyperLayer(in_ch=in_ch, out_ch=hidden_ch)) 161 | 162 | for i in range(num_hidden_layers): 163 | self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=hidden_ch)) 164 | 165 | if outermost_linear: 166 | self.layers.append(PreconfHyperLinear(in_ch=hidden_ch, out_ch=out_ch)) 167 | else: 168 | self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=out_ch)) 169 | 170 | 171 | def forward(self, hyper_input): 172 | ''' 173 | :param hyper_input: Input to hypernetwork. 174 | :return: nn.Module; Predicted fully connected neural network. 175 | ''' 176 | net = [] 177 | for i in range(len(self.layers)): 178 | net.append(self.layers[i](hyper_input)) 179 | 180 | return nn.Sequential(*net) 181 | 182 | 183 | class BatchLinear(nn.Module): 184 | def __init__(self, 185 | weights, 186 | biases): 187 | '''Implements a batch linear layer. 188 | 189 | :param weights: Shape: (batch, out_ch, in_ch) 190 | :param biases: Shape: (batch, 1, out_ch) 191 | ''' 192 | super().__init__() 193 | 194 | self.weights = weights 195 | self.biases = biases 196 | 197 | def __repr__(self): 198 | return "BatchLinear(batch=%d, in_ch=%d, out_ch=%d)"%( 199 | self.weights.shape[0], self.weights.shape[-1], self.weights.shape[-2]) 200 | 201 | def forward(self, input): 202 | output = input.matmul(self.weights.permute(*[i for i in range(len(self.weights.shape)-2)], -1, -2)) 203 | output += self.biases 204 | return output 205 | 206 | 207 | def last_hyper_layer_init(m): 208 | if type(m) == nn.Linear: 209 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 210 | m.weight.data *= 1e-1 211 | 212 | 213 | class HyperLinear(nn.Module): 214 | '''A hypernetwork that predicts a single linear layer (weights & biases).''' 215 | def __init__(self, 216 | in_ch, 217 | out_ch, 218 | hyper_in_ch, 219 | hyper_num_hidden_layers, 220 | hyper_hidden_ch): 221 | 222 | super().__init__() 223 | self.in_ch = in_ch 224 | self.out_ch = out_ch 225 | 226 | self.hypo_params = FCBlock( 227 | in_features=hyper_in_ch, 228 | hidden_ch=hyper_hidden_ch, 229 | num_hidden_layers=hyper_num_hidden_layers, 230 | out_features=(in_ch * out_ch) + out_ch, 231 | outermost_linear=True) 232 | self.hypo_params[-1].apply(last_hyper_layer_init) 233 | 234 | def forward(self, hyper_input): 235 | hypo_params = self.hypo_params(hyper_input.cuda()) 236 | 237 | # Indices explicit to catch erros in shape of output layer 238 | weights = hypo_params[..., :self.in_ch * self.out_ch] 239 | biases = hypo_params[..., self.in_ch * self.out_ch:(self.in_ch * self.out_ch)+self.out_ch] 240 | 241 | biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) 242 | weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) 243 | 244 | return BatchLinear(weights=weights, biases=biases) 245 | -------------------------------------------------------------------------------- /fairnr/modules/implicit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from fairseq.utils import get_activation_fn 11 | from fairnr.modules.hyper import HyperFC 12 | from fairnr.modules.module_utils import FCLayer 13 | 14 | 15 | class BackgroundField(nn.Module): 16 | """ 17 | Background (we assume a uniform color) 18 | """ 19 | def __init__(self, out_dim=3, bg_color="1.0,1.0,1.0", min_color=-1, stop_grad=False, background_depth=5.0): 20 | super().__init__() 21 | 22 | if out_dim == 3: # directly model RGB 23 | bg_color = [float(b) for b in bg_color.split(',')] if isinstance(bg_color, str) else [bg_color] 24 | if min_color == -1: 25 | bg_color = [b * 2 - 1 for b in bg_color] 26 | if len(bg_color) == 1: 27 | bg_color = bg_color + bg_color + bg_color 28 | bg_color = torch.tensor(bg_color) 29 | else: 30 | bg_color = torch.ones(out_dim).uniform_() 31 | if min_color == -1: 32 | bg_color = bg_color * 2 - 1 33 | self.out_dim = out_dim 34 | self.bg_color = nn.Parameter(bg_color, requires_grad=not stop_grad) 35 | self.depth = background_depth 36 | 37 | def forward(self, x, **kwargs): 38 | return self.bg_color.unsqueeze(0).expand( 39 | *x.size()[:-1], self.out_dim) 40 | 41 | 42 | class ImplicitField(nn.Module): 43 | def __init__(self, in_dim, out_dim, hidden_dim, num_layers, 44 | outmost_linear=False, with_ln=True, skips=None, spec_init=True): 45 | super().__init__() 46 | self.skips = skips 47 | self.net = [] 48 | 49 | prev_dim = in_dim 50 | for i in range(num_layers): 51 | next_dim = out_dim if i == (num_layers - 1) else hidden_dim 52 | if (i == (num_layers - 1)) and outmost_linear: 53 | self.net.append(nn.Linear(prev_dim, next_dim)) 54 | else: 55 | self.net.append(FCLayer(prev_dim, next_dim, with_ln=with_ln)) 56 | prev_dim = next_dim 57 | if (self.skips is not None) and (i in self.skips) and (i != (num_layers - 1)): 58 | prev_dim += in_dim 59 | 60 | if num_layers > 0: 61 | self.net = nn.ModuleList(self.net) 62 | if spec_init: 63 | self.net.apply(self.init_weights) 64 | 65 | def forward(self, x): 66 | y = self.net[0](x) 67 | for i in range(len(self.net) - 1): 68 | if (self.skips is not None) and (i in self.skips): 69 | y = torch.cat((x, y), dim=-1) 70 | y = self.net[i+1](y) 71 | return y 72 | 73 | def init_weights(self, m): 74 | if type(m) == nn.Linear: 75 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 76 | 77 | 78 | class HyperImplicitField(nn.Module): 79 | 80 | def __init__(self, hyper_in_dim, in_dim, out_dim, hidden_dim, num_layers, 81 | outmost_linear=False): 82 | super().__init__() 83 | 84 | self.hyper_in_dim = hyper_in_dim 85 | self.in_dim = in_dim 86 | self.net = HyperFC( 87 | hyper_in_dim, 88 | 1, 256, 89 | hidden_dim, 90 | num_layers, 91 | in_dim, 92 | out_dim, 93 | outermost_linear=outmost_linear 94 | ) 95 | 96 | def forward(self, x, c): 97 | assert (x.size(-1) == self.in_dim) and (c.size(-1) == self.hyper_in_dim) 98 | if self.nerfpos is not None: 99 | x = torch.cat([x, self.nerfpos(x)], -1) 100 | return self.net(c)(x.unsqueeze(0)).squeeze(0) 101 | 102 | 103 | class SignedDistanceField(ImplicitField): 104 | """ 105 | Predictor for density or SDF values. 106 | """ 107 | def __init__(self, in_dim, hidden_dim, num_layers=1, 108 | recurrent=False, with_ln=True, spec_init=True): 109 | super().__init__(in_dim, in_dim, in_dim, num_layers-1, with_ln=with_ln, spec_init=spec_init) 110 | self.recurrent = recurrent 111 | if recurrent: 112 | assert num_layers > 1 113 | self.hidden_layer = nn.LSTMCell(input_size=in_dim, hidden_size=hidden_dim) 114 | self.hidden_layer.apply(init_recurrent_weights) 115 | lstm_forget_gate_init(self.hidden_layer) 116 | else: 117 | self.hidden_layer = FCLayer(in_dim, hidden_dim, with_ln) \ 118 | if num_layers > 0 else nn.Identity() 119 | prev_dim = hidden_dim if num_layers > 0 else in_dim 120 | self.output_layer = nn.Linear(prev_dim, 1) 121 | 122 | def forward(self, x, state=None): 123 | if self.recurrent: 124 | shape = x.size() 125 | state = self.hidden_layer(x.view(-1, shape[-1]), state) 126 | if state[0].requires_grad: 127 | state[0].register_hook(lambda x: x.clamp(min=-5, max=5)) 128 | return self.output_layer(state[0].view(*shape[:-1], -1)).squeeze(-1), state 129 | else: 130 | return self.output_layer(self.hidden_layer(x)).squeeze(-1), None 131 | 132 | 133 | class TextureField(ImplicitField): 134 | """ 135 | Pixel generator based on 1x1 conv networks 136 | """ 137 | def __init__(self, in_dim, hidden_dim, num_layers, 138 | with_alpha=False, with_ln=True, spec_init=True): 139 | out_dim = 3 if not with_alpha else 4 140 | super().__init__(in_dim, out_dim, hidden_dim, num_layers, 141 | outmost_linear=True, with_ln=with_ln, spec_init=spec_init) 142 | 143 | 144 | # ------------------ # 145 | # helper functions # 146 | # ------------------ # 147 | def init_recurrent_weights(self): 148 | for m in self.modules(): 149 | if type(m) in [nn.GRU, nn.LSTM, nn.RNN]: 150 | for name, param in m.named_parameters(): 151 | if 'weight_ih' in name: 152 | nn.init.kaiming_normal_(param.data) 153 | elif 'weight_hh' in name: 154 | nn.init.orthogonal_(param.data) 155 | elif 'bias' in name: 156 | param.data.fill_(0) 157 | 158 | 159 | def lstm_forget_gate_init(lstm_layer): 160 | for name, parameter in lstm_layer.named_parameters(): 161 | if not "bias" in name: continue 162 | n = parameter.size(0) 163 | start, end = n // 4, n // 2 164 | parameter.data[start:end].fill_(1.) 165 | 166 | 167 | def clip_grad_norm_hook(x, max_norm=10): 168 | total_norm = x.norm() 169 | total_norm = total_norm ** (1 / 2.) 170 | clip_coef = max_norm / (total_norm + 1e-6) 171 | if clip_coef < 1: 172 | return x * clip_coef -------------------------------------------------------------------------------- /fairnr/modules/module_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from fairseq.modules import LayerNorm 12 | from fairseq.utils import get_activation_fn 13 | 14 | 15 | def Linear(in_features, out_features, bias=True): 16 | m = nn.Linear(in_features, out_features, bias) 17 | nn.init.xavier_uniform_(m.weight) 18 | if bias: 19 | nn.init.constant_(m.bias, 0.0) 20 | return m 21 | 22 | 23 | def Embedding(num_embeddings, embedding_dim, padding_idx=None): 24 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 25 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 26 | return m 27 | 28 | 29 | class PosEmbLinear(nn.Module): 30 | 31 | def __init__(self, in_dim, out_dim, no_linear=False, scale=1, *args, **kwargs): 32 | super().__init__() 33 | assert out_dim % (2 * in_dim) == 0, "dimension must be dividable" 34 | half_dim = out_dim // 2 // in_dim 35 | emb = math.log(10000) / (half_dim - 1) 36 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 37 | 38 | self.emb = nn.Parameter(emb, requires_grad=False) 39 | self.linear = Linear(out_dim, out_dim) if not no_linear else None 40 | self.scale = scale 41 | self.in_dim = in_dim 42 | self.out_dim = out_dim 43 | self.cat_input = False 44 | 45 | def forward(self, x): 46 | assert x.size(-1) == self.in_dim, "size must match" 47 | sizes = x.size() 48 | x = self.scale * x.unsqueeze(-1) @ self.emb.unsqueeze(0) 49 | x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 50 | x = x.view(*sizes[:-1], self.out_dim) 51 | if self.linear is not None: 52 | return self.linear(x) 53 | return x 54 | 55 | 56 | class NeRFPosEmbLinear(nn.Module): 57 | 58 | def __init__(self, in_dim, out_dim, angular=False, no_linear=False, cat_input=False): 59 | super().__init__() 60 | assert out_dim % (2 * in_dim) == 0, "dimension must be dividable" 61 | L = out_dim // 2 // in_dim 62 | emb = torch.exp(torch.arange(L, dtype=torch.float) * math.log(2.)) 63 | if not angular: 64 | emb = emb * math.pi 65 | 66 | self.emb = nn.Parameter(emb, requires_grad=False) 67 | self.angular = angular 68 | self.linear = Linear(out_dim, out_dim) if not no_linear else None 69 | self.in_dim = in_dim 70 | self.out_dim = out_dim 71 | self.cat_input = cat_input 72 | 73 | def forward(self, x): 74 | assert x.size(-1) == self.in_dim, "size must match" 75 | sizes = x.size() 76 | inputs = x.clone() 77 | 78 | if self.angular: 79 | x = torch.acos(x.clamp(-1 + 1e-6, 1 - 1e-6)) 80 | x = x.unsqueeze(-1) @ self.emb.unsqueeze(0) 81 | x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 82 | x = x.view(*sizes[:-1], self.out_dim) 83 | if self.linear is not None: 84 | x = self.linear(x) 85 | if self.cat_input: 86 | x = torch.cat([x, inputs], -1) 87 | return x 88 | 89 | def extra_repr(self) -> str: 90 | outstr = 'Sinusoidal (in={}, out={}, angular={})'.format( 91 | self.in_dim, self.out_dim, self.angular) 92 | if self.cat_input: 93 | outstr = 'Cat({}, {})'.format(outstr, self.in_dim) 94 | return outstr 95 | 96 | 97 | class FCLayer(nn.Module): 98 | """ 99 | Reference: 100 | https://github.com/vsitzmann/pytorch_prototyping/blob/10f49b1e7df38a58fd78451eac91d7ac1a21df64/pytorch_prototyping.py 101 | """ 102 | def __init__(self, in_dim, out_dim, with_ln=True): 103 | super().__init__() 104 | self.net = [nn.Linear(in_dim, out_dim)] 105 | if with_ln: 106 | self.net += [nn.LayerNorm([out_dim])] 107 | self.net += [nn.ReLU()] 108 | self.net = nn.Sequential(*self.net) 109 | 110 | def forward(self, x): 111 | return self.net(x) 112 | 113 | 114 | class FCBlock(nn.Module): 115 | def __init__(self, 116 | hidden_ch, 117 | num_hidden_layers, 118 | in_features, 119 | out_features, 120 | outermost_linear=False, 121 | with_ln=True): 122 | super().__init__() 123 | 124 | self.net = [] 125 | self.net.append(FCLayer(in_features, hidden_ch, with_ln)) 126 | for i in range(num_hidden_layers): 127 | self.net.append(FCLayer(hidden_ch, hidden_ch, with_ln)) 128 | if outermost_linear: 129 | self.net.append(Linear(hidden_ch, out_features)) 130 | else: 131 | self.net.append(FCLayer(hidden_ch, out_features, with_ln)) 132 | self.net = nn.Sequential(*self.net) 133 | self.net.apply(self.init_weights) 134 | 135 | def __getitem__(self, item): 136 | return self.net[item] 137 | 138 | def init_weights(self, m): 139 | if type(m) == nn.Linear: 140 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 141 | 142 | def forward(self, input): 143 | return self.net(input) 144 | 145 | 146 | class InvertableMapping(nn.Module): 147 | def __init__(self, style='simple'): 148 | super().__init__() 149 | self.style = style 150 | 151 | def f(self, x): # (0, 1) --> (0, +inf) 152 | if self.style == 'simple': 153 | return x / (1 - x + 1e-7) 154 | raise NotImplementedError 155 | 156 | def g(self, y): # (0, +inf) --> (0, 1) 157 | if self.style == 'simple': 158 | return y / (1 + y) 159 | raise NotImplementedError 160 | 161 | def dy(self, x): 162 | if self.style == 'simple': 163 | return 1 / ((1 - x) ** 2 + 1e-7) 164 | raise NotImplementedError -------------------------------------------------------------------------------- /fairnr/modules/reader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import random, os, glob 9 | 10 | from fairnr.data.geometry import get_ray_direction, r6d2mat 11 | 12 | torch.autograd.set_detect_anomaly(True) 13 | TINY = 1e-9 14 | READER_REGISTRY = {} 15 | 16 | def register_reader(name): 17 | def register_reader_cls(cls): 18 | if name in READER_REGISTRY: 19 | raise ValueError('Cannot register duplicate module ({})'.format(name)) 20 | READER_REGISTRY[name] = cls 21 | return cls 22 | return register_reader_cls 23 | 24 | 25 | def get_reader(name): 26 | if name not in READER_REGISTRY: 27 | raise ValueError('Cannot find module {}'.format(name)) 28 | return READER_REGISTRY[name] 29 | 30 | 31 | @register_reader('abstract_reader') 32 | class Reader(nn.Module): 33 | def __init__(self, args): 34 | super().__init__() 35 | self.args = args 36 | 37 | def forward(self, **kwargs): 38 | raise NotImplementedError 39 | 40 | @staticmethod 41 | def add_args(parser): 42 | pass 43 | 44 | 45 | @register_reader('image_reader') 46 | class ImageReader(Reader): 47 | """ 48 | basic image reader 49 | """ 50 | def __init__(self, args): 51 | super().__init__(args) 52 | self.num_pixels = args.pixel_per_view 53 | self.no_sampling = getattr(args, "no_sampling_at_reader", False) 54 | self.deltas = None 55 | self.all_data = self.find_data() 56 | if getattr(args, "trainable_extrinsics", False): 57 | self.all_data_idx = {data_img: (s, v) 58 | for s, data in enumerate(self.all_data) 59 | for v, data_img in enumerate(data)} 60 | self.deltas = nn.ParameterList([ 61 | nn.Parameter(torch.tensor( 62 | [[1., 0., 0., 0., 1., 0., 0., 0., 0.]]).repeat(len(data), 1)) 63 | for data in self.all_data]) 64 | 65 | def find_data(self): 66 | paths = self.args.data 67 | if os.path.isdir(paths): 68 | self.paths = [paths] 69 | else: 70 | self.paths = [line.strip() for line in open(paths)] 71 | return [sorted(glob.glob("{}/rgb/*".format(p))) for p in self.paths] 72 | 73 | @staticmethod 74 | def add_args(parser): 75 | parser.add_argument('--pixel-per-view', type=float, metavar='N', 76 | help='number of pixels sampled for each view') 77 | parser.add_argument("--sampling-on-mask", nargs='?', const=0.9, type=float, 78 | help="this value determined the probability of sampling rays on masks") 79 | parser.add_argument("--sampling-at-center", type=float, 80 | help="only useful for training where we restrict sampling at center of the image") 81 | parser.add_argument("--sampling-on-bbox", action='store_true', 82 | help="sampling points to close to the mask") 83 | parser.add_argument("--sampling-patch-size", type=int, 84 | help="sample pixels based on patches instead of independent pixels") 85 | parser.add_argument("--sampling-skipping-size", type=int, 86 | help="sample pixels if we have skipped pixels") 87 | parser.add_argument("--no-sampling-at-reader", action='store_true', 88 | help="do not perform sampling.") 89 | parser.add_argument("--trainable-extrinsics", action='store_true', 90 | help="if set, we assume extrinsics are trainable. We use 6D representations for rotation") 91 | 92 | def forward(self, uv, intrinsics, extrinsics, size, path=None, **kwargs): 93 | S, V = uv.size()[:2] 94 | if (not self.training) or self.no_sampling: 95 | uv = uv.reshape(S, V, 2, -1, 1, 1) 96 | flatten_uv = uv.reshape(S, V, 2, -1) 97 | else: 98 | uv, _ = self.sample_pixels(uv, size, **kwargs) 99 | flatten_uv = uv.reshape(S, V, 2, -1) 100 | 101 | # go over all shapes 102 | ray_start, ray_dir = [[] for _ in range(S)], [[] for _ in range(S)] 103 | for s in range(S): 104 | for v in range(V): 105 | ixt = intrinsics[s] if intrinsics.dim() == 3 else intrinsics[s, v] 106 | ext = extrinsics[s, v] 107 | translation, rotation = ext[:3, 3], ext[:3, :3] 108 | if (self.deltas is not None) and (path is not None): 109 | shape_id, view_id = self.all_data_idx[path[s][v]] 110 | delta = self.deltas[shape_id][view_id] 111 | d_t, d_r = delta[6:], r6d2mat(delta[None, :6]).squeeze(0) 112 | rotation = rotation @ d_r 113 | translation = translation + d_t 114 | ext = torch.cat([torch.cat([rotation, translation[:, None]], 1), ext[3:]], 0) 115 | ray_start[s] += [translation] 116 | ray_dir[s] += [get_ray_direction(translation, flatten_uv[s, v], ixt, ext, 1)] 117 | ray_start = torch.stack([torch.stack(r) for r in ray_start]) 118 | ray_dir = torch.stack([torch.stack(r) for r in ray_dir]) 119 | return ray_start.unsqueeze(-2), ray_dir.transpose(2, 3), uv 120 | 121 | @torch.no_grad() 122 | def sample_pixels(self, uv, size, alpha=None, mask=None, **kwargs): 123 | H, W = int(size[0,0,0]), int(size[0,0,1]) 124 | S, V = uv.size()[:2] 125 | 126 | if mask is None: 127 | if alpha is not None: 128 | mask = (alpha > 0) 129 | else: 130 | mask = uv.new_ones(S, V, uv.size(-1)).bool() 131 | mask = mask.float().reshape(S, V, H, W) 132 | 133 | if self.args.sampling_at_center < 1.0: 134 | r = (1 - self.args.sampling_at_center) / 2.0 135 | mask0 = mask.new_zeros(S, V, H, W) 136 | mask0[:, :, int(H * r): H - int(H * r), int(W * r): W - int(W * r)] = 1 137 | mask = mask * mask0 138 | 139 | if self.args.sampling_on_bbox: 140 | x_has_points = mask.sum(2, keepdim=True) > 0 141 | y_has_points = mask.sum(3, keepdim=True) > 0 142 | mask = (x_has_points & y_has_points).float() 143 | 144 | probs = mask / (mask.sum() + 1e-8) 145 | if self.args.sampling_on_mask > 0.0: 146 | probs = self.args.sampling_on_mask * probs + (1 - self.args.sampling_on_mask) * 1.0 / (H * W) 147 | 148 | num_pixels = int(self.args.pixel_per_view) 149 | patch_size, skip_size = self.args.sampling_patch_size, self.args.sampling_skipping_size 150 | C = patch_size * skip_size 151 | 152 | if C > 1: 153 | probs = probs.reshape(S, V, H // C, C, W // C, C).sum(3).sum(-1) 154 | num_pixels = num_pixels // patch_size // patch_size 155 | 156 | flatten_probs = probs.reshape(S, V, -1) 157 | sampled_index = sampling_without_replacement(torch.log(flatten_probs+ TINY), num_pixels) 158 | sampled_masks = torch.zeros_like(flatten_probs).scatter_(-1, sampled_index, 1).reshape(S, V, H // C, W // C) 159 | 160 | if C > 1: 161 | sampled_masks = sampled_masks[:, :, :, None, :, None].repeat( 162 | 1, 1, 1, patch_size, 1, patch_size).reshape(S, V, H // skip_size, W // skip_size) 163 | if skip_size > 1: 164 | full_datamask = sampled_masks.new_zeros(S, V, skip_size * skip_size, H // skip_size, W // skip_size) 165 | full_index = torch.randint(skip_size*skip_size, (S, V)) 166 | for i in range(S): 167 | for j in range(V): 168 | full_datamask[i, j, full_index[i, j]] = sampled_masks[i, j] 169 | sampled_masks = full_datamask.reshape( 170 | S, V, skip_size, skip_size, H // skip_size, W // skip_size).permute(0, 1, 4, 2, 5, 3).reshape(S, V, H, W) 171 | 172 | X, Y = uv[:,:,0].reshape(S, V, H, W), uv[:,:,1].reshape(S, V, H, W) 173 | X = X[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size) 174 | Y = Y[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size) 175 | return torch.cat([X, Y], 2), sampled_masks 176 | 177 | 178 | def sampling_without_replacement(logp, k): 179 | def gumbel_like(u): 180 | return -torch.log(-torch.log(torch.rand_like(u) + TINY) + TINY) 181 | scores = logp + gumbel_like(logp) 182 | return scores.topk(k, dim=-1)[1] -------------------------------------------------------------------------------- /fairnr/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import sys 8 | import torch 9 | 10 | 11 | from fairseq import options 12 | 13 | 14 | def parse_args_and_arch(*args, **kwargs): 15 | return options.parse_args_and_arch(*args, **kwargs) 16 | 17 | 18 | def get_rendering_parser(default_task="single_object_rendering"): 19 | parser = options.get_parser("Rendering", default_task) 20 | options.add_dataset_args(parser, gen=True) 21 | add_rendering_args(parser) 22 | return parser 23 | 24 | 25 | def add_rendering_args(parser): 26 | group = parser.add_argument_group("Rendering") 27 | options.add_common_eval_args(group) 28 | group.add_argument("--render-beam", default=5, type=int, metavar="N", 29 | help="beam size for parallel rendering") 30 | group.add_argument("--render-resolution", default="512x512", type=str, metavar="N", help='if provide two numbers, means H x W') 31 | group.add_argument("--render-angular-speed", default=1, type=float, metavar="D", 32 | help="angular speed when rendering around the object") 33 | group.add_argument("--render-num-frames", default=500, type=int, metavar="N") 34 | group.add_argument("--render-path-style", default="circle", choices=["circle", "zoomin_circle", "zoomin_line"], type=str) 35 | group.add_argument("--render-path-args", default="{'radius': 2.5, 'h': 0.0}", 36 | help="specialized arguments for rendering paths") 37 | group.add_argument("--render-output", default=None, type=str) 38 | group.add_argument("--render-at-vector", default="(0,0,0)", type=str) 39 | group.add_argument("--render-up-vector", default="(0,0,-1)", type=str) 40 | group.add_argument("--render-output-types", nargs="+", type=str, default=["color"], 41 | choices=["target", "color", "depth", "normal", "voxel", "predn", "point", "featn2", "vcolors"]) 42 | group.add_argument("--render-raymarching-steps", default=None, type=int) 43 | group.add_argument("--render-save-fps", default=24, type=int) 44 | group.add_argument("--render-combine-output", action='store_true', 45 | help="if set, concat the images into one file.") 46 | group.add_argument("--render-camera-poses", default=None, type=str, 47 | help="text file saved for the testing trajectories") 48 | group.add_argument("--render-camera-intrinsics", default=None, type=str) 49 | group.add_argument("--render-views", type=str, default=None, 50 | help="views sampled for rendering, you can set specific view id, or a range") -------------------------------------------------------------------------------- /fairnr/renderer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | This file is to simulate "generator" in fairseq 8 | """ 9 | 10 | import os, tempfile, shutil, glob 11 | import time 12 | import torch 13 | import numpy as np 14 | import logging 15 | import imageio 16 | 17 | from torchvision.utils import save_image 18 | from fairnr.data import trajectory, geometry, data_utils 19 | from fairseq.meters import StopwatchMeter 20 | from fairnr.data.data_utils import recover_image, get_uv, parse_views 21 | from pathlib import Path 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class NeuralRenderer(object): 27 | 28 | def __init__(self, 29 | resolution="512x512", 30 | frames=501, 31 | speed=5, 32 | raymarching_steps=None, 33 | path_gen=None, 34 | beam=10, 35 | at=(0,0,0), 36 | up=(0,1,0), 37 | output_dir=None, 38 | output_type=None, 39 | fps=24, 40 | test_camera_poses=None, 41 | test_camera_intrinsics=None, 42 | test_camera_views=None): 43 | 44 | self.frames = frames 45 | self.speed = speed 46 | self.raymarching_steps = raymarching_steps 47 | self.path_gen = path_gen 48 | 49 | if isinstance(resolution, str): 50 | self.resolution = [int(r) for r in resolution.split('x')] 51 | else: 52 | self.resolution = [resolution, resolution] 53 | 54 | self.beam = beam 55 | self.output_dir = output_dir 56 | self.output_type = output_type 57 | self.at = at 58 | self.up = up 59 | self.fps = fps 60 | 61 | if self.path_gen is None: 62 | self.path_gen = trajectory.circle() 63 | if self.output_type is None: 64 | self.output_type = ["rgb"] 65 | 66 | if test_camera_intrinsics is not None: 67 | self.test_int = data_utils.load_intrinsics(test_camera_intrinsics) 68 | else: 69 | self.test_int = None 70 | 71 | self.test_frameids = None 72 | if test_camera_poses is not None: 73 | if os.path.isdir(test_camera_poses): 74 | self.test_poses = [ 75 | np.loadtxt(f)[None, :, :] for f in sorted(glob.glob(test_camera_poses + "/*.txt"))] 76 | self.test_poses = np.concatenate(self.test_poses, 0) 77 | else: 78 | self.test_poses = data_utils.load_matrix(test_camera_poses) 79 | if self.test_poses.shape[1] == 17: 80 | self.test_frameids = self.test_poses[:, -1].astype(np.int32) 81 | self.test_poses = self.test_poses[:, :-1] 82 | self.test_poses = self.test_poses.reshape(-1, 4, 4) 83 | 84 | if test_camera_views is not None: 85 | render_views = parse_views(test_camera_views) 86 | self.test_poses = np.stack([self.test_poses[r] for r in render_views]) 87 | 88 | else: 89 | self.test_poses = None 90 | 91 | def generate_rays(self, t, intrinsics, img_size, inv_RT=None, action='none'): 92 | if inv_RT is None: 93 | cam_pos = torch.tensor(self.path_gen(t * self.speed / 180 * np.pi), 94 | device=intrinsics.device, dtype=intrinsics.dtype) 95 | cam_rot = geometry.look_at_rotation(cam_pos, at=self.at, up=self.up, inverse=True, cv=True) 96 | 97 | inv_RT = cam_pos.new_zeros(4, 4) 98 | inv_RT[:3, :3] = cam_rot 99 | inv_RT[:3, 3] = cam_pos 100 | inv_RT[3, 3] = 1 101 | else: 102 | inv_RT = torch.from_numpy(inv_RT).type_as(intrinsics) 103 | 104 | h, w, rh, rw = img_size[0], img_size[1], img_size[2], img_size[3] 105 | if self.test_int is not None: 106 | uv = torch.from_numpy(get_uv(h, w, h, w)[0]).type_as(intrinsics) 107 | intrinsics = self.test_int 108 | else: 109 | uv = torch.from_numpy(get_uv(h * rh, w * rw, h, w)[0]).type_as(intrinsics) 110 | 111 | uv = uv.reshape(2, -1) 112 | return uv, inv_RT 113 | 114 | def parse_sample(self,sample): 115 | if len(sample) == 1: 116 | return sample[0], 0, self.frames 117 | elif len(sample) == 2: 118 | return sample[0], sample[1], self.frames 119 | elif len(sample) == 3: 120 | return sample[0], sample[1], sample[2] 121 | else: 122 | raise NotImplementedError 123 | 124 | @torch.no_grad() 125 | def generate(self, models, sample, **kwargs): 126 | model = models[0] 127 | model.eval() 128 | 129 | logger.info("rendering starts. {}".format(model.text)) 130 | output_path = self.output_dir 131 | image_names = [] 132 | sample, step, frames = self.parse_sample(sample) 133 | 134 | # fix the rendering size 135 | a = sample['size'][0,0,0] / self.resolution[0] 136 | b = sample['size'][0,0,1] / self.resolution[1] 137 | sample['size'][:, :, 0] /= a 138 | sample['size'][:, :, 1] /= b 139 | sample['size'][:, :, 2] *= a 140 | sample['size'][:, :, 3] *= b 141 | 142 | for shape in range(sample['shape'].size(0)): 143 | max_step = step + frames 144 | while step < max_step: 145 | next_step = min(step + self.beam, max_step) 146 | uv, inv_RT = zip(*[ 147 | self.generate_rays( 148 | k, 149 | sample['intrinsics'][shape], 150 | sample['size'][shape, 0], 151 | self.test_poses[k] if self.test_poses is not None else None) 152 | for k in range(step, next_step) 153 | ]) 154 | if self.test_frameids is not None: 155 | assert next_step - step == 1 156 | ids = torch.tensor(self.test_frameids[step: next_step]).type_as(sample['id']) 157 | else: 158 | ids = sample['id'][shape:shape+1] 159 | 160 | real_images = sample['full_rgb'] if 'full_rgb' in sample else sample['colors'] 161 | real_images = real_images.transpose(2, 3) if real_images.size(-1) != 3 else real_images 162 | 163 | _sample = { 164 | 'id': ids, 165 | 'colors': torch.cat([real_images[shape:shape+1] for _ in range(step, next_step)], 1), 166 | 'intrinsics': sample['intrinsics'][shape:shape+1], 167 | 'extrinsics': torch.stack(inv_RT, 0).unsqueeze(0), 168 | 'uv': torch.stack(uv, 0).unsqueeze(0), 169 | 'shape': sample['shape'][shape:shape+1], 170 | 'view': torch.arange( 171 | step, next_step, 172 | device=sample['shape'].device).unsqueeze(0), 173 | 'size': torch.cat([sample['size'][shape:shape+1] for _ in range(step, next_step)], 1), 174 | 'step': step 175 | } 176 | with data_utils.GPUTimer() as timer: 177 | outs = model(**_sample) 178 | logger.info("rendering frame={}\ttotal time={:.4f}".format(step, timer.sum)) 179 | 180 | for k in range(step, next_step): 181 | images = model.visualize(_sample, None, 0, k-step) 182 | image_name = "{:04d}".format(k) 183 | 184 | for key in images: 185 | name, type = key.split('/')[0].split('_') 186 | if type in self.output_type: 187 | if name == 'coarse': 188 | type = 'coarse-' + type 189 | if name == 'target': 190 | continue 191 | 192 | prefix = os.path.join(output_path, type) 193 | Path(prefix).mkdir(parents=True, exist_ok=True) 194 | if type == 'point': 195 | data_utils.save_point_cloud( 196 | os.path.join(prefix, image_name + '.ply'), 197 | images[key][:, :3].cpu().numpy(), 198 | (images[key][:, 3:] * 255).cpu().int().numpy()) 199 | # from fairseq import pdb; pdb.set_trace() 200 | 201 | else: 202 | image = images[key].permute(2, 0, 1) \ 203 | if images[key].dim() == 3 else torch.stack(3*[images[key]], 0) 204 | save_image(image, os.path.join(prefix, image_name + '.png'), format=None) 205 | image_names.append(os.path.join(prefix, image_name + '.png')) 206 | 207 | # save pose matrix 208 | prefix = os.path.join(output_path, 'pose') 209 | Path(prefix).mkdir(parents=True, exist_ok=True) 210 | pose = self.test_poses[k] if self.test_poses is not None else inv_RT[k-step].cpu().numpy() 211 | np.savetxt(os.path.join(prefix, image_name + '.txt'), pose) 212 | 213 | step = next_step 214 | 215 | logger.info("done") 216 | return step, image_names 217 | 218 | def save_images(self, output_files, steps=None, combine_output=True): 219 | if not os.path.exists(self.output_dir): 220 | os.mkdir(self.output_dir) 221 | timestamp = time.strftime('%Y-%m-%d.%H-%M-%S',time.localtime(time.time())) 222 | if steps is not None: 223 | timestamp = "step_{}.".format(steps) + timestamp 224 | 225 | if not combine_output: 226 | for type in self.output_type: 227 | images = [imageio.imread(file_path) for file_path in output_files if type in file_path] 228 | # imageio.mimsave('{}/{}_{}.gif'.format(self.output_dir, type, timestamp), images, fps=self.fps) 229 | imageio.mimwrite('{}/{}_{}.mp4'.format(self.output_dir, type, timestamp), images, fps=self.fps, quality=8) 230 | else: 231 | images = [[imageio.imread(file_path) for file_path in output_files if type == file_path.split('/')[-2]] for type in self.output_type] 232 | images = [np.concatenate([images[j][i] for j in range(len(images))], 1) for i in range(len(images[0]))] 233 | imageio.mimwrite('{}/{}_{}.mp4'.format(self.output_dir, 'full', timestamp), images, fps=self.fps, quality=8) 234 | 235 | return timestamp 236 | 237 | def merge_videos(self, timestamps): 238 | logger.info("mergining mp4 files..") 239 | timestamp = time.strftime('%Y-%m-%d.%H-%M-%S',time.localtime(time.time())) 240 | writer = imageio.get_writer( 241 | os.path.join(self.output_dir, 'full_' + timestamp + '.mp4'), fps=self.fps) 242 | for timestamp in timestamps: 243 | tempfile = os.path.join(self.output_dir, 'full_' + timestamp + '.mp4') 244 | reader = imageio.get_reader(tempfile) 245 | for im in reader: 246 | writer.append_data(im) 247 | writer.close() 248 | for timestamp in timestamps: 249 | tempfile = os.path.join(self.output_dir, 'full_' + timestamp + '.mp4') 250 | os.remove(tempfile) -------------------------------------------------------------------------------- /fairnr/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | for file in os.listdir(os.path.dirname(__file__)): 10 | if file.endswith('.py') and not file.startswith('_'): 11 | task_name = file[:file.find('.py')] 12 | importlib.import_module('fairnr.tasks.' + task_name) 13 | -------------------------------------------------------------------------------- /fairnr_cli/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /fairnr_cli/extract.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | This code is used for extact voxels/meshes from the learne model 8 | """ 9 | import logging 10 | import numpy as np 11 | import torch 12 | import sys, os 13 | import argparse 14 | 15 | from fairseq import options 16 | from fairseq import checkpoint_utils 17 | from plyfile import PlyData, PlyElement 18 | 19 | 20 | def main(args): 21 | logging.basicConfig( 22 | format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 23 | datefmt='%Y-%m-%d %H:%M:%S', 24 | level=logging.INFO, 25 | stream=sys.stdout, 26 | ) 27 | logger = logging.getLogger('fairnr_cli.extract') 28 | logger.info(args) 29 | 30 | use_cuda = torch.cuda.is_available() and not args.cpu 31 | models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( 32 | [args.path], suffix=getattr(args, "checkpoint_suffix", "")) 33 | model = models[0] 34 | if use_cuda: 35 | model.cuda() 36 | 37 | if args.format == 'mc_mesh': 38 | plydata = model.encoder.export_surfaces( 39 | model.field, th=args.mc_threshold, 40 | bits=2 * args.mc_num_samples_per_halfvoxel) 41 | elif args.format == 'voxel_center': 42 | plydata = model.encoder.export_voxels(False) 43 | elif args.format == 'voxel_mesh': 44 | plydata = model.encoder.export_voxels(True) 45 | else: 46 | raise NotImplementedError 47 | 48 | # write to ply file. 49 | if not os.path.exists(args.output): 50 | os.makedirs(args.output) 51 | plydata.text = args.savetext 52 | plydata.write(open(os.path.join(args.output, args.name + '.ply'), 'wb')) 53 | 54 | 55 | def cli_main(): 56 | parser = argparse.ArgumentParser(description='Extract geometry from a trained model (only for learnable embeddings).') 57 | parser.add_argument('--path', type=str, required=True) 58 | parser.add_argument('--output', type=str, required=True) 59 | parser.add_argument('--name', type=str, default='sparsevoxel') 60 | parser.add_argument('--format', type=str, choices=['voxel_center', 'voxel_mesh', 'mc_mesh']) 61 | parser.add_argument('--savetext', action='store_true', help='save .ply in plain text') 62 | parser.add_argument('--mc-num-samples-per-halfvoxel', type=int, default=8, 63 | help="""the number of point samples every half voxel-size for marching cube. 64 | For instance, by setting to 8, it will use (8 x 2) ^ 3 = 4096 points to compute density for each voxel. 65 | In practise, the larger this number is, the more accurate surface you get. 66 | """) 67 | parser.add_argument('--mc-threshold', type=float, default=0.5, 68 | help="""the threshold used to find the isosurface from the learned implicit field. 69 | In our implementation, we define our values as ``1 - exp(-max(0, density))`` 70 | where "0" is empty and "1" is fully occupied. 71 | """) 72 | parser.add_argument('--user-dir', default='fairnr') 73 | parser.add_argument('--cpu', action='store_true') 74 | args = options.parse_args_and_arch(parser) 75 | main(args) 76 | 77 | 78 | if __name__ == '__main__': 79 | cli_main() 80 | -------------------------------------------------------------------------------- /fairnr_cli/launch_slurm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import random, shlex 7 | import os, sys, subprocess 8 | 9 | 10 | def launch_cluster(slurm_args, model_args): 11 | # prepare 12 | jobname = slurm_args.get('job-name', 'test') 13 | train_log = slurm_args.get('output', None) 14 | train_stderr = slurm_args.get('error', None) 15 | nodes, gpus = slurm_args.get('nodes', 1), slurm_args.get('gpus', 8) 16 | if not slurm_args.get('local', False): 17 | assert (train_log is not None) and (train_stderr is not None) 18 | 19 | # parse slurm 20 | train_cmd = ['python', 'train.py', ] 21 | train_cmd.extend(['--distributed-world-size', str(nodes * gpus)]) 22 | if nodes > 1: 23 | train_cmd.extend(['--distributed-port', str(get_random_port())]) 24 | 25 | train_cmd += model_args 26 | 27 | base_srun_cmd = [ 28 | 'srun', 29 | '--job-name', jobname, 30 | '--output', train_log, 31 | '--error', train_stderr, 32 | '--open-mode', 'append', 33 | '--unbuffered', 34 | ] 35 | srun_cmd = base_srun_cmd + train_cmd 36 | srun_cmd_str = ' '.join(map(shlex.quote, srun_cmd)) 37 | srun_cmd_str = srun_cmd_str + ' &' 38 | 39 | sbatch_cmd = [ 40 | 'sbatch', 41 | '--job-name', jobname, 42 | '--partition', slurm_args.get('partition', 'learnfair'), 43 | '--gres', 'gpu:volta:{}'.format(gpus), 44 | '--nodes', str(nodes), 45 | '--ntasks-per-node', '1', 46 | '--cpus-per-task', '48', 47 | '--output', train_log, 48 | '--error', train_stderr, 49 | '--open-mode', 'append', 50 | '--signal', 'B:USR1@180', 51 | '--time', slurm_args.get('time', '4320'), 52 | '--mem', slurm_args.get('mem', '500gb'), 53 | '--exclusive', 54 | ] 55 | if 'constraint' in slurm_args: 56 | sbatch_cmd += ['-C', slurm_args.get('constraint')] 57 | if 'comment' in slurm_args: 58 | sbatch_cmd += ['--comment', slurm_args.get('comment')] 59 | 60 | wrapped_cmd = requeue_support() + '\n' + srun_cmd_str + ' \n wait $! \n sleep 610 & \n wait $!' 61 | sbatch_cmd += ['--wrap', wrapped_cmd] 62 | sbatch_cmd_str = ' '.join(map(shlex.quote, sbatch_cmd)) 63 | 64 | # start training 65 | env = os.environ.copy() 66 | env['OMP_NUM_THREADS'] = '2' 67 | if env.get('SLURM_ARGS', None) is not None: 68 | del env['SLURM_ARGS'] 69 | 70 | if nodes > 1: 71 | env['NCCL_SOCKET_IFNAME'] = '^docker0,lo' 72 | env['NCCL_DEBUG'] = 'INFO' 73 | 74 | if slurm_args.get('dry-run', False): 75 | print(sbatch_cmd_str) 76 | 77 | elif slurm_args.get('local', False): 78 | assert nodes == 1, 'distributed training cannot be combined with local' 79 | if 'CUDA_VISIBLE_DEVICES' not in env: 80 | env['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(gpus))) 81 | env['NCCL_DEBUG'] = 'INFO' 82 | 83 | if train_log is not None: 84 | train_proc = subprocess.Popen(train_cmd, env=env, stdout=subprocess.PIPE) 85 | tee_proc = subprocess.Popen(['tee', '-a', train_log], stdin=train_proc.stdout) 86 | train_proc.stdout.close() 87 | train_proc.wait() 88 | tee_proc.wait() 89 | else: 90 | train_proc = subprocess.Popen(train_cmd, env=env) 91 | train_proc.wait() 92 | else: 93 | with open(train_log, 'a') as train_log_h: 94 | print(f'running command: {sbatch_cmd_str}\n') 95 | with subprocess.Popen(sbatch_cmd, stdout=subprocess.PIPE, env=env) as train_proc: 96 | stdout = train_proc.stdout.read().decode('utf-8') 97 | print(stdout, file=train_log_h) 98 | try: 99 | job_id = int(stdout.rstrip().split()[-1]) 100 | return job_id 101 | except IndexError: 102 | return None 103 | 104 | 105 | def launch(slurm_args, model_args): 106 | job_id = launch_cluster(slurm_args, model_args) 107 | if job_id is not None: 108 | print('Launched {}'.format(job_id)) 109 | else: 110 | print('Failed.') 111 | 112 | 113 | def requeue_support(): 114 | return """ 115 | trap_handler () { 116 | echo "Caught signal: " $1 117 | # SIGTERM must be bypassed 118 | if [ "$1" = "TERM" ]; then 119 | echo "bypass sigterm" 120 | else 121 | # Submit a new job to the queue 122 | echo "Requeuing " $SLURM_JOB_ID 123 | scontrol requeue $SLURM_JOB_ID 124 | fi 125 | } 126 | 127 | 128 | # Install signal handler 129 | trap 'trap_handler USR1' USR1 130 | trap 'trap_handler TERM' TERM 131 | """ 132 | 133 | 134 | def get_random_port(): 135 | old_state = random.getstate() 136 | random.seed() 137 | port = random.randint(10000, 20000) 138 | random.setstate(old_state) 139 | return port 140 | -------------------------------------------------------------------------------- /fairnr_cli/render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | This is a copy of fairseq-generate while simpler for other usage. 8 | """ 9 | 10 | 11 | import logging 12 | import math 13 | import os 14 | import sys 15 | import time 16 | import torch 17 | import imageio 18 | import numpy as np 19 | 20 | from fairseq import checkpoint_utils, progress_bar, tasks, utils 21 | from fairseq.meters import StopwatchMeter, TimeMeter 22 | from fairnr import options 23 | 24 | 25 | def main(args): 26 | assert args.path is not None, '--path required for generation!' 27 | 28 | if args.results_path is not None: 29 | os.makedirs(args.results_path, exist_ok=True) 30 | output_path = os.path.join(args.results_path, 'generate-{}.txt'.format(args.gen_subset)) 31 | with open(output_path, 'w', buffering=1) as h: 32 | return _main(args, h) 33 | else: 34 | return _main(args, sys.stdout) 35 | 36 | 37 | def _main(args, output_file): 38 | logging.basicConfig( 39 | format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 40 | datefmt='%Y-%m-%d %H:%M:%S', 41 | level=logging.INFO, 42 | stream=output_file, 43 | ) 44 | logger = logging.getLogger('fairnr_cli.render') 45 | 46 | utils.import_user_module(args) 47 | 48 | if args.max_tokens is None and args.max_sentences is None: 49 | args.max_tokens = 12000 50 | logger.info(args) 51 | 52 | use_cuda = torch.cuda.is_available() and not args.cpu 53 | 54 | # Load dataset splits 55 | task = tasks.setup_task(args) 56 | task.load_dataset(args.gen_subset) 57 | 58 | 59 | # Load ensemble 60 | logger.info('loading model(s) from {}'.format(args.path)) 61 | models, _model_args = checkpoint_utils.load_model_ensemble( 62 | args.path.split(os.pathsep), 63 | arg_overrides=eval(args.model_overrides), 64 | task=task, 65 | ) 66 | 67 | # Optimize ensemble for generation 68 | for model in models: 69 | if args.fp16: 70 | model.half() 71 | if use_cuda: 72 | model.cuda() 73 | 74 | # Load dataset (possibly sharded) 75 | itr = task.get_batch_iterator( 76 | dataset=task.dataset(args.gen_subset), 77 | max_tokens=args.max_tokens, 78 | max_sentences=args.max_sentences, 79 | max_positions=utils.resolve_max_positions( 80 | task.max_positions(), 81 | *[model.max_positions() for model in models] 82 | ), 83 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 84 | required_batch_size_multiple=args.required_batch_size_multiple, 85 | num_shards=args.num_shards, 86 | shard_id=args.shard_id, 87 | num_workers=args.num_workers, 88 | ).next_epoch_itr(shuffle=False) 89 | 90 | # Initialize generator 91 | gen_timer = StopwatchMeter() 92 | generator = task.build_generator(args) 93 | 94 | 95 | output_files, step= [], 0 96 | with progress_bar.build_progress_bar(args, itr) as t: 97 | wps_meter = TimeMeter() 98 | for i, sample in enumerate(t): 99 | sample = utils.move_to_cuda(sample) if use_cuda else sample 100 | gen_timer.start() 101 | 102 | step, _output_files = task.inference_step(generator, models, [sample, step]) 103 | output_files += _output_files 104 | 105 | gen_timer.stop(500) 106 | wps_meter.update(500) 107 | t.log({'wps': round(wps_meter.avg)}) 108 | 109 | break 110 | # if i > 5: 111 | # break 112 | 113 | generator.save_images(output_files, combine_output=args.render_combine_output) 114 | 115 | def cli_main(): 116 | parser = options.get_rendering_parser() 117 | args = options.parse_args_and_arch(parser) 118 | main(args) 119 | 120 | 121 | if __name__ == '__main__': 122 | cli_main() 123 | -------------------------------------------------------------------------------- /fairnr_cli/render_multigpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | This is a copy of fairseq-generate while simpler for other usage. 8 | """ 9 | 10 | 11 | import logging 12 | import math 13 | import os 14 | import sys 15 | import time 16 | import torch 17 | import imageio 18 | import numpy as np 19 | 20 | from fairseq import checkpoint_utils, progress_bar, tasks, utils, distributed_utils 21 | from fairseq.meters import StopwatchMeter, TimeMeter 22 | from fairseq.options import add_distributed_training_args 23 | from fairnr import options 24 | 25 | 26 | def main(args, *kwargs): 27 | assert args.path is not None, '--path required for generation!' 28 | 29 | if args.results_path is not None: 30 | os.makedirs(args.results_path, exist_ok=True) 31 | output_path = os.path.join(args.results_path, 'generate-{}.txt'.format(args.gen_subset)) 32 | with open(output_path, 'w', buffering=1) as h: 33 | return _main(args, h) 34 | else: 35 | return _main(args, sys.stdout) 36 | 37 | 38 | def _main(args, output_file): 39 | logging.basicConfig( 40 | format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 41 | datefmt='%Y-%m-%d %H:%M:%S', 42 | level=logging.INFO, 43 | stream=output_file, 44 | ) 45 | logger = logging.getLogger('fairnr_cli.render') 46 | 47 | utils.import_user_module(args) 48 | 49 | if args.max_tokens is None and args.max_sentences is None: 50 | args.max_tokens = 12000 51 | logger.info(args) 52 | 53 | use_cuda = torch.cuda.is_available() and not args.cpu 54 | 55 | # Load dataset splits 56 | task = tasks.setup_task(args) 57 | task.load_dataset(args.gen_subset) 58 | 59 | 60 | # Load ensemble 61 | logger.info('loading model(s) from {}'.format(args.path)) 62 | models, _model_args = checkpoint_utils.load_model_ensemble( 63 | args.path.split(os.pathsep), 64 | arg_overrides=eval(args.model_overrides), 65 | task=task, 66 | ) 67 | 68 | # Optimize ensemble for generation 69 | for model in models: 70 | if args.fp16: 71 | model.half() 72 | if use_cuda: 73 | model.cuda() 74 | 75 | logging.info(model) 76 | 77 | # Load dataset (possibly sharded) 78 | itr = task.get_batch_iterator( 79 | dataset=task.dataset(args.gen_subset), 80 | max_tokens=args.max_tokens, 81 | max_sentences=args.max_sentences, 82 | max_positions=utils.resolve_max_positions( 83 | task.max_positions(), 84 | *[model.max_positions() for model in models] 85 | ), 86 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 87 | required_batch_size_multiple=args.required_batch_size_multiple, 88 | seed=args.seed, 89 | num_workers=args.num_workers 90 | ).next_epoch_itr(shuffle=False) 91 | 92 | # Initialize generator 93 | gen_timer = StopwatchMeter() 94 | generator = task.build_generator(args) 95 | shard_id, world_size = args.distributed_rank, args.distributed_world_size 96 | output_files = [] 97 | if generator.test_poses is not None: 98 | total_frames = generator.test_poses.shape[0] 99 | _frames = int(np.floor(total_frames / world_size)) 100 | step = shard_id * _frames 101 | frames = _frames if shard_id < (world_size - 1) else total_frames - step 102 | else: 103 | step = shard_id * args.render_num_frames 104 | frames = args.render_num_frames 105 | 106 | with progress_bar.build_progress_bar(args, itr) as t: 107 | wps_meter = TimeMeter() 108 | for i, sample in enumerate(t): 109 | sample = utils.move_to_cuda(sample) if use_cuda else sample 110 | gen_timer.start() 111 | 112 | step, _output_files = task.inference_step( 113 | generator, models, [sample, step, frames]) 114 | output_files += _output_files 115 | 116 | gen_timer.stop(500) 117 | wps_meter.update(500) 118 | t.log({'wps': round(wps_meter.avg)}) 119 | 120 | timestamp = generator.save_images( 121 | output_files, steps='shard{}'.format(shard_id), combine_output=args.render_combine_output) 122 | 123 | # join videos from all GPUs and delete temp files 124 | try: 125 | timestamps = distributed_utils.all_gather_list(timestamp) 126 | except: 127 | timestamps = [timestamp] 128 | 129 | if shard_id == 0: 130 | generator.merge_videos(timestamps) 131 | 132 | def cli_main(): 133 | parser = options.get_rendering_parser() 134 | add_distributed_training_args(parser) 135 | args = options.parse_args_and_arch(parser) 136 | 137 | distributed_utils.call_main(args, main) 138 | 139 | 140 | if __name__ == '__main__': 141 | cli_main() 142 | -------------------------------------------------------------------------------- /fairnr_cli/validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import sys 9 | 10 | import numpy as np 11 | import torch 12 | from itertools import chain 13 | from fairseq import checkpoint_utils, distributed_utils, options, utils 14 | from fairseq.logging import metrics, progress_bar 15 | from fairseq.options import add_distributed_training_args 16 | 17 | logging.basicConfig( 18 | format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 19 | datefmt='%Y-%m-%d %H:%M:%S', 20 | level=logging.INFO, 21 | stream=sys.stdout, 22 | ) 23 | logger = logging.getLogger('fairnr_cli.validate') 24 | 25 | 26 | def main(args, override_args=None): 27 | utils.import_user_module(args) 28 | 29 | assert args.max_tokens is not None or args.max_sentences is not None, \ 30 | 'Must specify batch size either with --max-tokens or --max-sentences' 31 | 32 | use_fp16 = args.fp16 33 | use_cuda = torch.cuda.is_available() and not args.cpu 34 | 35 | if override_args is not None: 36 | try: 37 | override_args = override_args['override_args'] 38 | except TypeError: 39 | override_args = override_args 40 | overrides = vars(override_args) 41 | overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) 42 | else: 43 | overrides = None 44 | 45 | # Load ensemble 46 | logger.info('loading model(s) from {}'.format(args.path)) 47 | models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( 48 | [args.path], 49 | arg_overrides=overrides, 50 | suffix=getattr(args, "checkpoint_suffix", ""), 51 | ) 52 | model = models[0] 53 | 54 | # Move models to GPU 55 | for model in models: 56 | if use_fp16: 57 | model.half() 58 | if use_cuda: 59 | model.cuda() 60 | 61 | # Print args 62 | logger.info(model_args) 63 | 64 | # Build criterion 65 | criterion = task.build_criterion(model_args) 66 | if use_fp16: 67 | criterion.half() 68 | if use_cuda: 69 | criterion.cuda() 70 | criterion.eval() 71 | 72 | for subset in args.valid_subset.split(','): 73 | try: 74 | task.load_dataset(subset, combine=False, epoch=1) 75 | dataset = task.dataset(subset) 76 | except KeyError: 77 | raise Exception('Cannot find dataset: ' + subset) 78 | 79 | # Initialize data iterator 80 | itr = task.get_batch_iterator( 81 | dataset=dataset, 82 | max_tokens=args.max_tokens, 83 | max_sentences=args.max_sentences, 84 | max_positions=utils.resolve_max_positions( 85 | task.max_positions(), 86 | *[m.max_positions() for m in models], 87 | ), 88 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 89 | required_batch_size_multiple=args.required_batch_size_multiple, 90 | seed=args.seed, 91 | num_workers=args.num_workers, 92 | num_shards=args.distributed_world_size, 93 | shard_id=args.distributed_rank 94 | ).next_epoch_itr(shuffle=False) 95 | 96 | progress = progress_bar.progress_bar( 97 | itr, 98 | log_format=args.log_format, 99 | log_interval=args.log_interval, 100 | prefix=f"valid on '{subset}' subset", 101 | default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), 102 | ) 103 | 104 | log_outputs = [] 105 | for i, sample in enumerate(progress): 106 | sample = utils.move_to_cuda(sample) if use_cuda else sample 107 | sample = utils.apply_to_sample( 108 | lambda t: t.half() if t.dtype is torch.float32 else t, sample) if use_fp16 else sample 109 | try: 110 | with torch.no_grad(): # do not save backward passes 111 | max_num_rays = 900 * 900 112 | if sample['uv'].shape[3] > max_num_rays: 113 | sample['ray_split'] = sample['uv'].shape[3] // max_num_rays 114 | _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) 115 | 116 | progress.log(log_output, step=i) 117 | log_outputs.append(log_output) 118 | 119 | except TypeError: 120 | break 121 | 122 | with metrics.aggregate() as agg: 123 | task.reduce_metrics(log_outputs, criterion) 124 | log_output = agg.get_smoothed_values() 125 | 126 | 127 | # summarize all the gpus 128 | if args.distributed_world_size > 1: 129 | all_log_output = list(zip(*distributed_utils.all_gather_list([log_output])))[0] 130 | log_output = { 131 | key: np.mean([log[key] for log in all_log_output]) 132 | for key in all_log_output[0] 133 | } 134 | 135 | progress.print(log_output, tag=subset, step=i) 136 | 137 | 138 | 139 | def cli_main(): 140 | parser = options.get_validation_parser() 141 | args = options.parse_args_and_arch(parser) 142 | 143 | # only override args that are explicitly given on the command line 144 | override_parser = options.get_validation_parser() 145 | override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) 146 | 147 | # support multi-gpu validation, use all available gpus 148 | default_world_size = max(1, torch.cuda.device_count()) 149 | if args.distributed_world_size < default_world_size: 150 | args.distributed_world_size = default_world_size 151 | override_args.distributed_world_size = default_world_size 152 | 153 | distributed_utils.call_main(args, main, override_args=override_args) 154 | 155 | 156 | if __name__ == '__main__': 157 | cli_main() 158 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairnr_cli.render_multigpu import cli_main 8 | 9 | 10 | if __name__ == '__main__': 11 | cli_main() 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | open3d==0.10.0 2 | opencv_python==4.2.0.32 3 | tqdm==4.43.0 4 | pandas==0.25.3 5 | imageio==2.6.1 6 | scikit_image==0.16.2 7 | scipy==1.4.1 8 | plyfile==0.7.1 9 | matplotlib==3.1.2 10 | numpy==1.16.4 11 | mathutils==2.81.2 12 | tensorboardX==2.0 13 | imageio-ffmpeg==0.4.2 14 | git+https://github.com/MultiPath/fairseq-stable.git 15 | git+https://github.com/MultiPath/lpips-pytorch.git -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | 10 | # build clib 11 | # _ext_src_root = "fairnr/clib" 12 | import os 13 | _ext_src_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fairnr/clib") 14 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 15 | "{}/src/*.cu".format(_ext_src_root) 16 | ) 17 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 18 | 19 | setup( 20 | name='fairnr', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name='fairnr.clib._ext', 24 | sources=_ext_sources, 25 | extra_compile_args={ 26 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 27 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 28 | }, 29 | ) 30 | ], 31 | cmdclass={ 32 | 'build_ext': BuildExtension 33 | }, 34 | entry_points={ 35 | 'console_scripts': [ 36 | 'fairnr-render = fairnr_cli.render:cli_main', 37 | 'fairnr-train = fairseq_cli.train:cli_main' 38 | ], 39 | }, 40 | ) 41 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import sys, os 7 | from fairnr_cli.train import cli_main 8 | from fairnr_cli.launch_slurm import launch 9 | 10 | if __name__ == '__main__': 11 | if os.getenv('SLURM_ARGS') is not None: 12 | slurm_arg = eval(os.getenv('SLURM_ARGS')) 13 | all_args = sys.argv[1:] 14 | 15 | print(slurm_arg) 16 | print(all_args) 17 | launch(slurm_arg, all_args) 18 | 19 | else: 20 | cli_main() 21 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairnr_cli.validate import cli_main 8 | 9 | 10 | if __name__ == '__main__': 11 | cli_main() 12 | --------------------------------------------------------------------------------