├── .clang-format ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── benchmark.py ├── configs ├── __init__.py ├── db.yaml ├── mipnerf360_indoor.yaml └── mipnerf360_outdoor.yaml ├── data_loader ├── __init__.py ├── blender.py └── colmap.py ├── external ├── CMakeLists.txt ├── gl3w │ └── gl3w.c └── include │ ├── GL │ ├── gl3w.h │ └── glcorearb.h │ └── KHR │ └── khrplatform.h ├── prepare_colmap_data.py ├── pyproject.toml ├── radfoam_model ├── __init__.py ├── render.py ├── scene.py └── utils.py ├── requirements.txt ├── scripts ├── cmake_clean.sh └── torch_info.py ├── setup.cfg ├── setup.py ├── src ├── CMakeLists.txt ├── aabb_tree │ ├── aabb_tree.cu │ ├── aabb_tree.cuh │ └── aabb_tree.h ├── delaunay │ ├── delaunay.cu │ ├── delaunay.cuh │ ├── delaunay.h │ ├── delete_violations.cu │ ├── exact_tree_ops.cuh │ ├── growth_iteration.cu │ ├── predicate.cuh │ ├── sample_initial_tets.cu │ ├── shewchuk.cuh │ ├── sorted_map.cuh │ ├── triangulation_ops.cu │ └── triangulation_ops.h ├── tracing │ ├── camera.h │ ├── pipeline.cu │ ├── pipeline.h │ ├── sh_utils.cuh │ └── tracing_utils.cuh ├── utils │ ├── batch_fetcher.cpp │ ├── batch_fetcher.h │ ├── common_kernels.cuh │ ├── cuda_array.h │ ├── cuda_helpers.h │ ├── geometry.h │ ├── random.h │ ├── typing.h │ └── unenumerate_iterator.cuh └── viewer │ ├── viewer.cpp │ └── viewer.h ├── teaser.jpg ├── test.py ├── torch_bindings ├── CMakeLists.txt ├── bindings.h ├── pipeline_bindings.cpp ├── pipeline_bindings.h ├── radfoam │ └── __init__.py.in ├── torch_bindings.cpp ├── triangulation_bindings.cpp └── triangulation_bindings.h ├── train.py └── viewer.py /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | UseTab: Never 3 | TabWidth: 4 4 | IndentWidth: 4 5 | ColumnLimit: 80 6 | AlwaysBreakTemplateDeclarations: Yes 7 | BinPackArguments: false 8 | BinPackParameters: false 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | /build*/ 3 | /dist/ 4 | /data 5 | /radfoam/ 6 | /*output* 7 | 8 | *.egg-info 9 | __pycache__ 10 | /*nsight 11 | *.mp4b 12 | *ipynb 13 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/submodules/glfw"] 2 | path = external/submodules/glfw 3 | url = https://github.com/glfw/glfw 4 | [submodule "external/submodules/eigen"] 5 | path = external/submodules/eigen 6 | url = https://gitlab.com/libeigen/eigen.git 7 | [submodule "external/submodules/imgui"] 8 | path = external/submodules/imgui 9 | url = https://github.com/ocornut/imgui 10 | [submodule "external/submodules/atomic_queue"] 11 | path = external/submodules/atomic_queue 12 | url = https://github.com/max0x7ba/atomic_queue 13 | [submodule "external/submodules/mesa"] 14 | path = external/submodules/mesa 15 | url = https://gitlab.freedesktop.org/mesa/mesa.git 16 | [submodule "external/submodules/tbb"] 17 | path = external/submodules/tbb 18 | url = https://github.com/uxlfoundation/oneTBB 19 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.27) 2 | project(RadFoam VERSION 1.0.0) 3 | 4 | cmake_policy(SET CMP0060 NEW) 5 | 6 | enable_language(CUDA) 7 | set(CMAKE_CXX_STANDARD 17) 8 | set(CMAKE_CUDA_STANDARD 17) 9 | set(CMAKE_C_EXTENSIONS OFF) 10 | set(CMAKE_CXX_EXTENSIONS OFF) 11 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 12 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 13 | 14 | set(GPU_DEBUG 15 | ON 16 | CACHE BOOL "Enable GPU debug features") 17 | add_definitions(-DGPU_DEBUG=$) 18 | if(NOT CMAKE_BUILD_TYPE) 19 | set(CMAKE_BUILD_TYPE 20 | "Release" 21 | CACHE STRING "Build type") 22 | endif() 23 | 24 | set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH}) 25 | list(APPEND CMAKE_PREFIX_PATH "${CMAKE_SOURCE_DIR}/external" 26 | "${CMAKE_SOURCE_DIR}/external/submodules") 27 | 28 | find_package( 29 | Python3 30 | COMPONENTS Interpreter Development.Module 31 | REQUIRED) 32 | 33 | find_package(pybind11 REQUIRED) 34 | 35 | if(NOT Torch_DIR) 36 | set(Torch_DIR ${Python3_SITELIB}/torch/share/cmake/Torch) 37 | endif() 38 | 39 | find_package(Torch REQUIRED) 40 | find_library(TORCH_PYTHON_LIBRARY torch_python PATH 41 | "${TORCH_INSTALL_PREFIX}/lib") 42 | 43 | add_subdirectory(external) 44 | include_directories(${RADFOAM_EXTERNAL_INCLUDES}) 45 | 46 | if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) 47 | set(CMAKE_INSTALL_PREFIX 48 | "${CMAKE_SOURCE_DIR}/radfoam" 49 | CACHE PATH "..." FORCE) 50 | endif() 51 | 52 | set(RADFOAM_INSTALL_PREFIX ${CMAKE_INSTALL_PREFIX}) 53 | 54 | add_subdirectory(src) 55 | add_subdirectory(torch_bindings) 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2025 The Radiant Foam Authors 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Radiant Foam: Real-Time Differentiable Ray Tracing 2 | 3 | ![](teaser.jpg) 4 | 5 | ## Shrisudhan Govindarajan, Daniel Rebain, Kwang Moo Yi, Andrea Tagliasacchi 6 | 7 | This repository contains the official implementation of [Radiant Foam: Real-Time Differentiable Ray Tracing](https://radfoam.github.io). 8 | 9 | The code includes scripts for training and evaluation, as well as a real-time viewer that can be used to visualize trained models, or optionally to observe the progression of models as they train. Everything in this repository is non-final and subject to change as the project is still being actively developed. **We encourage anyone citing our results to do as RadFoam (vx), where x is the version specified for those metrics in the paper or tagged to a commit on GitHub.** This should hopefully reduce confusion. 10 | 11 | Warning: this is an organic, free-range research codebase, and should be treated with the appropriate care when integrating it into any other software. 12 | 13 | ## Known issues 14 | - GPU memory usage can be high for scenes with many points. You may need to reduce the `final_points` setting to train outdoor scenes on a 24GB GPU. This will hopefully be improved the future. 15 | - Best PSNR is acheived with the default softplus density activation, but also it causes an increase in volumetric artifacts. Using exponential activation may result in qualitatively better renders. We are planning to add configuration options for this. 16 | - The Delaunay triangulation is not perfectly robust, and relies on random perturbation of points and iterative retries to attempt to recover from failures. Training may stall for long periods when this occurs. 17 | 18 | ## Getting started 19 | 20 | Start by cloning the repository and submodules: 21 | 22 | git clone --recursive https://github.com/theialab/radfoam 23 | 24 | You will need a Linux environment with Python 3.10 or newer, as well as version 12.x of the [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) and a CUDA-compatible GPU of Compute Capability 7.0 or higher. Please ensure that your installation method for CUDA places `nvcc` in your `PATH`. The following instructions were tested with Ubuntu 24.04. 25 | 26 | After installing the CUDA Toolkit and initializing your python virtual environment, install PyTorch 2.3 or newer. For example, with CUDA 12.1: 27 | 28 | pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121 29 | 30 | From here, there are two options: 31 | 32 | ### Option 1: build with `pip install` 33 | 34 | Choose this option if you want to run the code as-is, and do not need to make modifications to the CUDA/C++ code. 35 | 36 | Simply run `pip install .` in the repository root. This will build the CUDA kernels and install them along with the python bindings into your python environment. This may take some time to complete, but once finished, you should be able to run the code without further setup. 37 | 38 | Optionally if you want to install with the frozen version of required packages, you can do so by running `pip install -r requirements.txt` before running `pip install .` 39 | 40 | ### Option 2: build with CMake 41 | 42 | Choose this option if you intend to modify the CUDA/C++ code. Using CMake directly will allow you to quickly recompile the kernels as needed. 43 | 44 | First install the Python dependencies: 45 | 46 | pip install -r requirements.txt 47 | 48 | 49 | Then, create a `build` directory in the repository root and run the following commands from it to initialize CMake and build the bindings library: 50 | 51 | cmake .. 52 | make install 53 | 54 | This will install to a local `radfoam` directory in the repository root. Recompilation can be performed by re-running `make install` in the build directory. 55 | 56 | ### Training 57 | 58 | Place the [Mip-NeRF 360](https://jonbarron.info/mipnerf360) and [Deep Blending](https://github.com/Phog/DeepBlending) datasets in `data/mipnerf360` and `data/db`. 59 | 60 | Training can then be launched with: 61 | 62 | python train.py -c configs/.yaml 63 | 64 | Where `` is either one of the supplied files in the `configs` directory or your own. 65 | You can optionally include the `--viewer` flag to train interactively, or use the `viewer.py` script to view saved checkpoints. 66 | 67 | ### Evaluation 68 | 69 | The standard test metrics can be computed with: 70 | 71 | python test.py -c outputs//config.yaml 72 | 73 | Rendering speed can be computed with: 74 | 75 | python benchmark.py -c outputs//config.yaml 76 | 77 | ### Checkpoints 78 | 79 | You can find trained checkpoints, as well as COLMAP output for some scenes [here](https://drive.google.com/drive/folders/1o8ulZORogwjrfsz3E-QY3f-oPjVFrEVI?usp=drive_link). 80 | 81 | ## BibTeX 82 | 83 | @article{govindarajan2025radfoam, 84 | author = {Govindarajan, Shrisudhan and Rebain, Daniel and Yi, Kwang Moo and Tagliasacchi, Andrea}, 85 | title = {Radiant Foam: Real-Time Differentiable Ray Tracing}, 86 | journal = {arXiv:2502.01157}, 87 | year = {2025}, 88 | } 89 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import configargparse 5 | import warnings 6 | 7 | warnings.filterwarnings("ignore") 8 | 9 | import torch 10 | 11 | from data_loader import DataHandler 12 | from configs import * 13 | from radfoam_model.scene import RadFoamScene 14 | import radfoam 15 | 16 | 17 | seed = 42 18 | torch.random.manual_seed(seed) 19 | np.random.seed(seed) 20 | 21 | 22 | def benchmark(args, pipeline_args, model_args, optimizer_args, dataset_args): 23 | checkpoint = args.config.replace("/config.yaml", "") 24 | os.makedirs(os.path.join(checkpoint, "test"), exist_ok=True) 25 | device = torch.device(args.device) 26 | 27 | test_data_handler = DataHandler( 28 | dataset_args, rays_per_batch=0, device=device 29 | ) 30 | test_data_handler.reload( 31 | split="test", downsample=min(dataset_args.downsample) 32 | ) 33 | 34 | # Setting up model 35 | model = RadFoamScene( 36 | args=model_args, device=device, attr_dtype=torch.float16 37 | ) 38 | 39 | model.load_pt(f"{checkpoint}/model.pt") 40 | 41 | points, attributes, point_adjacency, point_adjacency_offsets = ( 42 | model.get_trace_data() 43 | ) 44 | self_point_inds = torch.zeros_like(point_adjacency.long()) 45 | scatter_inds = point_adjacency_offsets[1:-1].long() 46 | self_point_inds.scatter_add_(0, scatter_inds, torch.ones_like(scatter_inds)) 47 | self_point_inds = torch.cumsum(self_point_inds, dim=0) 48 | self_points = points[self_point_inds] 49 | 50 | adjacent_points = points[point_adjacency.long()] 51 | adjacent_offsets = adjacent_points - self_points 52 | adjacent_offsets = torch.cat( 53 | [adjacent_offsets, torch.zeros_like(adjacent_offsets[:, :1])], dim=1 54 | ).to(torch.half) 55 | 56 | c2w = test_data_handler.c2ws 57 | width, height = test_data_handler.img_wh 58 | fy = test_data_handler.fy 59 | 60 | cameras = [] 61 | positions = [] 62 | 63 | for i in range(c2w.shape[0]): 64 | if i % 8 == 0: 65 | position = c2w[i, :3, 3].contiguous() 66 | fov = float(2 * np.arctan(height / (2 * fy))) 67 | 68 | right = c2w[i, :3, 0].contiguous() 69 | up = -c2w[i, :3, 1].contiguous() 70 | forward = c2w[i, :3, 2].contiguous() 71 | 72 | positions.append(position) 73 | 74 | camera = { 75 | "position": position, 76 | "forward": forward, 77 | "right": right, 78 | "up": up, 79 | "fov": fov, 80 | "width": width, 81 | "height": height, 82 | "model": "pinhole", 83 | } 84 | cameras.append(camera) 85 | 86 | n_frames = len(cameras) 87 | 88 | positions = torch.stack(positions, dim=0).to(device) 89 | start_points = radfoam.nn(points, model.aabb_tree, positions) 90 | 91 | output = torch.zeros( 92 | (n_frames, height, width), dtype=torch.uint32, device=device 93 | ) 94 | 95 | torch.cuda.synchronize() 96 | 97 | # warmup 98 | for i in range(n_frames): 99 | model.pipeline.trace_benchmark( 100 | points, 101 | attributes, 102 | point_adjacency, 103 | point_adjacency_offsets, 104 | adjacent_offsets, 105 | cameras[i], 106 | start_points[i], 107 | output[i], 108 | weight_threshold=0.05, 109 | ) 110 | 111 | torch.cuda.synchronize() 112 | n_reps = 5 113 | start_event = torch.cuda.Event(enable_timing=True) 114 | start_event.record() 115 | 116 | for _ in range(n_reps): 117 | for i in range(n_frames): 118 | model.pipeline.trace_benchmark( 119 | points, 120 | attributes, 121 | point_adjacency, 122 | point_adjacency_offsets, 123 | adjacent_offsets, 124 | cameras[i], 125 | start_points[i], 126 | output[i], 127 | weight_threshold=0.05, 128 | ) 129 | 130 | end_event = torch.cuda.Event(enable_timing=True) 131 | end_event.record() 132 | 133 | torch.cuda.synchronize() 134 | 135 | total_time = start_event.elapsed_time(end_event) 136 | framerate = n_reps * n_frames / (total_time / 1000.0) 137 | 138 | print(f"Total time: {total_time} ms") 139 | print(f"FPS: {framerate}") 140 | 141 | 142 | def main(): 143 | parser = configargparse.ArgParser() 144 | 145 | model_params = ModelParams(parser) 146 | dataset_params = DatasetParams(parser) 147 | pipeline_params = PipelineParams(parser) 148 | optimization_params = OptimizationParams(parser) 149 | 150 | # Add argument to specify a custom config file 151 | parser.add_argument( 152 | "-c", "--config", is_config_file=True, help="Path to config file" 153 | ) 154 | 155 | # Parse arguments 156 | args = parser.parse_args() 157 | 158 | benchmark( 159 | args, 160 | pipeline_params.extract(args), 161 | model_params.extract(args), 162 | optimization_params.extract(args), 163 | dataset_params.extract(args), 164 | ) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import os 3 | from argparse import Namespace 4 | 5 | 6 | class GroupParams: 7 | pass 8 | 9 | 10 | class ParamGroup: 11 | def __init__( 12 | self, parser: configargparse.ArgParser, name: str, fill_none=False 13 | ): 14 | group = parser.add_argument_group(name) 15 | for key, value in vars(self).items(): 16 | t = type(value) 17 | value = value if not fill_none else None 18 | if t == bool: 19 | group.add_argument( 20 | "--" + key, default=value, action="store_true" 21 | ) 22 | elif t == list: 23 | group.add_argument( 24 | "--" + key, 25 | nargs="+", 26 | type=type(value[0]), 27 | default=value, 28 | help=f"List of {type(value[0]).__name__}", 29 | ) 30 | else: 31 | group.add_argument("--" + key, default=value, type=t) 32 | 33 | def extract(self, args): 34 | group = GroupParams() 35 | for arg in vars(args).items(): 36 | if arg[0] in vars(self): 37 | setattr(group, arg[0], arg[1]) 38 | return group 39 | 40 | 41 | class PipelineParams(ParamGroup): 42 | 43 | def __init__(self, parser): 44 | self.iterations = 20_000 45 | self.densify_from = 2_000 46 | self.densify_until = 11_000 47 | self.densify_factor = 1.15 48 | self.white_background = True 49 | self.quantile_weight = 1e-4 50 | self.experiment_name = "" 51 | self.debug = False 52 | self.viewer = False 53 | super().__init__(parser, "Setting Pipeline parameters") 54 | 55 | 56 | class ModelParams(ParamGroup): 57 | 58 | def __init__(self, parser): 59 | self.sh_degree = 3 60 | self.init_points = 131_072 61 | self.final_points = 2_097_152 62 | self.activation_scale = 1.0 63 | self.device = "cuda" 64 | super().__init__(parser, "Setting Model parameters") 65 | 66 | 67 | class OptimizationParams(ParamGroup): 68 | 69 | def __init__(self, parser): 70 | self.points_lr_init = 2e-4 71 | self.points_lr_final = 5e-6 72 | self.density_lr_init = 1e-1 73 | self.density_lr_final = 1e-2 74 | self.attributes_lr_init = 5e-3 75 | self.attributes_lr_final = 5e-4 76 | self.sh_factor = 0.1 77 | self.freeze_points = 18_000 78 | super().__init__(parser, "Setting Optimization parameters") 79 | 80 | 81 | class DatasetParams(ParamGroup): 82 | 83 | def __init__(self, parser): 84 | self.dataset = "colmap" 85 | self.data_path = "data/mipnerf360" 86 | self.scene = "bonsai" 87 | self.patch_based = False 88 | self.downsample = [4, 2, 1] 89 | self.downsample_iterations = [0, 150, 500] 90 | super().__init__(parser, "Setting Dataset parameters") 91 | -------------------------------------------------------------------------------- /configs/db.yaml: -------------------------------------------------------------------------------- 1 | # Model Parameters 2 | sh_degree: 3 3 | init_points: 131_072 4 | final_points: 3_145_728 5 | activation_scale: 1 6 | device: cuda 7 | 8 | # Pipeline Parameters 9 | iterations: 20_000 10 | densify_from: 2_000 11 | densify_until: 11_000 12 | densify_factor: 1.15 13 | white_background: true 14 | quantile_weight: 0 15 | viewer: false # Flag to use viewer 16 | debug: false # Flag to not use tensorboard 17 | 18 | # Optimization Parameters 19 | points_lr_init: 2e-4 20 | points_lr_final: 5e-6 21 | density_lr_init: 1e-1 22 | density_lr_final: 1e-2 23 | attributes_lr_init: 5e-3 24 | attributes_lr_final: 5e-4 25 | sh_factor: 0.01 26 | freeze_points: 18_000 # Points are frozen after this cycle 27 | 28 | # Dataset Parameters 29 | dataset: "colmap" 30 | data_path: "data/db" 31 | scene: "playroom" 32 | patch_based: false 33 | downsample: [2, 1] # Image downsample factors 34 | downsample_iterations: [0, 2_000] 35 | -------------------------------------------------------------------------------- /configs/mipnerf360_indoor.yaml: -------------------------------------------------------------------------------- 1 | # Model Parameters 2 | sh_degree: 3 3 | init_points: 131_072 4 | final_points: 2_097_152 5 | activation_scale: 1 6 | device: cuda 7 | 8 | # Pipeline Parameters 9 | iterations: 20_000 10 | densify_from: 2_000 11 | densify_until: 11_000 12 | densify_factor: 1.15 13 | white_background: true 14 | quantile_weight: 1e-4 15 | viewer: false # Flag to use viewer 16 | debug: false # Flag to not use tensorboard 17 | 18 | # Optimization Parameters 19 | points_lr_init: 2e-4 20 | points_lr_final: 5e-6 21 | density_lr_init: 1e-1 22 | density_lr_final: 1e-2 23 | attributes_lr_init: 5e-3 24 | attributes_lr_final: 5e-4 25 | sh_factor: 0.1 26 | freeze_points: 18_000 # Points are frozen after this cycle 27 | 28 | # Dataset Parameters 29 | dataset: "colmap" 30 | data_path: "data/mipnerf360" 31 | scene: "bonsai" 32 | patch_based: false 33 | downsample: [4, 2] # Image downsample factors 34 | downsample_iterations: [0, 5_000] 35 | -------------------------------------------------------------------------------- /configs/mipnerf360_outdoor.yaml: -------------------------------------------------------------------------------- 1 | # Model Parameters 2 | sh_degree: 3 3 | init_points: 131_072 4 | final_points: 4_194_304 5 | activation_scale: 1 6 | device: cuda 7 | 8 | # Pipeline Parameters 9 | iterations: 20_000 10 | densify_from: 2_000 11 | densify_until: 11_000 12 | densify_factor: 1.15 13 | white_background: true 14 | quantile_weight: 1e-4 15 | viewer: false # Flag to use viewer 16 | debug: false # Flag to not use tensorboard 17 | 18 | # Optimization Parameters 19 | points_lr_init: 2e-4 20 | points_lr_final: 5e-6 21 | density_lr_init: 1e-1 22 | density_lr_final: 1e-2 23 | attributes_lr_init: 5e-3 24 | attributes_lr_final: 5e-4 25 | sh_factor: 0.02 26 | freeze_points: 18_000 # Points are frozen after this cycle 27 | 28 | # Dataset Parameters 29 | dataset: "colmap" 30 | data_path: "data/mipnerf360" 31 | scene: "garden" 32 | patch_based: false 33 | downsample: [8, 4] # Image downsample factors 34 | downsample_iterations: [0, 5_000] 35 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import einops 5 | import torch 6 | 7 | import radfoam 8 | 9 | from .colmap import COLMAPDataset 10 | from .blender import BlenderDataset 11 | 12 | 13 | dataset_dict = { 14 | "colmap": COLMAPDataset, 15 | "blender": BlenderDataset, 16 | } 17 | 18 | 19 | def get_up(c2ws): 20 | right = c2ws[:, :3, 0] 21 | down = c2ws[:, :3, 1] 22 | forward = c2ws[:, :3, 2] 23 | 24 | A = torch.einsum("bi,bj->bij", right, right).sum(dim=0) 25 | A += torch.einsum("bi,bj->bij", forward, forward).sum(dim=0) * 0.02 26 | 27 | l, V = torch.linalg.eig(A) 28 | 29 | min_idx = torch.argmin(l.real) 30 | global_up = V[:, min_idx].real 31 | global_up *= torch.einsum("bi,i->b", -down, global_up).sum().sign() 32 | 33 | return global_up 34 | 35 | 36 | class DataHandler: 37 | def __init__(self, dataset_args, rays_per_batch, device="cuda"): 38 | self.args = dataset_args 39 | self.rays_per_batch = rays_per_batch 40 | self.device = torch.device(device) 41 | self.img_wh = None 42 | self.patch_size = 8 43 | 44 | def reload(self, split, downsample=None): 45 | data_dir = os.path.join(self.args.data_path, self.args.scene) 46 | dataset = dataset_dict[self.args.dataset] 47 | if downsample is not None: 48 | split_dataset = dataset( 49 | data_dir, split=split, downsample=downsample 50 | ) 51 | else: 52 | split_dataset = dataset(data_dir, split=split) 53 | self.img_wh = split_dataset.img_wh 54 | self.fx = split_dataset.fx 55 | self.fy = split_dataset.fy 56 | self.c2ws = split_dataset.poses 57 | self.rays, self.rgbs = split_dataset.all_rays, split_dataset.all_rgbs 58 | self.alphas = getattr( 59 | split_dataset, "all_alphas", torch.ones_like(self.rgbs[..., 0:1]) 60 | ) 61 | 62 | self.viewer_up = get_up(self.c2ws) 63 | self.viewer_pos = self.c2ws[0, :3, 3] 64 | self.viewer_forward = self.c2ws[0, :3, 2] 65 | 66 | try: 67 | self.points3D = split_dataset.points3D 68 | self.points3D_colors = split_dataset.points3D_color 69 | except: 70 | self.points3D = None 71 | self.points3D_colors = None 72 | 73 | if split == "train": 74 | if self.args.patch_based: 75 | dw = self.img_wh[0] - (self.img_wh[0] % self.patch_size) 76 | dh = self.img_wh[1] - (self.img_wh[1] % self.patch_size) 77 | w_inds = np.linspace(0, self.img_wh[0] - 1, dw, dtype=int) 78 | h_inds = np.linspace(0, self.img_wh[1] - 1, dh, dtype=int) 79 | 80 | self.train_rays = self.rays[:, h_inds, :, :] 81 | self.train_rays = self.train_rays[:, :, w_inds, :] 82 | self.train_rgbs = self.rgbs[:, h_inds, :, :] 83 | self.train_rgbs = self.train_rgbs[:, :, w_inds, :] 84 | 85 | self.train_rays = einops.rearrange( 86 | self.train_rays, 87 | "n (x ph) (y pw) r -> (n x y) ph pw r", 88 | ph=self.patch_size, 89 | pw=self.patch_size, 90 | ) 91 | self.train_rgbs = einops.rearrange( 92 | self.train_rgbs, 93 | "n (x ph) (y pw) c -> (n x y) ph pw c", 94 | ph=self.patch_size, 95 | pw=self.patch_size, 96 | ) 97 | 98 | self.batch_size = self.rays_per_batch // (self.patch_size**2) 99 | else: 100 | self.train_rays = einops.rearrange( 101 | self.rays, "n h w r -> (n h w) r" 102 | ) 103 | self.train_rgbs = einops.rearrange( 104 | self.rgbs, "n h w c -> (n h w) c" 105 | ) 106 | self.train_alphas = einops.rearrange( 107 | self.alphas, "n h w 1 -> (n h w) 1" 108 | ) 109 | 110 | self.batch_size = self.rays_per_batch 111 | 112 | def get_iter(self): 113 | ray_batch_fetcher = radfoam.BatchFetcher( 114 | self.train_rays, self.batch_size, shuffle=True 115 | ) 116 | rgb_batch_fetcher = radfoam.BatchFetcher( 117 | self.train_rgbs, self.batch_size, shuffle=True 118 | ) 119 | alpha_batch_fetcher = radfoam.BatchFetcher( 120 | self.train_alphas, self.batch_size, shuffle=True 121 | ) 122 | 123 | while True: 124 | ray_batch = ray_batch_fetcher.next() 125 | rgb_batch = rgb_batch_fetcher.next() 126 | alpha_batch = alpha_batch_fetcher.next() 127 | 128 | yield ray_batch, rgb_batch, alpha_batch 129 | -------------------------------------------------------------------------------- /data_loader/blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import torch 5 | from torch.utils.data import Dataset 6 | import json 7 | import math 8 | 9 | 10 | def get_ray_directions(H, W, focal, center=None): 11 | x = np.arange(W, dtype=np.float32) + 0.5 12 | y = np.arange(H, dtype=np.float32) + 0.5 13 | x, y = np.meshgrid(x, y) 14 | pix_coords = np.stack([x, y], axis=-1).reshape(-1, 2) 15 | i, j = pix_coords[..., 0:1], pix_coords[..., 1:] 16 | 17 | cent = center if center is not None else [W / 2, H / 2] 18 | directions = np.concatenate( 19 | [ 20 | (i - cent[0]) / focal[0], 21 | (j - cent[1]) / focal[1], 22 | np.ones_like(i), 23 | ], 24 | axis=-1, 25 | ) 26 | ray_dirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) 27 | return torch.tensor(ray_dirs, dtype=torch.float32) 28 | 29 | 30 | class BlenderDataset(Dataset): 31 | def __init__(self, datadir, split="train", downsample=1): 32 | 33 | self.root_dir = datadir 34 | self.split = split 35 | self.downsample = downsample 36 | 37 | self.blender2opencv = np.array( 38 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 39 | ) 40 | self.points3D = None 41 | self.points3D_color = None 42 | 43 | with open( 44 | os.path.join(self.root_dir, f"transforms_{self.split}.json"), "r" 45 | ) as f: 46 | meta = json.load(f) 47 | if "w" in meta and "h" in meta: 48 | W, H = int(meta["w"]), int(meta["h"]) 49 | else: 50 | W, H = 800, 800 51 | 52 | self.img_wh = (int(W / self.downsample), int(H / self.downsample)) 53 | w, h = self.img_wh 54 | 55 | focal = ( 56 | 0.5 * w / math.tan(0.5 * meta["camera_angle_x"]) 57 | ) # scaled focal length 58 | 59 | self.fx, self.fy = focal, focal 60 | self.intrinsics = torch.tensor( 61 | [[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]] 62 | ) 63 | 64 | cam_ray_dirs = get_ray_directions( 65 | h, w, [self.intrinsics[0, 0], self.intrinsics[1, 1]] 66 | ) 67 | 68 | self.poses = [] 69 | self.cameras = [] 70 | self.all_rays = [] 71 | self.all_rgbs = [] 72 | self.all_alphas = [] 73 | for i, frame in enumerate(meta["frames"]): 74 | pose = np.array(frame["transform_matrix"]) @ self.blender2opencv 75 | c2w = torch.FloatTensor(pose) 76 | self.poses.append(c2w) 77 | world_ray_dirs = torch.einsum( 78 | "ij,kj->ik", 79 | cam_ray_dirs, 80 | c2w[:3, :3], 81 | ) 82 | world_ray_origins = c2w[:3, 3] + torch.zeros_like(cam_ray_dirs) 83 | world_rays = torch.cat([world_ray_origins, world_ray_dirs], dim=-1) 84 | world_rays = world_rays.reshape(self.img_wh[1], self.img_wh[0], 6) 85 | 86 | img_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 87 | img = Image.open(img_path) 88 | if self.downsample != 1.0: 89 | img = img.resize(self.img_wh, Image.LANCZOS) 90 | img = img.convert("RGBA") 91 | rgbas = torch.tensor(np.array(img), dtype=torch.float32) / 255.0 92 | rgbs = rgbas[..., :3] * rgbas[..., 3:4] + ( 93 | 1 - rgbas[..., 3:4] 94 | ) # white bg 95 | img.close() 96 | 97 | self.all_rays.append(world_rays) 98 | self.all_rgbs.append(rgbs) 99 | self.all_alphas.append(rgbas[..., -1:]) 100 | 101 | self.poses = torch.stack(self.poses) 102 | self.all_rays = torch.stack(self.all_rays) 103 | self.all_rgbs = torch.stack(self.all_rgbs) 104 | self.all_alphas = torch.stack(self.all_alphas) 105 | 106 | def __len__(self): 107 | return len(self.all_rgbs) 108 | 109 | def __getitem__(self, idx): 110 | 111 | if self.split == "train": # use data in the buffers 112 | sample = { 113 | "rays": self.all_rays[idx], 114 | "rgbs": self.all_rgbs[idx], 115 | "alphas": self.all_alphas[idx], 116 | } 117 | 118 | else: # create data for each image separately 119 | 120 | img = self.all_rgbs[idx] 121 | rays = self.all_rays[idx] 122 | alphas = self.all_alphas[idx] 123 | 124 | sample = {"rays": rays, "rgbs": img, "alphas": alphas} 125 | return sample 126 | 127 | 128 | if __name__ == "__main__": 129 | pass 130 | -------------------------------------------------------------------------------- /data_loader/colmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | import torch 7 | import pycolmap 8 | 9 | 10 | def get_cam_ray_dirs(camera): 11 | x = np.arange(camera.width, dtype=np.float32) + 0.5 12 | y = np.arange(camera.height, dtype=np.float32) + 0.5 13 | x, y = np.meshgrid(x, y) 14 | pix_coords = np.stack([x, y], axis=-1).reshape(-1, 2) 15 | ip_coords = camera.cam_from_img(pix_coords) 16 | ip_coords = np.concatenate( 17 | [ip_coords, np.ones_like(ip_coords[:, :1])], axis=-1 18 | ) 19 | ray_dirs = ip_coords / np.linalg.norm(ip_coords, axis=-1, keepdims=True) 20 | return torch.tensor(ray_dirs, dtype=torch.float32) 21 | 22 | 23 | class COLMAPDataset: 24 | def __init__(self, datadir, split, downsample): 25 | assert downsample in [1, 2, 4, 8] 26 | 27 | self.root_dir = datadir 28 | self.colmap_dir = os.path.join(datadir, "sparse/0/") 29 | self.split = split 30 | self.downsample = downsample 31 | 32 | if downsample == 1: 33 | images_dir = os.path.join(datadir, "images") 34 | else: 35 | images_dir = os.path.join(datadir, f"images_{downsample}") 36 | 37 | if not os.path.exists(images_dir): 38 | raise ValueError(f"Images directory {images_dir} not found") 39 | 40 | self.reconstruction = pycolmap.Reconstruction() 41 | self.reconstruction.read(self.colmap_dir) 42 | 43 | if len(self.reconstruction.cameras) > 1: 44 | raise ValueError("Multiple cameras are not supported") 45 | 46 | names = sorted(im.name for im in self.reconstruction.images.values()) 47 | indices = np.arange(len(names)) 48 | 49 | if split == "train": 50 | names = list(np.array(names)[indices % 8 != 0]) 51 | elif split == "test": 52 | names = list(np.array(names)[indices % 8 == 0]) 53 | else: 54 | raise ValueError(f"Invalid split: {split}") 55 | 56 | names = list(str(name) for name in names) 57 | 58 | im = Image.open(os.path.join(images_dir, names[0])) 59 | self.img_wh = im.size 60 | im.close() 61 | 62 | self.camera = list(self.reconstruction.cameras.values())[0] 63 | self.camera.rescale(self.img_wh[0], self.img_wh[1]) 64 | 65 | self.fx = self.camera.focal_length_x 66 | self.fy = self.camera.focal_length_y 67 | 68 | cam_ray_dirs = get_cam_ray_dirs(self.camera) 69 | 70 | self.images = [] 71 | for name in names: 72 | image = None 73 | for image_id in self.reconstruction.images: 74 | image = self.reconstruction.images[image_id] 75 | if image.name == name: 76 | break 77 | 78 | if image is None: 79 | raise ValueError( 80 | f"Image {name} not found in COLMAP reconstruction" 81 | ) 82 | 83 | self.images.append(image) 84 | 85 | self.poses = [] 86 | self.all_rays = [] 87 | self.all_rgbs = [] 88 | for image in tqdm(self.images): 89 | c2w = torch.tensor( 90 | image.cam_from_world.inverse().matrix(), dtype=torch.float32 91 | ) 92 | self.poses.append(c2w) 93 | world_ray_dirs = torch.einsum( 94 | "ij,kj->ik", 95 | cam_ray_dirs, 96 | c2w[:, :3], 97 | ) 98 | world_ray_origins = c2w[:, 3] + torch.zeros_like(cam_ray_dirs) 99 | world_rays = torch.cat([world_ray_origins, world_ray_dirs], dim=-1) 100 | world_rays = world_rays.reshape(self.img_wh[1], self.img_wh[0], 6) 101 | 102 | im = Image.open(os.path.join(images_dir, image.name)) 103 | im = im.convert("RGB") 104 | rgbs = torch.tensor(np.array(im), dtype=torch.float32) / 255.0 105 | im.close() 106 | 107 | self.all_rays.append(world_rays) 108 | self.all_rgbs.append(rgbs) 109 | 110 | self.poses = torch.stack(self.poses) 111 | self.all_rays = torch.stack(self.all_rays) 112 | self.all_rgbs = torch.stack(self.all_rgbs) 113 | 114 | self.points3D = [] 115 | self.points3D_color = [] 116 | for point in self.reconstruction.points3D.values(): 117 | self.points3D.append(point.xyz) 118 | self.points3D_color.append(point.color) 119 | 120 | self.points3D = torch.tensor( 121 | np.array(self.points3D), dtype=torch.float32 122 | ) 123 | self.points3D_color = torch.tensor( 124 | np.array(self.points3D_color), dtype=torch.float32 125 | ) 126 | self.points3D_color = self.points3D_color / 255.0 127 | -------------------------------------------------------------------------------- /external/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(PIP_GLFW) 2 | set(USE_PIP_GLFW 3 | True 4 | PARENT_SCOPE) 5 | set(GLFW_LIBRARY 6 | "" 7 | PARENT_SCOPE) 8 | set(GLFW_INCLUDES 9 | ${CMAKE_SOURCE_DIR}/external/submodules/glfw/include 10 | PARENT_SCOPE) 11 | else() 12 | set(USE_PIP_GLFW 13 | False 14 | PARENT_SCOPE) 15 | set(GLFW_BUILD_EXAMPLES 16 | OFF 17 | CACHE BOOL "" FORCE) 18 | set(GLFW_BUILD_TESTS 19 | OFF 20 | CACHE BOOL "" FORCE) 21 | set(GLFW_BUILD_DOCS 22 | OFF 23 | CACHE BOOL "" FORCE) 24 | add_subdirectory(submodules/glfw) 25 | set(GLFW_LIBRARY 26 | glfw 27 | PARENT_SCOPE) 28 | set(GLFW_INCLUDES 29 | "" 30 | PARENT_SCOPE) 31 | message(STATUS "GLFW not found from pip, building from source") 32 | endif() 33 | 34 | add_library(gl3w STATIC gl3w/gl3w.c) 35 | target_include_directories(gl3w PUBLIC "include") 36 | 37 | add_library( 38 | imgui STATIC 39 | submodules/imgui/imgui.cpp 40 | submodules/imgui/imgui_draw.cpp 41 | submodules/imgui/imgui_demo.cpp 42 | submodules/imgui/imgui_widgets.cpp 43 | submodules/imgui/imgui_tables.cpp 44 | submodules/imgui/backends/imgui_impl_glfw.cpp 45 | submodules/imgui/backends/imgui_impl_opengl3.cpp) 46 | target_include_directories( 47 | imgui PUBLIC "submodules/imgui" "submodules/glfw/include" 48 | "submodules/mesa/include") 49 | 50 | find_package(TBB GLOBAL) 51 | if(NOT TBB_FOUND) 52 | add_subdirectory(submodules/tbb) 53 | endif() 54 | 55 | set(RADFOAM_EXTERNAL_INCLUDES 56 | "${CMAKE_SOURCE_DIR}/external/include" 57 | "${CMAKE_SOURCE_DIR}/external/submodules/imgui" 58 | "${CMAKE_SOURCE_DIR}/external/submodules/imgui/backends" 59 | PARENT_SCOPE) 60 | -------------------------------------------------------------------------------- /external/include/KHR/khrplatform.h: -------------------------------------------------------------------------------- 1 | #ifndef __khrplatform_h_ 2 | #define __khrplatform_h_ 3 | 4 | /* 5 | ** Copyright (c) 2008-2018 The Khronos Group Inc. 6 | ** 7 | ** Permission is hereby granted, free of charge, to any person obtaining a 8 | ** copy of this software and/or associated documentation files (the 9 | ** "Materials"), to deal in the Materials without restriction, including 10 | ** without limitation the rights to use, copy, modify, merge, publish, 11 | ** distribute, sublicense, and/or sell copies of the Materials, and to 12 | ** permit persons to whom the Materials are furnished to do so, subject to 13 | ** the following conditions: 14 | ** 15 | ** The above copyright notice and this permission notice shall be included 16 | ** in all copies or substantial portions of the Materials. 17 | ** 18 | ** THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 19 | ** EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 20 | ** MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 21 | ** IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | ** CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 23 | ** TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 24 | ** MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. 25 | */ 26 | 27 | /* Khronos platform-specific types and definitions. 28 | * 29 | * The master copy of khrplatform.h is maintained in the Khronos EGL 30 | * Registry repository at https://github.com/KhronosGroup/EGL-Registry 31 | * The last semantic modification to khrplatform.h was at commit ID: 32 | * 67a3e0864c2d75ea5287b9f3d2eb74a745936692 33 | * 34 | * Adopters may modify this file to suit their platform. Adopters are 35 | * encouraged to submit platform specific modifications to the Khronos 36 | * group so that they can be included in future versions of this file. 37 | * Please submit changes by filing pull requests or issues on 38 | * the EGL Registry repository linked above. 39 | * 40 | * 41 | * See the Implementer's Guidelines for information about where this file 42 | * should be located on your system and for more details of its use: 43 | * http://www.khronos.org/registry/implementers_guide.pdf 44 | * 45 | * This file should be included as 46 | * #include 47 | * by Khronos client API header files that use its types and defines. 48 | * 49 | * The types in khrplatform.h should only be used to define API-specific types. 50 | * 51 | * Types defined in khrplatform.h: 52 | * khronos_int8_t signed 8 bit 53 | * khronos_uint8_t unsigned 8 bit 54 | * khronos_int16_t signed 16 bit 55 | * khronos_uint16_t unsigned 16 bit 56 | * khronos_int32_t signed 32 bit 57 | * khronos_uint32_t unsigned 32 bit 58 | * khronos_int64_t signed 64 bit 59 | * khronos_uint64_t unsigned 64 bit 60 | * khronos_intptr_t signed same number of bits as a pointer 61 | * khronos_uintptr_t unsigned same number of bits as a pointer 62 | * khronos_ssize_t signed size 63 | * khronos_usize_t unsigned size 64 | * khronos_float_t signed 32 bit floating point 65 | * khronos_time_ns_t unsigned 64 bit time in nanoseconds 66 | * khronos_utime_nanoseconds_t unsigned time interval or absolute time in 67 | * nanoseconds 68 | * khronos_stime_nanoseconds_t signed time interval in nanoseconds 69 | * khronos_boolean_enum_t enumerated boolean type. This should 70 | * only be used as a base type when a client API's boolean type is 71 | * an enum. Client APIs which use an integer or other type for 72 | * booleans cannot use this as the base type for their boolean. 73 | * 74 | * Tokens defined in khrplatform.h: 75 | * 76 | * KHRONOS_FALSE, KHRONOS_TRUE Enumerated boolean false/true values. 77 | * 78 | * KHRONOS_SUPPORT_INT64 is 1 if 64 bit integers are supported; otherwise 0. 79 | * KHRONOS_SUPPORT_FLOAT is 1 if floats are supported; otherwise 0. 80 | * 81 | * Calling convention macros defined in this file: 82 | * KHRONOS_APICALL 83 | * KHRONOS_APIENTRY 84 | * KHRONOS_APIATTRIBUTES 85 | * 86 | * These may be used in function prototypes as: 87 | * 88 | * KHRONOS_APICALL void KHRONOS_APIENTRY funcname( 89 | * int arg1, 90 | * int arg2) KHRONOS_APIATTRIBUTES; 91 | */ 92 | 93 | #if defined(__SCITECH_SNAP__) && !defined(KHRONOS_STATIC) 94 | # define KHRONOS_STATIC 1 95 | #endif 96 | 97 | /*------------------------------------------------------------------------- 98 | * Definition of KHRONOS_APICALL 99 | *------------------------------------------------------------------------- 100 | * This precedes the return type of the function in the function prototype. 101 | */ 102 | #if defined(KHRONOS_STATIC) 103 | /* If the preprocessor constant KHRONOS_STATIC is defined, make the 104 | * header compatible with static linking. */ 105 | # define KHRONOS_APICALL 106 | #elif defined(_WIN32) 107 | # define KHRONOS_APICALL __declspec(dllimport) 108 | #elif defined (__SYMBIAN32__) 109 | # define KHRONOS_APICALL IMPORT_C 110 | #elif defined(__ANDROID__) 111 | # define KHRONOS_APICALL __attribute__((visibility("default"))) 112 | #else 113 | # define KHRONOS_APICALL 114 | #endif 115 | 116 | /*------------------------------------------------------------------------- 117 | * Definition of KHRONOS_APIENTRY 118 | *------------------------------------------------------------------------- 119 | * This follows the return type of the function and precedes the function 120 | * name in the function prototype. 121 | */ 122 | #if defined(_WIN32) && !defined(_WIN32_WCE) && !defined(__SCITECH_SNAP__) 123 | /* Win32 but not WinCE */ 124 | # define KHRONOS_APIENTRY __stdcall 125 | #else 126 | # define KHRONOS_APIENTRY 127 | #endif 128 | 129 | /*------------------------------------------------------------------------- 130 | * Definition of KHRONOS_APIATTRIBUTES 131 | *------------------------------------------------------------------------- 132 | * This follows the closing parenthesis of the function prototype arguments. 133 | */ 134 | #if defined (__ARMCC_2__) 135 | #define KHRONOS_APIATTRIBUTES __softfp 136 | #else 137 | #define KHRONOS_APIATTRIBUTES 138 | #endif 139 | 140 | /*------------------------------------------------------------------------- 141 | * basic type definitions 142 | *-----------------------------------------------------------------------*/ 143 | #if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L) || defined(__GNUC__) || defined(__SCO__) || defined(__USLC__) 144 | 145 | 146 | /* 147 | * Using 148 | */ 149 | #include 150 | typedef int32_t khronos_int32_t; 151 | typedef uint32_t khronos_uint32_t; 152 | typedef int64_t khronos_int64_t; 153 | typedef uint64_t khronos_uint64_t; 154 | #define KHRONOS_SUPPORT_INT64 1 155 | #define KHRONOS_SUPPORT_FLOAT 1 156 | /* 157 | * To support platform where unsigned long cannot be used interchangeably with 158 | * inptr_t (e.g. CHERI-extended ISAs), we can use the stdint.h intptr_t. 159 | * Ideally, we could just use (u)intptr_t everywhere, but this could result in 160 | * ABI breakage if khronos_uintptr_t is changed from unsigned long to 161 | * unsigned long long or similar (this results in different C++ name mangling). 162 | * To avoid changes for existing platforms, we restrict usage of intptr_t to 163 | * platforms where the size of a pointer is larger than the size of long. 164 | */ 165 | #if defined(__SIZEOF_LONG__) && defined(__SIZEOF_POINTER__) 166 | #if __SIZEOF_POINTER__ > __SIZEOF_LONG__ 167 | #define KHRONOS_USE_INTPTR_T 168 | #endif 169 | #endif 170 | 171 | #elif defined(__VMS ) || defined(__sgi) 172 | 173 | /* 174 | * Using 175 | */ 176 | #include 177 | typedef int32_t khronos_int32_t; 178 | typedef uint32_t khronos_uint32_t; 179 | typedef int64_t khronos_int64_t; 180 | typedef uint64_t khronos_uint64_t; 181 | #define KHRONOS_SUPPORT_INT64 1 182 | #define KHRONOS_SUPPORT_FLOAT 1 183 | 184 | #elif defined(_WIN32) && !defined(__SCITECH_SNAP__) 185 | 186 | /* 187 | * Win32 188 | */ 189 | typedef __int32 khronos_int32_t; 190 | typedef unsigned __int32 khronos_uint32_t; 191 | typedef __int64 khronos_int64_t; 192 | typedef unsigned __int64 khronos_uint64_t; 193 | #define KHRONOS_SUPPORT_INT64 1 194 | #define KHRONOS_SUPPORT_FLOAT 1 195 | 196 | #elif defined(__sun__) || defined(__digital__) 197 | 198 | /* 199 | * Sun or Digital 200 | */ 201 | typedef int khronos_int32_t; 202 | typedef unsigned int khronos_uint32_t; 203 | #if defined(__arch64__) || defined(_LP64) 204 | typedef long int khronos_int64_t; 205 | typedef unsigned long int khronos_uint64_t; 206 | #else 207 | typedef long long int khronos_int64_t; 208 | typedef unsigned long long int khronos_uint64_t; 209 | #endif /* __arch64__ */ 210 | #define KHRONOS_SUPPORT_INT64 1 211 | #define KHRONOS_SUPPORT_FLOAT 1 212 | 213 | #elif 0 214 | 215 | /* 216 | * Hypothetical platform with no float or int64 support 217 | */ 218 | typedef int khronos_int32_t; 219 | typedef unsigned int khronos_uint32_t; 220 | #define KHRONOS_SUPPORT_INT64 0 221 | #define KHRONOS_SUPPORT_FLOAT 0 222 | 223 | #else 224 | 225 | /* 226 | * Generic fallback 227 | */ 228 | #include 229 | typedef int32_t khronos_int32_t; 230 | typedef uint32_t khronos_uint32_t; 231 | typedef int64_t khronos_int64_t; 232 | typedef uint64_t khronos_uint64_t; 233 | #define KHRONOS_SUPPORT_INT64 1 234 | #define KHRONOS_SUPPORT_FLOAT 1 235 | 236 | #endif 237 | 238 | 239 | /* 240 | * Types that are (so far) the same on all platforms 241 | */ 242 | typedef signed char khronos_int8_t; 243 | typedef unsigned char khronos_uint8_t; 244 | typedef signed short int khronos_int16_t; 245 | typedef unsigned short int khronos_uint16_t; 246 | 247 | /* 248 | * Types that differ between LLP64 and LP64 architectures - in LLP64, 249 | * pointers are 64 bits, but 'long' is still 32 bits. Win64 appears 250 | * to be the only LLP64 architecture in current use. 251 | */ 252 | #ifdef KHRONOS_USE_INTPTR_T 253 | typedef intptr_t khronos_intptr_t; 254 | typedef uintptr_t khronos_uintptr_t; 255 | #elif defined(_WIN64) 256 | typedef signed long long int khronos_intptr_t; 257 | typedef unsigned long long int khronos_uintptr_t; 258 | #else 259 | typedef signed long int khronos_intptr_t; 260 | typedef unsigned long int khronos_uintptr_t; 261 | #endif 262 | 263 | #if defined(_WIN64) 264 | typedef signed long long int khronos_ssize_t; 265 | typedef unsigned long long int khronos_usize_t; 266 | #else 267 | typedef signed long int khronos_ssize_t; 268 | typedef unsigned long int khronos_usize_t; 269 | #endif 270 | 271 | #if KHRONOS_SUPPORT_FLOAT 272 | /* 273 | * Float type 274 | */ 275 | typedef float khronos_float_t; 276 | #endif 277 | 278 | #if KHRONOS_SUPPORT_INT64 279 | /* Time types 280 | * 281 | * These types can be used to represent a time interval in nanoseconds or 282 | * an absolute Unadjusted System Time. Unadjusted System Time is the number 283 | * of nanoseconds since some arbitrary system event (e.g. since the last 284 | * time the system booted). The Unadjusted System Time is an unsigned 285 | * 64 bit value that wraps back to 0 every 584 years. Time intervals 286 | * may be either signed or unsigned. 287 | */ 288 | typedef khronos_uint64_t khronos_utime_nanoseconds_t; 289 | typedef khronos_int64_t khronos_stime_nanoseconds_t; 290 | #endif 291 | 292 | /* 293 | * Dummy value used to pad enum types to 32 bits. 294 | */ 295 | #ifndef KHRONOS_MAX_ENUM 296 | #define KHRONOS_MAX_ENUM 0x7FFFFFFF 297 | #endif 298 | 299 | /* 300 | * Enumerated boolean type 301 | * 302 | * Values other than zero should be considered to be true. Therefore 303 | * comparisons should not be made against KHRONOS_TRUE. 304 | */ 305 | typedef enum { 306 | KHRONOS_FALSE = 0, 307 | KHRONOS_TRUE = 1, 308 | KHRONOS_BOOLEAN_ENUM_FORCE_SIZE = KHRONOS_MAX_ENUM 309 | } khronos_boolean_enum_t; 310 | 311 | #endif /* __khrplatform_h_ */ 312 | -------------------------------------------------------------------------------- /prepare_colmap_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | 5 | import pycolmap 6 | from PIL import Image 7 | import tqdm 8 | 9 | 10 | def main(args): 11 | """Minimal script to run COLMAP on a directory of images.""" 12 | 13 | data_dir = args.data_dir 14 | reconstruction_dir = os.path.join(data_dir, "sparse") 15 | 16 | if os.path.exists(reconstruction_dir): 17 | raise ValueError("Reconstruction directory already exists") 18 | 19 | images_dir = os.path.join(data_dir, "images") 20 | if not os.path.exists(images_dir): 21 | raise ValueError("data_dir must contain an 'images' directory") 22 | 23 | database_path = os.path.join(data_dir, "database.db") 24 | if os.path.exists(database_path): 25 | raise ValueError("Database file already exists") 26 | 27 | database = pycolmap.Database(database_path) 28 | 29 | pycolmap.extract_features( 30 | database_path, 31 | images_dir, 32 | camera_mode=pycolmap.CameraMode.SINGLE, 33 | camera_model=args.camera_model, 34 | ) 35 | 36 | print(f"Imported {database.num_images} images to {database_path}") 37 | 38 | pycolmap.match_exhaustive(database_path) 39 | 40 | print(f"Feature matching completed") 41 | 42 | os.makedirs(reconstruction_dir) 43 | 44 | reconstructions = pycolmap.incremental_mapping( 45 | database_path, 46 | image_path=images_dir, 47 | output_path=reconstruction_dir, 48 | ) 49 | 50 | if len(reconstructions) > 1: 51 | warnings.warn("Multiple reconstructions found") 52 | 53 | reconstruction = reconstructions[0] 54 | 55 | os.makedirs(os.path.join(data_dir, "images_2"), exist_ok=True) 56 | os.makedirs(os.path.join(data_dir, "images_4"), exist_ok=True) 57 | os.makedirs(os.path.join(data_dir, "images_8"), exist_ok=True) 58 | 59 | print("Downsampling images") 60 | 61 | for image in tqdm.tqdm(list(reconstruction.images.values())): 62 | image_1_path = os.path.join(images_dir, image.name) 63 | image_2_path = os.path.join(data_dir, "images_2", image.name) 64 | image_4_path = os.path.join(data_dir, "images_4", image.name) 65 | image_8_path = os.path.join(data_dir, "images_8", image.name) 66 | 67 | pil_image = Image.open(image_1_path) 68 | pil_image_2 = pil_image.resize( 69 | (pil_image.width // 2, pil_image.height // 2), 70 | resample=Image.LANCZOS, 71 | ) 72 | pil_image_4 = pil_image.resize( 73 | (pil_image.width // 4, pil_image.height // 4), 74 | resample=Image.LANCZOS, 75 | ) 76 | pil_image_8 = pil_image.resize( 77 | (pil_image.width // 8, pil_image.height // 8), 78 | resample=Image.LANCZOS, 79 | ) 80 | 81 | pil_image_2.save(image_2_path) 82 | pil_image_4.save(image_4_path) 83 | pil_image_8.save(image_8_path) 84 | 85 | pil_image.close() 86 | pil_image_2.close() 87 | pil_image_4.close() 88 | pil_image_8.close() 89 | 90 | print("Exporting point cloud") 91 | 92 | reconstruction.export_PLY(os.path.join(data_dir, "point_cloud.ply")) 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("data_dir", type=str) 98 | parser.add_argument("--camera_model", type=str, default="OPENCV") 99 | args = parser.parse_args() 100 | main(args) 101 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=64", 4 | "pybind11", 5 | "glfw==2.6.5", 6 | "cmake_build_extension", 7 | ] 8 | build-backend = "setuptools.build_meta" 9 | 10 | [tool.black] 11 | line-length = 80 12 | -------------------------------------------------------------------------------- /radfoam_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/radfoam/3e7b52cf74e37ab2ab5e695f53570f515f537e3d/radfoam_model/__init__.py -------------------------------------------------------------------------------- /radfoam_model/render.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ErrorBox: 5 | def __init__(self): 6 | self.ray_error = None 7 | self.point_error = None 8 | 9 | 10 | class TraceRays(torch.autograd.Function): 11 | @staticmethod 12 | def forward( 13 | ctx, 14 | pipeline, 15 | _points, 16 | _attributes, 17 | _point_adjacency, 18 | _point_adjacency_offsets, 19 | rays, 20 | start_point, 21 | depth_quantiles, 22 | return_contribution, 23 | ): 24 | ctx.rays = rays 25 | ctx.start_point = start_point 26 | ctx.depth_quantiles = depth_quantiles 27 | ctx.pipeline = pipeline 28 | ctx.points = _points 29 | ctx.attributes = _attributes 30 | ctx.point_adjacency = _point_adjacency 31 | ctx.point_adjacency_offsets = _point_adjacency_offsets 32 | 33 | results = pipeline.trace_forward( 34 | _points, 35 | _attributes, 36 | _point_adjacency, 37 | _point_adjacency_offsets, 38 | rays, 39 | start_point, 40 | depth_quantiles=depth_quantiles, 41 | return_contribution=return_contribution, 42 | ) 43 | ctx.rgba = results["rgba"] 44 | ctx.depth_indices = results.get("depth_indices", None) 45 | 46 | errbox = ErrorBox() 47 | ctx.errbox = errbox 48 | 49 | return ( 50 | results["rgba"], 51 | results.get("depth", None), 52 | results.get("contribution", None), 53 | results["num_intersections"], 54 | errbox, 55 | ) 56 | 57 | @staticmethod 58 | def backward( 59 | ctx, 60 | grad_rgba, 61 | grad_depth, 62 | grad_contribution, 63 | grad_num_intersections, 64 | errbox_grad, 65 | ): 66 | del grad_contribution 67 | del grad_num_intersections 68 | del errbox_grad 69 | 70 | rays = ctx.rays 71 | start_point = ctx.start_point 72 | pipeline = ctx.pipeline 73 | rgba = ctx.rgba 74 | _points = ctx.points 75 | _attributes = ctx.attributes 76 | _point_adjacency = ctx.point_adjacency 77 | _point_adjacency_offsets = ctx.point_adjacency_offsets 78 | depth_quantiles = ctx.depth_quantiles 79 | 80 | results = pipeline.trace_backward( 81 | _points, 82 | _attributes, 83 | _point_adjacency, 84 | _point_adjacency_offsets, 85 | rays, 86 | start_point, 87 | rgba, 88 | grad_rgba, 89 | depth_quantiles, 90 | ctx.depth_indices, 91 | grad_depth, 92 | ctx.errbox.ray_error, 93 | ) 94 | points_grad = results["points_grad"] 95 | attr_grad = results["attr_grad"] 96 | ctx.errbox.point_error = results.get("point_error", None) 97 | 98 | points_grad[~points_grad.isfinite()] = 0 99 | attr_grad[~attr_grad.isfinite()] = 0 100 | 101 | del ( 102 | ctx.rays, 103 | ctx.start_point, 104 | ctx.pipeline, 105 | ctx.rgba, 106 | ctx.points, 107 | ctx.attributes, 108 | ctx.point_adjacency, 109 | ctx.point_adjacency_offsets, 110 | ctx.depth_quantiles, 111 | ) 112 | return ( 113 | None, # pipeline 114 | points_grad, # _points 115 | attr_grad, # _attributes 116 | None, # _point_adjacency 117 | None, # _point_adjacency_offsets 118 | None, # rays 119 | None, # start_point 120 | None, # depth_quantiles 121 | None, # return_contribution 122 | ) 123 | -------------------------------------------------------------------------------- /radfoam_model/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def inverse_softplus(x, beta, scale=1): 6 | # log(exp(scale*x)-1)/scale 7 | out = x / scale 8 | mask = x * beta < 20 * scale 9 | out[mask] = torch.log(torch.exp(beta * out[mask]) - 1 + 1e-10) / beta 10 | return out 11 | 12 | 13 | def psnr(img1, img2): 14 | mse = (((img1 - img2)) ** 2).view(-1, img1.shape[-1]).mean(0, keepdim=True) 15 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 16 | 17 | 18 | def get_expon_lr_func( 19 | lr_init, 20 | lr_final, 21 | warmup_steps=0, 22 | max_steps=1_000, 23 | ): 24 | """ 25 | Copied from Plenoxels 26 | 27 | Continuous learning rate decay function. Adapted from JaxNeRF 28 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 29 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 30 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 31 | function of lr_delay_mult, such that the initial learning rate is 32 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 33 | to the normal learning rate when steps>lr_delay_steps. 34 | :param conf: config subtree 'lr' or similar 35 | :param max_steps: int, the number of steps during optimization. 36 | :return HoF which takes step as input 37 | """ 38 | 39 | def helper(step): 40 | if warmup_steps and step < warmup_steps: 41 | return lr_init * step / warmup_steps 42 | elif step > max_steps: 43 | return 0 44 | t = np.clip((step - warmup_steps) / (max_steps - warmup_steps), 0, 1) 45 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 46 | return log_lerp 47 | 48 | return helper 49 | 50 | 51 | def get_cosine_lr_func( 52 | lr_init, 53 | lr_final, 54 | warmup_steps=0, 55 | max_steps=10_000, 56 | ): 57 | """ 58 | Copied from Plenoxels 59 | 60 | Continuous learning rate decay function. Adapted from JaxNeRF 61 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 62 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 63 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 64 | function of lr_delay_mult, such that the initial learning rate is 65 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 66 | to the normal learning rate when steps>lr_delay_steps. 67 | :param conf: config subtree 'lr' or similar 68 | :param max_steps: int, the number of steps during optimization. 69 | :return HoF which takes step as input 70 | """ 71 | 72 | def helper(step): 73 | if warmup_steps and step < warmup_steps: 74 | return lr_init * step / warmup_steps 75 | elif step > max_steps: 76 | return 0.0 77 | lr_cos = lr_final + 0.5 * (lr_init - lr_final) * ( 78 | 1 79 | + np.cos(np.pi * (step - warmup_steps) / (max_steps - warmup_steps)) 80 | ) 81 | return lr_cos 82 | 83 | return helper 84 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cmake==3.29.2 2 | cmake-format==0.6.13 3 | cmake_build_extension==0.6.1 4 | ConfigArgParse==1.7 5 | einops==0.8.1 6 | glfw==2.6.5 7 | pycolmap==3.11.1 8 | opencv-python==4.11.0.86 9 | pillow==11.0.0 10 | plyfile==1.0.3 11 | pybind11[global]==2.13.6 12 | pyyaml==6.0.2 13 | scipy==1.15.1 14 | tensorboard==2.19.0 15 | tqdm==4.67.1 -------------------------------------------------------------------------------- /scripts/cmake_clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | cmake-format -i CMakeLists.txt 4 | cmake-format -i src/CMakeLists.txt 5 | cmake-format -i external/CMakeLists.txt 6 | cmake-format -i torch_bindings/CMakeLists.txt 7 | -------------------------------------------------------------------------------- /scripts/torch_info.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import sysconfig 3 | import importlib.util 4 | 5 | lib_path = sysconfig.get_path("purelib") 6 | 7 | def import_module_from_path(module_name, file_path): 8 | spec = importlib.util.spec_from_file_location(module_name, file_path) 9 | module = importlib.util.module_from_spec(spec) 10 | spec.loader.exec_module(module) 11 | return module 12 | 13 | file_path = f"{lib_path}/torch/version.py" 14 | module = import_module_from_path("version", file_path) 15 | 16 | if sys.argv[1] == "torch": 17 | print(module.__version__.split("+")[0]) 18 | elif sys.argv[1] == "cuda": 19 | print(module.cuda) 20 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = radfoam 3 | description = C++ backend for RadFoam 4 | long_description = file: README.md 5 | long_description_content_type = text/markdown 6 | author = Daniel Rebain 7 | author_email = drebain@cs.ubc.ca 8 | license = Apache-2.0 9 | platforms = any 10 | url = https://radfoam.github.io 11 | project_urls = 12 | Source = https://github.com/theialab/radfoam 13 | keywords = ray tracing radiance field 14 | classifiers = 15 | 16 | [options] 17 | zip_safe = False 18 | packages = find: 19 | package_dir = 20 | =src 21 | python_requires = >=3.8 22 | 23 | [options.packages.find] 24 | where = src 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | from pathlib import Path 5 | 6 | import cmake_build_extension 7 | import setuptools 8 | import sysconfig 9 | import subprocess 10 | 11 | lib_path = sysconfig.get_path("purelib") 12 | assert os.path.exists( 13 | f"{lib_path}/torch" 14 | ), "Could not find PyTorch; please make sure it is installed in your environment before installing radfoam." 15 | 16 | cmake_options = [] 17 | 18 | if "CUDA_HOME" in os.environ: 19 | cmake_options.append(f"-DCUDA_TOOLKIT_ROOT_DIR={os.environ['CUDA_HOME']}") 20 | 21 | source_dir = Path(__file__).parent.absolute() 22 | cmake = (source_dir / "CMakeLists.txt").read_text() 23 | version = re.search(r"project\(\S+ VERSION (\S+)\)", cmake).group(1) 24 | 25 | install_requirements = [ 26 | "cmake==3.29.2", 27 | "cmake-format", 28 | "cmake_build_extension", 29 | "ConfigArgParse", 30 | "einops", 31 | "glfw==2.6.5", 32 | "pycolmap", 33 | "opencv-python", 34 | "pillow", 35 | "plyfile", 36 | "pybind11[global]", 37 | "pyyaml", 38 | "scipy", 39 | "tensorboard", 40 | "tqdm", 41 | ] 42 | 43 | 44 | setuptools.setup( 45 | version=version, 46 | install_requires=install_requirements, 47 | ext_modules=[ 48 | cmake_build_extension.CMakeExtension( 49 | name="RadFoamBindings", 50 | install_prefix="radfoam", 51 | cmake_depends_on=["pybind11"], 52 | write_top_level_init=None, 53 | source_dir=str(source_dir), 54 | cmake_configure_options=[ 55 | f"-DPython3_ROOT_DIR={Path(sys.prefix)}", 56 | "-DCALL_FROM_SETUP_PY:BOOL=ON", 57 | "-DBUILD_SHARED_LIBS:BOOL=OFF", 58 | "-DGPU_DEBUG:BOOL=OFF", 59 | "-DEXAMPLE_WITH_PYBIND11:BOOL=ON", 60 | f"-DTorch_DIR={lib_path}/torch/share/cmake/Torch", 61 | "-DPIP_GLFW:BOOL=ON", 62 | ] 63 | + cmake_options, 64 | ), 65 | ], 66 | cmdclass=dict( 67 | build_ext=cmake_build_extension.BuildExtension, 68 | ), 69 | ) 70 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(CUDAToolkit REQUIRED) 2 | 3 | add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0 -D_USE_MATH_DEFINES) 4 | 5 | list( 6 | APPEND 7 | RADFOAM_INCLUDES 8 | ${CMAKE_SOURCE_DIR}/external/submodules/termcolor/include 9 | ${CMAKE_SOURCE_DIR}/external/submodules/eigen 10 | ${CMAKE_SOURCE_DIR}/external/submodules/atomic_queue/include 11 | ${GLFW_INCLUDES}) 12 | 13 | list(APPEND RADFOAM_CXX_SOURCES viewer/viewer.cpp utils/batch_fetcher.cpp) 14 | list( 15 | APPEND 16 | RADFOAM_CUDA_SOURCES 17 | aabb_tree/aabb_tree.cu 18 | delaunay/delaunay.cu 19 | delaunay/sample_initial_tets.cu 20 | delaunay/growth_iteration.cu 21 | delaunay/delete_violations.cu 22 | delaunay/triangulation_ops.cu 23 | tracing/pipeline.cu) 24 | 25 | add_library(radfoam STATIC ${RADFOAM_CXX_SOURCES} ${RADFOAM_CUDA_SOURCES}) 26 | target_include_directories(radfoam PUBLIC ${RADFOAM_INCLUDES}) 27 | target_link_libraries( 28 | radfoam PUBLIC CUDA::cudart CUDA::cuda_driver ${GLFW_LIBRARY} gl3w imgui TBB::tbb) 29 | target_compile_options(radfoam PRIVATE $<$:-lineinfo -O3 30 | -Wall -Wno-unknown-pragmas> -O3) 31 | -------------------------------------------------------------------------------- /src/aabb_tree/aabb_tree.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../utils/cuda_array.h" 6 | #include "../utils/geometry.h" 7 | #include "../utils/random.h" 8 | #include "aabb_tree.h" 9 | 10 | namespace radfoam { 11 | 12 | /// @brief Sort the point set into an order appropriate for the leaves of an 13 | /// AABB tree 14 | void sort_points(CUDAArray &points_buffer, 15 | uint32_t num_points, 16 | CUDAArray &permutation); 17 | 18 | /// @brief Query a node from the specified level of the AABB tree 19 | template 20 | __forceinline__ __host__ __device__ AABB 21 | get_node(const AABB *aabb_tree, 22 | uint32_t tree_depth, 23 | uint32_t node_depth, 24 | uint32_t node_idx) { 25 | auto *level_start = aabb_tree + ((1 << tree_depth) - (1 << node_depth + 1)); 26 | return *(level_start + node_idx); 27 | } 28 | 29 | enum TraversalAction { 30 | Continue, 31 | SkipSubtree, 32 | Terminate, 33 | }; 34 | 35 | /// @brief Traverse the AABB tree in a depth-first manner 36 | template 37 | __forceinline__ __host__ __device__ void 38 | traverse(uint32_t num_points, uint32_t max_depth, Functor functor) { 39 | 40 | uint32_t current_depth = 0; 41 | uint32_t current_node = 0; 42 | uint32_t tree_depth = log2(pow2_round_up(num_points)); 43 | 44 | for (;;) { 45 | auto action = functor(current_depth, current_node); 46 | 47 | if (action == TraversalAction::Terminate) { 48 | break; 49 | } else if (action == TraversalAction::Continue && 50 | current_depth != max_depth) { 51 | current_node = 2 * current_node; 52 | current_depth++; 53 | continue; 54 | } 55 | 56 | current_node++; 57 | #ifdef __CUDA_ARCH__ 58 | uint32_t step_up_amount = 59 | min(__ffs(current_node) - 1, (int)current_depth); 60 | #else 61 | uint32_t step_up_amount = 62 | min(__builtin_ctz(current_node), (int)current_depth); 63 | #endif 64 | current_depth -= step_up_amount; 65 | current_node = current_node >> step_up_amount; 66 | 67 | uint32_t div = tree_depth - current_depth; 68 | uint32_t current_width = (num_points + (1 << div) - 1) >> div; 69 | 70 | if (current_node >= current_width) 71 | break; 72 | } 73 | } 74 | 75 | /// @brief Traverse the AABB tree cooperatively within a warp 76 | template 77 | __forceinline__ __device__ void warp_traverse(uint32_t num_points, 78 | NodeFunctor node_functor, 79 | LeafFunctor leaf_functor) { 80 | 81 | uint32_t base_width = pow2_round_up(num_points); 82 | uint32_t tree_depth = log2(base_width); 83 | 84 | uint32_t warp_idx = threadIdx.x / 32; 85 | uint8_t idx_in_warp = threadIdx.x % 32; 86 | uint32_t current_depth = tree_depth % 5; 87 | uint32_t current_node = 0; 88 | 89 | __shared__ uint32_t mask_stack_array[5 * (block_size / 32)]; 90 | __shared__ uint8_t offset_stack_array[5 * (block_size / 32)]; 91 | uint32_t *mask_stack = mask_stack_array + 5 * warp_idx; 92 | uint8_t *offset_stack = offset_stack_array + 5 * warp_idx; 93 | int8_t stack_idx = -1; 94 | 95 | uint8_t current_offset = 0; 96 | uint32_t current_mask = 0xffffffff; 97 | 98 | for (;;) { 99 | uint32_t div = tree_depth - current_depth; 100 | uint32_t current_width = (num_points + (1 << div) - 1) >> div; 101 | 102 | if (current_node >= current_width) { 103 | break; 104 | } 105 | 106 | bool maskbit = (current_mask >> current_offset) & 1; 107 | 108 | uint32_t thread_node = current_node + idx_in_warp; 109 | 110 | if (maskbit && current_depth == tree_depth) { 111 | leaf_functor(thread_node); 112 | } else if (maskbit) { 113 | TraversalAction action = node_functor(current_depth, thread_node); 114 | 115 | if (__any_sync(0xffffffff, action == TraversalAction::Terminate)) { 116 | break; 117 | } 118 | __syncwarp(0xffffffff); 119 | if (idx_in_warp == 0 && stack_idx >= 0) { 120 | mask_stack[stack_idx] = current_mask; 121 | offset_stack[stack_idx] = current_offset; 122 | } 123 | __syncwarp(0xffffffff); 124 | stack_idx++; 125 | current_offset = 0; 126 | current_mask = 127 | __ballot_sync(0xffffffff, action == TraversalAction::Continue); 128 | 129 | current_node = 32 * current_node; 130 | current_depth += 5; 131 | continue; 132 | } 133 | 134 | current_node += 32; 135 | current_offset++; 136 | 137 | while (current_offset == 32) { 138 | if (stack_idx < 0) { 139 | break; 140 | } 141 | 142 | stack_idx--; 143 | current_offset = offset_stack[stack_idx]; 144 | current_mask = mask_stack[stack_idx]; 145 | 146 | current_offset++; 147 | current_node = current_node >> 5; 148 | current_depth -= 5; 149 | } 150 | } 151 | } 152 | 153 | /// @brief Perform a warp-cooperative k-nearest neighbor search 154 | template 155 | __forceinline__ __device__ void warp_knn(const Vec3 *points, 156 | const AABB *aabb_tree, 157 | uint32_t num_points, 158 | const Vec3 &query, 159 | uint32_t k, 160 | scalar *distance_out, 161 | uint32_t *index_out) { 162 | uint32_t base_width = pow2_round_up(num_points); 163 | uint32_t tree_depth = log2(base_width); 164 | 165 | uint32_t warp_idx = threadIdx.x / 32; 166 | uint8_t idx_in_warp = threadIdx.x % 32; 167 | 168 | using Sort = cub::WarpMergeSort; 169 | __shared__ typename Sort::TempStorage temp_storage[block_size / 32]; 170 | 171 | auto compare = [](auto a, auto b) { return a < b; }; 172 | 173 | scalar distance = std::numeric_limits::max(); 174 | uint32_t index = 0; 175 | scalar max_distance = std::numeric_limits::max(); 176 | 177 | scalar local_keys[2]; 178 | uint32_t local_values[2]; 179 | 180 | uint32_t current_depth = 6; 181 | uint32_t current_node = 0; 182 | 183 | uint32_t steps = 0; 184 | 185 | auto leaf_functor = [&](uint32_t point_idx) { 186 | scalar new_dist = std::numeric_limits::max(); 187 | 188 | if (point_idx < num_points) { 189 | Vec3 point = points[point_idx]; 190 | new_dist = (point - query).norm(); 191 | } 192 | 193 | local_keys[0] = distance; 194 | local_keys[1] = new_dist; 195 | local_values[0] = index; 196 | local_values[1] = point_idx; 197 | 198 | Sort(temp_storage[warp_idx]).Sort(local_keys, local_values, compare); 199 | 200 | scalar d = __shfl_sync(0xffffffff, local_keys[0], idx_in_warp / 2); 201 | uint32_t i = __shfl_sync(0xffffffff, local_values[0], idx_in_warp / 2); 202 | if (idx_in_warp % 2 == 0) { 203 | distance = d; 204 | index = i; 205 | } 206 | d = __shfl_sync(0xffffffff, local_keys[1], idx_in_warp / 2); 207 | i = __shfl_sync(0xffffffff, local_values[1], idx_in_warp / 2); 208 | if (idx_in_warp % 2 == 1) { 209 | distance = d; 210 | index = i; 211 | } 212 | 213 | scalar new_max_distance = __shfl_sync(0xffffffff, distance, k - 1); 214 | if (new_max_distance < max_distance) { 215 | max_distance = new_max_distance; 216 | } 217 | 218 | steps++; 219 | }; 220 | 221 | while (current_depth < tree_depth) { 222 | AABB node0 = get_node( 223 | aabb_tree, tree_depth, current_depth, current_node + idx_in_warp); 224 | AABB node1 = get_node(aabb_tree, 225 | tree_depth, 226 | current_depth, 227 | current_node + idx_in_warp + 32); 228 | 229 | scalar dist0 = node0.sdf(query); 230 | scalar dist1 = node1.sdf(query); 231 | 232 | local_keys[0] = dist0; 233 | local_keys[1] = dist1; 234 | local_values[0] = current_node + idx_in_warp; 235 | local_values[1] = current_node + idx_in_warp + 32; 236 | 237 | Sort(temp_storage[warp_idx]).Sort(local_keys, local_values, compare); 238 | 239 | uint32_t best_node = __shfl_sync(0xffffffff, local_values[0], 0); 240 | 241 | uint32_t levels_to_step = min(tree_depth - current_depth, 6); 242 | 243 | current_node = best_node << levels_to_step; 244 | current_depth += levels_to_step; 245 | } 246 | 247 | current_node = max(current_node - 16, 0u); 248 | leaf_functor(current_node + idx_in_warp); 249 | 250 | distance = std::numeric_limits::max(); 251 | index = 0; 252 | 253 | auto node_functor = [&](uint32_t current_depth, uint32_t current_node) { 254 | if (current_node >= (1 << current_depth)) 255 | return TraversalAction::SkipSubtree; 256 | 257 | AABB node = 258 | get_node(aabb_tree, tree_depth, current_depth, current_node); 259 | 260 | scalar dist = (node.min - query) 261 | .cwiseMax(Vec3::Zero()) 262 | .cwiseMax(query - node.max) 263 | .norm(); 264 | 265 | if (dist > max_distance) { 266 | return TraversalAction::SkipSubtree; 267 | } 268 | 269 | return TraversalAction::Continue; 270 | }; 271 | 272 | warp_traverse(num_points, node_functor, leaf_functor); 273 | 274 | *distance_out = distance; 275 | *index_out = index; 276 | } 277 | 278 | } // namespace radfoam -------------------------------------------------------------------------------- /src/aabb_tree/aabb_tree.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../utils/cuda_helpers.h" 4 | #include "../utils/geometry.h" 5 | 6 | namespace radfoam { 7 | 8 | /// @brief Build an AABB tree from a set of points, assuming that the points are 9 | /// already sorted 10 | void build_aabb_tree(ScalarType scalar_type, 11 | const void *points, 12 | uint32_t num_points, 13 | void *aabb_tree); 14 | 15 | /// @brief Find the nearest neighbor of each query point 16 | void nn(ScalarType coord_scalar_type, 17 | const void *coords, 18 | const void *aabb_tree, 19 | const void *query_points, 20 | uint32_t num_points, 21 | uint32_t num_queries, 22 | uint32_t *indices, 23 | const void *stream = nullptr); 24 | 25 | /// @brief Find the nearest neighbor of a single query point 26 | uint32_t nn_cpu(ScalarType coord_scalar_type, 27 | const void *coords, 28 | const void *aabb_tree, 29 | const Vec3f &query, 30 | uint32_t num_points); 31 | 32 | } // namespace radfoam -------------------------------------------------------------------------------- /src/delaunay/delaunay.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "../utils/cuda_helpers.h" 13 | #include "../utils/geometry.h" 14 | #include "delaunay.h" 15 | 16 | #include "../utils/common_kernels.cuh" 17 | #include "exact_tree_ops.cuh" 18 | #include "sorted_map.cuh" 19 | 20 | namespace radfoam { 21 | 22 | /// @brief Sample num_samples Delaunay tets randomly from the point set 23 | void sample_initial_tets(const Vec3f *points, 24 | uint32_t num_points, 25 | const AABB *aabb_tree, 26 | SortedMap &tets_table, 27 | SortedMap &faces_table, 28 | uint32_t num_samples); 29 | 30 | /// @brief Grow the Delaunay mesh by finding tets adjacent to the frontier 31 | uint32_t growth_iteration(const Vec3f *points, 32 | uint32_t num_points, 33 | const AABB *aabb_tree, 34 | SortedMap &tets_table, 35 | SortedMap &faces_table, 36 | CUDAArray &frontier, 37 | uint32_t num_frontier); 38 | 39 | /// @brief Delete tets that violate the Delaunay condition 40 | uint32_t 41 | delete_delaunay_violations(const Vec3f *points, 42 | uint32_t num_points, 43 | const AABB *aabb_tree, 44 | SortedMap &tets_table, 45 | SortedMap &faces_table, 46 | CUDAArray &frontier, 47 | const uint32_t *face_to_tet); 48 | 49 | } // namespace radfoam -------------------------------------------------------------------------------- /src/delaunay/delaunay.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../utils/geometry.h" 6 | 7 | namespace radfoam { 8 | 9 | class TriangulationFailedError : public std::runtime_error { 10 | public: 11 | TriangulationFailedError(const std::string &message) 12 | : std::runtime_error(message) {} 13 | }; 14 | 15 | class Triangulation { 16 | public: 17 | virtual ~Triangulation() = default; 18 | 19 | virtual const uint32_t *permutation() const = 0; 20 | 21 | virtual uint32_t num_points() const = 0; 22 | 23 | virtual const IndexedTet *tets() const = 0; 24 | 25 | virtual uint32_t num_tets() const = 0; 26 | 27 | virtual uint32_t num_faces() const = 0; 28 | 29 | virtual const uint32_t *tet_adjacency() const = 0; 30 | 31 | virtual const uint32_t *point_adjacency() const = 0; 32 | 33 | virtual uint32_t point_adjacency_size() const = 0; 34 | 35 | virtual const uint32_t *point_adjacency_offsets() const = 0; 36 | 37 | virtual const uint32_t *vert_to_tet() const = 0; 38 | 39 | virtual bool 40 | rebuild(const void *points, uint32_t num_points, bool incremental) = 0; 41 | 42 | static std::unique_ptr 43 | create_triangulation(const void *points, uint32_t num_points); 44 | }; 45 | 46 | } // namespace radfoam -------------------------------------------------------------------------------- /src/delaunay/delete_violations.cu: -------------------------------------------------------------------------------- 1 | #include "delaunay.cuh" 2 | 3 | namespace radfoam { 4 | 5 | constexpr int DELAUNAY_VIOLATIONS_BLOCK_SIZE = 128; 6 | 7 | __global__ void 8 | delaunay_violations_kernel(const Vec3f *__restrict__ points, 9 | uint32_t num_points, 10 | const AABB *__restrict__ aabb_tree, 11 | IndexedTet *__restrict__ tets, 12 | uint32_t num_tets, 13 | bool *__restrict__ conditions) { 14 | uint32_t i = (blockIdx.x * blockDim.x + threadIdx.x) / 32; 15 | uint32_t lane = threadIdx.x % 32; 16 | 17 | if (i >= num_tets) { 18 | return; 19 | } 20 | 21 | uint32_t steps = 0; 22 | 23 | IndexedTet tet = tets[i]; 24 | swap(tet.vertices[0], tet.vertices[1]); 25 | 26 | bool condition = check_delaunay_warp( 27 | points, aabb_tree, num_points, tet, &steps); 28 | 29 | if (lane == 0) { 30 | conditions[i] = condition; 31 | } 32 | } 33 | 34 | uint32_t 35 | delete_delaunay_violations(const Vec3f *points, 36 | uint32_t num_points, 37 | const AABB *aabb_tree, 38 | SortedMap &tets_map, 39 | SortedMap &faces_map, 40 | CUDAArray &frontier, 41 | const uint32_t *face_to_tet) { 42 | auto tets_begin = tets_map.unique_keys.begin(); 43 | auto num_tets = tets_map.num_unique_keys; 44 | 45 | CUDAArray conditions(num_tets); 46 | auto conditions_begin = conditions.begin(); 47 | 48 | launch_kernel_1d(delaunay_violations_kernel, 49 | num_tets * 32, 50 | nullptr, 51 | points, 52 | num_points, 53 | aabb_tree, 54 | tets_begin, 55 | num_tets, 56 | conditions_begin); 57 | 58 | auto faces_begin = faces_map.unique_keys.begin(); 59 | auto num_faces = faces_map.num_unique_keys; 60 | 61 | CUDAArray face_status(num_faces); 62 | auto face_status_begin = face_status.begin(); 63 | 64 | constexpr uint8_t FACE_UNCHANGED = 0; 65 | constexpr uint8_t FACE_NOW_FRONTIER = 1; 66 | constexpr uint8_t FACE_DELETED = 2; 67 | 68 | auto check_face_status = [=] __device__(uint32_t i) { 69 | uint32_t tet_0 = face_to_tet[2 * i + 0] / 4; 70 | uint32_t tet_1 = face_to_tet[2 * i + 1] / 4; 71 | 72 | IndexedTriangle face = faces_begin[i]; 73 | uint32_t adjacent_tet; 74 | 75 | uint8_t status; 76 | if (tet_1 == UINT32_MAX / 4) { 77 | status = conditions_begin[tet_0] ? FACE_NOW_FRONTIER : FACE_DELETED; 78 | adjacent_tet = tet_0; 79 | } else { 80 | bool cond_0 = conditions_begin[tet_0]; 81 | bool cond_1 = conditions_begin[tet_1]; 82 | 83 | if (cond_0 && cond_1) { 84 | status = FACE_UNCHANGED; 85 | } else if (cond_0) { 86 | status = FACE_NOW_FRONTIER; 87 | adjacent_tet = tet_0; 88 | } else if (cond_1) { 89 | status = FACE_NOW_FRONTIER; 90 | adjacent_tet = tet_1; 91 | } else { 92 | status = FACE_DELETED; 93 | } 94 | } 95 | 96 | if (status == FACE_NOW_FRONTIER) { 97 | IndexedTet adjacent = tets_begin[adjacent_tet]; 98 | 99 | for (uint32_t j = 0; j < 4; ++j) { 100 | IndexedTriangle adjacent_face = adjacent.face(j); 101 | if (face == adjacent_face) { 102 | faces_begin[i] = adjacent_face; 103 | } 104 | } 105 | } 106 | 107 | face_status_begin[i] = status; 108 | }; 109 | 110 | for_n(u32zero(), num_faces, check_face_status); 111 | 112 | auto enumerated_faces = enumerate(faces_begin); 113 | 114 | CUDAArray unchanged_faces_buffer(num_faces); 115 | auto unchanged_faces_out = 116 | unenumerate(unchanged_faces_buffer.begin()); 117 | 118 | auto select_unchanged_and_frontier = 119 | [=] __device__(cub::KeyValuePair kv) { 120 | return face_status_begin[kv.key] == FACE_UNCHANGED || 121 | face_status_begin[kv.key] == FACE_NOW_FRONTIER; 122 | }; 123 | 124 | CUDAArray num_selected_buffer(2); 125 | auto num_selected_begin = num_selected_buffer.begin(); 126 | 127 | CUB_CALL(cub::DevicePartition::If(temp_data, 128 | temp_bytes, 129 | enumerated_faces, 130 | unchanged_faces_out, 131 | num_selected_begin, 132 | num_faces, 133 | select_unchanged_and_frontier)); 134 | 135 | frontier.expand(num_faces); 136 | auto frontier_faces_out = unenumerate(frontier.begin()); 137 | 138 | auto select_frontier = 139 | [=] __device__(cub::KeyValuePair kv) { 140 | return face_status_begin[kv.key] == FACE_NOW_FRONTIER; 141 | }; 142 | 143 | CUB_CALL(cub::DevicePartition::If(temp_data, 144 | temp_bytes, 145 | enumerated_faces, 146 | frontier_faces_out, 147 | num_selected_begin + 1, 148 | num_faces, 149 | select_frontier)); 150 | 151 | face_status.clear(); 152 | 153 | uint32_t num_selected_host[2]; 154 | cuda_check(cuMemcpyDtoH(&num_selected_host, 155 | (CUdeviceptr)num_selected_begin, 156 | 2 * sizeof(uint32_t))); 157 | 158 | uint32_t num_frontier = num_selected_host[1]; 159 | faces_map.insert( 160 | unchanged_faces_buffer.begin(), nullptr, num_selected_host[0]); 161 | 162 | unchanged_faces_buffer.clear(); 163 | 164 | CUDAArray temp_tets_buffer(num_tets); 165 | auto temp_tets_begin = temp_tets_buffer.begin(); 166 | 167 | CUB_CALL(cub::DevicePartition::Flagged(temp_data, 168 | temp_bytes, 169 | tets_begin, 170 | conditions_begin, 171 | temp_tets_begin, 172 | num_selected_begin, 173 | num_tets)); 174 | 175 | cuda_check(cuMemcpyDtoH( 176 | &num_selected_host, (CUdeviceptr)num_selected_begin, sizeof(uint32_t))); 177 | 178 | tets_map.insert(temp_tets_begin, nullptr, num_selected_host[0]); 179 | 180 | return num_frontier; 181 | } 182 | 183 | } // namespace radfoam -------------------------------------------------------------------------------- /src/delaunay/growth_iteration.cu: -------------------------------------------------------------------------------- 1 | #include "delaunay.cuh" 2 | 3 | namespace radfoam { 4 | 5 | constexpr int GROWTH_ITERATION_BLOCK_SIZE = 128; 6 | 7 | __global__ void 8 | growth_iteration_kernel(const Vec3f *__restrict__ points, 9 | uint32_t num_points, 10 | const AABB *__restrict__ aabb_tree, 11 | const IndexedTriangle *__restrict__ frontier, 12 | uint32_t num_frontier, 13 | IndexedTet *__restrict__ new_tets, 14 | IndexedTriangle *__restrict__ new_faces) { 15 | uint32_t i = (blockIdx.x * blockDim.x + threadIdx.x) / 32; 16 | uint32_t lane = threadIdx.x % 32; 17 | 18 | if (i >= num_frontier) { 19 | return; 20 | } 21 | 22 | IndexedTriangle seed_face = frontier[i]; 23 | Vec3f v0 = points[seed_face.vertices[0]]; 24 | Vec3f v1 = points[seed_face.vertices[1]]; 25 | Vec3f v2 = points[seed_face.vertices[2]]; 26 | 27 | IndexedTet tet; 28 | tet.vertices[0] = seed_face.vertices[0]; 29 | tet.vertices[1] = seed_face.vertices[1]; 30 | tet.vertices[2] = seed_face.vertices[2]; 31 | 32 | uint32_t steps = 0; 33 | bool found; 34 | tet.vertices[3] = 35 | maximal_empty_sphere(points, 36 | aabb_tree, 37 | num_points, 38 | v0, 39 | v1, 40 | v2, 41 | tet.vertices[0], 42 | tet.vertices[1], 43 | tet.vertices[2], 44 | 3, 45 | &steps, 46 | &found); 47 | 48 | if (lane != 0) { 49 | return; 50 | } 51 | 52 | Vec3f v3 = points[tet.vertices[3]]; 53 | 54 | IndexedTriangle faces[3]; 55 | uint32_t j = 0; 56 | if (found) { 57 | Vec3f n = (v1 - v0).cross((v2 - v1)); 58 | if (n.dot((v0 - v3)) < 0) { 59 | swap(tet.vertices[0], tet.vertices[1]); 60 | } 61 | 62 | for (uint32_t k = 0; k < 4; ++k) { 63 | auto face = tet.face(k); 64 | if (face != seed_face) { 65 | switch (j) { // Avoid spilling to local memory 66 | case 0: 67 | faces[0] = face; 68 | break; 69 | case 1: 70 | faces[1] = face; 71 | break; 72 | case 2: 73 | faces[2] = face; 74 | break; 75 | } 76 | j++; 77 | } 78 | } 79 | } else { 80 | uint32_t max = UINT32_MAX; 81 | 82 | tet.vertices[0] = max; 83 | tet.vertices[1] = max; 84 | tet.vertices[2] = max; 85 | tet.vertices[3] = max; 86 | 87 | faces[0] = IndexedTriangle(max, max, max); 88 | faces[1] = IndexedTriangle(max, max, max); 89 | faces[2] = IndexedTriangle(max, max, max); 90 | } 91 | 92 | new_tets[i] = tet; 93 | new_faces[i * 3 + 0] = faces[0]; 94 | new_faces[i * 3 + 1] = faces[1]; 95 | new_faces[i * 3 + 2] = faces[2]; 96 | } 97 | 98 | uint32_t growth_iteration(const Vec3f *points, 99 | uint32_t num_points, 100 | const AABB *aabb_tree, 101 | SortedMap &tets_map, 102 | SortedMap &faces_map, 103 | CUDAArray &frontier, 104 | uint32_t num_frontier) { 105 | 106 | frontier.expand(num_frontier * 3, true); 107 | auto frontier_begin = frontier.begin(); 108 | 109 | CUDAArray new_tets(num_frontier * 2); 110 | auto new_tets_begin = new_tets.begin(); 111 | 112 | CUDAArray new_faces(num_frontier * 6); 113 | auto new_faces_begin = new_faces.begin(); 114 | 115 | launch_kernel_1d(growth_iteration_kernel, 116 | num_frontier * 32, 117 | nullptr, 118 | points, 119 | num_points, 120 | aabb_tree, 121 | frontier_begin, 122 | num_frontier, 123 | new_tets_begin, 124 | new_faces_begin); 125 | SortedMap new_tets_map; 126 | new_tets_map.num_unique_keys = num_frontier; 127 | new_tets_map.insert(new_tets_begin, nullptr, num_frontier); 128 | auto unique_new_tets_begin = new_tets_map.unique_keys.begin(); 129 | uint32_t num_new_tets = new_tets_map.num_unique_keys; 130 | 131 | new_tets.clear(); 132 | 133 | SortedMap new_faces_map; 134 | new_faces_map.num_unique_keys = num_frontier * 3; 135 | new_faces_map.insert(new_faces_begin, nullptr, num_frontier * 3); 136 | auto unique_new_faces_begin = new_faces_map.unique_keys.begin(); 137 | uint32_t num_new_faces = new_faces_map.num_unique_keys; 138 | 139 | new_faces.clear(); 140 | 141 | CUDAArray num_selected(2); 142 | auto num_selected_begin = num_selected.begin(); 143 | 144 | uint32_t num_tets = tets_map.num_unique_keys; 145 | 146 | CUDAArray tets(num_tets + num_frontier); 147 | auto tets_begin = tets.begin(); 148 | IndexedTet *tets_end = tets_begin + num_tets; 149 | 150 | copy_range(tets_map.unique_keys.begin(), 151 | tets_map.unique_keys.begin() + num_tets, 152 | tets_begin); 153 | 154 | CUDAArray flags(3 * num_frontier); 155 | auto flags_begin = flags.begin(); 156 | 157 | auto check_tets = [=] __device__(uint32_t i) { 158 | IndexedTet tet = unique_new_tets_begin[i]; 159 | if (tet.vertices[0] == UINT32_MAX) { 160 | flags_begin[i] = false; 161 | return; 162 | } 163 | 164 | auto it = binary_search(tets_begin, tets_end, tet); 165 | 166 | flags_begin[i] = it == tets_end; 167 | }; 168 | 169 | for_n(u32zero(), num_new_tets, check_tets); 170 | 171 | CUB_CALL(cub::DeviceSelect::Flagged(temp_data, 172 | temp_bytes, 173 | unique_new_tets_begin, 174 | flags_begin, 175 | tets_end, 176 | num_selected_begin, 177 | num_new_tets)); 178 | 179 | cuda_check(cuMemcpyDtoH( 180 | &num_new_tets, (CUdeviceptr)num_selected_begin, sizeof(uint32_t))); 181 | num_tets += num_new_tets; 182 | tets_map.insert(tets_begin, nullptr, num_tets); 183 | 184 | tets.clear(); 185 | 186 | uint32_t num_faces = faces_map.num_unique_keys; 187 | 188 | CUDAArray faces(num_faces + 3 * num_frontier); 189 | auto faces_begin = faces.begin(); 190 | copy_range(faces_map.unique_keys.begin(), 191 | faces_map.unique_keys.begin() + num_faces, 192 | faces_begin); 193 | IndexedTriangle *faces_end = faces_begin + num_faces; 194 | 195 | auto check_faces = [=] __device__(uint32_t i) { 196 | IndexedTriangle face = unique_new_faces_begin[i]; 197 | if (face.vertices[0] == UINT32_MAX) { 198 | flags_begin[i] = false; 199 | return; 200 | } 201 | 202 | auto it = binary_search(faces_begin, faces_end, face); 203 | 204 | flags_begin[i] = it == faces_end; 205 | }; 206 | for_n(u32zero(), num_new_faces, check_faces); 207 | 208 | CUB_CALL(cub::DeviceSelect::Flagged(temp_data, 209 | temp_bytes, 210 | unique_new_faces_begin, 211 | flags_begin, 212 | frontier_begin, 213 | num_selected_begin + 1, 214 | num_new_faces)); 215 | 216 | cuda_check(cuMemcpyDtoH(&num_new_faces, 217 | (CUdeviceptr)(num_selected_begin + 1), 218 | sizeof(uint32_t))); 219 | 220 | num_frontier = num_new_faces; 221 | num_faces += num_frontier; 222 | 223 | copy_range(frontier_begin, frontier_begin + num_frontier, faces_end); 224 | 225 | faces_map.insert(faces_begin, nullptr, num_faces); 226 | 227 | return num_frontier; 228 | } 229 | 230 | } // namespace radfoam -------------------------------------------------------------------------------- /src/delaunay/predicate.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../utils/geometry.h" 4 | #include "shewchuk.cuh" 5 | 6 | namespace radfoam { 7 | 8 | struct HalfspacePredicate { 9 | const Vec3f *v0; 10 | const Vec3f *v1; 11 | const Vec3f *v2; 12 | 13 | __forceinline__ __device__ HalfspacePredicate(const Vec3f &v0, 14 | const Vec3f &v1, 15 | const Vec3f &v2) 16 | : v0(&v0), v1(&v1), v2(&v2) {} 17 | 18 | __forceinline__ __device__ bool 19 | check_point(const Vec3f &v3, bool conservative = false) const { 20 | if (conservative) { 21 | return orient3dconservative(*v0, *v1, v3, *v2) == 22 | PredicateResult::Inside; 23 | } else { 24 | return orient3d(*v0, *v1, v3, *v2) == PredicateResult::Inside; 25 | } 26 | } 27 | 28 | __forceinline__ __device__ bool 29 | check_aabb_conservative(const AABB &aabb) const { 30 | bool inside = false; 31 | #pragma unroll 32 | for (int i = 0; i < 8; ++i) { 33 | Vec3f corner((i & 1) ? aabb.min[0] : aabb.max[0], 34 | (i & 2) ? aabb.min[1] : aabb.max[1], 35 | (i & 4) ? aabb.min[2] : aabb.max[2]); 36 | inside |= check_point(corner, true); 37 | } 38 | return inside; 39 | } 40 | }; 41 | 42 | struct EmptyCircumspherePredicate { 43 | const Vec3f *v0; 44 | const Vec3f *v1; 45 | const Vec3f *v2; 46 | Vec3f v3; 47 | Vec3f c; 48 | 49 | __forceinline__ __device__ EmptyCircumspherePredicate(const Vec3f &v0, 50 | const Vec3f &v1, 51 | const Vec3f &v2) 52 | : v0(&v0), v1(&v1), v2(&v2), c(Vec3f::Zero()) { 53 | c = Vec3f(std::numeric_limits::infinity(), 54 | std::numeric_limits::infinity(), 55 | std::numeric_limits::infinity()); 56 | v3 = Vec3f(std::numeric_limits::infinity(), 57 | std::numeric_limits::infinity(), 58 | std::numeric_limits::infinity()); 59 | } 60 | 61 | inline __device__ bool check_point(const Vec3f &v3_new, 62 | bool conservative = false) const { 63 | if (!isfinite(v3[0])) { 64 | return true; 65 | } 66 | if (conservative) { 67 | return insphereconservative(*v0, *v1, *v2, v3_new, v3) == 68 | PredicateResult::Inside; 69 | } else { 70 | return insphere(*v0, *v1, *v2, v3_new, v3) == 71 | PredicateResult::Inside; 72 | } 73 | } 74 | 75 | inline __device__ bool 76 | check_aabb_conservative(const AABB &aabb) const { 77 | if (!isfinite(c[0])) { 78 | return true; 79 | } 80 | Vec3f x = c.cwiseMin(aabb.max).cwiseMax(aabb.min); 81 | return check_point(x, true); 82 | } 83 | 84 | __forceinline__ __device__ bool 85 | warp_update(const Vec3f &v3_new, bool valid, bool &found) { 86 | uint32_t valid_mask = __ballot_sync(0xffffffff, valid); 87 | uint32_t lane_id = threadIdx.x % 32; 88 | bool is_best_lane = false; 89 | for (uint32_t i = 0; i < 32; ++i) { 90 | if (valid_mask & (1 << i)) { 91 | Vec3f v3i; 92 | v3i[0] = __shfl_sync(0xffffffff, v3_new[0], i); 93 | v3i[1] = __shfl_sync(0xffffffff, v3_new[1], i); 94 | v3i[2] = __shfl_sync(0xffffffff, v3_new[2], i); 95 | 96 | if (!found || check_point(v3i)) { 97 | found = true; 98 | update(v3i); 99 | is_best_lane = lane_id == i; 100 | } 101 | } 102 | } 103 | return is_best_lane; 104 | } 105 | 106 | __forceinline__ __device__ void update(const Vec3f &v3_new) { 107 | v3 = v3_new; 108 | 109 | Vec3f u0 = *v1 - *v0; 110 | Vec3f u1 = *v2 - *v0; 111 | Vec3f u2 = v3 - *v0; 112 | 113 | Vec3f w0 = u0.cross(u1); 114 | Vec3f w1 = u1.cross(u2); 115 | Vec3f w2 = u2.cross(u0); 116 | 117 | float vol = u0.dot(w1) / 6; 118 | Vec3f num = u0.squaredNorm() * w1 + u1.squaredNorm() * w2 + 119 | u2.squaredNorm() * w0; 120 | Vec3f x = num / (12 * vol); 121 | 122 | if (isfinite(x[0]) && isfinite(x[1]) && isfinite(x[2])) { 123 | c = x + *v0; 124 | } 125 | } 126 | }; 127 | 128 | } // namespace radfoam -------------------------------------------------------------------------------- /src/delaunay/sample_initial_tets.cu: -------------------------------------------------------------------------------- 1 | #include "delaunay.cuh" 2 | 3 | namespace radfoam { 4 | 5 | constexpr int SAMPLE_INITIAL_TETS_BLOCK_SIZE = 128; 6 | 7 | __global__ void sample_initial_tets_kernel(const Vec3f *points, 8 | uint32_t num_points, 9 | const AABB *aabb_tree, 10 | IndexedTet *tets, 11 | IndexedTriangle *faces, 12 | uint32_t num_samples) { 13 | 14 | uint32_t i = (blockIdx.x * blockDim.x + threadIdx.x) / 32; 15 | uint32_t lane = threadIdx.x % 32; 16 | 17 | if (i >= num_samples) { 18 | return; 19 | } 20 | 21 | uint32_t steps = 0; 22 | 23 | IndexedTet tet; 24 | RNGState rng = make_rng(i); 25 | 26 | Vec3f v0, v1, v2, v3; 27 | 28 | bool found = false; 29 | for (;;) { 30 | tet.vertices[0] = i % num_points; 31 | 32 | uint32_t nn_idx; 33 | 34 | if (lane == 0) { 35 | nn_idx = vertex_nearest_neighbour( 36 | points, aabb_tree, num_points, tet.vertices[0], &steps); 37 | } 38 | 39 | tet.vertices[1] = __shfl_sync(0xffffffff, nn_idx, 0); 40 | 41 | v0 = points[tet.vertices[0]]; 42 | v1 = points[tet.vertices[1]]; 43 | 44 | Vec3f c = (v0 + v1) / 2; 45 | Vec3f d0 = v1 - v0; 46 | float r_squared = d0.squaredNorm() / 4; 47 | Vec3f dr = randn3(rng); 48 | Vec3f n = d0.cross(dr.normalized()); 49 | Vec3f d1 = n.cross(d0); 50 | d1.normalize(); 51 | v2 = c + d1 * sqrt(r_squared); 52 | 53 | tet.vertices[2] = maximal_empty_sphere( 54 | points, 55 | aabb_tree, 56 | num_points, 57 | v0, 58 | v1, 59 | v2, 60 | tet.vertices[0], 61 | tet.vertices[1], 62 | UINT32_MAX, 63 | 2, 64 | &steps, 65 | &found); 66 | if (!found) { 67 | break; 68 | } 69 | 70 | v2 = points[tet.vertices[2]]; 71 | 72 | tet.vertices[3] = maximal_empty_sphere( 73 | points, 74 | aabb_tree, 75 | num_points, 76 | v0, 77 | v1, 78 | v2, 79 | tet.vertices[0], 80 | tet.vertices[1], 81 | tet.vertices[2], 82 | 3, 83 | &steps, 84 | &found); 85 | 86 | if (!found) { 87 | break; 88 | } 89 | 90 | v3 = points[tet.vertices[3]]; 91 | 92 | if (!found) { 93 | break; 94 | } 95 | 96 | found = check_delaunay( 97 | points, aabb_tree, num_points, tet, nullptr); 98 | 99 | break; 100 | } 101 | 102 | if (lane != 0) { 103 | return; 104 | } 105 | 106 | IndexedTriangle _faces[4]; 107 | 108 | if (found) { 109 | Vec3f n = tet.face(0).normal(v0, v1, v2); 110 | if (n.dot(v0 - v3) < 0) { 111 | swap(tet.vertices[0], tet.vertices[1]); 112 | } 113 | 114 | _faces[0] = tet.face(0); 115 | _faces[1] = tet.face(1); 116 | _faces[2] = tet.face(2); 117 | _faces[3] = tet.face(3); 118 | } else { 119 | uint32_t max = UINT32_MAX; 120 | 121 | tet.vertices[0] = max; 122 | tet.vertices[1] = max; 123 | tet.vertices[2] = max; 124 | tet.vertices[3] = max; 125 | 126 | _faces[0] = IndexedTriangle(max, max, max); 127 | _faces[1] = IndexedTriangle(max, max, max); 128 | _faces[2] = IndexedTriangle(max, max, max); 129 | _faces[3] = IndexedTriangle(max, max, max); 130 | } 131 | tets[i] = tet; 132 | faces[i * 4 + 0] = _faces[0]; 133 | faces[i * 4 + 1] = _faces[1]; 134 | faces[i * 4 + 2] = _faces[2]; 135 | faces[i * 4 + 3] = _faces[3]; 136 | } 137 | 138 | void sample_initial_tets(const Vec3f *points, 139 | uint32_t num_points, 140 | const AABB *aabb_tree, 141 | SortedMap &tets_map, 142 | SortedMap &faces_map, 143 | uint32_t num_samples) { 144 | 145 | CUDAArray temp_tets(num_samples); 146 | IndexedTet *temp_tets_begin = temp_tets.begin(); 147 | 148 | CUDAArray temp_faces(num_samples * 4); 149 | IndexedTriangle *temp_faces_begin = temp_faces.begin(); 150 | 151 | launch_kernel_1d(sample_initial_tets_kernel, 152 | num_samples * 32, 153 | nullptr, 154 | points, 155 | num_points, 156 | aabb_tree, 157 | temp_tets_begin, 158 | temp_faces_begin, 159 | num_samples); 160 | 161 | tets_map.insert(temp_tets_begin, nullptr, num_samples); 162 | faces_map.insert(temp_faces_begin, nullptr, num_samples * 4); 163 | } 164 | 165 | } // namespace radfoam -------------------------------------------------------------------------------- /src/delaunay/sorted_map.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../utils/cuda_array.h" 4 | 5 | namespace radfoam { 6 | 7 | template 8 | struct SortedMap { 9 | CUDAArray unique_keys; 10 | CUDAArray unique_values; 11 | uint32_t num_unique_keys; 12 | 13 | SortedMap() { num_unique_keys = 0; } 14 | 15 | SortedMap(K *keys, V *values, uint32_t num_keys) { 16 | insert(keys, values, num_keys); 17 | } 18 | 19 | void insert(K *keys, V *values, uint32_t num_keys) { 20 | 21 | if (values != nullptr) { 22 | unique_values.resize(num_keys); 23 | 24 | CUB_CALL(cub::DeviceMergeSort::SortPairs( 25 | temp_data, temp_bytes, keys, values, num_keys, std::less())); 26 | } else { 27 | CUB_CALL(cub::DeviceMergeSort::SortKeys( 28 | temp_data, temp_bytes, keys, num_keys, std::less())); 29 | } 30 | 31 | auto sorted_keys_begin = keys; 32 | 33 | auto is_valid_unique_key = 34 | [=] __device__(cub::KeyValuePair pair) { 35 | auto i = pair.key; 36 | K key = sorted_keys_begin[i]; 37 | if (!key.is_valid()) { 38 | return false; 39 | } 40 | if (i > 0) { 41 | K prev_key = sorted_keys_begin[i - 1]; 42 | if (key == prev_key) { 43 | return false; 44 | } 45 | } 46 | return true; 47 | }; 48 | 49 | auto is_valid_unique_value = 50 | [=] __device__(cub::KeyValuePair pair) { 51 | auto i = pair.key; 52 | K key = sorted_keys_begin[i]; 53 | if (!key.is_valid()) { 54 | return false; 55 | } 56 | if (i > 0) { 57 | K prev_key = sorted_keys_begin[i - 1]; 58 | if (key == prev_key) { 59 | return false; 60 | } 61 | } 62 | return true; 63 | }; 64 | 65 | size_t *num_unique_keys_device; 66 | cuda_check( 67 | cuMemAlloc(reinterpret_cast(&num_unique_keys_device), 68 | sizeof(size_t))); 69 | 70 | unique_keys.resize(num_keys); 71 | 72 | auto enumerated_sorted_keys = enumerate(sorted_keys_begin); 73 | auto unenumerated_unique_keys = unenumerate(unique_keys.begin()); 74 | 75 | CUB_CALL(cub::DeviceSelect::If(temp_data, 76 | temp_bytes, 77 | enumerated_sorted_keys, 78 | unenumerated_unique_keys, 79 | num_unique_keys_device, 80 | num_keys, 81 | is_valid_unique_key)); 82 | 83 | if (values != nullptr) { 84 | unique_values.resize(num_keys); 85 | 86 | auto enumerated_sorted_values = enumerate(values); 87 | auto unenumerated_unique_values = 88 | unenumerate(unique_values.begin()); 89 | 90 | CUB_CALL(cub::DeviceSelect::If(temp_data, 91 | temp_bytes, 92 | enumerated_sorted_values, 93 | unenumerated_unique_values, 94 | num_unique_keys_device, 95 | num_keys, 96 | is_valid_unique_value)); 97 | } 98 | 99 | cuda_check( 100 | cuMemcpyDtoH(&num_unique_keys, 101 | reinterpret_cast(num_unique_keys_device), 102 | sizeof(size_t))); 103 | cuda_check( 104 | cuMemFree(reinterpret_cast(num_unique_keys_device))); 105 | } 106 | 107 | void clear() { 108 | num_unique_keys = 0; 109 | unique_keys.clear(); 110 | unique_values.clear(); 111 | } 112 | }; 113 | 114 | } // namespace radfoam -------------------------------------------------------------------------------- /src/delaunay/triangulation_ops.cu: -------------------------------------------------------------------------------- 1 | #include "triangulation_ops.h" 2 | 3 | #include "../utils/common_kernels.cuh" 4 | #include "exact_tree_ops.cuh" 5 | 6 | namespace radfoam { 7 | 8 | template 9 | __global__ void 10 | farthest_neighbor_kernel(const Vec3 *__restrict__ points, 11 | const uint32_t *point_adjacency, 12 | const uint32_t *point_adjacency_offsets, 13 | uint32_t num_points, 14 | uint32_t *__restrict__ indices, 15 | float *__restrict__ cell_radius) { 16 | uint32_t i = (blockIdx.x * blockDim.x + threadIdx.x); 17 | if (i >= num_points) { 18 | return; 19 | } 20 | 21 | Vec3f primal_point = points[i]; 22 | uint32_t point_adjacency_begin = point_adjacency_offsets[i]; 23 | uint32_t point_adjacency_end = point_adjacency_offsets[i + 1]; 24 | uint32_t num_faces = point_adjacency_end - point_adjacency_begin; 25 | uint32_t farthest_idx = UINT32_MAX; 26 | float sum_distance = 0.0f; 27 | float max_distance = 0.0f; 28 | 29 | for (uint32_t i = 0; i < num_faces; ++i) { 30 | uint32_t opposite_point_idx = 31 | point_adjacency[point_adjacency_begin + i]; 32 | Vec3f opposite_point = points[opposite_point_idx]; 33 | 34 | float distance = (opposite_point - primal_point).norm(); 35 | sum_distance += 0.5 * distance; 36 | if (distance > max_distance) { 37 | max_distance = distance; 38 | farthest_idx = opposite_point_idx; 39 | } 40 | } 41 | 42 | indices[i] = farthest_idx; 43 | cell_radius[i] = sum_distance / num_faces; 44 | } 45 | 46 | template 47 | void farthest_neighbor(const Vec3 *points, 48 | const uint32_t *point_adjacency, 49 | const uint32_t *point_adjacency_offsets, 50 | uint32_t num_points, 51 | uint32_t *indices, 52 | float *cell_radius, 53 | const void *stream) { 54 | launch_kernel_1d<1024>(farthest_neighbor_kernel, 55 | num_points, 56 | stream, 57 | points, 58 | point_adjacency, 59 | point_adjacency_offsets, 60 | num_points, 61 | indices, 62 | cell_radius); 63 | } 64 | 65 | void farthest_neighbor(ScalarType coord_scalar_type, 66 | const void *points, 67 | const void *point_adjacency, 68 | const void *point_adjacency_offsets, 69 | uint32_t num_points, 70 | void *indices, 71 | void *cell_radius, 72 | const void *stream) { 73 | 74 | if (coord_scalar_type == ScalarType::Float32) { 75 | farthest_neighbor( 76 | static_cast *>(points), 77 | static_cast(point_adjacency), 78 | static_cast(point_adjacency_offsets), 79 | num_points, 80 | static_cast(indices), 81 | static_cast(cell_radius), 82 | stream); 83 | } else { 84 | throw std::runtime_error("unsupported scalar type"); 85 | } 86 | } 87 | 88 | } // namespace radfoam -------------------------------------------------------------------------------- /src/delaunay/triangulation_ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../utils/geometry.h" 4 | 5 | namespace radfoam { 6 | 7 | /// @brief Find the farthest neighbor of each point 8 | void farthest_neighbor(ScalarType coord_scalar_type, 9 | const void *points, 10 | const void *point_adjacency, 11 | const void *point_adjacency_offsets, 12 | uint32_t num_points, 13 | void *indices, 14 | void *cell_radius, 15 | const void *stream = nullptr); 16 | 17 | } // namespace radfoam -------------------------------------------------------------------------------- /src/tracing/camera.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../utils/geometry.h" 4 | 5 | namespace radfoam { 6 | 7 | struct Ray { 8 | Vec3f origin; 9 | Vec3f direction; 10 | }; 11 | 12 | enum CameraModel { 13 | Pinhole, 14 | Fisheye, 15 | }; 16 | 17 | struct Camera { 18 | CVec3f position; 19 | CVec3f forward; 20 | CVec3f right; 21 | CVec3f up; 22 | float fov; 23 | uint32_t width; 24 | uint32_t height; 25 | CameraModel model; 26 | 27 | RADFOAM_HD void rotate(Vec3f axis, float angle) { 28 | auto rotation = Eigen::AngleAxisf(angle, axis); 29 | forward = rotation * *forward; 30 | right = rotation * *right; 31 | up = rotation * *up; 32 | } 33 | }; 34 | 35 | /// @brief Create a camera pointing from position to target 36 | inline RADFOAM_HD Camera look_at(const Vec3f &position, 37 | const Vec3f &target, 38 | const Vec3f &up, 39 | float fov, 40 | uint32_t width, 41 | uint32_t height, 42 | CameraModel model = Pinhole) { 43 | Camera camera; 44 | camera.position = position; 45 | camera.forward = (target - position).normalized(); 46 | camera.right = camera.forward->cross(up).normalized(); 47 | camera.up = camera.right->cross(*camera.forward).normalized(); 48 | camera.fov = fov; 49 | camera.width = width; 50 | camera.height = height; 51 | camera.model = model; 52 | return camera; 53 | } 54 | 55 | /// @brief Create a ray from the camera through pixel (i, j) 56 | inline RADFOAM_HD Ray cast_ray(const Camera &camera, int i, int j) { 57 | Ray ray; 58 | ray.origin = *camera.position; 59 | float aspect_ratio = static_cast(camera.width) / camera.height; 60 | float x = static_cast(i) / camera.width; 61 | float y = static_cast(j) / camera.height; 62 | 63 | float u = (2.0f * x - 1.0f) * aspect_ratio; 64 | float v = (1.0f - 2.0f * y); 65 | float mask = 1.0f; 66 | 67 | if (camera.model == Pinhole) { 68 | float w = 1.0f / tanf(camera.fov * 0.5f); 69 | ray.direction = 70 | w * *camera.forward + u * *camera.right + v * *camera.up; 71 | } else if (camera.model == Fisheye) { 72 | float theta = atan2f(v, u); 73 | float phi = camera.fov * sqrtf(u * u + v * v); 74 | if (phi >= M_PIf) { 75 | phi = M_PIf - 1e-6f; 76 | mask = 0.0f; 77 | } 78 | ray.direction = sinf(phi) * cosf(theta) * *camera.right + 79 | sinf(phi) * sinf(theta) * *camera.up + 80 | cosf(phi) * *camera.forward; 81 | } 82 | ray.direction = ray.direction.normalized() * mask; 83 | 84 | return ray; 85 | } 86 | 87 | } // namespace radfoam 88 | -------------------------------------------------------------------------------- /src/tracing/pipeline.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../utils/typing.h" 6 | #include "camera.h" 7 | 8 | namespace radfoam { 9 | 10 | struct TraceSettings { 11 | float weight_threshold; 12 | uint32_t max_intersections; 13 | }; 14 | 15 | inline TraceSettings default_trace_settings() { 16 | TraceSettings settings; 17 | settings.weight_threshold = 0.001f; 18 | settings.max_intersections = 1024; 19 | return settings; 20 | } 21 | 22 | enum VisualizationMode { 23 | RGB = 0, 24 | Depth = 1, 25 | Alpha = 2, 26 | Intersections = 3, 27 | }; 28 | 29 | struct VisualizationSettings { 30 | VisualizationMode mode; 31 | ColorMap color_map; 32 | CVec3f bg_color; 33 | bool checker_bg; 34 | float max_depth; 35 | float depth_quantile; 36 | }; 37 | 38 | inline VisualizationSettings default_visualization_settings() { 39 | VisualizationSettings settings; 40 | settings.mode = RGB; 41 | settings.color_map = Turbo; 42 | settings.bg_color = Vec3f(1.0f, 1.0f, 1.0f); 43 | settings.checker_bg = false; 44 | settings.max_depth = 10.0f; 45 | settings.depth_quantile = 0.5f; 46 | return settings; 47 | } 48 | 49 | /// @brief Prefetch offset for each edge in the adjacency matrix 50 | void prefetch_adjacent_diff(const Vec3f *points, 51 | uint32_t num_points, 52 | uint32_t point_adjacency_size, 53 | const uint32_t *point_adjacency, 54 | const uint32_t *point_adjacency_offsets, 55 | Vec4h *adjacent_diff, 56 | const void *stream); 57 | 58 | class Pipeline { 59 | public: 60 | virtual ~Pipeline() = default; 61 | 62 | virtual void trace_forward(const TraceSettings &settings, 63 | uint32_t num_points, 64 | const Vec3f *points, 65 | const void *attributes, 66 | uint32_t point_adjacency_size, 67 | const uint32_t *point_adjacency, 68 | const uint32_t *point_adjacency_offsets, 69 | uint32_t num_rays, 70 | const Ray *rays, 71 | const uint32_t *start_point_index, 72 | uint32_t num_depth_quantiles, 73 | const float *depth_quantiles, 74 | void *ray_rgba, 75 | float *quantile_dpeths, 76 | uint32_t *quantile_point_indices, 77 | uint32_t *num_intersections, 78 | void *point_contribution) = 0; 79 | 80 | virtual void trace_backward(const TraceSettings &settings, 81 | uint32_t num_points, 82 | const Vec3f *points, 83 | const void *attributes, 84 | uint32_t point_adjacency_size, 85 | const uint32_t *point_adjacency, 86 | const uint32_t *point_adjacency_offsets, 87 | uint32_t num_rays, 88 | const Ray *rays, 89 | const uint32_t *start_point_index, 90 | uint32_t num_depth_quantiles, 91 | const float *depth_quantiles, 92 | const uint32_t *quantile_point_indices, 93 | const void *ray_rgba, 94 | const void *ray_rgba_grad, 95 | const float *depth_grad, 96 | const void *ray_error, 97 | Ray *ray_grad, 98 | Vec3f *points_grad, 99 | void *attribute_grad, 100 | void *point_error) = 0; 101 | 102 | virtual void trace_visualization(const TraceSettings &settings, 103 | const VisualizationSettings &vis_settings, 104 | const Camera &camera, 105 | CMapTable cmap_table, 106 | uint32_t num_points, 107 | uint32_t num_tets, 108 | const void *points, 109 | const void *attributes, 110 | const void *point_adjacency, 111 | const void *point_adjacency_offsets, 112 | const void *adjacent_points, 113 | uint32_t start_index, 114 | uint64_t output_surface, 115 | const void *stream = nullptr) = 0; 116 | 117 | virtual void trace_benchmark(const TraceSettings &settings, 118 | uint32_t num_points, 119 | const Vec3f *points, 120 | const void *attributes, 121 | const uint32_t *point_adjacency, 122 | const uint32_t *point_adjacency_offsets, 123 | const Vec4h *adjacent_diff, 124 | Camera camera, 125 | const uint32_t *start_point_index, 126 | uint32_t *ray_rgba) = 0; 127 | 128 | virtual uint32_t attribute_dim() const = 0; 129 | 130 | virtual ScalarType attribute_type() const = 0; 131 | }; 132 | 133 | std::shared_ptr create_pipeline(int sh_degree, ScalarType attr_type); 134 | 135 | } // namespace radfoam -------------------------------------------------------------------------------- /src/tracing/sh_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../utils/geometry.h" 4 | #include "camera.h" 5 | 6 | namespace radfoam { 7 | 8 | __constant__ float C0 = 0.28209479177387814f; 9 | __constant__ float C1 = 0.4886025119029199f; 10 | __constant__ float C2[5] = {1.0925484305920792f, 11 | -1.0925484305920792f, 12 | 0.31539156525252005f, 13 | -1.0925484305920792f, 14 | 0.5462742152960396f}; 15 | __constant__ float C3[7] = {-0.5900435899266435f, 16 | 2.890611442640554f, 17 | -0.4570457994644658f, 18 | 0.3731763325901154f, 19 | -0.4570457994644658f, 20 | 1.445305721320277f, 21 | -0.5900435899266435f}; 22 | __constant__ float C4[9] = {2.5033429417967046f, 23 | -1.7701307697799304f, 24 | 0.9461746957575601f, 25 | -0.6690465435572892f, 26 | 0.10578554691520431f, 27 | -0.6690465435572892f, 28 | 0.47308734787878004f, 29 | -1.7701307697799304f, 30 | 0.6258357354491761f}; 31 | 32 | constexpr int sh_dimension(int degree) { return (degree + 1) * (degree + 1); } 33 | 34 | template 35 | __device__ Vecf sh_coefficients(const Vec3f &dir) { 36 | float x = dir[0]; 37 | float y = dir[1]; 38 | float z = dir[2]; 39 | 40 | Vecf sh = Vecf::Zero(); 41 | 42 | sh[0] = C0; 43 | 44 | if (degree > 0) { 45 | sh[1] = -C1 * y; 46 | sh[2] = C1 * z; 47 | sh[3] = -C1 * x; 48 | } 49 | float xx = x * x, yy = y * y, zz = z * z; 50 | float xy = x * y, yz = y * z, xz = x * z; 51 | if (degree > 1) { 52 | 53 | sh[4] = C2[0] * xy; 54 | sh[5] = C2[1] * yz; 55 | sh[6] = C2[2] * (2.0f * zz - xx - yy); 56 | sh[7] = C2[3] * xz; 57 | sh[8] = C2[4] * (xx - yy); 58 | } 59 | if (degree > 2) { 60 | sh[9] = C3[0] * y * (3.0f * xx - yy); 61 | sh[10] = C3[1] * xy * z; 62 | sh[11] = C3[2] * y * (4.0f * zz - xx - yy); 63 | sh[12] = C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy); 64 | sh[13] = C3[4] * x * (4.0f * zz - xx - yy); 65 | sh[14] = C3[5] * z * (xx - yy); 66 | sh[15] = C3[6] * x * (xx - 3.0f * yy); 67 | } 68 | 69 | return sh; 70 | } 71 | 72 | template 73 | __device__ Vec3f load_sh_as_rgb(const Vecf &coeffs, 74 | const scalar *sh_rgb_vals) { 75 | Vec3f rgb = Vec3f(0.5f, 0.5f, 0.5f); 76 | 77 | #pragma unroll 78 | for (uint32_t i = 0; i < 3 * sh_dimension(degree); ++i) { 79 | rgb[i % 3] += coeffs[i / 3] * (float)sh_rgb_vals[i]; 80 | } 81 | 82 | return rgb.cwiseMax(0.0f); 83 | } 84 | 85 | template 86 | __device__ void write_rgb_grad_to_sh(const Vecf &coeffs, 87 | Vec3f grad_rgb, 88 | scalar *sh_rgb_grad) { 89 | for (uint32_t i = 0; i < 3 * sh_dimension(degree); ++i) { 90 | atomicAdd(sh_rgb_grad + i, (scalar)(coeffs[i / 3] * grad_rgb[i % 3])); 91 | } 92 | } 93 | 94 | template 95 | __device__ Vec3 96 | forward_sh(uint32_t deg, Vec sh_vec, Vec3f dirs) { 97 | float x = dirs[0]; 98 | float y = dirs[1]; 99 | float z = dirs[2]; 100 | 101 | constexpr int sh_vars = int(sh_dim / 3); 102 | Eigen::Map> sh_mat(sh_vec.data()); 103 | 104 | Vec3 result = C0 * sh_mat.col(0); 105 | if (deg > 0) { 106 | result = result - C1 * y * sh_mat.col(1) + C1 * z * sh_mat.col(2) - 107 | C1 * x * sh_mat.col(3); 108 | } 109 | 110 | float xx = x * x, yy = y * y, zz = z * z; 111 | float xy = x * y, yz = y * z, xz = x * z; 112 | if (deg > 1) { 113 | result = result + C2[0] * xy * sh_mat.col(4) + 114 | C2[1] * yz * sh_mat.col(5) + 115 | C2[2] * (2.0f * zz - xx - yy) * sh_mat.col(6) + 116 | C2[3] * xz * sh_mat.col(7) + C2[4] * (xx - yy) * sh_mat.col(8); 117 | } 118 | if (deg > 2) { 119 | result = 120 | result + C3[0] * y * (3.0f * xx - yy) * sh_mat.col(9) + 121 | C3[1] * xy * z * sh_mat.col(10) + 122 | C3[2] * y * (4.0f * zz - xx - yy) * sh_mat.col(11) + 123 | C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh_mat.col(12) + 124 | C3[4] * x * (4.0f * zz - xx - yy) * sh_mat.col(13) + 125 | C3[5] * z * (xx - yy) * sh_mat.col(14) + 126 | C3[6] * x * (xx - 3.0f * yy) * sh_mat.col(15); 127 | } 128 | 129 | result = result.array() + 0.5f; 130 | return result; 131 | } 132 | 133 | template 134 | __device__ Vec 135 | backward_sh(uint32_t deg, Vec3 pd_color, Vec3f dirs) { 136 | float x = dirs[0]; 137 | float y = dirs[1]; 138 | float z = dirs[2]; 139 | 140 | constexpr int sh_vars = int(sh_dim / 3); 141 | Eigen::Matrix pd_sh; 142 | 143 | pd_sh.col(0) = C0 * pd_color; 144 | if (deg > 0) { 145 | pd_sh.col(1) = -C1 * y * pd_color; 146 | pd_sh.col(2) = C1 * z * pd_color; 147 | pd_sh.col(3) = -C1 * x * pd_color; 148 | 149 | if (deg > 1) { 150 | float xx = x * x, yy = y * y, zz = z * z; 151 | float xy = x * y, yz = y * z, xz = x * z; 152 | 153 | pd_sh.col(4) = C2[0] * xy * pd_color; 154 | pd_sh.col(5) = C2[1] * yz * pd_color; 155 | pd_sh.col(6) = C2[2] * (2.0f * zz - xx - yy) * pd_color; 156 | pd_sh.col(7) = C2[3] * xz * pd_color; 157 | pd_sh.col(8) = C2[4] * (xx - yy) * pd_color; 158 | 159 | if (deg > 2) { 160 | pd_sh.col(9) = C3[0] * y * (3.0f * xx - yy) * pd_color; 161 | pd_sh.col(10) = C3[1] * xy * z * pd_color; 162 | pd_sh.col(11) = C3[2] * y * (4.0f * zz - xx - yy) * pd_color; 163 | pd_sh.col(12) = 164 | C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * pd_color; 165 | pd_sh.col(13) = C3[4] * x * (4.0f * zz - xx - yy) * pd_color; 166 | pd_sh.col(14) = C3[5] * z * (xx - yy) * pd_color; 167 | pd_sh.col(15) = C3[6] * x * (xx - 3.0f * yy) * pd_color; 168 | } 169 | } 170 | } 171 | 172 | Eigen::Map> pd_sh_vector(pd_sh.data()); 173 | return pd_sh_vector; 174 | } 175 | 176 | } // namespace radfoam 177 | -------------------------------------------------------------------------------- /src/tracing/tracing_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../utils/geometry.h" 4 | #include "camera.h" 5 | 6 | namespace radfoam { 7 | 8 | template 9 | __forceinline__ __device__ uint32_t 10 | trace(const Ray &ray, 11 | const Vec3f *__restrict__ points, 12 | const uint32_t *__restrict__ point_adjacency, 13 | const uint32_t *__restrict__ point_adjacency_offsets, 14 | const Vec4h *__restrict__ adjacent_points, 15 | uint32_t start_point, 16 | uint32_t max_steps, 17 | CellFunctor cell_functor) { 18 | float t_0 = 0.0f; 19 | uint32_t n = 0; 20 | 21 | uint32_t current_point_idx = start_point; 22 | Vec3f primal_point = points[current_point_idx]; 23 | 24 | for (;;) { 25 | n++; 26 | if (n > max_steps) { 27 | break; 28 | } 29 | 30 | // Outer loop iterates through Voronoi cells 31 | uint32_t point_adjacency_begin = 32 | point_adjacency_offsets[current_point_idx]; 33 | uint32_t point_adjacency_end = 34 | point_adjacency_offsets[current_point_idx + 1]; 35 | 36 | uint32_t num_faces = point_adjacency_end - point_adjacency_begin; 37 | float t_1 = std::numeric_limits::infinity(); 38 | 39 | uint32_t next_face = UINT32_MAX; 40 | Vec3f next_point = Vec3f::Zero(); 41 | 42 | half2 chunk[chunk_size * 2]; 43 | for (uint32_t i = 0; i < num_faces; i += chunk_size) { 44 | #pragma unroll 45 | for (uint32_t j = 0; j < chunk_size; ++j) { 46 | chunk[2 * j] = reinterpret_cast( 47 | adjacent_points + point_adjacency_begin + i + j)[0]; 48 | chunk[2 * j + 1] = reinterpret_cast( 49 | adjacent_points + point_adjacency_begin + i + j)[1]; 50 | } 51 | 52 | #pragma unroll 53 | for (uint32_t j = 0; j < chunk_size; ++j) { 54 | Vec3f offset(__half2float(chunk[2 * j].x), 55 | __half2float(chunk[2 * j].y), 56 | __half2float(chunk[2 * j + 1].x)); 57 | Vec3f face_origin = primal_point + offset / 2.0f; 58 | Vec3f face_normal = offset; 59 | float dp = face_normal.dot(ray.direction); 60 | float t = (face_origin - ray.origin).dot(face_normal) / dp; 61 | 62 | if (dp > 0.0f && t < t_1 && (i + j) < num_faces) { 63 | t_1 = t; 64 | next_face = i + j; 65 | } 66 | } 67 | } 68 | 69 | if (next_face == UINT32_MAX) { 70 | break; 71 | } 72 | 73 | uint32_t next_point_idx = 74 | point_adjacency[point_adjacency_begin + next_face]; 75 | next_point = points[next_point_idx]; 76 | 77 | if (t_1 > t_0) { 78 | if (!cell_functor( 79 | current_point_idx, t_0, t_1, primal_point, next_point)) { 80 | break; 81 | } 82 | } 83 | t_0 = fmaxf(t_0, t_1); 84 | current_point_idx = next_point_idx; 85 | primal_point = next_point; 86 | } 87 | 88 | return n; 89 | } 90 | 91 | __forceinline__ __device__ Vec3f cell_intersection_grad( 92 | const Vec3f &primal_point, const Vec3f &opposite_point, const Ray &ray) { 93 | Vec3f face_origin = (primal_point + opposite_point) / 2.0f; 94 | Vec3f face_normal = (opposite_point - primal_point); 95 | 96 | float num = (face_origin - ray.origin).dot(face_normal); 97 | float dp = face_normal.dot(ray.direction); 98 | 99 | Vec3f grad = num * ray.direction + dp * (ray.origin - primal_point); 100 | grad /= dp * dp; 101 | 102 | return grad; 103 | } 104 | 105 | inline RADFOAM_HD uint32_t make_rgba8(float r, float g, float b, float a) { 106 | r = std::max(0.0f, std::min(1.0f, r)); 107 | g = std::max(0.0f, std::min(1.0f, g)); 108 | b = std::max(0.0f, std::min(1.0f, b)); 109 | a = std::max(0.0f, std::min(1.0f, a)); 110 | int ri = static_cast(r * 255.0f); 111 | int gi = static_cast(g * 255.0f); 112 | int bi = static_cast(b * 255.0f); 113 | int ai = static_cast(a * 255.0f); 114 | return (ai << 24) | (bi << 16) | (gi << 8) | ri; 115 | } 116 | 117 | inline __device__ Vec3f colormap(float v, 118 | ColorMap map, 119 | const CMapTable &cmap_table) { 120 | int map_len = cmap_table.sizes[map]; 121 | const Vec3f *map_vals = 122 | reinterpret_cast(cmap_table.data[map]); 123 | 124 | int i0 = static_cast(v * (map_len - 1)); 125 | int i1 = i0 + 1; 126 | float t = v * (map_len - 1) - i0; 127 | i0 = max(0, min(i0, map_len - 1)); 128 | i1 = max(0, min(i1, map_len - 1)); 129 | return map_vals[i0] * (1.0f - t) + map_vals[i1] * t; 130 | } 131 | 132 | } // namespace radfoam -------------------------------------------------------------------------------- /src/utils/batch_fetcher.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | #define THRUST_HOST_SYSTEM THRUST_HOST_SYSTEM_TBB 11 | #include 12 | #include 13 | #include 14 | 15 | #include "batch_fetcher.h" 16 | #include "cuda_helpers.h" 17 | #include "random.h" 18 | 19 | namespace radfoam { 20 | 21 | constexpr int buffer_size = 4; 22 | 23 | struct Batch { 24 | CUdeviceptr data; 25 | CUevent data_ready_event; 26 | }; 27 | 28 | class BatchFetcherImpl : public BatchFetcher { 29 | public: 30 | BatchFetcherImpl(const uint8_t *data, 31 | size_t num_bytes, 32 | size_t stride, 33 | size_t batch_size, 34 | bool shuffle) 35 | : worker_exception(nullptr), done(false) { 36 | 37 | CUcontext context; 38 | cuda_check(cuCtxGetCurrent(&context)); 39 | 40 | if (context == nullptr) { 41 | throw std::runtime_error("No CUDA context found"); 42 | } 43 | 44 | worker = std::thread([=] { 45 | try { 46 | size_t num_elemnts = num_bytes / stride; 47 | if (num_elemnts > (size_t)__UINT32_MAX__) { 48 | throw std::runtime_error("Too many elements"); 49 | } 50 | uint32_t batch_idx = 0; 51 | 52 | cuda_check(cuCtxSetCurrent(context)); 53 | 54 | CUstream stream; 55 | 56 | std::vector cpu_batch_buffer[buffer_size]; 57 | CUdeviceptr gpu_batch_buffer[buffer_size]; 58 | CUevent events[buffer_size]; 59 | 60 | cuda_check(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); 61 | 62 | auto upload_batch = [&](int i) { 63 | auto copy_element = [&](int j) { 64 | size_t idx; 65 | if (shuffle) { 66 | auto rng = make_rng(batch_idx * batch_size + j); 67 | idx = randint(rng, 0, num_elemnts); 68 | } else { 69 | idx = (batch_idx * batch_size + j) % num_elemnts; 70 | } 71 | memcpy(cpu_batch_buffer[i].data() + j * stride, 72 | data + idx * stride, 73 | stride); 74 | }; 75 | thrust::for_each(thrust::host, 76 | thrust::counting_iterator(0), 77 | thrust::counting_iterator(batch_size), 78 | copy_element); 79 | batch_idx += 1; 80 | cuda_check(cuMemcpyHtoDAsync(gpu_batch_buffer[i], 81 | cpu_batch_buffer[i].data(), 82 | batch_size * stride, 83 | stream)); 84 | cuda_check(cuEventRecord(events[i], stream)); 85 | 86 | return Batch{gpu_batch_buffer[i], events[i]}; 87 | }; 88 | 89 | for (int i = 0; i < buffer_size; ++i) { 90 | cpu_batch_buffer[i].resize(batch_size * stride); 91 | cuda_check( 92 | cuMemAlloc(&gpu_batch_buffer[i], batch_size * stride)); 93 | cuda_check( 94 | cuEventCreate(&events[i], CU_EVENT_BLOCKING_SYNC)); 95 | } 96 | 97 | int i = 0; 98 | while (!this->done) { 99 | auto batch = upload_batch(i); 100 | while (!queue.try_push(batch) && !this->done) { 101 | std::this_thread::yield(); 102 | } 103 | i = (i + 1) % buffer_size; 104 | } 105 | 106 | // Free resources 107 | cuda_check(cuStreamSynchronize(stream)); 108 | for (int i = 0; i < buffer_size; i++) { 109 | cuda_check(cuMemFree(gpu_batch_buffer[i])); 110 | cuda_check(cuEventDestroy(events[i])); 111 | } 112 | cuda_check(cuStreamDestroy(stream)); 113 | } catch (...) { 114 | this->worker_exception = std::current_exception(); 115 | this->done = true; 116 | } 117 | }); 118 | } 119 | 120 | ~BatchFetcherImpl() { 121 | done = true; 122 | worker.join(); 123 | } 124 | 125 | void *next() override { 126 | Batch batch = {}; 127 | while (!done && !queue.try_pop(batch)) { 128 | std::this_thread::yield(); 129 | } 130 | if (done) { 131 | worker.join(); 132 | std::rethrow_exception(worker_exception); 133 | } 134 | cuda_check(cuEventSynchronize(batch.data_ready_event)); 135 | return reinterpret_cast(batch.data); 136 | } 137 | 138 | private: 139 | std::exception_ptr worker_exception; 140 | std::thread worker; 141 | std::atomic_bool done; 142 | atomic_queue::AtomicQueue2 queue; 143 | }; 144 | 145 | std::unique_ptr create_batch_fetcher(const void *data, 146 | size_t num_bytes, 147 | size_t stride, 148 | size_t batch_size, 149 | bool shuffle) { 150 | return std::make_unique( 151 | static_cast(data), 152 | num_bytes, 153 | stride, 154 | batch_size, 155 | shuffle); 156 | } 157 | 158 | } // namespace radfoam -------------------------------------------------------------------------------- /src/utils/batch_fetcher.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace radfoam { 6 | 7 | class BatchFetcher { 8 | public: 9 | virtual ~BatchFetcher() = default; 10 | 11 | virtual void *next() = 0; 12 | }; 13 | 14 | std::unique_ptr create_batch_fetcher(const void *data, 15 | size_t num_bytes, 16 | size_t stride, 17 | size_t batch_size, 18 | bool shuffle); 19 | 20 | } // namespace radfoam -------------------------------------------------------------------------------- /src/utils/common_kernels.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "cuda_helpers.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "unenumerate_iterator.cuh" 12 | 13 | namespace radfoam { 14 | 15 | template 16 | __global__ void for_n_kernel(InputIterator begin, size_t n, UnaryFunction f) { 17 | size_t i = blockIdx.x * blockDim.x + threadIdx.x; 18 | size_t stride = gridDim.x * blockDim.x; 19 | while (i < n) { 20 | f(begin[i]); 21 | i += stride; 22 | } 23 | } 24 | 25 | template 26 | void launch_kernel_1d(Kernel kernel, 27 | size_t n, 28 | const void *stream, 29 | Args... args) { 30 | if (n == 0) { 31 | return; 32 | } 33 | size_t num_blocks = (n + block_size - 1) / block_size; 34 | if (stream) { 35 | cudaStream_t s = *reinterpret_cast(stream); 36 | kernel<<>>(args...); 37 | } else { 38 | kernel<<>>(args...); 39 | } 40 | cuda_check(cudaGetLastError()); 41 | } 42 | 43 | template 44 | void for_n_b(InputIterator begin, 45 | size_t n, 46 | UnaryFunction f, 47 | bool strided = false, 48 | const void *stream = nullptr) { 49 | size_t num_threads = n; 50 | if (strided) { 51 | int mpc; 52 | cuda_check( 53 | cudaDeviceGetAttribute(&mpc, cudaDevAttrMultiProcessorCount, 0)); 54 | num_threads = block_size * mpc; 55 | } 56 | 57 | launch_kernel_1d(for_n_kernel, 58 | num_threads, 59 | stream, 60 | begin, 61 | n, 62 | f); 63 | } 64 | 65 | template 66 | void for_n(InputIterator begin, 67 | size_t n, 68 | UnaryFunction f, 69 | bool strided = false, 70 | const void *stream = nullptr) { 71 | for_n_b<256>(begin, n, f, strided, stream); 72 | } 73 | 74 | template 75 | void for_range(InputIterator begin, 76 | InputIterator end, 77 | UnaryFunction f, 78 | bool strided = false, 79 | const void *stream = nullptr) { 80 | size_t n = end - begin; 81 | for_n(begin, n, f, strided, stream); 82 | } 83 | 84 | template 87 | struct TransformFunctor { 88 | InputIterator begin; 89 | OutputIterator result; 90 | UnaryFunction f; 91 | 92 | __device__ void operator()(decltype(*begin) x) { 93 | size_t i = blockIdx.x * blockDim.x + threadIdx.x; 94 | result[i] = f(x); 95 | } 96 | }; 97 | 98 | template 101 | void transform_range(InputIterator begin, 102 | InputIterator end, 103 | OutputIterator result, 104 | UnaryFunction f, 105 | bool strided = false, 106 | const void *stream = nullptr) { 107 | size_t n = end - begin; 108 | TransformFunctor func = { 109 | begin, result, f}; 110 | for_n(begin, n, func, strided, stream); 111 | } 112 | 113 | template 114 | void copy_range(InputIterator begin, 115 | InputIterator end, 116 | OutputIterator result, 117 | bool strided = false, 118 | const void *stream = nullptr) { 119 | transform_range( 120 | begin, 121 | end, 122 | result, 123 | [] __device__(auto x) { return x; }, 124 | strided, 125 | stream); 126 | } 127 | 128 | inline cub::CountingInputIterator u32zero() { 129 | return cub::CountingInputIterator(0); 130 | } 131 | 132 | inline cub::CountingInputIterator u64zero() { 133 | return cub::CountingInputIterator(0); 134 | } 135 | 136 | template 137 | inline cub::ArgIndexInputIterator enumerate(T *begin) { 138 | return cub::ArgIndexInputIterator(begin); 139 | } 140 | 141 | template 142 | inline UnenumerateIterator unenumerate(T *begin) { 143 | return UnenumerateIterator(begin); 144 | } 145 | 146 | inline cub::DiscardOutputIterator<> discard() { 147 | return cub::DiscardOutputIterator(); 148 | } 149 | 150 | } // namespace radfoam -------------------------------------------------------------------------------- /src/utils/cuda_array.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cuda_helpers.h" 9 | #include "typing.h" 10 | 11 | #define CUB_CALL(call) \ 12 | { \ 13 | size_t temp_bytes = 0; \ 14 | void *temp_data; \ 15 | CUDAArray cub_temp_buffer; \ 16 | for (temp_data = nullptr;;) { \ 17 | cuda_check(call); \ 18 | if (temp_data) \ 19 | break; \ 20 | else { \ 21 | cub_temp_buffer.resize(temp_bytes); \ 22 | temp_data = cub_temp_buffer.begin(); \ 23 | } \ 24 | } \ 25 | } 26 | 27 | namespace radfoam { 28 | 29 | struct OpaqueBuffer { 30 | virtual ~OpaqueBuffer() = default; 31 | 32 | virtual void *data() = 0; 33 | }; 34 | 35 | std::unique_ptr allocate_buffer(size_t bytes); 36 | 37 | /// @brief RAII wrapper for CUDA array 38 | template 39 | class CUDAArray { 40 | template 41 | friend class CUDAArray; 42 | 43 | private: 44 | CUdeviceptr begin_ptr; 45 | CUdeviceptr end_ptr; 46 | std::unique_ptr buffer; 47 | 48 | public: 49 | /// @brief Construct an empty CUDA array 50 | CUDAArray() : begin_ptr(0), end_ptr(0), buffer(nullptr) {} 51 | 52 | /// @brief Construct a CUDA array of a given size 53 | CUDAArray(size_t size) { 54 | if (size == 0) { 55 | begin_ptr = 0; 56 | end_ptr = 0; 57 | buffer = nullptr; 58 | return; 59 | } 60 | buffer = allocate_buffer(size * sizeof(T)); 61 | begin_ptr = reinterpret_cast(buffer->data()); 62 | end_ptr = begin_ptr + (size * sizeof(T)); 63 | } 64 | 65 | CUDAArray(const CUDAArray &other) = delete; 66 | 67 | CUDAArray(CUDAArray &&other) { 68 | begin_ptr = other.begin_ptr; 69 | end_ptr = other.end_ptr; 70 | other.begin_ptr = 0; 71 | other.end_ptr = 0; 72 | buffer = std::move(other.buffer); 73 | } 74 | 75 | ~CUDAArray() = default; 76 | 77 | CUDAArray &operator=(const CUDAArray &other) = delete; 78 | 79 | CUDAArray &operator=(CUDAArray &&other) { 80 | begin_ptr = other.begin_ptr; 81 | end_ptr = other.end_ptr; 82 | other.begin_ptr = 0; 83 | other.end_ptr = 0; 84 | buffer = std::move(other.buffer); 85 | return *this; 86 | } 87 | 88 | template 89 | CUDAArray(CUDAArray &&other) { 90 | begin_ptr = other.begin_ptr; 91 | end_ptr = other.end_ptr; 92 | other.begin_ptr = 0; 93 | other.end_ptr = 0; 94 | buffer = std::move(other.buffer); 95 | } 96 | 97 | /// @brief Get a pointer to the beginning of the array 98 | T *begin() { return reinterpret_cast(begin_ptr); } 99 | 100 | /// @brief Get a pointer to the end of the array 101 | T *end() { 102 | size_t size_bytes = end_ptr - begin_ptr; 103 | size_t elements = size_bytes / sizeof(T); 104 | return reinterpret_cast(begin_ptr) + elements; 105 | } 106 | 107 | /// @brief Get a pointer to the beginning of the array 108 | const T *begin() const { return reinterpret_cast(begin_ptr); } 109 | 110 | /// @brief Get a pointer to the end of the array 111 | const T *end() const { 112 | size_t size_bytes = end_ptr - begin_ptr; 113 | size_t elements = size_bytes / sizeof(T); 114 | return reinterpret_cast(begin_ptr) + elements; 115 | } 116 | 117 | /// @brief Get a pointer to the beginning of the array as a different type 118 | template 119 | U *begin_as() { 120 | return reinterpret_cast(begin_ptr); 121 | } 122 | 123 | /// @brief Get a pointer to the end of the array as a different type 124 | template 125 | U *end_as() { 126 | size_t size_bytes = end_ptr - begin_ptr; 127 | size_t elements = size_bytes / sizeof(T); 128 | return reinterpret_cast(reinterpret_cast(begin_ptr) + 129 | elements); 130 | } 131 | 132 | /// @brief Get a pointer to the beginning of the array as a different type 133 | template 134 | const U *begin_as() const { 135 | return reinterpret_cast(begin_ptr); 136 | } 137 | 138 | /// @brief Get a pointer to the end of the array as a different type 139 | template 140 | const U *end_as() const { 141 | size_t size_bytes = end_ptr - begin_ptr; 142 | size_t elements = size_bytes / sizeof(T); 143 | return reinterpret_cast( 144 | reinterpret_cast(begin_ptr) + elements); 145 | } 146 | 147 | /// @brief Get the number of elements in the array 148 | size_t size() const { return (end_ptr - begin_ptr) / sizeof(T); } 149 | 150 | /// @brief Set the number of elements in the array 151 | /// @param size The new size of the array 152 | /// @param preserve_content If true, the elements currently in the array 153 | /// will be copied to the beginning of the new array, up to the minimum of 154 | /// the old and new sizes 155 | void resize(size_t size, bool preserve_content = false) { 156 | if (size == this->size()) { 157 | return; 158 | } 159 | if (begin_ptr) { 160 | auto new_buffer = allocate_buffer(size * sizeof(T)); 161 | CUdeviceptr new_begin_ptr = 162 | reinterpret_cast(new_buffer->data()); 163 | if (preserve_content) { 164 | if (size > this->size()) { 165 | cuda_check(cuMemcpyDtoD( 166 | new_begin_ptr, begin_ptr, this->size() * sizeof(T))); 167 | } else { 168 | cuda_check(cuMemcpyDtoD( 169 | new_begin_ptr, begin_ptr, size * sizeof(T))); 170 | } 171 | } 172 | buffer = std::move(new_buffer); 173 | begin_ptr = new_begin_ptr; 174 | end_ptr = begin_ptr + (size * sizeof(T)); 175 | } else { 176 | *this = CUDAArray(size); 177 | } 178 | } 179 | 180 | /// @brief Expand the array to at least a given size 181 | /// @param size The minimum size of the array 182 | /// @param preserve_content If true, the elements currently in the array 183 | /// will be copied to the beginning of the new array, up to the minimum of 184 | /// the old and new sizes 185 | /// @param round_up If true, the size will be rounded up to the nearest 186 | /// power of 2 187 | void 188 | expand(size_t size, bool preserve_content = false, bool round_up = true) { 189 | if (size > this->size()) { 190 | if (round_up) 191 | size = pow2_round_up(size); 192 | resize(size, preserve_content); 193 | } 194 | } 195 | 196 | void clear() { 197 | begin_ptr = 0; 198 | end_ptr = 0; 199 | buffer = nullptr; 200 | } 201 | }; 202 | 203 | } // namespace radfoam 204 | -------------------------------------------------------------------------------- /src/utils/cuda_helpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace radfoam { 11 | 12 | inline void cuda_check_fn(cudaError_t err, int line, const char *file) { 13 | if (err != cudaSuccess) { 14 | std::string msg = "CUDA call at " + std::string(file) + ":" + 15 | std::to_string(line) + " failed: "; 16 | msg = msg + cudaGetErrorString(err); 17 | throw std::runtime_error(msg); 18 | } 19 | } 20 | 21 | inline void cuda_check_fn(CUresult err, int line, const char *file) { 22 | if (err != CUDA_SUCCESS) { 23 | const char *msg; 24 | cuGetErrorString(err, &msg); 25 | throw std::runtime_error(std::string("CUDA call at ") + file + ":" + 26 | std::to_string(line) + " failed: " + msg); 27 | } 28 | } 29 | 30 | #define cuda_check(call) radfoam::cuda_check_fn(call, __LINE__, __FILE__) 31 | 32 | inline void gl_check_fn(GLenum err, int line, const char *file) { 33 | if (err != GL_NO_ERROR) { 34 | throw std::runtime_error("OpenGL call at " + std::string(file) + ":" + 35 | std::to_string(line) + " failed"); 36 | } 37 | } 38 | 39 | #define gl_check(call) \ 40 | call; \ 41 | radfoam::gl_check_fn(glGetError(), __LINE__, __FILE__) 42 | 43 | inline void global_cuda_init() { 44 | cuda_check(cuInit(0)); 45 | cuda_check(cudaDeviceSetLimit(cudaLimitMallocHeapSize, 1ul << 29ul)); 46 | } 47 | 48 | } // namespace radfoam 49 | -------------------------------------------------------------------------------- /src/utils/random.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "typing.h" 7 | 8 | namespace radfoam { 9 | 10 | // https://github.com/skeeto/hash-prospector 11 | inline RADFOAM_HD uint32_t mix(uint32_t x) { 12 | x ^= x >> 17; 13 | x *= 0xed5ad4bb; 14 | x ^= x >> 11; 15 | x *= 0xac4c1b51; 16 | x ^= x >> 15; 17 | x *= 0x31848bab; 18 | x ^= x >> 14; 19 | return x; 20 | } 21 | 22 | struct RNGState { 23 | uint32_t bits; 24 | }; 25 | 26 | /// @brief Create a new RNG state with the given seed 27 | inline RADFOAM_HD RNGState make_rng(uint32_t seed) { 28 | return RNGState{mix(seed ^ 0x2815db5b)}; 29 | } 30 | 31 | #ifdef __CUDACC__ 32 | /// @brief Create an RNG with state unique to the current thread 33 | inline __device__ RNGState thread_rng() { 34 | uint32_t seed = 35 | threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) + 36 | blockDim.x * blockDim.y * blockDim.z * 37 | (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)); 38 | return make_rng(seed); 39 | } 40 | #endif 41 | 42 | /// @brief Generate a random integer in the range [0, 0xffffffff] 43 | inline RADFOAM_HD uint32_t randint(RNGState &rngstate) { 44 | uint32_t x = rngstate.bits; 45 | rngstate.bits = mix(rngstate.bits + 1); 46 | return x; 47 | } 48 | 49 | /// @brief Generate a random integer in the range [min, max) 50 | inline RADFOAM_HD uint32_t randint(RNGState &rngstate, 51 | uint32_t min, 52 | uint32_t max) { 53 | uint32_t diff = max - min; 54 | uint32_t x = randint(rngstate); 55 | x /= (0xffffffff / diff); 56 | return std::min(x, diff - 1) + min; 57 | } 58 | 59 | /// @brief Generate a random float in the range [0, 1] 60 | inline RADFOAM_HD float rand(RNGState &rngstate) { 61 | return float(randint(rngstate)) / float(0xffffffff); 62 | } 63 | 64 | /// @brief Generate a random float from a unit normal distribution 65 | inline RADFOAM_HD float randn(RNGState &rngstate) { 66 | // sample normal distribution using Box - Muller transform 67 | float u1 = std::max(rand(rngstate), std::numeric_limits::min()); 68 | float u2 = rand(rngstate); 69 | #ifdef __CUDA_ARCH__ 70 | float result = sqrtf(-2 * logf(u1)) * cosf(2 * M_PIf * u2); 71 | #else 72 | float result = std::sqrt(-2 * std::log(u1)) * std::cos(2 * M_PIf * u2); 73 | #endif 74 | return float(result); 75 | } 76 | 77 | } // namespace radfoam -------------------------------------------------------------------------------- /src/utils/typing.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | #ifdef __CUDACC__ 8 | #define RADFOAM_HD __host__ __device__ 9 | #else 10 | #define RADFOAM_HD 11 | #endif 12 | 13 | #ifndef M_PI 14 | #define M_PI 3.14159265358979323846 15 | #endif 16 | 17 | #ifndef M_PIf 18 | #define M_PIf 3.14159265358979323846f 19 | #endif 20 | 21 | namespace radfoam { 22 | 23 | enum ScalarType { 24 | Float16, 25 | Float32, 26 | Float64, 27 | UInt32, 28 | Int32, 29 | Int64, 30 | }; 31 | 32 | inline std::string scalar_to_string(ScalarType type) { 33 | switch (type) { 34 | case Float16: 35 | return "float16"; 36 | case Float32: 37 | return "float32"; 38 | case Float64: 39 | return "float64"; 40 | case UInt32: 41 | return "uint32"; 42 | case Int32: 43 | return "int32"; 44 | case Int64: 45 | return "int64"; 46 | default: 47 | return "unknown"; 48 | } 49 | } 50 | 51 | inline size_t scalar_size(ScalarType type) { 52 | switch (type) { 53 | case Float16: 54 | return 2; 55 | case Float32: 56 | return 4; 57 | case Float64: 58 | return 8; 59 | case UInt32: 60 | return 4; 61 | case Int32: 62 | return 4; 63 | case Int64: 64 | return 8; 65 | default: 66 | return 0; 67 | } 68 | } 69 | 70 | template 71 | constexpr ScalarType scalar_code() = delete; 72 | 73 | template <> 74 | constexpr ScalarType scalar_code<__half>() { 75 | return Float16; 76 | } 77 | 78 | template <> 79 | constexpr ScalarType scalar_code() { 80 | return Float32; 81 | } 82 | 83 | template <> 84 | constexpr ScalarType scalar_code() { 85 | return Float64; 86 | } 87 | 88 | template <> 89 | constexpr ScalarType scalar_code() { 90 | return UInt32; 91 | } 92 | 93 | template <> 94 | constexpr ScalarType scalar_code() { 95 | return Int32; 96 | } 97 | 98 | template <> 99 | constexpr ScalarType scalar_code() { 100 | return Int64; 101 | } 102 | 103 | template 104 | constexpr const char *scalar_cxx_name() = delete; 105 | 106 | template <> 107 | constexpr const char *scalar_cxx_name() { 108 | return "uint32_t"; 109 | } 110 | 111 | template <> 112 | constexpr const char *scalar_cxx_name<__half>() { 113 | return "__half"; 114 | } 115 | 116 | template <> 117 | constexpr const char *scalar_cxx_name() { 118 | return "float"; 119 | } 120 | 121 | template <> 122 | constexpr const char *scalar_cxx_name() { 123 | return "double"; 124 | } 125 | 126 | template <> 127 | constexpr const char *scalar_cxx_name() { 128 | return "int32_t"; 129 | } 130 | 131 | template <> 132 | constexpr const char *scalar_cxx_name() { 133 | return "int64_t"; 134 | } 135 | 136 | enum ColorMap { 137 | Gray = 0, 138 | Viridis = 1, 139 | Inferno = 2, 140 | Turbo = 3, 141 | }; 142 | 143 | struct CMapTable { 144 | const float *const *data; 145 | const int *sizes; 146 | }; 147 | 148 | template 149 | RADFOAM_HD void swap(T &a, T &b) { 150 | typename std::decay::type tmp = a; 151 | a = b; 152 | b = tmp; 153 | } 154 | 155 | /// @brief Compute the base-2 logarithm of an integer 156 | inline RADFOAM_HD uint32_t log2(uint32_t x) { 157 | #if defined(__CUDA_ARCH__) 158 | return (x > 0) ? 31 - __clz(x) : 0; 159 | #else 160 | uint32_t result = 0; 161 | while (x >>= 1) { 162 | result++; 163 | } 164 | return result; 165 | #endif 166 | } 167 | 168 | /// @brief Compute the smallest power of 2 greater than or equal to x 169 | inline RADFOAM_HD uint32_t pow2_round_up(uint32_t x) { 170 | return (x > 1) ? 1 << (log2(x - 1) + 1) : 1; 171 | } 172 | 173 | } // namespace radfoam -------------------------------------------------------------------------------- /src/utils/unenumerate_iterator.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace radfoam { 6 | 7 | /// @brief An output iterator that removes the key from a key-value pair 8 | template 9 | struct UnenumerateIterator { 10 | using value_type = void; 11 | using reference = void; 12 | using pointer = void; 13 | using difference_type = ptrdiff_t; 14 | using iterator_category = std::output_iterator_tag; 15 | 16 | T *ptr; 17 | 18 | __host__ __device__ __forceinline__ UnenumerateIterator() : ptr(nullptr) {} 19 | 20 | __host__ __device__ __forceinline__ UnenumerateIterator(T *ptr) 21 | : ptr(ptr) {} 22 | 23 | __host__ __device__ __forceinline__ UnenumerateIterator operator++() { 24 | return UnenumerateIterator(ptr + 1); 25 | } 26 | 27 | __host__ __device__ __forceinline__ UnenumerateIterator operator++(int) { 28 | UnenumerateIterator retval = *this; 29 | ptr++; 30 | return retval; 31 | } 32 | 33 | __host__ __device__ __forceinline__ UnenumerateIterator &operator*() { 34 | return *this; 35 | } 36 | 37 | template 38 | __host__ __device__ __forceinline__ UnenumerateIterator 39 | operator+(Distance n) const { 40 | return UnenumerateIterator(ptr + n); 41 | } 42 | 43 | template 44 | __host__ __device__ __forceinline__ UnenumerateIterator & 45 | operator+=(Distance n) { 46 | ptr += n; 47 | return *this; 48 | } 49 | 50 | template 51 | __host__ __device__ __forceinline__ UnenumerateIterator 52 | operator-(Distance n) const { 53 | return UnenumerateIterator(ptr - n); 54 | } 55 | 56 | template 57 | __host__ __device__ __forceinline__ UnenumerateIterator & 58 | operator-=(Distance n) { 59 | ptr -= n; 60 | return *this; 61 | } 62 | 63 | __host__ __device__ __forceinline__ ptrdiff_t 64 | operator-(UnenumerateIterator other) const { 65 | return ptr - other.ptr; 66 | } 67 | 68 | template 69 | __host__ __device__ __forceinline__ UnenumerateIterator 70 | operator[](Distance n) { 71 | return UnenumerateIterator(ptr + n); 72 | } 73 | 74 | __host__ __device__ __forceinline__ void 75 | operator=(const cub::KeyValuePair &x) { 76 | *ptr = x.value; 77 | } 78 | }; 79 | 80 | } // namespace radfoam 81 | -------------------------------------------------------------------------------- /src/viewer/viewer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "../tracing/pipeline.h" 7 | 8 | namespace radfoam { 9 | 10 | struct ViewerOptions { 11 | bool limit_framerate; 12 | int max_framerate; 13 | int total_iterations; 14 | Vec3f camera_pos; 15 | Vec3f camera_forward; 16 | Vec3f camera_up; 17 | }; 18 | 19 | inline ViewerOptions default_viewer_options() { 20 | ViewerOptions options; 21 | options.limit_framerate = true; 22 | options.max_framerate = 20; 23 | options.total_iterations = 0; 24 | options.camera_pos = Vec3f(2.5f, 2.5f, 2.5f); 25 | options.camera_forward = Vec3f(-1.0f, -1.0f, -1.0f).normalized(); 26 | options.camera_up = Vec3f(0.0f, 0.0f, 1.0f); 27 | return options; 28 | } 29 | 30 | class Viewer { 31 | public: 32 | ~Viewer() = default; 33 | 34 | virtual void update_scene(uint32_t num_points, 35 | uint32_t num_attrs, 36 | uint32_t num_point_adjacency, 37 | const void *coords, 38 | const void *attributes, 39 | const void *point_adjacency, 40 | const void *point_adjacency_offsets, 41 | const void *aabb_tree) = 0; 42 | 43 | virtual void step(int iteration) = 0; 44 | 45 | virtual bool is_closed() const = 0; 46 | 47 | virtual const Pipeline &get_pipeline() const = 0; 48 | }; 49 | 50 | void run_with_viewer(std::shared_ptr pipeline, 51 | std::function)> callback, 52 | ViewerOptions options); 53 | 54 | } // namespace radfoam -------------------------------------------------------------------------------- /teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/radfoam/3e7b52cf74e37ab2ab5e695f53570f515f537e3d/teaser.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import configargparse 4 | import warnings 5 | 6 | warnings.filterwarnings("ignore") 7 | 8 | import torch 9 | 10 | from data_loader import DataHandler 11 | from configs import * 12 | from radfoam_model.scene import RadFoamScene 13 | from radfoam_model.utils import psnr 14 | import radfoam 15 | 16 | 17 | seed = 42 18 | torch.random.manual_seed(seed) 19 | np.random.seed(seed) 20 | 21 | 22 | def test(args, pipeline_args, model_args, optimizer_args, dataset_args): 23 | checkpoint = args.config.replace("/config.yaml", "") 24 | os.makedirs(os.path.join(checkpoint, "test"), exist_ok=True) 25 | device = torch.device(args.device) 26 | 27 | test_data_handler = DataHandler( 28 | dataset_args, rays_per_batch=0, device=device 29 | ) 30 | test_data_handler.reload( 31 | split="test", downsample=min(dataset_args.downsample) 32 | ) 33 | test_ray_batch_fetcher = radfoam.BatchFetcher( 34 | test_data_handler.rays, batch_size=1, shuffle=False 35 | ) 36 | test_rgb_batch_fetcher = radfoam.BatchFetcher( 37 | test_data_handler.rgbs, batch_size=1, shuffle=False 38 | ) 39 | 40 | # Setting up model 41 | model = RadFoamScene(args=model_args, device=device) 42 | 43 | model.load_pt(f"{checkpoint}/model.pt") 44 | 45 | def test_render( 46 | test_data_handler, ray_batch_fetcher, rgb_batch_fetcher 47 | ): 48 | rays = test_data_handler.rays 49 | points, _, _, _ = model.get_trace_data() 50 | start_points = model.get_starting_point( 51 | rays[:, 0, 0].cuda(), points, model.aabb_tree 52 | ) 53 | 54 | psnr_list = [] 55 | with torch.no_grad(): 56 | for i in range(rays.shape[0]): 57 | ray_batch = ray_batch_fetcher.next()[0] 58 | rgb_batch = rgb_batch_fetcher.next()[0] 59 | output, _, _, _, _ = model(ray_batch, start_points[i]) 60 | 61 | # White background 62 | opacity = output[..., -1:] 63 | rgb_output = output[..., :3] + (1 - opacity) 64 | rgb_output = rgb_output.reshape(*rgb_batch.shape).clip(0, 1) 65 | 66 | img_psnr = psnr(rgb_output, rgb_batch).mean() 67 | psnr_list.append(img_psnr) 68 | torch.cuda.synchronize() 69 | 70 | error = np.uint8((rgb_output - rgb_batch).cpu().abs() * 255) 71 | rgb_output = np.uint8(rgb_output.cpu() * 255) 72 | rgb_batch = np.uint8(rgb_batch.cpu() * 255) 73 | 74 | im = Image.fromarray( 75 | np.concatenate([rgb_output, rgb_batch, error], axis=1) 76 | ) 77 | im.save( 78 | f"{checkpoint}/test/rgb_{i:03d}_psnr_{img_psnr:.3f}.png" 79 | ) 80 | 81 | average_psnr = sum(psnr_list) / len(psnr_list) 82 | 83 | f = open(f"{checkpoint}/metrics.txt", "w") 84 | f.write(f"Average PSNR: {average_psnr}") 85 | f.close() 86 | 87 | return average_psnr 88 | 89 | test_render( 90 | test_data_handler, test_ray_batch_fetcher, test_rgb_batch_fetcher 91 | ) 92 | 93 | 94 | def main(): 95 | parser = configargparse.ArgParser() 96 | 97 | model_params = ModelParams(parser) 98 | dataset_params = DatasetParams(parser) 99 | pipeline_params = PipelineParams(parser) 100 | optimization_params = OptimizationParams(parser) 101 | 102 | # Add argument to specify a custom config file 103 | parser.add_argument( 104 | "-c", "--config", is_config_file=True, help="Path to config file" 105 | ) 106 | 107 | # Parse arguments 108 | args = parser.parse_args() 109 | 110 | test( 111 | args, 112 | pipeline_params.extract(args), 113 | model_params.extract(args), 114 | optimization_params.extract(args), 115 | dataset_params.extract(args), 116 | ) 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /torch_bindings/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | execute_process( 2 | COMMAND python ${CMAKE_SOURCE_DIR}/scripts/torch_info.py torch 3 | OUTPUT_VARIABLE TORCH_VERSION 4 | OUTPUT_STRIP_TRAILING_WHITESPACE) 5 | 6 | execute_process( 7 | COMMAND python ${CMAKE_SOURCE_DIR}/scripts/torch_info.py cuda 8 | OUTPUT_VARIABLE TORCH_CUDA_VERSION 9 | OUTPUT_STRIP_TRAILING_WHITESPACE) 10 | 11 | if(TORCH_CUDA_VERSION VERSION_EQUAL CUDA_VERSION_STRING) 12 | message(STATUS "CUDA version is: ${CUDA_VERSION_STRING}") 13 | else() 14 | message( 15 | FATAL_ERROR 16 | "CUDA version found (${CUDA_VERSION_STRING}) does not match the version required by PyTorch (${TORCH_CUDA_VERSION})." 17 | ) 18 | endif() 19 | 20 | pybind11_add_module(torch_bindings MODULE torch_bindings.cpp 21 | pipeline_bindings.cpp triangulation_bindings.cpp) 22 | 23 | target_include_directories(torch_bindings PUBLIC ${CMAKE_SOURCE_DIR}/src) 24 | 25 | target_link_libraries(torch_bindings PRIVATE torch ${TORCH_PYTHON_LIBRARY} 26 | radfoam ${GLFW_LIBRARY}) 27 | 28 | install( 29 | TARGETS torch_bindings 30 | COMPONENT torch_bindings 31 | LIBRARY DESTINATION ${RADFOAM_INSTALL_PREFIX} 32 | ARCHIVE DESTINATION ${RADFOAM_INSTALL_PREFIX} 33 | RUNTIME DESTINATION ${RADFOAM_INSTALL_PREFIX}) 34 | 35 | configure_file(${CMAKE_SOURCE_DIR}/torch_bindings/radfoam/__init__.py.in 36 | ${CMAKE_BINARY_DIR}/__init__.py @ONLY) 37 | 38 | install(FILES ${CMAKE_BINARY_DIR}/__init__.py 39 | DESTINATION ${RADFOAM_INSTALL_PREFIX}) 40 | -------------------------------------------------------------------------------- /torch_bindings/bindings.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "utils/geometry.h" 13 | 14 | namespace radfoam_bindings { 15 | 16 | namespace py = pybind11; 17 | using namespace radfoam; 18 | 19 | inline void set_default_stream() { 20 | auto stream = at::cuda::getCurrentCUDAStream(); 21 | at::cuda::setCurrentCUDAStream(stream); 22 | } 23 | 24 | inline ScalarType dtype_to_scalar_type(py::object dtype) { 25 | std::string dtype_str = py::str(dtype).cast(); 26 | 27 | if (dtype_str == "float32") { 28 | return ScalarType::Float32; 29 | } else if (dtype_str == "torch.float32") { 30 | return ScalarType::Float32; 31 | } else if (dtype_str == "float64") { 32 | return ScalarType::Float64; 33 | } else if (dtype_str == "torch.float64") { 34 | return ScalarType::Float64; 35 | } else if (dtype_str == "float16") { 36 | return ScalarType::Float16; 37 | } else if (dtype_str == "torch.float16") { 38 | return ScalarType::Float16; 39 | } else { 40 | throw std::runtime_error("unsupported dtype '" + dtype_str + "'"); 41 | } 42 | } 43 | 44 | inline ScalarType dtype_to_scalar_type(at::ScalarType dtype) { 45 | if (dtype == at::kFloat) { 46 | return ScalarType::Float32; 47 | } else if (dtype == at::kDouble) { 48 | return ScalarType::Float64; 49 | } else if (dtype == at::kHalf) { 50 | return ScalarType::Float16; 51 | } else { 52 | throw std::runtime_error("unsupported dtype '" + 53 | std::string(c10::toString(dtype)) + "'"); 54 | } 55 | } 56 | 57 | inline caffe2::TypeMeta scalar_to_type_meta(ScalarType scalar) { 58 | switch (scalar) { 59 | case ScalarType::Float32: 60 | return caffe2::TypeMeta::Make(); 61 | case ScalarType::Float64: 62 | return caffe2::TypeMeta::Make(); 63 | case ScalarType::Float16: 64 | return caffe2::TypeMeta::Make(); 65 | case ScalarType::UInt32: 66 | return caffe2::TypeMeta::Make(); 67 | default: 68 | throw std::runtime_error("unsupported scalar type"); 69 | } 70 | } 71 | 72 | inline std::array get_3d_shape(const torch::Tensor &tensor, 73 | int feature_dims) { 74 | std::array shape = {1, 1, 1}; 75 | uint32_t product = 1; 76 | for (int i = 0; i < feature_dims; i++) { 77 | product *= tensor.size(-(i + 1)); 78 | } 79 | for (int i = 0; i < 2; i++) { 80 | if (i + feature_dims + 1 > tensor.dim()) { 81 | break; 82 | } 83 | product *= tensor.size(-(i + feature_dims + 1)); 84 | shape[i] = tensor.size(-(i + feature_dims + 1)); 85 | } 86 | shape[2] = tensor.numel() / product; 87 | return shape; 88 | } 89 | 90 | } // namespace radfoam_bindings -------------------------------------------------------------------------------- /torch_bindings/pipeline_bindings.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "bindings.h" 4 | 5 | namespace radfoam_bindings { 6 | 7 | void init_pipeline_bindings(py::module &module); 8 | 9 | } -------------------------------------------------------------------------------- /torch_bindings/radfoam/__init__.py.in: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import ctypes 4 | 5 | import torch 6 | 7 | torch_version_at_compile = "@TORCH_VERSION@" 8 | torch_version_at_runtime = torch.__version__.split("+")[0] 9 | 10 | if torch_version_at_compile != torch_version_at_runtime: 11 | warnings.warn( 12 | f"RadFoam was compiled with torch version {torch_version_at_compile}, but " 13 | f"the current torch version is {torch_version_at_runtime}. This might lead to " 14 | "unexpected behavior or crashes." 15 | ) 16 | 17 | if @USE_PIP_GLFW@: 18 | import glfw 19 | glfw_path = os.path.split(glfw.__file__)[0] 20 | libglfw_path = os.path.join(glfw_path, "x11", "libglfw.so") 21 | libdl = ctypes.CDLL("libdl.so.2") 22 | RTLD_NOW = 0x00002 23 | RTLD_GLOBAL = 0x00100 24 | handle = libdl.dlopen(libglfw_path.encode("utf-8"), RTLD_NOW | RTLD_GLOBAL) 25 | 26 | if handle is None: 27 | raise ImportError("failed to load libglfw.so") 28 | 29 | from .torch_bindings import * 30 | -------------------------------------------------------------------------------- /torch_bindings/torch_bindings.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "pipeline_bindings.h" 8 | #include "triangulation_bindings.h" 9 | #include "utils/batch_fetcher.h" 10 | #include "utils/cuda_array.h" 11 | 12 | namespace radfoam { 13 | 14 | struct TorchBuffer : public OpaqueBuffer { 15 | torch::Tensor tensor; 16 | 17 | TorchBuffer(size_t bytes) { 18 | // allocate on CUDA device 19 | // int64 dtype for alignment 20 | size_t num_words = (bytes + sizeof(int64_t) - 1) / sizeof(int64_t); 21 | tensor = torch::empty({(int64_t)num_words}, 22 | torch::dtype(torch::kInt64).device(torch::kCUDA)); 23 | } 24 | 25 | void *data() override { return tensor.data_ptr(); } 26 | }; 27 | 28 | std::unique_ptr allocate_buffer(size_t bytes) { 29 | return std::make_unique(bytes); 30 | } 31 | 32 | struct TorchBatchFetcher { 33 | std::unique_ptr fetcher; 34 | torch::Tensor data; 35 | size_t batch_size; 36 | 37 | TorchBatchFetcher(torch::Tensor _data, size_t _batch_size, bool shuffle) 38 | : data(_data), batch_size(_batch_size) { 39 | size_t num_bytes = data.numel() * data.element_size(); 40 | size_t num_elems = data.size(0); 41 | size_t stride = num_bytes / num_elems; 42 | fetcher = create_batch_fetcher( 43 | data.data_ptr(), num_bytes, stride, batch_size, shuffle); 44 | } 45 | 46 | torch::Tensor next() { 47 | void *batch = fetcher->next(); 48 | std::vector shape; 49 | shape.push_back(batch_size); 50 | for (int i = 1; i < data.dim(); i++) { 51 | shape.push_back(data.size(i)); 52 | } 53 | return torch::from_blob(batch, 54 | shape, 55 | torch::dtype(data.dtype()).device(torch::kCUDA)) 56 | .clone(); 57 | } 58 | }; 59 | 60 | std::unique_ptr create_torch_batch_fetcher( 61 | torch::Tensor data, size_t batch_size, bool shuffle) { 62 | return std::make_unique(data, batch_size, shuffle); 63 | } 64 | 65 | } // namespace radfoam 66 | 67 | namespace radfoam_bindings { 68 | 69 | PYBIND11_MODULE(torch_bindings, module) { 70 | using namespace radfoam_bindings; 71 | 72 | module.doc() = "radfoam pytorch bindings module"; 73 | 74 | init_pipeline_bindings(module); 75 | init_triangulation_bindings(module); 76 | 77 | py::class_>( 78 | module, "BatchFetcher") 79 | .def(py::init(&create_torch_batch_fetcher), 80 | py::arg("data"), 81 | py::arg("batch_size"), 82 | py::arg("shuffle")) 83 | .def("next", &TorchBatchFetcher::next); 84 | } 85 | 86 | } // namespace radfoam_bindings 87 | -------------------------------------------------------------------------------- /torch_bindings/triangulation_bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "triangulation_bindings.h" 4 | 5 | #include "aabb_tree/aabb_tree.h" 6 | #include "delaunay/delaunay.h" 7 | #include "delaunay/triangulation_ops.h" 8 | 9 | namespace radfoam_bindings { 10 | 11 | std::unique_ptr create_triangulation(torch::Tensor points) { 12 | if (points.size(-1) != 3) { 13 | throw std::runtime_error("points must have 3 as the last dimension"); 14 | } 15 | if (points.device().type() != at::kCUDA) { 16 | throw std::runtime_error("points must be on CUDA device"); 17 | } 18 | if (points.scalar_type() != torch::kFloat32) { 19 | throw std::runtime_error("points must have float32 dtype"); 20 | } 21 | 22 | uint32_t num_points = points.numel() / 3; 23 | 24 | set_default_stream(); 25 | 26 | return Triangulation::create_triangulation(points.data_ptr(), num_points); 27 | } 28 | 29 | bool rebuild(Triangulation &triangulation, 30 | torch::Tensor points, 31 | bool incremental) { 32 | if (points.size(-1) != 3) { 33 | throw std::runtime_error("points must have 3 as the last dimension"); 34 | } 35 | if (points.device().type() != at::kCUDA) { 36 | throw std::runtime_error("points must be on CUDA device"); 37 | } 38 | if (points.scalar_type() != torch::kFloat32) { 39 | throw std::runtime_error("points must have float32 dtype"); 40 | } 41 | 42 | set_default_stream(); 43 | 44 | return triangulation.rebuild( 45 | points.data_ptr(), points.numel() / 3, incremental); 46 | } 47 | 48 | torch::Tensor permutation(const Triangulation &triangulation) { 49 | const uint32_t *permutation = triangulation.permutation(); 50 | uint32_t num_points = triangulation.num_points(); 51 | 52 | at::TensorOptions options = 53 | at::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA); 54 | 55 | return torch::from_blob( 56 | const_cast(permutation), {num_points}, options); 57 | } 58 | 59 | torch::Tensor get_tets(const Triangulation &triangulation) { 60 | const IndexedTet *tets = triangulation.tets(); 61 | uint32_t num_tets = triangulation.num_tets(); 62 | 63 | at::TensorOptions options = 64 | at::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA); 65 | 66 | return torch::from_blob( 67 | const_cast(tets), {num_tets, 4}, options); 68 | } 69 | 70 | torch::Tensor get_tet_adjacency(const Triangulation &triangulation) { 71 | const uint32_t *tet_adjacency = triangulation.tet_adjacency(); 72 | uint32_t num_tets = triangulation.num_tets(); 73 | 74 | at::TensorOptions options = 75 | at::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA); 76 | 77 | return torch::from_blob( 78 | const_cast(tet_adjacency), {num_tets, 4}, options); 79 | } 80 | 81 | torch::Tensor get_point_adjacency(const Triangulation &triangulation) { 82 | const uint32_t *point_adjacency = triangulation.point_adjacency(); 83 | uint32_t point_adjacency_size = triangulation.point_adjacency_size(); 84 | 85 | at::TensorOptions options = 86 | at::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA); 87 | 88 | return torch::from_blob(const_cast(point_adjacency), 89 | {point_adjacency_size}, 90 | options); 91 | } 92 | 93 | torch::Tensor get_point_adjacency_offsets(const Triangulation &triangulation) { 94 | const uint32_t *point_adjacency_offsets = 95 | triangulation.point_adjacency_offsets(); 96 | uint32_t num_points = triangulation.num_points(); 97 | 98 | at::TensorOptions options = 99 | at::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA); 100 | 101 | return torch::from_blob(const_cast(point_adjacency_offsets), 102 | {num_points + 1}, 103 | options); 104 | } 105 | 106 | torch::Tensor get_vert_to_tet(const Triangulation &triangulation) { 107 | const uint32_t *vert_to_tet = triangulation.vert_to_tet(); 108 | uint32_t num_points = triangulation.num_points(); 109 | 110 | at::TensorOptions options = 111 | at::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA); 112 | 113 | return torch::from_blob( 114 | const_cast(vert_to_tet), {num_points}, options); 115 | } 116 | 117 | torch::Tensor build_aabb_tree(torch::Tensor points) { 118 | if (points.size(-1) != 3) { 119 | throw std::runtime_error("points must have 3 as the last dimension"); 120 | } 121 | if (points.dim() != 2) { 122 | throw std::runtime_error("points must have 2 dimensions"); 123 | } 124 | if (points.device().type() != at::kCUDA) { 125 | throw std::runtime_error("points must be on CUDA device"); 126 | } 127 | 128 | ScalarType scalar_type = dtype_to_scalar_type(points.scalar_type()); 129 | 130 | uint32_t num_points = points.numel() / 3; 131 | 132 | torch::Tensor aabb_tree = torch::empty( 133 | {pow2_round_up(num_points), 2, 3}, 134 | torch::TensorOptions().dtype(points.dtype()).device(points.device())); 135 | 136 | radfoam::build_aabb_tree( 137 | scalar_type, points.data_ptr(), num_points, aabb_tree.data_ptr()); 138 | 139 | return aabb_tree; 140 | } 141 | 142 | torch::Tensor 143 | nn(torch::Tensor points, torch::Tensor tree, torch::Tensor queries) { 144 | uint32_t num_points = points.numel() / 3; 145 | uint32_t num_queries = queries.numel() / 3; 146 | 147 | if (points.scalar_type() != queries.scalar_type()) { 148 | throw std::runtime_error("points and queries must have the same dtype"); 149 | } 150 | if (points.scalar_type() != tree.scalar_type()) { 151 | throw std::runtime_error("points and tree must have the same dtype"); 152 | } 153 | if (points.device().type() != at::kCUDA) { 154 | throw std::runtime_error("points must be on CUDA device"); 155 | } 156 | if (tree.device().type() != at::kCUDA) { 157 | throw std::runtime_error("tree must be on CUDA device"); 158 | } 159 | if (queries.device().type() != at::kCUDA) { 160 | throw std::runtime_error("queries must be on CUDA device"); 161 | } 162 | 163 | std::vector indices_shape; 164 | 165 | for (int64_t i = 0; i < queries.dim() - 1; i++) { 166 | indices_shape.push_back(queries.size(i)); 167 | } 168 | 169 | torch::Tensor indices = torch::zeros( 170 | indices_shape, torch::dtype(torch::kUInt32).device(queries.device())); 171 | 172 | radfoam::nn(dtype_to_scalar_type(points.scalar_type()), 173 | points.data_ptr(), 174 | tree.data_ptr(), 175 | queries.data_ptr(), 176 | num_points, 177 | num_queries, 178 | static_cast(indices.data_ptr())); 179 | 180 | return indices; 181 | } 182 | 183 | std::tuple 184 | farthest_neighbor(torch::Tensor points_in, 185 | torch::Tensor point_adjacency_in, 186 | torch::Tensor point_adjacency_offsets_in) { 187 | uint32_t num_points = points_in.size(0); 188 | torch::Tensor points = points_in.contiguous(); 189 | torch::Tensor point_adjacency = point_adjacency_in.contiguous(); 190 | torch::Tensor point_adjacency_offsets = 191 | point_adjacency_offsets_in.contiguous(); 192 | 193 | if (points.device().type() != at::kCUDA) { 194 | throw std::runtime_error("points must be on CUDA device"); 195 | } 196 | 197 | std::vector indices_shape; 198 | 199 | for (int64_t i = 0; i < points.dim() - 1; i++) { 200 | indices_shape.push_back(points.size(i)); 201 | } 202 | 203 | torch::Tensor indices = torch::zeros( 204 | indices_shape, torch::dtype(torch::kUInt32).device(points.device())); 205 | torch::Tensor cell_radius = torch::zeros( 206 | indices_shape, torch::dtype(torch::kFloat32).device(points.device())); 207 | 208 | radfoam::farthest_neighbor(dtype_to_scalar_type(points.scalar_type()), 209 | points.data_ptr(), 210 | point_adjacency.data_ptr(), 211 | point_adjacency_offsets.data_ptr(), 212 | num_points, 213 | static_cast(indices.data_ptr()), 214 | static_cast(cell_radius.data_ptr())); 215 | 216 | return std::make_tuple(indices, cell_radius); 217 | } 218 | 219 | void init_triangulation_bindings(py::module &module) { 220 | radfoam::global_cuda_init(); 221 | 222 | py::register_exception( 223 | module, "TriangulationFailedError"); 224 | 225 | py::class_>(module, 226 | "Triangulation") 227 | .def(py::init(&create_triangulation), py::arg("points")) 228 | .def("tets", &get_tets) 229 | .def("tet_adjacency", &get_tet_adjacency) 230 | .def("point_adjacency", &get_point_adjacency) 231 | .def("point_adjacency_offsets", &get_point_adjacency_offsets) 232 | .def("vert_to_tet", &get_vert_to_tet) 233 | .def("rebuild", 234 | &rebuild, 235 | py::arg("points"), 236 | py::arg("incremental") = false) 237 | .def("permutation", &permutation); 238 | 239 | module.def("build_aabb_tree", &build_aabb_tree, py::arg("points")); 240 | 241 | module.def( 242 | "nn", &nn, py::arg("points"), py::arg("tree"), py::arg("queries")); 243 | 244 | module.def("farthest_neighbor", 245 | &farthest_neighbor, 246 | py::arg("points"), 247 | py::arg("point_adjacency"), 248 | py::arg("point_adjacency_offsets")); 249 | } 250 | 251 | } // namespace radfoam_bindings -------------------------------------------------------------------------------- /torch_bindings/triangulation_bindings.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "bindings.h" 4 | 5 | namespace radfoam_bindings { 6 | 7 | void init_triangulation_bindings(py::module &module); 8 | 9 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import yaml 4 | import gc 5 | import numpy as np 6 | from PIL import Image 7 | import configargparse 8 | import tqdm 9 | import warnings 10 | 11 | warnings.filterwarnings("ignore") 12 | 13 | import torch 14 | from torch import nn 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | from data_loader import DataHandler 18 | from configs import * 19 | from radfoam_model.scene import RadFoamScene 20 | from radfoam_model.utils import psnr 21 | import radfoam 22 | 23 | 24 | seed = 42 25 | torch.random.manual_seed(seed) 26 | np.random.seed(seed) 27 | 28 | 29 | def train(args, pipeline_args, model_args, optimizer_args, dataset_args): 30 | device = torch.device(model_args.device) 31 | # Setting up output directory 32 | if not pipeline_args.debug: 33 | if len(pipeline_args.experiment_name) == 0: 34 | unique_str = str(uuid.uuid4())[:8] 35 | experiment_name = f"{dataset_args.scene}@{unique_str}" 36 | else: 37 | experiment_name = pipeline_args.experiment_name 38 | out_dir = f"output/{experiment_name}" 39 | writer = SummaryWriter(out_dir, purge_step=0) 40 | os.makedirs(f"{out_dir}/test", exist_ok=True) 41 | 42 | def represent_list_inline(dumper, data): 43 | return dumper.represent_sequence( 44 | "tag:yaml.org,2002:seq", data, flow_style=True 45 | ) 46 | 47 | yaml.add_representer(list, represent_list_inline) 48 | 49 | # Save the arguments to a YAML file 50 | with open(f"{out_dir}/config.yaml", "w") as yaml_file: 51 | yaml.dump(vars(args), yaml_file, default_flow_style=False) 52 | 53 | # Setting up dataset 54 | iter2downsample = dict( 55 | zip( 56 | dataset_args.downsample_iterations, 57 | dataset_args.downsample, 58 | ) 59 | ) 60 | train_data_handler = DataHandler( 61 | dataset_args, rays_per_batch=1_000_000, device=device 62 | ) 63 | downsample = iter2downsample[0] 64 | train_data_handler.reload(split="train", downsample=downsample) 65 | 66 | test_data_handler = DataHandler( 67 | dataset_args, rays_per_batch=0, device=device 68 | ) 69 | test_data_handler.reload( 70 | split="test", downsample=min(dataset_args.downsample) 71 | ) 72 | test_ray_batch_fetcher = radfoam.BatchFetcher( 73 | test_data_handler.rays, batch_size=1, shuffle=False 74 | ) 75 | test_rgb_batch_fetcher = radfoam.BatchFetcher( 76 | test_data_handler.rgbs, batch_size=1, shuffle=False 77 | ) 78 | 79 | # Define viewer settings 80 | viewer_options = { 81 | "camera_pos": train_data_handler.viewer_pos, 82 | "camera_up": train_data_handler.viewer_up, 83 | "camera_forward": train_data_handler.viewer_forward, 84 | } 85 | 86 | # Setting up pipeline 87 | rgb_loss = nn.SmoothL1Loss(reduction="none") 88 | 89 | # Setting up model 90 | model = RadFoamScene( 91 | args=model_args, 92 | device=device, 93 | points=train_data_handler.points3D, 94 | points_colors=train_data_handler.points3D_colors, 95 | ) 96 | 97 | # Setting up optimizer 98 | model.declare_optimizer( 99 | args=optimizer_args, 100 | warmup=pipeline_args.densify_from, 101 | max_iterations=pipeline_args.iterations, 102 | ) 103 | 104 | def test_render( 105 | test_data_handler, ray_batch_fetcher, rgb_batch_fetcher, debug=False 106 | ): 107 | rays = test_data_handler.rays 108 | points, _, _, _ = model.get_trace_data() 109 | start_points = model.get_starting_point( 110 | rays[:, 0, 0].cuda(), points, model.aabb_tree 111 | ) 112 | 113 | psnr_list = [] 114 | with torch.no_grad(): 115 | for i in range(rays.shape[0]): 116 | ray_batch = ray_batch_fetcher.next()[0] 117 | rgb_batch = rgb_batch_fetcher.next()[0] 118 | output, _, _, _, _ = model(ray_batch, start_points[i]) 119 | 120 | # White background 121 | opacity = output[..., -1:] 122 | rgb_output = output[..., :3] + (1 - opacity) 123 | rgb_output = rgb_output.reshape(*rgb_batch.shape).clip(0, 1) 124 | 125 | img_psnr = psnr(rgb_output, rgb_batch).mean() 126 | psnr_list.append(img_psnr) 127 | torch.cuda.synchronize() 128 | 129 | if not debug: 130 | error = np.uint8((rgb_output - rgb_batch).cpu().abs() * 255) 131 | rgb_output = np.uint8(rgb_output.cpu() * 255) 132 | rgb_batch = np.uint8(rgb_batch.cpu() * 255) 133 | 134 | im = Image.fromarray( 135 | np.concatenate([rgb_output, rgb_batch, error], axis=1) 136 | ) 137 | im.save( 138 | f"{out_dir}/test/rgb_{i:03d}_psnr_{img_psnr:.3f}.png" 139 | ) 140 | 141 | average_psnr = sum(psnr_list) / len(psnr_list) 142 | if not debug: 143 | f = open(f"{out_dir}/metrics.txt", "w") 144 | f.write(f"Average PSNR: {average_psnr}") 145 | f.close() 146 | 147 | return average_psnr 148 | 149 | def train_loop(viewer): 150 | print("Training") 151 | 152 | torch.cuda.synchronize() 153 | 154 | data_iterator = train_data_handler.get_iter() 155 | ray_batch, rgb_batch, alpha_batch = next(data_iterator) 156 | 157 | triangulation_update_period = 1 158 | iters_since_update = 1 159 | iters_since_densification = 0 160 | next_densification_after = 1 161 | 162 | with tqdm.trange(pipeline_args.iterations) as train: 163 | for i in train: 164 | if viewer is not None: 165 | model.update_viewer(viewer) 166 | viewer.step(i) 167 | 168 | if i in iter2downsample and i: 169 | downsample = iter2downsample[i] 170 | train_data_handler.reload( 171 | split="train", downsample=downsample 172 | ) 173 | data_iterator = train_data_handler.get_iter() 174 | ray_batch, rgb_batch, alpha_batch = next(data_iterator) 175 | 176 | depth_quantiles = ( 177 | torch.rand(*ray_batch.shape[:-1], 2, device=device) 178 | .sort(dim=-1, descending=True) 179 | .values 180 | ) 181 | 182 | rgba_output, depth, _, _, _ = model( 183 | ray_batch, 184 | depth_quantiles=depth_quantiles, 185 | ) 186 | 187 | # White background 188 | opacity = rgba_output[..., -1:] 189 | if pipeline_args.white_background: 190 | rgb_output = rgba_output[..., :3] + (1 - opacity) 191 | else: 192 | rgb_output = rgba_output[..., :3] 193 | 194 | color_loss = rgb_loss(rgb_batch, rgb_output) 195 | opacity_loss = ((alpha_batch - opacity) ** 2).mean() 196 | 197 | valid_depth_mask = (depth > 0).all(dim=-1) 198 | quant_loss = (depth[..., 0] - depth[..., 1]).abs() 199 | quant_loss = (quant_loss * valid_depth_mask).mean() 200 | w_depth = pipeline_args.quantile_weight * min( 201 | 2 * i / pipeline_args.iterations, 1 202 | ) 203 | 204 | loss = color_loss.mean() + opacity_loss + w_depth * quant_loss 205 | 206 | model.optimizer.zero_grad(set_to_none=True) 207 | 208 | # Hide latency of data loading behind the backward pass 209 | event = torch.cuda.Event() 210 | event.record() 211 | loss.backward() 212 | event.synchronize() 213 | ray_batch, rgb_batch, alpha_batch = next(data_iterator) 214 | 215 | model.optimizer.step() 216 | model.update_learning_rate(i) 217 | 218 | train.set_postfix(color_loss=f"{color_loss.mean().item():.5f}") 219 | 220 | if i % 100 == 99 and not pipeline_args.debug: 221 | writer.add_scalar("train/rgb_loss", color_loss.mean(), i) 222 | num_points = model.primal_points.shape[0] 223 | writer.add_scalar("test/num_points", num_points, i) 224 | 225 | test_psnr = test_render( 226 | test_data_handler, 227 | test_ray_batch_fetcher, 228 | test_rgb_batch_fetcher, 229 | True, 230 | ) 231 | writer.add_scalar("test/psnr", test_psnr, i) 232 | 233 | writer.add_scalar( 234 | "lr/points_lr", model.xyz_scheduler_args(i), i 235 | ) 236 | writer.add_scalar( 237 | "lr/density_lr", model.den_scheduler_args(i), i 238 | ) 239 | writer.add_scalar( 240 | "lr/attr_lr", model.attr_dc_scheduler_args(i), i 241 | ) 242 | 243 | if iters_since_update >= triangulation_update_period: 244 | model.update_triangulation(incremental=True) 245 | iters_since_update = 0 246 | 247 | if triangulation_update_period < 100: 248 | triangulation_update_period += 2 249 | 250 | iters_since_update += 1 251 | if i + 1 >= pipeline_args.densify_from: 252 | iters_since_densification += 1 253 | 254 | if ( 255 | iters_since_densification == next_densification_after 256 | and model.primal_points.shape[0] 257 | < 0.9 * model.num_final_points 258 | ): 259 | point_error, point_contribution = model.collect_error_map( 260 | train_data_handler, pipeline_args.white_background 261 | ) 262 | model.prune_and_densify( 263 | point_error, 264 | point_contribution, 265 | pipeline_args.densify_factor, 266 | ) 267 | 268 | model.update_triangulation(incremental=False) 269 | triangulation_update_period = 1 270 | gc.collect() 271 | 272 | # Linear growth 273 | iters_since_densification = 0 274 | next_densification_after = int( 275 | ( 276 | (pipeline_args.densify_factor - 1) 277 | * model.primal_points.shape[0] 278 | * ( 279 | pipeline_args.densify_until 280 | - pipeline_args.densify_from 281 | ) 282 | ) 283 | / (model.num_final_points - model.num_init_points) 284 | ) 285 | next_densification_after = max( 286 | next_densification_after, 100 287 | ) 288 | 289 | if i == optimizer_args.freeze_points: 290 | model.update_triangulation(incremental=False) 291 | 292 | if viewer is not None and viewer.is_closed(): 293 | break 294 | 295 | model.save_ply(f"{out_dir}/scene.ply") 296 | model.save_pt(f"{out_dir}/model.pt") 297 | del data_iterator 298 | 299 | if pipeline_args.viewer: 300 | model.show( 301 | train_loop, iterations=pipeline_args.iterations, **viewer_options 302 | ) 303 | else: 304 | train_loop(viewer=None) 305 | if not pipeline_args.debug: 306 | writer.close() 307 | 308 | test_render( 309 | test_data_handler, 310 | test_ray_batch_fetcher, 311 | test_rgb_batch_fetcher, 312 | pipeline_args.debug, 313 | ) 314 | 315 | 316 | def main(): 317 | parser = configargparse.ArgParser( 318 | default_config_files=["arguments/mipnerf360_outdoor_config.yaml"] 319 | ) 320 | 321 | model_params = ModelParams(parser) 322 | pipeline_params = PipelineParams(parser) 323 | optimization_params = OptimizationParams(parser) 324 | dataset_params = DatasetParams(parser) 325 | 326 | # Add argument to specify a custom config file 327 | parser.add_argument( 328 | "-c", "--config", is_config_file=True, help="Path to config file" 329 | ) 330 | 331 | # Parse arguments 332 | args = parser.parse_args() 333 | 334 | train( 335 | args, 336 | pipeline_params.extract(args), 337 | model_params.extract(args), 338 | optimization_params.extract(args), 339 | dataset_params.extract(args), 340 | ) 341 | 342 | 343 | if __name__ == "__main__": 344 | main() 345 | -------------------------------------------------------------------------------- /viewer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import configargparse 4 | import warnings 5 | 6 | warnings.filterwarnings("ignore") 7 | 8 | import torch 9 | 10 | from data_loader import DataHandler 11 | from configs import * 12 | from radfoam_model.scene import RadFoamScene 13 | 14 | 15 | seed = 42 16 | torch.random.manual_seed(seed) 17 | np.random.seed(seed) 18 | 19 | 20 | def viewer(args, pipeline_args, model_args, optimizer_args, dataset_args): 21 | checkpoint = args.config.replace("/config.yaml", "") 22 | device = torch.device(args.device) 23 | 24 | test_data_handler = DataHandler( 25 | dataset_args, rays_per_batch=0, device=device 26 | ) 27 | test_data_handler.reload(split="test", downsample=min(dataset_args.downsample)) 28 | 29 | # Define viewer settings 30 | viewer_options = { 31 | "camera_pos": test_data_handler.viewer_pos, 32 | "camera_up": test_data_handler.viewer_up, 33 | "camera_forward": test_data_handler.viewer_forward, 34 | } 35 | 36 | # Setting up model 37 | model = RadFoamScene( 38 | args=model_args, device=device, attr_dtype=torch.float16 39 | ) 40 | 41 | model.load_pt(f"{checkpoint}/model.pt") 42 | 43 | def viewer_init(viewer): 44 | model.update_viewer(viewer) 45 | 46 | model.show(viewer_init, **viewer_options) 47 | 48 | 49 | def main(): 50 | parser = configargparse.ArgParser() 51 | 52 | model_params = ModelParams(parser) 53 | dataset_params = DatasetParams(parser) 54 | pipeline_params = PipelineParams(parser) 55 | optimization_params = OptimizationParams(parser) 56 | 57 | # Add argument to specify a custom config file 58 | parser.add_argument( 59 | "-c", "--config", is_config_file=True, help="Path to config file" 60 | ) 61 | 62 | # Parse arguments 63 | args = parser.parse_args() 64 | 65 | viewer( 66 | args, 67 | pipeline_params.extract(args), 68 | model_params.extract(args), 69 | optimization_params.extract(args), 70 | dataset_params.extract(args), 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | --------------------------------------------------------------------------------