├── .gitignore ├── .gitmodules ├── LICENSE.md ├── README.md ├── arguments └── __init__.py ├── environment.yml ├── gaussian_renderer └── __init__.py ├── scene ├── __init__.py └── gaussian_model.py ├── train.py └── utils ├── general_utils.py ├── loss_utils.py ├── sh_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = https://github.com/graphdeco-inria/diff-gaussian-rasterization 7 | [submodule "SIBR_viewers"] 8 | path = SIBR_viewers 9 | url = https://gitlab.inria.fr/sibr/sibr_core.git 10 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | 85 | ## 6. Files subject to permissive licenses 86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. 87 | 88 | Title: pytorch-ssim\ 89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\ 90 | Copyright Evan Su, 2017\ 91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A minimize training version of 3DGS 2 | 3 | ## Cloning the Repository 4 | 5 | The repository contains submodules, thus please check it out with 6 | ```shell 7 | git clone https://github.com/hugoycj/gaussian-splatting-mini --recursive 8 | ``` 9 | 10 | ### Setup 11 | ```shell 12 | SET DISTUTILS_USE_SDK=1 # Windows only 13 | conda env create --file environment.yml 14 | conda activate gaussian_splatting 15 | ``` 16 | 17 | ### Running 18 | ```shell 19 | python train.py -s 20 | ``` 21 | 22 |
23 | Command Line Arguments for train.py 24 | 25 | #### --source_path / -s 26 | Path to the source directory containing a COLMAP or Synthetic NeRF data set. 27 | #### --model_path / -m 28 | Path where the trained model should be stored (```output/``` by default). 29 | #### --images / -i 30 | Alternative subdirectory for COLMAP images (```images``` by default). 31 | #### --eval 32 | Add this flag to use a MipNeRF360-style training/test split for evaluation. 33 | #### --resolution / -r 34 | Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.** 35 | #### --data_device 36 | Specifies where to put the source image data, ```cuda``` by default, recommended to use ```cpu``` if training on large/high-resolution dataset, will reduce VRAM consumption, but slightly slow down training. Thanks to [HrsPythonix](https://github.com/HrsPythonix). 37 | #### --white_background / -w 38 | Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset. 39 | #### --sh_degree 40 | Order of spherical harmonics to be used (no larger than 3). ```3``` by default. 41 | #### --convert_SHs_python 42 | Flag to make pipeline compute forward and backward of SHs with PyTorch instead of ours. 43 | #### --convert_cov3D_python 44 | Flag to make pipeline compute forward and backward of the 3D covariance with PyTorch instead of ours. 45 | #### --debug 46 | Enables debug mode if you experience erros. If the rasterizer fails, a ```dump``` file is created that you may forward to us in an issue so we can take a look. 47 | #### --debug_from 48 | Debugging is **slow**. You may specify an iteration (starting from 0) after which the above debugging becomes active. 49 | #### --iterations 50 | Number of total iterations to train for, ```30_000``` by default. 51 | #### --test_iterations 52 | Space-separated iterations at which the training script computes L1 and PSNR over test set, ```7000 30000``` by default. 53 | #### --save_iterations 54 | Space-separated iterations at which the training script saves the Gaussian model, ```7000 30000 ``` by default. 55 | #### --checkpoint_iterations 56 | Space-separated iterations at which to store a checkpoint for continuing later, saved in the model directory. 57 | #### --start_checkpoint 58 | Path to a saved checkpoint to continue training from. 59 | #### --quiet 60 | Flag to omit any text written to standard out pipe. 61 | #### --feature_lr 62 | Spherical harmonics features learning rate, ```0.0025``` by default. 63 | #### --opacity_lr 64 | Opacity learning rate, ```0.05``` by default. 65 | #### --scaling_lr 66 | Scaling learning rate, ```0.005``` by default. 67 | #### --rotation_lr 68 | Rotation learning rate, ```0.001``` by default. 69 | #### --position_lr_max_steps 70 | Number of steps (from 0) where position learning rate goes from ```initial``` to ```final```. ```30_000``` by default. 71 | #### --position_lr_init 72 | Initial 3D position learning rate, ```0.00016``` by default. 73 | #### --position_lr_final 74 | Final 3D position learning rate, ```0.0000016``` by default. 75 | #### --position_lr_delay_mult 76 | Position learning rate multiplier (cf. Plenoxels), ```0.01``` by default. 77 | #### --densify_from_iter 78 | Iteration where densification starts, ```500``` by default. 79 | #### --densify_until_iter 80 | Iteration where densification stops, ```15_000``` by default. 81 | #### --densify_grad_threshold 82 | Limit that decides if points should be densified based on 2D position gradient, ```0.0002``` by default. 83 | #### --densification_interval 84 | How frequently to densify, ```100``` (every 100 iterations) by default. 85 | #### --opacity_reset_interval 86 | How frequently to reset opacity, ```3_000``` by default. 87 | #### --lambda_dssim 88 | Influence of SSIM on total loss from 0 to 1, ```0.2``` by default. 89 | #### --percent_dense 90 | Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. 91 | 92 |
93 |
94 | 95 | # License 96 | This project is licensed under the Gaussian-Splatting License - see the LICENSE file for details. -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._dataset = "colmap" 54 | self.initializer = "colmap" 55 | self._resolution = -1 56 | self._white_background = False 57 | self.data_device = "cuda" 58 | self.eval = False 59 | super().__init__(parser, "Loading Parameters", sentinel) 60 | 61 | def extract(self, args): 62 | g = super().extract(args) 63 | g.source_path = os.path.abspath(g.source_path) 64 | return g 65 | 66 | class PipelineParams(ParamGroup): 67 | def __init__(self, parser): 68 | self.convert_SHs_python = False 69 | self.compute_cov3D_python = False 70 | self.debug = False 71 | super().__init__(parser, "Pipeline Parameters") 72 | 73 | class OptimizationParams(ParamGroup): 74 | def __init__(self, parser): 75 | self.iterations = 30_000 76 | self.position_lr_init = 0.00016 77 | self.position_lr_final = 0.0000016 78 | self.position_lr_delay_mult = 0.01 79 | self.position_lr_max_steps = 30_000 80 | self.feature_lr = 0.0025 81 | self.opacity_lr = 0.05 82 | self.scaling_lr = 0.005 83 | self.rotation_lr = 0.001 84 | self.percent_dense = 0.01 85 | self.lambda_dssim = 0.2 86 | self.mask_from_iter = 0 87 | self.mask_until_iter = 7000 88 | self.lambda_mask = 0. 89 | self.densification_interval = 100 90 | self.opacity_reset_interval = 3000 91 | self.densify_from_iter = 500 92 | self.densify_until_iter = 15_000 93 | self.densify_grad_threshold = 0.0002 94 | self.random_background = False 95 | super().__init__(parser, "Optimization Parameters") 96 | 97 | def get_combined_args(parser : ArgumentParser): 98 | cmdlne_string = sys.argv[1:] 99 | cfgfile_string = "Namespace()" 100 | args_cmdline = parser.parse_args(cmdlne_string) 101 | 102 | try: 103 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 104 | print("Looking for config file in", cfgfilepath) 105 | with open(cfgfilepath) as cfg_file: 106 | print("Config file found: {}".format(cfgfilepath)) 107 | cfgfile_string = cfg_file.read() 108 | except TypeError: 109 | print("Config file not found at") 110 | pass 111 | args_cfgfile = eval(cfgfile_string) 112 | 113 | merged_dict = vars(args_cfgfile).copy() 114 | for k,v in vars(args_cmdline).items(): 115 | if v != None: 116 | merged_dict[k] = v 117 | return Namespace(**merged_dict) 118 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gaussian_splatting 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.6 8 | - plyfile 9 | - python=3.7.13 10 | - pip=22.3.1 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - tqdm 15 | - pip: 16 | - git+https://github.com/camenduru/simple-knn.git 17 | - git+https://github.com/GAP-LAB-CUHK-SZ/gaustudio.git 18 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from gaustudio_diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | 18 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(viewpoint_camera.image_height), 38 | image_width=int(viewpoint_camera.image_width), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=scaling_modifier, 43 | viewmatrix=viewpoint_camera.world_view_transform, 44 | projmatrix=viewpoint_camera.full_proj_transform, 45 | sh_degree=pc.active_sh_degree, 46 | campos=viewpoint_camera.camera_center, 47 | prefiltered=False, 48 | debug=pipe.debug 49 | ) 50 | 51 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 52 | 53 | means3D = pc.get_xyz 54 | means2D = screenspace_points 55 | opacity = pc.get_opacity 56 | 57 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 58 | # scaling / rotation by the rasterizer. 59 | scales = None 60 | rotations = None 61 | cov3D_precomp = None 62 | if pipe.compute_cov3D_python: 63 | cov3D_precomp = pc.get_covariance(scaling_modifier) 64 | else: 65 | scales = pc.get_scaling 66 | rotations = pc.get_rotation 67 | 68 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 69 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 70 | shs = None 71 | colors_precomp = None 72 | if override_color is None: 73 | if pipe.convert_SHs_python: 74 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 75 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 76 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 77 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 78 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 79 | else: 80 | shs = pc.get_features 81 | else: 82 | colors_precomp = override_color 83 | 84 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 85 | rendered_image, radii, rendered_depth, rendered_median_depth, rendered_final_opacity = rasterizer( 86 | means3D = means3D, 87 | means2D = means2D, 88 | shs = shs, 89 | colors_precomp = colors_precomp, 90 | opacities = opacity, 91 | scales = scales, 92 | rotations = rotations, 93 | cov3D_precomp = cov3D_precomp) 94 | 95 | rendered_image = rendered_image + (1-rendered_final_opacity).repeat(3, 1, 1) * bg_color 96 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 97 | # They will be excluded from value updates used in the splitting criteria. 98 | return {"render": rendered_image, 99 | "rendered_depth": rendered_depth, 100 | "rendered_median_depth": rendered_median_depth, 101 | "viewspace_points": screenspace_points, 102 | "visibility_filter" : radii > 0, 103 | "alpha": rendered_final_opacity, 104 | "radii": radii} 105 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.gaussian_model import GaussianModel 17 | from arguments import ModelParams 18 | from gaustudio import datasets 19 | 20 | class Scene: 21 | 22 | gaussians : GaussianModel 23 | 24 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 25 | """b 26 | :param path: Path to colmap scene main folder. 27 | """ 28 | self.model_path = args.model_path 29 | self.loaded_iter = None 30 | self.gaussians = gaussians 31 | 32 | if load_iteration: 33 | if load_iteration == -1: 34 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 35 | else: 36 | self.loaded_iter = load_iteration 37 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 38 | 39 | self.train_cameras = {} 40 | self.test_cameras = {} 41 | 42 | # Initialize dataset with gaustudio.datasets 43 | _dataset = datasets.make({"name": args.dataset, "source_path": args.source_path, \ 44 | "images": args.images, "white_background": args.white_background, \ 45 | "resolution":resolution_scales[0], "data_device":"cuda", \ 46 | "eval": False}) 47 | _dataset.export(os.path.join(self.model_path, "cameras.json")) 48 | print("Loading Training Cameras") 49 | self.train_cameras[resolution_scales[0]] = _dataset.all_cameras 50 | print("Loading Test Cameras") 51 | self.test_cameras[resolution_scales[0]] = [] 52 | self.cameras_extent = _dataset.cameras_extent 53 | 54 | if self.loaded_iter: 55 | self.gaussians.load_ply(os.path.join(self.model_path, 56 | "point_cloud", 57 | "iteration_" + str(self.loaded_iter), 58 | "point_cloud.ply")) 59 | else: 60 | # Initialize pcd with gaustudio.initializers 61 | from gaustudio.pipelines import initializers 62 | from gaustudio import models 63 | pcd = models.make("general_pcd") 64 | initializer_name = args.initializer 65 | if args.dataset == "colmap" and initializer_name == 'colmap': 66 | initializer_workspace = args.source_path 67 | else: 68 | initializer_workspace = os.path.join(args.source_path, 'tmp_'+initializer_name) 69 | initializer_config = {"name": initializer_name, "workspace_dir": initializer_workspace} 70 | initializer = initializers.make(initializer_config) 71 | initializer(pcd, _dataset, overwrite=False) 72 | self.gaussians.create_from_pcd(pcd, self.cameras_extent) 73 | 74 | def save(self, iteration): 75 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 76 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 77 | 78 | def getTrainCameras(self, scale=1.0): 79 | return self.train_cameras[scale] 80 | 81 | def getTestCameras(self, scale=1.0): 82 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation 15 | from torch import nn 16 | import os 17 | from utils.system_utils import mkdir_p 18 | from plyfile import PlyData, PlyElement 19 | from utils.sh_utils import RGB2SH 20 | from simple_knn._C import distCUDA2 21 | from utils.general_utils import strip_symmetric, build_scaling_rotation 22 | 23 | class GaussianModel: 24 | 25 | def setup_functions(self): 26 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 27 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 28 | actual_covariance = L @ L.transpose(1, 2) 29 | symm = strip_symmetric(actual_covariance) 30 | return symm 31 | 32 | self.scaling_activation = torch.exp 33 | self.scaling_inverse_activation = torch.log 34 | 35 | self.covariance_activation = build_covariance_from_scaling_rotation 36 | 37 | self.opacity_activation = torch.sigmoid 38 | self.inverse_opacity_activation = inverse_sigmoid 39 | 40 | self.rotation_activation = torch.nn.functional.normalize 41 | 42 | 43 | def __init__(self, sh_degree : int): 44 | self.active_sh_degree = 0 45 | self.max_sh_degree = sh_degree 46 | self._xyz = torch.empty(0) 47 | self._features_dc = torch.empty(0) 48 | self._features_rest = torch.empty(0) 49 | self._scaling = torch.empty(0) 50 | self._rotation = torch.empty(0) 51 | self._opacity = torch.empty(0) 52 | self.max_radii2D = torch.empty(0) 53 | self.xyz_gradient_accum = torch.empty(0) 54 | self.denom = torch.empty(0) 55 | self.optimizer = None 56 | self.percent_dense = 0 57 | self.spatial_lr_scale = 0 58 | self.setup_functions() 59 | 60 | def capture(self): 61 | return ( 62 | self.active_sh_degree, 63 | self._xyz, 64 | self._features_dc, 65 | self._features_rest, 66 | self._scaling, 67 | self._rotation, 68 | self._opacity, 69 | self.max_radii2D, 70 | self.xyz_gradient_accum, 71 | self.denom, 72 | self.optimizer.state_dict(), 73 | self.spatial_lr_scale, 74 | ) 75 | 76 | def restore(self, model_args, training_args): 77 | (self.active_sh_degree, 78 | self._xyz, 79 | self._features_dc, 80 | self._features_rest, 81 | self._scaling, 82 | self._rotation, 83 | self._opacity, 84 | self.max_radii2D, 85 | xyz_gradient_accum, 86 | denom, 87 | opt_dict, 88 | self.spatial_lr_scale) = model_args 89 | self.training_setup(training_args) 90 | self.xyz_gradient_accum = xyz_gradient_accum 91 | self.denom = denom 92 | self.optimizer.load_state_dict(opt_dict) 93 | 94 | @property 95 | def get_scaling(self): 96 | return self.scaling_activation(self._scaling) 97 | 98 | @property 99 | def get_rotation(self): 100 | return self.rotation_activation(self._rotation) 101 | 102 | @property 103 | def get_xyz(self): 104 | return self._xyz 105 | 106 | @property 107 | def get_features(self): 108 | features_dc = self._features_dc 109 | features_rest = self._features_rest 110 | return torch.cat((features_dc, features_rest), dim=1) 111 | 112 | @property 113 | def get_opacity(self): 114 | return self.opacity_activation(self._opacity) 115 | 116 | def get_covariance(self, scaling_modifier = 1): 117 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 118 | 119 | def oneupSHdegree(self): 120 | if self.active_sh_degree < self.max_sh_degree: 121 | self.active_sh_degree += 1 122 | 123 | def create_from_pcd(self, pcd, spatial_lr_scale : float): 124 | self.spatial_lr_scale = spatial_lr_scale 125 | fused_point_cloud = torch.tensor(np.asarray(pcd._xyz)).float().cuda() 126 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd._rgb)).float().cuda()) 127 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() 128 | features[:, :3, 0 ] = fused_color 129 | features[:, 3:, 1:] = 0.0 130 | 131 | print("Number of points at initialisation : ", fused_point_cloud.shape[0]) 132 | 133 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd._xyz)).float().cuda()), 0.0000001) 134 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) 135 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 136 | rots[:, 0] = 1 137 | 138 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 139 | 140 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 141 | self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) 142 | self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) 143 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 144 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 145 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 146 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 147 | 148 | def training_setup(self, training_args): 149 | self.percent_dense = training_args.percent_dense 150 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 151 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 152 | 153 | l = [ 154 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, 155 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 156 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 157 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 158 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, 159 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} 160 | ] 161 | 162 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 163 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, 164 | lr_final=training_args.position_lr_final*self.spatial_lr_scale, 165 | lr_delay_mult=training_args.position_lr_delay_mult, 166 | max_steps=training_args.position_lr_max_steps) 167 | 168 | def update_learning_rate(self, iteration): 169 | ''' Learning rate scheduling per step ''' 170 | for param_group in self.optimizer.param_groups: 171 | if param_group["name"] == "xyz": 172 | lr = self.xyz_scheduler_args(iteration) 173 | param_group['lr'] = lr 174 | return lr 175 | 176 | def construct_list_of_attributes(self): 177 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 178 | # All channels except the 3 DC 179 | for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): 180 | l.append('f_dc_{}'.format(i)) 181 | for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): 182 | l.append('f_rest_{}'.format(i)) 183 | l.append('opacity') 184 | for i in range(self._scaling.shape[1]): 185 | l.append('scale_{}'.format(i)) 186 | for i in range(self._rotation.shape[1]): 187 | l.append('rot_{}'.format(i)) 188 | return l 189 | 190 | def save_ply(self, path): 191 | mkdir_p(os.path.dirname(path)) 192 | 193 | xyz = self._xyz.detach().cpu().numpy() 194 | normals = np.zeros_like(xyz) 195 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 196 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 197 | opacities = self._opacity.detach().cpu().numpy() 198 | scale = self._scaling.detach().cpu().numpy() 199 | rotation = self._rotation.detach().cpu().numpy() 200 | 201 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 202 | 203 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 204 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 205 | elements[:] = list(map(tuple, attributes)) 206 | el = PlyElement.describe(elements, 'vertex') 207 | PlyData([el]).write(path) 208 | 209 | def reset_opacity(self): 210 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) 211 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 212 | self._opacity = optimizable_tensors["opacity"] 213 | 214 | def load_ply(self, path): 215 | plydata = PlyData.read(path) 216 | 217 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 218 | np.asarray(plydata.elements[0]["y"]), 219 | np.asarray(plydata.elements[0]["z"])), axis=1) 220 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 221 | 222 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 223 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 224 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 225 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 226 | 227 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 228 | extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) 229 | assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 230 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 231 | for idx, attr_name in enumerate(extra_f_names): 232 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 233 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 234 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 235 | 236 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 237 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) 238 | scales = np.zeros((xyz.shape[0], len(scale_names))) 239 | for idx, attr_name in enumerate(scale_names): 240 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 241 | 242 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 243 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) 244 | rots = np.zeros((xyz.shape[0], len(rot_names))) 245 | for idx, attr_name in enumerate(rot_names): 246 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 247 | 248 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 249 | self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 250 | self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 251 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 252 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 253 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 254 | 255 | self.active_sh_degree = self.max_sh_degree 256 | 257 | def replace_tensor_to_optimizer(self, tensor, name): 258 | optimizable_tensors = {} 259 | for group in self.optimizer.param_groups: 260 | if group["name"] == name: 261 | stored_state = self.optimizer.state.get(group['params'][0], None) 262 | stored_state["exp_avg"] = torch.zeros_like(tensor) 263 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 264 | 265 | del self.optimizer.state[group['params'][0]] 266 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 267 | self.optimizer.state[group['params'][0]] = stored_state 268 | 269 | optimizable_tensors[group["name"]] = group["params"][0] 270 | return optimizable_tensors 271 | 272 | def _prune_optimizer(self, mask): 273 | optimizable_tensors = {} 274 | for group in self.optimizer.param_groups: 275 | stored_state = self.optimizer.state.get(group['params'][0], None) 276 | if stored_state is not None: 277 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 278 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 279 | 280 | del self.optimizer.state[group['params'][0]] 281 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 282 | self.optimizer.state[group['params'][0]] = stored_state 283 | 284 | optimizable_tensors[group["name"]] = group["params"][0] 285 | else: 286 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 287 | optimizable_tensors[group["name"]] = group["params"][0] 288 | return optimizable_tensors 289 | 290 | def prune_points(self, mask): 291 | valid_points_mask = ~mask 292 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 293 | 294 | self._xyz = optimizable_tensors["xyz"] 295 | self._features_dc = optimizable_tensors["f_dc"] 296 | self._features_rest = optimizable_tensors["f_rest"] 297 | self._opacity = optimizable_tensors["opacity"] 298 | self._scaling = optimizable_tensors["scaling"] 299 | self._rotation = optimizable_tensors["rotation"] 300 | 301 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 302 | 303 | self.denom = self.denom[valid_points_mask] 304 | self.max_radii2D = self.max_radii2D[valid_points_mask] 305 | 306 | def cat_tensors_to_optimizer(self, tensors_dict): 307 | optimizable_tensors = {} 308 | for group in self.optimizer.param_groups: 309 | assert len(group["params"]) == 1 310 | extension_tensor = tensors_dict[group["name"]] 311 | stored_state = self.optimizer.state.get(group['params'][0], None) 312 | if stored_state is not None: 313 | 314 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) 315 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) 316 | 317 | del self.optimizer.state[group['params'][0]] 318 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 319 | self.optimizer.state[group['params'][0]] = stored_state 320 | 321 | optimizable_tensors[group["name"]] = group["params"][0] 322 | else: 323 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 324 | optimizable_tensors[group["name"]] = group["params"][0] 325 | 326 | return optimizable_tensors 327 | 328 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation): 329 | d = {"xyz": new_xyz, 330 | "f_dc": new_features_dc, 331 | "f_rest": new_features_rest, 332 | "opacity": new_opacities, 333 | "scaling" : new_scaling, 334 | "rotation" : new_rotation} 335 | 336 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 337 | self._xyz = optimizable_tensors["xyz"] 338 | self._features_dc = optimizable_tensors["f_dc"] 339 | self._features_rest = optimizable_tensors["f_rest"] 340 | self._opacity = optimizable_tensors["opacity"] 341 | self._scaling = optimizable_tensors["scaling"] 342 | self._rotation = optimizable_tensors["rotation"] 343 | 344 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 345 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 346 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 347 | 348 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 349 | n_init_points = self.get_xyz.shape[0] 350 | # Extract points that satisfy the gradient condition 351 | padded_grad = torch.zeros((n_init_points), device="cuda") 352 | padded_grad[:grads.shape[0]] = grads.squeeze() 353 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 354 | selected_pts_mask = torch.logical_and(selected_pts_mask, 355 | torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) 356 | 357 | stds = self.get_scaling[selected_pts_mask].repeat(N,1) 358 | means =torch.zeros((stds.size(0), 3),device="cuda") 359 | samples = torch.normal(mean=means, std=stds) 360 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) 361 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 362 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) 363 | new_rotation = self._rotation[selected_pts_mask].repeat(N,1) 364 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) 365 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) 366 | new_opacity = self._opacity[selected_pts_mask].repeat(N,1) 367 | 368 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) 369 | 370 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 371 | self.prune_points(prune_filter) 372 | 373 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 374 | # Extract points that satisfy the gradient condition 375 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 376 | selected_pts_mask = torch.logical_and(selected_pts_mask, 377 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) 378 | 379 | new_xyz = self._xyz[selected_pts_mask] 380 | new_features_dc = self._features_dc[selected_pts_mask] 381 | new_features_rest = self._features_rest[selected_pts_mask] 382 | new_opacities = self._opacity[selected_pts_mask] 383 | new_scaling = self._scaling[selected_pts_mask] 384 | new_rotation = self._rotation[selected_pts_mask] 385 | 386 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) 387 | 388 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 389 | grads = self.xyz_gradient_accum / self.denom 390 | grads[grads.isnan()] = 0.0 391 | 392 | self.densify_and_clone(grads, max_grad, extent) 393 | self.densify_and_split(grads, max_grad, extent) 394 | 395 | prune_mask = (self.get_opacity < min_opacity).squeeze() 396 | if max_screen_size: 397 | big_points_vs = self.max_radii2D > max_screen_size 398 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 399 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 400 | self.prune_points(prune_mask) 401 | 402 | torch.cuda.empty_cache() 403 | 404 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 405 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) 406 | self.denom[update_filter] += 1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim 16 | from gaussian_renderer import render 17 | import sys 18 | from scene import Scene, GaussianModel 19 | from utils.general_utils import safe_state 20 | import uuid 21 | from tqdm import tqdm 22 | from argparse import ArgumentParser, Namespace 23 | from arguments import ModelParams, PipelineParams, OptimizationParams 24 | 25 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 26 | first_iter = 0 27 | prepare_output(dataset) 28 | gaussians = GaussianModel(dataset.sh_degree) 29 | scene = Scene(dataset, gaussians) 30 | gaussians.training_setup(opt) 31 | if checkpoint: 32 | (model_params, first_iter) = torch.load(checkpoint) 33 | gaussians.restore(model_params, opt) 34 | 35 | iter_start = torch.cuda.Event(enable_timing = True) 36 | iter_end = torch.cuda.Event(enable_timing = True) 37 | 38 | viewpoint_stack = None 39 | ema_loss_for_log = 0.0 40 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 41 | first_iter += 1 42 | for iteration in range(first_iter, opt.iterations + 1): 43 | iter_start.record() 44 | 45 | gaussians.update_learning_rate(iteration) 46 | 47 | # Every 1000 its we increase the levels of SH up to a maximum degree 48 | if iteration % 1000 == 0: 49 | gaussians.oneupSHdegree() 50 | 51 | # Pick a random Camera 52 | if not viewpoint_stack: 53 | viewpoint_stack = scene.getTrainCameras().copy() 54 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)).to("cuda") 55 | 56 | # Render 57 | if (iteration - 1) == debug_from: 58 | pipe.debug = True 59 | 60 | if viewpoint_cam.bg_image is not None: 61 | bg = viewpoint_cam.bg_image.to("cuda").permute(2, 0, 1) 62 | else: 63 | if dataset.white_background: 64 | bg = torch.ones((3, viewpoint_cam.image_height, viewpoint_cam.image_width)).to("cuda") 65 | else: 66 | bg = torch.zeros((3, viewpoint_cam.image_height, viewpoint_cam.image_width)).to("cuda") 67 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg) 68 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 69 | alpha = render_pkg["alpha"][0] 70 | 71 | # Loss 72 | gt_image = viewpoint_cam.image.cuda().permute(2, 0, 1) 73 | Ll1 = l1_loss(image, gt_image) 74 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 75 | 76 | # mask loss 77 | if viewpoint_cam.mask is not None and opt.mask_from_iter < iteration < opt.mask_until_iter: 78 | loss += (~viewpoint_cam.mask * alpha).mean() * opt.lambda_mask 79 | 80 | loss.backward() 81 | 82 | iter_end.record() 83 | 84 | with torch.no_grad(): 85 | # Progress bar 86 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 87 | if iteration % 10 == 0: 88 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 89 | progress_bar.update(10) 90 | if iteration == opt.iterations: 91 | progress_bar.close() 92 | 93 | # Save 94 | if (iteration in saving_iterations): 95 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 96 | scene.save(iteration) 97 | 98 | # Densification 99 | if iteration < opt.densify_until_iter: 100 | # Keep track of max radii in image-space for pruning 101 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 102 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 103 | 104 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 105 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 106 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 107 | 108 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 109 | gaussians.reset_opacity() 110 | 111 | # Optimizer step 112 | if iteration < opt.iterations: 113 | gaussians.optimizer.step() 114 | gaussians.optimizer.zero_grad(set_to_none = True) 115 | 116 | if (iteration in checkpoint_iterations): 117 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 118 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 119 | 120 | def prepare_output(args): 121 | if not args.model_path: 122 | if os.getenv('OAR_JOB_ID'): 123 | unique_str=os.getenv('OAR_JOB_ID') 124 | else: 125 | unique_str = str(uuid.uuid4()) 126 | args.model_path = os.path.join("./output/", unique_str[0:10]) 127 | 128 | # Set up output folder 129 | print("Output folder: {}".format(args.model_path)) 130 | os.makedirs(args.model_path, exist_ok = True) 131 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 132 | cfg_log_f.write(str(Namespace(**vars(args)))) 133 | 134 | if __name__ == "__main__": 135 | # Set up command line argument parser 136 | parser = ArgumentParser(description="Training script parameters") 137 | lp = ModelParams(parser) 138 | op = OptimizationParams(parser) 139 | pp = PipelineParams(parser) 140 | parser.add_argument('--debug_from', type=int, default=-1) 141 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 142 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) 143 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) 144 | parser.add_argument("--quiet", action="store_true") 145 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 146 | parser.add_argument("--start_checkpoint", type=str, default = None) 147 | args = parser.parse_args(sys.argv[1:]) 148 | args.save_iterations.append(args.iterations) 149 | 150 | print("Optimizing " + args.model_path) 151 | 152 | # Initialize system state (RNG) 153 | safe_state(args.quiet) 154 | 155 | # Configure and run training 156 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 157 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 158 | 159 | # All done 160 | print("\nTraining complete.") 161 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------