├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── arguments └── __init__.py ├── assets ├── Ablation_LOD.jpg ├── datagallery.jpg ├── device.jpg └── teaser.jpg ├── convert.py ├── environment.yml ├── full_eval.py ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py ├── gaussian_model.py └── octree_loader.py ├── train.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://github.com/zhaofuq/simple-knn.git 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = https://github.com/zhaofuq/diff-gaussian-rasterization.git 7 | [submodule "SIBR_viewers"] 8 | path = SIBR_viewers 9 | url = https://github.com/zhaofuq/LOD-SIBR-Viewer.git 10 | [submodule "PotreeConverter"] 11 | path = PotreeConverter 12 | url = https://github.com/zhaofuq/PotreeConverter.git 13 | 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fuqiang Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LetsGo: Large-Scale Garage Modeling and Rendering via LiDAR-Assisted Gaussian Primitives 2 | 3 | 4 | [Project page](https://zhaofuq.github.io/LetsGo/) | [Paper](https://dl.acm.org/doi/pdf/10.1145/3687762) | [Video](https://www.youtube.com/watch?v=fs42UBKvGRw) | [LOD Viewer (SIBR)](https://zhaofuq.github.io/LetsGo/) | [Web Viewer](https://zhaofuq.github.io/LetsGo/)| [GarageWorld Dataset](https://zhaofuq.github.io/LetsGo/)
5 | 6 | ![Teaser image](assets/teaser.jpg) 7 | 8 | This repository contains the official implementation associated with the paper "LetsGo: Large-Scale Garage Modeling and Rendering via LiDAR-Assisted Gaussian Primitives" 9 | 10 | ## Abstract 11 | 12 | Large garages are ubiquitous yet intricate scenes in our daily lives. They pose challenges characterized by monotonous colors, repetitive patterns, reflective surfaces, and transparent vehicle glass. 13 | Conventional Structure from Motion (SfM) methods for camera pose estimation and 3D reconstruction fail in these environments due to poor correspondence construction. To address these challenges, this paper introduces LetsGo, a LiDAR-assisted Gaussian splatting framework for large-scale garage modeling and rendering. 14 | We develop a handheld scanner, Polar, equipped with IMU, LiDAR, and a fisheye camera, to facilitate accurate LiDAR and image data scanning. 15 | With this Polar device, we present a GarageWorld dataset consisting of eight expansive garage scenes with diverse geometric structures and will release the dataset to the community for further research. 16 | We demonstrate that the collected LiDAR point cloud by the Polar device enhances a suite of 3D Gaussian splatting algorithms for garage scene modeling and rendering. 17 | We also introduce a novel depth regularizer that effectively eliminates floating artifacts in rendered images. 18 | Furthermore, we propose a multi-resolution 3D Gaussian representation designed for Level-of-Detail rendering. We use tailored scaling factors for individual levels and a random-resolution-level training scheme to optimize the Gaussians across different levels. This 3D Gaussian representation enables efficient rendering of large-scale garage scenes on lightweight devices via a web-based renderer. 19 | Experimental results on our dataset, along with ScanNet++ and KITTI-360, demonstrate the superiority of our method in rendering quality and resource efficiency. 20 | 21 | ## Installation 22 | ```bash 23 | # clone repo 24 | git clone https://github.com/zhaofuq/LOD-3DGS.git --recursive 25 | 26 | # create a new python environment 27 | conda env create --file environment.yml 28 | conda activate lod-3dgs 29 | 30 | # build PotreeConverter for Multi-Resolution Pointcloud, tested on windows 31 | cd PotreeConverter 32 | cmake . -B build 33 | cmake --build build --config Release 34 | ``` 35 | 36 | ## Training 37 | We provide a small scale garage [sample dataset](https://drive.google.com/drive/folders/1sO8XHeHum3oiKC0sd7ZIm8LUSYf7JuR9). To train a scene in our GarageWorld dataset, simply use 38 | ```bash 39 | python train.py -s \ 40 | --use_lod \ 41 | --sh_degree 2 \ 42 | --depths depths \ # use for depth loss if contains depths folder. 43 | --densification_interval 10000 \ 44 | --iterations 300000 \ 45 | --scaling_lr 0.0015 \ 46 |     --position_lr_init 0.000016 \ 47 | --opacity_reset_interval 300000 \ 48 | --densify_until_iter 200000 \ 49 | --data_device cpu \ 50 | -r 1 51 | ``` 52 | 53 |
54 | Command Line Arguments for train.py 55 | 56 | #### --source_path / -s 57 | Path to the source directory containing a COLMAP or Synthetic NeRF data set. 58 | #### --model_path / -m 59 | Path where the trained model should be stored (```output/``` by default). 60 | #### --images / -i 61 | Alternative subdirectory for COLMAP images (```images``` by default). 62 | #### --depth / -d 63 | Alternative subdirectory for depth images (```depths``` by default). 64 | #### --eval 65 | Add this flag to use a MipNeRF360-style training/test split for evaluation. 66 | #### --resolution / -r 67 | Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.** 68 | #### --data_device 69 | Specifies where to put the source image data, ```cuda``` by default, recommended to use ```cpu``` if training on large/high-resolution dataset, will reduce VRAM consumption, but slightly slow down training. Thanks to [HrsPythonix](https://github.com/HrsPythonix). 70 | #### --white_background / -w 71 | Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset. 72 | #### --sh_degree 73 | Order of spherical harmonics to be used (no larger than 3). ```3``` by default. 74 | #### --convert_SHs_python 75 | Flag to make pipeline compute forward and backward of SHs with PyTorch instead of ours. 76 | #### --convert_cov3D_python 77 | Flag to make pipeline compute forward and backward of the 3D covariance with PyTorch instead of ours. 78 | #### --debug 79 | Enables debug mode if you experience erros. If the rasterizer fails, a ```dump``` file is created that you may forward to us in an issue so we can take a look. 80 | #### --debug_from 81 | Debugging is **slow**. You may specify an iteration (starting from 0) after which the above debugging becomes active. 82 | #### --iterations 83 | Number of total iterations to train for, ```30_000``` by default. 84 | #### --ip 85 | IP to start GUI server on, ```127.0.0.1``` by default. 86 | #### --port 87 | Port to use for GUI server, ```6009``` by default. 88 | #### --test_iterations 89 | Space-separated iterations at which the training script computes L1 and PSNR over test set, ```7000 30000``` by default. 90 | #### --save_iterations 91 | Space-separated iterations at which the training script saves the Gaussian model, ```7000 30000 ``` by default. 92 | #### --checkpoint_iterations 93 | Space-separated iterations at which to store a checkpoint for continuing later, saved in the model directory. 94 | #### --start_checkpoint 95 | Path to a saved checkpoint to continue training from. 96 | #### --quiet 97 | Flag to omit any text written to standard out pipe. 98 | #### --feature_lr 99 | Spherical harmonics features learning rate, ```0.0025``` by default. 100 | #### --opacity_lr 101 | Opacity learning rate, ```0.05``` by default. 102 | #### --scaling_lr 103 | Scaling learning rate, ```0.005``` by default. 104 | #### --rotation_lr 105 | Rotation learning rate, ```0.001``` by default. 106 | #### --position_lr_max_steps 107 | Number of steps (from 0) where position learning rate goes from ```initial``` to ```final```. ```30_000``` by default. 108 | #### --position_lr_init 109 | Initial 3D position learning rate, ```0.00016``` by default. 110 | #### --position_lr_final 111 | Final 3D position learning rate, ```0.0000016``` by default. 112 | #### --position_lr_delay_mult 113 | Position learning rate multiplier (cf. Plenoxels), ```0.01``` by default. 114 | #### --densify_from_iter 115 | Iteration where densification starts, ```500``` by default. 116 | #### --densify_until_iter 117 | Iteration where densification stops, ```15_000``` by default. 118 | #### --densify_grad_threshold 119 | Limit that decides if points should be densified based on 2D position gradient, ```0.0002``` by default. 120 | #### --densification_interval 121 | How frequently to densify, ```100``` (every 100 iterations) by default. 122 | #### --opacity_reset_interval 123 | How frequently to reset opacity, ```3_000``` by default. 124 | #### --lambda_dssim 125 | Influence of SSIM on total loss from 0 to 1, ```0.2``` by default. 126 | #### --percent_dense 127 | Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. 128 | 129 |
130 |
131 | 132 | ## Rendering 133 | To render a trained model, simply use 134 | ```bash 135 | python render.py -m 136 | ``` 137 |
138 | Command Line Arguments for render.py 139 | 140 | #### --model_path / -m 141 | Path to the trained model directory you want to create renderings for. 142 | #### --skip_train 143 | Flag to skip rendering the training set. 144 | #### --skip_test 145 | Flag to skip rendering the test set. 146 | #### --quiet 147 | Flag to omit any text written to standard out pipe. 148 | 149 | **The below parameters will be read automatically from the model path, based on what was used for training. However, you may override them by providing them explicitly on the command line.** 150 | 151 | #### --source_path / -s 152 | Path to the source directory containing a COLMAP or Synthetic NeRF data set. 153 | #### --images / -i 154 | Alternative subdirectory for COLMAP images (```images``` by default). 155 | #### --eval 156 | Add this flag to use a MipNeRF360-style training/test split for evaluation. 157 | #### --resolution / -r 158 | Changes the resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. ```1``` by default. 159 | #### --white_background / -w 160 | Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset. 161 | #### --convert_SHs_python 162 | Flag to make pipeline render with computed SHs from PyTorch instead of ours. 163 | #### --convert_cov3D_python 164 | Flag to make pipeline render with computed 3D covariance from PyTorch instead of ours. 165 | 166 |
167 | 168 | ## Interactive LOD Viewers 169 | Our viewing solutions are based on the SIBR framework, developed by the GRAPHDECO group for several novel-view synthesis projects. We intergrate LOD rendering technique into SIBR framework to make faster rendering effects. 170 | 171 | ![image](assets/Ablation_LOD.jpg) 172 | 173 | ### Hardware Requirements 174 | 175 | - CUDA-ready GPU with Compute Capability 7.0+ 176 | - 24 GB VRAM (to train to paper evaluation quality) 177 | - Please see FAQ for smaller VRAM configurations 178 | 179 | ### Software Requirements 180 | - Conda (recommended for easy setup) 181 | - C++ Compiler for PyTorch extensions (we used Visual Studio 2019 for Windows) 182 | - CUDA SDK 11 for PyTorch extensions, install *after* Visual Studio (we used 11.8, **known issues with 11.6**) 183 | - C++ Compiler and CUDA SDK must be compatible 184 | 185 | ### Installation from Source 186 | If you cloned with submodules (e.g., using ```--recursive```), the source code for the viewers is found in ```SIBR_viewers```. The network viewer runs within the SIBR framework for Image-based Rendering applications. 187 | 188 | #### Windows 189 | CMake should take care of your dependencies. 190 | ```shell 191 | cd SIBR_viewers 192 | cmake . -B build 193 | cmake --build build --target install --config RelWithDebInfo 194 | ``` 195 | You may specify a different configuration, e.g. ```Debug``` if you need more control during development. 196 | 197 | #### Ubuntu 22.04 198 | You will need to install a few dependencies before running the project setup. 199 | ```shell 200 | # Dependencies 201 | sudo apt install -y libglew-dev libassimp-dev libboost-all-dev libgtk-3-dev libopencv-dev libglfw3-dev libavdevice-dev libavcodec-dev libeigen3-dev libxxf86vm-dev libembree-dev 202 | # Project setup 203 | cd SIBR_viewers 204 | cmake -Bbuild . -DCMAKE_BUILD_TYPE=Release # add -G Ninja to build faster 205 | cmake --build build -j24 --target install 206 | ``` 207 | 208 | #### Ubuntu 20.04 209 | Backwards compatibility with Focal Fossa is not fully tested, but building SIBR with CMake should still work after invoking 210 | ```shell 211 | git checkout fossa_compatibility 212 | ``` 213 | 214 | ## GarageWorld Dataset 215 | ### Self-designed Polar Scanning System 216 | Using our polar divide, we build GarageWorld, the first large-scale garage dataset. 217 | ![image](assets/device.jpg) 218 | 219 | 220 | ### GarageWorld Dataset 221 | ![image](assets/datagallery.jpg) 222 | 223 |
224 |
225 |

BibTeX

226 |
@article{10.1145/3687762,
227 |         author = {Cui, Jiadi and Cao, Junming and Zhao, Fuqiang and He, Zhipeng and Chen, Yifan and Zhong, Yuhui and Xu, Lan and Shi, Yujiao and Zhang, Yingliang and Yu, Jingyi},
228 |         title = {LetsGo: Large-Scale Garage Modeling and Rendering via LiDAR-Assisted Gaussian Primitives},
229 |         year = {2024},
230 |         issue_date = {December 2024},
231 |         publisher = {Association for Computing Machinery},
232 |         address = {New York, NY, USA},
233 |         volume = {43},
234 |         number = {6},
235 |         issn = {0730-0301},
236 |         url = {https://doi.org/10.1145/3687762},
237 |         doi = {10.1145/3687762},
238 |         journal = {ACM Trans. Graph.},
239 |         month = nov,
240 |         articleno = {172},
241 |         numpages = {18},
242 |         keywords = {neural rendering, large-scale garage modeling, LiDAR scanning, 3D gaussian splatting, garage dataset, level-of-detail rendering}
243 |         }
244 | 
245 |
246 |
247 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self.use_lod = False 51 | self.depth_max = 100.0 52 | self._source_path = "" 53 | self._model_path = "" 54 | self._images = "images" 55 | self._depths = "" 56 | self._resolution = -1 57 | self._white_background = False 58 | self.data_device = "cuda" 59 | self.eval = False 60 | super().__init__(parser, "Loading Parameters", sentinel) 61 | 62 | def extract(self, args): 63 | g = super().extract(args) 64 | g.source_path = os.path.abspath(g.source_path) 65 | return g 66 | 67 | class PipelineParams(ParamGroup): 68 | def __init__(self, parser): 69 | self.convert_SHs_python = False 70 | self.compute_cov3D_python = False 71 | self.debug = False 72 | super().__init__(parser, "Pipeline Parameters") 73 | 74 | class OptimizationParams(ParamGroup): 75 | def __init__(self, parser): 76 | self.iterations = 30_000 77 | self.position_lr_init = 0.00016 78 | self.position_lr_final = 0.0000016 79 | self.position_lr_delay_mult = 0.01 80 | self.position_lr_max_steps = 30_000 81 | self.feature_lr = 0.0025 82 | self.opacity_lr = 0.05 83 | self.scaling_lr = 0.005 84 | self.rotation_lr = 0.001 85 | self.percent_dense = 0.01 86 | self.lambda_dssim = 0.2 87 | self.densification_interval = 100 88 | self.opacity_reset_interval = 3000 89 | self.densify_from_iter = 500 90 | self.densify_until_iter = 15_000 91 | self.densify_grad_threshold = 0.0002 92 | super().__init__(parser, "Optimization Parameters") 93 | 94 | def get_combined_args(parser : ArgumentParser): 95 | cmdlne_string = sys.argv[1:] 96 | cfgfile_string = "Namespace()" 97 | args_cmdline = parser.parse_args(cmdlne_string) 98 | 99 | try: 100 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 101 | print("Looking for config file in", cfgfilepath) 102 | with open(cfgfilepath) as cfg_file: 103 | print("Config file found: {}".format(cfgfilepath)) 104 | cfgfile_string = cfg_file.read() 105 | except TypeError: 106 | print("Config file not found at") 107 | pass 108 | args_cfgfile = eval(cfgfile_string) 109 | 110 | merged_dict = vars(args_cfgfile).copy() 111 | for k,v in vars(args_cmdline).items(): 112 | if v != None: 113 | merged_dict[k] = v 114 | return Namespace(**merged_dict) 115 | -------------------------------------------------------------------------------- /assets/Ablation_LOD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaofuq/LOD-3DGS/eec40cfdf979add3343eb6fd0508179fc94ecf80/assets/Ablation_LOD.jpg -------------------------------------------------------------------------------- /assets/datagallery.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaofuq/LOD-3DGS/eec40cfdf979add3343eb6fd0508179fc94ecf80/assets/datagallery.jpg -------------------------------------------------------------------------------- /assets/device.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaofuq/LOD-3DGS/eec40cfdf979add3343eb6fd0508179fc94ecf80/assets/device.jpg -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaofuq/LOD-3DGS/eec40cfdf979add3343eb6fd0508179fc94ecf80/assets/teaser.jpg -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="E:\\COLMAP-3.8-windows-cuda\\colmap.bat", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/images \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/images \ 61 | --output_path " + args.source_path + "/distorted/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/images \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lod-3dgs 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.6 8 | - plyfile=0.8.1 9 | - python=3.7.13 10 | - pip=22.3.1 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - tqdm 15 | - pip: 16 | - submodules/diff-gaussian-rasterization 17 | - submodules/simple-knn -------------------------------------------------------------------------------- /full_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] 17 | tanks_and_temples_scenes = ["truck", "train"] 18 | deep_blending_scenes = ["drjohnson", "playroom"] 19 | 20 | parser = ArgumentParser(description="Full evaluation script parameters") 21 | parser.add_argument("--skip_training", action="store_true") 22 | parser.add_argument("--skip_rendering", action="store_true") 23 | parser.add_argument("--skip_metrics", action="store_true") 24 | parser.add_argument("--output_path", default="./eval") 25 | args, _ = parser.parse_known_args() 26 | 27 | all_scenes = [] 28 | all_scenes.extend(mipnerf360_outdoor_scenes) 29 | all_scenes.extend(mipnerf360_indoor_scenes) 30 | all_scenes.extend(tanks_and_temples_scenes) 31 | all_scenes.extend(deep_blending_scenes) 32 | 33 | if not args.skip_training or not args.skip_rendering: 34 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str) 35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) 36 | parser.add_argument("--deepblending", "-db", required=True, type=str) 37 | args = parser.parse_args() 38 | 39 | if not args.skip_training: 40 | common_args = " --quiet --eval --test_iterations -1 " 41 | for scene in mipnerf360_outdoor_scenes: 42 | source = args.mipnerf360 + "/" + scene 43 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) 44 | for scene in mipnerf360_indoor_scenes: 45 | source = args.mipnerf360 + "/" + scene 46 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) 47 | for scene in tanks_and_temples_scenes: 48 | source = args.tanksandtemples + "/" + scene 49 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 50 | for scene in deep_blending_scenes: 51 | source = args.deepblending + "/" + scene 52 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 53 | 54 | if not args.skip_rendering: 55 | all_sources = [] 56 | for scene in mipnerf360_outdoor_scenes: 57 | all_sources.append(args.mipnerf360 + "/" + scene) 58 | for scene in mipnerf360_indoor_scenes: 59 | all_sources.append(args.mipnerf360 + "/" + scene) 60 | for scene in tanks_and_temples_scenes: 61 | all_sources.append(args.tanksandtemples + "/" + scene) 62 | for scene in deep_blending_scenes: 63 | all_sources.append(args.deepblending + "/" + scene) 64 | 65 | common_args = " --quiet --eval --skip_train" 66 | for scene, source in zip(all_scenes, all_sources): 67 | os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 68 | os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 69 | 70 | if not args.skip_metrics: 71 | scenes_string = "" 72 | for scene in all_scenes: 73 | scenes_string += "\"" + args.output_path + "/" + scene + "\" " 74 | 75 | os.system("python metrics.py -m " + scenes_string) -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from utils.sh_utils import eval_sh 16 | 17 | def render(viewpoint_camera, xyz, features, opacity, scales, rotations, active_sh_degree, max_sh_degree, 18 | pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, cov3D_precomp = None, colors_precomp = None ): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(xyz, dtype=xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(viewpoint_camera.image_height), 38 | image_width=int(viewpoint_camera.image_width), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=scaling_modifier, 43 | viewmatrix=viewpoint_camera.world_view_transform, 44 | projmatrix=viewpoint_camera.full_proj_transform, 45 | sh_degree=active_sh_degree, 46 | campos=viewpoint_camera.camera_center, 47 | prefiltered=False, 48 | debug=pipe.debug 49 | ) 50 | 51 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 52 | 53 | means3D = xyz 54 | means2D = screenspace_points 55 | opacity = opacity 56 | 57 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 58 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 59 | shs = None 60 | colors_precomp = None 61 | if override_color is None: 62 | if pipe.convert_SHs_python: 63 | shs_view = features.transpose(1, 2).view(-1, 3, (max_sh_degree+1)**2) 64 | dir_pp = (xyz - viewpoint_camera.camera_center.repeat(features.shape[0], 1)) 65 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 66 | sh2rgb = eval_sh(active_sh_degree, shs_view, dir_pp_normalized) 67 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 68 | else: 69 | shs = features 70 | else: 71 | colors_precomp = override_color 72 | 73 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 74 | # scaling / rotation by the rasterizer. 75 | 76 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 77 | rendered_image, depth, radii = rasterizer( 78 | means3D = means3D, 79 | means2D = means2D, 80 | shs = shs, 81 | colors_precomp = colors_precomp, 82 | opacities = opacity, 83 | scales = scales, 84 | rotations = rotations, 85 | cov3D_precomp = cov3D_precomp) 86 | 87 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 88 | # They will be excluded from value updates used in the splitting criteria. 89 | return {"render": rendered_image, 90 | "depth": depth, 91 | "viewspace_points": screenspace_points, 92 | "visibility_filter" : radii > 0, 93 | "radii": radii} 94 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"[ Viewer ] Connected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in os.listdir(renders_dir): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | print("") 43 | 44 | for scene_dir in model_paths: 45 | try: 46 | print("Scene:", scene_dir) 47 | full_dict[scene_dir] = {} 48 | per_view_dict[scene_dir] = {} 49 | full_dict_polytopeonly[scene_dir] = {} 50 | per_view_dict_polytopeonly[scene_dir] = {} 51 | 52 | test_dir = Path(scene_dir) / "test" 53 | 54 | for method in os.listdir(test_dir): 55 | print("Method:", method) 56 | 57 | full_dict[scene_dir][method] = {} 58 | per_view_dict[scene_dir][method] = {} 59 | full_dict_polytopeonly[scene_dir][method] = {} 60 | per_view_dict_polytopeonly[scene_dir][method] = {} 61 | 62 | method_dir = test_dir / method 63 | gt_dir = method_dir/ "gt" 64 | renders_dir = method_dir / "renders" 65 | renders, gts, image_names = readImages(renders_dir, gt_dir) 66 | 67 | ssims = [] 68 | psnrs = [] 69 | lpipss = [] 70 | 71 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 72 | ssims.append(ssim(renders[idx], gts[idx])) 73 | psnrs.append(psnr(renders[idx], gts[idx])) 74 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 75 | 76 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 77 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 78 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 79 | print("") 80 | 81 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 82 | "PSNR": torch.tensor(psnrs).mean().item(), 83 | "LPIPS": torch.tensor(lpipss).mean().item()}) 84 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 85 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 86 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 87 | 88 | with open(scene_dir + "/results.json", 'w') as fp: 89 | json.dump(full_dict[scene_dir], fp, indent=True) 90 | with open(scene_dir + "/per_view.json", 'w') as fp: 91 | json.dump(per_view_dict[scene_dir], fp, indent=True) 92 | except: 93 | print("Unable to compute metrics for model", scene_dir) 94 | 95 | if __name__ == "__main__": 96 | device = torch.device("cuda:0") 97 | torch.cuda.set_device(device) 98 | 99 | # Set up command line argument parser 100 | parser = ArgumentParser(description="Training script parameters") 101 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 102 | args = parser.parse_args() 103 | evaluate(args.model_paths) 104 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene, GaussianModel 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from argparse import ArgumentParser 21 | from arguments import ModelParams, PipelineParams, get_combined_args 22 | 23 | def render_set(model_path, name, iteration, views, scene, pipeline, background): 24 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 25 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 26 | 27 | makedirs(render_path, exist_ok=True) 28 | makedirs(gts_path, exist_ok=True) 29 | 30 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 31 | xyz, features, opacity, scales, rotations, cov3D_precomp, \ 32 | active_sh_degree, max_sh_degree, masks = scene.get_gaussian_parameters(view.world_view_transform, pipeline.compute_cov3D_python) 33 | 34 | rendering = render(view, xyz, features, opacity, scales, rotations, active_sh_degree, max_sh_degree, pipeline, background, cov3D_precomp = cov3D_precomp)["render"] 35 | gt = view.original_image[0:3, :, :] 36 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 37 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 38 | 39 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): 40 | with torch.no_grad(): 41 | # gaussians = GaussianModel(dataset.sh_degree) 42 | # scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 43 | scene = Scene(dataset, load_iteration=iteration, shuffle=False) 44 | 45 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 46 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 47 | 48 | if not skip_train: 49 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), scene, pipeline, background) 50 | 51 | if not skip_test: 52 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), scene, pipeline, background) 53 | 54 | if __name__ == "__main__": 55 | # Set up command line argument parser 56 | parser = ArgumentParser(description="Testing script parameters") 57 | model = ModelParams(parser, sentinel=True) 58 | pipeline = PipelineParams(parser) 59 | parser.add_argument("--iteration", default=-1, type=int) 60 | parser.add_argument("--skip_train", action="store_true") 61 | parser.add_argument("--skip_test", action="store_true") 62 | parser.add_argument("--quiet", action="store_true") 63 | args = get_combined_args(parser) 64 | print("[ INFO ] Rendering " + args.model_path) 65 | 66 | # Initialize system state (RNG) 67 | safe_state(args.quiet) 68 | 69 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | import torch 16 | import numpy as np 17 | from utils.system_utils import searchForMaxIteration 18 | from scene.dataset_readers import sceneLoadTypeCallbacks 19 | from scene.gaussian_model import GaussianModel 20 | from arguments import ModelParams 21 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 22 | from plyfile import PlyData, PlyElement 23 | from utils.system_utils import mkdir_p 24 | 25 | class Scene: 26 | 27 | gaussians : GaussianModel 28 | 29 | def __init__(self, args : ModelParams, load_iteration=-1, shuffle=True, resolution_scales=[1.0]): 30 | """b 31 | :param path: Path to colmap scene main folder. 32 | """ 33 | self.model_path = args.model_path 34 | self.loaded_iter = None 35 | self.level = 0 36 | 37 | if os.path.exists(os.path.join(self.model_path, "point_cloud")): 38 | if load_iteration == -1: 39 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 40 | else: 41 | self.loaded_iter = load_iteration 42 | print("[ Scene ] Loading trained model at iteration {}".format(self.loaded_iter)) 43 | 44 | self.train_cameras = {} 45 | self.test_cameras = {} 46 | if os.path.exists(os.path.join(args.source_path, "sparse")): 47 | if args.use_lod: 48 | print("[ Scene ] Found sparse folder, assuming Octree data set!") 49 | scene_info = sceneLoadTypeCallbacks["Octree"](args.source_path, args.images, args.depths, args.eval) 50 | else: 51 | print("[ Scene ] Found sparse folder, assuming Colmap data set!") 52 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval) 53 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 54 | print("[ Scene ] Found transforms_train.json file, assuming Blender data set!") 55 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 56 | elif os.path.exists(os.path.join(args.source_path, "transforms.json")): 57 | print("[ Scene ] Found transforms.json file, assuming Blender data set!") 58 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 59 | else: 60 | assert False, "Could not recognize scene type!" 61 | 62 | self.max_level = scene_info.max_level 63 | self.beta = np.log(self.max_level+1) 64 | self.gaussians = [GaussianModel(args.sh_degree, level) for level in range(self.max_level + 1)] 65 | 66 | if not self.loaded_iter: 67 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 68 | dest_file.write(src_file.read()) 69 | json_cams = [] 70 | camlist = [] 71 | if scene_info.test_cameras: 72 | camlist.extend(scene_info.test_cameras) 73 | if scene_info.train_cameras: 74 | camlist.extend(scene_info.train_cameras) 75 | for id, cam in enumerate(camlist): 76 | json_cams.append(camera_to_JSON(id, cam)) 77 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 78 | json.dump(json_cams, file) 79 | 80 | if shuffle: 81 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 82 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 83 | 84 | self.cameras_extent = scene_info.nerf_normalization["radius"] 85 | 86 | for resolution_scale in resolution_scales: 87 | print("[ Scene ] Loading Training Cameras") 88 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 89 | print("[ Scene ] Loading Test Cameras") 90 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 91 | 92 | self.depth_min = torch.tensor(float('inf')) 93 | self.depth_max = torch.tensor(-float('inf')) 94 | # Load Gaussian Model 95 | import time 96 | st = time.time() 97 | for level in range(self.max_level + 1): 98 | if self.loaded_iter: 99 | self.gaussians[level].load_ply(os.path.join(self.model_path, 100 | "point_cloud", 101 | "iteration_" + str(self.loaded_iter), 102 | "level_{}.ply".format(level))) 103 | else: 104 | self.gaussians[level].create_from_pcd(scene_info.point_cloud[level], self.cameras_extent) 105 | 106 | for cam in self.train_cameras[1.0]: 107 | xyz = self.gaussians[level].get_xyz.detach() 108 | depth_z = self.get_z_depth(xyz, cam.world_view_transform) 109 | self.depth_min = torch.min(self.depth_min, torch.max(depth_z.min(), torch.tensor(0.0))) 110 | self.depth_max = torch.max(self.depth_max, depth_z.max()) 111 | self.depth_max = 0.95 * 1.3 * (self.depth_max - self.depth_min) + self.depth_min 112 | self.depth_min = 0.05 * 1.3 * (self.depth_max - self.depth_min) + self.depth_min 113 | print("[ Scene ] Initialize scene depth range at [{:2f}, {:2f}]".format(self.depth_min.cpu(), self.depth_max.cpu())) 114 | et = time.time() 115 | print("[ Scene ] Gaussian Model creation took {} seconds".format(et - st)) 116 | 117 | 118 | def get_z_depth(self, xyz, viewmatrix): 119 | homogeneous_xyz = torch.cat((xyz, torch.ones(xyz.shape[0], 1, dtype=xyz.dtype, device=xyz.device)), dim=1) 120 | projected_xyz= torch.matmul(homogeneous_xyz, viewmatrix) 121 | depth_z = projected_xyz[:,2] 122 | return depth_z 123 | 124 | def save(self, iteration): 125 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 126 | if self.max_level == 0: 127 | self.save_full_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 128 | else: 129 | for level in range(self.max_level+1): 130 | self.gaussians[level].save_ply(os.path.join(point_cloud_path, "level_{}.ply".format(level))) 131 | 132 | 133 | def construct_list_of_attributes(self): 134 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 135 | # All channels except the 3 DC 136 | for i in range(self.gaussians[-1]._features_dc.shape[1]*self.gaussians[-1]._features_dc.shape[2]): 137 | l.append('f_dc_{}'.format(i)) 138 | for i in range(self.gaussians[-1]._features_rest.shape[1]*self.gaussians[-1]._features_rest.shape[2]): 139 | l.append('f_rest_{}'.format(i)) 140 | l.append('opacity') 141 | for i in range(self.gaussians[-1]._scaling.shape[1]): 142 | l.append('scale_{}'.format(i)) 143 | for i in range(self.gaussians[-1]._rotation.shape[1]): 144 | l.append('rot_{}'.format(i)) 145 | return l 146 | 147 | def save_full_ply(self, path): 148 | mkdir_p(os.path.dirname(path)) 149 | 150 | xyz, f_dc, f_rest, opacity, scales, rotations = [], [], [], [], [], [] 151 | for level in range(self.max_level + 1): 152 | xyz.append(self.gaussians[level]._xyz) 153 | f_dc.append(self.gaussians[level]._features_dc) 154 | f_rest.append(self.gaussians[level]._features_rest) 155 | opacity.append(self.gaussians[level]._opacity) 156 | scales.append(self.gaussians[level]._scaling) 157 | rotations.append(self.gaussians[level]._rotation) 158 | 159 | xyz = torch.cat(xyz, dim=0).detach().cpu().numpy() 160 | normals = np.zeros_like(xyz) 161 | f_dc = torch.cat(f_dc, dim=0).detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 162 | f_rest = torch.cat(f_rest, dim=0).detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 163 | opacities = torch.cat(opacity, dim=0).detach().cpu().numpy() 164 | scale = torch.cat(scales, dim=0).detach().cpu().numpy() 165 | rotation = torch.cat(rotations, dim=0).detach().cpu().numpy() 166 | 167 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 168 | 169 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 170 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 171 | elements[:] = list(map(tuple, attributes)) 172 | el = PlyElement.describe(elements, 'vertex') 173 | PlyData([el]).write(path) 174 | 175 | def getTrainCameras(self, scale=1.0): 176 | return self.train_cameras[scale] 177 | 178 | def getTestCameras(self, scale=1.0): 179 | return self.test_cameras[scale] 180 | 181 | def getGaussians(self, level=-1): 182 | return self.gaussians[level] 183 | 184 | def getLevels(self): 185 | return self.max_level 186 | 187 | def update_max_radii2D(self, radii, visibility_filter, masks): 188 | level_start = 0 189 | expanded_visibility_filter = torch.zeros(masks.shape[0], dtype=torch.bool, device=visibility_filter.device) 190 | expanded_radii = torch.zeros(masks.shape[0], dtype=radii.dtype, device=radii.device) 191 | expanded_visibility_filter[masks] = visibility_filter 192 | expanded_radii[masks] = radii 193 | for level in range(self.max_level + 1): 194 | level_offset = self.gaussians[level].max_radii2D.shape[0] 195 | level_radii = expanded_radii[level_start:level_start+level_offset] 196 | level_visibility_filter = expanded_visibility_filter[level_start:level_start+level_offset] 197 | self.gaussians[level].max_radii2D[level_visibility_filter] = torch.max(self.gaussians[level].max_radii2D[level_visibility_filter], level_radii[level_visibility_filter]) 198 | level_start += level_offset 199 | 200 | def training_setup(self, args): 201 | for level in range(self.max_level + 1): 202 | self.gaussians[level].training_setup(args) 203 | 204 | def restore(self, params, args): 205 | for level in range(self.max_level + 1): 206 | self.gaussians[level].restore(params, args) 207 | 208 | def update_learning_rate(self, iters): 209 | for level in range(self.max_level + 1): 210 | self.gaussians[level].update_learning_rate(iters) 211 | 212 | def oneupSHdegree(self): 213 | for level in range(self.max_level + 1): 214 | self.gaussians[level].oneupSHdegree() 215 | 216 | def add_densification_stats(self, viewspace_point, visibility_filter, masks): 217 | level_start = 0 218 | viewspace_point_grad = viewspace_point.grad 219 | expanded_viewspace_point_grad = torch.zeros(masks.shape[0], 3, dtype=viewspace_point_grad.dtype, device=viewspace_point_grad.device) 220 | expanded_visibility_filter = torch.zeros(masks.shape[0], dtype=torch.bool, device=visibility_filter.device) 221 | expanded_viewspace_point_grad[masks,:] = viewspace_point_grad 222 | expanded_visibility_filter[masks] = visibility_filter 223 | for level in range(self.max_level + 1): 224 | level_offset = self.gaussians[level].get_xyz.shape[0] 225 | level_viewspace_point_grad = expanded_viewspace_point_grad[level_start:level_start + level_offset] 226 | level_visibility_filter = expanded_visibility_filter[level_start:level_start + level_offset] 227 | self.gaussians[level].add_densification_stats(level_viewspace_point_grad, level_visibility_filter) 228 | level_start += self.gaussians[level].get_xyz.shape[0] 229 | 230 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 231 | for level in range(self.max_level + 1): 232 | scale = np.min([np.sqrt(2) ** (self.max_level - level), 4.0]) #np.log(self.max_level - level + 1.0) + 1.0 233 | if max_screen_size: 234 | max_screen_size = max_screen_size * scale 235 | self.gaussians[level].densify_and_prune(max_grad * scale, min_opacity, extent * scale, max_screen_size) 236 | 237 | def reset_opacity(self): 238 | for level in range(self.max_level + 1): 239 | self.gaussians[level].reset_opacity() 240 | 241 | def optimizer_step(self): 242 | for level in range(self.max_level + 1): 243 | self.gaussians[level].optimizer.step() 244 | self.gaussians[level].optimizer.zero_grad(set_to_none = True) 245 | 246 | 247 | def get_gaussian_parameters(self, viewpoint, compute_cov3D_python, scaling_modifier=1.0, random=-1): 248 | 249 | levels = range(self.max_level + 1) 250 | get_attrs = lambda attr: [getattr(self.gaussians[level], attr) for level in levels] 251 | xyz, features, opacity, scales, rotations = map(get_attrs, ['get_xyz', 'get_features', 'get_opacity', 'get_scaling', 'get_rotation']) 252 | 253 | # Compute cov3D_precomp if necessary 254 | cov3D_precomp = [self.gaussians[-1].get_covariance(scaling_modifier)] * len(xyz) if compute_cov3D_python else None 255 | 256 | # Define activation levels based on 'random' parameter 257 | if random < 0: 258 | depths = [self.get_z_depth(xyz_lvl.detach(), viewpoint) for xyz_lvl in xyz] 259 | act_levels = [torch.clamp((self.max_level + 1) * torch.exp(-1.0 * self.beta * torch.abs(depth) / self.depth_max), 0, self.max_level) for depth in depths] 260 | act_levels = [torch.floor(level) for level in act_levels] 261 | filters = [act_level == level for act_level, level in zip(act_levels, levels)] 262 | else: 263 | filters = [torch.full_like(xyz[level][:,0], level == random, dtype=torch.bool) for level in levels] 264 | 265 | # Concatenate all attributes 266 | concat_attrs = lambda attrs: torch.cat(attrs, dim=0) 267 | xyz, features, opacity, scales, rotations, filters = map(concat_attrs, [xyz, features, opacity, scales, rotations, filters]) 268 | 269 | # Apply filters to all attributes 270 | filtered = lambda attr: attr[filters] 271 | xyz, features, opacity, scales, rotations = map(filtered, [xyz, features, opacity, scales, rotations]) 272 | 273 | if compute_cov3D_python: 274 | cov3D_precomp = filtered(concat_attrs(cov3D_precomp)) 275 | 276 | # Active and maximum spherical harmonics degrees 277 | active_sh_degree, max_sh_degree = self.gaussians[-1].active_sh_degree, self.gaussians[-1].max_sh_degree 278 | 279 | return xyz, features, opacity, scales, rotations, cov3D_precomp, active_sh_degree, max_sh_degree, filters 280 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, depth, depth_mask, 19 | image_name, uid, cx, cy, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.cx = cx 31 | self.cy = cy 32 | self.image_name = image_name 33 | 34 | try: 35 | self.data_device = torch.device(data_device) 36 | except Exception as e: 37 | print(e) 38 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 39 | self.data_device = torch.device("cuda") 40 | 41 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 42 | self.image_width = self.original_image.shape[2] 43 | self.image_height = self.original_image.shape[1] 44 | 45 | if gt_alpha_mask is not None: 46 | self.original_image *= gt_alpha_mask.to(self.data_device) 47 | else: 48 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 49 | 50 | self.depth = depth 51 | self.depth_mask = depth_mask 52 | 53 | self.zfar = 100.0 54 | self.znear = 0.01 55 | 56 | self.trans = trans 57 | self.scale = scale 58 | 59 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 60 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 61 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 62 | self.camera_center = self.world_view_transform.inverse()[3, :3] 63 | 64 | class MiniCam: 65 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 66 | self.image_width = width 67 | self.image_height = height 68 | self.FoVy = fovy 69 | self.FoVx = fovx 70 | self.znear = znear 71 | self.zfar = zfar 72 | self.world_view_transform = world_view_transform 73 | self.full_proj_transform = full_proj_transform 74 | view_inv = torch.inverse(self.world_view_transform) 75 | self.camera_center = view_inv[3][:3] 76 | 77 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | try: 263 | elems = fid.readline().split() 264 | xys = np.column_stack([tuple(map(float, elems[0::3])), 265 | tuple(map(float, elems[1::3]))]) 266 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 267 | except: 268 | xys = None 269 | point3D_ids = None 270 | images[image_id] = Image( 271 | id=image_id, qvec=qvec, tvec=tvec, 272 | camera_id=camera_id, name=image_name, 273 | xys=xys, point3D_ids=point3D_ids) 274 | return images 275 | 276 | 277 | def read_colmap_bin_array(path): 278 | """ 279 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 280 | 281 | :param path: path to the colmap binary file. 282 | :return: nd array with the floating point values in the value 283 | """ 284 | with open(path, "rb") as fid: 285 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 286 | usecols=(0, 1, 2), dtype=int) 287 | fid.seek(0) 288 | num_delimiter = 0 289 | byte = fid.read(1) 290 | while True: 291 | if byte == b"&": 292 | num_delimiter += 1 293 | if num_delimiter >= 3: 294 | break 295 | byte = fid.read(1) 296 | array = np.fromfile(fid, np.float32) 297 | array = array.reshape((width, height, channels), order="F") 298 | return np.transpose(array, (1, 0, 2)).squeeze() 299 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal, OctreeGaussian, OctreeGaussianNode 19 | import numpy as np 20 | import json 21 | import laspy 22 | import subprocess 23 | from tqdm import tqdm 24 | from pathlib import Path 25 | from plyfile import PlyData, PlyElement 26 | from utils.sh_utils import SH2RGB 27 | from utils.system_utils import mkdir_p 28 | from scene.gaussian_model import BasicPointCloud 29 | from scene.octree_loader import loadOctree 30 | import queue 31 | 32 | class CameraInfo(NamedTuple): 33 | uid: int 34 | R: np.array 35 | T: np.array 36 | FovY: np.array 37 | FovX: np.array 38 | image: np.array 39 | image_path: str 40 | image_name: str 41 | depth_path: str 42 | depth: np.array 43 | width: int 44 | height: int 45 | cx: float 46 | cy: float 47 | 48 | class SceneInfo(NamedTuple): 49 | point_cloud: BasicPointCloud 50 | train_cameras: list 51 | test_cameras: list 52 | nerf_normalization: dict 53 | ply_path: str 54 | max_level: int = 16 55 | depth_min: float = 0.01 56 | depth_max: float = 1000 57 | 58 | def getNerfppNorm(cam_info): 59 | def get_center_and_diag(cam_centers): 60 | cam_centers = np.hstack(cam_centers) 61 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 62 | center = avg_cam_center 63 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 64 | diagonal = np.max(dist) 65 | return center.flatten(), diagonal 66 | 67 | cam_centers = [] 68 | 69 | for cam in cam_info: 70 | W2C = getWorld2View2(cam.R, cam.T) 71 | C2W = np.linalg.inv(W2C) 72 | cam_centers.append(C2W[:3, 3:4]) 73 | 74 | center, diagonal = get_center_and_diag(cam_centers) 75 | radius = diagonal * 1.1 76 | 77 | translate = -center 78 | 79 | return {"translate": translate, "radius": radius} 80 | 81 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, depths_folder = None): 82 | cam_infos = [] 83 | for idx, key in enumerate(cam_extrinsics): 84 | sys.stdout.write('\r') 85 | # the exact output you're looking for: 86 | sys.stdout.write("[ Scene ] Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 87 | sys.stdout.flush() 88 | 89 | extr = cam_extrinsics[key] 90 | intr = cam_intrinsics[extr.camera_id] 91 | height = intr.height 92 | width = intr.width 93 | cx = width / 2.0 94 | cy = height / 2.0 95 | 96 | uid = intr.id 97 | R = np.transpose(qvec2rotmat(extr.qvec)) 98 | T = np.array(extr.tvec) 99 | if intr.model=="SIMPLE_PINHOLE": 100 | focal_length_x = intr.params[0] 101 | FovY = focal2fov(focal_length_x, height) 102 | FovX = focal2fov(focal_length_x, width) 103 | elif intr.model=="PINHOLE" or intr.model=="SIMPLE_RADIAL": 104 | focal_length_x = intr.params[0] 105 | focal_length_y = intr.params[1] 106 | FovY = focal2fov(focal_length_y, height) 107 | FovX = focal2fov(focal_length_x, width) 108 | else: 109 | assert False, "[ ERROR ] Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 110 | 111 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 112 | image_name = os.path.splitext(os.path.basename(image_path))[0] 113 | image = Image.open(image_path) 114 | 115 | if depths_folder is not None: 116 | depth_path = os.path.join(depths_folder, image_name + ".png") 117 | depth = Image.open(depth_path) 118 | else: 119 | depth_path = None 120 | depth = None 121 | 122 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 123 | image_path=image_path, image_name=image_name, depth_path=depth_path, depth = depth, width=width, height=height, cx = cx, cy = cy) 124 | cam_infos.append(cam_info) 125 | sys.stdout.write('\n') 126 | return cam_infos 127 | 128 | def fetchPly(path): 129 | plydata = PlyData.read(path) 130 | vertices = plydata['vertex'] 131 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 132 | 133 | if vertices.__contains__('red'): 134 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 135 | else: 136 | colors = np.zeros_like(positions) 137 | 138 | if vertices.__contains__('nx'): 139 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 140 | else: 141 | normals = np.zeros_like(positions) 142 | 143 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 144 | 145 | def storePly(path, xyz, rgb): 146 | # Define the dtype for the structured array 147 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 148 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 149 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 150 | 151 | normals = np.zeros_like(xyz) 152 | 153 | elements = np.empty(xyz.shape[0], dtype=dtype) 154 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 155 | elements[:] = list(map(tuple, attributes)) 156 | 157 | # Create the PlyData object and write to file 158 | vertex_element = PlyElement.describe(elements, 'vertex') 159 | ply_data = PlyData([vertex_element]) 160 | ply_data.write(path) 161 | 162 | 163 | def storeLas(path, xyz, rgb): 164 | 165 | # 1. Create a Las 166 | header = laspy.LasHeader(point_format=2, version="1.4") 167 | out_las = laspy.LasData(header) 168 | 169 | # 2. Fill the Las 170 | out_las.x = xyz[:, 0] 171 | out_las.y = xyz[:, 1] 172 | out_las.z = xyz[:, 2] 173 | 174 | def normalize_color(color): 175 | return (color * 255).astype(np.uint16) 176 | 177 | out_las.red = normalize_color(rgb[:, 0]) 178 | out_las.green = normalize_color(rgb[:, 1]) 179 | out_las.blue = normalize_color(rgb[:, 2]) 180 | 181 | # Save the LAS file 182 | out_las.write(path) 183 | 184 | def readColmapSceneInfo(path, images, depths, eval, llffhold=8): 185 | try: 186 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 187 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 188 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 189 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 190 | except: 191 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 192 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 193 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 194 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 195 | 196 | images_dir = "images" if images == None else images 197 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 198 | images_folder=os.path.join(path, images_dir), 199 | depths_folder=os.path.join(path, depths) if depths != "" else None) 200 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 201 | 202 | if eval: 203 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 204 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 205 | else: 206 | train_cam_infos = cam_infos 207 | test_cam_infos = [] 208 | 209 | nerf_normalization = getNerfppNorm(train_cam_infos) 210 | 211 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 212 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 213 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 214 | if not os.path.exists(ply_path): 215 | print("[ Dataloader ] Converting point3d.bin to .ply, will happen only the first time you open the scene.") 216 | try: 217 | xyz, rgb, _ = read_points3D_binary(bin_path) 218 | except: 219 | xyz, rgb, _ = read_points3D_text(txt_path) 220 | storePly(ply_path, xyz, rgb) 221 | else: 222 | print("[ Dataloader ] Found .ply point cloud, skipping conversion.") 223 | pcd = fetchPly(ply_path) 224 | 225 | pcds = [pcd] 226 | scene_info = SceneInfo(point_cloud=pcds, 227 | train_cameras=train_cam_infos, 228 | test_cameras=test_cam_infos, 229 | nerf_normalization=nerf_normalization, 230 | ply_path=ply_path, 231 | max_level=0) 232 | return scene_info 233 | 234 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 235 | cam_infos = [] 236 | 237 | with open(os.path.join(path, transformsfile)) as json_file: 238 | contents = json.load(json_file) 239 | 240 | if "aabb" in contents: 241 | aabb = contents["aabb"] 242 | scale = max(0.000001,max(max(abs(float(aabb[1][0])-float(aabb[0][0])), 243 | abs(float(aabb[1][1])-float(aabb[0][1]))), 244 | abs(float(aabb[1][2])-float(aabb[0][2])))) 245 | 246 | offset = [((float(aabb[1][0]) + float(aabb[0][0])) * 0.5) - 0.5 * scale, 247 | ((float(aabb[1][1]) + float(aabb[0][1])) * 0.5) - 0.5 * scale, 248 | ((float(aabb[1][2]) + float(aabb[0][2])) * 0.5)- 0.5 * scale] 249 | 250 | elif "scale" in contents and "offset" in contents: 251 | scale = 1.0 / contents["scale"] 252 | offset = -np.array(contents["offset"]) * scale 253 | else: 254 | scale = 2.6 # default scale for NeRF scenes 255 | offset = -1.3 256 | 257 | frames = contents["frames"] 258 | for idx, frame in enumerate(tqdm(frames, unit=" images", desc=f"Loading Images")): 259 | cam_name = os.path.join(path, frame["file_path"]) 260 | if not os.path.exists(cam_name): 261 | cam_name = os.path.join(path, frame["file_path"] + extension) 262 | 263 | # NeRF 'transform_matrix' is a camera-to-world transform 264 | c2w = np.array(frame["transform_matrix"]) 265 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 266 | c2w[:3, 1:3] *= -1 267 | 268 | # get the world-to-camera transform and set R, T 269 | w2c = np.linalg.inv(c2w) 270 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 271 | T = w2c[:3, 3] 272 | 273 | image_path = os.path.join(path, cam_name) 274 | image_name = Path(cam_name).stem 275 | image = Image.open(image_path) 276 | 277 | im_data = np.array(image.convert("RGBA")) 278 | 279 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 280 | 281 | norm_data = im_data / 255.0 282 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 283 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 284 | 285 | if "cx" in contents and "cy" in contents: 286 | cx = contents["cx"] 287 | cy = contents["cy"] 288 | else: 289 | cx = image.size[0] / 2 290 | cy = image.size[1] / 2 291 | 292 | # Extract focal length the transform matrix 293 | if 'camera_angle_x' in contents or 'camera_angle_y' in contents: 294 | # blender, assert in radians. already downscaled since we use H/W 295 | fl_x = fov2focal(contents["camera_angle_x"], 2 * cx) if 'camera_angle_x' in contents else None 296 | fl_y = fov2focal(contents["camera_angle_y"], 2 * cy) if 'camera_angle_y' in contents else None 297 | if fl_x is None: fl_x = fl_y 298 | if fl_y is None: fl_y = fl_x 299 | FovX = fovx = focal2fov(fl_x, 2 * cx) 300 | FovY = fovy = focal2fov(fl_y, 2 * cy) 301 | elif 'fl_x' in contents or 'fl_y' in contents: 302 | fl_x = (contents['fl_x'] if 'fl_x' in contents else contents['fl_y']) 303 | fl_y = (contents['fl_y'] if 'fl_y' in contents else contents['fl_x']) 304 | FovX = fovx = focal2fov(fl_x, 2 * cx) 305 | FovY = fovy = focal2fov(fl_y, 2 * cy) 306 | elif 'K' in frame or 'intrinsic_matrix' in frame: 307 | K = frame['K'] if 'K' in frame else frame['intrinsic_matrix'] 308 | FovX = fovx = focal2fov(K[0][0], 2.0*K[0][2]) 309 | FovY = fovy = focal2fov(K[1][1], 2.0*K[1][2]) 310 | cx, cy = K[0][2], K[1][2] 311 | elif 'focal_length' in frame: 312 | FovX = fovx = focal2fov(frame['focal_length'], 2 * cx) 313 | FovY = fovy = focal2fov(frame['focal_length'], 2 * cy) 314 | else: 315 | raise Exception("[ ERROR ] No camera intrinsics found in the transforms file.") 316 | 317 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 318 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1], cx = cx, cy = cy)) 319 | 320 | return cam_infos, scale, offset 321 | 322 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 323 | # Read Train Cameras 324 | if os.path.exists(os.path.join(path, "transforms_train.json")): 325 | print("[ Dataloader ] Reading Training Transforms From: transforms_train.json") 326 | train_cam_infos, scale, offset = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 327 | elif os.path.exists(os.path.join(path, "transforms.json")): 328 | print("[ Dataloader ] Reading Training Transforms From: transforms.json") 329 | train_cam_infos, scale, offset = readCamerasFromTransforms(path, "transforms.json", white_background, extension) 330 | 331 | # Read Test Cameras 332 | if os.path.exists(os.path.join(path, "transforms_test.json")) and eval: 333 | print("[ Dataloader ] Reading Test Transforms From: transforms_test.json") 334 | test_cam_infos, scale, offset = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 335 | else: 336 | test_cam_infos = [] 337 | 338 | nerf_normalization = getNerfppNorm(train_cam_infos) 339 | 340 | # Read Point Cloud 341 | ply_path = os.path.join(path, "points3d.ply") 342 | if not os.path.exists(ply_path): 343 | # Since this data set has no colmap data, we start with random points 344 | num_pts = 100_000 345 | print(f"[ Dataloader ] Generating random point cloud ({num_pts})...") 346 | 347 | # We create random points inside the bounds of the synthetic Blender scenes 348 | xyz = np.random.random((num_pts, 3)) * scale + offset 349 | shs = np.random.random((num_pts, 3)) / 255.0 350 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 351 | 352 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 353 | try: 354 | pcd = fetchPly(ply_path) 355 | except: 356 | pcd = None 357 | 358 | pcds = [pcd] 359 | scene_info = SceneInfo(point_cloud=pcds, 360 | train_cameras=train_cam_infos, 361 | test_cameras=test_cam_infos, 362 | nerf_normalization=nerf_normalization, 363 | ply_path=ply_path, 364 | max_level=0) 365 | return scene_info 366 | 367 | """function to save the octree class into ply file and store into disk""" 368 | 369 | # gaussianmodels need to be replace to real gaussianmodesl 370 | # now the position is aligned with colmap coordinate, do not need to add any offset 371 | 372 | def collect_position_buffers(node, level, max_level=16): 373 | position_buffers = [] 374 | color_buffers = [] 375 | if level <= max_level: 376 | position_buffer = node.pointcloud.points 377 | position_buffers.append(position_buffer) 378 | color_buffer = node.pointcloud.colors 379 | color_buffers.append(color_buffer) 380 | if hasattr(node, 'children'): 381 | for child in node.children: 382 | if child is not None: 383 | result = collect_position_buffers(child, level + 1, max_level) 384 | position_buffers.extend(result[0]) 385 | color_buffers.extend(result[1]) 386 | return position_buffers, color_buffers 387 | 388 | def recover_octree(octree_path, node, level): 389 | if level <= 16: 390 | position_buffer = node.pointcloud.points 391 | color_buffer = node.pointcloud.colors 392 | name = node.name 393 | output_path = os.path.join(octree_path, f"level_{level}_{name}.ply") 394 | vertices = np.array( 395 | [(position[0], position[1], position[2], color[0], color[1], color[2]) for position, color in zip(position_buffer, color_buffer)], 396 | dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 397 | ) 398 | el = PlyElement.describe(vertices, 'vertex') 399 | PlyData([el], text=True).write(output_path) 400 | if hasattr(node, 'children'): 401 | for child in node.children: 402 | if child is not None: 403 | recover_octree(octree_path, child, level + 1) 404 | 405 | def readoctreeColmapInfo(path, images, depths, eval, llffhold=8): 406 | try: 407 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 408 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 409 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 410 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 411 | except: 412 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 413 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 414 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 415 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 416 | 417 | images_dir = "images" if images == None else images 418 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 419 | images_folder=os.path.join(path, images_dir), 420 | depths_folder=os.path.join(path, depths) if depths != "" else None) 421 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 422 | 423 | if eval: 424 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 425 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 426 | else: 427 | train_cam_infos = cam_infos 428 | test_cam_infos = [] 429 | 430 | nerf_normalization = getNerfppNorm(train_cam_infos) 431 | 432 | octree_path = os.path.join(path, "octree") 433 | las_path = os.path.join(path, "octree/0/points3D.las") 434 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 435 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 436 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 437 | 438 | pcd = None 439 | if not os.path.exists(las_path): 440 | print("[ Dataloader ] Converting point3d.bin to LOD PCS, will happen only the first time you open the scene.") 441 | if not os.path.exists(ply_path): 442 | try: 443 | xyz, rgb, _ = read_points3D_binary(bin_path) 444 | except: 445 | xyz, rgb, _ = read_points3D_text(txt_path) 446 | storePly(ply_path, xyz, rgb) 447 | else: 448 | print("[ Dataloader ] Found .ply point cloud, skipping conversion.") 449 | pcd = fetchPly(ply_path) 450 | xyz, rgb = pcd.points, pcd.colors 451 | 452 | print("[ Dataloader ] Converting points to LAS format.") 453 | mkdir_p(os.path.dirname(las_path)) 454 | storeLas(las_path, xyz, rgb) 455 | 456 | # Convert to octree 457 | print("[ Dataloader ] Converting LAS to octree for level-of-detail pointclouds.") 458 | command = [os.path.join(os.getcwd(), "PotreeConverter/bin/Release/Converter.exe"), las_path, "-o", octree_path, "--overwrite"] 459 | subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 460 | print("[ Dataloader ] LOD pointclouds generated.") 461 | else: 462 | if not os.path.exists(os.path.join(octree_path, "metadata.json")): 463 | # Convert to octree 464 | print("[ Dataloader ] Converting LAS to octree for level-of-detail pointclouds.") 465 | command = [os.path.join(os.getcwd(), "PotreeConverter/bin/Release/Converter.exe"), las_path, "-o", octree_path, "--overwrite"] 466 | subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 467 | print("[ Dataloader ] LOD pointclouds generated.") 468 | else: 469 | print("[ Dataloader ] Found octree, skipping conversion.") 470 | 471 | # init the octree scene 472 | octreeGaussian = loadOctree(octree_path) 473 | octreeGaussianLoader = octreeGaussian.loader 474 | max_level = 0 475 | q = queue.Queue() 476 | q.put({"node": octreeGaussian.root, "level": 0}) 477 | while q.qsize() > 0: 478 | element = q.get() 479 | node = element["node"] 480 | max_level = level = element["level"] 481 | if level < 16: 482 | octreeGaussianLoader.load(node) 483 | 484 | for cid in range(8): 485 | child = node.children[cid] 486 | if child is not None: 487 | q.put({"node": child, "level": level + 1}) 488 | 489 | # concat all the position and color buffers into single ply 490 | # position_buffers, color_buffers = collect_position_buffers(octreeGaussian.root, 0, max_level) 491 | # all_positions = np.concatenate(position_buffers, axis=0) 492 | # all_color = np.concatenate(color_buffers, axis=0) 493 | # vertices = np.array( 494 | # [(position[0], position[1], position[2], color[0], color[1], color[2]) for position, color in zip(all_positions, all_color)], 495 | # dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 496 | # ) 497 | # el = PlyElement.describe(vertices, 'vertex') 498 | # PlyData([el], text=True).write(os.path.join(octree_path, "pcd_octree.ply")) 499 | 500 | # do not concat, save them individually 501 | # recover_octree(octree_path, octreeGaussian.root, 0) 502 | 503 | pcds = [] 504 | for level in range(0, max_level + 1): 505 | position_buffers, color_buffers = collect_position_buffers(octreeGaussian.root, 0, level) 506 | positions = np.concatenate(position_buffers, axis=0) 507 | colors = np.concatenate(color_buffers, axis=0) 508 | normals = np.zeros_like(positions) 509 | pcds.append(BasicPointCloud(positions, colors[:,:3] / 255.0, normals)) 510 | 511 | scene_info = SceneInfo(point_cloud=pcds, 512 | train_cameras=train_cam_infos, 513 | test_cameras=test_cam_infos, 514 | nerf_normalization=nerf_normalization, 515 | ply_path=ply_path, 516 | max_level=max_level) 517 | 518 | return scene_info 519 | 520 | sceneLoadTypeCallbacks = { 521 | "Colmap": readColmapSceneInfo, 522 | "Blender" : readNerfSyntheticInfo, 523 | "Octree" : readoctreeColmapInfo 524 | } -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation 15 | from torch import nn 16 | import os 17 | import laspy 18 | from utils.system_utils import mkdir_p 19 | from plyfile import PlyData, PlyElement 20 | from utils.sh_utils import RGB2SH 21 | from simple_knn._C import distCUDA2 22 | from utils.graphics_utils import BasicPointCloud 23 | from utils.general_utils import strip_symmetric, build_scaling_rotation 24 | 25 | class GaussianModel: 26 | 27 | def setup_functions(self): 28 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 29 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 30 | actual_covariance = L @ L.transpose(1, 2) 31 | symm = strip_symmetric(actual_covariance) 32 | return symm 33 | 34 | self.scaling_activation = torch.exp 35 | self.scaling_inverse_activation = torch.log 36 | 37 | self.covariance_activation = build_covariance_from_scaling_rotation 38 | 39 | self.opacity_activation = torch.sigmoid 40 | self.inverse_opacity_activation = inverse_sigmoid 41 | 42 | self.rotation_activation = torch.nn.functional.normalize 43 | 44 | 45 | def __init__(self, sh_degree : int, level:int = 0): 46 | self.active_sh_degree = 0 47 | self.max_sh_degree = sh_degree 48 | self.level = level 49 | self._xyz = torch.empty(0) 50 | self._features_dc = torch.empty(0) 51 | self._features_rest = torch.empty(0) 52 | self._scaling = torch.empty(0) 53 | self._rotation = torch.empty(0) 54 | self._opacity = torch.empty(0) 55 | self.max_radii2D = torch.empty(0) 56 | self.xyz_gradient_accum = torch.empty(0) 57 | self.denom = torch.empty(0) 58 | self.optimizer = None 59 | self.percent_dense = 0 60 | self.spatial_lr_scale = 0 61 | self.setup_functions() 62 | 63 | def capture(self): 64 | return ( 65 | self.active_sh_degree, 66 | self._xyz, 67 | self._features_dc, 68 | self._features_rest, 69 | self._scaling, 70 | self._rotation, 71 | self._opacity, 72 | self.max_radii2D, 73 | self.xyz_gradient_accum, 74 | self.denom, 75 | self.optimizer.state_dict(), 76 | self.spatial_lr_scale, 77 | ) 78 | 79 | def restore(self, model_args, training_args): 80 | (self.active_sh_degree, 81 | self._xyz, 82 | self._features_dc, 83 | self._features_rest, 84 | self._scaling, 85 | self._rotation, 86 | self._opacity, 87 | self.max_radii2D, 88 | xyz_gradient_accum, 89 | denom, 90 | opt_dict, 91 | self.spatial_lr_scale) = model_args 92 | self.training_setup(training_args) 93 | self.xyz_gradient_accum = xyz_gradient_accum 94 | self.denom = denom 95 | self.optimizer.load_state_dict(opt_dict) 96 | 97 | @property 98 | def get_scaling(self): 99 | return self.scaling_activation(self._scaling) 100 | 101 | @property 102 | def get_rotation(self): 103 | return self.rotation_activation(self._rotation) 104 | 105 | @property 106 | def get_xyz(self): 107 | return self._xyz 108 | 109 | @property 110 | def get_features(self): 111 | features_dc = self._features_dc 112 | features_rest = self._features_rest 113 | return torch.cat((features_dc, features_rest), dim=1) 114 | 115 | @property 116 | def get_opacity(self): 117 | return self.opacity_activation(self._opacity) 118 | 119 | def get_covariance(self, scaling_modifier = 1): 120 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 121 | 122 | def oneupSHdegree(self): 123 | if self.active_sh_degree < self.max_sh_degree: 124 | self.active_sh_degree += 1 125 | 126 | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): 127 | self.spatial_lr_scale = spatial_lr_scale 128 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 129 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) 130 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() 131 | features[:, :3, 0 ] = fused_color 132 | features[:, 3:, 1:] = 0.0 133 | 134 | print(f"[ Scene ] Number of points at Level {self.level}: ", fused_point_cloud.shape[0]) 135 | 136 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 137 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) 138 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 139 | rots[:, 0] = 1 140 | 141 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 142 | 143 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 144 | self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) 145 | self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) 146 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 147 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 148 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 149 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 150 | 151 | def training_setup(self, training_args): 152 | self.percent_dense = training_args.percent_dense 153 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 154 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 155 | 156 | l = [ 157 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, 158 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 159 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 160 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 161 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, 162 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} 163 | ] 164 | 165 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 166 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, 167 | lr_final=training_args.position_lr_final*self.spatial_lr_scale, 168 | lr_delay_mult=training_args.position_lr_delay_mult, 169 | max_steps=training_args.position_lr_max_steps) 170 | 171 | def update_learning_rate(self, iteration): 172 | ''' Learning rate scheduling per step ''' 173 | for param_group in self.optimizer.param_groups: 174 | if param_group["name"] == "xyz": 175 | lr = self.xyz_scheduler_args(iteration) 176 | param_group['lr'] = lr 177 | return lr 178 | 179 | def construct_list_of_attributes(self): 180 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 181 | # All channels except the 3 DC 182 | for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): 183 | l.append('f_dc_{}'.format(i)) 184 | for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): 185 | l.append('f_rest_{}'.format(i)) 186 | l.append('opacity') 187 | for i in range(self._scaling.shape[1]): 188 | l.append('scale_{}'.format(i)) 189 | for i in range(self._rotation.shape[1]): 190 | l.append('rot_{}'.format(i)) 191 | return l 192 | 193 | def save_ply(self, path): 194 | mkdir_p(os.path.dirname(path)) 195 | 196 | xyz = self._xyz.detach().cpu().numpy() 197 | normals = np.zeros_like(xyz) 198 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 199 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 200 | opacities = self._opacity.detach().cpu().numpy() 201 | scale = self._scaling.detach().cpu().numpy() 202 | rotation = self._rotation.detach().cpu().numpy() 203 | 204 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 205 | 206 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 207 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 208 | elements[:] = list(map(tuple, attributes)) 209 | el = PlyElement.describe(elements, 'vertex') 210 | PlyData([el]).write(path) 211 | 212 | def reset_opacity(self): 213 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) 214 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 215 | self._opacity = optimizable_tensors["opacity"] 216 | 217 | def load_ply(self, path): 218 | plydata = PlyData.read(path) 219 | 220 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 221 | np.asarray(plydata.elements[0]["y"]), 222 | np.asarray(plydata.elements[0]["z"])), axis=1) 223 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 224 | 225 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 226 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 227 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 228 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 229 | 230 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 231 | extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) 232 | assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 233 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 234 | for idx, attr_name in enumerate(extra_f_names): 235 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 236 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 237 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 238 | 239 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 240 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) 241 | scales = np.zeros((xyz.shape[0], len(scale_names))) 242 | for idx, attr_name in enumerate(scale_names): 243 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 244 | 245 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 246 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) 247 | rots = np.zeros((xyz.shape[0], len(rot_names))) 248 | for idx, attr_name in enumerate(rot_names): 249 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 250 | 251 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 252 | self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 253 | self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 254 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 255 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 256 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 257 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 258 | 259 | self.active_sh_degree = self.max_sh_degree 260 | 261 | def replace_tensor_to_optimizer(self, tensor, name): 262 | optimizable_tensors = {} 263 | for group in self.optimizer.param_groups: 264 | if group["name"] == name: 265 | stored_state = self.optimizer.state.get(group['params'][0], None) 266 | stored_state["exp_avg"] = torch.zeros_like(tensor) 267 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 268 | 269 | del self.optimizer.state[group['params'][0]] 270 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 271 | self.optimizer.state[group['params'][0]] = stored_state 272 | 273 | optimizable_tensors[group["name"]] = group["params"][0] 274 | return optimizable_tensors 275 | 276 | def _prune_optimizer(self, mask): 277 | optimizable_tensors = {} 278 | for group in self.optimizer.param_groups: 279 | stored_state = self.optimizer.state.get(group['params'][0], None) 280 | if stored_state is not None: 281 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 282 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 283 | 284 | del self.optimizer.state[group['params'][0]] 285 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 286 | self.optimizer.state[group['params'][0]] = stored_state 287 | 288 | optimizable_tensors[group["name"]] = group["params"][0] 289 | else: 290 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 291 | optimizable_tensors[group["name"]] = group["params"][0] 292 | return optimizable_tensors 293 | 294 | def prune_points(self, mask): 295 | valid_points_mask = ~mask 296 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 297 | 298 | self._xyz = optimizable_tensors["xyz"] 299 | self._features_dc = optimizable_tensors["f_dc"] 300 | self._features_rest = optimizable_tensors["f_rest"] 301 | self._opacity = optimizable_tensors["opacity"] 302 | self._scaling = optimizable_tensors["scaling"] 303 | self._rotation = optimizable_tensors["rotation"] 304 | 305 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 306 | 307 | self.denom = self.denom[valid_points_mask] 308 | self.max_radii2D = self.max_radii2D[valid_points_mask] 309 | 310 | def cat_tensors_to_optimizer(self, tensors_dict): 311 | optimizable_tensors = {} 312 | for group in self.optimizer.param_groups: 313 | assert len(group["params"]) == 1 314 | extension_tensor = tensors_dict[group["name"]] 315 | stored_state = self.optimizer.state.get(group['params'][0], None) 316 | if stored_state is not None: 317 | 318 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) 319 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) 320 | 321 | del self.optimizer.state[group['params'][0]] 322 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 323 | self.optimizer.state[group['params'][0]] = stored_state 324 | 325 | optimizable_tensors[group["name"]] = group["params"][0] 326 | else: 327 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 328 | optimizable_tensors[group["name"]] = group["params"][0] 329 | 330 | return optimizable_tensors 331 | 332 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation): 333 | d = {"xyz": new_xyz, 334 | "f_dc": new_features_dc, 335 | "f_rest": new_features_rest, 336 | "opacity": new_opacities, 337 | "scaling" : new_scaling, 338 | "rotation" : new_rotation} 339 | 340 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 341 | self._xyz = optimizable_tensors["xyz"] 342 | self._features_dc = optimizable_tensors["f_dc"] 343 | self._features_rest = optimizable_tensors["f_rest"] 344 | self._opacity = optimizable_tensors["opacity"] 345 | self._scaling = optimizable_tensors["scaling"] 346 | self._rotation = optimizable_tensors["rotation"] 347 | 348 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 349 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 350 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 351 | 352 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 353 | n_init_points = self.get_xyz.shape[0] 354 | # Extract points that satisfy the gradient condition 355 | padded_grad = torch.zeros((n_init_points), device="cuda") 356 | padded_grad[:grads.shape[0]] = grads.squeeze() 357 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 358 | selected_pts_mask = torch.logical_and(selected_pts_mask, 359 | torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) 360 | 361 | stds = self.get_scaling[selected_pts_mask].repeat(N,1) 362 | means =torch.zeros((stds.size(0), 3),device="cuda") 363 | samples = torch.normal(mean=means, std=stds) 364 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) 365 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 366 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) 367 | new_rotation = self._rotation[selected_pts_mask].repeat(N,1) 368 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) 369 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) 370 | new_opacity = self._opacity[selected_pts_mask].repeat(N,1) 371 | 372 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) 373 | 374 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 375 | self.prune_points(prune_filter) 376 | 377 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 378 | # Extract points that satisfy the gradient condition 379 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 380 | selected_pts_mask = torch.logical_and(selected_pts_mask, 381 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) 382 | 383 | new_xyz = self._xyz[selected_pts_mask] 384 | new_features_dc = self._features_dc[selected_pts_mask] 385 | new_features_rest = self._features_rest[selected_pts_mask] 386 | new_opacities = self._opacity[selected_pts_mask] 387 | new_scaling = self._scaling[selected_pts_mask] 388 | new_rotation = self._rotation[selected_pts_mask] 389 | 390 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) 391 | 392 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 393 | grads = self.xyz_gradient_accum / self.denom 394 | grads[grads.isnan()] = 0.0 395 | 396 | self.densify_and_clone(grads, max_grad, extent) 397 | self.densify_and_split(grads, max_grad, extent) 398 | 399 | prune_mask = (self.get_opacity < min_opacity).squeeze() 400 | if max_screen_size: 401 | big_points_vs = self.max_radii2D > max_screen_size 402 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 403 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 404 | self.prune_points(prune_mask) 405 | 406 | torch.cuda.empty_cache() 407 | 408 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 409 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor[update_filter,:2], dim=-1, keepdim=True) 410 | # self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) 411 | self.denom[update_filter] += 1 -------------------------------------------------------------------------------- /scene/octree_loader.py: -------------------------------------------------------------------------------- 1 | # ref: potree\src\modules\loader\2.0\octreeGaussian.js 2 | # create by Penghao Wang 3 | 4 | import os 5 | import numpy as np 6 | import json 7 | 8 | from utils.graphics_utils import OctreeGaussianNode, Vector3, BoundingBox, OctreeGaussian 9 | from scene.gaussian_model import GaussianModel 10 | from utils.graphics_utils import BasicPointCloud 11 | 12 | octreeConst = { 13 | "pointBudget": 1 * 1000 * 1000, 14 | "framenumber" : 0, 15 | "numNodesLoading" : 0, 16 | "maxNodesLoading" : 4 17 | } 18 | 19 | PointAttributeTypesTmp = { 20 | "DATA_TYPE_DOUBLE": {"ordinal": 0, "name": "double", "size": 8}, 21 | "DATA_TYPE_FLOAT": {"ordinal": 1, "name": "float", "size": 4}, 22 | "DATA_TYPE_INT8": {"ordinal": 2, "name": "int8", "size": 1}, 23 | "DATA_TYPE_UINT8": {"ordinal": 3, "name": "uint8", "size": 1}, 24 | "DATA_TYPE_INT16": {"ordinal": 4, "name": "int16", "size": 2}, 25 | "DATA_TYPE_UINT16": {"ordinal": 5, "name": "uint16", "size": 2}, 26 | "DATA_TYPE_INT32": {"ordinal": 6, "name": "int32", "size": 4}, 27 | "DATA_TYPE_UINT32": {"ordinal": 7, "name": "uint32", "size": 4}, 28 | "DATA_TYPE_INT64": {"ordinal": 8, "name": "int64", "size": 8}, 29 | "DATA_TYPE_UINT64": {"ordinal": 9, "name": "uint64", "size": 8} 30 | } 31 | 32 | PointAttributeTypes = PointAttributeTypesTmp.copy() 33 | 34 | i = 0 35 | for obj in PointAttributeTypesTmp: 36 | PointAttributeTypes[str(i)] = PointAttributeTypesTmp[obj] 37 | i += 1 38 | 39 | # print(PointAttributeTypes) 40 | 41 | class PointAttribute: 42 | def __init__(self, name, type, numElements): 43 | self.name = name 44 | self.type = type 45 | self.numElements = numElements 46 | self.byteSize = self.numElements * self.type['size'] 47 | self.description = "" 48 | self.range = [float('inf'), float('-inf')] 49 | 50 | # Defining the static attributes for the PointAttribute class 51 | PointAttribute.POSITION_CARTESIAN = PointAttribute("POSITION_CARTESIAN", PointAttributeTypes["DATA_TYPE_FLOAT"], 3) 52 | PointAttribute.RGBA_PACKED = PointAttribute("COLOR_PACKED", PointAttributeTypes["DATA_TYPE_INT8"], 4) 53 | PointAttribute.COLOR_PACKED = PointAttribute.RGBA_PACKED 54 | PointAttribute.RGB_PACKED = PointAttribute("COLOR_PACKED", PointAttributeTypes["DATA_TYPE_INT8"], 3) 55 | PointAttribute.NORMAL_FLOATS = PointAttribute("NORMAL_FLOATS", PointAttributeTypes["DATA_TYPE_FLOAT"], 3) 56 | PointAttribute.INTENSITY = PointAttribute("INTENSITY", PointAttributeTypes["DATA_TYPE_UINT16"], 1) 57 | PointAttribute.CLASSIFICATION = PointAttribute("CLASSIFICATION", PointAttributeTypes["DATA_TYPE_UINT8"], 1) 58 | PointAttribute.NORMAL_SPHEREMAPPED = PointAttribute("NORMAL_SPHEREMAPPED", PointAttributeTypes["DATA_TYPE_UINT8"], 2) 59 | PointAttribute.NORMAL_OCT16 = PointAttribute("NORMAL_OCT16", PointAttributeTypes["DATA_TYPE_UINT8"], 2) 60 | PointAttribute.NORMAL = PointAttribute("NORMAL", PointAttributeTypes["DATA_TYPE_FLOAT"], 3) 61 | PointAttribute.RETURN_NUMBER = PointAttribute("RETURN_NUMBER", PointAttributeTypes["DATA_TYPE_UINT8"], 1) 62 | PointAttribute.NUMBER_OF_RETURNS = PointAttribute("NUMBER_OF_RETURNS", PointAttributeTypes["DATA_TYPE_UINT8"], 1) 63 | PointAttribute.SOURCE_ID = PointAttribute("SOURCE_ID", PointAttributeTypes["DATA_TYPE_UINT16"], 1) 64 | PointAttribute.INDICES = PointAttribute("INDICES", PointAttributeTypes["DATA_TYPE_UINT32"], 1) 65 | PointAttribute.SPACING = PointAttribute("SPACING", PointAttributeTypes["DATA_TYPE_FLOAT"], 1) 66 | PointAttribute.GPS_TIME = PointAttribute("GPS_TIME", PointAttributeTypes["DATA_TYPE_DOUBLE"], 1) 67 | 68 | class PointAttributes: 69 | def __init__(self, pointAttributes=None): 70 | self.attributes = [] 71 | self.byteSize = 0 72 | self.size = 0 73 | self.vectors = [] 74 | 75 | if pointAttributes is not None: 76 | for pointAttributeName in pointAttributes: 77 | pointAttribute = getattr(PointAttribute, pointAttributeName, None) 78 | if pointAttribute: 79 | self.attributes.append(pointAttribute) 80 | self.byteSize += pointAttribute.byteSize 81 | self.size += 1 82 | 83 | def add(self, pointAttribute): 84 | self.attributes.append(pointAttribute) 85 | self.byteSize += pointAttribute.byteSize 86 | self.size += 1 87 | 88 | def addVector(self, vector): 89 | self.vectors.append(vector) 90 | 91 | typename_typeattribute_map = { 92 | "double": PointAttributeTypes["DATA_TYPE_DOUBLE"], 93 | "float": PointAttributeTypes["DATA_TYPE_FLOAT"], 94 | "int8": PointAttributeTypes["DATA_TYPE_INT8"], 95 | "uint8": PointAttributeTypes["DATA_TYPE_UINT8"], 96 | "int16": PointAttributeTypes["DATA_TYPE_INT16"], 97 | "uint16": PointAttributeTypes["DATA_TYPE_UINT16"], 98 | "int32": PointAttributeTypes["DATA_TYPE_INT32"], 99 | "uint32": PointAttributeTypes["DATA_TYPE_UINT32"], 100 | "int64": PointAttributeTypes["DATA_TYPE_INT64"], 101 | "uint64": PointAttributeTypes["DATA_TYPE_UINT64"], 102 | } 103 | 104 | tmpVec3 = Vector3() 105 | 106 | def createChildAABB(aabb: BoundingBox, index: int) -> BoundingBox: 107 | minPoint = Vector3(aabb.min.x, aabb.min.y, aabb.min.z) 108 | maxPoint = Vector3(aabb.max.x, aabb.max.y, aabb.max.z) 109 | size = tmpVec3.subVectors(maxPoint, minPoint) 110 | 111 | if (index & 0b0001) > 0: 112 | minPoint.z += size.z / 2 113 | else: 114 | maxPoint.z -= size.z / 2 115 | 116 | if (index & 0b0010) > 0: 117 | minPoint.y += size.y / 2 118 | else: 119 | maxPoint.y -= size.y / 2 120 | 121 | if (index & 0b0100) > 0: 122 | minPoint.x += size.x / 2 123 | else: 124 | maxPoint.x -= size.x / 2 125 | 126 | return BoundingBox(min_point=minPoint, max_point=maxPoint) 127 | 128 | def loadOctree(path): 129 | if not os.path.exists(path): 130 | return None 131 | 132 | if "metadata.json" not in os.listdir(path): 133 | assert False, "[ Error ] Octree path dir does not contain metadata.json in loadOctree method" 134 | 135 | loadworker = octreeLoader() 136 | octree = loadworker.load(path) 137 | return octree 138 | 139 | def toIndex(x, y, z, sizeX, sizeY, sizeZ): 140 | gridSize = 32 141 | dx = gridSize * x / sizeX 142 | dy = gridSize * y / sizeY 143 | dz = gridSize * z / sizeZ 144 | 145 | # print(dx, gridSize) 146 | 147 | ix = min(int(dx), gridSize - 1) 148 | iy = min(int(dy), gridSize - 1) 149 | iz = min(int(dz), gridSize - 1) 150 | 151 | index = ix + iy * gridSize + iz * gridSize * gridSize 152 | return index 153 | 154 | class nodeLoader(): 155 | 156 | def __init__(self, path: str) -> None: 157 | self.path = path 158 | self.metadata = None 159 | self.attributes = None 160 | self.scale = None 161 | self.offset = None 162 | 163 | def loadHierarchy(self, node: OctreeGaussianNode) -> None: 164 | hierarchyByteOffset = node.hierarchyByteOffset 165 | hierarchyByteSize = node.hierarchyByteSize 166 | first = hierarchyByteOffset 167 | last = first + hierarchyByteSize - 1 168 | # load the hierarchy.bin from byte first to last 169 | hierarchyPath = os.path.join(self.path, "hierarchy.bin") 170 | with open(hierarchyPath, "rb") as f: 171 | # load from first to last, which is index of bytes 172 | f.seek(first) 173 | buffer = f.read(last - first + 1) 174 | f.close() 175 | self.parseHierarchy(node, buffer) 176 | 177 | def parseHierarchy(self, node: OctreeGaussianNode, buffer): 178 | bytesPerNode = 22 179 | numNodes = int(len(buffer) / bytesPerNode) 180 | 181 | octree = node.octreeGaussian 182 | nodes = [None for i in range(numNodes)] 183 | nodes[0] = node 184 | nodePos = 1 185 | 186 | for i in range(numNodes): 187 | 188 | start = i * bytesPerNode 189 | # uint8 190 | type = buffer[start] 191 | # uint8 192 | childMask = buffer[start + 1] 193 | # uint32 194 | numPoints = int.from_bytes(buffer[start + 2:start + 6], byteorder='little', signed=False) 195 | # bigint 64 196 | byteOffset = int.from_bytes(buffer[start + 6:start + 14], byteorder='little', signed=True) 197 | # bigint 64 198 | byteSize = int.from_bytes(buffer[start + 14:start + 22], byteorder='little', signed=True) 199 | 200 | # print(f"[ Info ] type: {type}, childMask: {childMask}, numPoints: {numPoints}, byteOffset: {byteOffset}, byteSize: {byteSize}") 201 | 202 | if nodes[i].nodeType == 2: 203 | nodes[i].byteOffset = byteOffset 204 | nodes[i].byteSize = byteSize 205 | nodes[i].numGaussians = numPoints 206 | elif type == 2: 207 | nodes[i].hierarchyByteOffset = byteOffset 208 | nodes[i].hierarchyByteSize = byteSize 209 | nodes[i].numGaussians = numPoints 210 | else: 211 | nodes[i].byteOffset = byteOffset 212 | nodes[i].byteSize = byteSize 213 | nodes[i].numGaussians = numPoints 214 | 215 | if nodes[i].byteSize == 0: 216 | nodes[i].numGaussians = 0 217 | 218 | nodes[i].nodeType = type 219 | 220 | if nodes[i].nodeType == 2: 221 | continue 222 | 223 | for childIndex in range(8): 224 | childExists = ((1 << childIndex) & childMask) != 0 225 | if not childExists: 226 | continue 227 | 228 | childName = nodes[i].name + str(childIndex) 229 | childAABB = createChildAABB(nodes[i].boundingbox, childIndex) 230 | child = OctreeGaussianNode(childName, octree, childAABB) 231 | child.name = childName 232 | child.spacing = nodes[i].spacing / 2 233 | child.level = nodes[i].level + 1 234 | child.parent = nodes[i] 235 | 236 | nodes[i].children[childIndex] = child 237 | nodes[nodePos] = child 238 | nodePos += 1 239 | 240 | def load(self, node: OctreeGaussianNode) -> None: 241 | 242 | # print(node) 243 | 244 | if (node.loaded or node.loading): 245 | return 246 | 247 | node.loading = True 248 | octreeConst["numNodesLoading"] += 1 249 | 250 | if node.nodeType == 2: 251 | self.loadHierarchy(node=node) 252 | 253 | byteOffset = node.byteOffset 254 | byteSize = node.byteSize 255 | 256 | octreePath = os.path.join(self.path, "octree.bin") 257 | 258 | first = byteOffset 259 | last = first + byteSize - 1 260 | 261 | if byteSize == 0: 262 | assert False, "[ Error ] byteSize is 0 in nodeLoader.load method" 263 | else: 264 | with open(octreePath, "rb") as f: 265 | f.seek(first) 266 | buffer = f.read(last - first + 1) 267 | f.close() 268 | 269 | attributeBuffers = {} 270 | attributeOffset = 0 271 | 272 | bytesPerPoint = 0 273 | for pointAttribute in node.octreeGaussian.pointAttributes.attributes: 274 | bytesPerPoint += pointAttribute.byteSize 275 | 276 | scale = node.octreeGaussian.scale 277 | offset = node.octreeGaussian.loader.offset 278 | 279 | # print(node.octreeGaussian.loader.offset) 280 | # print(node.octreeGaussian.offset) 281 | 282 | for pointAttribute in node.octreeGaussian.pointAttributes.attributes: 283 | if pointAttribute.name in ["POSITION_CARTESIAN", "position"]: 284 | buff = np.zeros(node.numGaussians * 3, dtype=np.float32) 285 | positions = buff 286 | for j in range(node.numGaussians): 287 | pointOffset = j * bytesPerPoint 288 | 289 | # reserve pos aligned with colmap coordinate 290 | x = (int.from_bytes(buffer[pointOffset + attributeOffset + 0:pointOffset + attributeOffset + 4], byteorder="little", signed=True) * scale[0]) + offset[0] 291 | y = (int.from_bytes(buffer[pointOffset + attributeOffset + 4:pointOffset + attributeOffset + 8], byteorder="little", signed=True) * scale[1]) + offset[1] 292 | z = (int.from_bytes(buffer[pointOffset + attributeOffset + 8:pointOffset + attributeOffset + 12], byteorder="little", signed=True) * scale[2]) + offset[2] 293 | positions[3 * j + 0] = x 294 | positions[3 * j + 1] = y 295 | positions[3 * j + 2] = z 296 | 297 | attributeBuffers[pointAttribute.name] = {"buffer": buff, "attribute": pointAttribute} 298 | elif pointAttribute.name in ["RGBA", "rgba"]: 299 | buff = np.zeros(node.numGaussians * 4, dtype = np.uint8) 300 | colors = buff 301 | 302 | for j in range(node.numGaussians): 303 | pointOffset = j * bytesPerPoint 304 | r = np.frombuffer(buffer[pointOffset + attributeOffset + 0:pointOffset + attributeOffset + 2], dtype=np.uint16)[0] 305 | g = np.frombuffer(buffer[pointOffset + attributeOffset + 2:pointOffset + attributeOffset + 4], dtype=np.uint16)[0] 306 | b = np.frombuffer(buffer[pointOffset + attributeOffset + 4:pointOffset + attributeOffset + 6], dtype=np.uint16)[0] 307 | 308 | colors[4 * j + 0] = r / 256 if r > 255 else r 309 | colors[4 * j + 1] = g / 256 if g > 255 else g 310 | colors[4 * j + 2] = b / 256 if b > 255 else b 311 | 312 | attributeBuffers[pointAttribute.name] = {"buffer": buff, "attribute": pointAttribute} 313 | 314 | else: 315 | # other attribute no need 316 | pass 317 | 318 | attributeOffset += pointAttribute.byteSize 319 | 320 | node_position = np.array(attributeBuffers["position"]["buffer"], dtype=np.float32).reshape(-1, 3) 321 | try: 322 | node_colors = np.array(attributeBuffers["rgba"]["buffer"], dtype=np.uint8).reshape(-1, 4) 323 | except: 324 | node_colors = np.zeros(len(attributeBuffers["position"]["buffer"])) 325 | 326 | pcd = BasicPointCloud(node_position, node_colors, None) 327 | 328 | node.pointcloud = pcd 329 | 330 | node.loaded = True 331 | node.loading = False 332 | octreeConst["numNodesLoading"] -= 1 333 | 334 | class octreeLoader(): 335 | 336 | def __init__(self) -> None: 337 | self.metadata = None 338 | 339 | def load(self, path_dir): 340 | 341 | # get metadata path 342 | path = os.path.join(path_dir, "metadata.json") 343 | 344 | if not os.path.exists(path): 345 | assert False, "[ Error ] Path does not exist in disk in octreeLoader.load method" 346 | 347 | # load the json file and parse it 348 | with open(path, 'r') as file: 349 | self.metadata = json.load(file) 350 | file.close() 351 | 352 | # parse the attributes 353 | attributes = self.parseAttributes(self.metadata["attributes"]) 354 | 355 | loader = nodeLoader(path_dir) 356 | loader.metadata = self.metadata 357 | loader.attributes = attributes 358 | loader.scale = self.metadata["scale"] 359 | loader.offset = self.metadata["offset"] 360 | 361 | # define octreeGaussian 362 | octree = OctreeGaussian() 363 | octree.spacing = self.metadata["spacing"] 364 | octree.scale = self.metadata["scale"] 365 | 366 | meta_min = self.metadata["boundingBox"]["min"] 367 | meta_max = self.metadata["boundingBox"]["max"] 368 | 369 | min = Vector3(meta_min[0], meta_min[1], meta_min[2]) 370 | max = Vector3(meta_max[0], meta_max[1], meta_max[2]) 371 | boundingBox = BoundingBox(min, max) 372 | 373 | offset = Vector3(meta_min[0], meta_min[1], meta_min[2]) 374 | 375 | boundingBox.min -= offset 376 | boundingBox.max -= offset 377 | 378 | octree.projection = self.metadata["projection"] 379 | octree.boundingBox = boundingBox 380 | octree.offset = offset 381 | octree.pointAttributes = self.parseAttributes(self.metadata["attributes"]) 382 | octree.loader = loader 383 | 384 | root = OctreeGaussianNode("r", octree, boundingBox) 385 | root.level = 0 386 | root.nodeType = 2 387 | root.hierarchyByteOffset = 0 388 | root.hierarchyByteSize = self.metadata["hierarchy"]["firstChunkSize"] 389 | root.spacing = octree.spacing 390 | root.byteOffset = 0 391 | 392 | octree.root = root 393 | 394 | loader.load(root) 395 | 396 | return octree 397 | 398 | def parseAttributes(self, jsonAttributes: list) -> None: 399 | attributes = PointAttributes() 400 | replacements = { 401 | "rgb": "rgba" 402 | } 403 | for jsonAttribute in jsonAttributes: 404 | 405 | name = jsonAttribute["name"] 406 | description = jsonAttribute["description"] 407 | size = jsonAttribute["size"] 408 | numElements = jsonAttribute["numElements"] 409 | elementSize = jsonAttribute["elementSize"] 410 | type = jsonAttribute["type"] 411 | min = jsonAttribute["min"] 412 | max = jsonAttribute["max"] 413 | 414 | type = typename_typeattribute_map[type] 415 | potreeAttributeName = replacements[name] if name in replacements else name 416 | attribute = PointAttribute(potreeAttributeName, type, numElements) 417 | 418 | if numElements == 1: 419 | attribute.range = [min[0], max[0]] 420 | else: 421 | attribute.range = [min, max] 422 | 423 | if name == "gps-time": 424 | if attribute.range[0] == attribute.range[1]: 425 | attribute.range[1] += 1 426 | 427 | attributes.add(attribute) 428 | 429 | # check if it has normals 430 | if any(attr.name == "NormalX" for attr in attributes.attributes) and \ 431 | any(attr.name == "NormalY" for attr in attributes.attributes) and \ 432 | any(attr.name == "NormalZ" for attr in attributes.attributes): 433 | vector = { 434 | "name": "NORMAL", 435 | "attributes": ["NormalX", "NormalY", "NormalZ"] 436 | } 437 | attributes.addVector(vector) 438 | 439 | return attributes -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, l2_loss, ssim 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | from scene import Scene, GaussianModel 19 | from utils.general_utils import safe_state 20 | import uuid 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr, warpped_depth, unwarpped_depth 23 | from argparse import ArgumentParser, Namespace 24 | from arguments import ModelParams, PipelineParams, OptimizationParams 25 | try: 26 | from torch.utils.tensorboard import SummaryWriter 27 | TENSORBOARD_FOUND = True 28 | except ImportError: 29 | TENSORBOARD_FOUND = False 30 | 31 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 32 | first_iter = 0 33 | 34 | # Set up output folder 35 | if not dataset.model_path: 36 | if os.getenv('OAR_JOB_ID'): 37 | unique_str=os.getenv('OAR_JOB_ID') 38 | else: 39 | unique_str = str(uuid.uuid4()) 40 | dataset.model_path = os.path.join(dataset.source_path, "3D-Gaussian-Splatting", unique_str[0:10]) 41 | print("[ Training ] Output Folder: {}".format(dataset.model_path)) 42 | os.makedirs(dataset.model_path, exist_ok = True) 43 | 44 | # Load dataset 45 | scene = Scene(dataset) 46 | scene.training_setup(opt) 47 | 48 | # extract scene depth max 49 | dataset.depth_max = scene.depth_max.cpu().item() 50 | 51 | # prepare logger and extract parameters 52 | tb_writer = prepare_output_and_logger(dataset) 53 | 54 | if checkpoint: 55 | (model_params, first_iter) = torch.load(checkpoint) 56 | scene.restore(model_params, opt) 57 | 58 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 59 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 60 | 61 | iter_start = torch.cuda.Event(enable_timing = True) 62 | iter_end = torch.cuda.Event(enable_timing = True) 63 | 64 | viewpoint_stack = None 65 | level_stack = None 66 | ema_loss_for_log = 0.0 67 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 68 | first_iter += 1 69 | for iteration in range(first_iter, opt.iterations + 1): 70 | if network_gui.conn == None: 71 | network_gui.try_connect() 72 | while network_gui.conn != None: 73 | try: 74 | net_image_bytes = None 75 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 76 | if custom_cam != None: 77 | # net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 78 | xyz, features, opacity, scales, rotations, cov3D_precomp, \ 79 | active_sh_degree, max_sh_degree, masks = scene.get_gaussian_parameters(viewpoint_cam.world_view_transform, pipe.compute_cov3D_python, scaling_modifer) 80 | net_image = render(custom_cam, xyz, features, opacity, scales, rotations, active_sh_degree, max_sh_degree, pipe, background, scaling_modifer, cov3D_precomp = cov3D_precomp) 81 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 82 | network_gui.send(net_image_bytes, dataset.source_path) 83 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 84 | break 85 | except Exception as e: 86 | network_gui.conn = None 87 | 88 | iter_start.record() 89 | 90 | scene.update_learning_rate(iteration) 91 | 92 | # Every 1000 its we increase the levels of SH up to a maximum degree 93 | if iteration % 1000 == 0: 94 | scene.oneupSHdegree() 95 | 96 | # Pick a random Camera 97 | if not viewpoint_stack: 98 | viewpoint_stack = scene.getTrainCameras().copy() 99 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 100 | 101 | # Pick a random Level 102 | if not level_stack: 103 | level_stack = list(range(-scene.max_level, scene.max_level + 1)) 104 | random_level = level_stack.pop(randint(0, len(level_stack)-1)) 105 | # random_level =-1 106 | 107 | # Render 108 | if (iteration - 1) == debug_from: 109 | pipe.debug = True 110 | 111 | # render_pkg = render(viewpoint_cam, gaussians, pipe, background) 112 | xyz, features, opacity, scales, rotations, cov3D_precomp, \ 113 | active_sh_degree, max_sh_degree, masks = scene.get_gaussian_parameters(viewpoint_cam.world_view_transform, pipe.compute_cov3D_python, random = random_level) 114 | render_pkg = render(viewpoint_cam, xyz, features, opacity, scales, rotations, active_sh_degree, max_sh_degree, pipe, background, cov3D_precomp = cov3D_precomp) 115 | image, viewspace_point_tensor, visibility_filter, radii, depth = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"], render_pkg["depth"] 116 | depth = warpped_depth(depth) 117 | # Loss 118 | gt_image = viewpoint_cam.original_image.cuda() 119 | Ll1 = l1_loss(image, gt_image) 120 | rgb_loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 121 | depth_loss = 0.0 122 | 123 | if viewpoint_cam.depth is not None: 124 | gt_depth = viewpoint_cam.depth.cuda() 125 | gt_depth_mask = viewpoint_cam.depth_mask.cuda() 126 | 127 | if iteration > opt.iterations / 3: 128 | gt_depth = gt_depth * gt_depth_mask + (1 - gt_depth_mask) * 0.75 129 | gt_depth_mask = depth < gt_depth 130 | 131 | depth_loss = 2.0 * l1_loss(depth * gt_depth_mask, gt_depth * gt_depth_mask) 132 | 133 | loss = rgb_loss + depth_loss 134 | loss.backward() 135 | 136 | iter_end.record() 137 | 138 | with torch.no_grad(): 139 | # Progress bar 140 | ema_loss_for_log = 0.4 * rgb_loss.item() + 0.6 * ema_loss_for_log 141 | ema_depth_loss_for_log = 0.0 142 | if viewpoint_cam.depth is not None: 143 | ema_depth_loss_for_log = 0.4 * depth_loss.item() + 0.6 * ema_depth_loss_for_log 144 | if iteration % 10 == 0: 145 | progress_bar.set_postfix({"RGB Loss": f"{ema_loss_for_log:.{4}f}", "Depth Loss": f"{ema_depth_loss_for_log:.{4}f}"}) 146 | progress_bar.update(10) 147 | if iteration == opt.iterations: 148 | progress_bar.close() 149 | 150 | # Log and save 151 | training_report(tb_writer, iteration, rgb_loss, depth_loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 152 | if (iteration in saving_iterations): 153 | print("[ Training ] [ITER {}] Saving Gaussians".format(iteration)) 154 | scene.save(iteration) 155 | 156 | # Densification 157 | if iteration < opt.densify_until_iter: 158 | # Keep track of max radii in image-space for pruning 159 | scene.update_max_radii2D(radii, visibility_filter, masks) 160 | scene.add_densification_stats(viewspace_point_tensor, visibility_filter, masks) 161 | 162 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 163 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 164 | scene.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 165 | 166 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 167 | scene.reset_opacity() 168 | 169 | # Optimizer step 170 | if iteration < opt.iterations: 171 | scene.optimizer_step() 172 | 173 | def prepare_output_and_logger(args): 174 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 175 | cfg_log_f.write(str(Namespace(**vars(args)))) 176 | 177 | # Create Tensorboard writer 178 | tb_writer = None 179 | if TENSORBOARD_FOUND: 180 | tb_writer = SummaryWriter(args.model_path) 181 | else: 182 | print("[ Training ] Tensorboard not available: not logging progress") 183 | return tb_writer 184 | 185 | def training_report(tb_writer, iteration, rgb_loss, depth_loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 186 | if tb_writer: 187 | tb_writer.add_scalar('train_loss_patches/rgb_loss', rgb_loss.item(), iteration) 188 | if isinstance(depth_loss, torch.Tensor): 189 | tb_writer.add_scalar('train_loss_patches/depth_loss', depth_loss.item(), iteration) 190 | tb_writer.add_scalar('train_loss_patches/total_loss', rgb_loss.item() + depth_loss.item(), iteration) 191 | else: 192 | tb_writer.add_scalar('train_loss_patches/depth_loss', depth_loss, iteration) 193 | tb_writer.add_scalar('train_loss_patches/total_loss', rgb_loss.item() + depth_loss, iteration) 194 | 195 | tb_writer.add_scalar('iter_time', elapsed, iteration) 196 | tb_writer.add_scalar('memory/memory_allocated', torch.cuda.memory_allocated('cuda') / (1024 ** 3), iteration) 197 | tb_writer.add_scalar('memory/memory_reserved', torch.cuda.memory_reserved('cuda') / (1024 ** 3), iteration) 198 | 199 | # Report test and samples of training set 200 | if iteration in testing_iterations: 201 | torch.cuda.empty_cache() 202 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 203 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 204 | 205 | for config in validation_configs: 206 | if config['cameras'] and len(config['cameras']) > 0: 207 | images = torch.tensor([], device="cuda") 208 | gts = torch.tensor([], device="cuda") 209 | l1_test, psnr_test = [], [] 210 | for idx, viewpoint in enumerate(config['cameras']): 211 | xyz, features, opacity, scales, rotations, cov3D_precomp, \ 212 | active_sh_degree, max_sh_degree, masks = scene.get_gaussian_parameters(viewpoint.world_view_transform, renderArgs[0].compute_cov3D_python) 213 | results = renderFunc(viewpoint, xyz, features, opacity, scales, rotations, active_sh_degree, max_sh_degree, cov3D_precomp = cov3D_precomp, *renderArgs) 214 | image = torch.clamp(results["render"], 0.0, 1.0) 215 | depth = warpped_depth(results["depth"]) 216 | depth = (depth - depth.min()) / (depth.max() - depth.min()) 217 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 218 | if viewpoint.depth is not None: 219 | gt_depth = viewpoint.depth.to("cuda") 220 | gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - gt_depth.min()) 221 | l1_test.append(l1_loss(image, gt_image)) 222 | psnr_test.append(psnr(image, gt_image).mean()) 223 | if tb_writer and (idx == 0): 224 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image.unsqueeze(0), global_step=iteration) 225 | tb_writer.add_images(config['name'] + "_view_{}/depth".format(viewpoint.image_name), depth.unsqueeze(0), global_step=iteration) 226 | if iteration == testing_iterations[0]: 227 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image.unsqueeze(0), global_step=iteration) 228 | if viewpoint.depth is not None: 229 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth_depth".format(viewpoint.image_name), gt_depth.unsqueeze(0), global_step=iteration) 230 | for level in range(scene.max_level + 1): 231 | viewpoint = config['cameras'][0] #[randint(0, len(config['cameras'])-1)] 232 | xyz, features, opacity, scales, rotations, cov3D_precomp, \ 233 | active_sh_degree, max_sh_degree, masks = scene.get_gaussian_parameters(viewpoint.world_view_transform, renderArgs[0].compute_cov3D_python, random=level) 234 | image = torch.clamp(renderFunc(viewpoint, xyz, features, opacity, scales, rotations, active_sh_degree, max_sh_degree, cov3D_precomp = cov3D_precomp, *renderArgs)["render"], 0.0, 1.0) 235 | if tb_writer: 236 | tb_writer.add_images(config['name'] + "_view_{}/level_{}".format(viewpoint.image_name, level), image.unsqueeze(0), global_step=iteration) 237 | 238 | l1_test = sum(l1_test) / len(l1_test) 239 | psnr_test = sum(psnr_test) / len(psnr_test) 240 | print("\n[ Training ] [ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 241 | if tb_writer: 242 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 243 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 244 | 245 | if tb_writer: 246 | tb_writer.add_histogram("scene/opacity_histogram", scene.getGaussians().get_opacity, iteration) 247 | tb_writer.add_scalar('total_points', scene.getGaussians().get_xyz.shape[0], iteration) 248 | torch.cuda.empty_cache() 249 | 250 | if __name__ == "__main__": 251 | # Set up command line argument parser 252 | parser = ArgumentParser(description="Training script parameters") 253 | lp = ModelParams(parser) 254 | op = OptimizationParams(parser) 255 | pp = PipelineParams(parser) 256 | parser.add_argument('--ip', type=str, default="127.0.0.1") 257 | parser.add_argument('--port', type=int, default=6009) 258 | parser.add_argument('--debug_from', type=int, default=-1) 259 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 260 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[50, 100, 500, 1_000, 5_000, 10_000, 20_000, 30_000]) 261 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1_000, 10_000, 20_000, 30_000]) 262 | parser.add_argument("--quiet", action="store_true") 263 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 264 | parser.add_argument("--start_checkpoint", type=str, default = None) 265 | args = parser.parse_args(sys.argv[1:]) 266 | args.save_iterations.append(args.iterations) 267 | 268 | print("[ Training ] Optimizing With Parameters: " + str(vars(args))) 269 | 270 | # Initialize system state (RNG) 271 | safe_state(args.quiet) 272 | 273 | # Start GUI server, configure and run training 274 | network_gui.init(args.ip, args.port) 275 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 276 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 277 | 278 | # All done 279 | print("[ Training ] Training complete.") -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | 17 | WARNED = False 18 | 19 | def loadCam(args, id, cam_info, resolution_scale): 20 | orig_w, orig_h = cam_info.image.size 21 | 22 | if args.resolution in [1, 2, 4, 8]: 23 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 24 | else: # should be a type that converts to float 25 | if args.resolution == -1: 26 | if orig_w > 1600: 27 | global WARNED 28 | if not WARNED: 29 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n" 30 | "[ INFO ] If this is not desired, please explicitly specify '--resolution/-r' as 1") 31 | WARNED = True 32 | global_down = orig_w / 1600 33 | else: 34 | global_down = 1 35 | else: 36 | global_down = orig_w / args.resolution 37 | 38 | scale = float(global_down) * float(resolution_scale) 39 | resolution = (int(orig_w / scale), int(orig_h / scale)) 40 | 41 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 42 | 43 | gt_image = resized_image_rgb[:3, ...] 44 | loaded_mask = None 45 | 46 | if resized_image_rgb.shape[1] == 4: 47 | loaded_mask = resized_image_rgb[3:4, ...] 48 | 49 | if cam_info.depth is not None: 50 | resized_depth_map = PILtoTorch(cam_info.depth.convert("LA"), resolution) 51 | gt_depth = resized_depth_map[0:1] 52 | gt_depth_mask = resized_depth_map[1:2] 53 | else: 54 | gt_depth = None 55 | gt_depth_mask = None 56 | 57 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 58 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 59 | image=gt_image, gt_alpha_mask=loaded_mask, 60 | depth=gt_depth, depth_mask=gt_depth_mask, 61 | image_name=cam_info.image_name, uid=id, 62 | cx = cam_info.cx, cy = cam_info.cy, data_device=args.data_device) 63 | 64 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 65 | camera_list = [] 66 | 67 | for id, c in enumerate(cam_infos): 68 | camera_list.append(loadCam(args, id, c, resolution_scale)) 69 | 70 | return camera_list 71 | 72 | def camera_to_JSON(id, camera : Camera): 73 | Rt = np.zeros((4, 4)) 74 | Rt[:3, :3] = camera.R.transpose() 75 | Rt[:3, 3] = camera.T 76 | Rt[3, 3] = 1.0 77 | 78 | W2C = np.linalg.inv(Rt) 79 | pos = W2C[:3, 3] 80 | rot = W2C[:3, :3] 81 | serializable_array_2d = [x.tolist() for x in rot] 82 | camera_entry = { 83 | 'id' : id, 84 | 'img_name' : camera.image_name, 85 | 'width' : camera.width, 86 | 'height' : camera.height, 87 | 'position': pos.tolist(), 88 | 'rotation': serializable_array_2d, 89 | 'fy' : fov2focal(camera.FovY, camera.height), 90 | 'fx' : fov2focal(camera.FovX, camera.width), 91 | 'cx' : camera.cx, 92 | 'cy' : camera.cy 93 | } 94 | return camera_entry 95 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | # simple vector3 18 | class Vector3: 19 | 20 | def __init__(self, x=float(0), y=float(0), z=float(0)): 21 | self.x = x 22 | self.y = y 23 | self.z = z 24 | 25 | def __repr__(self) -> str: 26 | return f"Vector3({self.x}, {self.y}, {self.z})" 27 | 28 | def set(self, x: float, y: float, z: float) -> None: 29 | self.x = x 30 | self.y = y 31 | self.z = z 32 | 33 | def __add__(self, other: 'Vector3') -> 'Vector3': 34 | return Vector3(self.x + other.x, self.y + other.y, self.z + other.z) 35 | 36 | def __sub__(self, other: 'Vector3') -> 'Vector3': 37 | return Vector3(self.x - other.x, self.y - other.y, self.z - other.z) 38 | 39 | def __mul__(self, scalar: float) -> 'Vector3': 40 | return Vector3(self.x * scalar, self.y * scalar, self.z * scalar) 41 | 42 | def __truediv__(self, scalar: float) -> 'Vector3': 43 | return Vector3(self.x / scalar, self.y / scalar, self.z / scalar) 44 | 45 | def addVectors(self, vecA: 'Vector3', vecB: 'Vector3') -> 'Vector3': 46 | self.x = vecA.x + vecB.x 47 | self.y = vecA.y + vecB.y 48 | self.z = vecA.z + vecB.z 49 | return self 50 | 51 | def subVectors(self, vecA: 'Vector3', vecB: 'Vector3') -> 'Vector3': 52 | self.x = vecA.x - vecB.x 53 | self.y = vecA.y - vecB.y 54 | self.z = vecA.z - vecB.z 55 | return self 56 | 57 | def multiplyScalar(self, scalar: float) -> 'Vector3': 58 | self.x *= scalar 59 | self.y *= scalar 60 | self.z *= scalar 61 | return self 62 | 63 | def length(self) -> float: 64 | return math.sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2) 65 | 66 | def norm(self) -> float: 67 | return math.sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2) 68 | 69 | def normalize(self) -> 'Vector3': 70 | l = self.length() 71 | if l > 0: 72 | return self / l 73 | return self 74 | 75 | # simple bbox 76 | class BoundingBox: 77 | 78 | def __init__(self, min_point: Vector3 = Vector3(), max_point: Vector3 = Vector3()) -> None: 79 | self.min = min_point 80 | self.max = max_point 81 | 82 | def __repr__(self) -> str: 83 | return f"BoundingBox({self.min}, {self.max})" 84 | 85 | def isEmpty(self) -> bool: 86 | return (self.max.x < self.min.x) or (self.max.y < self.min.y) or (self.max.z < self.min.z) 87 | 88 | def getCenter(self) -> Vector3: 89 | return Vector3(0, 0, 0) if self.isEmpty() else (self.min + self.max) * 0.5 90 | 91 | def getSize(self) -> Vector3: 92 | return Vector3(0, 0, 0) if self.isEmpty() else (self.max - self.min) 93 | 94 | def getBoundingSphere(self): 95 | sphere = BoundingSphere() 96 | sphere.set(center=self.getCenter(), radius=self.getSize().length() * 0.5) 97 | return sphere 98 | 99 | # simple bounding sphere 100 | class BoundingSphere: 101 | 102 | def __init__(self, center=Vector3(), radius=0.0) -> None: 103 | self.center = center 104 | self.radius = radius 105 | 106 | def __repr__(self) -> str: 107 | return f"BoundingSphere({self.center}, {self.radius})" 108 | 109 | def set(self, center: Vector3, radius: float) -> None: 110 | self.center = center 111 | self.radius = radius 112 | 113 | class OctreeGaussian: 114 | def __init__(self): 115 | self.spacing = 0 116 | self.boundingbox = None 117 | self.root = None 118 | self.scale = None 119 | self.pointAttributes = None 120 | self.loader = None 121 | self.maxLevel = 0 122 | 123 | class OctreeGaussianNode: 124 | # Static variables 125 | IDCount = 0 126 | 127 | def __init__(self, name, octreeGaussian, boundingbox: BoundingBox): 128 | self.id = OctreeGaussianNode.IDCount 129 | OctreeGaussianNode.IDCount += 1 130 | self.name = name 131 | self.index = 0 if name == "r" else int(name[-1]) 132 | self.nodeType = 0 133 | self.hierarchyByteOffset = 0 134 | self.hierarchyByteSize = 0 135 | self.byteOffset = 0 136 | self.byteSize = 0 137 | self.spacing = 0 138 | self.projection = 0 139 | self.offset = 0 140 | self.gaussian_model = None 141 | self.octreeGaussian = octreeGaussian 142 | self.pointcloud = None 143 | self.boundingbox = boundingbox 144 | self.children = [None for _ in range(8)] 145 | self.parent = None 146 | self.numGaussians = 0 147 | self.level = None 148 | self.loaded = False 149 | self.loading = False 150 | 151 | def loadGaussianData(self): 152 | pass 153 | 154 | def __repr__(self): 155 | return f"OctreeGaussianNode {self.name}" 156 | 157 | def isGeometryNode(self) -> bool: 158 | return True 159 | 160 | def isTreeNode(self) -> bool: 161 | return False 162 | 163 | def getLevel(self) -> int: 164 | return self.level 165 | 166 | def getBoundingSphere(self): 167 | return self.boundingSphere 168 | 169 | def getBoundingBox(self): 170 | return self.boundingbox 171 | 172 | def getChildren(self): 173 | return self.children 174 | 175 | def getNumGaussians(self): 176 | return self.numGaussians 177 | 178 | class BasicPointCloud(NamedTuple): 179 | points : np.array 180 | colors : np.array 181 | normals : np.array 182 | 183 | def geom_transform_points(points, transf_matrix): 184 | P, _ = points.shape 185 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 186 | points_hom = torch.cat([points, ones], dim=1) 187 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 188 | 189 | denom = points_out[..., 3:] + 0.0000001 190 | return (points_out[..., :3] / denom).squeeze(dim=0) 191 | 192 | def getWorld2View(R, t): 193 | Rt = np.zeros((4, 4)) 194 | Rt[:3, :3] = R.transpose() 195 | Rt[:3, 3] = t 196 | Rt[3, 3] = 1.0 197 | return np.float32(Rt) 198 | 199 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 200 | Rt = np.zeros((4, 4)) 201 | Rt[:3, :3] = R.transpose() 202 | Rt[:3, 3] = t 203 | Rt[3, 3] = 1.0 204 | 205 | C2W = np.linalg.inv(Rt) 206 | cam_center = C2W[:3, 3] 207 | cam_center = (cam_center + translate) * scale 208 | C2W[:3, 3] = cam_center 209 | Rt = np.linalg.inv(C2W) 210 | return np.float32(Rt) 211 | 212 | def getProjectionMatrix(znear, zfar, fovX, fovY): 213 | tanHalfFovY = math.tan((fovY / 2)) 214 | tanHalfFovX = math.tan((fovX / 2)) 215 | 216 | top = tanHalfFovY * znear 217 | bottom = -top 218 | right = tanHalfFovX * znear 219 | left = -right 220 | 221 | P = torch.zeros(4, 4) 222 | 223 | z_sign = 1.0 224 | 225 | P[0, 0] = 2.0 * znear / (right - left) 226 | P[1, 1] = 2.0 * znear / (top - bottom) 227 | P[0, 2] = (right + left) / (right - left) 228 | P[1, 2] = (top + bottom) / (top - bottom) 229 | P[3, 2] = z_sign 230 | P[2, 2] = z_sign * zfar / (zfar - znear) 231 | P[2, 3] = -(zfar * znear) / (zfar - znear) 232 | return P 233 | 234 | def fov2focal(fov, pixels): 235 | return pixels / (2 * math.tan(fov / 2)) 236 | 237 | def focal2fov(focal, pixels): 238 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | 21 | def warpped_depth(depth): 22 | return torch.where(depth < 10.0, depth / 10.0, 2.0 - 10 / depth) / 2.0 23 | 24 | def unwarpped_depth(depth): 25 | return torch.where(depth < 0.5, 2 * depth * 10.0, 10 / (1.0 - depth) / 2) 26 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------