├── .gitmodules ├── LICENSE.md ├── LICENSE_BBSplat.txt ├── README.md ├── arguments └── __init__.py ├── assets ├── alpha_init_gaussian.png ├── alpha_init_gaussian_small.png ├── control_panel.png └── readme_images │ ├── blender_preset.jpg │ ├── scull.gif │ ├── teaser.png │ ├── train.gif │ └── visualizer.png ├── bbsplat_install.sh ├── convert.py ├── docker ├── Dockerfile ├── build.sh ├── environment.yml ├── push.sh ├── run.sh └── source.sh ├── docker_colmap ├── run.sh └── source.sh ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── scripts ├── average_error.py ├── colmap_all.sh ├── dtu_eval.py ├── eval_dtu │ ├── eval.py │ ├── evaluate_single_scene.py │ └── render_utils.py ├── metrics_all.sh ├── render_all.sh └── train_all.sh ├── train.py ├── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── mcube_utils.py ├── mesh_utils.py ├── point_utils.py ├── reconstruction_utils.py ├── render_utils.py ├── sh_utils.py └── system_utils.py └── visualize.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/diff-bbsplat-rasterization"] 5 | path = submodules/diff-bbsplat-rasterization 6 | url = https://github.com/david-svitov/diff-bbsplat-rasterization.git 7 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /LICENSE_BBSplat.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, David Svitov 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BillBoard Splatting (BBSplat): Learnable Textured Primitives for Novel View Synthesis 2 | 3 | [Project page](https://david-svitov.github.io/BBSplat_project_page/) | [Paper](https://arxiv.org/pdf/2411.08508) | [Video](https://youtu.be/ZnIOZHBJ4wM) | [BBSplat Rasterizer (CUDA)](https://github.com/david-svitov/diff-bbsplat-rasterization/) | [Scenes example (1.5GB)](https://drive.google.com/file/d/1gu_bDFXx38KJtwIrXo8lMVtuY-P2PFXX/view?usp=sharing) |
4 | 5 | ![Teaser image](assets/readme_images/teaser.png) 6 | 7 | ## Abstract 8 | We present billboard Splatting (BBSplat) - a novel approach for novel view synthesis based on textured geometric primitives. 9 | BBSplat represents the scene as a set of optimizable textured planar primitives with learnable RGB textures and alpha-maps to 10 | control their shape. BBSplat primitives can be used in any Gaussian Splatting pipeline as drop-in replacements for Gaussians. 11 | The proposed primitives close the rendering quality gap between 2D and 3D Gaussian Splatting (GS), enabling the accurate extraction 12 | of 3D mesh as in the 2DGS framework. Additionally, the explicit nature of planar primitives enables the use of the ray-tracing effects in rasterization. 13 | Our novel regularization term encourages textures to have a sparser structure, enabling an efficient compression that leads to a reduction in the storage 14 | space of the model up to $\times17$ times compared to 3DGS. Our experiments show the efficiency of BBSplat on standard datasets of real indoor and outdoor 15 | scenes such as Tanks\&Temples, DTU, and Mip-NeRF-360. Namely, we achieve a state-of-the-art PSNR of 29.72 for DTU at Full HD resolution. 16 | 17 | ## Updates 18 | 19 | * 10/02/2025 - We fixed a bug in the FPS measurement function and updated the preprint accordingly. 20 | * 13/03/2025 - We released the mesh extraction code 21 | 22 | ## Repository structure 23 | 24 | Here, we briefly describe the key elements of the project. All main python scripts are in the ```./``` directory, 25 | bash scripts to reproduce the experiments are in the ```scripts``` folder, for a quick start please use 26 | Docker images provided in the ```docker``` folder. 27 | 28 | ```bash 29 | . 30 | ├── scripts # Bash scripts to process datasets 31 | │ ├── colmap_all.sh # > Extract point clouds with COLMAP 32 | │ ├── dtu_eval.py # Script to run DTU Chamfer distance evaluation 33 | │ ├── train_all.sh # > Fit all scenes 34 | │ ├── render_all.sh # > Render all scenes 35 | │ └── metrics_all.sh # > Calculate metrics for all scenes 36 | ├── submodules 37 | │ ├── diff-bbsplat-rasterization # CUDA implementation of BBSplat rasterized 38 | │ └── simple-knn # CUDA implementation of KNN 39 | ├── docker # Scripts to build and run Docker image 40 | ├── docker_colmap # Scripts to download and run Docker image for COLMAP 41 | ├── bbsplat_install.sh # Build and install submodules 42 | ├── convert.py # Extract point cloud with COLMAP 43 | ├── train.py # Train BBSplat scene representation 44 | ├── render.py # Novel view synthesis 45 | ├── metrics.py # Metrics calculation 46 | └── visualize.py # Interactive scene visualizer 47 | ``` 48 | 49 | 50 | ## Installation 51 | 52 | We prepared the Docker image for quick and easy installation. Please follow the next steps: 53 | 54 | ```bash 55 | # Download 56 | git clone https://github.com/david-svitov/BBSplat.git --recursive 57 | # Go to the "docker" subfolder 58 | cd BBSplat/docker 59 | 60 | # Build Docker image 61 | bash build.sh 62 | # Optionally adjust mounting folder paths in source.sh 63 | # Run Docker container 64 | bash run.sh 65 | 66 | # In the container please install submodules 67 | bash bbsplat_install.sh 68 | ``` 69 | 70 |
71 | Docker container for COLMAP 72 | 73 | To use COLMAP you can also use provided Docker image in the ```docker_colmap``` as follows: 74 | 75 | ```bash 76 | cd BBSplat/docker_colmap 77 | # Optionally adjust mounting folder paths in source.sh 78 | # Run Docker container 79 | bash run.sh 80 | 81 | # The trick is that you have to install OpenCV in this container because we use "jsantisi/colmap-gpu" one 82 | add-apt-repository universe 83 | apt-get update 84 | apt install python3-pip 85 | python3 -m pip install opencv-python 86 | ``` 87 |
88 | 89 | ## Data preprocessing 90 | 91 | The example of using ```convert.py``` can be found in ```scripts\colmap.all```. 92 | Please note that for different datasets in the paper we used different ```images_N``` folders from the COLMAP output folder. 93 | The instructions on how to install COLMAP can be found above. 94 | 95 | We use the same COLMAP loader as 3DGS and 2DGS, you can find detailed description of it [here](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file#processing-your-own-scenes). 96 | 97 | 98 | ## Training 99 | To train a scene, please use following command: 100 | ```bash 101 | python train.py -s --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval 102 | ``` 103 | Commandline arguments description: 104 | ```bash 105 | --cap_max # maximum number of Billboards 106 | --max_read_points # maximum number of SfM points for initialization 107 | --add_sky_box # flag to create additional points for far objects 108 | --eval # to hold each N-th image for evaluation 109 | 110 | # 2DGS normal-depth regularization can be beneficial for some datasets 111 | --lambda_normal # hyperparameter for normal consistency 112 | --lambda_distortion # hyperparameter for depth distortion 113 | ``` 114 | 115 | The examples of training commands for different datasets can be found in ```scripts\train_all.sh```. 116 | 117 | ## Testing 118 | ### Novel view synthesis evaluation 119 | For novel view synthesis use: 120 | ```bash 121 | python render.py -m -s 122 | ``` 123 | 124 | Commandline arguments description: 125 | ```bash 126 | --skip_mesh # flag to disable mesh extraction to accelerate NVS evaluation 127 | --save_planes # flag to save BBSplat as a set of textured planes 128 | ``` 129 | 130 | To calculate metrics values use: 131 | ```bash 132 | python metrics.py -m 133 | ``` 134 | The examples for the datasets used in the paper can be found in ```scripts\render_all.sh``` and ```scripts\metrics_all.sh```. 135 | 136 | --- 137 | ❗ **Faster inference** 138 | 139 | There is an option to accelerate inference speed by using more tight bounding boxes in the rasterisation. To do this follow next steps: 140 | * Open ```submodules/diff-bbsplat-rasterization/cuda_rasterizer/auxiliary.h``` 141 | * Modify ```#define FAST_INFERENCE 0``` to be ```#define FAST_INFERENCE 1``` 142 | * Rebuild the code with ```.\bbsplat_install.sh``` 143 | 144 | This will give you up to $\times 2$ acceleration by the cost of slight metrics degradation. 145 | 146 | --- 147 | 148 | ### DTU Chamfer distance evaluation 149 | 150 | To calculate Chamfer distance metrics for the DTU dataset simple run ```scripts\dtu_eval.py``` as fallows: 151 | ```bash 152 | python scripts/dtu_eval.py --dtu= --output_path= --DTU_Official= 153 | ``` 154 | 155 | ## Exporting to Blender 156 | 157 | The newest feature of the code is convertion of BBSplat into set of textured planes for rasterization in Blender: 158 | 159 |

160 | 161 | 162 |

163 | 164 | To do this follow these instructions. First you have to enable StopThePop sorting of billboards: 165 | * Open ```submodules/diff-bbsplat-rasterization/cuda_rasterizer/auxiliary.h``` 166 | * Switch ```#define TILE_SORTING 0``` to ```#define TILE_SORTING 1``` 167 | * Switch ```#define PIXEL_RESORTING 0``` to ```#define TILE_SORTING 1``` 168 | * Make sure that ```#define FAST_INFERENCE 0``` is seit to 0 169 | * Rebuild the code with ```.\bbsplat_install.sh``` 170 | 171 | Next simple run ```render.py``` with ```--save_planes``` flag. In the folder you will find ```planes_mesh.obj```. Import it to the blender. 172 | 173 | As the final step use alpha textures for alpha channel in Blender shader settings. Enable Raytracing and adjust sampling number for EEVEE renderer: 174 | 175 | 176 | ## Interactive visualization 177 | 178 | ![Teaser image](assets/readme_images/visualizer.png) 179 | 180 | To dynamically control camera position use ```visualize.py``` with the same ```-m -s``` parameters as ```render.py``` 181 | 182 | ## Citation 183 | If you find our code or paper helps, please consider citing: 184 | ```bibtex 185 | @article{svitov2024billboard, 186 | title={BillBoard Splatting (BBSplat): Learnable Textured Primitives for Novel View Synthesis}, 187 | author={Svitov, David and Morerio, Pietro and Agapito, Lourdes and Del Bue, Alessio}, 188 | journal={arXiv preprint arXiv:2411.08508}, 189 | year={2024} 190 | } 191 | ``` -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._resolution = -1 54 | self._white_background = False 55 | self.data_device = "cuda" 56 | self.eval = False 57 | super().__init__(parser, "Loading Parameters", sentinel) 58 | 59 | def extract(self, args): 60 | g = super().extract(args) 61 | g.source_path = os.path.abspath(g.source_path) 62 | return g 63 | 64 | class PipelineParams(ParamGroup): 65 | def __init__(self, parser): 66 | self.convert_SHs_python = False 67 | self.compute_cov3D_python = False 68 | self.depth_ratio = 0.0 69 | self.debug = False 70 | super().__init__(parser, "Pipeline Parameters") 71 | 72 | class OptimizationParams(ParamGroup): 73 | def __init__(self, parser): 74 | self.iterations = 32_000 75 | self.position_lr_init = 0.00016 76 | self.position_lr_final = 0.0000016 77 | self.position_lr_delay_mult = 0.01 78 | self.position_lr_max_steps = 30_000 79 | self.feature_lr = 0.005 80 | self.scaling_lr = 0.005 81 | self.rotation_lr = 0.001 82 | self.texture_opacity_lr = 0.001 83 | self.texture_color_lr = 0.0025 84 | self.percent_dense = 0.1 85 | self.lambda_dssim = 0.2 86 | self.lambda_dist = 0.0 87 | self.lambda_normal = 0.0 88 | self.lambda_texture_value = 0.0001 89 | self.lambda_alpha_value = 0.0001 90 | self.max_impact_threshold = 100 91 | self.sphere_point = 10000 92 | 93 | # Densification policy 94 | self.texture_from_iter = 500 95 | self.texture_to_iter = 30000 96 | self.densification_interval = 100 97 | self.densify_from_iter = 500 98 | self.densify_until_iter = 25000 99 | self.dead_opacity = 0.005 100 | 101 | # MCMC 102 | self.noise_lr = 5e5 103 | self.opacity_reg = 0.01 104 | self.cap_max = 160_000 105 | 106 | # Data 107 | self.max_read_points = self.cap_max - 20_000 108 | self.add_sky_box = False 109 | 110 | super().__init__(parser, "Optimization Parameters") 111 | 112 | def get_combined_args(parser : ArgumentParser): 113 | cmdlne_string = sys.argv[1:] 114 | cfgfile_string = "Namespace()" 115 | args_cmdline = parser.parse_args(cmdlne_string) 116 | 117 | try: 118 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 119 | print("Looking for config file in", cfgfilepath) 120 | with open(cfgfilepath) as cfg_file: 121 | print("Config file found: {}".format(cfgfilepath)) 122 | cfgfile_string = cfg_file.read() 123 | except TypeError: 124 | print("Config file not found at") 125 | pass 126 | args_cfgfile = eval(cfgfile_string) 127 | 128 | merged_dict = vars(args_cfgfile).copy() 129 | for k,v in vars(args_cmdline).items(): 130 | if v != None: 131 | merged_dict[k] = v 132 | return Namespace(**merged_dict) 133 | -------------------------------------------------------------------------------- /assets/alpha_init_gaussian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/alpha_init_gaussian.png -------------------------------------------------------------------------------- /assets/alpha_init_gaussian_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/alpha_init_gaussian_small.png -------------------------------------------------------------------------------- /assets/control_panel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/control_panel.png -------------------------------------------------------------------------------- /assets/readme_images/blender_preset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/blender_preset.jpg -------------------------------------------------------------------------------- /assets/readme_images/scull.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/scull.gif -------------------------------------------------------------------------------- /assets/readme_images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/teaser.png -------------------------------------------------------------------------------- /assets/readme_images/train.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/train.gif -------------------------------------------------------------------------------- /assets/readme_images/visualizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/visualizer.png -------------------------------------------------------------------------------- /bbsplat_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install ./submodules/diff-bbsplat-rasterization 4 | pip install ./submodules/simple-knn 5 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | input_folder = "/input" 32 | if not args.skip_matching: 33 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 34 | 35 | ## Feature extraction 36 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 37 | "--database_path " + args.source_path + "/distorted/database.db \ 38 | --image_path " + args.source_path + input_folder + " \ 39 | --ImageReader.single_camera 1 \ 40 | --ImageReader.camera_model " + args.camera + " \ 41 | --SiftExtraction.use_gpu " + str(use_gpu) 42 | exit_code = os.system(feat_extracton_cmd) 43 | if exit_code != 0: 44 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 45 | exit(exit_code) 46 | 47 | ## Feature matching 48 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 49 | --database_path " + args.source_path + "/distorted/database.db \ 50 | --SiftMatching.use_gpu " + str(use_gpu) 51 | exit_code = os.system(feat_matching_cmd) 52 | if exit_code != 0: 53 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 54 | exit(exit_code) 55 | 56 | ### Bundle adjustment 57 | # The default Mapper tolerance is unnecessarily large, 58 | # decreasing it speeds up bundle adjustment steps. 59 | mapper_cmd = (colmap_command + " mapper \ 60 | --database_path " + args.source_path + "/distorted/database.db \ 61 | --image_path " + args.source_path + input_folder + " \ 62 | --output_path " + args.source_path + "/distorted/sparse \ 63 | --Mapper.ba_global_function_tolerance=0.000001") 64 | exit_code = os.system(mapper_cmd) 65 | if exit_code != 0: 66 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 67 | exit(exit_code) 68 | 69 | ### Image undistortion 70 | ## We need to undistort our images into ideal pinhole intrinsics. 71 | img_undist_cmd = (colmap_command + " image_undistorter \ 72 | --image_path " + args.source_path + input_folder + " \ 73 | --input_path " + args.source_path + "/distorted/sparse/0 \ 74 | --output_path " + args.source_path + "\ 75 | --output_type COLMAP") 76 | exit_code = os.system(img_undist_cmd) 77 | if exit_code != 0: 78 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 79 | exit(exit_code) 80 | 81 | files = os.listdir(args.source_path + "/sparse") 82 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 83 | # Copy each file from the source directory to the destination directory 84 | for file in files: 85 | if file == '0': 86 | continue 87 | source_file = os.path.join(args.source_path, "sparse", file) 88 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 89 | shutil.move(source_file, destination_file) 90 | 91 | if(args.resize): 92 | print("Copying and resizing...") 93 | 94 | # Resize images. 95 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 97 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 98 | # Get the list of files in the source directory 99 | files = os.listdir(args.source_path + "/images") 100 | # Copy each file from the source directory to the destination directory 101 | for file in files: 102 | source_file = os.path.join(args.source_path, "images", file) 103 | 104 | destination_file = os.path.join(args.source_path, "images_2", file) 105 | shutil.copy2(source_file, destination_file) 106 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 107 | if exit_code != 0: 108 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 109 | exit(exit_code) 110 | 111 | destination_file = os.path.join(args.source_path, "images_4", file) 112 | shutil.copy2(source_file, destination_file) 113 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 114 | if exit_code != 0: 115 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 116 | exit(exit_code) 117 | 118 | destination_file = os.path.join(args.source_path, "images_8", file) 119 | shutil.copy2(source_file, destination_file) 120 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 121 | if exit_code != 0: 122 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 123 | exit(exit_code) 124 | 125 | print("Done.") 126 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 2 | 3 | ENV TZ=Europe/Rome 4 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 5 | 6 | SHELL ["/bin/bash", "--login", "-c"] 7 | 8 | RUN apt-get update && apt-get install -y \ 9 | wget \ 10 | htop \ 11 | git \ 12 | nano \ 13 | cmake \ 14 | unzip \ 15 | zip \ 16 | vim \ 17 | libglu1-mesa-dev freeglut3-dev mesa-common-dev \ 18 | libopencv-dev \ 19 | libglew-dev \ 20 | assimp-utils libassimp-dev \ 21 | libboost-all-dev \ 22 | libglfw3-dev \ 23 | libgtk-3-dev \ 24 | ffmpeg libavcodec-dev libavdevice-dev libavfilter-dev libavformat-dev libavutil-dev \ 25 | libeigen3-dev \ 26 | libgl1-mesa-dev xorg-dev \ 27 | libembree-dev 28 | 29 | RUN ln -s /lib/x86_64-linux-gnu/libembree3.so /lib/x86_64-linux-gnu/libembree.so 30 | 31 | ENV PYTHONDONTWRITEBYTECODE=1 32 | ENV PYTHONUNBUFFERED=1 33 | 34 | ENV LD_LIBRARY_PATH /usr/lib64:$LD_LIBRARY_PATH 35 | 36 | ENV NVIDIA_VISIBLE_DEVICES all 37 | ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,graphics 38 | 39 | ENV PYOPENGL_PLATFORM egl 40 | 41 | #RUN ls /usr/share/glvnd/egl_vendor.d/ 42 | #COPY docker/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json 43 | 44 | # fixuid 45 | ARG USERNAME=user 46 | RUN apt-get update && apt-get install -y sudo curl && \ 47 | addgroup --gid 1000 $USERNAME && \ 48 | adduser --uid 1000 --gid 1000 --disabled-password --gecos '' $USERNAME && \ 49 | adduser $USERNAME sudo && \ 50 | echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers && \ 51 | USER=$USERNAME && \ 52 | GROUP=$USERNAME && \ 53 | curl -SsL https://github.com/boxboat/fixuid/releases/download/v0.4/fixuid-0.4-linux-amd64.tar.gz | tar -C /usr/local/bin -xzf - && \ 54 | chown root:root /usr/local/bin/fixuid && \ 55 | chmod 4755 /usr/local/bin/fixuid && \ 56 | mkdir -p /etc/fixuid && \ 57 | printf "user: $USER\ngroup: $GROUP\n" > /etc/fixuid/config.yml 58 | USER $USERNAME:$USERNAME 59 | 60 | # miniforge 61 | WORKDIR /home/$USERNAME 62 | ENV CONDA_AUTO_UPDATE_CONDA=false 63 | ENV PATH=/home/$USERNAME/miniforge/bin:$PATH 64 | 65 | RUN wget --quiet https://github.com/conda-forge/miniforge/releases/download/24.11.3-2/Miniforge3-24.11.3-2-Linux-x86_64.sh -O ~/miniforge.sh && \ 66 | chmod +x ~/miniforge.sh && \ 67 | ~/miniforge.sh -b -p ~/miniforge 68 | 69 | #RUN echo 112 70 | COPY docker/environment.yml /home/$USERNAME/environment.yml 71 | RUN conda env create -f /home/$USERNAME/environment.yml 72 | ENV PATH=/home/$USERNAME/miniforge/envs/bbsplat/bin:$PATH 73 | 74 | RUN echo "source activate bbsplat" > ~/.bashrc 75 | ENV PATH /opt/conda/envs/bbsplat/bin:$PATH 76 | 77 | # python libs 78 | RUN pip install --upgrade pip 79 | 80 | 81 | # docker setup 82 | WORKDIR / 83 | ENTRYPOINT ["fixuid", "-q"] 84 | CMD ["fixuid", "-q", "bash"] 85 | -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 4 | source ${CURRENT_DIR}/source.sh 5 | 6 | DOCKER_BUILDKIT=0 docker build -t $NAME --build-arg ssh_prv_key="$(cat ~/.ssh/id_rsa)" --build-arg ssh_pub_key="$(cat ~/.ssh/id_rsa.pub)" -f ${CURRENT_DIR}/Dockerfile ${CURRENT_DIR}/.. 7 | -------------------------------------------------------------------------------- /docker/environment.yml: -------------------------------------------------------------------------------- 1 | name: bbsplat 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - nvidia 6 | - defaults 7 | - open3d-admin 8 | - anaconda 9 | dependencies: 10 | - pip=23.3.1 11 | - ffmpeg=6.1.1=h4c62175_0 12 | - jpeg=9e=h5eee18b_3 13 | - ncurses=6.4=h6a678d5_0 14 | - networkx=3.1=py38h06a4308_0 15 | - numpy=1.24.3=py38h14f4228_0 16 | - numpy-base=1.24.3=py38h31eccc5_0 17 | - openh264=2.1.1=h4ff587b_0 18 | - openjpeg=2.5.2=he7f1fd0_0 19 | - pillow=10.2.0=py38h5eee18b_0 20 | - plyfile=1.0.3=pyhd8ed1ab_0 21 | - python=3.8.18=h955ad1f_0 22 | - pytorch=2.0.0=py3.8_cuda11.8_cudnn8.7.0_0 23 | - pytorch-cuda=11.8=h7e8668a_5 24 | - torchaudio=2.0.0=py38_cu118 25 | - torchtriton=2.0.0=py38 26 | - torchvision=0.15.0=py38_cu118 27 | - typing_extensions=4.9.0=py38h06a4308_1 28 | - open3d=0.11.2 29 | - scikit-learn=1.3.0 30 | - addict=2.4.0 31 | - pandas=2.0.3 32 | - ninja=1.12.1 33 | - pip: 34 | - mediapy==1.1.2 35 | - opencv-python==4.9.0.80 36 | - scikit-image==0.21.0 37 | - tqdm==4.66.2 38 | - trimesh 39 | - xatlas 40 | - git+https://github.com/facebookresearch/pytorch3d.git 41 | - git+https://github.com/NVlabs/nvdiffrast.git 42 | -------------------------------------------------------------------------------- /docker/push.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 4 | source ${CURRENT_DIR}/source.sh 5 | 6 | docker tag $NAME $HEAD_NAME 7 | docker push $HEAD_NAME -------------------------------------------------------------------------------- /docker/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 4 | source ${CURRENT_DIR}/source.sh 5 | 6 | docker run -e DISPLAY=unix$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix -ti --gpus all $VOLUMES $NAME $@ 7 | 8 | -------------------------------------------------------------------------------- /docker/source.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | NAME="bbsplat" 4 | VOLUMES="-v ./..:/home/bbsplat -v /media/dsvitov/DATA:/media/dsvitov/DATA -w /home/bbsplat" 5 | -------------------------------------------------------------------------------- /docker_colmap/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 4 | source ${CURRENT_DIR}/source.sh 5 | 6 | docker run -e DISPLAY=unix$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix -ti --gpus all $VOLUMES $NAME $@ 7 | 8 | -------------------------------------------------------------------------------- /docker_colmap/source.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | NAME="jsantisi/colmap-gpu" 4 | VOLUMES="-v /home/dsvitov/Code/textured-splatting:/home/dsvitov/Code/textured-splatting -v /home/dsvitov/Datasets:/home/dsvitov/Datasets -v /media/dsvitov/DATA1:/media/dsvitov/DATA -w /home/dsvitov/Code/textured-splatting" 5 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_bbsplat_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | from utils.point_utils import depth_to_normal 18 | 19 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, additional_return=True): 20 | """ 21 | Render the scene. 22 | 23 | Background tensor (bg_color) must be on GPU! 24 | """ 25 | 26 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 27 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 28 | try: 29 | screenspace_points.retain_grad() 30 | except: 31 | pass 32 | 33 | # Set up rasterization configuration 34 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 35 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 36 | 37 | raster_settings = GaussianRasterizationSettings( 38 | image_height=int(viewpoint_camera.image_height), 39 | image_width=int(viewpoint_camera.image_width), 40 | tanfovx=tanfovx, 41 | tanfovy=tanfovy, 42 | bg=bg_color, 43 | scale_modifier=scaling_modifier, 44 | viewmatrix=viewpoint_camera.world_view_transform, 45 | projmatrix=viewpoint_camera.full_proj_transform, 46 | sh_degree=pc.active_sh_degree, 47 | campos=viewpoint_camera.camera_center, 48 | prefiltered=False, 49 | debug=False, 50 | # pipe.debug 51 | ) 52 | 53 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 54 | 55 | means3D = pc.get_xyz 56 | means2D = screenspace_points 57 | 58 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 59 | # scaling / rotation by the rasterizer. 60 | scales = None 61 | rotations = None 62 | cov3D_precomp = None 63 | if pipe.compute_cov3D_python: 64 | cov3D_precomp = pc.get_covariance(scaling_modifier) 65 | else: 66 | scales = pc.get_scaling 67 | rotations = pc.get_rotation 68 | 69 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 70 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 71 | pipe.convert_SHs_python = False 72 | shs = None 73 | colors_precomp = None 74 | if override_color is None: 75 | if pipe.convert_SHs_python: 76 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 77 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 78 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 79 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 80 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 81 | else: 82 | shs = pc.get_features 83 | else: 84 | colors_precomp = override_color 85 | 86 | try: 87 | means3D.retain_grad() 88 | except: 89 | pass 90 | 91 | texture_alpha = pc.get_texture_alpha 92 | texture_color = pc.get_texture_color 93 | 94 | start_timer = torch.cuda.Event(enable_timing=True) 95 | end_timer = torch.cuda.Event(enable_timing=True) 96 | start_timer.record() 97 | 98 | rendered_image, radii, impact, allmap = rasterizer( 99 | means3D = means3D, 100 | means2D = means2D, 101 | shs = shs, 102 | colors_precomp = colors_precomp, 103 | texture_alpha = texture_alpha, 104 | texture_color = texture_color, 105 | scales = scales, 106 | rotations = rotations, 107 | cov3D_precomp = cov3D_precomp, 108 | ) 109 | 110 | end_timer.record() 111 | torch.cuda.synchronize() 112 | start_timer.elapsed_time(end_timer) 113 | fps = 1000 / start_timer.elapsed_time(end_timer) 114 | 115 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 116 | # They will be excluded from value updates used in the splitting criteria. 117 | rets = {"render": rendered_image, 118 | "viewspace_points": means2D, 119 | "visibility_filter" : impact > 0, 120 | "radii": radii, 121 | "impact": impact, 122 | "fps": fps, 123 | } 124 | 125 | if additional_return: 126 | # additional regularizations 127 | render_alpha = allmap[1:2] 128 | 129 | # get normal map 130 | render_normal = allmap[2:5] 131 | render_normal = (render_normal.permute(1,2,0) @ (viewpoint_camera.world_view_transform[:3,:3].T)).permute(2,0,1) 132 | 133 | # get median depth map 134 | render_depth_median = allmap[5:6] 135 | render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) 136 | 137 | # get expected depth map 138 | render_depth_expected = allmap[0:1] 139 | render_depth_expected = (render_depth_expected / render_alpha) 140 | render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) 141 | 142 | # get depth distortion map 143 | render_dist = allmap[6:7] 144 | 145 | # psedo surface attributes 146 | # surf depth is either median or expected by setting depth_ratio to 1 or 0 147 | # for bounded scene, use median depth, i.e., depth_ratio = 1; 148 | # for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing. 149 | surf_depth = render_depth_expected * (1-pipe.depth_ratio) + (pipe.depth_ratio) * render_depth_median 150 | 151 | # assume the depth points form the 'surface' and generate psudo surface normal for regularizations. 152 | surf_normal = depth_to_normal(viewpoint_camera, surf_depth) 153 | surf_normal = surf_normal.permute(2,0,1) 154 | # remember to multiply with accum_alpha since render_normal is unnormalized. 155 | surf_normal = surf_normal * (render_alpha).detach() 156 | 157 | 158 | rets.update({ 159 | 'rend_alpha': render_alpha, 160 | 'rend_normal': render_normal, 161 | 'rend_dist': render_dist, 162 | 'surf_depth': surf_depth, 163 | 'surf_normal': surf_normal, 164 | }) 165 | 166 | return rets 167 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1', 10 | size_average: bool = True): 11 | r"""Function that measures 12 | Learned Perceptual Image Patch Similarity (LPIPS). 13 | 14 | Arguments: 15 | x, y (torch.Tensor): the input tensors to compare. 16 | net_type (str): the network type to compare the features: 17 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 18 | version (str): the version of LPIPS. Default: 0.1. 19 | """ 20 | device = x.device 21 | criterion = LPIPS(net_type, version).to(device) 22 | return criterion(x, y, size_average) 23 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as nnf 4 | 5 | from .networks import get_network, LinLayers 6 | from .utils import get_state_dict 7 | 8 | 9 | class LPIPS(nn.Module): 10 | r"""Creates a criterion that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | net_type (str): the network type to compare the features: 15 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 16 | version (str): the version of LPIPS. Default: 0.1. 17 | """ 18 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 19 | 20 | assert version in ['0.1'], 'v0.1 is only supported now' 21 | 22 | super(LPIPS, self).__init__() 23 | 24 | # pretrained network 25 | self.net = get_network(net_type) 26 | 27 | # linear layers 28 | self.lin = LinLayers(self.net.n_channels_list) 29 | self.lin.load_state_dict(get_state_dict(net_type, version)) 30 | 31 | def forward(self, x: torch.Tensor, y: torch.Tensor, size_average: bool = True): 32 | _, _, H, W = x.shape 33 | feat_x, feat_y = self.net(x), self.net(y) 34 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 35 | 36 | if size_average: 37 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 38 | return torch.sum(torch.cat(res, 0), 0, True) 39 | else: 40 | res = [l(d) for d, l in zip(diff, self.lin)] 41 | res = [nnf.interpolate(f, size=(H, W), mode='bicubic', align_corners=False) for f in res] 42 | return torch.sum(torch.cat(res, 0), 0) 43 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in os.listdir(renders_dir): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | 43 | for scene_dir in model_paths: 44 | try: 45 | print("Scene:", scene_dir) 46 | full_dict[scene_dir] = {} 47 | per_view_dict[scene_dir] = {} 48 | full_dict_polytopeonly[scene_dir] = {} 49 | per_view_dict_polytopeonly[scene_dir] = {} 50 | 51 | test_dir = Path(scene_dir) / "test" 52 | 53 | for method in os.listdir(test_dir): 54 | print("Method:", method) 55 | 56 | full_dict[scene_dir][method] = {} 57 | per_view_dict[scene_dir][method] = {} 58 | full_dict_polytopeonly[scene_dir][method] = {} 59 | per_view_dict_polytopeonly[scene_dir][method] = {} 60 | 61 | method_dir = test_dir / method 62 | gt_dir = method_dir/ "gt" 63 | renders_dir = method_dir / "renders" 64 | renders, gts, image_names = readImages(renders_dir, gt_dir) 65 | 66 | ssims = [] 67 | psnrs = [] 68 | lpipss = [] 69 | 70 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 71 | ssims.append(ssim(renders[idx], gts[idx])) 72 | psnrs.append(psnr(renders[idx], gts[idx])) 73 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 74 | 75 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 76 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 77 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 78 | print("") 79 | 80 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 81 | "PSNR": torch.tensor(psnrs).mean().item(), 82 | "LPIPS": torch.tensor(lpipss).mean().item()}) 83 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 84 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 85 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 86 | 87 | with open(scene_dir + "/results.json", 'w') as fp: 88 | json.dump(full_dict[scene_dir], fp, indent=True) 89 | with open(scene_dir + "/per_view.json", 'w') as fp: 90 | json.dump(per_view_dict[scene_dir], fp, indent=True) 91 | except: 92 | print("Unable to compute metrics for model", scene_dir) 93 | 94 | if __name__ == "__main__": 95 | device = torch.device("cuda:0") 96 | torch.cuda.set_device(device) 97 | 98 | # Set up command line argument parser 99 | parser = ArgumentParser(description="Training script parameters") 100 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 101 | args = parser.parse_args() 102 | evaluate(args.model_paths) 103 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import json 12 | import math 13 | import os 14 | from argparse import ArgumentParser 15 | 16 | import cv2 17 | import numpy as np 18 | import nvdiffrast.torch as dr 19 | import open3d as o3d 20 | import torch 21 | import torch.nn.functional as F 22 | import xatlas 23 | from pytorch3d.io import save_obj 24 | from tqdm import tqdm 25 | 26 | from arguments import ModelParams, PipelineParams, get_combined_args 27 | from gaussian_renderer import GaussianModel 28 | from gaussian_renderer import render 29 | from scene import Scene 30 | from utils.general_utils import build_scaling_rotation 31 | from utils.loss_utils import l1_loss 32 | from utils.mesh_utils import GaussianExtractor, post_process_mesh 33 | from utils.render_utils import generate_path, create_videos, save_img_u8 34 | from utils.sh_utils import SH2RGB 35 | 36 | 37 | def unwrap_uvmap(mesh, device="cuda"): 38 | v_np = np.asarray(mesh.vertices) # [N, 3] 39 | f_np = np.asarray(mesh.triangles) # [M, 3] 40 | 41 | print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}') 42 | 43 | # unwrap uv in contracted space 44 | atlas = xatlas.Atlas() 45 | atlas.add_mesh(v_np, f_np) 46 | chart_options = xatlas.ChartOptions() 47 | chart_options.max_iterations = 0 # disable merge_chart for faster unwrap... 48 | pack_options = xatlas.PackOptions() 49 | # pack_options.blockAlign = True 50 | # pack_options.bruteForce = False 51 | atlas.generate(chart_options=chart_options, pack_options=pack_options) 52 | vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] 53 | 54 | vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device) 55 | ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device) 56 | 57 | print("UV shape:", vt.shape) 58 | 59 | v_torch = torch.from_numpy(v_np.astype(np.float32)).to(device) 60 | f_torch = torch.from_numpy(f_np).to(device) 61 | 62 | return v_torch, f_torch, vt, ft 63 | 64 | def render_mesh(v_torch, f_torch, uv, uv_idx, cudactx, texture, cam): 65 | mvp = cam.full_proj_transform 66 | vertices_clip = torch.matmul(F.pad(v_torch, pad=(0, 1), mode='constant', value=1.0), mvp).float().unsqueeze(0) 67 | rast, _ = dr.rasterize(cudactx, vertices_clip, f_torch, resolution=[cam.image_height, cam.image_width]) 68 | texc, _ = dr.interpolate(uv[None, ...], rast, uv_idx) 69 | color = dr.texture(texture[None, ...], texc, filter_mode='linear')[0] 70 | return color 71 | 72 | def train_texture(v_torch, f_torch, uv, uv_idx, cudactx, texture, scene): 73 | optimizer = torch.optim.Adam([texture], lr=0.01) 74 | for epoch in tqdm(range(300)): 75 | for cam in scene.getTrainCameras(): 76 | optimizer.zero_grad() 77 | color = render_mesh(v_torch, f_torch, uv, uv_idx, cudactx, F.sigmoid(texture), cam) 78 | 79 | gt = torch.permute(cam.original_image.cuda(), (1, 2, 0)) 80 | Ll1 = l1_loss(color, gt) 81 | #ssim_map = ssim(color, gt, size_average=False).mean() 82 | loss = Ll1 # * 0.8 + ssim_map * 0.2 83 | loss.backward() 84 | optimizer.step() 85 | 86 | def billboard_to_plane(xyz, transform, rgb, alpha, texture_size, num_textures_x, vertices, faces, stitched_texture, uv, uv_idx): 87 | vertices_local = torch.tensor([[-1, -1, 0], [1, 1, 0], [1, -1, 0], [-1, 1, 0]], dtype=torch.float32).cuda() 88 | faces_local = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int32).cuda() 89 | 90 | # Scaling + Rotation 91 | vertices_local = vertices_local @ transform.T 92 | # Offset 93 | vertices_local += xyz 94 | 95 | # Add to the "mesh" 96 | faces_local += 4 * len(faces) 97 | faces.append(faces_local) 98 | vertices.append(vertices_local) 99 | 100 | # Add tile to the texture 101 | num = len(vertices) - 1 102 | y = num // num_textures_x 103 | x = num % num_textures_x 104 | h, w = alpha.shape 105 | stitched_texture[:3, y*texture_size: y*texture_size + h, x*texture_size: x*texture_size + w] = rgb 106 | stitched_texture[3:, y*texture_size: y*texture_size + h, x*texture_size: x*texture_size + w] = alpha[None] 107 | 108 | u = x*texture_size / stitched_texture.shape[2] 109 | v = y*texture_size / stitched_texture.shape[1] 110 | offset_u = h / stitched_texture.shape[2] 111 | offset_v = w / stitched_texture.shape[1] 112 | uv_local = torch.tensor([[u, v], [u + offset_u, v + offset_v], [u + offset_u, v], [u, v + offset_v]], dtype=torch.float32).cuda() 113 | uv.append(uv_local) 114 | uv_idx.append(faces_local) 115 | 116 | def billboards_to_mesh(gaussians, save_folder): 117 | num_points = len(gaussians.get_xyz) 118 | gaps = 2 119 | texture_size = gaussians.get_texture_alpha.shape[-1] + gaps 120 | num_textures_x = int(math.sqrt(num_points)) 121 | globa_texture_size = num_textures_x * texture_size 122 | global_rgba = torch.zeros([4, globa_texture_size + texture_size*2, globa_texture_size]).cuda() 123 | 124 | transform = build_scaling_rotation(gaussians.get_scaling, gaussians.get_rotation) 125 | 126 | vertices = [] 127 | faces = [] 128 | uv = [] 129 | uv_idx = [] 130 | for i in tqdm(range(num_points)): 131 | #if gaussians.get_scaling[i].min() > 1: 132 | # continue 133 | billboard_to_plane( 134 | gaussians.get_xyz[i], transform[i], gaussians.get_texture_color[i] + SH2RGB(gaussians.get_features_first[i])[0, :, None, None], 135 | gaussians.get_texture_alpha[i], texture_size, num_textures_x, 136 | vertices, faces, global_rgba, uv, uv_idx, 137 | ) 138 | vertices = torch.concat(vertices) 139 | faces = torch.concat(faces) 140 | uv = torch.concat(uv) 141 | uv_idx = torch.concat(uv_idx) 142 | 143 | print(vertices.shape, faces.shape) 144 | 145 | global_rgba = torch.permute(global_rgba, (1, 2, 0)) 146 | global_rgba = torch.flip(global_rgba, [0]) 147 | save_obj( 148 | os.path.join(save_folder, "planes_mesh.obj"), 149 | verts=vertices, 150 | faces=faces, 151 | verts_uvs=uv, 152 | faces_uvs=uv_idx, 153 | texture_map=global_rgba[..., :3], 154 | ) 155 | print(global_rgba.shape) 156 | global_rgba = global_rgba.detach().cpu().numpy() 157 | global_rgba[..., :3] = cv2.cvtColor(global_rgba[..., :3], cv2.COLOR_BGR2RGB) 158 | cv2.imwrite(os.path.join(save_folder, "planes_mesh.png"), global_rgba * 255) 159 | 160 | def prune_based_on_visibility(scene, gaussians, pipe, background): 161 | with torch.no_grad(): 162 | # Calculate impact 163 | acc_impact = None 164 | for camera in scene.getTrainCameras(): 165 | render_pkg = render(camera, gaussians, pipe, background) 166 | impact = render_pkg["impact"] 167 | if acc_impact is None: 168 | acc_impact = impact 169 | else: 170 | acc_impact += impact 171 | 172 | prob = acc_impact / acc_impact.sum() 173 | mask = prob > 1e-6 174 | 175 | mask = mask & (torch.amax(gaussians.get_texture_alpha, dim=(1, 2)) > 0.2) 176 | gaussians.prune_postproc(mask) 177 | 178 | if __name__ == "__main__": 179 | # Set up command line argument parser 180 | parser = ArgumentParser(description="Testing script parameters") 181 | model = ModelParams(parser, sentinel=True) 182 | pipeline = PipelineParams(parser) 183 | parser.add_argument("--iteration", default=-1, type=int) 184 | parser.add_argument("--skip_train", action="store_true") 185 | parser.add_argument("--skip_test", action="store_true") 186 | parser.add_argument("--skip_mesh", action="store_true") 187 | parser.add_argument("--save_planes", action="store_true") 188 | parser.add_argument("--quiet", action="store_true") 189 | parser.add_argument("--render_path", action="store_true") 190 | parser.add_argument("--voxel_size", default=0.004, type=float, help='Mesh: voxel size for TSDF') 191 | parser.add_argument("--depth_trunc", default=3.0, type=float, help='Mesh: Max depth range for TSDF') 192 | parser.add_argument("--sdf_trunc", default=-1.0, type=float, help='Mesh: truncation value for TSDF') 193 | parser.add_argument("--num_cluster", default=1000, type=int, help='Mesh: number of connected clusters to export') 194 | parser.add_argument("--unbounded", action="store_true", help='Mesh: using unbounded mode for meshing') 195 | parser.add_argument("--mesh_res", default=1024, type=int, help='Mesh: resolution for unbounded mesh extraction') 196 | args = get_combined_args(parser) 197 | print("Rendering " + args.model_path) 198 | 199 | 200 | dataset, iteration, pipe = model.extract(args), args.iteration, pipeline.extract(args) 201 | gaussians = GaussianModel(dataset.sh_degree, texture_preproc=True) 202 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 203 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 204 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 205 | 206 | train_dir = os.path.join(args.model_path, 'train', "ours_{}".format(scene.loaded_iter)) 207 | test_dir = os.path.join(args.model_path, 'test', "ours_{}".format(scene.loaded_iter)) 208 | gaussExtractor = GaussianExtractor(gaussians, render, pipe, bg_color=bg_color, additional_return=True) 209 | 210 | speed_data = {"points": len(gaussians.get_xyz)} 211 | 212 | if not args.skip_train: 213 | print("export training images ...") 214 | os.makedirs(train_dir, exist_ok=True) 215 | mean_time, std_time = gaussExtractor.reconstruction(scene.getTrainCameras()) 216 | speed_data["train_time"] = mean_time 217 | speed_data["train_time_std"] = std_time 218 | gaussExtractor.export_image(train_dir) 219 | 220 | 221 | if (not args.skip_test) and (len(scene.getTestCameras()) > 0): 222 | print("export rendered testing images ...") 223 | os.makedirs(test_dir, exist_ok=True) 224 | mean_time, std_time = gaussExtractor.reconstruction(scene.getTestCameras()) 225 | speed_data["test_time"] = mean_time 226 | speed_data["test_time_std"] = std_time 227 | gaussExtractor.export_image(test_dir) 228 | 229 | with open(os.path.join(args.model_path, "speed.json"), "w") as f: 230 | json.dump(speed_data, f) 231 | 232 | if args.render_path: 233 | print("render videos ...") 234 | traj_dir = os.path.join(args.model_path, 'traj', "ours_{}".format(scene.loaded_iter)) 235 | os.makedirs(traj_dir, exist_ok=True) 236 | n_fames = 480 237 | cam_traj = generate_path(scene.getTrainCameras(), n_frames=n_fames) 238 | gaussExtractor.reconstruction(cam_traj) 239 | gaussExtractor.export_image(traj_dir, export_gt=False) #, print_fps=True 240 | create_videos(base_dir=traj_dir, 241 | input_dir=traj_dir, 242 | out_name='render_traj', 243 | num_frames=n_fames) 244 | 245 | if args.save_planes: 246 | # CONVERT TO SET OF PLANES 247 | prune_based_on_visibility(scene, gaussians, pipe, background) 248 | billboards_to_mesh(gaussians, args.model_path) 249 | 250 | if not args.skip_mesh: 251 | print("export mesh ...") 252 | os.makedirs(train_dir, exist_ok=True) 253 | # set the active_sh to 0 to export only diffuse texture 254 | gaussExtractor.gaussians.active_sh_degree = 0 255 | gaussExtractor.reconstruction(scene.getTrainCameras()) 256 | print("ckpt 1 ...") 257 | # extract the mesh and save 258 | if args.unbounded: 259 | name = 'fuse_unbounded.ply' 260 | mesh = gaussExtractor.extract_mesh_unbounded(resolution=args.mesh_res) 261 | else: 262 | name = 'fuse.ply' 263 | #mesh = gaussExtractor.extract_mesh_bounded(voxel_size=args.voxel_size, sdf_trunc=5*args.voxel_size, depth_trunc=args.depth_trunc) 264 | depth_trunc = (gaussExtractor.radius * 2.0) if args.depth_trunc < 0 else args.depth_trunc 265 | voxel_size = (depth_trunc / args.mesh_res) if args.voxel_size < 0 else args.voxel_size 266 | sdf_trunc = 5.0 * voxel_size if args.sdf_trunc < 0 else args.sdf_trunc 267 | mesh = gaussExtractor.extract_mesh_bounded(voxel_size=voxel_size, sdf_trunc=sdf_trunc, depth_trunc=depth_trunc) 268 | 269 | print("ckpt 2 ...") 270 | o3d.io.write_triangle_mesh(os.path.join(train_dir, name), mesh) 271 | print("mesh saved at {}".format(os.path.join(train_dir, name))) 272 | # post-process the mesh and save, saving the largest N clusters 273 | mesh_post = post_process_mesh(mesh, cluster_to_keep=args.num_cluster) 274 | o3d.io.write_triangle_mesh(os.path.join(train_dir, name.replace('.ply', '_post.ply')), mesh_post) 275 | print("mesh post processed saved at {}".format(os.path.join(train_dir, name.replace('.ply', '_post.ply')))) 276 | 277 | # TEXTURE EXTRACTION 278 | device = "cuda" 279 | # Unwrap the uv-map for the mesh 280 | v_cuda, f_cuda, uv, uv_idx = unwrap_uvmap(mesh, device) 281 | 282 | texture = 0.5 + torch.randn((1024, 1024, 3), dtype=torch.float32, device=device) * 0.001 283 | texture = torch.nn.Parameter(texture, requires_grad=True) 284 | 285 | cudactx = dr.RasterizeCudaContext() 286 | 287 | # Train texture from input images 288 | train_texture(v_cuda, f_cuda, uv, uv_idx, cudactx, texture, scene) 289 | texture = F.sigmoid(texture) 290 | 291 | # Render textured mesh to the folder 292 | mesh_path = os.path.join(train_dir, "mesh") 293 | os.makedirs(mesh_path, exist_ok=True) 294 | for idx, cam in enumerate(scene.getTrainCameras()): 295 | mvp = cam.full_proj_transform 296 | color = render_mesh(v_cuda, f_cuda, uv, uv_idx, cudactx, texture, cam) 297 | color = torch.permute(color, (2, 0, 1)) 298 | save_img_u8(color, os.path.join(mesh_path, '{0:05d}'.format(idx) + ".png")) 299 | 300 | save_obj( 301 | os.path.join(args.model_path, "textured_mesh.obj"), 302 | verts=v_cuda, 303 | faces=f_cuda, 304 | verts_uvs=uv, 305 | faces_uvs=uv_idx, 306 | texture_map=torch.flip(texture, [0]), 307 | ) 308 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | import numpy as np 21 | import cv2 22 | 23 | class Scene: 24 | 25 | gaussians : GaussianModel 26 | 27 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0], 28 | add_sky_box=False, max_read_points=60_000, sphere_point=10_000): 29 | """b 30 | :param path: Path to colmap scene main folder. 31 | """ 32 | self.model_path = args.model_path 33 | self.loaded_iter = None 34 | self.gaussians = gaussians 35 | 36 | if load_iteration: 37 | if load_iteration == -1: 38 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 39 | else: 40 | self.loaded_iter = load_iteration 41 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 42 | 43 | self.train_cameras = {} 44 | self.test_cameras = {} 45 | 46 | if os.path.exists(os.path.join(args.source_path, "sparse")): 47 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, max_points=max_read_points) 48 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 49 | print("Found transforms_train.json file, assuming Blender data set!") 50 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, max_points=max_read_points) 51 | elif os.path.exists(os.path.join(args.source_path, "inputs/sfm_scene.json")): 52 | print("Found sfm_scene.json file, assuming NeILF data set!") 53 | scene_info = sceneLoadTypeCallbacks["NeILF"](args.source_path, args.white_background, args.eval) 54 | else: 55 | assert False, "Could not recognize scene type!" 56 | 57 | if not self.loaded_iter: 58 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 59 | dest_file.write(src_file.read()) 60 | json_cams = [] 61 | camlist = [] 62 | if scene_info.test_cameras: 63 | camlist.extend(scene_info.test_cameras) 64 | if scene_info.train_cameras: 65 | camlist.extend(scene_info.train_cameras) 66 | for id, cam in enumerate(camlist): 67 | json_cams.append(camera_to_JSON(id, cam)) 68 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 69 | json.dump(json_cams, file) 70 | 71 | if shuffle: 72 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 73 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 74 | 75 | self.cameras_extent = scene_info.nerf_normalization["radius"] 76 | 77 | for resolution_scale in resolution_scales: 78 | print("Loading Training Cameras") 79 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 80 | print("Loading Test Cameras") 81 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 82 | 83 | if self.loaded_iter: 84 | folder_path = os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter)) 85 | self.gaussians.load_ply(os.path.join(folder_path, "point_cloud.ply")) 86 | self.gaussians.load_texture(folder_path) 87 | else: 88 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, add_sky_box=add_sky_box, sphere_point=sphere_point) 89 | 90 | def save(self, iteration): 91 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 92 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 93 | self.gaussians.save_texture(point_cloud_path) 94 | 95 | def getTrainCameras(self, scale=1.0): 96 | return self.train_cameras[scale] 97 | 98 | def getTestCameras(self, scale=1.0): 99 | return self.test_cameras[scale] 100 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.image_name = image_name 31 | 32 | try: 33 | self.data_device = torch.device(data_device) 34 | except Exception as e: 35 | print(e) 36 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 37 | self.data_device = torch.device("cuda") 38 | 39 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 40 | self.image_width = self.original_image.shape[2] 41 | self.image_height = self.original_image.shape[1] 42 | 43 | if gt_alpha_mask is not None: 44 | # self.original_image *= gt_alpha_mask.to(self.data_device) 45 | self.gt_alpha_mask = gt_alpha_mask.to(self.data_device) 46 | else: 47 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 48 | self.gt_alpha_mask = None 49 | 50 | self.zfar = 100.0 51 | self.znear = 0.01 52 | 53 | self.trans = trans 54 | self.scale = scale 55 | 56 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 57 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 58 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 59 | self.camera_center = self.world_view_transform.inverse()[3, :3] 60 | 61 | def update_proj_matrix(self): 62 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 63 | self.camera_center = self.world_view_transform.inverse()[3, :3] 64 | 65 | class MiniCam: 66 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 67 | self.image_width = width 68 | self.image_height = height 69 | self.FoVy = fovy 70 | self.FoVx = fovx 71 | self.znear = znear 72 | self.zfar = zfar 73 | self.world_view_transform = world_view_transform 74 | self.full_proj_transform = full_proj_transform 75 | view_inv = torch.inverse(self.world_view_transform) 76 | self.camera_center = view_inv[3][:3] 77 | 78 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import json 13 | import os 14 | import sys 15 | from pathlib import Path 16 | from typing import NamedTuple 17 | 18 | import numpy as np 19 | import torch 20 | from PIL import Image 21 | from plyfile import PlyData, PlyElement 22 | from pytorch3d.ops import sample_farthest_points 23 | import imageio.v2 as imageio 24 | import glob 25 | import re 26 | 27 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 28 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 29 | from scene.gaussian_model import BasicPointCloud 30 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 31 | from utils.sh_utils import SH2RGB 32 | 33 | 34 | class CameraInfo(NamedTuple): 35 | uid: int 36 | R: np.array 37 | T: np.array 38 | FovY: np.array 39 | FovX: np.array 40 | image: np.array 41 | image_path: str 42 | image_name: str 43 | width: int 44 | height: int 45 | image_id: int = None 46 | normal: Image.Image = None 47 | alpha: Image.Image = None 48 | depth: np.array = None 49 | 50 | class SceneInfo(NamedTuple): 51 | point_cloud: BasicPointCloud 52 | train_cameras: list 53 | test_cameras: list 54 | nerf_normalization: dict 55 | ply_path: str 56 | 57 | def getNerfppNorm(cam_info): 58 | def get_center_and_diag(cam_centers): 59 | cam_centers = np.hstack(cam_centers) 60 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 61 | center = avg_cam_center 62 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 63 | diagonal = np.max(dist) 64 | return center.flatten(), diagonal 65 | 66 | cam_centers = [] 67 | 68 | for cam in cam_info: 69 | W2C = getWorld2View2(cam.R, cam.T) 70 | C2W = np.linalg.inv(W2C) 71 | cam_centers.append(C2W[:3, 3:4]) 72 | 73 | center, diagonal = get_center_and_diag(cam_centers) 74 | radius = diagonal * 1.1 75 | 76 | translate = -center 77 | 78 | return {"translate": translate, "radius": radius} 79 | 80 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 81 | cam_infos = [] 82 | for idx, key in enumerate(cam_extrinsics): 83 | sys.stdout.write('\r') 84 | # the exact output you're looking for: 85 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 86 | sys.stdout.flush() 87 | 88 | extr = cam_extrinsics[key] 89 | intr = cam_intrinsics[extr.camera_id] 90 | height = intr.height 91 | width = intr.width 92 | 93 | uid = intr.id 94 | R = np.transpose(qvec2rotmat(extr.qvec)) 95 | T = np.array(extr.tvec) 96 | 97 | if intr.model=="SIMPLE_PINHOLE": 98 | focal_length_x = intr.params[0] 99 | FovY = focal2fov(focal_length_x, height) 100 | FovX = focal2fov(focal_length_x, width) 101 | elif intr.model=="PINHOLE": 102 | focal_length_x = intr.params[0] 103 | focal_length_y = intr.params[1] 104 | FovY = focal2fov(focal_length_y, height) 105 | FovX = focal2fov(focal_length_x, width) 106 | else: 107 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 108 | 109 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 110 | image_name = os.path.basename(image_path).split(".")[0] 111 | image = Image.open(image_path) 112 | 113 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 114 | image_path=image_path, image_name=image_name, width=width, height=height) 115 | cam_infos.append(cam_info) 116 | sys.stdout.write('\n') 117 | return cam_infos 118 | 119 | def fetchPly(path, max_points=60_000): 120 | plydata = PlyData.read(path) 121 | vertices = plydata['vertex'] 122 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 123 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 124 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 125 | 126 | if len(positions) >= max_points: 127 | #indices = np.random.randint(0, len(positions), size=max_points 128 | _, indices = sample_farthest_points(torch.tensor(positions[None]), K=max_points) 129 | indices = indices[0] 130 | positions = positions[indices] 131 | colors = colors[indices] 132 | normals = normals[indices] 133 | 134 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 135 | 136 | def storePly(path, xyz, rgb, normals=None): 137 | # Define the dtype for the structured array 138 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 139 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 140 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 141 | 142 | if normals is None: 143 | normals = np.zeros_like(xyz) 144 | 145 | elements = np.empty(xyz.shape[0], dtype=dtype) 146 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 147 | elements[:] = list(map(tuple, attributes)) 148 | 149 | # Create the PlyData object and write to file 150 | vertex_element = PlyElement.describe(elements, 'vertex') 151 | ply_data = PlyData([vertex_element]) 152 | ply_data.write(path) 153 | 154 | def readColmapSceneInfo(path, images, eval, llffhold=8, max_points=60_000): 155 | try: 156 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 157 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 158 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 159 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 160 | except: 161 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 162 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 163 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 164 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 165 | 166 | reading_dir = "images" if images == None else images 167 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) 168 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 169 | 170 | if eval: 171 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 172 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 173 | else: 174 | train_cam_infos = cam_infos 175 | test_cam_infos = [] 176 | 177 | nerf_normalization = getNerfppNorm(train_cam_infos) 178 | 179 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 180 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 181 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 182 | if not os.path.exists(ply_path): 183 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 184 | try: 185 | xyz, rgb, _ = read_points3D_binary(bin_path) 186 | except: 187 | xyz, rgb, _ = read_points3D_text(txt_path) 188 | storePly(ply_path, xyz, rgb) 189 | try: 190 | pcd = fetchPly(ply_path, max_points) 191 | except: 192 | pcd = None 193 | 194 | scene_info = SceneInfo(point_cloud=pcd, 195 | train_cameras=train_cam_infos, 196 | test_cameras=test_cam_infos, 197 | nerf_normalization=nerf_normalization, 198 | ply_path=ply_path) 199 | return scene_info 200 | 201 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 202 | cam_infos = [] 203 | 204 | with open(os.path.join(path, transformsfile)) as json_file: 205 | contents = json.load(json_file) 206 | fovx = contents["camera_angle_x"] 207 | 208 | frames = contents["frames"] 209 | for idx, frame in enumerate(frames): 210 | cam_name = os.path.join(path, frame["file_path"] + extension) 211 | 212 | # NeRF 'transform_matrix' is a camera-to-world transform 213 | c2w = np.array(frame["transform_matrix"]) 214 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 215 | c2w[:3, 1:3] *= -1 216 | 217 | # get the world-to-camera transform and set R, T 218 | w2c = np.linalg.inv(c2w) 219 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 220 | T = w2c[:3, 3] 221 | 222 | image_path = os.path.join(path, cam_name) 223 | image_name = Path(cam_name).stem 224 | image = Image.open(image_path) 225 | 226 | im_data = np.array(image.convert("RGBA")) 227 | 228 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 229 | 230 | norm_data = im_data / 255.0 231 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 232 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 233 | 234 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 235 | FovY = fovy 236 | FovX = fovx 237 | 238 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 239 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) 240 | 241 | return cam_infos 242 | 243 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png", max_points=60_000): 244 | print("Reading Training Transforms") 245 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 246 | print("Reading Test Transforms") 247 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 248 | 249 | if not eval: 250 | train_cam_infos.extend(test_cam_infos) 251 | test_cam_infos = [] 252 | 253 | nerf_normalization = getNerfppNorm(train_cam_infos) 254 | 255 | ply_path = os.path.join(path, "points3d.ply") 256 | if not os.path.exists(ply_path): 257 | # Since this data set has no colmap data, we start with random points 258 | num_pts = 100_000 259 | print(f"Generating random point cloud ({num_pts})...") 260 | 261 | # We create random points inside the bounds of the synthetic Blender scenes 262 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 263 | shs = np.random.random((num_pts, 3)) / 255.0 264 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 265 | 266 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 267 | try: 268 | pcd = fetchPly(ply_path, max_points) 269 | except: 270 | pcd = None 271 | 272 | scene_info = SceneInfo(point_cloud=pcd, 273 | train_cameras=train_cam_infos, 274 | test_cameras=test_cam_infos, 275 | nerf_normalization=nerf_normalization, 276 | ply_path=ply_path) 277 | return scene_info 278 | 279 | 280 | def load_img(path): 281 | if not "." in os.path.basename(path): 282 | files = glob.glob(path + '.*') 283 | assert len(files) > 0, "Tried to find image file for: %s, but found 0 files" % (path) 284 | path = files[0] 285 | if path.endswith(".exr"): 286 | assert False 287 | if pyexr is not None: 288 | exr_file = pyexr.open(path) 289 | # print(exr_file.channels) 290 | all_data = exr_file.get() 291 | img = all_data[..., 0:3] 292 | if "A" in exr_file.channels: 293 | mask = np.clip(all_data[..., 3:4], 0, 1) 294 | img = img * mask 295 | else: 296 | img = imageio.imread(path) 297 | import pdb; 298 | pdb.set_trace() 299 | img = np.nan_to_num(img) 300 | hdr = True 301 | else: # LDR image 302 | img = imageio.imread(path) 303 | img = img / 255 304 | # img[..., 0:3] = srgb_to_rgb_np(img[..., 0:3]) 305 | hdr = False 306 | return img, hdr 307 | 308 | 309 | def load_pfm(file: str): 310 | color = None 311 | width = None 312 | height = None 313 | scale = None 314 | endian = None 315 | with open(file, 'rb') as f: 316 | header = f.readline().rstrip() 317 | if header == b'PF': 318 | color = True 319 | elif header == b'Pf': 320 | color = False 321 | else: 322 | raise Exception('Not a PFM file.') 323 | dim_match = re.match(br'^(\d+)\s(\d+)\s$', f.readline()) 324 | if dim_match: 325 | width, height = map(int, dim_match.groups()) 326 | else: 327 | raise Exception('Malformed PFM header.') 328 | scale = float(f.readline().rstrip()) 329 | if scale < 0: # little-endian 330 | endian = '<' 331 | scale = -scale 332 | else: 333 | endian = '>' # big-endian 334 | data = np.fromfile(f, endian + 'f') 335 | shape = (height, width, 3) if color else (height, width) 336 | data = np.reshape(data, shape) 337 | data = data[::-1, ...] # cv2.flip(data, 0) 338 | 339 | return np.ascontiguousarray(data) 340 | 341 | 342 | def load_depth(tiff_path): 343 | return imageio.imread(tiff_path) 344 | 345 | 346 | def load_mask(mask_file): 347 | mask = imageio.imread(mask_file, mode='L') 348 | mask = mask.astype(np.float32) 349 | mask[mask > 0.5] = 1.0 350 | 351 | return mask 352 | 353 | 354 | def loadCamsFromScene(path, valid_list, background, debug): 355 | with open(f'{path}/sfm_scene.json') as f: 356 | sfm_scene = json.load(f) 357 | 358 | # load bbox transform 359 | bbox_transform = np.array(sfm_scene['bbox']['transform']).reshape(4, 4) 360 | bbox_transform = bbox_transform.copy() 361 | bbox_transform[[0, 1, 2], [0, 1, 2]] = bbox_transform[[0, 1, 2], [0, 1, 2]].max() / 2 362 | bbox_inv = np.linalg.inv(bbox_transform) 363 | 364 | # meta info 365 | image_list = sfm_scene['image_path']['file_paths'] 366 | 367 | # camera parameters 368 | train_cam_infos = [] 369 | test_cam_infos = [] 370 | camera_info_list = sfm_scene['camera_track_map']['images'] 371 | for i, (index, camera_info) in enumerate(camera_info_list.items()): 372 | if debug and i >= 5: break 373 | if camera_info['flg'] == 2: 374 | intrinsic = np.zeros((4, 4)) 375 | intrinsic[0, 0] = camera_info['camera']['intrinsic']['focal'][0] 376 | intrinsic[1, 1] = camera_info['camera']['intrinsic']['focal'][1] 377 | intrinsic[0, 2] = camera_info['camera']['intrinsic']['ppt'][0] 378 | intrinsic[1, 2] = camera_info['camera']['intrinsic']['ppt'][1] 379 | intrinsic[2, 2] = intrinsic[3, 3] = 1 380 | 381 | extrinsic = np.array(camera_info['camera']['extrinsic']).reshape(4, 4) 382 | c2w = np.linalg.inv(extrinsic) 383 | c2w[:3, 3] = (c2w[:4, 3] @ bbox_inv.T)[:3] 384 | extrinsic = np.linalg.inv(c2w) 385 | 386 | R = np.transpose(extrinsic[:3, :3]) 387 | T = extrinsic[:3, 3] 388 | 389 | focal_length_x = camera_info['camera']['intrinsic']['focal'][0] 390 | focal_length_y = camera_info['camera']['intrinsic']['focal'][1] 391 | ppx = camera_info['camera']['intrinsic']['ppt'][0] 392 | ppy = camera_info['camera']['intrinsic']['ppt'][1] 393 | 394 | image_path = os.path.join(path, image_list[index]) 395 | image_name = Path(image_path).stem 396 | 397 | image, is_hdr = load_img(image_path) 398 | 399 | depth_path = os.path.join(path + "/depths/", os.path.basename( 400 | image_list[index]).replace(os.path.splitext(image_list[index])[-1], ".tiff")) 401 | 402 | if os.path.exists(depth_path): 403 | depth = load_depth(depth_path) 404 | depth *= bbox_inv[0, 0] 405 | else: 406 | print("No depth map for test view.") 407 | depth = None 408 | 409 | normal_path = os.path.join(path + "/normals/", os.path.basename( 410 | image_list[index]).replace(os.path.splitext(image_list[index])[-1], ".pfm")) 411 | if os.path.exists(normal_path): 412 | normal = load_pfm(normal_path) 413 | else: 414 | print("No normal map for test view.") 415 | normal = None 416 | 417 | mask_path = os.path.join(path + "/pmasks/", os.path.basename( 418 | image_list[index]).replace(os.path.splitext(image_list[index])[-1], ".png")) 419 | if os.path.exists(mask_path): 420 | img_mask = (imageio.imread(mask_path, pilmode='L') > 0.1).astype(np.float32) 421 | # if pmask is available, mask the image for PSNR 422 | image *= img_mask[..., np.newaxis] 423 | else: 424 | img_mask = np.ones_like(image[:, :, 0]) 425 | 426 | fovx = focal2fov(focal_length_x, image.shape[1]) 427 | fovy = focal2fov(focal_length_y, image.shape[0]) 428 | if int(index) in valid_list: 429 | image *= img_mask[..., np.newaxis] 430 | image = Image.fromarray(np.array(image * 255.0, dtype=np.byte), "RGB") 431 | alpha = Image.fromarray(np.array(np.tile(img_mask[..., np.newaxis], (1, 1, 3)) * 255.0, dtype=np.byte), 432 | "RGB") 433 | if normal is not None: 434 | normal = Image.fromarray(np.array((normal + 1) / 2 * 255.0, dtype=np.byte), "RGB") 435 | test_cam_infos.append(CameraInfo( 436 | uid=index, R=R, T=T, FovY=fovy, FovX=fovx, image=image, 437 | image_path=image_path, image_name=image_name, 438 | alpha=alpha, normal=normal, depth=depth, 439 | width=image.size[0], height=image.size[1])) 440 | else: 441 | image *= img_mask[..., np.newaxis] 442 | depth *= img_mask 443 | normal *= img_mask[..., np.newaxis] 444 | image = Image.fromarray(np.array(image * 255.0, dtype=np.byte), "RGB") 445 | alpha = Image.fromarray(np.array(np.tile(img_mask[..., np.newaxis], (1, 1, 3)) * 255.0, dtype=np.byte), 446 | "RGB") 447 | if normal is not None: 448 | normal = Image.fromarray(np.array((normal + 1) / 2 * 255.0, dtype=np.byte), "RGB") 449 | train_cam_infos.append(CameraInfo( 450 | uid=index, R=R, T=T, FovY=fovy, FovX=fovx, image=image, 451 | image_path=image_path, image_name=image_name, 452 | alpha=alpha, normal=normal, depth=depth, 453 | width=image.size[0], height=image.size[1])) 454 | 455 | return train_cam_infos, test_cam_infos, bbox_transform 456 | 457 | 458 | def readNeILFInfo(path, background, eval, log=None, debug=False): 459 | validation_indexes = [] 460 | if eval: 461 | if "dtu" in path.lower(): 462 | validation_indexes = [6, 13, 30, 35] # same as neuTex 463 | else: 464 | raise NotImplementedError 465 | 466 | train_cam_infos, test_cam_infos, bbx_trans = loadCamsFromScene( 467 | f'{path}/inputs', validation_indexes, background, debug) 468 | 469 | nerf_normalization = getNerfppNorm(train_cam_infos) 470 | 471 | ply_path = f'{path}/inputs/model/sparse_bbx_scale.ply' 472 | if not os.path.exists(ply_path): 473 | org_ply_path = f'{path}/inputs/model/sparse.ply' 474 | 475 | # scale sparse.ply 476 | pcd = fetchPly(org_ply_path) 477 | inv_scale_mat = np.linalg.inv(bbx_trans) # [4, 4] 478 | points = pcd.points 479 | xyz = (np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) @ inv_scale_mat.T)[:, :3] 480 | normals = pcd.normals 481 | colors = pcd.colors 482 | 483 | storePly(ply_path, xyz, colors * 255, normals) 484 | 485 | try: 486 | pcd = fetchPly(ply_path) 487 | except: 488 | pcd = None 489 | 490 | scene_info = SceneInfo(point_cloud=pcd, 491 | train_cameras=train_cam_infos, 492 | test_cameras=test_cam_infos, 493 | nerf_normalization=nerf_normalization, 494 | ply_path=ply_path) 495 | return scene_info 496 | 497 | sceneLoadTypeCallbacks = { 498 | "Colmap": readColmapSceneInfo, 499 | "Blender" : readNerfSyntheticInfo, 500 | "NeILF": readNeILFInfo, 501 | } -------------------------------------------------------------------------------- /scripts/average_error.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import defaultdict 4 | from glob import glob 5 | from statistics import mean 6 | from argparse import ArgumentParser 7 | 8 | if __name__ == '__main__': 9 | parser = ArgumentParser(description='Script to averaging metrics values for the dataset') 10 | parser.add_argument('-f', '--folder', help='Path to the target folder', type=str, required=True) 11 | args = parser.parse_args() 12 | 13 | #====================================================== 14 | results = glob(os.path.join(args.folder, "*/results.json")) 15 | 16 | metrics_statistic = defaultdict(list) 17 | for path in results: 18 | with open(path, "r") as f: 19 | metrics = json.load(f) 20 | metrics_statistic["PSNR"].append(metrics["ours_32000"]["PSNR"]) 21 | metrics_statistic["SSIM"].append(metrics["ours_32000"]["SSIM"]) 22 | metrics_statistic["LPIPS"].append(metrics["ours_32000"]["LPIPS"]) 23 | 24 | print("PSNR:", mean(metrics_statistic["PSNR"])) 25 | print("SSIM:", mean(metrics_statistic["SSIM"])) 26 | print("LPIPS:", mean(metrics_statistic["LPIPS"])) 27 | 28 | #====================================================== 29 | results = glob(os.path.join(args.folder, "*/speed.json")) 30 | 31 | metrics_statistic = defaultdict(list) 32 | for path in results: 33 | with open(path, "r") as f: 34 | metrics = json.load(f) 35 | metrics_statistic["points"].append(metrics["points"]) 36 | metrics_statistic["train_time"].append(metrics["train_time"]) 37 | metrics_statistic["train_time_std"].append(metrics["train_time_std"]) 38 | 39 | print("Points:", mean(metrics_statistic["points"])) 40 | print("FPS:", mean(metrics_statistic["train_time"]), "±", mean(metrics_statistic["train_time_std"])) 41 | 42 | #====================================================== 43 | files_1 = glob(os.path.join(args.folder, "*/point_cloud/iteration_32000/texture_color.npz")) 44 | files_2 = glob(os.path.join(args.folder, "*/point_cloud/iteration_32000/texture_alpha.npz")) 45 | files_3 = glob(os.path.join(args.folder, "*/point_cloud/iteration_32000/point_cloud.ply")) 46 | 47 | size_statistic = [] 48 | for f1, f2, f3 in zip(files_1, files_2, files_3): 49 | total_size = 0 50 | file_stats = os.stat(f1) 51 | total_size += file_stats.st_size / (1024 * 1024) 52 | file_stats = os.stat(f2) 53 | total_size += file_stats.st_size / (1024 * 1024) 54 | file_stats = os.stat(f3) 55 | total_size += file_stats.st_size / (1024 * 1024) 56 | size_statistic.append(total_size) 57 | 58 | print("Size:", mean(size_statistic), " MB") 59 | -------------------------------------------------------------------------------- /scripts/colmap_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd .. 3 | DATA_FOLDER=/media/dsvitov/DATA/ 4 | 5 | # Process Tanks&Temples 6 | # First put images in /*_COLMAP/input folder 7 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Train_COLMAP/ 8 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Training/Truck_COLMAP/ 9 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Francis_COLMAP/ 10 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Horse_COLMAP/ 11 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Lighthouse_COLMAP/ 12 | 13 | # Mip-NeRF-360 14 | # First put images in /*/COLMAP/input folder 15 | python3 convert.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bonsai/COLMAP 16 | python3 convert.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/counter/COLMAP 17 | python3 convert.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/kitchen/COLMAP 18 | python3 convert.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/room/COLMAP 19 | 20 | # Process DTU 21 | # First put images in /*_COLMAP/input folder 22 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan24_COLMAP 23 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan37_COLMAP 24 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan40_COLMAP 25 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan55_COLMAP 26 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan63_COLMAP 27 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan65_COLMAP 28 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan69_COLMAP 29 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan83_COLMAP 30 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan97_COLMAP 31 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan105_COLMAP 32 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan106_COLMAP 33 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan110_COLMAP 34 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan114_COLMAP 35 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan118_COLMAP 36 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan122_COLMAP 37 | 38 | -------------------------------------------------------------------------------- /scripts/dtu_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser 4 | from glob import glob 5 | from statistics import mean 6 | 7 | dtu_scenes = ['scan24', 'scan37', 'scan40', 'scan55', 'scan63', 'scan65', 'scan69', 'scan83', 'scan97', 'scan105', 'scan106', 'scan110', 'scan114', 'scan118', 'scan122'] 8 | 9 | points = { 10 | 'scan24': 30_000, 11 | 'scan37': 30_000, 12 | 'scan40': 30_000, 13 | 'scan55': 60_000, 14 | 'scan63': 60_000, 15 | 'scan65': 60_000, 16 | 'scan69': 60_000, 17 | 'scan83': 60_000, 18 | 'scan97': 60_000, 19 | 'scan105': 30_000, 20 | 'scan106': 60_000, 21 | 'scan110': 60_000, 22 | 'scan114': 60_000, 23 | 'scan118': 60_000, 24 | 'scan122': 60_000, 25 | } 26 | 27 | parser = ArgumentParser(description="Full evaluation script parameters") 28 | parser.add_argument("--skip_training", action="store_true") 29 | parser.add_argument("--skip_rendering", action="store_true") 30 | parser.add_argument("--skip_metrics", action="store_true") 31 | parser.add_argument("--output_path", default="./eval/dtu") 32 | parser.add_argument('--dtu', "-dtu", required=True, type=str) 33 | args, _ = parser.parse_known_args() 34 | 35 | all_scenes = [] 36 | all_scenes.extend(dtu_scenes) 37 | 38 | if not args.skip_metrics: 39 | parser.add_argument('--DTU_Official', "-DTU", required=True, type=str) 40 | args = parser.parse_args() 41 | 42 | 43 | if not args.skip_training: 44 | for scene in dtu_scenes: 45 | common_args = " --quiet --test_iterations -1 --depth_ratio 1.0 -r 2 --lambda_dist 1000 --lambda_normal=0.05 --cap_max=" + str(points[scene]) + " --max_read_points=" + str(points[scene]) 46 | source = args.dtu + "/" + scene 47 | print("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 48 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 49 | 50 | 51 | if not args.skip_rendering: 52 | all_sources = [] 53 | common_args = " --quiet --depth_ratio 1.0 --num_cluster 1 --voxel_size 0.004 --sdf_trunc 0.016 --depth_trunc 3.0" 54 | for scene in dtu_scenes: 55 | source = args.dtu + "/" + scene 56 | print("python render.py --iteration 32000 -s " + source + " -m" + args.output_path + "/" + scene + common_args) 57 | os.system("python render.py --iteration 32000 -s " + source + " -m" + args.output_path + "/" + scene + common_args) 58 | 59 | 60 | if not args.skip_metrics: 61 | script_dir = os.path.dirname(os.path.abspath(__file__)) 62 | for scene in dtu_scenes: 63 | scan_id = scene[4:] 64 | ply_file = f"{args.output_path}/{scene}/train/ours_32000/" 65 | iteration = 32000 66 | string = f"python {script_dir}/eval_dtu/evaluate_single_scene.py " + \ 67 | f"--input_mesh {args.output_path}/{scene}/train/ours_32000/fuse_post.ply " + \ 68 | f"--scan_id {scan_id} --output_dir {script_dir}/tmp/scan{scan_id} " + \ 69 | f"--mask_dir {args.dtu} " + \ 70 | f"--DTU {args.DTU_Official}" 71 | 72 | os.system(string) 73 | 74 | results = glob(f"{script_dir}/tmp/*/results.json") 75 | overall = [] 76 | for path in results: 77 | with open(path, "r") as f: 78 | metrics = json.load(f) 79 | overall.append(metrics["overall"]) 80 | 81 | print("Mean CD:", mean(overall)) -------------------------------------------------------------------------------- /scripts/eval_dtu/eval.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/jzhangbs/DTUeval-python 2 | import numpy as np 3 | import open3d as o3d 4 | import sklearn.neighbors as skln 5 | from tqdm import tqdm 6 | from scipy.io import loadmat 7 | import multiprocessing as mp 8 | import argparse 9 | 10 | def sample_single_tri(input_): 11 | n1, n2, v1, v2, tri_vert = input_ 12 | c = np.mgrid[:n1+1, :n2+1] 13 | c += 0.5 14 | c[0] /= max(n1, 1e-7) 15 | c[1] /= max(n2, 1e-7) 16 | c = np.transpose(c, (1,2,0)) 17 | k = c[c.sum(axis=-1) < 1] # m2 18 | q = v1 * k[:,:1] + v2 * k[:,1:] + tri_vert 19 | return q 20 | 21 | def write_vis_pcd(file, points, colors): 22 | pcd = o3d.geometry.PointCloud() 23 | pcd.points = o3d.utility.Vector3dVector(points) 24 | pcd.colors = o3d.utility.Vector3dVector(colors) 25 | o3d.io.write_point_cloud(file, pcd) 26 | 27 | if __name__ == '__main__': 28 | mp.freeze_support() 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--data', type=str, default='data_in.ply') 32 | parser.add_argument('--scan', type=int, default=1) 33 | parser.add_argument('--mode', type=str, default='mesh', choices=['mesh', 'pcd']) 34 | parser.add_argument('--dataset_dir', type=str, default='.') 35 | parser.add_argument('--vis_out_dir', type=str, default='.') 36 | parser.add_argument('--downsample_density', type=float, default=0.2) 37 | parser.add_argument('--patch_size', type=float, default=60) 38 | parser.add_argument('--max_dist', type=float, default=20) 39 | parser.add_argument('--visualize_threshold', type=float, default=10) 40 | args = parser.parse_args() 41 | 42 | thresh = args.downsample_density 43 | if args.mode == 'mesh': 44 | pbar = tqdm(total=9) 45 | pbar.set_description('read data mesh') 46 | data_mesh = o3d.io.read_triangle_mesh(args.data) 47 | 48 | vertices = np.asarray(data_mesh.vertices) 49 | triangles = np.asarray(data_mesh.triangles) 50 | tri_vert = vertices[triangles] 51 | 52 | pbar.update(1) 53 | pbar.set_description('sample pcd from mesh') 54 | v1 = tri_vert[:,1] - tri_vert[:,0] 55 | v2 = tri_vert[:,2] - tri_vert[:,0] 56 | l1 = np.linalg.norm(v1, axis=-1, keepdims=True) 57 | l2 = np.linalg.norm(v2, axis=-1, keepdims=True) 58 | area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True) 59 | non_zero_area = (area2 > 0)[:,0] 60 | l1, l2, area2, v1, v2, tri_vert = [ 61 | arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert] 62 | ] 63 | thr = thresh * np.sqrt(l1 * l2 / area2) 64 | n1 = np.floor(l1 / thr) 65 | n2 = np.floor(l2 / thr) 66 | 67 | with mp.Pool() as mp_pool: 68 | new_pts = mp_pool.map(sample_single_tri, ((n1[i,0], n2[i,0], v1[i:i+1], v2[i:i+1], tri_vert[i:i+1,0]) for i in range(len(n1))), chunksize=1024) 69 | 70 | new_pts = np.concatenate(new_pts, axis=0) 71 | data_pcd = np.concatenate([vertices, new_pts], axis=0) 72 | 73 | elif args.mode == 'pcd': 74 | pbar = tqdm(total=8) 75 | pbar.set_description('read data pcd') 76 | data_pcd_o3d = o3d.io.read_point_cloud(args.data) 77 | data_pcd = np.asarray(data_pcd_o3d.points) 78 | 79 | pbar.update(1) 80 | pbar.set_description('random shuffle pcd index') 81 | shuffle_rng = np.random.default_rng() 82 | shuffle_rng.shuffle(data_pcd, axis=0) 83 | 84 | pbar.update(1) 85 | pbar.set_description('downsample pcd') 86 | nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=-1) 87 | nn_engine.fit(data_pcd) 88 | rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False) 89 | mask = np.ones(data_pcd.shape[0], dtype=np.bool_) 90 | for curr, idxs in enumerate(rnn_idxs): 91 | if mask[curr]: 92 | mask[idxs] = 0 93 | mask[curr] = 1 94 | data_down = data_pcd[mask] 95 | 96 | pbar.update(1) 97 | pbar.set_description('masking data pcd') 98 | obs_mask_file = loadmat(f'{args.dataset_dir}/ObsMask/ObsMask{args.scan}_10.mat') 99 | ObsMask, BB, Res = [obs_mask_file[attr] for attr in ['ObsMask', 'BB', 'Res']] 100 | BB = BB.astype(np.float32) 101 | 102 | patch = args.patch_size 103 | inbound = ((data_down >= BB[:1]-patch) & (data_down < BB[1:]+patch*2)).sum(axis=-1) ==3 104 | data_in = data_down[inbound] 105 | 106 | data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32) 107 | grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(axis=-1) ==3 108 | data_grid_in = data_grid[grid_inbound] 109 | in_obs = ObsMask[data_grid_in[:,0], data_grid_in[:,1], data_grid_in[:,2]].astype(np.bool_) 110 | data_in_obs = data_in[grid_inbound][in_obs] 111 | 112 | pbar.update(1) 113 | pbar.set_description('read STL pcd') 114 | stl_pcd = o3d.io.read_point_cloud(f'{args.dataset_dir}/Points/stl/stl{args.scan:03}_total.ply') 115 | stl = np.asarray(stl_pcd.points) 116 | 117 | pbar.update(1) 118 | pbar.set_description('compute data2stl') 119 | nn_engine.fit(stl) 120 | dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True) 121 | max_dist = args.max_dist 122 | mean_d2s = dist_d2s[dist_d2s < max_dist].mean() 123 | 124 | pbar.update(1) 125 | pbar.set_description('compute stl2data') 126 | ground_plane = loadmat(f'{args.dataset_dir}/ObsMask/Plane{args.scan}.mat')['P'] 127 | 128 | stl_hom = np.concatenate([stl, np.ones_like(stl[:,:1])], -1) 129 | above = (ground_plane.reshape((1,4)) * stl_hom).sum(-1) > 0 130 | stl_above = stl[above] 131 | 132 | nn_engine.fit(data_in) 133 | dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True) 134 | mean_s2d = dist_s2d[dist_s2d < max_dist].mean() 135 | 136 | pbar.update(1) 137 | pbar.set_description('visualize error') 138 | vis_dist = args.visualize_threshold 139 | R = np.array([[1,0,0]], dtype=np.float64) 140 | G = np.array([[0,1,0]], dtype=np.float64) 141 | B = np.array([[0,0,1]], dtype=np.float64) 142 | W = np.array([[1,1,1]], dtype=np.float64) 143 | data_color = np.tile(B, (data_down.shape[0], 1)) 144 | data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist 145 | data_color[ np.where(inbound)[0][grid_inbound][in_obs] ] = R * data_alpha + W * (1-data_alpha) 146 | data_color[ np.where(inbound)[0][grid_inbound][in_obs][dist_d2s[:,0] >= max_dist] ] = G 147 | write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_d2s.ply', data_down, data_color) 148 | stl_color = np.tile(B, (stl.shape[0], 1)) 149 | stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist 150 | stl_color[ np.where(above)[0] ] = R * stl_alpha + W * (1-stl_alpha) 151 | stl_color[ np.where(above)[0][dist_s2d[:,0] >= max_dist] ] = G 152 | write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_s2d.ply', stl, stl_color) 153 | 154 | pbar.update(1) 155 | pbar.set_description('done') 156 | pbar.close() 157 | over_all = (mean_d2s + mean_s2d) / 2 158 | print(mean_d2s, mean_s2d, over_all) 159 | 160 | import json 161 | with open(f'{args.vis_out_dir}/results.json', 'w') as fp: 162 | json.dump({ 163 | 'mean_d2s': mean_d2s, 164 | 'mean_s2d': mean_s2d, 165 | 'overall': over_all, 166 | }, fp, indent=True) -------------------------------------------------------------------------------- /scripts/eval_dtu/evaluate_single_scene.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import cv2 5 | import numpy as np 6 | import os 7 | import glob 8 | from skimage.morphology import binary_dilation, disk 9 | import argparse 10 | 11 | import trimesh 12 | from pathlib import Path 13 | import subprocess 14 | 15 | import sys 16 | import render_utils as rend_util 17 | from tqdm import tqdm 18 | 19 | def cull_scan(scan, mesh_path, result_mesh_file, instance_dir): 20 | 21 | # load poses 22 | image_dir = '{0}/images'.format(instance_dir) 23 | image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png"))) 24 | n_images = len(image_paths) 25 | cam_file = '{0}/cameras.npz'.format(instance_dir) 26 | camera_dict = np.load(cam_file) 27 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)] 28 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)] 29 | 30 | intrinsics_all = [] 31 | pose_all = [] 32 | for scale_mat, world_mat in zip(scale_mats, world_mats): 33 | P = world_mat @ scale_mat 34 | P = P[:3, :4] 35 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) 36 | intrinsics_all.append(torch.from_numpy(intrinsics).float()) 37 | pose_all.append(torch.from_numpy(pose).float()) 38 | 39 | # load mask 40 | mask_dir = '{0}/mask'.format(instance_dir) 41 | mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*.png"))) 42 | masks = [] 43 | for p in mask_paths: 44 | mask = cv2.imread(p) 45 | masks.append(mask) 46 | 47 | # hard-coded image shape 48 | W, H = 1600, 1200 49 | 50 | # load mesh 51 | mesh = trimesh.load(mesh_path) 52 | 53 | # load transformation matrix 54 | 55 | vertices = mesh.vertices 56 | 57 | # project and filter 58 | vertices = torch.from_numpy(vertices).cuda() 59 | vertices = torch.cat((vertices, torch.ones_like(vertices[:, :1])), dim=-1) 60 | vertices = vertices.permute(1, 0) 61 | vertices = vertices.float() 62 | 63 | sampled_masks = [] 64 | for i in tqdm(range(n_images), desc="Culling mesh given masks"): 65 | pose = pose_all[i] 66 | w2c = torch.inverse(pose).cuda() 67 | intrinsic = intrinsics_all[i].cuda() 68 | 69 | with torch.no_grad(): 70 | # transform and project 71 | cam_points = intrinsic @ w2c @ vertices 72 | pix_coords = cam_points[:2, :] / (cam_points[2, :].unsqueeze(0) + 1e-6) 73 | pix_coords = pix_coords.permute(1, 0) 74 | pix_coords[..., 0] /= W - 1 75 | pix_coords[..., 1] /= H - 1 76 | pix_coords = (pix_coords - 0.5) * 2 77 | valid = ((pix_coords > -1. ) & (pix_coords < 1.)).all(dim=-1).float() 78 | 79 | # dialate mask similar to unisurf 80 | maski = masks[i][:, :, 0].astype(np.float32) / 256. 81 | maski = torch.from_numpy(binary_dilation(maski, disk(24))).float()[None, None].cuda() 82 | 83 | sampled_mask = F.grid_sample(maski, pix_coords[None, None], mode='nearest', padding_mode='zeros', align_corners=True)[0, -1, 0] 84 | 85 | sampled_mask = sampled_mask + (1. - valid) 86 | sampled_masks.append(sampled_mask) 87 | 88 | sampled_masks = torch.stack(sampled_masks, -1) 89 | # filter 90 | 91 | mask = (sampled_masks > 0.).all(dim=-1).cpu().numpy() 92 | face_mask = mask[mesh.faces].all(axis=1) 93 | 94 | mesh.update_vertices(mask) 95 | mesh.update_faces(face_mask) 96 | 97 | # transform vertices to world 98 | scale_mat = scale_mats[0] 99 | mesh.vertices = mesh.vertices * scale_mat[0, 0] + scale_mat[:3, 3][None] 100 | mesh.export(result_mesh_file) 101 | del mesh 102 | 103 | 104 | if __name__ == "__main__": 105 | 106 | parser = argparse.ArgumentParser( 107 | description='Arguments to evaluate the mesh.' 108 | ) 109 | 110 | parser.add_argument('--input_mesh', type=str, help='path to the mesh to be evaluated') 111 | parser.add_argument('--scan_id', type=str, help='scan id of the input mesh') 112 | parser.add_argument('--output_dir', type=str, default='evaluation_results_single', help='path to the output folder') 113 | parser.add_argument('--mask_dir', type=str, default='mask', help='path to uncropped mask') 114 | parser.add_argument('--DTU', type=str, default='Offical_DTU_Dataset', help='path to the GT DTU point clouds') 115 | args = parser.parse_args() 116 | 117 | Offical_DTU_Dataset = args.DTU 118 | out_dir = args.output_dir 119 | Path(out_dir).mkdir(parents=True, exist_ok=True) 120 | 121 | scan = args.scan_id 122 | ply_file = args.input_mesh 123 | print("cull mesh ....") 124 | result_mesh_file = os.path.join(out_dir, "culled_mesh.ply") 125 | cull_scan(scan, ply_file, result_mesh_file, instance_dir=os.path.join(args.mask_dir, f'scan{args.scan_id}')) 126 | 127 | script_dir = os.path.dirname(os.path.abspath(__file__)) 128 | cmd = f"python {script_dir}/eval.py --data {result_mesh_file} --scan {scan} --mode mesh --dataset_dir {Offical_DTU_Dataset} --vis_out_dir {out_dir}" 129 | os.system(cmd) -------------------------------------------------------------------------------- /scripts/eval_dtu/render_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import skimage 4 | import cv2 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | 9 | def get_psnr(img1, img2, normalize_rgb=False): 10 | if normalize_rgb: # [-1,1] --> [0,1] 11 | img1 = (img1 + 1.) / 2. 12 | img2 = (img2 + 1. ) / 2. 13 | 14 | mse = torch.mean((img1 - img2) ** 2) 15 | psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda()) 16 | 17 | return psnr 18 | 19 | 20 | def load_rgb(path, normalize_rgb = False): 21 | img = imageio.imread(path) 22 | img = skimage.img_as_float32(img) 23 | 24 | if normalize_rgb: # [-1,1] --> [0,1] 25 | img -= 0.5 26 | img *= 2. 27 | img = img.transpose(2, 0, 1) 28 | return img 29 | 30 | 31 | def load_K_Rt_from_P(filename, P=None): 32 | if P is None: 33 | lines = open(filename).read().splitlines() 34 | if len(lines) == 4: 35 | lines = lines[1:] 36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 37 | P = np.asarray(lines).astype(np.float32).squeeze() 38 | 39 | out = cv2.decomposeProjectionMatrix(P) 40 | K = out[0] 41 | R = out[1] 42 | t = out[2] 43 | 44 | K = K/K[2,2] 45 | intrinsics = np.eye(4) 46 | intrinsics[:3, :3] = K 47 | 48 | pose = np.eye(4, dtype=np.float32) 49 | pose[:3, :3] = R.transpose() 50 | pose[:3,3] = (t[:3] / t[3])[:,0] 51 | 52 | return intrinsics, pose 53 | 54 | 55 | def get_camera_params(uv, pose, intrinsics): 56 | if pose.shape[1] == 7: #In case of quaternion vector representation 57 | cam_loc = pose[:, 4:] 58 | R = quat_to_rot(pose[:,:4]) 59 | p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float() 60 | p[:, :3, :3] = R 61 | p[:, :3, 3] = cam_loc 62 | else: # In case of pose matrix representation 63 | cam_loc = pose[:, :3, 3] 64 | p = pose 65 | 66 | batch_size, num_samples, _ = uv.shape 67 | 68 | depth = torch.ones((batch_size, num_samples)).cuda() 69 | x_cam = uv[:, :, 0].view(batch_size, -1) 70 | y_cam = uv[:, :, 1].view(batch_size, -1) 71 | z_cam = depth.view(batch_size, -1) 72 | 73 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) 74 | 75 | # permute for batch matrix product 76 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1) 77 | 78 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] 79 | ray_dirs = world_coords - cam_loc[:, None, :] 80 | ray_dirs = F.normalize(ray_dirs, dim=2) 81 | 82 | return ray_dirs, cam_loc 83 | 84 | 85 | def get_camera_for_plot(pose): 86 | if pose.shape[1] == 7: #In case of quaternion vector representation 87 | cam_loc = pose[:, 4:].detach() 88 | R = quat_to_rot(pose[:,:4].detach()) 89 | else: # In case of pose matrix representation 90 | cam_loc = pose[:, :3, 3] 91 | R = pose[:, :3, :3] 92 | cam_dir = R[:, :3, 2] 93 | return cam_loc, cam_dir 94 | 95 | 96 | def lift(x, y, z, intrinsics): 97 | # parse intrinsics 98 | intrinsics = intrinsics.cuda() 99 | fx = intrinsics[:, 0, 0] 100 | fy = intrinsics[:, 1, 1] 101 | cx = intrinsics[:, 0, 2] 102 | cy = intrinsics[:, 1, 2] 103 | sk = intrinsics[:, 0, 1] 104 | 105 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z 106 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z 107 | 108 | # homogeneous 109 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1) 110 | 111 | 112 | def quat_to_rot(q): 113 | batch_size, _ = q.shape 114 | q = F.normalize(q, dim=1) 115 | R = torch.ones((batch_size, 3,3)).cuda() 116 | qr=q[:,0] 117 | qi = q[:, 1] 118 | qj = q[:, 2] 119 | qk = q[:, 3] 120 | R[:, 0, 0]=1-2 * (qj**2 + qk**2) 121 | R[:, 0, 1] = 2 * (qj *qi -qk*qr) 122 | R[:, 0, 2] = 2 * (qi * qk + qr * qj) 123 | R[:, 1, 0] = 2 * (qj * qi + qk * qr) 124 | R[:, 1, 1] = 1-2 * (qi**2 + qk**2) 125 | R[:, 1, 2] = 2*(qj*qk - qi*qr) 126 | R[:, 2, 0] = 2 * (qk * qi-qj * qr) 127 | R[:, 2, 1] = 2 * (qj*qk + qi*qr) 128 | R[:, 2, 2] = 1-2 * (qi**2 + qj**2) 129 | return R 130 | 131 | 132 | def rot_to_quat(R): 133 | batch_size, _,_ = R.shape 134 | q = torch.ones((batch_size, 4)).cuda() 135 | 136 | R00 = R[:, 0,0] 137 | R01 = R[:, 0, 1] 138 | R02 = R[:, 0, 2] 139 | R10 = R[:, 1, 0] 140 | R11 = R[:, 1, 1] 141 | R12 = R[:, 1, 2] 142 | R20 = R[:, 2, 0] 143 | R21 = R[:, 2, 1] 144 | R22 = R[:, 2, 2] 145 | 146 | q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2 147 | q[:, 1]=(R21-R12)/(4*q[:,0]) 148 | q[:, 2] = (R02 - R20) / (4 * q[:, 0]) 149 | q[:, 3] = (R10 - R01) / (4 * q[:, 0]) 150 | return q 151 | 152 | 153 | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0): 154 | # Input: n_rays x 3 ; n_rays x 3 155 | # Output: n_rays x 1, n_rays x 1 (close and far) 156 | 157 | ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3), 158 | cam_loc.view(-1, 3, 1)).squeeze(-1) 159 | under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2) 160 | 161 | # sanity check 162 | if (under_sqrt <= 0).sum() > 0: 163 | print('BOUNDING SPHERE PROBLEM!') 164 | exit() 165 | 166 | sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot 167 | sphere_intersections = sphere_intersections.clamp_min(0.0) 168 | 169 | return sphere_intersections -------------------------------------------------------------------------------- /scripts/metrics_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd .. 3 | OUTPUT_FOLDER=/media/dsvitov/DATA/output/Ours 4 | 5 | # Process Tanks&Temples 6 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Train 7 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Truck 8 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Francis 9 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Horse 10 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Lighthouse 11 | 12 | # Mip-NeRF-360 13 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Bonsai 14 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Counter 15 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Kitchen 16 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Room 17 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Bicycle 18 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Stump 19 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Garden 20 | 21 | # Process DTU 22 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan24 23 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan37 24 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan40 25 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan55 26 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan63 27 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan65 28 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan69 29 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan83 30 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan97 31 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan105 32 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan106 33 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan110 34 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan114 35 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan118 36 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan122 37 | 38 | # Average metrics for each dataset 39 | python scripts/average_error.py --folder ${OUTPUT_FOLDER}/TnT 40 | python scripts/average_error.py --folder ${OUTPUT_FOLDER}/MipNerf 41 | python scripts/average_error.py --folder ${OUTPUT_FOLDER}/DTU 42 | -------------------------------------------------------------------------------- /scripts/render_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd .. 3 | DATA_FOLDER=/media/dsvitov/DATA/ 4 | OUTPUT_FOLDER=/media/dsvitov/DATA/output/Ours 5 | 6 | # Process Tanks&Temples 7 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Train_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Train --skip_mesh 8 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Training/Truck_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Truck --skip_mesh 9 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Francis_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Francis --skip_mesh 10 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Horse_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Horse --skip_mesh 11 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Lighthouse_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Lighthouse --skip_mesh 12 | 13 | # Mip-NeRF-360 14 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bonsai --model_path=${OUTPUT_FOLDER}/MipNerf/Bonsai --skip_mesh 15 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/counter --model_path=${OUTPUT_FOLDER}/MipNerf/Counter --skip_mesh 16 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/kitchen --model_path=${OUTPUT_FOLDER}/MipNerf/Kitchen --skip_mesh 17 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/room --model_path=${OUTPUT_FOLDER}/MipNerf/Room --skip_mesh 18 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bicycle --model_path=${OUTPUT_FOLDER}/MipNerf/Bicycle --skip_mesh 19 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/stump --model_path=${OUTPUT_FOLDER}/MipNerf/Stump --skip_mesh 20 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/garden --model_path=${OUTPUT_FOLDER}/MipNerf/Garden --skip_mesh 21 | 22 | # Process DTU 23 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan24 --model_path=${OUTPUT_FOLDER}/DTU/scan24 --skip_mesh 24 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan37 --model_path=${OUTPUT_FOLDER}/DTU/scan37 --skip_mesh 25 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan40 --model_path=${OUTPUT_FOLDER}/DTU/scan40 --skip_mesh 26 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan55 --model_path=${OUTPUT_FOLDER}/DTU/scan55 --skip_mesh 27 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan63 --model_path=${OUTPUT_FOLDER}/DTU/scan63 --skip_mesh 28 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan65 --model_path=${OUTPUT_FOLDER}/DTU/scan65 --skip_mesh 29 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan69 --model_path=${OUTPUT_FOLDER}/DTU/scan69 --skip_mesh 30 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan83 --model_path=${OUTPUT_FOLDER}/DTU/scan83 --skip_mesh 31 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan97 --model_path=${OUTPUT_FOLDER}/DTU/scan97 --skip_mesh 32 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan105 --model_path=${OUTPUT_FOLDER}/DTU/scan105 --skip_mesh 33 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan106 --model_path=${OUTPUT_FOLDER}/DTU/scan106 --skip_mesh 34 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan110 --model_path=${OUTPUT_FOLDER}/DTU/scan110 --skip_mesh 35 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan114 --model_path=${OUTPUT_FOLDER}/DTU/scan114 --skip_mesh 36 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan118 --model_path=${OUTPUT_FOLDER}/DTU/scan118 --skip_mesh 37 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan122 --model_path=${OUTPUT_FOLDER}/DTU/scan122 --skip_mesh 38 | -------------------------------------------------------------------------------- /scripts/train_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd .. 3 | DATA_FOLDER=/media/dsvitov/DATA/ 4 | OUTPUT_FOLDER=/media/dsvitov/DATA/output/Ours 5 | 6 | # Process Tanks&Temples 7 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Train_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Train --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval 8 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Training/Truck_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Truck --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval 9 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Francis_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Francis --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval 10 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Horse_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Horse --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval 11 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Lighthouse_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Lighthouse --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval 12 | 13 | # Mip-NeRF-360 14 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bonsai --model_path=${OUTPUT_FOLDER}/MipNerf/Bonsai --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval 15 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/counter --model_path=${OUTPUT_FOLDER}/MipNerf/Counter --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval 16 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/kitchen --model_path=${OUTPUT_FOLDER}/MipNerf/Kitchen --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval 17 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/room --model_path=${OUTPUT_FOLDER}/MipNerf/Room --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval 18 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bicycle --model_path=${OUTPUT_FOLDER}/MipNerf/Bicycle --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval 19 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/stump --model_path=${OUTPUT_FOLDER}/MipNerf/Stump --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval 20 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/garden --model_path=${OUTPUT_FOLDER}/MipNerf/Garden --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval 21 | 22 | # Process DTU 23 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan24 --model_path=${OUTPUT_FOLDER}/DTU/scan24 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 24 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan37 --model_path=${OUTPUT_FOLDER}/DTU/scan37 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 25 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan40 --model_path=${OUTPUT_FOLDER}/DTU/scan40 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 26 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan55 --model_path=${OUTPUT_FOLDER}/DTU/scan55 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 27 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan63 --model_path=${OUTPUT_FOLDER}/DTU/scan63 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 28 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan65 --model_path=${OUTPUT_FOLDER}/DTU/scan65 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 29 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan69 --model_path=${OUTPUT_FOLDER}/DTU/scan69 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 30 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan83 --model_path=${OUTPUT_FOLDER}/DTU/scan83 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 31 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan97 --model_path=${OUTPUT_FOLDER}/DTU/scan97 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 32 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan105 --model_path=${OUTPUT_FOLDER}/DTU/scan105 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 33 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan106 --model_path=${OUTPUT_FOLDER}/DTU/scan106 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 34 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan110 --model_path=${OUTPUT_FOLDER}/DTU/scan110 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 35 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan114 --model_path=${OUTPUT_FOLDER}/DTU/scan114 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 36 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan118 --model_path=${OUTPUT_FOLDER}/DTU/scan118 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 37 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan122 --model_path=${OUTPUT_FOLDER}/DTU/scan122 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval 38 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" 14 | 15 | import torch 16 | from random import randint 17 | from utils.loss_utils import l1_loss, ssim 18 | from gaussian_renderer import render 19 | import sys 20 | from scene import Scene, GaussianModel 21 | from utils.general_utils import safe_state, build_scaling_rotation 22 | import uuid 23 | from tqdm import tqdm 24 | from utils.image_utils import psnr 25 | from argparse import ArgumentParser, Namespace 26 | from arguments import ModelParams, PipelineParams, OptimizationParams 27 | 28 | try: 29 | from torch.utils.tensorboard import SummaryWriter 30 | TENSORBOARD_FOUND = True 31 | except ImportError: 32 | TENSORBOARD_FOUND = False 33 | 34 | 35 | def total_variation_loss(img): 36 | bs_img, c_img, h_img, w_img = img.size() 37 | tv_h = torch.pow(img[:, :, 1:, :] - img[:, :, :-1, :], 2).sum() 38 | tv_w = torch.pow(img[:, :, :, 1:] - img[:, :, :, :-1], 2).sum() 39 | return (tv_h + tv_w) / (bs_img * c_img * h_img * w_img) 40 | 41 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint): 42 | first_iter = 0 43 | tb_writer = prepare_output_and_logger(dataset) 44 | gaussians = GaussianModel(dataset.sh_degree) 45 | scene = Scene(dataset, gaussians, add_sky_box=opt.add_sky_box, max_read_points=opt.max_read_points, sphere_point=opt.sphere_point) 46 | gaussians.training_setup(opt) 47 | if checkpoint: 48 | (model_params, first_iter) = torch.load(checkpoint) 49 | gaussians.restore(model_params, opt) 50 | 51 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 52 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 53 | 54 | iter_start = torch.cuda.Event(enable_timing = True) 55 | iter_end = torch.cuda.Event(enable_timing = True) 56 | 57 | viewpoint_stack = None 58 | ema_loss_for_log = 0.0 59 | ema_dist_for_log = 0.0 60 | ema_normal_for_log = 0.0 61 | ema_texture_for_log = 0.0 62 | 63 | initial_texture_alpha = gaussians.get_texture_alpha[0:1].detach().clone() 64 | 65 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 66 | first_iter += 1 67 | for iteration in range(first_iter, opt.iterations + 1): 68 | 69 | iter_start.record() 70 | 71 | xyz_lr = gaussians.update_learning_rate(iteration) 72 | 73 | # Every 1000 its we increase the levels of SH up to a maximum degree 74 | if iteration % 1000 == 0: 75 | gaussians.oneupSHdegree() 76 | 77 | # Pick a random Camera 78 | if not viewpoint_stack: 79 | viewpoint_stack = scene.getTrainCameras().copy() 80 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 81 | 82 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 83 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 84 | impact = render_pkg["impact"] 85 | 86 | gt_image = viewpoint_cam.original_image.cuda() 87 | 88 | Ll1 = l1_loss(image, gt_image) 89 | ssim_map = ssim(image, gt_image, size_average=False) 90 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_map.mean()) 91 | 92 | # regularization 93 | lambda_normal = opt.lambda_normal if iteration > 7000 else 0.0 94 | lambda_dist = opt.lambda_dist if iteration > 3000 else 0.0 95 | 96 | rend_dist = render_pkg["rend_dist"] 97 | rend_normal = render_pkg['rend_normal'] 98 | surf_normal = render_pkg['surf_normal'] 99 | normal_error = (1 - (rend_normal * surf_normal).sum(dim=0))[None] 100 | normal_loss = lambda_normal * (normal_error).mean() 101 | dist_loss = lambda_dist * (rend_dist).mean() 102 | 103 | weights = opt.max_impact_threshold - torch.clamp(impact[visibility_filter], 0, opt.max_impact_threshold) 104 | textures_reg = (gaussians.get_texture_color[visibility_filter].mean(dim=[1, 2, 3]) * weights).mean() * opt.lambda_texture_value 105 | textures_reg += torch.abs((gaussians.get_texture_alpha[visibility_filter] - initial_texture_alpha).mean(dim=[1, 2]) * weights).mean() * opt.lambda_alpha_value 106 | 107 | # loss 108 | total_loss = loss + dist_loss + normal_loss + textures_reg 109 | # For MCMC sampler 110 | total_loss += opt.opacity_reg * gaussians.get_texture_alpha.mean() 111 | total_loss.backward() 112 | 113 | iter_end.record() 114 | 115 | with torch.no_grad(): 116 | # Progress bar 117 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 118 | ema_dist_for_log = 0.4 * dist_loss.item() + 0.6 * ema_dist_for_log 119 | ema_normal_for_log = 0.4 * normal_loss.item() + 0.6 * ema_normal_for_log 120 | ema_texture_for_log = 0.4 * textures_reg.item() + 0.6 * ema_texture_for_log 121 | 122 | 123 | if iteration % 10 == 0: 124 | loss_dict = { 125 | "Loss": f"{ema_loss_for_log:.{5}f}", 126 | "distort": f"{ema_dist_for_log:.{5}f}", 127 | "normal": f"{ema_normal_for_log:.{5}f}", 128 | "texture": f"{ema_texture_for_log:.{5}f}", 129 | "Points": f"{len(gaussians.get_xyz)}" 130 | } 131 | progress_bar.set_postfix(loss_dict) 132 | 133 | progress_bar.update(10) 134 | if iteration == opt.iterations: 135 | progress_bar.close() 136 | 137 | # Log and save 138 | if tb_writer is not None: 139 | tb_writer.add_scalar('train_loss_patches/dist_loss', ema_dist_for_log, iteration) 140 | tb_writer.add_scalar('train_loss_patches/normal_loss', ema_normal_for_log, iteration) 141 | 142 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 143 | if (iteration in saving_iterations): 144 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 145 | scene.save(iteration) 146 | 147 | if opt.texture_from_iter <= iteration < opt.texture_to_iter: 148 | gaussians.activate_texture_training() 149 | 150 | if iteration >= opt.texture_to_iter: 151 | gaussians.deactivate_texture_training() 152 | 153 | if iteration > opt.position_lr_max_steps: 154 | gaussians.deactivate_gaussians_training() 155 | 156 | # Densification 157 | if iteration < opt.densify_until_iter and iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 158 | size = len(gaussians.get_texture_alpha) 159 | dead_mask = (gaussians.get_texture_alpha.view(size, -1).mean(1) <= opt.dead_opacity).squeeze(-1) 160 | 161 | gaussians.relocate_gs(dead_mask=dead_mask) 162 | gaussians.add_new_gs(cap_max=opt.cap_max) 163 | 164 | # Optimizer step 165 | if iteration < opt.iterations: 166 | gaussians.optimizer.step() 167 | gaussians.optimizer.zero_grad(set_to_none = True) 168 | 169 | L = build_scaling_rotation(gaussians.get_scaling, gaussians.get_rotation) 170 | actual_covariance = L @ L.transpose(1, 2) 171 | 172 | def op_sigmoid(x, k=100, x0=0.995): 173 | return 1 / (1 + torch.exp(-k * (x - x0))) 174 | 175 | #size = len(gaussians.get_texture_alpha) 176 | #opacity = gaussians.get_texture_alpha.view(size, -1).mean(1, keepdim=True) * 10 # Rescale to get maximum = 1 177 | opacity = torch.ones([gaussians.get_texture_alpha.shape[0], 1], dtype=torch.float32, device="cuda") # Fix opacity to 1 (results in the paper obtained this way) 178 | noise = torch.randn_like(gaussians._xyz) * (op_sigmoid(1 - opacity)) * opt.noise_lr * xyz_lr 179 | noise = torch.bmm(actual_covariance, noise.unsqueeze(-1)).squeeze(-1) 180 | gaussians._xyz.add_(noise) 181 | 182 | if (iteration in checkpoint_iterations): 183 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 184 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 185 | 186 | def prepare_output_and_logger(args): 187 | if not args.model_path: 188 | if os.getenv('OAR_JOB_ID'): 189 | unique_str=os.getenv('OAR_JOB_ID') 190 | else: 191 | unique_str = str(uuid.uuid4()) 192 | args.model_path = os.path.join("./output/", unique_str[0:10]) 193 | 194 | # Set up output folder 195 | print("Output folder: {}".format(args.model_path)) 196 | os.makedirs(args.model_path, exist_ok = True) 197 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 198 | cfg_log_f.write(str(Namespace(**vars(args)))) 199 | 200 | # Create Tensorboard writer 201 | tb_writer = None 202 | if TENSORBOARD_FOUND: 203 | tb_writer = SummaryWriter(args.model_path) 204 | else: 205 | print("Tensorboard not available: not logging progress") 206 | return tb_writer 207 | 208 | @torch.no_grad() 209 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 210 | if tb_writer: 211 | tb_writer.add_scalar('train_loss_patches/reg_loss', Ll1.item(), iteration) 212 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 213 | tb_writer.add_scalar('iter_time', elapsed, iteration) 214 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 215 | 216 | # Report test and samples of training set 217 | if iteration in testing_iterations: 218 | torch.cuda.empty_cache() 219 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 220 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 221 | 222 | for config in validation_configs: 223 | if config['cameras'] and len(config['cameras']) > 0: 224 | l1_test = 0.0 225 | psnr_test = 0.0 226 | for idx, viewpoint in enumerate(config['cameras']): 227 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs) 228 | image = torch.clamp(render_pkg["render"], 0.0, 1.0) 229 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 230 | if tb_writer and (idx < 5): 231 | from utils.general_utils import colormap 232 | depth = render_pkg["surf_depth"] 233 | norm = depth.max() 234 | depth = depth / norm 235 | depth = colormap(depth.cpu().numpy()[0], cmap='turbo') 236 | tb_writer.add_images(config['name'] + "_view_{}/depth".format(viewpoint.image_name), depth[None], global_step=iteration) 237 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 238 | 239 | try: 240 | rend_alpha = render_pkg['rend_alpha'] 241 | rend_normal = render_pkg["rend_normal"] * 0.5 + 0.5 242 | surf_normal = render_pkg["surf_normal"] * 0.5 + 0.5 243 | tb_writer.add_images(config['name'] + "_view_{}/rend_normal".format(viewpoint.image_name), rend_normal[None], global_step=iteration) 244 | tb_writer.add_images(config['name'] + "_view_{}/surf_normal".format(viewpoint.image_name), surf_normal[None], global_step=iteration) 245 | tb_writer.add_images(config['name'] + "_view_{}/rend_alpha".format(viewpoint.image_name), rend_alpha[None], global_step=iteration) 246 | 247 | rend_dist = render_pkg["rend_dist"] 248 | rend_dist = colormap(rend_dist.cpu().numpy()[0]) 249 | tb_writer.add_images(config['name'] + "_view_{}/rend_dist".format(viewpoint.image_name), rend_dist[None], global_step=iteration) 250 | except: 251 | pass 252 | 253 | if iteration == testing_iterations[0]: 254 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 255 | 256 | l1_test += l1_loss(image, gt_image).mean().double() 257 | psnr_test += psnr(image, gt_image).mean().double() 258 | 259 | psnr_test /= len(config['cameras']) 260 | l1_test /= len(config['cameras']) 261 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 262 | if tb_writer: 263 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 264 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 265 | 266 | torch.cuda.empty_cache() 267 | 268 | if __name__ == "__main__": 269 | # Set up command line argument parser 270 | parser = ArgumentParser(description="Training script parameters") 271 | lp = ModelParams(parser) 272 | op = OptimizationParams(parser) 273 | pp = PipelineParams(parser) 274 | parser.add_argument('--ip', type=str, default="127.0.0.1") 275 | parser.add_argument('--port', type=int, default=6009) 276 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 277 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[1_000, 7_000, 10_000, 15_000, 20_000, 25_000, 30_000, 32_000]) 278 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1_000, 7_000, 30_000, 32_000]) 279 | parser.add_argument("--quiet", action="store_true") 280 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 281 | parser.add_argument("--start_checkpoint", type=str, default = None) 282 | args = parser.parse_args(sys.argv[1:]) 283 | args.save_iterations.append(args.iterations) 284 | 285 | print("Optimizing " + args.model_path) 286 | 287 | # Initialize system state (RNG) 288 | safe_state(args.quiet) 289 | 290 | # Start GUI server, configure and run training 291 | # network_gui.init(args.ip, args.port) 292 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 293 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint) 294 | 295 | # All done 296 | print("\nTraining complete.") 297 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import torch 14 | 15 | from scene.cameras import Camera 16 | from utils.general_utils import PILtoTorch 17 | from utils.graphics_utils import fov2focal 18 | 19 | WARNED = False 20 | 21 | def loadCam(args, id, cam_info, resolution_scale): 22 | orig_w, orig_h = cam_info.image.size 23 | 24 | if args.resolution in [1, 2, 4, 8]: 25 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 26 | else: # should be a type that converts to float 27 | if args.resolution == -1: 28 | if orig_w > 1600: 29 | global WARNED 30 | if not WARNED: 31 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 32 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 33 | WARNED = True 34 | global_down = orig_w / 1600 35 | else: 36 | global_down = 1 37 | else: 38 | global_down = orig_w / args.resolution 39 | 40 | scale = float(global_down) * float(resolution_scale) 41 | resolution = (int(orig_w / scale), int(orig_h / scale)) 42 | 43 | if len(cam_info.image.split()) > 3: 44 | resized_image_rgb = torch.cat([PILtoTorch(im, resolution) for im in cam_info.image.split()[:3]], dim=0) 45 | loaded_mask = PILtoTorch(cam_info.image.split()[3], resolution) 46 | gt_image = resized_image_rgb 47 | else: 48 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 49 | loaded_mask = None 50 | gt_image = resized_image_rgb 51 | 52 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 53 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 54 | image=gt_image, gt_alpha_mask=loaded_mask, 55 | image_name=cam_info.image_name, uid=id, data_device=args.data_device) 56 | 57 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 58 | camera_list = [] 59 | 60 | for id, c in enumerate(cam_infos): 61 | camera_list.append(loadCam(args, id, c, resolution_scale)) 62 | 63 | return camera_list 64 | 65 | def camera_to_JSON(id, camera : Camera): 66 | Rt = np.zeros((4, 4)) 67 | Rt[:3, :3] = camera.R.transpose() 68 | Rt[:3, 3] = camera.T 69 | Rt[3, 3] = 1.0 70 | 71 | W2C = np.linalg.inv(Rt) 72 | pos = W2C[:3, 3] 73 | rot = W2C[:3, :3] 74 | serializable_array_2d = [x.tolist() for x in rot] 75 | camera_entry = { 76 | 'id' : id, 77 | 'img_name' : camera.image_name, 78 | 'width' : camera.width, 79 | 'height' : camera.height, 80 | 'position': pos.tolist(), 81 | 'rotation': serializable_array_2d, 82 | 'fy' : fov2focal(camera.FovY, camera.height), 83 | 'fx' : fov2focal(camera.FovX, camera.width) 84 | } 85 | return camera_entry -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = 0 #s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | 135 | 136 | 137 | 138 | def create_rotation_matrix_from_direction_vector_batch(direction_vectors): 139 | # Normalize the batch of direction vectors 140 | direction_vectors = direction_vectors / torch.norm(direction_vectors, dim=-1, keepdim=True) 141 | # Create a batch of arbitrary vectors that are not collinear with the direction vectors 142 | v1 = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32).to(direction_vectors.device).expand(direction_vectors.shape[0], -1).clone() 143 | is_collinear = torch.all(torch.abs(direction_vectors - v1) < 1e-5, dim=-1) 144 | v1[is_collinear] = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32).to(direction_vectors.device) 145 | 146 | # Calculate the first orthogonal vectors 147 | v1 = torch.cross(direction_vectors, v1) 148 | v1 = v1 / (torch.norm(v1, dim=-1, keepdim=True)) 149 | # Calculate the second orthogonal vectors by taking the cross product 150 | v2 = torch.cross(direction_vectors, v1) 151 | v2 = v2 / (torch.norm(v2, dim=-1, keepdim=True)) 152 | # Create the batch of rotation matrices with the direction vectors as the last columns 153 | rotation_matrices = torch.stack((v1, v2, direction_vectors), dim=-1) 154 | return rotation_matrices 155 | 156 | # from kornia.geometry import conversions 157 | # def normal_to_rotation(normals): 158 | # rotations = create_rotation_matrix_from_direction_vector_batch(normals) 159 | # rotations = conversions.rotation_matrix_to_quaternion(rotations,eps=1e-5, order=conversions.QuaternionCoeffOrder.WXYZ) 160 | # return rotations 161 | 162 | 163 | def colormap(img, cmap='jet'): 164 | import matplotlib.pyplot as plt 165 | W, H = img.shape[:2] 166 | dpi = 300 167 | fig, ax = plt.subplots(1, figsize=(H/dpi, W/dpi), dpi=dpi) 168 | im = ax.imshow(img, cmap=cmap) 169 | ax.set_axis_off() 170 | fig.colorbar(im, ax=ax) 171 | fig.tight_layout() 172 | fig.canvas.draw() 173 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 174 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 175 | img = torch.from_numpy(data / 255.).float().permute(2,0,1) 176 | plt.close() 177 | return img -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | 28 | def smooth_loss(disp, img): 29 | grad_disp_x = torch.abs(disp[:,1:-1, :-2] + disp[:,1:-1,2:] - 2 * disp[:,1:-1,1:-1]) 30 | grad_disp_y = torch.abs(disp[:,:-2, 1:-1] + disp[:,2:,1:-1] - 2 * disp[:,1:-1,1:-1]) 31 | grad_img_x = torch.mean(torch.abs(img[:, 1:-1, :-2] - img[:, 1:-1, 2:]), 0, keepdim=True) * 0.5 32 | grad_img_y = torch.mean(torch.abs(img[:, :-2, 1:-1] - img[:, 2:, 1:-1]), 0, keepdim=True) * 0.5 33 | grad_disp_x *= torch.exp(-grad_img_x) 34 | grad_disp_y *= torch.exp(-grad_img_y) 35 | return grad_disp_x.mean() + grad_disp_y.mean() 36 | 37 | def create_window(window_size, channel): 38 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 39 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 40 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 41 | return window 42 | 43 | def ssim(img1, img2, window_size=11, size_average=True): 44 | channel = img1.size(-3) 45 | window = create_window(window_size, channel) 46 | 47 | if img1.is_cuda: 48 | window = window.cuda(img1.get_device()) 49 | window = window.type_as(img1) 50 | 51 | return _ssim(img1, img2, window, window_size, channel, size_average) 52 | 53 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 54 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 55 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 56 | 57 | mu1_sq = mu1.pow(2) 58 | mu2_sq = mu2.pow(2) 59 | mu1_mu2 = mu1 * mu2 60 | 61 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 62 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 63 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 64 | 65 | C1 = 0.01 ** 2 66 | C2 = 0.03 ** 2 67 | 68 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 69 | 70 | if size_average: 71 | return ssim_map.mean() 72 | else: 73 | return ssim_map.mean(0) 74 | 75 | -------------------------------------------------------------------------------- /utils/mcube_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, ShanghaiTech 3 | # SVIP research group, https://github.com/svip-lab 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact huangbb@shanghaitech.edu.cn 10 | # 11 | 12 | import numpy as np 13 | import torch 14 | import trimesh 15 | from skimage import measure 16 | # modified from here https://github.com/autonomousvision/sdfstudio/blob/370902a10dbef08cb3fe4391bd3ed1e227b5c165/nerfstudio/utils/marching_cubes.py#L201 17 | def marching_cubes_with_contraction( 18 | sdf, 19 | resolution=512, 20 | bounding_box_min=(-1.0, -1.0, -1.0), 21 | bounding_box_max=(1.0, 1.0, 1.0), 22 | return_mesh=False, 23 | level=0, 24 | simplify_mesh=True, 25 | inv_contraction=None, 26 | max_range=32.0, 27 | ): 28 | assert resolution % 512 == 0 29 | 30 | resN = resolution 31 | cropN = 512 32 | level = 0 33 | N = resN // cropN 34 | 35 | grid_min = bounding_box_min 36 | grid_max = bounding_box_max 37 | xs = np.linspace(grid_min[0], grid_max[0], N + 1) 38 | ys = np.linspace(grid_min[1], grid_max[1], N + 1) 39 | zs = np.linspace(grid_min[2], grid_max[2], N + 1) 40 | 41 | meshes = [] 42 | for i in range(N): 43 | for j in range(N): 44 | for k in range(N): 45 | print(i, j, k) 46 | x_min, x_max = xs[i], xs[i + 1] 47 | y_min, y_max = ys[j], ys[j + 1] 48 | z_min, z_max = zs[k], zs[k + 1] 49 | 50 | x = np.linspace(x_min, x_max, cropN) 51 | y = np.linspace(y_min, y_max, cropN) 52 | z = np.linspace(z_min, z_max, cropN) 53 | 54 | xx, yy, zz = np.meshgrid(x, y, z, indexing="ij") 55 | points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() 56 | 57 | @torch.no_grad() 58 | def evaluate(points): 59 | z = [] 60 | for _, pnts in enumerate(torch.split(points, 256**3, dim=0)): 61 | z.append(sdf(pnts)) 62 | z = torch.cat(z, axis=0) 63 | return z 64 | 65 | # construct point pyramids 66 | points = points.reshape(cropN, cropN, cropN, 3) 67 | points = points.reshape(-1, 3) 68 | pts_sdf = evaluate(points.contiguous()) 69 | z = pts_sdf.detach().cpu().numpy() 70 | if not (np.min(z) > level or np.max(z) < level): 71 | z = z.astype(np.float32) 72 | verts, faces, normals, _ = measure.marching_cubes( 73 | volume=z.reshape(cropN, cropN, cropN), 74 | level=level, 75 | spacing=( 76 | (x_max - x_min) / (cropN - 1), 77 | (y_max - y_min) / (cropN - 1), 78 | (z_max - z_min) / (cropN - 1), 79 | ), 80 | ) 81 | verts = verts + np.array([x_min, y_min, z_min]) 82 | meshcrop = trimesh.Trimesh(verts, faces, normals) 83 | meshes.append(meshcrop) 84 | 85 | print("finished one block") 86 | 87 | combined = trimesh.util.concatenate(meshes) 88 | combined.merge_vertices(digits_vertex=6) 89 | 90 | # inverse contraction and clipping the points range 91 | if inv_contraction is not None: 92 | combined.vertices = inv_contraction(torch.from_numpy(combined.vertices).float().cuda()).cpu().numpy() 93 | combined.vertices = np.clip(combined.vertices, -max_range, max_range) 94 | 95 | return combined -------------------------------------------------------------------------------- /utils/mesh_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, ShanghaiTech 3 | # SVIP research group, https://github.com/svip-lab 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact huangbb@shanghaitech.edu.cn 10 | # 11 | import os 12 | from functools import partial 13 | from statistics import mean, stdev 14 | 15 | import cv2 16 | import numpy as np 17 | import open3d as o3d 18 | import torch 19 | from tqdm import tqdm 20 | 21 | from utils.render_utils import save_img_u8 22 | 23 | 24 | def post_process_mesh(mesh, cluster_to_keep=1000): 25 | """ 26 | Post-process a mesh to filter out floaters and disconnected parts 27 | """ 28 | import copy 29 | print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep)) 30 | mesh_0 = copy.deepcopy(mesh) 31 | with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm: 32 | triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles()) 33 | 34 | triangle_clusters = np.asarray(triangle_clusters) 35 | cluster_n_triangles = np.asarray(cluster_n_triangles) 36 | cluster_area = np.asarray(cluster_area) 37 | n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep] 38 | n_cluster = max(n_cluster, 50) # filter meshes smaller than 50 39 | triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster 40 | mesh_0.remove_triangles_by_mask(triangles_to_remove) 41 | mesh_0.remove_unreferenced_vertices() 42 | mesh_0.remove_degenerate_triangles() 43 | print("num vertices raw {}".format(len(mesh.vertices))) 44 | print("num vertices post {}".format(len(mesh_0.vertices))) 45 | return mesh_0 46 | 47 | 48 | def to_cam_open3d(viewpoint_stack): 49 | camera_traj = [] 50 | for i, viewpoint_cam in enumerate(viewpoint_stack): 51 | W = viewpoint_cam.image_width 52 | H = viewpoint_cam.image_height 53 | ndc2pix = torch.tensor([ 54 | [W / 2, 0, 0, (W - 1) / 2], 55 | [0, H / 2, 0, (H - 1) / 2], 56 | [0, 0, 0, 1]]).float().cuda().T 57 | intrins = (viewpoint_cam.projection_matrix @ ndc2pix)[:3, :3].T 58 | intrinsic = o3d.camera.PinholeCameraIntrinsic( 59 | width=viewpoint_cam.image_width, 60 | height=viewpoint_cam.image_height, 61 | cx=intrins[0, 2].item(), 62 | cy=intrins[1, 2].item(), 63 | fx=intrins[0, 0].item(), 64 | fy=intrins[1, 1].item() 65 | ) 66 | 67 | extrinsic = np.asarray((viewpoint_cam.world_view_transform.T).cpu().numpy()) 68 | camera = o3d.camera.PinholeCameraParameters() 69 | camera.extrinsic = extrinsic 70 | camera.intrinsic = intrinsic 71 | camera_traj.append(camera) 72 | 73 | return camera_traj 74 | 75 | 76 | class GaussianExtractor(object): 77 | def __init__(self, gaussians, render, pipe, bg_color=None, additional_return=True): 78 | """ 79 | a class that extracts attributes a scene presented by 2DGS 80 | 81 | Usage example: 82 | >>> gaussExtrator = GaussianExtractor(gaussians, render, pipe) 83 | >>> gaussExtrator.reconstruction(view_points) 84 | >>> mesh = gaussExtractor.export_mesh_bounded(...) 85 | """ 86 | if bg_color is None: 87 | bg_color = [0, 0, 0] 88 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 89 | self.gaussians = gaussians 90 | self.render = partial(render, pipe=pipe, bg_color=background, additional_return=additional_return) 91 | self._additional_return = additional_return 92 | self.clean() 93 | 94 | @torch.no_grad() 95 | def clean(self): 96 | self.depthmaps = [] 97 | # self.alphamaps = [] 98 | self.rgbmaps = [] 99 | # self.normals = [] 100 | # self.depth_normals = [] 101 | self.viewpoint_stack = [] 102 | 103 | @torch.no_grad() 104 | def reconstruction(self, viewpoint_stack): 105 | """ 106 | reconstruct radiance field given cameras 107 | """ 108 | self.clean() 109 | self.viewpoint_stack = viewpoint_stack 110 | times = [] 111 | if len(self.viewpoint_stack) > 1: 112 | iterator = tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields") 113 | else: 114 | iterator = enumerate(self.viewpoint_stack) 115 | 116 | for i, viewpoint_cam in iterator: 117 | render_pkg = self.render(viewpoint_cam, self.gaussians) 118 | times.append(render_pkg['fps']) 119 | rgb = render_pkg['render'] 120 | self.rgbmaps.append(rgb.cpu()) 121 | if self._additional_return: 122 | alpha = render_pkg['rend_alpha'] 123 | normal = torch.nn.functional.normalize(render_pkg['rend_normal'], dim=0) 124 | depth = render_pkg['surf_depth'] 125 | depth_normal = render_pkg['surf_normal'] 126 | self.depthmaps.append(depth.cpu()) 127 | # self.alphamaps.append(alpha.cpu()) 128 | # self.normals.append(normal.cpu()) 129 | # self.depth_normals.append(depth_normal.cpu()) 130 | 131 | self.times = times 132 | mean_time = mean(times) 133 | std_time = 0 134 | if len(times) > 1: 135 | std_time = stdev(times) 136 | print("FPS:", mean_time, " std:", std_time) 137 | # self.rgbmaps = torch.stack(self.rgbmaps, dim=0) 138 | # self.depthmaps = torch.stack(self.depthmaps, dim=0) 139 | # self.alphamaps = torch.stack(self.alphamaps, dim=0) 140 | # self.depth_normals = torch.stack(self.depth_normals, dim=0) 141 | self.estimate_bounding_sphere() 142 | 143 | return mean_time, std_time 144 | 145 | def estimate_bounding_sphere(self): 146 | """ 147 | Estimate the bounding sphere given camera pose 148 | """ 149 | from utils.render_utils import focus_point_fn 150 | torch.cuda.empty_cache() 151 | c2ws = np.array( 152 | [np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in self.viewpoint_stack]) 153 | poses = c2ws[:, :3, :] @ np.diag([1, -1, -1, 1]) 154 | center = (focus_point_fn(poses)) 155 | self.radius = np.linalg.norm(c2ws[:, :3, 3] - center, axis=-1).min() 156 | self.center = torch.from_numpy(center).float().cuda() 157 | print(f"The estimated bounding radius is {self.radius:.2f}") 158 | print(f"Use at least {2.0 * self.radius:.2f} for depth_trunc") 159 | 160 | @torch.no_grad() 161 | def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, mask_backgrond=True): 162 | """ 163 | Perform TSDF fusion given a fixed depth range, used in the paper. 164 | 165 | voxel_size: the voxel size of the volume 166 | sdf_trunc: truncation value 167 | depth_trunc: maximum depth range, should depended on the scene's scales 168 | mask_backgrond: whether to mask backgroud, only works when the dataset have masks 169 | 170 | return o3d.mesh 171 | """ 172 | print("Running tsdf volume integration ...") 173 | print(f'voxel_size: {voxel_size}') 174 | print(f'sdf_trunc: {sdf_trunc}') 175 | print(f'depth_truc: {depth_trunc}') 176 | 177 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 178 | voxel_length=voxel_size, 179 | sdf_trunc=sdf_trunc, 180 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 181 | ) 182 | 183 | for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack)), desc="TSDF integration progress"): 184 | rgb = self.rgbmaps[i] 185 | depth = self.depthmaps[i] 186 | 187 | # if we have mask provided, use it 188 | if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None): 189 | depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0 190 | 191 | # make open3d rgbd 192 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 193 | o3d.geometry.Image( 194 | np.asarray(np.clip(rgb.permute(1, 2, 0).cpu().numpy(), 0.0, 1.0) * 255, order="C", dtype=np.uint8)), 195 | o3d.geometry.Image(np.asarray(depth.permute(1, 2, 0).cpu().numpy(), order="C")), 196 | depth_trunc=depth_trunc, convert_rgb_to_intensity=False, 197 | depth_scale=1.0 198 | ) 199 | 200 | volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic) 201 | 202 | mesh = volume.extract_triangle_mesh() 203 | return mesh 204 | 205 | @torch.no_grad() 206 | def extract_mesh_unbounded(self, resolution=1024): 207 | """ 208 | Experimental features, extracting meshes from unbounded scenes, not fully test across datasets. 209 | return o3d.mesh 210 | """ 211 | 212 | def contract(x): 213 | mag = torch.linalg.norm(x, ord=2, dim=-1)[..., None] 214 | return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag)) 215 | 216 | def uncontract(y): 217 | mag = torch.linalg.norm(y, ord=2, dim=-1)[..., None] 218 | return torch.where(mag < 1, y, (1 / (2 - mag) * (y / mag))) 219 | 220 | def compute_sdf_perframe(i, points, depthmap, rgbmap, viewpoint_cam): 221 | """ 222 | compute per frame sdf 223 | """ 224 | new_points = torch.cat([points, torch.ones_like(points[..., :1])], 225 | dim=-1) @ viewpoint_cam.full_proj_transform 226 | z = new_points[..., -1:] 227 | pix_coords = (new_points[..., :2] / new_points[..., -1:]) 228 | mask_proj = ((pix_coords > -1.) & (pix_coords < 1.) & (z > 0)).all(dim=-1) 229 | sampled_depth = torch.nn.functional.grid_sample(depthmap.cuda()[None], pix_coords[None, None], 230 | mode='bilinear', padding_mode='border', 231 | align_corners=True).reshape(-1, 1) 232 | sampled_rgb = torch.nn.functional.grid_sample(rgbmap.cuda()[None], pix_coords[None, None], mode='bilinear', 233 | padding_mode='border', align_corners=True).reshape(3, -1).T 234 | sdf = (sampled_depth - z) 235 | return sdf, sampled_rgb, mask_proj 236 | 237 | def compute_unbounded_tsdf(samples, inv_contraction, voxel_size, return_rgb=False): 238 | """ 239 | Fusion all frames, perform adaptive sdf_funcation on the contract spaces. 240 | """ 241 | if inv_contraction is not None: 242 | mask = torch.linalg.norm(samples, dim=-1) > 1 243 | # adaptive sdf_truncation 244 | sdf_trunc = 5 * voxel_size * torch.ones_like(samples[:, 0]) 245 | sdf_trunc[mask] *= 1 / (2 - torch.linalg.norm(samples, dim=-1)[mask].clamp(max=1.9)) 246 | samples = inv_contraction(samples) 247 | else: 248 | sdf_trunc = 5 * voxel_size 249 | 250 | tsdfs = torch.ones_like(samples[:, 0]) * 1 251 | rgbs = torch.zeros((samples.shape[0], 3)).cuda() 252 | 253 | weights = torch.ones_like(samples[:, 0]) 254 | for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="TSDF integration progress"): 255 | sdf, rgb, mask_proj = compute_sdf_perframe(i, samples, 256 | depthmap=self.depthmaps[i], 257 | rgbmap=self.rgbmaps[i], 258 | viewpoint_cam=self.viewpoint_stack[i], 259 | ) 260 | 261 | # volume integration 262 | sdf = sdf.flatten() 263 | mask_proj = mask_proj & (sdf > -sdf_trunc) 264 | sdf = torch.clamp(sdf / sdf_trunc, min=-1.0, max=1.0)[mask_proj] 265 | w = weights[mask_proj] 266 | wp = w + 1 267 | tsdfs[mask_proj] = (tsdfs[mask_proj] * w + sdf) / wp 268 | rgbs[mask_proj] = (rgbs[mask_proj] * w[:, None] + rgb[mask_proj]) / wp[:, None] 269 | # update weight 270 | weights[mask_proj] = wp 271 | 272 | if return_rgb: 273 | return tsdfs, rgbs 274 | 275 | return tsdfs 276 | 277 | normalize = lambda x: (x - self.center) / self.radius 278 | unnormalize = lambda x: (x * self.radius) + self.center 279 | inv_contraction = lambda x: unnormalize(uncontract(x)) 280 | 281 | N = resolution 282 | voxel_size = (self.radius * 2 / N) 283 | print(f"Computing sdf gird resolution {N} x {N} x {N}") 284 | print(f"Define the voxel_size as {voxel_size}") 285 | sdf_function = lambda x: compute_unbounded_tsdf(x, inv_contraction, voxel_size) 286 | from utils.mcube_utils import marching_cubes_with_contraction 287 | R = contract(normalize(self.gaussians.get_xyz)).norm(dim=-1).cpu().numpy() 288 | R = np.quantile(R, q=0.95) 289 | R = min(R + 0.01, 1.9) 290 | 291 | mesh = marching_cubes_with_contraction( 292 | sdf=sdf_function, 293 | bounding_box_min=(-R, -R, -R), 294 | bounding_box_max=(R, R, R), 295 | level=0, 296 | resolution=N, 297 | inv_contraction=inv_contraction, 298 | ) 299 | 300 | # coloring the mesh 301 | torch.cuda.empty_cache() 302 | mesh = mesh.as_open3d 303 | print("texturing mesh ... ") 304 | _, rgbs = compute_unbounded_tsdf(torch.tensor(np.asarray(mesh.vertices)).float().cuda(), inv_contraction=None, 305 | voxel_size=voxel_size, return_rgb=True) 306 | mesh.vertex_colors = o3d.utility.Vector3dVector(rgbs.cpu().numpy()) 307 | return mesh 308 | 309 | @torch.no_grad() 310 | def export_image(self, path, export_gt=True, print_fps=False): 311 | render_path = os.path.join(path, "renders") 312 | vis_path = os.path.join(path, "vis") 313 | os.makedirs(render_path, exist_ok=True) 314 | os.makedirs(vis_path, exist_ok=True) 315 | if export_gt: 316 | gts_path = os.path.join(path, "gt") 317 | os.makedirs(gts_path, exist_ok=True) 318 | 319 | for idx, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="export images"): 320 | if export_gt: 321 | gt = viewpoint_cam.original_image[0:3, :, :] 322 | save_img_u8(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 323 | 324 | image = self.rgbmaps[idx] 325 | if print_fps: 326 | fps = '{:4d}'.format(int(self.times[idx])) 327 | image = image.numpy() 328 | image = np.transpose(image, (1, 2, 0)).copy() 329 | cv2.putText(image, 'FPS: ' + str(fps), (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 330 | 1, (0, 0, 0), 3, 2) 331 | cv2.putText(image, 'FPS: ' + str(fps), (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 332 | 1, (1, 1, 1), 1, 2) 333 | image = np.transpose(image, (2, 0, 1)) 334 | image = torch.tensor(image) 335 | save_img_u8(image, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 336 | if self._additional_return: 337 | depth = self.depthmaps[idx][0] 338 | #save_img_f32(depth, os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".tiff")) 339 | depth[depth > 30] = 30 340 | depth = depth**0.1 341 | depth = 1 - (depth - depth.min()) / (depth.max() - depth.min()) 342 | save_img_u8(depth, os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".png")) 343 | #save_img_u8(self.normals[idx] * 0.5 + 0.5, os.path.join(vis_path, 'normal_{0:05d}'.format(idx) + ".png")) 344 | #save_img_u8(self.depth_normals[idx] * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + ".png")) 345 | -------------------------------------------------------------------------------- /utils/point_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os, cv2 6 | import matplotlib.pyplot as plt 7 | import math 8 | 9 | def depths_to_points(view, depthmap): 10 | c2w = (view.world_view_transform.T).inverse() 11 | W, H = view.image_width, view.image_height 12 | fx = W / (2 * math.tan(view.FoVx / 2.)) 13 | fy = H / (2 * math.tan(view.FoVy / 2.)) 14 | intrins = torch.tensor( 15 | [[fx, 0., W/2.], 16 | [0., fy, H/2.], 17 | [0., 0., 1.0]] 18 | ).float().cuda() 19 | grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy') 20 | points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3) 21 | rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T 22 | rays_o = c2w[:3,3] 23 | points = depthmap.reshape(-1, 1) * rays_d + rays_o 24 | return points 25 | 26 | def depth_to_normal(view, depth): 27 | """ 28 | view: view camera 29 | depth: depthmap 30 | """ 31 | points = depths_to_points(view, depth).reshape(*depth.shape[1:], 3) 32 | output = torch.zeros_like(points) 33 | dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) 34 | dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) 35 | normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) 36 | output[1:-1, 1:-1, :] = normal_map 37 | return output -------------------------------------------------------------------------------- /utils/reconstruction_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, ShanghaiTech 3 | # SVIP research group, https://github.com/svip-lab 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact huangbb@shanghaitech.edu.cn 10 | # 11 | import os 12 | from functools import partial 13 | from statistics import mean, stdev 14 | 15 | import cv2 16 | import numpy as np 17 | import torch 18 | from tqdm import tqdm 19 | 20 | from utils.render_utils import save_img_u8 21 | 22 | 23 | class GaussianExtractor(object): 24 | def __init__(self, gaussians, render, pipe, bg_color=None, additional_return=True): 25 | """ 26 | a class that extracts attributes a scene presented by 2DGS 27 | 28 | Usage example: 29 | >>> gaussExtrator = GaussianExtractor(gaussians, render, pipe) 30 | >>> gaussExtrator.reconstruction(view_points) 31 | >>> mesh = gaussExtractor.export_mesh_bounded(...) 32 | """ 33 | if bg_color is None: 34 | bg_color = [0, 0, 0] 35 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 36 | self.gaussians = gaussians 37 | self.render = partial(render, pipe=pipe, bg_color=background, additional_return=additional_return) 38 | self._additional_return = additional_return 39 | self.clean() 40 | 41 | @torch.no_grad() 42 | def clean(self): 43 | self.depthmaps = [] 44 | self.alphamaps = [] 45 | self.rgbmaps = [] 46 | self.normals = [] 47 | self.depth_normals = [] 48 | self.viewpoint_stack = [] 49 | self.times = [] 50 | 51 | @torch.no_grad() 52 | def reconstruction(self, viewpoint_stack): 53 | """ 54 | reconstruct radiance field given cameras 55 | """ 56 | self.clean() 57 | self.viewpoint_stack = viewpoint_stack 58 | times = [] 59 | if len(self.viewpoint_stack) > 1: 60 | iterator = tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields") 61 | else: 62 | iterator = enumerate(self.viewpoint_stack) 63 | 64 | for i, viewpoint_cam in iterator: 65 | render_pkg = self.render(viewpoint_cam, self.gaussians) 66 | times.append(render_pkg['fps']) 67 | rgb = render_pkg['render'] 68 | self.rgbmaps.append(rgb.cpu()) 69 | if self._additional_return: 70 | alpha = render_pkg['rend_alpha'] 71 | normal = torch.nn.functional.normalize(render_pkg['rend_normal'], dim=0) 72 | depth = render_pkg['surf_depth'] 73 | depth_normal = render_pkg['surf_normal'] 74 | self.depthmaps.append(depth.cpu()) 75 | self.alphamaps.append(alpha.cpu()) 76 | self.normals.append(normal.cpu()) 77 | self.depth_normals.append(depth_normal.cpu()) 78 | 79 | self.times = times 80 | mean_time = mean(times) 81 | std_time = 0 82 | if len(times) > 1: 83 | std_time = stdev(times) 84 | print("FPS:", mean_time, " std:", std_time) 85 | #self.rgbmaps = torch.stack(self.rgbmaps, dim=0) 86 | if self._additional_return: 87 | self.depthmaps = torch.stack(self.depthmaps, dim=0) 88 | self.alphamaps = torch.stack(self.alphamaps, dim=0) 89 | self.depth_normals = torch.stack(self.depth_normals, dim=0) 90 | 91 | return mean_time, std_time 92 | 93 | @torch.no_grad() 94 | def export_image(self, path, export_gt=True, print_fps=False): 95 | render_path = os.path.join(path, "renders") 96 | os.makedirs(render_path, exist_ok=True) 97 | if export_gt: 98 | gts_path = os.path.join(path, "gt") 99 | os.makedirs(gts_path, exist_ok=True) 100 | 101 | for idx, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="export images"): 102 | if export_gt: 103 | gt = viewpoint_cam.original_image[0:3, :, :] 104 | save_img_u8(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 105 | 106 | image = self.rgbmaps[idx] 107 | if print_fps: 108 | fps = '{:4d}'.format(int(self.times[idx])) 109 | image = image.numpy() 110 | image = np.transpose(image, (1, 2, 0)).copy() 111 | cv2.putText(image, 'FPS: ' + str(fps), (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 112 | 1, (0, 0, 0), 3, 2) 113 | cv2.putText(image, 'FPS: ' + str(fps), (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 114 | 1, (1, 1, 1), 1, 2) 115 | image = np.transpose(image, (2, 0, 1)) 116 | image = torch.tensor(image) 117 | save_img_u8(image, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 118 | -------------------------------------------------------------------------------- /utils/render_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import os 17 | from typing import Tuple 18 | 19 | import mediapy as media 20 | import numpy as np 21 | import torch 22 | import torchvision 23 | from PIL import Image 24 | from tqdm import tqdm 25 | 26 | 27 | def normalize(x: np.ndarray) -> np.ndarray: 28 | """Normalization helper function.""" 29 | return x / np.linalg.norm(x) 30 | 31 | def pad_poses(p: np.ndarray) -> np.ndarray: 32 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" 33 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) 34 | return np.concatenate([p[..., :3, :4], bottom], axis=-2) 35 | 36 | 37 | def unpad_poses(p: np.ndarray) -> np.ndarray: 38 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" 39 | return p[..., :3, :4] 40 | 41 | 42 | def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 43 | """Recenter poses around the origin.""" 44 | cam2world = average_pose(poses) 45 | transform = np.linalg.inv(pad_poses(cam2world)) 46 | poses = transform @ pad_poses(poses) 47 | return unpad_poses(poses), transform 48 | 49 | 50 | def average_pose(poses: np.ndarray) -> np.ndarray: 51 | """New pose using average position, z-axis, and up vector of input poses.""" 52 | position = poses[:, :3, 3].mean(0) 53 | z_axis = poses[:, :3, 2].mean(0) 54 | up = poses[:, :3, 1].mean(0) 55 | cam2world = viewmatrix(z_axis, up, position) 56 | return cam2world 57 | 58 | def viewmatrix(lookdir: np.ndarray, up: np.ndarray, 59 | position: np.ndarray) -> np.ndarray: 60 | """Construct lookat view matrix.""" 61 | vec2 = normalize(lookdir) 62 | vec0 = normalize(np.cross(up, vec2)) 63 | vec1 = normalize(np.cross(vec2, vec0)) 64 | m = np.stack([vec0, vec1, vec2, position], axis=1) 65 | return m 66 | 67 | def focus_point_fn(poses: np.ndarray) -> np.ndarray: 68 | """Calculate nearest point to all focal axes in poses.""" 69 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] 70 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) 71 | mt_m = np.transpose(m, [0, 2, 1]) @ m 72 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] 73 | return focus_pt 74 | 75 | def transform_poses_pca(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 76 | """Transforms poses so principal components lie on XYZ axes. 77 | 78 | Args: 79 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms. 80 | 81 | Returns: 82 | A tuple (poses, transform), with the transformed poses and the applied 83 | camera_to_world transforms. 84 | """ 85 | t = poses[:, :3, 3] 86 | t_mean = t.mean(axis=0) 87 | t = t - t_mean 88 | 89 | eigval, eigvec = np.linalg.eig(t.T @ t) 90 | # Sort eigenvectors in order of largest to smallest eigenvalue. 91 | inds = np.argsort(eigval)[::-1] 92 | eigvec = eigvec[:, inds] 93 | rot = eigvec.T 94 | if np.linalg.det(rot) < 0: 95 | rot = np.diag(np.array([1, 1, -1])) @ rot 96 | 97 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) 98 | poses_recentered = unpad_poses(transform @ pad_poses(poses)) 99 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) 100 | 101 | # Flip coordinate system if z component of y-axis is negative 102 | if poses_recentered.mean(axis=0)[2, 1] < 0: 103 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered 104 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform 105 | 106 | return poses_recentered, transform 107 | # points = np.random.rand(3,100) 108 | # points_h = np.concatenate((points,np.ones_like(points[:1])), axis=0) 109 | # (poses_recentered @ points_h)[0] 110 | # (transform @ pad_poses(poses) @ points_h)[0,:3] 111 | # import pdb; pdb.set_trace() 112 | 113 | # # Just make sure it's it in the [-1, 1]^3 cube 114 | # scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) 115 | # poses_recentered[:, :3, 3] *= scale_factor 116 | # transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform 117 | 118 | # return poses_recentered, transform 119 | 120 | def generate_ellipse_path(poses: np.ndarray, 121 | n_frames: int = 120, 122 | const_speed: bool = True, 123 | z_variation: float = 0., 124 | z_phase: float = 0.) -> np.ndarray: 125 | """Generate an elliptical render path based on the given poses.""" 126 | # Calculate the focal point for the path (cameras point toward this). 127 | center = focus_point_fn(poses) 128 | # Path height sits at z=0 (in middle of zero-mean capture pattern). 129 | offset = np.array([center[0], center[1], 0]) 130 | 131 | # Calculate scaling for ellipse axes based on input camera positions. 132 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) 133 | # Use ellipse that is symmetric about the focal point in xy. 134 | low = -sc + offset 135 | high = sc + offset 136 | # Optional height variation need not be symmetric 137 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) 138 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) 139 | 140 | def get_positions(theta): 141 | # Interpolate between bounds with trig functions to get ellipse in x-y. 142 | # Optionally also interpolate in z to change camera height along path. 143 | return np.stack([ 144 | low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5), 145 | low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5), 146 | z_variation * (z_low[2] + (z_high - z_low)[2] * 147 | (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)), 148 | ], -1) 149 | 150 | theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True) 151 | positions = get_positions(theta) 152 | 153 | #if const_speed: 154 | 155 | # # Resample theta angles so that the velocity is closer to constant. 156 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 157 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) 158 | # positions = get_positions(theta) 159 | 160 | # Throw away duplicated last position. 161 | positions = positions[:-1] 162 | 163 | # Set path's up vector to axis closest to average of input pose up vectors. 164 | avg_up = poses[:, :3, 1].mean(0) 165 | avg_up = avg_up / np.linalg.norm(avg_up) 166 | ind_up = np.argmax(np.abs(avg_up)) 167 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 168 | 169 | return np.stack([viewmatrix(p - center, up, p) for p in positions]) 170 | 171 | 172 | def generate_path(viewpoint_cameras, n_frames=480): 173 | c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in viewpoint_cameras]) 174 | pose = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1]) 175 | pose_recenter, colmap_to_world_transform = transform_poses_pca(pose) 176 | 177 | # generate new poses 178 | new_poses = generate_ellipse_path(poses=pose_recenter, n_frames=n_frames) 179 | # warp back to orignal scale 180 | new_poses = np.linalg.inv(colmap_to_world_transform) @ pad_poses(new_poses) 181 | 182 | traj = [] 183 | for c2w in new_poses: 184 | c2w = c2w @ np.diag([1, -1, -1, 1]) 185 | cam = copy.deepcopy(viewpoint_cameras[0]) 186 | cam.image_height = int(cam.image_height / 2) * 2 187 | cam.image_width = int(cam.image_width / 2) * 2 188 | cam.world_view_transform = torch.from_numpy(np.linalg.inv(c2w).T).float().cuda() 189 | cam.full_proj_transform = (cam.world_view_transform.unsqueeze(0).bmm(cam.projection_matrix.unsqueeze(0))).squeeze(0) 190 | cam.camera_center = cam.world_view_transform.inverse()[3, :3] 191 | traj.append(cam) 192 | 193 | return traj 194 | 195 | def load_img(pth: str) -> np.ndarray: 196 | """Load an image and cast to float32.""" 197 | with open(pth, 'rb') as f: 198 | image = np.array(Image.open(f), dtype=np.float32) 199 | return image 200 | 201 | 202 | def create_videos(base_dir, input_dir, out_name, num_frames=480): 203 | """Creates videos out of the images saved to disk.""" 204 | # Last two parts of checkpoint path are experiment name and scene name. 205 | video_prefix = f'{out_name}' 206 | 207 | zpad = max(5, len(str(num_frames - 1))) 208 | idx_to_str = lambda idx: str(idx).zfill(zpad) 209 | 210 | os.makedirs(base_dir, exist_ok=True) 211 | 212 | # Load one example frame to get image shape and depth range. 213 | rgb_file = os.path.join(input_dir, 'renders', f'{idx_to_str(0)}.png') 214 | rgb_frame = load_img(rgb_file) 215 | shape = rgb_frame.shape 216 | print(f'Video shape is {shape[:2]}') 217 | 218 | video_kwargs = { 219 | 'shape': shape[:2], 220 | 'codec': 'h264', 221 | 'fps': 30, 222 | 'crf': 1, 223 | } 224 | 225 | for k in ['color']: 226 | video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') 227 | input_format = 'rgb' 228 | file_ext = 'png' 229 | 230 | if k == 'color': 231 | file0 = os.path.join(input_dir, 'renders', f'{idx_to_str(0)}.{file_ext}') 232 | 233 | if not os.path.exists(file0): 234 | print(f'Images missing for tag {k}') 235 | continue 236 | print(f'Making video {video_file}...') 237 | with media.VideoWriter( 238 | video_file, **video_kwargs, input_format=input_format) as writer: 239 | for idx in tqdm(range(num_frames)): 240 | img_file = os.path.join(input_dir, 'renders', f'{idx_to_str(idx)}.{file_ext}') 241 | 242 | if not os.path.exists(img_file): 243 | ValueError(f'Image file {img_file} does not exist.') 244 | img = load_img(img_file) 245 | img = img / 255. 246 | 247 | frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8) 248 | writer.add_image(frame) 249 | idx += 1 250 | 251 | def save_img_u8(img, pth): 252 | """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" 253 | torchvision.utils.save_image(img, pth) 254 | 255 | def save_img_f32(depthmap, pth): 256 | """Save an image (probably a depthmap) to disk as a float32 TIFF.""" 257 | with open(pth, 'wb') as f: 258 | Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF') -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from arguments import ModelParams, PipelineParams, get_combined_args 9 | from gaussian_renderer import GaussianModel 10 | from gaussian_renderer import render 11 | from scene import Scene 12 | from utils.general_utils import build_rotation 13 | from utils.reconstruction_utils import GaussianExtractor 14 | 15 | if __name__ == "__main__": 16 | # Set up command line argument parser 17 | parser = ArgumentParser(description="Testing script parameters") 18 | model = ModelParams(parser, sentinel=True) 19 | pipeline = PipelineParams(parser) 20 | parser.add_argument("--iteration", default=-1, type=int) 21 | args = get_combined_args(parser) 22 | print("Rendering " + args.model_path) 23 | 24 | control_panel = cv2.imread("assets/control_panel.png")[..., ::-1].astype(np.float32) / 255. 25 | 26 | dataset, iteration, pipe = model.extract(args), args.iteration, pipeline.extract(args) 27 | gaussians = GaussianModel(dataset.sh_degree, texture_preproc=True) 28 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 29 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 30 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 31 | 32 | train_dir = os.path.join(args.model_path, 'train', "ours_{}".format(scene.loaded_iter)) 33 | test_dir = os.path.join(args.model_path, 'test', "ours_{}".format(scene.loaded_iter)) 34 | gaussExtractor = GaussianExtractor(gaussians, render, pipe, bg_color=bg_color, additional_return=False) 35 | 36 | speed_data = {"points": len(gaussians.get_xyz)} 37 | 38 | idx = 0 39 | cameras = scene.getTestCameras()[idx: idx+1].copy() 40 | frame_num = 0 41 | while True: 42 | mean_time, std_time = gaussExtractor.reconstruction(cameras) 43 | render = gaussExtractor.rgbmaps[0].detach().cpu().numpy() 44 | render = np.transpose(render, (1, 2, 0)).copy() 45 | if frame_num == 0: 46 | scale = render.shape[1] / control_panel.shape[1] 47 | control_panel = cv2.resize(control_panel, None, fx=scale, fy=scale) 48 | 49 | if frame_num > 5: 50 | mean_time = int(mean_time) 51 | cv2.putText(render, 'FPS: ' + str(mean_time), (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 52 | 1, (0, 0, 0), 3, 2) 53 | cv2.putText(render, 'FPS: ' + str(mean_time),(10, 50), cv2.FONT_HERSHEY_SIMPLEX, 54 | 1,(255, 255, 255),1,2) 55 | 56 | render = cv2.vconcat([render, control_panel]) 57 | cv2.imshow("Render", render[..., ::-1]) 58 | key = cv2.waitKey(-1) & 0b11111111 59 | 60 | speed_t = 0.03 61 | speed_r = speed_t / 2.0 62 | if key == ord("q"): 63 | break 64 | if key == ord("a"): 65 | cameras[0].world_view_transform[3, 0] += speed_t 66 | if key == ord("d"): 67 | cameras[0].world_view_transform[3, 0] -= speed_t 68 | if key == ord("w"): 69 | cameras[0].world_view_transform[3, 2] -= speed_t 70 | if key == ord("s"): 71 | cameras[0].world_view_transform[3, 2] += speed_t 72 | if key == ord("e"): 73 | cameras[0].world_view_transform[3, 1] += speed_t 74 | if key == ord("f"): 75 | cameras[0].world_view_transform[3, 1] -= speed_t 76 | 77 | if key == ord("j"): 78 | R = build_rotation(torch.tensor([[1-speed_r, -speed_r, 0, 0]]).cuda())[0] 79 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R) 80 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R) 81 | if key == ord("u"): 82 | R = build_rotation(torch.tensor([[1-speed_r, speed_r, 0, 0]]).cuda())[0] 83 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R) 84 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R) 85 | if key == ord("k"): 86 | R = build_rotation(torch.tensor([[1-speed_r, 0, speed_r, 0]]).cuda())[0] 87 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R) 88 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R) 89 | if key == ord("h"): 90 | R = build_rotation(torch.tensor([[1-speed_r, 0, -speed_r, 0]]).cuda())[0] 91 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R) 92 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R) 93 | if key == ord("l"): 94 | R = build_rotation(torch.tensor([[1-speed_r, 0, 0, speed_r]]).cuda())[0] 95 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R) 96 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R) 97 | if key == ord("i"): 98 | R = build_rotation(torch.tensor([[1-speed_r, 0, 0, -speed_r]]).cuda())[0] 99 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R) 100 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R) 101 | 102 | if key == 32: 103 | idx += 1 104 | if idx >= len(scene.getTestCameras()): 105 | idx = 0 106 | cameras = scene.getTestCameras()[idx: idx+1].copy() 107 | 108 | cameras[0].update_proj_matrix() 109 | frame_num += 1 110 | 111 | --------------------------------------------------------------------------------