├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE.md ├── README.md ├── arguments └── __init__.py ├── convert.py ├── distill_train.py ├── environment.yml ├── full_eval.py ├── gaussian_renderer ├── __init__.py ├── gaussian_count.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── prune.py ├── prune_finetune.py ├── render.py ├── render_video.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── scripts ├── run_distill_finetune.sh ├── run_prune_finetune.sh ├── run_prune_pt_finetune.sh ├── run_train_densify_prune.sh └── run_vectree_quantize.sh ├── static ├── prune_ratio_vs_ssim.svg └── table5.png ├── submodules └── simple-knn │ ├── ext.cpp │ ├── setup.py │ ├── simple_knn.cu │ ├── simple_knn.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt │ ├── simple_knn.h │ ├── simple_knn │ └── .gitkeep │ ├── spatial.cu │ └── spatial.h ├── train_densify_prune.py ├── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image.py ├── image_utils.py ├── logger_utils.py ├── loss_utils.py ├── pose_utils.py ├── save_imp_score.py ├── sh_utils.py ├── system_utils.py ├── tracker_utils.py └── vgg.py └── vectree ├── utils.py ├── vectree.py └── vq.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | logs_train 10 | vectree/pruned_distilled 11 | vectree/output 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | 5 | 6 | [submodule "submodules/compress-diff-gaussian-rasterization"] 7 | path = submodules/compress-diff-gaussian-rasterization 8 | url = https://github.com/Kevin-2017/compress-diff-gaussian-rasterization.git 9 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:22.04-py3 2 | RUN conda env create --file environment.yml 3 | RUN bash -c "conda init bash" -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightGaussian: Unbounded 3D Gaussian Compression with 15x Reduction and 200+ FPS 2 | 3 |

4 | 5 | 6 | 7 | 8 |

9 | 10 | 11 |
12 | 13 |
14 | 15 | ## User Guidance 16 | #### Gaussian Prune Ratio, Vector Quantization Ratio vs. FPS, SSIM 17 |
18 | 19 |
20 | 21 | #### Mild Compression Ratio, with Minimum Accuracy Degradation 22 |
23 | 24 |
25 | 26 | 27 | ## Setup 28 | #### Local Setup 29 | The codebase is based on [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) 30 | 31 | The used datasets, MipNeRF360 and Tank & Temple, are hosted by the paper authors [here](https://jonbarron.info/mipnerf360/). 32 | 33 | For installation: 34 | ``` 35 | git clone --recursive https://github.com/VITA-Group/LightGaussian.git 36 | cd LightGaussian 37 | # if you have already cloned LightGaussian: 38 | # git submodule update --init --recursive 39 | ``` 40 | ```shell 41 | conda env create --file environment.yml 42 | conda activate lightgaussian 43 | ``` 44 | note: we modified the "diff-gaussian-rasterization" in the submodule to get the Global Significant Score. 45 | 46 | 47 | ## Compress to Compact Representation 48 | 49 | Lightgaussian includes **3 ways** to make the 3D Gaussians be compact 50 | 51 | 52 | 53 | #### Option 1 Prune & Recovery 54 | Users can directly prune a trained 3D-GS checkpoint using the following command (default setting): 55 | ``` 56 | bash scripts/run_prune_finetune.sh 57 | ``` 58 | 59 | Users can also train from scratch and jointly prune redundant Gaussians in training using the following command (different setting from the paper): 60 | ``` 61 | bash scripts/run_train_densify_prune.sh 62 | ``` 63 | note: 3D-GS is trained for 20,000 iterations and then prune it. The resulting ply file is approximately 35% of the size of the original 3D-GS while ensuring a comparable quality level. 64 | 65 | 66 | #### Option 2 SH distillation 67 | Users can distill 3D-GS checkpoint using the following command (default setting): 68 | ``` 69 | bash scripts/run_distill_finetune.sh 70 | ``` 71 | 72 | #### Option 3 VecTree Quantization 73 | Users can quantize a pruned and distilled 3D-GS checkpoint using the following command (default setting): 74 | ``` 75 | bash scripts/run_vectree_quantize.sh 76 | ``` 77 | 78 | 79 | ## Render 80 | Render with trajectory. By default ellipse, you can change it to spiral or others trajectory by changing to corresponding function. 81 | ``` 82 | python render_video.py --source_path PATH/TO/DATASET --model_path PATH/TO/MODEL --skip_train --skip_test --video 83 | ``` 84 | For render after the Vectree Quantization stage, you could render them through 85 | ``` 86 | python render_video.py --load_vq 87 | ``` 88 | 89 | 90 | ## Example 91 | An example ckpt for room scene can be downloaded [here](), which mainly includes the following several parts: 92 | 93 | - point_cloud.ply —— Pruned, distilled and quantized 3D-GS checkpoint. 94 | - extreme_saving —— Relevant files obtained after vectree quantization. 95 | - imp_score.npz —— Global significance used in vectree quantization. 96 | 97 | 98 | 99 | ## TODO List 100 | - [x] Upload module 1: Prune & recovery 101 | - [x] Upload module 2: SH distillation 102 | - [x] Upload module 3: Vectree Quantization 103 | - [ ] Upload docker image 104 | 105 | ## Acknowledgements 106 | We would like to express our gratitude to [Yueyu Hu](https://huzi96.github.io/) from NYU for the invaluable discussion on our project. 107 | 108 | 109 | ## BibTeX 110 | If you find our work useful for your project, please consider citing the following paper. 111 | 112 | 113 | ``` 114 | @misc{fan2023lightgaussian, 115 | title={LightGaussian: Unbounded 3D Gaussian Compression with 15x Reduction and 200+ FPS}, 116 | author={Zhiwen Fan and Kevin Wang and Kairun Wen and Zehao Zhu and Dejia Xu and Zhangyang Wang}, 117 | year={2023}, 118 | eprint={2311.17245}, 119 | archivePrefix={arXiv}, 120 | primaryClass={cs.CV} } 121 | ``` 122 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | 17 | class GroupParams: 18 | pass 19 | 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | if shorthand: 32 | if t == bool: 33 | group.add_argument( 34 | "--" + key, ("-" + key[0:1]), default=value, action="store_true" 35 | ) 36 | else: 37 | group.add_argument( 38 | "--" + key, ("-" + key[0:1]), default=value, type=t 39 | ) 40 | else: 41 | if t == bool: 42 | group.add_argument("--" + key, default=value, action="store_true") 43 | else: 44 | group.add_argument("--" + key, default=value, type=t) 45 | 46 | def extract(self, args): 47 | group = GroupParams() 48 | for arg in vars(args).items(): 49 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 50 | setattr(group, arg[0], arg[1]) 51 | return group 52 | 53 | 54 | class ModelParams(ParamGroup): 55 | def __init__(self, parser, sentinel=False): 56 | self.sh_degree = 3 57 | self._source_path = "" 58 | self._model_path = "" 59 | self._images = "images" 60 | self._resolution = -1 61 | self._white_background = False 62 | self.data_device = "cuda" 63 | self.eval = False 64 | super().__init__(parser, "Loading Parameters", sentinel) 65 | 66 | def extract(self, args): 67 | g = super().extract(args) 68 | g.source_path = os.path.abspath(g.source_path) 69 | return g 70 | 71 | 72 | class PipelineParams(ParamGroup): 73 | def __init__(self, parser): 74 | self.convert_SHs_python = False 75 | self.compute_cov3D_python = False 76 | self.debug = False 77 | super().__init__(parser, "Pipeline Parameters") 78 | 79 | 80 | class OptimizationParams(ParamGroup): 81 | def __init__(self, parser): 82 | self.iterations = 30_000 83 | self.position_lr_init = 0.00016 84 | self.position_lr_final = 0.0000016 85 | self.position_lr_delay_mult = 0.01 86 | self.position_lr_max_steps = 30_000 87 | self.feature_lr = 0.0025 88 | self.opacity_lr = 0.05 89 | self.scaling_lr = 0.005 90 | self.rotation_lr = 0.001 91 | self.percent_dense = 0.01 92 | self.lambda_dssim = 0.2 93 | self.densification_interval = 100 94 | self.opacity_reset_interval = 3000 95 | self.densify_from_iter = 500 96 | self.densify_until_iter = 15_000 97 | self.densify_grad_threshold = 0.0002 98 | super().__init__(parser, "Optimization Parameters") 99 | 100 | 101 | def get_combined_args(parser: ArgumentParser): 102 | cmdlne_string = sys.argv[1:] 103 | cfgfile_string = "Namespace()" 104 | args_cmdline = parser.parse_args(cmdlne_string) 105 | 106 | try: 107 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 108 | print("Looking for config file in", cfgfilepath) 109 | with open(cfgfilepath) as cfg_file: 110 | print("Config file found: {}".format(cfgfilepath)) 111 | cfgfile_string = cfg_file.read() 112 | except TypeError: 113 | print("Config file not found at") 114 | pass 115 | args_cfgfile = eval(cfgfile_string) 116 | 117 | merged_dict = vars(args_cfgfile).copy() 118 | for k, v in vars(args_cmdline).items(): 119 | if v != None: 120 | merged_dict[k] = v 121 | return Namespace(**merged_dict) 122 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/input \ 61 | --output_path " + args.source_path + "/distorted/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/input \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /distill_train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | from scene import Scene, GaussianModel 19 | from utils.general_utils import safe_state 20 | import uuid 21 | from os import makedirs 22 | from tqdm import tqdm 23 | from utils.image_utils import psnr 24 | from argparse import ArgumentParser, Namespace 25 | from arguments import ModelParams, PipelineParams, OptimizationParams 26 | from utils.graphics_utils import getWorld2View2 27 | from utils.pose_utils import gaussian_poses 28 | from icecream import ic 29 | import random 30 | import copy 31 | import json 32 | import numpy as np 33 | from utils.logger_utils import prepare_output_and_logger, training_report 34 | from torch.optim.lr_scheduler import ExponentialLR 35 | from prune import prune_list, calculate_v_imp_score 36 | 37 | 38 | try: 39 | from torch.utils.tensorboard import SummaryWriter 40 | TENSORBOARD_FOUND = True 41 | except ImportError: 42 | TENSORBOARD_FOUND = False 43 | 44 | class NumpyArrayEncoder(json.JSONEncoder): 45 | def default(self, obj): 46 | if isinstance(obj, np.integer): 47 | return int(obj) 48 | elif isinstance(obj, np.floating): 49 | return float(obj) 50 | elif isinstance(obj, np.ndarray): 51 | return obj.tolist() 52 | else: 53 | return super(NumpyArrayEncoder, self).default(obj) 54 | 55 | 56 | to_tensor = lambda x: x.to("cuda") if isinstance( 57 | x, torch.Tensor) else torch.Tensor(x).to("cuda") 58 | img2mse = lambda x, y: torch.mean((x - y)**2) 59 | mse2psnr = lambda x: -10. * torch.log(x) / torch.log(to_tensor([10.])) 60 | 61 | def training(args, dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, new_max_sh): 62 | first_iter = 0 63 | old_sh_degree = dataset.sh_degree 64 | dataset.sh_degree = new_max_sh 65 | tb_writer = prepare_output_and_logger(dataset) 66 | with torch.no_grad(): 67 | teacher_gaussians = GaussianModel(old_sh_degree) 68 | # teacher_gaussians.training_setup(opt) 69 | 70 | student_gaussians = GaussianModel(old_sh_degree) 71 | student_scene = Scene(dataset, student_gaussians) 72 | 73 | if checkpoint: 74 | (teacher_model_params, _) = torch.load(args.teacher_model) 75 | (model_params, first_iter) = torch.load(checkpoint) 76 | teacher_gaussians.restore(teacher_model_params, copy.deepcopy(opt)) 77 | student_gaussians.restore(model_params, opt) 78 | student_gaussians.max_sh_degree = new_max_sh 79 | student_gaussians.onedownSHdegree() 80 | student_gaussians.training_setup(opt) 81 | student_gaussians.scheduler = ExponentialLR(student_gaussians.optimizer, gamma=0.90) 82 | # if !args.enable 83 | if (not args.enable_covariance): 84 | student_gaussians._scaling.requires_grad = False 85 | student_gaussians._rotation.requires_grad = False 86 | if (not args.enable_opacity): 87 | student_gaussians._opacity.requires_grad = False 88 | 89 | teacher_gaussians.optimizer = None 90 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 91 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 92 | iter_start = torch.cuda.Event(enable_timing = True) 93 | iter_end = torch.cuda.Event(enable_timing = True) 94 | viewpoint_stack = None 95 | ema_loss_for_log = 0.0 96 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 97 | first_iter += 1 98 | 99 | # os.makedirs(student_scene.model_path + "/vis_data", exist_ok=True) 100 | for iteration in range(first_iter, opt.iterations + 1): 101 | if network_gui.conn == None: 102 | network_gui.try_connect() 103 | while network_gui.conn != None: 104 | try: 105 | net_image_bytes = None 106 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 107 | if custom_cam != None: 108 | net_image = render(custom_cam, student_gaussians, pipe, background, scaling_modifer)["render"] 109 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 110 | network_gui.send(net_image_bytes, dataset.source_path) 111 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 112 | break 113 | except Exception as e: 114 | network_gui.conn = None 115 | 116 | iter_start.record() 117 | student_gaussians.update_learning_rate(iteration) 118 | 119 | # Every 500 iterations step in scheduler 120 | if iteration % 500 == 0: 121 | # student_gaussians.oneupSHdegree() 122 | student_gaussians.scheduler.step() 123 | 124 | if not viewpoint_stack: 125 | viewpoint_stack = student_scene.getTrainCameras().copy() 126 | viewpoint_cam_org = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 127 | viewpoint_cam = copy.deepcopy(viewpoint_cam_org) 128 | 129 | if (iteration - 1) == debug_from: 130 | pipe.debug = True 131 | 132 | if args.augmented_view and iteration%3: 133 | viewpoint_cam = gaussian_poses(viewpoint_cam, mean= 0, std_dev_translation=0.05, std_dev_rotation=0) 134 | student_render_pkg = render(viewpoint_cam, student_gaussians, pipe, background) 135 | student_image = student_render_pkg["render"] 136 | teacher_render_pkg = render(viewpoint_cam, teacher_gaussians, pipe, background) 137 | teacher_image = teacher_render_pkg["render"].detach() 138 | else: 139 | render_pkg = render(viewpoint_cam, student_gaussians, pipe, background) 140 | student_image = render_pkg["render"] 141 | teacher_image = render(viewpoint_cam, teacher_gaussians, pipe, background)["render"].detach() 142 | Ll1 = l1_loss(student_image, teacher_image) 143 | # Ll1 = img2mse(student_image, teacher_image) 144 | 145 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(student_image, teacher_image)) 146 | loss.backward() 147 | iter_end.record() 148 | with torch.no_grad(): 149 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 150 | if iteration % 10 == 0: 151 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 152 | progress_bar.update(10) 153 | if iteration == opt.iterations: 154 | progress_bar.close() 155 | 156 | if (iteration in saving_iterations): 157 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 158 | ic(student_gaussians._features_rest.detach().shape) 159 | student_scene.save(iteration) 160 | 161 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, student_scene, render, (pipe, background)) 162 | 163 | # Optimizer step 164 | if iteration < opt.iterations: 165 | student_gaussians.optimizer.step() 166 | student_gaussians.optimizer.zero_grad(set_to_none = True) 167 | 168 | if (iteration in checkpoint_iterations): 169 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 170 | if not os.path.exists(student_scene.model_path): 171 | os.makedirs(student_scene.model_path) 172 | torch.save((student_gaussians.capture(), iteration), student_scene.model_path + "/chkpnt" + str(iteration) + ".pth") 173 | 174 | if iteration == checkpoint_iterations[-1]: 175 | print("Saving Imp_score") 176 | gaussian_list, imp_list = prune_list( 177 | student_gaussians, student_scene, pipe, background 178 | ) 179 | v_list = calculate_v_imp_score(student_gaussians, imp_list, 0.1) 180 | np.savez( 181 | os.path.join(student_scene.model_path, "imp_score"), 182 | v_list.cpu().detach().numpy(), 183 | ) 184 | 185 | 186 | if __name__ == "__main__": 187 | # Set up command line argument parser 188 | parser = ArgumentParser(description="Training script parameters") 189 | lp = ModelParams(parser) 190 | op = OptimizationParams(parser) 191 | pp = PipelineParams(parser) 192 | parser.add_argument('--ip', type=str, default="127.0.0.1") 193 | parser.add_argument('--port', type=int, default=6009) 194 | parser.add_argument('--debug_from', type=int, default=-1) 195 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 196 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[35_001, 40_000]) 197 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[40_000]) 198 | parser.add_argument("--quiet", action="store_true") 199 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[40_000]) 200 | parser.add_argument("--start_checkpoint", type=str, default = None) 201 | parser.add_argument("--new_max_sh", type=int, default = 2) 202 | parser.add_argument("--augmented_view", action="store_true") 203 | parser.add_argument("--enable_covariance", action="store_true") 204 | parser.add_argument("--enable_opacity", action="store_true") 205 | parser.add_argument("--opacity_prune", type=float, default = 0) 206 | parser.add_argument("--teacher_model", type=str) 207 | 208 | args = parser.parse_args(sys.argv[1:]) 209 | args.save_iterations.append(args.iterations) 210 | 211 | print("Optimizing " + args.model_path) 212 | 213 | # Initialize system state (RNG) 214 | safe_state(args.quiet) 215 | 216 | # Start GUI server, configure and run training 217 | network_gui.init(args.ip, args.port) 218 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 219 | training(args, lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args.new_max_sh) 220 | 221 | # All done 222 | print("\nTraining complete.") 223 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lightgaussian 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.6 8 | - plyfile=0.8.1 9 | - python=3.9 10 | - pip=22.3.1 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - setuptools=69.5.1 15 | - tqdm 16 | - icecream 17 | - pip: 18 | - submodules/compress-diff-gaussian-rasterization 19 | - submodules/simple-knn 20 | -------------------------------------------------------------------------------- /full_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] 17 | tanks_and_temples_scenes = ["truck", "train"] 18 | deep_blending_scenes = ["drjohnson", "playroom"] 19 | 20 | parser = ArgumentParser(description="Full evaluation script parameters") 21 | parser.add_argument("--skip_training", action="store_true") 22 | parser.add_argument("--skip_rendering", action="store_true") 23 | parser.add_argument("--skip_metrics", action="store_true") 24 | parser.add_argument("--output_path", default="./eval") 25 | args, _ = parser.parse_known_args() 26 | 27 | all_scenes = [] 28 | all_scenes.extend(mipnerf360_outdoor_scenes) 29 | all_scenes.extend(mipnerf360_indoor_scenes) 30 | all_scenes.extend(tanks_and_temples_scenes) 31 | all_scenes.extend(deep_blending_scenes) 32 | 33 | if not args.skip_training or not args.skip_rendering: 34 | parser.add_argument("--mipnerf360", "-m360", required=True, type=str) 35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) 36 | parser.add_argument("--deepblending", "-db", required=True, type=str) 37 | args = parser.parse_args() 38 | 39 | if not args.skip_training: 40 | common_args = " --quiet --eval --test_iterations -1 " 41 | for scene in mipnerf360_outdoor_scenes: 42 | source = args.mipnerf360 + "/" + scene 43 | os.system( 44 | "python train.py -s " 45 | + source 46 | + " -i images_4 -m " 47 | + args.output_path 48 | + "/" 49 | + scene 50 | + common_args 51 | ) 52 | for scene in mipnerf360_indoor_scenes: 53 | source = args.mipnerf360 + "/" + scene 54 | os.system( 55 | "python train.py -s " 56 | + source 57 | + " -i images_2 -m " 58 | + args.output_path 59 | + "/" 60 | + scene 61 | + common_args 62 | ) 63 | for scene in tanks_and_temples_scenes: 64 | source = args.tanksandtemples + "/" + scene 65 | os.system( 66 | "python train.py -s " 67 | + source 68 | + " -m " 69 | + args.output_path 70 | + "/" 71 | + scene 72 | + common_args 73 | ) 74 | for scene in deep_blending_scenes: 75 | source = args.deepblending + "/" + scene 76 | os.system( 77 | "python train.py -s " 78 | + source 79 | + " -m " 80 | + args.output_path 81 | + "/" 82 | + scene 83 | + common_args 84 | ) 85 | 86 | if not args.skip_rendering: 87 | all_sources = [] 88 | for scene in mipnerf360_outdoor_scenes: 89 | all_sources.append(args.mipnerf360 + "/" + scene) 90 | for scene in mipnerf360_indoor_scenes: 91 | all_sources.append(args.mipnerf360 + "/" + scene) 92 | for scene in tanks_and_temples_scenes: 93 | all_sources.append(args.tanksandtemples + "/" + scene) 94 | for scene in deep_blending_scenes: 95 | all_sources.append(args.deepblending + "/" + scene) 96 | 97 | common_args = " --quiet --eval --skip_train" 98 | for scene, source in zip(all_scenes, all_sources): 99 | os.system( 100 | "python render.py --iteration 7000 -s " 101 | + source 102 | + " -m " 103 | + args.output_path 104 | + "/" 105 | + scene 106 | + common_args 107 | ) 108 | os.system( 109 | "python render.py --iteration 30000 -s " 110 | + source 111 | + " -m " 112 | + args.output_path 113 | + "/" 114 | + scene 115 | + common_args 116 | ) 117 | 118 | if not args.skip_metrics: 119 | scenes_string = "" 120 | for scene in all_scenes: 121 | scenes_string += '"' + args.output_path + "/" + scene + '" ' 122 | 123 | os.system("python metrics.py -m " + scenes_string) 124 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import ( 15 | GaussianRasterizationSettings, 16 | GaussianRasterizer, 17 | ) 18 | from scene.gaussian_model import GaussianModel 19 | from utils.sh_utils import eval_sh 20 | 21 | 22 | def render( 23 | viewpoint_camera, 24 | pc: GaussianModel, 25 | pipe, 26 | bg_color: torch.Tensor, 27 | scaling_modifier=1.0, 28 | override_color=None, 29 | ): 30 | """ 31 | Render the scene. 32 | 33 | Background tensor (bg_color) must be on GPU! 34 | """ 35 | 36 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 37 | screenspace_points = ( 38 | torch.zeros_like( 39 | pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" 40 | ) 41 | + 0 42 | ) 43 | try: 44 | screenspace_points.retain_grad() 45 | except: 46 | pass 47 | 48 | # Set up rasterization configuration 49 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 50 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 51 | 52 | raster_settings = GaussianRasterizationSettings( 53 | image_height=int(viewpoint_camera.image_height), 54 | image_width=int(viewpoint_camera.image_width), 55 | tanfovx=tanfovx, 56 | tanfovy=tanfovy, 57 | bg=bg_color, 58 | scale_modifier=scaling_modifier, 59 | viewmatrix=viewpoint_camera.world_view_transform, 60 | projmatrix=viewpoint_camera.full_proj_transform, 61 | sh_degree=pc.active_sh_degree, 62 | campos=viewpoint_camera.camera_center, 63 | prefiltered=False, 64 | debug=pipe.debug, 65 | f_count=False, 66 | ) 67 | 68 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 69 | 70 | means3D = pc.get_xyz 71 | means2D = screenspace_points 72 | opacity = pc.get_opacity 73 | 74 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 75 | # scaling / rotation by the rasterizer. 76 | scales = None 77 | rotations = None 78 | cov3D_precomp = None 79 | if pipe.compute_cov3D_python: 80 | cov3D_precomp = pc.get_covariance(scaling_modifier) 81 | else: 82 | scales = pc.get_scaling 83 | rotations = pc.get_rotation 84 | 85 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 86 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 87 | shs = None 88 | colors_precomp = None 89 | if override_color is None: 90 | if pipe.convert_SHs_python: 91 | shs_view = pc.get_features.transpose(1, 2).view( 92 | -1, 3, (pc.max_sh_degree + 1) ** 2 93 | ) 94 | dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat( 95 | pc.get_features.shape[0], 1 96 | ) 97 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 98 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 99 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 100 | else: 101 | shs = pc.get_features 102 | else: 103 | colors_precomp = override_color 104 | 105 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 106 | rendered_image, radii = rasterizer( 107 | means3D=means3D, 108 | means2D=means2D, 109 | shs=shs, 110 | colors_precomp=colors_precomp, 111 | opacities=opacity, 112 | scales=scales, 113 | rotations=rotations, 114 | cov3D_precomp=cov3D_precomp, 115 | ) 116 | 117 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 118 | # They will be excluded from value updates used in the splitting criteria. 119 | return { 120 | "render": rendered_image, 121 | "viewspace_points": screenspace_points, 122 | "visibility_filter": radii > 0, 123 | "radii": radii, 124 | } 125 | 126 | 127 | def count_render( 128 | viewpoint_camera, 129 | pc: GaussianModel, 130 | pipe, 131 | bg_color: torch.Tensor, 132 | scaling_modifier=1.0, 133 | override_color=None, 134 | ): 135 | """ 136 | Render the scene. 137 | 138 | Background tensor (bg_color) must be on GPU! 139 | """ 140 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 141 | screenspace_points = ( 142 | torch.zeros_like( 143 | pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" 144 | ) 145 | + 0 146 | ) 147 | try: 148 | screenspace_points.retain_grad() 149 | except: 150 | pass 151 | 152 | # Set up rasterization configuration 153 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 154 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 155 | 156 | raster_settings = GaussianRasterizationSettings( 157 | image_height=int(viewpoint_camera.image_height), 158 | image_width=int(viewpoint_camera.image_width), 159 | tanfovx=tanfovx, 160 | tanfovy=tanfovy, 161 | bg=bg_color, 162 | scale_modifier=scaling_modifier, 163 | viewmatrix=viewpoint_camera.world_view_transform, 164 | projmatrix=viewpoint_camera.full_proj_transform, 165 | sh_degree=pc.active_sh_degree, 166 | campos=viewpoint_camera.camera_center, 167 | prefiltered=False, 168 | debug=pipe.debug, 169 | f_count=True, 170 | ) 171 | 172 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 173 | means3D = pc.get_xyz 174 | means2D = screenspace_points 175 | opacity = pc.get_opacity 176 | 177 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 178 | # scaling / rotation by the rasterizer. 179 | scales = None 180 | rotations = None 181 | cov3D_precomp = None 182 | if pipe.compute_cov3D_python: 183 | cov3D_precomp = pc.get_covariance(scaling_modifier) 184 | else: 185 | scales = pc.get_scaling 186 | rotations = pc.get_rotation 187 | 188 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 189 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 190 | shs = None 191 | colors_precomp = None 192 | if override_color is None: 193 | if pipe.convert_SHs_python: 194 | shs_view = pc.get_features.transpose(1, 2).view( 195 | -1, 3, (pc.max_sh_degree + 1) ** 2 196 | ) 197 | dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat( 198 | pc.get_features.shape[0], 1 199 | ) 200 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 201 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 202 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 203 | else: 204 | shs = pc.get_features 205 | else: 206 | colors_precomp = override_color 207 | 208 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 209 | gaussians_count, important_score, rendered_image, radii = rasterizer( 210 | means3D=means3D, 211 | means2D=means2D, 212 | shs=shs, 213 | colors_precomp=colors_precomp, 214 | opacities=opacity, 215 | scales=scales, 216 | rotations=rotations, 217 | cov3D_precomp=cov3D_precomp, 218 | ) 219 | 220 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 221 | # They will be excluded from value updates used in the splitting criteria. 222 | return { 223 | "render": rendered_image, 224 | "viewspace_points": screenspace_points, 225 | "visibility_filter": radii > 0, 226 | "radii": radii, 227 | "gaussians_count": gaussians_count, 228 | "important_score": important_score, 229 | } 230 | -------------------------------------------------------------------------------- /gaussian_renderer/gaussian_count.py: -------------------------------------------------------------------------------- 1 | # base on __ini__.render 2 | 3 | # 4 | # Copyright (C) 2023, Inria 5 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 6 | # All rights reserved. 7 | # 8 | # This software is free for non-commercial, research and evaluation use 9 | # under the terms of the LICENSE.md file. 10 | # 11 | # For inquiries contact george.drettakis@inria.fr 12 | # 13 | 14 | import torch 15 | import math 16 | from diff_gaussian_rasterization import ( 17 | GaussianRasterizationSettings, 18 | GaussianRasterizer, 19 | ) 20 | from scene.gaussian_model import GaussianModel 21 | from utils.sh_utils import eval_sh 22 | 23 | 24 | def count_render( 25 | viewpoint_camera, 26 | pc: GaussianModel, 27 | pipe, 28 | bg_color: torch.Tensor, 29 | scaling_modifier=1.0, 30 | override_color=None, 31 | ): 32 | """ 33 | Render the scene. 34 | 35 | Background tensor (bg_color) must be on GPU! 36 | """ 37 | 38 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 39 | screenspace_points = ( 40 | torch.zeros_like( 41 | pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" 42 | ) 43 | + 0 44 | ) 45 | try: 46 | screenspace_points.retain_grad() 47 | except: 48 | pass 49 | 50 | # Set up rasterization configuration 51 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 52 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 53 | 54 | raster_settings = GaussianRasterizationSettings( 55 | image_height=int(viewpoint_camera.image_height), 56 | image_width=int(viewpoint_camera.image_width), 57 | tanfovx=tanfovx, 58 | tanfovy=tanfovy, 59 | bg=bg_color, 60 | scale_modifier=scaling_modifier, 61 | viewmatrix=viewpoint_camera.world_view_transform, 62 | projmatrix=viewpoint_camera.full_proj_transform, 63 | sh_degree=pc.active_sh_degree, 64 | campos=viewpoint_camera.camera_center, 65 | prefiltered=False, 66 | debug=pipe.debug, 67 | ) 68 | 69 | rasterizer = GaussianRasterizer(raster_settings=raster_settings, f_count=True) 70 | 71 | means3D = pc.get_xyz 72 | means2D = screenspace_points 73 | opacity = pc.get_opacity 74 | 75 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 76 | # scaling / rotation by the rasterizer. 77 | scales = None 78 | rotations = None 79 | cov3D_precomp = None 80 | if pipe.compute_cov3D_python: 81 | cov3D_precomp = pc.get_covariance(scaling_modifier) 82 | else: 83 | scales = pc.get_scaling 84 | rotations = pc.get_rotation 85 | 86 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 87 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 88 | shs = None 89 | colors_precomp = None 90 | if override_color is None: 91 | if pipe.convert_SHs_python: 92 | shs_view = pc.get_features.transpose(1, 2).view( 93 | -1, 3, (pc.max_sh_degree + 1) ** 2 94 | ) 95 | dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat( 96 | pc.get_features.shape[0], 1 97 | ) 98 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 99 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 100 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 101 | else: 102 | shs = pc.get_features 103 | else: 104 | colors_precomp = override_color 105 | 106 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 107 | ( 108 | gaussians_count, 109 | important_score, 110 | rendered_image, 111 | radii, 112 | ) = rasterizer.forward_counter( 113 | means3D=means3D, 114 | means2D=means2D, 115 | shs=shs, 116 | colors_precomp=colors_precomp, 117 | opacities=opacity, 118 | scales=scales, 119 | rotations=rotations, 120 | cov3D_precomp=cov3D_precomp, 121 | ) 122 | 123 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 124 | # They will be excluded from value updates used in the splitting criteria. 125 | return { 126 | "render": rendered_image, 127 | "viewspace_points": screenspace_points, 128 | "visibility_filter": radii > 0, 129 | "radii": radii, 130 | "gaussians_count": gaussians_count, 131 | "important_score": important_score, 132 | } 133 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | 27 | def init(wish_host, wish_port): 28 | global host, port, listener 29 | host = wish_host 30 | port = wish_port 31 | listener.bind((host, port)) 32 | listener.listen() 33 | listener.settimeout(0) 34 | 35 | 36 | def try_connect(): 37 | global conn, addr, listener 38 | try: 39 | conn, addr = listener.accept() 40 | print(f"\nConnected by {addr}") 41 | conn.settimeout(None) 42 | except Exception as inst: 43 | pass 44 | 45 | 46 | def read(): 47 | global conn 48 | messageLength = conn.recv(4) 49 | messageLength = int.from_bytes(messageLength, "little") 50 | message = conn.recv(messageLength) 51 | return json.loads(message.decode("utf-8")) 52 | 53 | 54 | def send(message_bytes, verify): 55 | global conn 56 | if message_bytes != None: 57 | conn.sendall(message_bytes) 58 | conn.sendall(len(verify).to_bytes(4, "little")) 59 | conn.sendall(bytes(verify, "ascii")) 60 | 61 | 62 | def receive(): 63 | message = read() 64 | 65 | width = message["resolution_x"] 66 | height = message["resolution_y"] 67 | 68 | if width != 0 and height != 0: 69 | try: 70 | do_training = bool(message["train"]) 71 | fovy = message["fov_y"] 72 | fovx = message["fov_x"] 73 | znear = message["z_near"] 74 | zfar = message["z_far"] 75 | do_shs_python = bool(message["shs_python"]) 76 | do_rot_scale_python = bool(message["rot_scale_python"]) 77 | keep_alive = bool(message["keep_alive"]) 78 | scaling_modifier = message["scaling_modifier"] 79 | world_view_transform = torch.reshape( 80 | torch.tensor(message["view_matrix"]), (4, 4) 81 | ).cuda() 82 | world_view_transform[:, 1] = -world_view_transform[:, 1] 83 | world_view_transform[:, 2] = -world_view_transform[:, 2] 84 | full_proj_transform = torch.reshape( 85 | torch.tensor(message["view_projection_matrix"]), (4, 4) 86 | ).cuda() 87 | full_proj_transform[:, 1] = -full_proj_transform[:, 1] 88 | custom_cam = MiniCam( 89 | width, 90 | height, 91 | fovy, 92 | fovx, 93 | znear, 94 | zfar, 95 | world_view_transform, 96 | full_proj_transform, 97 | ) 98 | except Exception as e: 99 | print("") 100 | traceback.print_exc() 101 | raise e 102 | return ( 103 | custom_cam, 104 | do_training, 105 | do_shs_python, 106 | do_rot_scale_python, 107 | keep_alive, 108 | scaling_modifier, 109 | ) 110 | else: 111 | return None, None, None, None, None, None 112 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips( 7 | x: torch.Tensor, y: torch.Tensor, net_type: str = "alex", version: str = "0.1" 8 | ): 9 | r"""Function that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | x, y (torch.Tensor): the input tensors to compare. 14 | net_type (str): the network type to compare the features: 15 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 16 | version (str): the version of LPIPS. Default: 0.1. 17 | """ 18 | device = x.device 19 | criterion = LPIPS(net_type, version).to(device) 20 | return criterion(x, y) 21 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | 18 | def __init__(self, net_type: str = "alex", version: str = "0.1"): 19 | assert version in ["0.1"], "v0.1 is only supported now" 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == "alex": 14 | return AlexNet() 15 | elif net_type == "squeeze": 16 | return SqueezeNet() 17 | elif net_type == "vgg": 18 | return VGG16() 19 | else: 20 | raise NotImplementedError("choose net_type from [alex, squeeze, vgg].") 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__( 26 | [ 27 | nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False)) 28 | for nc in n_channels_list 29 | ] 30 | ) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | "mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 43 | ) 44 | self.register_buffer( 45 | "std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 46 | ) 47 | 48 | def set_requires_grad(self, state: bool): 49 | for param in chain(self.parameters(), self.buffers()): 50 | param.requires_grad = state 51 | 52 | def z_score(self, x: torch.Tensor): 53 | return (x - self.mean) / self.std 54 | 55 | def forward(self, x: torch.Tensor): 56 | x = self.z_score(x) 57 | 58 | output = [] 59 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 60 | x = layer(x) 61 | if i in self.target_layers: 62 | output.append(normalize_activation(x)) 63 | if len(output) == len(self.target_layers): 64 | break 65 | return output 66 | 67 | 68 | class SqueezeNet(BaseNet): 69 | def __init__(self): 70 | super(SqueezeNet, self).__init__() 71 | 72 | self.layers = models.squeezenet1_1(True).features 73 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 74 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 75 | 76 | self.set_requires_grad(False) 77 | 78 | 79 | class AlexNet(BaseNet): 80 | def __init__(self): 81 | super(AlexNet, self).__init__() 82 | 83 | self.layers = models.alexnet(True).features 84 | self.target_layers = [2, 5, 8, 10, 12] 85 | self.n_channels_list = [64, 192, 384, 256, 256] 86 | 87 | self.set_requires_grad(False) 88 | 89 | 90 | class VGG16(BaseNet): 91 | def __init__(self): 92 | super(VGG16, self).__init__() 93 | 94 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 95 | self.target_layers = [4, 9, 16, 23, 30] 96 | self.n_channels_list = [64, 128, 256, 512, 512] 97 | 98 | self.set_requires_grad(False) 99 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = "alex", version: str = "0.1"): 12 | # build url 13 | url = ( 14 | "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/" 15 | + f"master/lpips/weights/v{version}/{net_type}.pth" 16 | ) 17 | 18 | # download 19 | old_state_dict = torch.hub.load_state_dict_from_url( 20 | url, 21 | progress=True, 22 | map_location=None if torch.cuda.is_available() else torch.device("cpu"), 23 | ) 24 | 25 | # rename keys 26 | new_state_dict = OrderedDict() 27 | for key, val in old_state_dict.items(): 28 | new_key = key 29 | new_key = new_key.replace("lin", "") 30 | new_key = new_key.replace("model.", "") 31 | new_state_dict[new_key] = val 32 | 33 | return new_state_dict 34 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | 25 | def readImages(renders_dir, gt_dir): 26 | renders = [] 27 | gts = [] 28 | image_names = [] 29 | for fname in os.listdir(renders_dir): 30 | render = Image.open(renders_dir / fname) 31 | gt = Image.open(gt_dir / fname) 32 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 33 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 34 | image_names.append(fname) 35 | return renders, gts, image_names 36 | 37 | 38 | def evaluate(model_paths): 39 | full_dict = {} 40 | per_view_dict = {} 41 | full_dict_polytopeonly = {} 42 | per_view_dict_polytopeonly = {} 43 | print("") 44 | 45 | for scene_dir in model_paths: 46 | try: 47 | print("Scene:", scene_dir) 48 | full_dict[scene_dir] = {} 49 | per_view_dict[scene_dir] = {} 50 | full_dict_polytopeonly[scene_dir] = {} 51 | per_view_dict_polytopeonly[scene_dir] = {} 52 | 53 | test_dir = Path(scene_dir) / "test" 54 | 55 | for method in os.listdir(test_dir): 56 | print("Method:", method) 57 | 58 | full_dict[scene_dir][method] = {} 59 | per_view_dict[scene_dir][method] = {} 60 | full_dict_polytopeonly[scene_dir][method] = {} 61 | per_view_dict_polytopeonly[scene_dir][method] = {} 62 | 63 | method_dir = test_dir / method 64 | gt_dir = method_dir / "gt" 65 | renders_dir = method_dir / "renders" 66 | renders, gts, image_names = readImages(renders_dir, gt_dir) 67 | 68 | ssims = [] 69 | psnrs = [] 70 | lpipss = [] 71 | 72 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 73 | ssims.append(ssim(renders[idx], gts[idx])) 74 | psnrs.append(psnr(renders[idx], gts[idx])) 75 | lpipss.append(lpips(renders[idx], gts[idx], net_type="vgg")) 76 | 77 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 78 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 79 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 80 | print("") 81 | 82 | full_dict[scene_dir][method].update( 83 | { 84 | "SSIM": torch.tensor(ssims).mean().item(), 85 | "PSNR": torch.tensor(psnrs).mean().item(), 86 | "LPIPS": torch.tensor(lpipss).mean().item(), 87 | } 88 | ) 89 | per_view_dict[scene_dir][method].update( 90 | { 91 | "SSIM": { 92 | name: ssim 93 | for ssim, name in zip( 94 | torch.tensor(ssims).tolist(), image_names 95 | ) 96 | }, 97 | "PSNR": { 98 | name: psnr 99 | for psnr, name in zip( 100 | torch.tensor(psnrs).tolist(), image_names 101 | ) 102 | }, 103 | "LPIPS": { 104 | name: lp 105 | for lp, name in zip( 106 | torch.tensor(lpipss).tolist(), image_names 107 | ) 108 | }, 109 | } 110 | ) 111 | 112 | with open(scene_dir + "/results.json", "w") as fp: 113 | json.dump(full_dict[scene_dir], fp, indent=True) 114 | with open(scene_dir + "/per_view.json", "w") as fp: 115 | json.dump(per_view_dict[scene_dir], fp, indent=True) 116 | except: 117 | print("Unable to compute metrics for model", scene_dir) 118 | 119 | 120 | if __name__ == "__main__": 121 | device = torch.device("cuda:0") 122 | torch.cuda.set_device(device) 123 | 124 | # Set up command line argument parser 125 | parser = ArgumentParser(description="Training script parameters") 126 | parser.add_argument( 127 | "--model_paths", "-m", required=True, nargs="+", type=str, default=[] 128 | ) 129 | args = parser.parse_args() 130 | evaluate(args.model_paths) 131 | -------------------------------------------------------------------------------- /prune.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from gaussian_renderer import render, count_render 16 | import sys 17 | from scene import Scene, GaussianModel 18 | from utils.general_utils import safe_state 19 | import uuid 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser, Namespace 23 | from arguments import ModelParams, PipelineParams, OptimizationParams 24 | from utils.graphics_utils import getWorld2View2 25 | from icecream import ic 26 | import random 27 | import copy 28 | import gc 29 | import numpy as np 30 | from collections import defaultdict 31 | 32 | # from cuml.cluster import HDBSCAN 33 | 34 | 35 | # def HDBSCAN_prune(gaussians, score_list, prune_percent): 36 | # # Ensure the tensor is on the GPU and detached from the graph 37 | # s, d = gaussians.get_xyz.shape 38 | # X_gpu = cp.asarray(gaussians.get_xyz.detach().cuda()) 39 | 40 | # scores_gpu = cp.asarray(score_list.detach().cuda()) 41 | # hdbscan = HDBSCAN(min_cluster_size = 100) 42 | # cluster_labels = hdbscan.fit_predict(X_gpu) 43 | # points_by_centroid = {} 44 | # ic("cluster_labels") 45 | # ic(cluster_labels.shape) 46 | # ic(cluster_labels) 47 | # for i, label in enumerate(cluster_labels): 48 | # if label not in points_by_centroid: 49 | # points_by_centroid[label] = [] 50 | # points_by_centroid[label].append(i) 51 | # points_to_prune = [] 52 | 53 | # for centroid_idx, point_indices in points_by_centroid.items(): 54 | # # Skip noise points with label -1 55 | # if centroid_idx == -1: 56 | # continue 57 | # num_to_prune = int(cp.ceil(prune_percent * len(point_indices))) 58 | # if num_to_prune <= 3: 59 | # continue 60 | # point_indices_cp = cp.array(point_indices) 61 | # distances = scores_gpu[point_indices_cp].squeeze() 62 | # indices_to_prune = point_indices_cp[cp.argsort(distances)[:num_to_prune]] 63 | # points_to_prune.extend(indices_to_prune) 64 | # points_to_prune = np.array(points_to_prune) 65 | # mask = np.zeros(s, dtype=bool) 66 | # mask[points_to_prune] = True 67 | # # points_to_prune now contains the indices of the points to be pruned 68 | # return mask 69 | 70 | 71 | # def uniform_prune(gaussians, k, score_list, prune_percent, sample = "k_mean"): 72 | # # get the farthest_point 73 | # D, I = None, None 74 | # s, d = gaussians.get_xyz.shape 75 | 76 | # if sample == "k_mean": 77 | # ic("k_mean") 78 | # n_iter = 200 79 | # verbose = False 80 | # kmeans = faiss.Kmeans(d, k=k, niter=n_iter, verbose=verbose, gpu=True) 81 | # kmeans.train(gaussians.get_xyz.detach().cpu().numpy()) 82 | # # The cluster centroids can be accessed as follows 83 | # centroids = kmeans.centroids 84 | # D, I = kmeans.index.search(gaussians.get_xyz.detach().cpu().numpy(), 1) 85 | # else: 86 | # point_idx = farthest_point_sampler(torch.unsqueeze(gaussians.get_xyz, 0), k) 87 | # centroids = gaussians.get_xyz[point_idx,: ] 88 | # centroids = centroids.squeeze(0) 89 | # index = faiss.IndexFlatL2(d) 90 | # index.add(centroids.detach().cpu().numpy()) 91 | # D, I = index.search(gaussians.get_xyz.detach().cpu().numpy(), 1) 92 | # points_to_prune = [] 93 | # points_by_centroid = defaultdict(list) 94 | # for point_idx, centroid_idx in enumerate(I.flatten()): 95 | # points_by_centroid[centroid_idx.item()].append(point_idx) 96 | # for centroid_idx in points_by_centroid: 97 | # points_by_centroid[centroid_idx] = np.array(points_by_centroid[centroid_idx]) 98 | # for centroid_idx, point_indices in points_by_centroid.items(): 99 | # # Find the number of points to prune 100 | # num_to_prune = int(np.ceil(prune_percent * len(point_indices))) 101 | # if num_to_prune <= 3: 102 | # continue 103 | # distances = score_list[point_indices].squeeze().cpu().detach().numpy() 104 | # indices_to_prune = point_indices[np.argsort(distances)[:num_to_prune]] 105 | # points_to_prune.extend(indices_to_prune) 106 | # # Convert the list to an array 107 | # points_to_prune = np.array(points_to_prune) 108 | # mask = np.zeros(s, dtype=bool) 109 | # mask[points_to_prune] = True 110 | # return mask 111 | 112 | def calculate_v_imp_score(gaussians, imp_list, v_pow): 113 | """ 114 | :param gaussians: A data structure containing Gaussian components with a get_scaling method. 115 | :param imp_list: The importance scores for each Gaussian component. 116 | :param v_pow: The power to which the volume ratios are raised. 117 | :return: A list of adjusted values (v_list) used for pruning. 118 | """ 119 | # Calculate the volume of each Gaussian component 120 | volume = torch.prod(gaussians.get_scaling, dim=1) 121 | # Determine the kth_percent_largest value 122 | index = int(len(volume) * 0.9) 123 | sorted_volume, _ = torch.sort(volume, descending=True) 124 | kth_percent_largest = sorted_volume[index] 125 | # Calculate v_list 126 | v_list = torch.pow(volume / kth_percent_largest, v_pow) 127 | v_list = v_list * imp_list 128 | return v_list 129 | 130 | 131 | 132 | 133 | def prune_list(gaussians, scene, pipe, background): 134 | viewpoint_stack = scene.getTrainCameras().copy() 135 | gaussian_list, imp_list = None, None 136 | viewpoint_cam = viewpoint_stack.pop() 137 | render_pkg = count_render(viewpoint_cam, gaussians, pipe, background) 138 | gaussian_list, imp_list = ( 139 | render_pkg["gaussians_count"], 140 | render_pkg["important_score"], 141 | ) 142 | 143 | # ic(dataset.model_path) 144 | for iteration in range(len(viewpoint_stack)): 145 | # Pick a random Camera 146 | # prunning 147 | viewpoint_cam = viewpoint_stack.pop() 148 | render_pkg = count_render(viewpoint_cam, gaussians, pipe, background) 149 | # image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 150 | gaussians_count, important_score = ( 151 | render_pkg["gaussians_count"].detach(), 152 | render_pkg["important_score"].detach(), 153 | ) 154 | gaussian_list += gaussians_count 155 | imp_list += important_score 156 | gc.collect() 157 | return gaussian_list, imp_list 158 | -------------------------------------------------------------------------------- /prune_finetune.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim 16 | from lpipsPyTorch import lpips 17 | from gaussian_renderer import render, network_gui, count_render 18 | import sys 19 | from scene import Scene, GaussianModel 20 | from utils.general_utils import safe_state 21 | import uuid 22 | from tqdm import tqdm 23 | from utils.image_utils import psnr 24 | from argparse import ArgumentParser, Namespace 25 | from arguments import ModelParams, PipelineParams, OptimizationParams 26 | import numpy as np 27 | 28 | try: 29 | from torch.utils.tensorboard import SummaryWriter 30 | 31 | TENSORBOARD_FOUND = True 32 | except ImportError: 33 | TENSORBOARD_FOUND = False 34 | from icecream import ic 35 | import random 36 | import copy 37 | import gc 38 | from os import makedirs 39 | from prune import prune_list, calculate_v_imp_score 40 | import torchvision 41 | from torch.optim.lr_scheduler import ExponentialLR 42 | import csv 43 | from utils.logger_utils import training_report, prepare_output_and_logger 44 | 45 | 46 | to_tensor = ( 47 | lambda x: x.to("cuda") 48 | if isinstance(x, torch.Tensor) 49 | else torch.Tensor(x).to("cuda") 50 | ) 51 | img2mse = lambda x, y: torch.mean((x - y) ** 2) 52 | mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(to_tensor([10.0])) 53 | 54 | 55 | def training( 56 | dataset, 57 | opt, 58 | pipe, 59 | testing_iterations, 60 | saving_iterations, 61 | checkpoint_iterations, 62 | checkpoint, 63 | debug_from, 64 | args, 65 | ): 66 | first_iter = 0 67 | tb_writer = prepare_output_and_logger(dataset) 68 | gaussians = GaussianModel(dataset.sh_degree) 69 | scene = Scene(dataset, gaussians) 70 | if checkpoint: 71 | gaussians.training_setup(opt) 72 | (model_params, first_iter) = torch.load(checkpoint) 73 | gaussians.restore(model_params, opt) 74 | elif args.start_pointcloud: 75 | gaussians.load_ply(args.start_pointcloud) 76 | ic(gaussians.get_xyz.shape) 77 | # ic(gaussians.optimizer.param_groups["xyz"].shape) 78 | gaussians.training_setup(opt) 79 | gaussians.max_radii2D = torch.zeros((gaussians.get_xyz.shape[0]), device="cuda") 80 | 81 | else: 82 | raise ValueError("A checkpoint file or a pointcloud is required to proceed.") 83 | 84 | 85 | 86 | 87 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 88 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 89 | 90 | iter_start = torch.cuda.Event(enable_timing=True) 91 | iter_end = torch.cuda.Event(enable_timing=True) 92 | 93 | viewpoint_stack = None 94 | ema_loss_for_log = 0.0 95 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 96 | first_iter += 1 97 | gaussians.scheduler = ExponentialLR(gaussians.optimizer, gamma=0.95) 98 | 99 | for iteration in range(first_iter, opt.iterations + 1): 100 | if network_gui.conn == None: 101 | network_gui.try_connect() 102 | while network_gui.conn != None: 103 | try: 104 | net_image_bytes = None 105 | ( 106 | custom_cam, 107 | do_training, 108 | pipe.convert_SHs_python, 109 | pipe.compute_cov3D_python, 110 | keep_alive, 111 | scaling_modifer, 112 | ) = network_gui.receive() 113 | if custom_cam != None: 114 | net_image = render( 115 | custom_cam, gaussians, pipe, background, scaling_modifer 116 | )["render"] 117 | net_image_bytes = memoryview( 118 | (torch.clamp(net_image, min=0, max=1.0) * 255) 119 | .byte() 120 | .permute(1, 2, 0) 121 | .contiguous() 122 | .cpu() 123 | .numpy() 124 | ) 125 | network_gui.send(net_image_bytes, dataset.source_path) 126 | if do_training and ( 127 | (iteration < int(opt.iterations)) or not keep_alive 128 | ): 129 | break 130 | except Exception as e: 131 | network_gui.conn = None 132 | 133 | iter_start.record() 134 | 135 | gaussians.update_learning_rate(iteration) 136 | 137 | # Every 1000 its we increase the levels of SH up to a maximum degree 138 | if iteration % 1000 == 0: 139 | gaussians.oneupSHdegree() 140 | if iteration % 400 == 0: 141 | gaussians.scheduler.step() 142 | 143 | # Pick a random Camera 144 | if not viewpoint_stack: 145 | viewpoint_stack = scene.getTrainCameras().copy() 146 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) 147 | 148 | # Render 149 | if (iteration - 1) == debug_from: 150 | pipe.debug = True 151 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 152 | image, viewspace_point_tensor, visibility_filter, radii = ( 153 | render_pkg["render"], 154 | render_pkg["viewspace_points"], 155 | render_pkg["visibility_filter"], 156 | render_pkg["radii"], 157 | ) 158 | 159 | # Loss 160 | gt_image = viewpoint_cam.original_image.cuda() 161 | Ll1 = l1_loss(image, gt_image) 162 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ( 163 | 1.0 - ssim(image, gt_image) 164 | ) 165 | 166 | loss.backward() 167 | 168 | iter_end.record() 169 | 170 | with torch.no_grad(): 171 | # Progress bar 172 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 173 | if iteration % 1000 == 0: 174 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 175 | progress_bar.update(1000) 176 | if iteration == opt.iterations: 177 | progress_bar.close() 178 | 179 | # Log and save 180 | 181 | if iteration in saving_iterations: 182 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 183 | scene.save(iteration) 184 | 185 | if iteration in checkpoint_iterations: 186 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 187 | if not os.path.exists(scene.model_path): 188 | os.makedirs(scene.model_path) 189 | torch.save( 190 | (gaussians.capture(), iteration), 191 | scene.model_path + "/chkpnt" + str(iteration) + ".pth", 192 | ) 193 | 194 | if iteration == checkpoint_iterations[-1]: 195 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background) 196 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow) 197 | np.savez(os.path.join(scene.model_path,"imp_score"), v_list.cpu().detach().numpy()) 198 | 199 | 200 | training_report( 201 | tb_writer, 202 | iteration, 203 | Ll1, 204 | loss, 205 | l1_loss, 206 | iter_start.elapsed_time(iter_end), 207 | testing_iterations, 208 | scene, 209 | render, 210 | (pipe, background), 211 | ) 212 | 213 | if iteration in args.prune_iterations: 214 | ic("Before prune iteration, number of gaussians: " + str(len(gaussians.get_xyz))) 215 | i = args.prune_iterations.index(iteration) 216 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background) 217 | 218 | if args.prune_type == "important_score": 219 | gaussians.prune_gaussians( 220 | (args.prune_decay**i) * args.prune_percent, imp_list 221 | ) 222 | elif args.prune_type == "v_important_score": 223 | # normalize scale 224 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow) 225 | gaussians.prune_gaussians( 226 | (args.prune_decay**i) * args.prune_percent, v_list 227 | ) 228 | elif args.prune_type == "max_v_important_score": 229 | v_list = imp_list * torch.max(gaussians.get_scaling, dim=1)[0] 230 | gaussians.prune_gaussians( 231 | (args.prune_decay**i) * args.prune_percent, v_list 232 | ) 233 | elif args.prune_type == "count": 234 | gaussians.prune_gaussians( 235 | (args.prune_decay**i) * args.prune_percent, gaussian_list 236 | ) 237 | elif args.prune_type == "opacity": 238 | gaussians.prune_gaussians( 239 | (args.prune_decay**i) * args.prune_percent, 240 | gaussians.get_opacity.detach(), 241 | ) 242 | # TODO(release different pruning method) 243 | # elif args.prune_type == "HDBSCAN": 244 | # masks = HDBSCAN_prune(gaussians, imp_list, (args.prune_decay**i)*args.prune_percent) 245 | # gaussians.prune_points(masks) 246 | # # elif args.prune_type == "v_important_score": 247 | # # imp_list * 248 | # elif args.prune_type == "two_step": 249 | # if i == 0: 250 | # volume = torch.prod(gaussians.get_scaling, dim = 1) 251 | # index = int(len(volume) * 0.9) 252 | # sorted_volume, sorted_indices = torch.sort(volume, descending=True, dim=0) 253 | # kth_percent_largest = sorted_volume[index] 254 | # v_list = torch.pow(volume/kth_percent_largest, args.v_pow) 255 | # v_list = v_list * imp_list 256 | # gaussians.prune_gaussians((args.prune_decay**i)*args.prune_percent, v_list) 257 | # else: 258 | # k = 5^(1*i) * 100 259 | # masks = uniform_prune(gaussians, k, imp_list, 0.3, "k_mean") 260 | # gaussians.prune_points(masks) 261 | # else: 262 | # k = len(gaussians.get_xyz)//500 * i 263 | # masks = uniform_prune(gaussians, k, imp_list, (args.prune_decay**i)*args.prune_percent, args.prune_type) 264 | # gaussians.prune_points(masks) 265 | # gaussians.prune_gaussians(args.prune_percent, imp_list) 266 | # gaussians.optimizer.zero_grad(set_to_none = True) #hachy way to maintain grad 267 | # if (iteration in args.opacity_prune_iterations): 268 | # gaussians.prune_opacity(0.05) 269 | else: 270 | raise Exception("Unsupportive pruning method") 271 | 272 | ic("After prune iteration, number of gaussians: " + str(len(gaussians.get_xyz))) 273 | 274 | # if iteration in args.densify_iteration: 275 | # gaussians.max_radii2D[visibility_filter] = torch.max( 276 | # gaussians.max_radii2D[visibility_filter], radii[visibility_filter] 277 | # ) 278 | # gaussians.add_densification_stats( 279 | # viewspace_point_tensor, visibility_filter 280 | # ) 281 | # gaussians.densify(opt.densify_grad_threshold, scene.cameras_extent) 282 | 283 | ic("after") 284 | ic(gaussians.get_xyz.shape) 285 | ic(len(gaussians.optimizer.param_groups[0]['params'][0])) 286 | 287 | if iteration < opt.iterations: 288 | gaussians.optimizer.step() 289 | gaussians.optimizer.zero_grad(set_to_none=True) 290 | 291 | 292 | if __name__ == "__main__": 293 | # Set up command line argument parser 294 | parser = ArgumentParser(description="Training script parameters") 295 | lp = ModelParams(parser) 296 | op = OptimizationParams(parser) 297 | pp = PipelineParams(parser) 298 | parser.add_argument("--ip", type=str, default="127.0.0.1") 299 | parser.add_argument("--port", type=int, default=6009) 300 | parser.add_argument("--debug_from", type=int, default=-1) 301 | parser.add_argument("--detect_anomaly", action="store_true", default=False) 302 | parser.add_argument( 303 | "--test_iterations", nargs="+", type=int, default=[30_001, 30_002, 35_000] 304 | ) 305 | parser.add_argument( 306 | "--save_iterations", nargs="+", type=int, default=[35_000] 307 | ) 308 | parser.add_argument("--quiet", action="store_true") 309 | parser.add_argument( 310 | "--checkpoint_iterations", nargs="+", type=int, default=[35_000] 311 | ) 312 | 313 | parser.add_argument("--prune_iterations", nargs="+", type=int, default=[30_001]) 314 | parser.add_argument("--start_checkpoint", type=str, default=None) 315 | parser.add_argument("--start_pointcloud", type=str, default=None) 316 | parser.add_argument("--prune_percent", type=float, default=0.1) 317 | parser.add_argument("--prune_decay", type=float, default=1) 318 | parser.add_argument( 319 | "--prune_type", type=str, default="important_score" 320 | ) # k_mean, farther_point_sample, important_score 321 | parser.add_argument("--v_pow", type=float, default=0.1) 322 | parser.add_argument("--densify_iteration", nargs="+", type=int, default=[-1]) 323 | args = parser.parse_args(sys.argv[1:]) 324 | args.save_iterations.append(args.iterations) 325 | 326 | print("Optimizing " + args.model_path) 327 | 328 | # Initialize system state (RNG) 329 | safe_state(args.quiet) 330 | 331 | # Start GUI server, configure and run training 332 | network_gui.init(args.ip, args.port) 333 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 334 | training( 335 | lp.extract(args), 336 | op.extract(args), 337 | pp.extract(args), 338 | args.test_iterations, 339 | args.save_iterations, 340 | args.checkpoint_iterations, 341 | args.start_checkpoint, 342 | args.debug_from, 343 | args, 344 | ) 345 | 346 | # All done 347 | print("\nTraining complete.") 348 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from argparse import ArgumentParser 21 | from arguments import ModelParams, PipelineParams, get_combined_args 22 | from gaussian_renderer import GaussianModel 23 | 24 | 25 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 26 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 27 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 28 | 29 | makedirs(render_path, exist_ok=True) 30 | makedirs(gts_path, exist_ok=True) 31 | 32 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 33 | rendering = render(view, gaussians, pipeline, background)["render"] 34 | gt = view.original_image[0:3, :, :] 35 | torchvision.utils.save_image( 36 | rendering, os.path.join(render_path, "{0:05d}".format(idx) + ".png") 37 | ) 38 | torchvision.utils.save_image( 39 | gt, os.path.join(gts_path, "{0:05d}".format(idx) + ".png") 40 | ) 41 | 42 | 43 | def render_sets( 44 | dataset: ModelParams, 45 | iteration: int, 46 | pipeline: PipelineParams, 47 | skip_train: bool, 48 | skip_test: bool, 49 | load_vq: bool, 50 | ): 51 | with torch.no_grad(): 52 | gaussians = GaussianModel(dataset.sh_degree) 53 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False, load_vq= load_vq) 54 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 55 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 56 | 57 | if not skip_train: 58 | render_set( 59 | dataset.model_path, 60 | "train", 61 | scene.loaded_iter, 62 | scene.getTrainCameras(), 63 | gaussians, 64 | pipeline, 65 | background, 66 | ) 67 | 68 | if not skip_test: 69 | render_set( 70 | dataset.model_path, 71 | "test", 72 | scene.loaded_iter, 73 | scene.getTestCameras(), 74 | gaussians, 75 | pipeline, 76 | background, 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | # Set up command line argument parser 82 | parser = ArgumentParser(description="Testing script parameters") 83 | model = ModelParams(parser, sentinel=True) 84 | pipeline = PipelineParams(parser) 85 | parser.add_argument("--iteration", default=-1, type=int) 86 | parser.add_argument("--skip_train", action="store_true") 87 | parser.add_argument("--skip_test", action="store_true") 88 | parser.add_argument("--load_vq", action="store_true") 89 | parser.add_argument("--quiet", action="store_true") 90 | args = get_combined_args(parser) 91 | print("Rendering " + args.model_path) 92 | 93 | # Initialize system state (RNG) 94 | safe_state(args.quiet) 95 | 96 | render_sets( 97 | model.extract(args), 98 | args.iteration, 99 | pipeline.extract(args), 100 | args.skip_train, 101 | args.skip_test, 102 | args.load_vq 103 | ) 104 | -------------------------------------------------------------------------------- /render_video.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | from tqdm import tqdm 16 | import numpy as np 17 | from os import makedirs 18 | from gaussian_renderer import render 19 | import torchvision 20 | from utils.general_utils import safe_state 21 | from argparse import ArgumentParser 22 | from arguments import ModelParams, PipelineParams, get_combined_args 23 | from gaussian_renderer import GaussianModel 24 | from icecream import ic 25 | import copy 26 | 27 | from utils.graphics_utils import getWorld2View2 28 | from utils.pose_utils import generate_ellipse_path, generate_spherical_sample_path, generate_spiral_path, generate_spherify_path, gaussian_poses, circular_poses 29 | # import stepfun 30 | 31 | 32 | 33 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 34 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 35 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 36 | 37 | makedirs(render_path, exist_ok=True) 38 | makedirs(gts_path, exist_ok=True) 39 | 40 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 41 | rendering = render(view, gaussians, pipeline, background)["render"] 42 | gt = view.original_image[0:3, :, :] 43 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 44 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 45 | 46 | 47 | # def normalize(x): 48 | # return x / np.linalg.norm(x) 49 | 50 | # def viewmatrix(z, up, pos): 51 | # vec2 = normalize(z) 52 | # vec0 = normalize(np.cross(up, vec2)) 53 | # vec1 = normalize(np.cross(vec2, vec0)) 54 | # m = np.stack([vec0, vec1, vec2, pos], 1) 55 | # return m 56 | 57 | # def poses_avg(poses): 58 | # hwf = poses[0, :3, -1:] 59 | # center = poses[:, :3, 3].mean(0) 60 | # vec2 = normalize(poses[:, :3, 2].sum(0)) 61 | # up = poses[:, :3, 1].sum(0) 62 | # c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 63 | # return c2w 64 | 65 | # def get_focal(camera): 66 | # focal = camera.FoVx 67 | # return focal 68 | 69 | # def poses_avg_fixed_center(poses): 70 | # hwf = poses[0, :3, -1:] 71 | # center = poses[:, :3, 3].mean(0) 72 | # vec2 = [1, 0, 0] 73 | # up = [0, 0, 1] 74 | # c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 75 | # return c2w 76 | 77 | # def focus_point_fn(poses): 78 | # """Calculate nearest point to all focal axes in poses.""" 79 | # directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] 80 | # m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) 81 | # mt_m = np.transpose(m, [0, 2, 1]) @ m 82 | # focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] 83 | # return focus_pt 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | # xy circular 93 | def render_circular_video(model_path, iteration, views, gaussians, pipeline, background, radius=0.5, n_frames=240): 94 | render_path = os.path.join(model_path, 'circular', "ours_{}".format(iteration)) 95 | os.makedirs(render_path, exist_ok=True) 96 | makedirs(render_path, exist_ok=True) 97 | # view = views[0] 98 | for idx in range(n_frames): 99 | view = copy.deepcopy(views[13]) 100 | angle = 2 * np.pi * idx / n_frames 101 | cam = circular_poses(view, radius, angle) 102 | rendering = render(cam, gaussians, pipeline, background)["render"] 103 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 104 | 105 | 106 | 107 | def render_video(model_path, iteration, views, gaussians, pipeline, background): 108 | render_path = os.path.join(model_path, 'video', "ours_{}".format(iteration)) 109 | makedirs(render_path, exist_ok=True) 110 | view = views[0] 111 | # render_path_spiral 112 | # render_path_spherical 113 | for idx, pose in enumerate(tqdm(generate_ellipse_path(views,n_frames=600), desc="Rendering progress")): 114 | view.world_view_transform = torch.tensor(getWorld2View2(pose[:3, :3].T, pose[:3, 3], view.trans, view.scale)).transpose(0, 1).cuda() 115 | view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0) 116 | view.camera_center = view.world_view_transform.inverse()[3, :3] 117 | rendering = render(view, gaussians, pipeline, background)["render"] 118 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 119 | 120 | 121 | 122 | 123 | def gaussian_render(model_path, iteration, views, gaussians, pipeline, background, args): 124 | views = views[:10] #take the first 10 views and check gaussian view point 125 | render_path = os.path.join(model_path, 'video', "gaussians_{}_std{}".format(iteration, args.std)) 126 | makedirs(render_path, exist_ok=True) 127 | 128 | for i, view in enumerate(views): 129 | rendering = render(view, gaussians, pipeline, background)["render"] 130 | sub_path = os.path.join(render_path,"view_"+str(i)) 131 | makedirs(sub_path ,exist_ok=True) 132 | torchvision.utils.save_image(rendering, os.path.join(sub_path, "gt"+'{0:05d}'.format(i) + ".png")) 133 | for j in range(10): 134 | n_view = copy.deepcopy(view) 135 | g_view = gaussian_poses(n_view, args.mean, args.std) 136 | rendering = render(g_view, gaussians, pipeline, background)["render"] 137 | torchvision.utils.save_image(rendering, os.path.join(sub_path, '{0:05d}'.format(j) + ".png")) 138 | 139 | 140 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, video: bool, circular:bool, radius: float, args): 141 | with torch.no_grad(): 142 | gaussians = GaussianModel(dataset.sh_degree) 143 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False, load_vq= args.load_vq) 144 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 145 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 146 | 147 | if not skip_train: 148 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 149 | 150 | if not skip_test: 151 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 152 | if circular: 153 | render_circular_video(dataset.model_path, scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background,radius) 154 | # by default generate ellipse path, other options include spiral, circular, or other generate_xxx_path function from utils.pose_utils 155 | # Modify trajectory function in render_video's enumerate 156 | if video: 157 | render_video(dataset.model_path, scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 158 | #sample virtual view 159 | if args.gaussians: 160 | gaussian_render(dataset.model_path, scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, args) 161 | 162 | 163 | if __name__ == "__main__": 164 | # Set up command line argument parser 165 | parser = ArgumentParser(description="Testing script parameters") 166 | model = ModelParams(parser, sentinel=True) 167 | pipeline = PipelineParams(parser) 168 | parser.add_argument("--iteration", default=-1, type=int) 169 | parser.add_argument("--skip_train", action="store_true") 170 | parser.add_argument("--skip_test", action="store_true") 171 | parser.add_argument("--quiet", action="store_true") 172 | parser.add_argument("--video", action="store_true") 173 | parser.add_argument("--circular", action="store_true") 174 | parser.add_argument("--radius", default=5, type=float) 175 | parser.add_argument("--gaussians", action="store_true") 176 | parser.add_argument("--mean", default=0, type=float) 177 | parser.add_argument("--std", default=0.03, type=float) 178 | parser.add_argument("--load_vq", action="store_true") 179 | args = get_combined_args(parser) 180 | print("Rendering " + args.model_path) 181 | 182 | # Initialize system state (RNG) 183 | safe_state(args.quiet) 184 | 185 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.video, args.circular, args.radius, args) -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | 21 | 22 | class Scene: 23 | gaussians: GaussianModel 24 | # modified 25 | def __init__( 26 | self, 27 | args: ModelParams, 28 | gaussians: GaussianModel, 29 | load_iteration=None, 30 | shuffle=True, 31 | resolution_scales=[1.0], 32 | new_sh=0, 33 | load_vq=False 34 | ): 35 | """b 36 | :param path: Path to colmap scene main folder. 37 | """ 38 | self.model_path = args.model_path 39 | self.loaded_iter = None 40 | self.gaussians = gaussians 41 | 42 | if load_iteration: 43 | if load_iteration == -1: 44 | self.loaded_iter = searchForMaxIteration( 45 | os.path.join(self.model_path, "point_cloud") 46 | ) 47 | else: 48 | self.loaded_iter = load_iteration 49 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 50 | 51 | self.train_cameras = {} 52 | self.test_cameras = {} 53 | print(args.source_path) 54 | if os.path.exists(os.path.join(args.source_path, "sparse")): 55 | scene_info = sceneLoadTypeCallbacks["Colmap"]( 56 | args.source_path, args.images, args.eval 57 | ) 58 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 59 | print("Found transforms_train.json file, assuming Blender data set!") 60 | scene_info = sceneLoadTypeCallbacks["Blender"]( 61 | args.source_path, args.white_background, args.eval 62 | ) 63 | else: 64 | assert False, "Could not recognize scene type!" 65 | 66 | if not self.loaded_iter: 67 | with open(scene_info.ply_path, "rb") as src_file, open( 68 | os.path.join(self.model_path, "input.ply"), "wb" 69 | ) as dest_file: 70 | dest_file.write(src_file.read()) 71 | json_cams = [] 72 | camlist = [] 73 | if scene_info.test_cameras: 74 | camlist.extend(scene_info.test_cameras) 75 | if scene_info.train_cameras: 76 | camlist.extend(scene_info.train_cameras) 77 | for id, cam in enumerate(camlist): 78 | json_cams.append(camera_to_JSON(id, cam)) 79 | with open(os.path.join(self.model_path, "cameras.json"), "w") as file: 80 | json.dump(json_cams, file) 81 | 82 | if shuffle: 83 | random.shuffle( 84 | scene_info.train_cameras 85 | ) # Multi-res consistent random shuffling 86 | random.shuffle( 87 | scene_info.test_cameras 88 | ) # Multi-res consistent random shuffling 89 | 90 | self.cameras_extent = scene_info.nerf_normalization["radius"] 91 | 92 | for resolution_scale in resolution_scales: 93 | # temp comment out 94 | print("Loading Training Cameras") 95 | self.train_cameras[resolution_scale] = cameraList_from_camInfos( 96 | scene_info.train_cameras, resolution_scale, args 97 | ) 98 | print("Loading Test Cameras") 99 | self.test_cameras[resolution_scale] = cameraList_from_camInfos( 100 | scene_info.test_cameras, resolution_scale, args 101 | ) 102 | if load_vq: 103 | self.gaussians.load_vq(self.model_path) 104 | 105 | elif new_sh != 0 and self.loaded_iter: 106 | self.gaussians.load_ply_sh( 107 | os.path.join( 108 | self.model_path, 109 | "point_cloud", 110 | "iteration_" + str(self.loaded_iter), 111 | "point_cloud.ply", 112 | ), 113 | new_sh, 114 | ) 115 | elif self.loaded_iter: 116 | self.gaussians.load_ply( 117 | os.path.join( 118 | self.model_path, 119 | "point_cloud", 120 | "iteration_" + str(self.loaded_iter), 121 | "point_cloud.ply", 122 | ) 123 | ) 124 | else: 125 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 126 | 127 | def save(self, iteration): 128 | point_cloud_path = os.path.join( 129 | self.model_path, "point_cloud/iteration_{}".format(iteration) 130 | ) 131 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 132 | 133 | def getTrainCameras(self, scale=1.0): 134 | return self.train_cameras[scale] 135 | 136 | def getTestCameras(self, scale=1.0): 137 | return self.test_cameras[scale] 138 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | 18 | class Camera(nn.Module): 19 | def __init__( 20 | self, 21 | colmap_id, 22 | R, 23 | T, 24 | FoVx, 25 | FoVy, 26 | image, 27 | gt_alpha_mask, 28 | image_name, 29 | uid, 30 | trans=np.array([0.0, 0.0, 0.0]), 31 | scale=1.0, 32 | data_device="cuda", 33 | ): 34 | super(Camera, self).__init__() 35 | 36 | self.uid = uid 37 | self.colmap_id = colmap_id 38 | self.R = R 39 | self.T = T 40 | self.FoVx = FoVx 41 | self.FoVy = FoVy 42 | self.image_name = image_name 43 | 44 | try: 45 | self.data_device = torch.device(data_device) 46 | except Exception as e: 47 | print(e) 48 | print( 49 | f"[Warning] Custom device {data_device} failed, fallback to default cuda device" 50 | ) 51 | self.data_device = torch.device("cuda") 52 | 53 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 54 | self.image_width = self.original_image.shape[2] 55 | self.image_height = self.original_image.shape[1] 56 | 57 | if gt_alpha_mask is not None: 58 | self.original_image *= gt_alpha_mask.to(self.data_device) 59 | else: 60 | self.original_image *= torch.ones( 61 | (1, self.image_height, self.image_width), device=self.data_device 62 | ) 63 | 64 | self.zfar = 100.0 65 | self.znear = 0.01 66 | 67 | self.trans = trans 68 | self.scale = scale 69 | 70 | self.world_view_transform = ( 71 | torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 72 | ) 73 | self.projection_matrix = ( 74 | getProjectionMatrix( 75 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy 76 | ) 77 | .transpose(0, 1) 78 | .cuda() 79 | ) 80 | self.full_proj_transform = ( 81 | self.world_view_transform.unsqueeze(0).bmm( 82 | self.projection_matrix.unsqueeze(0) 83 | ) 84 | ).squeeze(0) 85 | self.camera_center = self.world_view_transform.inverse()[3, :3] 86 | 87 | 88 | class MiniCam: 89 | def __init__( 90 | self, 91 | width, 92 | height, 93 | fovy, 94 | fovx, 95 | znear, 96 | zfar, 97 | world_view_transform, 98 | full_proj_transform, 99 | ): 100 | self.image_width = width 101 | self.image_height = height 102 | self.FoVy = fovy 103 | self.FoVx = fovx 104 | self.znear = znear 105 | self.zfar = zfar 106 | self.world_view_transform = world_view_transform 107 | self.full_proj_transform = full_proj_transform 108 | view_inv = torch.inverse(self.world_view_transform) 109 | self.camera_center = view_inv[3][:3] 110 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"] 18 | ) 19 | Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] 22 | ) 23 | Point3D = collections.namedtuple( 24 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] 25 | ) 26 | CAMERA_MODELS = { 27 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 28 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 29 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 30 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 31 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 32 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 33 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 34 | CameraModel(model_id=7, model_name="FOV", num_params=5), 35 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 36 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 37 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), 38 | } 39 | CAMERA_MODEL_IDS = dict( 40 | [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] 41 | ) 42 | CAMERA_MODEL_NAMES = dict( 43 | [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] 44 | ) 45 | 46 | 47 | def qvec2rotmat(qvec): 48 | return np.array( 49 | [ 50 | [ 51 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 52 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 53 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], 54 | ], 55 | [ 56 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 57 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 58 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], 59 | ], 60 | [ 61 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 62 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 63 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, 64 | ], 65 | ] 66 | ) 67 | 68 | 69 | def rotmat2qvec(R): 70 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 71 | K = ( 72 | np.array( 73 | [ 74 | [Rxx - Ryy - Rzz, 0, 0, 0], 75 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 76 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 77 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], 78 | ] 79 | ) 80 | / 3.0 81 | ) 82 | eigvals, eigvecs = np.linalg.eigh(K) 83 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 84 | if qvec[0] < 0: 85 | qvec *= -1 86 | return qvec 87 | 88 | 89 | class Image(BaseImage): 90 | def qvec2rotmat(self): 91 | return qvec2rotmat(self.qvec) 92 | 93 | 94 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 95 | """Read and unpack the next bytes from a binary file. 96 | :param fid: 97 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 98 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 99 | :param endian_character: Any of {@, =, <, >, !} 100 | :return: Tuple of read and unpacked values. 101 | """ 102 | data = fid.read(num_bytes) 103 | return struct.unpack(endian_character + format_char_sequence, data) 104 | 105 | 106 | def read_points3D_text(path): 107 | """ 108 | see: src/base/reconstruction.cc 109 | void Reconstruction::ReadPoints3DText(const std::string& path) 110 | void Reconstruction::WritePoints3DText(const std::string& path) 111 | """ 112 | xyzs = None 113 | rgbs = None 114 | errors = None 115 | num_points = 0 116 | with open(path, "r") as fid: 117 | while True: 118 | line = fid.readline() 119 | if not line: 120 | break 121 | line = line.strip() 122 | if len(line) > 0 and line[0] != "#": 123 | num_points += 1 124 | 125 | xyzs = np.empty((num_points, 3)) 126 | rgbs = np.empty((num_points, 3)) 127 | errors = np.empty((num_points, 1)) 128 | count = 0 129 | with open(path, "r") as fid: 130 | while True: 131 | line = fid.readline() 132 | if not line: 133 | break 134 | line = line.strip() 135 | if len(line) > 0 and line[0] != "#": 136 | elems = line.split() 137 | xyz = np.array(tuple(map(float, elems[1:4]))) 138 | rgb = np.array(tuple(map(int, elems[4:7]))) 139 | error = np.array(float(elems[7])) 140 | xyzs[count] = xyz 141 | rgbs[count] = rgb 142 | errors[count] = error 143 | count += 1 144 | 145 | return xyzs, rgbs, errors 146 | 147 | 148 | def read_points3D_binary(path_to_model_file): 149 | """ 150 | see: src/base/reconstruction.cc 151 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 152 | void Reconstruction::WritePoints3DBinary(const std::string& path) 153 | """ 154 | 155 | with open(path_to_model_file, "rb") as fid: 156 | num_points = read_next_bytes(fid, 8, "Q")[0] 157 | 158 | xyzs = np.empty((num_points, 3)) 159 | rgbs = np.empty((num_points, 3)) 160 | errors = np.empty((num_points, 1)) 161 | 162 | for p_id in range(num_points): 163 | binary_point_line_properties = read_next_bytes( 164 | fid, num_bytes=43, format_char_sequence="QdddBBBd" 165 | ) 166 | xyz = np.array(binary_point_line_properties[1:4]) 167 | rgb = np.array(binary_point_line_properties[4:7]) 168 | error = np.array(binary_point_line_properties[7]) 169 | track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 170 | 0 171 | ] 172 | track_elems = read_next_bytes( 173 | fid, 174 | num_bytes=8 * track_length, 175 | format_char_sequence="ii" * track_length, 176 | ) 177 | xyzs[p_id] = xyz 178 | rgbs[p_id] = rgb 179 | errors[p_id] = error 180 | return xyzs, rgbs, errors 181 | 182 | 183 | def read_intrinsics_text(path): 184 | """ 185 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 186 | """ 187 | cameras = {} 188 | with open(path, "r") as fid: 189 | while True: 190 | line = fid.readline() 191 | if not line: 192 | break 193 | line = line.strip() 194 | if len(line) > 0 and line[0] != "#": 195 | elems = line.split() 196 | camera_id = int(elems[0]) 197 | model = elems[1] 198 | assert ( 199 | model == "PINHOLE" 200 | ), "While the loader support other types, the rest of the code assumes PINHOLE" 201 | width = int(elems[2]) 202 | height = int(elems[3]) 203 | params = np.array(tuple(map(float, elems[4:]))) 204 | cameras[camera_id] = Camera( 205 | id=camera_id, model=model, width=width, height=height, params=params 206 | ) 207 | return cameras 208 | 209 | 210 | def read_extrinsics_binary(path_to_model_file): 211 | """ 212 | see: src/base/reconstruction.cc 213 | void Reconstruction::ReadImagesBinary(const std::string& path) 214 | void Reconstruction::WriteImagesBinary(const std::string& path) 215 | """ 216 | images = {} 217 | with open(path_to_model_file, "rb") as fid: 218 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 219 | for _ in range(num_reg_images): 220 | binary_image_properties = read_next_bytes( 221 | fid, num_bytes=64, format_char_sequence="idddddddi" 222 | ) 223 | image_id = binary_image_properties[0] 224 | qvec = np.array(binary_image_properties[1:5]) 225 | tvec = np.array(binary_image_properties[5:8]) 226 | camera_id = binary_image_properties[8] 227 | image_name = "" 228 | current_char = read_next_bytes(fid, 1, "c")[0] 229 | while current_char != b"\x00": # look for the ASCII 0 entry 230 | image_name += current_char.decode("utf-8") 231 | current_char = read_next_bytes(fid, 1, "c")[0] 232 | num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 233 | 0 234 | ] 235 | x_y_id_s = read_next_bytes( 236 | fid, 237 | num_bytes=24 * num_points2D, 238 | format_char_sequence="ddq" * num_points2D, 239 | ) 240 | xys = np.column_stack( 241 | [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] 242 | ) 243 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 244 | images[image_id] = Image( 245 | id=image_id, 246 | qvec=qvec, 247 | tvec=tvec, 248 | camera_id=camera_id, 249 | name=image_name, 250 | xys=xys, 251 | point3D_ids=point3D_ids, 252 | ) 253 | return images 254 | 255 | 256 | def read_intrinsics_binary(path_to_model_file): 257 | """ 258 | see: src/base/reconstruction.cc 259 | void Reconstruction::WriteCamerasBinary(const std::string& path) 260 | void Reconstruction::ReadCamerasBinary(const std::string& path) 261 | """ 262 | cameras = {} 263 | with open(path_to_model_file, "rb") as fid: 264 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 265 | for _ in range(num_cameras): 266 | camera_properties = read_next_bytes( 267 | fid, num_bytes=24, format_char_sequence="iiQQ" 268 | ) 269 | camera_id = camera_properties[0] 270 | model_id = camera_properties[1] 271 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 272 | width = camera_properties[2] 273 | height = camera_properties[3] 274 | num_params = CAMERA_MODEL_IDS[model_id].num_params 275 | params = read_next_bytes( 276 | fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params 277 | ) 278 | cameras[camera_id] = Camera( 279 | id=camera_id, 280 | model=model_name, 281 | width=width, 282 | height=height, 283 | params=np.array(params), 284 | ) 285 | assert len(cameras) == num_cameras 286 | return cameras 287 | 288 | 289 | def read_extrinsics_text(path): 290 | """ 291 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 292 | """ 293 | images = {} 294 | with open(path, "r") as fid: 295 | while True: 296 | line = fid.readline() 297 | if not line: 298 | break 299 | line = line.strip() 300 | if len(line) > 0 and line[0] != "#": 301 | elems = line.split() 302 | image_id = int(elems[0]) 303 | qvec = np.array(tuple(map(float, elems[1:5]))) 304 | tvec = np.array(tuple(map(float, elems[5:8]))) 305 | camera_id = int(elems[8]) 306 | image_name = elems[9] 307 | elems = fid.readline().split() 308 | xys = np.column_stack( 309 | [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] 310 | ) 311 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 312 | images[image_id] = Image( 313 | id=image_id, 314 | qvec=qvec, 315 | tvec=tvec, 316 | camera_id=camera_id, 317 | name=image_name, 318 | xys=xys, 319 | point3D_ids=point3D_ids, 320 | ) 321 | return images 322 | 323 | 324 | def read_colmap_bin_array(path): 325 | """ 326 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 327 | 328 | :param path: path to the colmap binary file. 329 | :return: nd array with the floating point values in the value 330 | """ 331 | with open(path, "rb") as fid: 332 | width, height, channels = np.genfromtxt( 333 | fid, delimiter="&", max_rows=1, usecols=(0, 1, 2), dtype=int 334 | ) 335 | fid.seek(0) 336 | num_delimiter = 0 337 | byte = fid.read(1) 338 | while True: 339 | if byte == b"&": 340 | num_delimiter += 1 341 | if num_delimiter >= 3: 342 | break 343 | byte = fid.read(1) 344 | array = np.fromfile(fid, np.float32) 345 | array = array.reshape((width, height, channels), order="F") 346 | return np.transpose(array, (1, 0, 2)).squeeze() 347 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import ( 17 | read_extrinsics_text, 18 | read_intrinsics_text, 19 | qvec2rotmat, 20 | read_extrinsics_binary, 21 | read_intrinsics_binary, 22 | read_points3D_binary, 23 | read_points3D_text, 24 | ) 25 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 26 | import numpy as np 27 | import json 28 | from pathlib import Path 29 | from plyfile import PlyData, PlyElement 30 | from utils.sh_utils import SH2RGB 31 | from scene.gaussian_model import BasicPointCloud 32 | 33 | 34 | class CameraInfo(NamedTuple): 35 | uid: int 36 | R: np.array 37 | T: np.array 38 | FovY: np.array 39 | FovX: np.array 40 | image: np.array 41 | image_path: str 42 | image_name: str 43 | width: int 44 | height: int 45 | 46 | 47 | class SceneInfo(NamedTuple): 48 | point_cloud: BasicPointCloud 49 | train_cameras: list 50 | test_cameras: list 51 | nerf_normalization: dict 52 | ply_path: str 53 | 54 | 55 | def getNerfppNorm(cam_info): 56 | def get_center_and_diag(cam_centers): 57 | cam_centers = np.hstack(cam_centers) 58 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 59 | center = avg_cam_center 60 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 61 | diagonal = np.max(dist) 62 | return center.flatten(), diagonal 63 | 64 | cam_centers = [] 65 | 66 | for cam in cam_info: 67 | W2C = getWorld2View2(cam.R, cam.T) 68 | C2W = np.linalg.inv(W2C) 69 | cam_centers.append(C2W[:3, 3:4]) 70 | 71 | center, diagonal = get_center_and_diag(cam_centers) 72 | radius = diagonal * 1.1 73 | 74 | translate = -center 75 | 76 | return {"translate": translate, "radius": radius} 77 | 78 | 79 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 80 | cam_infos = [] 81 | for idx, key in enumerate(cam_extrinsics): 82 | sys.stdout.write("\r") 83 | # the exact output you're looking for: 84 | sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics))) 85 | sys.stdout.flush() 86 | 87 | extr = cam_extrinsics[key] 88 | intr = cam_intrinsics[extr.camera_id] 89 | height = intr.height 90 | width = intr.width 91 | 92 | uid = intr.id 93 | R = np.transpose(qvec2rotmat(extr.qvec)) 94 | T = np.array(extr.tvec) 95 | 96 | if intr.model == "SIMPLE_PINHOLE": 97 | focal_length_x = intr.params[0] 98 | FovY = focal2fov(focal_length_x, height) 99 | FovX = focal2fov(focal_length_x, width) 100 | elif intr.model == "PINHOLE": 101 | focal_length_x = intr.params[0] 102 | focal_length_y = intr.params[1] 103 | FovY = focal2fov(focal_length_y, height) 104 | FovX = focal2fov(focal_length_x, width) 105 | else: 106 | assert ( 107 | False 108 | ), "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 109 | 110 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 111 | image_name = os.path.basename(image_path).split(".")[0] 112 | image = Image.open(image_path) 113 | 114 | cam_info = CameraInfo( 115 | uid=uid, 116 | R=R, 117 | T=T, 118 | FovY=FovY, 119 | FovX=FovX, 120 | image=image, 121 | image_path=image_path, 122 | image_name=image_name, 123 | width=width, 124 | height=height, 125 | ) 126 | cam_infos.append(cam_info) 127 | sys.stdout.write("\n") 128 | return cam_infos 129 | 130 | 131 | def fetchPly(path): 132 | plydata = PlyData.read(path) 133 | vertices = plydata["vertex"] 134 | positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T 135 | colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0 136 | normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T 137 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 138 | 139 | 140 | def storePly(path, xyz, rgb): 141 | # Define the dtype for the structured array 142 | dtype = [ 143 | ("x", "f4"), 144 | ("y", "f4"), 145 | ("z", "f4"), 146 | ("nx", "f4"), 147 | ("ny", "f4"), 148 | ("nz", "f4"), 149 | ("red", "u1"), 150 | ("green", "u1"), 151 | ("blue", "u1"), 152 | ] 153 | 154 | normals = np.zeros_like(xyz) 155 | 156 | elements = np.empty(xyz.shape[0], dtype=dtype) 157 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 158 | elements[:] = list(map(tuple, attributes)) 159 | 160 | # Create the PlyData object and write to file 161 | vertex_element = PlyElement.describe(elements, "vertex") 162 | ply_data = PlyData([vertex_element]) 163 | ply_data.write(path) 164 | 165 | 166 | def readColmapSceneInfo(path, images, eval, llffhold=8): 167 | try: 168 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 169 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 170 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 171 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 172 | except: 173 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 174 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 175 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 176 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 177 | 178 | reading_dir = "images" if images == None else images 179 | cam_infos_unsorted = readColmapCameras( 180 | cam_extrinsics=cam_extrinsics, 181 | cam_intrinsics=cam_intrinsics, 182 | images_folder=os.path.join(path, reading_dir), 183 | ) 184 | cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name) 185 | 186 | if eval: 187 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 188 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 189 | else: 190 | train_cam_infos = cam_infos 191 | test_cam_infos = [] 192 | 193 | nerf_normalization = getNerfppNorm(train_cam_infos) 194 | 195 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 196 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 197 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 198 | if not os.path.exists(ply_path): 199 | print( 200 | "Converting point3d.bin to .ply, will happen only the first time you open the scene." 201 | ) 202 | try: 203 | xyz, rgb, _ = read_points3D_binary(bin_path) 204 | except: 205 | xyz, rgb, _ = read_points3D_text(txt_path) 206 | storePly(ply_path, xyz, rgb) 207 | try: 208 | pcd = fetchPly(ply_path) 209 | except: 210 | pcd = None 211 | 212 | scene_info = SceneInfo( 213 | point_cloud=pcd, 214 | train_cameras=train_cam_infos, 215 | test_cameras=test_cam_infos, 216 | nerf_normalization=nerf_normalization, 217 | ply_path=ply_path, 218 | ) 219 | return scene_info 220 | 221 | 222 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 223 | cam_infos = [] 224 | 225 | with open(os.path.join(path, transformsfile)) as json_file: 226 | contents = json.load(json_file) 227 | fovx = contents["camera_angle_x"] 228 | 229 | frames = contents["frames"] 230 | for idx, frame in enumerate(frames): 231 | cam_name = os.path.join(path, frame["file_path"] + extension) 232 | 233 | # NeRF 'transform_matrix' is a camera-to-world transform 234 | c2w = np.array(frame["transform_matrix"]) 235 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 236 | c2w[:3, 1:3] *= -1 237 | 238 | # get the world-to-camera transform and set R, T 239 | w2c = np.linalg.inv(c2w) 240 | R = np.transpose( 241 | w2c[:3, :3] 242 | ) # R is stored transposed due to 'glm' in CUDA code 243 | T = w2c[:3, 3] 244 | 245 | image_path = os.path.join(path, cam_name) 246 | image_name = Path(cam_name).stem 247 | image = Image.open(image_path) 248 | 249 | im_data = np.array(image.convert("RGBA")) 250 | 251 | bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0]) 252 | 253 | norm_data = im_data / 255.0 254 | arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * ( 255 | 1 - norm_data[:, :, 3:4] 256 | ) 257 | image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB") 258 | 259 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 260 | FovY = fovy 261 | FovX = fovx 262 | 263 | cam_infos.append( 264 | CameraInfo( 265 | uid=idx, 266 | R=R, 267 | T=T, 268 | FovY=FovY, 269 | FovX=FovX, 270 | image=image, 271 | image_path=image_path, 272 | image_name=image_name, 273 | width=image.size[0], 274 | height=image.size[1], 275 | ) 276 | ) 277 | 278 | return cam_infos 279 | 280 | 281 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 282 | print("Reading Training Transforms") 283 | train_cam_infos = readCamerasFromTransforms( 284 | path, "transforms_train.json", white_background, extension 285 | ) 286 | print("Reading Test Transforms") 287 | test_cam_infos = readCamerasFromTransforms( 288 | path, "transforms_test.json", white_background, extension 289 | ) 290 | 291 | if not eval: 292 | train_cam_infos.extend(test_cam_infos) 293 | test_cam_infos = [] 294 | 295 | nerf_normalization = getNerfppNorm(train_cam_infos) 296 | 297 | ply_path = os.path.join(path, "points3d.ply") 298 | if not os.path.exists(ply_path): 299 | # Since this data set has no colmap data, we start with random points 300 | num_pts = 100_000 301 | print(f"Generating random point cloud ({num_pts})...") 302 | 303 | # We create random points inside the bounds of the synthetic Blender scenes 304 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 305 | shs = np.random.random((num_pts, 3)) / 255.0 306 | pcd = BasicPointCloud( 307 | points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)) 308 | ) 309 | 310 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 311 | try: 312 | pcd = fetchPly(ply_path) 313 | except: 314 | pcd = None 315 | 316 | scene_info = SceneInfo( 317 | point_cloud=pcd, 318 | train_cameras=train_cam_infos, 319 | test_cameras=test_cam_infos, 320 | nerf_normalization=nerf_normalization, 321 | ply_path=ply_path, 322 | ) 323 | return scene_info 324 | 325 | 326 | sceneLoadTypeCallbacks = { 327 | "Colmap": readColmapSceneInfo, 328 | "Blender": readNerfSyntheticInfo, 329 | } 330 | -------------------------------------------------------------------------------- /scripts/run_distill_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to get the id of an available GPU 4 | get_available_gpu() { 5 | local mem_threshold=500 6 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | awk -v threshold="$mem_threshold" -F', ' ' 7 | $2 < threshold { print $1; exit } 8 | ' 9 | } 10 | 11 | # Initial port number 12 | port=6025 13 | 14 | # Datasets 15 | declare -a run_args=( 16 | "bicycle" 17 | "bonsai" 18 | "counter" 19 | "kitchen" 20 | "room" 21 | "stump" 22 | "garden" 23 | "train" 24 | "truck" 25 | ) 26 | 27 | 28 | # activate psudo view, else using train view for distillation 29 | declare -a virtue_view_arg=( 30 | "--augmented_view" 31 | ) 32 | # compress_gaussian/output5_prune_final_result/bicycle_v_important_score_oneshot_prune_densify0.67_vpow0.1_try3_decay1 33 | # compress_gaussian/output2 34 | for arg in "${run_args[@]}"; do 35 | for view in "${virtue_view_arg[@]}"; do 36 | # Wait for an available GPU 37 | while true; do 38 | gpu_id=$(get_available_gpu) 39 | if [[ -n $gpu_id ]]; then 40 | echo "GPU $gpu_id is available. Starting distill_train.py with dataset '$arg' and options '$view' on port $port" 41 | CUDA_VISIBLE_DEVICES=$gpu_id nohup python distill_train.py \ 42 | -s "PATH/TO/DATASET/$arg" \ 43 | -m "OUTPUT/PATH/${arg}_${prune_percent}" \ 44 | --start_checkpoint "PATH/TO/CHECKPOINT/$arg/chkpnt30000.pth" \ 45 | --iteration 40000 \ 46 | --eval \ 47 | --teacher_model "PATH/TO/TEACHER_CHECKPOINT/${arg}/chkpnt30000.pth" \ 48 | --new_max_sh 2 \ 49 | --position_lr_max_steps 40000 \ 50 | --enable_covariance \ 51 | $view \ 52 | --port $port > "logs/distill_${arg}${view}.log" 2>&1 & 53 | 54 | # Increment the port number for the next run 55 | ((port++)) 56 | # Allow some time for the process to initialize and potentially use GPU memory 57 | sleep 60 58 | break 59 | else 60 | echo "No GPU available at the moment. Retrying in 1 minute." 61 | sleep 60 62 | fi 63 | done 64 | done 65 | done 66 | wait 67 | echo "All distill_train.py runs completed." 68 | -------------------------------------------------------------------------------- /scripts/run_prune_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to get the id of an available GPU 4 | get_available_gpu() { 5 | local mem_threshold=10000 6 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | awk -v threshold="$mem_threshold" -F', ' ' 7 | $2 < threshold { print $1; exit } 8 | ' 9 | } 10 | 11 | # Initial port number 12 | port=6041 13 | 14 | # Only one dataset specified here, but you could run multiple 15 | declare -a run_args=( 16 | "bicycle" 17 | # "bonsai" 18 | # "counter" 19 | # "kitchen" 20 | # "room" 21 | # "stump" 22 | # "garden" 23 | # "train" 24 | # "truck" 25 | # "chair" 26 | # "drums" 27 | # "ficus" 28 | # "hotdog" 29 | # "lego" 30 | # "mic" 31 | # "materials" 32 | # "ship" 33 | ) 34 | 35 | 36 | # Prune percentages and corresponding decays, volume power 37 | declare -a prune_percents=(0.66) 38 | # decay rate for the following prune. The 2nd prune would prune out 0.5 x 0.6 = 0.3 of the remaining gaussian 39 | declare -a prune_decays=(1) 40 | # The volumetric importance power. The higher it is the more weight the volume is in the Global significant 41 | declare -a v_pow=(0.1) 42 | 43 | # prune type, by default the Global significant listed in the paper, but there are other option that you can play with 44 | declare -a prune_types=( 45 | "v_important_score" 46 | # "important_score" 47 | # "count" 48 | ) 49 | 50 | 51 | # Check that prune_percents, prune_decays, and v_pow arrays have the same length 52 | if [ "${#prune_percents[@]}" -ne "${#prune_decays[@]}" ] || [ "${#prune_percents[@]}" -ne "${#v_pow[@]}" ]; then 53 | echo "The lengths of prune_percents, prune_decays, and v_pow arrays do not match." 54 | exit 1 55 | fi 56 | 57 | # Loop over the arguments array 58 | for arg in "${run_args[@]}"; do 59 | for i in "${!prune_percents[@]}"; do 60 | prune_percent="${prune_percents[i]}" 61 | prune_decay="${prune_decays[i]}" 62 | vp="${v_pow[i]}" 63 | 64 | for prune_type in "${prune_types[@]}"; do 65 | # Wait for an available GPU 66 | while true; do 67 | gpu_id=$(get_available_gpu) 68 | if [[ -n $gpu_id ]]; then 69 | echo "GPU $gpu_id is available. Starting prune_finetune.py with dataset '$arg', prune_percent '$prune_percent', prune_type '$prune_type', prune_decay '$prune_decay', and v_pow '$vp' on port $port" 70 | 71 | CUDA_VISIBLE_DEVICES=$gpu_id nohup python prune_finetune.py \ 72 | -s "PATH/TO/DATASET/$arg" \ 73 | -m "OUTPUT/PATH/${arg}_${prune_percent}" \ 74 | --eval \ 75 | --port $port \ 76 | --start_checkpoint "PATH/TO/CHECKPOINT/$arg/chkpnt30000.pth" \ 77 | --iteration 35000 \ 78 | --prune_percent $prune_percent \ 79 | --prune_type $prune_type \ 80 | --prune_decay $prune_decay \ 81 | --position_lr_max_steps 35000 \ 82 | --v_pow $vp > "logs_prune/${arg}${prune_percent}prunned.log" 2>&1 & 83 | 84 | # Increment the port number for the next run 85 | ((port++)) 86 | # Allow some time for the process to initialize and potentially use GPU memory 87 | sleep 60 88 | break 89 | else 90 | echo "No GPU available at the moment. Retrying in 1 minute." 91 | sleep 60 92 | fi 93 | done 94 | done 95 | done 96 | done 97 | wait 98 | echo "All prune_finetune.py runs completed." 99 | -------------------------------------------------------------------------------- /scripts/run_prune_pt_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to get the id of an available GPU 4 | get_available_gpu() { 5 | local mem_threshold=10000 6 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | awk -v threshold="$mem_threshold" -F', ' ' 7 | $2 < threshold { print $1; exit } 8 | ' 9 | } 10 | 11 | # Initial port number 12 | port=6045 13 | # This is an example script to load from ply file. 14 | # Only one dataset specified here, but you could run multiple 15 | declare -a run_args=( 16 | "bicycle" 17 | # "bonsai" 18 | # "counter" 19 | # "kitchen" 20 | # "room" 21 | # "stump" 22 | # "garden" 23 | # "train" 24 | # "truck" 25 | # "chair" 26 | # "drums" 27 | # "ficus" 28 | # "hotdog" 29 | # "lego" 30 | # "mic" 31 | # "materials" 32 | # "ship" 33 | ) 34 | 35 | 36 | # Prune percentages and corresponding decays, volume power 37 | declare -a prune_percents=(0.66) 38 | # decay rate for the following prune 39 | declare -a prune_decays=(1) 40 | # The volumetric importance power. The higher it is the more weight the volume is in the Global significant 41 | declare -a v_pow=(0.1) 42 | 43 | # prune type, by default the Global significant listed in the paper, but there are other option that you can play with 44 | declare -a prune_types=( 45 | "v_important_score" 46 | # "important_score" 47 | # "count" 48 | ) 49 | 50 | 51 | # Check that prune_percents, prune_decays, and v_pow arrays have the same length 52 | if [ "${#prune_percents[@]}" -ne "${#prune_decays[@]}" ] || [ "${#prune_percents[@]}" -ne "${#v_pow[@]}" ]; then 53 | echo "The lengths of prune_percents, prune_decays, and v_pow arrays do not match." 54 | exit 1 55 | fi 56 | # /ssd1/zhiwen/projects/compress_gaussian/output2/bicycle/point_cloud/iteration_30000/point_cloud.ply 57 | # Loop over the arguments array 58 | for arg in "${run_args[@]}"; do 59 | for i in "${!prune_percents[@]}"; do 60 | prune_percent="${prune_percents[i]}" 61 | prune_decay="${prune_decays[i]}" 62 | vp="${v_pow[i]}" 63 | 64 | for prune_type in "${prune_types[@]}"; do 65 | # Wait for an available GPU 66 | while true; do 67 | gpu_id=$(get_available_gpu) 68 | if [[ -n $gpu_id ]]; then 69 | echo "GPU $gpu_id is available. Starting prune_finetune.py with dataset '$arg', prune_percent '$prune_percent', prune_type '$prune_type', prune_decay '$prune_decay', and v_pow '$vp' on port $port" 70 | 71 | CUDA_VISIBLE_DEVICES=$gpu_id python prune_finetune.py \ 72 | -s "PATH/TO/DATASET/$arg" \ 73 | -m "OUTPUT/PATH/${arg}_${prune_percent}" \ 74 | --eval \ 75 | --port $port \ 76 | --start_pointcloud "PATH/TO/CHECKPOINT/$arg/point_cloud/iteration_30000/point_cloud.ply" \ 77 | --iteration 5000 \ 78 | --test_iterations 5000 \ 79 | --save_iterations 5000 \ 80 | --prune_iterations 2 \ 81 | --prune_percent $prune_percent \ 82 | --prune_type $prune_type \ 83 | --prune_decay $prune_decay \ 84 | --position_lr_init 0.000005 \ 85 | --position_lr_max_steps 5000 \ 86 | --v_pow $vp > "logs_prune/${arg}${prune_percent}_ply_prune2.log" 2>&1 & 87 | 88 | # Increment the port number for the next run 89 | ((port++)) 90 | # Allow some time for the process to initialize and potentially use GPU memory 91 | sleep 60 92 | break 93 | else 94 | echo "No GPU available at the moment. Retrying in 1 minute." 95 | sleep 60 96 | fi 97 | done 98 | done 99 | done 100 | done 101 | wait 102 | echo "All prune_finetune.py runs completed." 103 | -------------------------------------------------------------------------------- /scripts/run_train_densify_prune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to get the id of an available GPU 4 | get_available_gpu() { 5 | local mem_threshold=5000 6 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \ 7 | awk -v threshold="$mem_threshold" -F', ' ' 8 | $2 < threshold { print $1; exit } 9 | ' 10 | } 11 | 12 | port=6035 13 | 14 | # Only one dataset specified here 15 | declare -a run_args=( 16 | "bicycle" 17 | # "bonsai" 18 | # "counter" 19 | # "kitchen" 20 | # "room" 21 | # "stump" 22 | # "garden" 23 | # "train" 24 | # "truck" 25 | ) 26 | 27 | # prune percentage for the first prune 28 | declare -a prune_percents=(0.6) 29 | 30 | # decay rate for the following prune 31 | declare -a prune_decays=(0.6) 32 | 33 | # The volumetric importance power 34 | declare -a v_pow=(0.1) 35 | 36 | # Prune types 37 | declare -a prune_types=( 38 | "v_important_score" 39 | ) 40 | 41 | # Check that prune_percents and prune_decays arrays have the same length 42 | if [ "${#prune_percents[@]}" -ne "${#prune_decays[@]}" ]; then 43 | echo "The number of prune_percents does not match the number of prune_decays." 44 | exit 1 45 | fi 46 | 47 | # Loop over datasets 48 | for arg in "${run_args[@]}"; do 49 | # Loop over each index in prune_percents/decays/v_pow 50 | for i in "${!prune_percents[@]}"; do 51 | prune_percent="${prune_percents[i]}" 52 | prune_decay="${prune_decays[i]}" 53 | vp="${v_pow[i]}" 54 | 55 | # Loop over each prune type 56 | for prune_type in "${prune_types[@]}"; do 57 | 58 | # Wait for an available GPU 59 | while true; do 60 | gpu_id=$(get_available_gpu) 61 | if [[ -n $gpu_id ]]; then 62 | echo "GPU $gpu_id is available. Starting train_densify_prune.py with dataset '$arg', prune_percent '$prune_percent', prune_type '$prune_type', prune_decay '$prune_decay', and v_pow '$vp' on port $port" 63 | 64 | CUDA_VISIBLE_DEVICES=$gpu_id nohup python train_densify_prune.py \ 65 | -s "PATH/TO/DATASET/$arg" \ 66 | -m "OUTPUT/PATH/${arg}" \ 67 | --prune_percent "$prune_percent" \ 68 | --prune_decay "$prune_decay" \ 69 | --prune_iterations 20000 \ 70 | --v_pow "$vp" \ 71 | --eval \ 72 | --port "$port" \ 73 | > "logs/train_${arg}.log" 2>&1 & 74 | 75 | # you need to create the log folder first if it doesn't exist 76 | ((port++)) 77 | 78 | # Give the process time to start using GPU memory 79 | sleep 60 80 | break 81 | else 82 | echo "No GPU available at the moment. Retrying in 1 minute." 83 | sleep 60 84 | fi 85 | done 86 | 87 | done # end for prune_type 88 | done # end for i 89 | done # end for arg 90 | 91 | # Wait for all background processes to finish 92 | wait 93 | echo "All train_densify_prune.py runs completed." 94 | -------------------------------------------------------------------------------- /scripts/run_vectree_quantize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # SCENES=(bicycle bonsai counter garden kitchen room stump train truck) 4 | SCENES=(room) 5 | VQ_RATIO=0.6 6 | CODEBOOK_SIZE=8192 7 | 8 | for SCENE in "${SCENES[@]}" # Add more scenes as needed 9 | do 10 | IMP_PATH=./vectree/pruned_distilled/${SCENE} 11 | INPUT_PLY_PATH=./vectree/pruned_distilled/${SCENE}/iteration_40000/point_cloud.ply 12 | SAVE_PATH=./vectree/output/${SCENE} 13 | 14 | CMD="CUDA_VISIBLE_DEVICES=0 python vectree/vectree.py \ 15 | --important_score_npz_path ${IMP_PATH} \ 16 | --input_path ${INPUT_PLY_PATH} \ 17 | --save_path ${SAVE_PATH} \ 18 | --vq_ratio ${VQ_RATIO} \ 19 | --codebook_size ${CODEBOOK_SIZE} \ 20 | " 21 | eval $CMD 22 | done -------------------------------------------------------------------------------- /static/table5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/LightGaussian/6676b983e77baadd909effc56a6aaadafa964dcc/static/table5.png -------------------------------------------------------------------------------- /submodules/simple-knn/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "spatial.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("distCUDA2", &distCUDA2); 17 | } 18 | -------------------------------------------------------------------------------- /submodules/simple-knn/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | cxx_compiler_flags = [] 17 | 18 | if os.name == "nt": 19 | cxx_compiler_flags.append("/wd4624") 20 | 21 | setup( 22 | name="simple_knn", 23 | ext_modules=[ 24 | CUDAExtension( 25 | name="simple_knn._C", 26 | sources=["spatial.cu", "simple_knn.cu", "ext.cpp"], 27 | extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}, 28 | ) 29 | ], 30 | cmdclass={"build_ext": BuildExtension}, 31 | ) 32 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #define BOX_SIZE 1024 13 | 14 | #include "cuda_runtime.h" 15 | #include "device_launch_parameters.h" 16 | #include "simple_knn.h" 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #define __CUDACC__ 24 | #include 25 | #include 26 | 27 | namespace cg = cooperative_groups; 28 | 29 | struct CustomMin 30 | { 31 | __device__ __forceinline__ 32 | float3 operator()(const float3& a, const float3& b) const { 33 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; 34 | } 35 | }; 36 | 37 | struct CustomMax 38 | { 39 | __device__ __forceinline__ 40 | float3 operator()(const float3& a, const float3& b) const { 41 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; 42 | } 43 | }; 44 | 45 | __host__ __device__ uint32_t prepMorton(uint32_t x) 46 | { 47 | x = (x | (x << 16)) & 0x030000FF; 48 | x = (x | (x << 8)) & 0x0300F00F; 49 | x = (x | (x << 4)) & 0x030C30C3; 50 | x = (x | (x << 2)) & 0x09249249; 51 | return x; 52 | } 53 | 54 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) 55 | { 56 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); 57 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); 58 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); 59 | 60 | return x | (y << 1) | (z << 2); 61 | } 62 | 63 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) 64 | { 65 | auto idx = cg::this_grid().thread_rank(); 66 | if (idx >= P) 67 | return; 68 | 69 | codes[idx] = coord2Morton(points[idx], minn, maxx); 70 | } 71 | 72 | struct MinMax 73 | { 74 | float3 minn; 75 | float3 maxx; 76 | }; 77 | 78 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) 79 | { 80 | auto idx = cg::this_grid().thread_rank(); 81 | 82 | MinMax me; 83 | if (idx < P) 84 | { 85 | me.minn = points[indices[idx]]; 86 | me.maxx = points[indices[idx]]; 87 | } 88 | else 89 | { 90 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; 91 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; 92 | } 93 | 94 | __shared__ MinMax redResult[BOX_SIZE]; 95 | 96 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2) 97 | { 98 | if (threadIdx.x < 2 * off) 99 | redResult[threadIdx.x] = me; 100 | __syncthreads(); 101 | 102 | if (threadIdx.x < off) 103 | { 104 | MinMax other = redResult[threadIdx.x + off]; 105 | me.minn.x = min(me.minn.x, other.minn.x); 106 | me.minn.y = min(me.minn.y, other.minn.y); 107 | me.minn.z = min(me.minn.z, other.minn.z); 108 | me.maxx.x = max(me.maxx.x, other.maxx.x); 109 | me.maxx.y = max(me.maxx.y, other.maxx.y); 110 | me.maxx.z = max(me.maxx.z, other.maxx.z); 111 | } 112 | __syncthreads(); 113 | } 114 | 115 | if (threadIdx.x == 0) 116 | boxes[blockIdx.x] = me; 117 | } 118 | 119 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) 120 | { 121 | float3 diff = { 0, 0, 0 }; 122 | if (p.x < box.minn.x || p.x > box.maxx.x) 123 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); 124 | if (p.y < box.minn.y || p.y > box.maxx.y) 125 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); 126 | if (p.z < box.minn.z || p.z > box.maxx.z) 127 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); 128 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; 129 | } 130 | 131 | template 132 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn) 133 | { 134 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; 135 | float dist = d.x * d.x + d.y * d.y + d.z * d.z; 136 | for (int j = 0; j < K; j++) 137 | { 138 | if (knn[j] > dist) 139 | { 140 | float t = knn[j]; 141 | knn[j] = dist; 142 | dist = t; 143 | } 144 | } 145 | } 146 | 147 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) 148 | { 149 | int idx = cg::this_grid().thread_rank(); 150 | if (idx >= P) 151 | return; 152 | 153 | float3 point = points[indices[idx]]; 154 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; 155 | 156 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) 157 | { 158 | if (i == idx) 159 | continue; 160 | updateKBest<3>(point, points[indices[i]], best); 161 | } 162 | 163 | float reject = best[2]; 164 | best[0] = FLT_MAX; 165 | best[1] = FLT_MAX; 166 | best[2] = FLT_MAX; 167 | 168 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) 169 | { 170 | MinMax box = boxes[b]; 171 | float dist = distBoxPoint(box, point); 172 | if (dist > reject || dist > best[2]) 173 | continue; 174 | 175 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) 176 | { 177 | if (i == idx) 178 | continue; 179 | updateKBest<3>(point, points[indices[i]], best); 180 | } 181 | } 182 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; 183 | } 184 | 185 | void SimpleKNN::knn(int P, float3* points, float* meanDists) 186 | { 187 | float3* result; 188 | cudaMalloc(&result, sizeof(float3)); 189 | size_t temp_storage_bytes; 190 | 191 | float3 init = { 0, 0, 0 }, minn, maxx; 192 | 193 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); 194 | thrust::device_vector temp_storage(temp_storage_bytes); 195 | 196 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); 197 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); 198 | 199 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); 200 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); 201 | 202 | thrust::device_vector morton(P); 203 | thrust::device_vector morton_sorted(P); 204 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); 205 | 206 | thrust::device_vector indices(P); 207 | thrust::sequence(indices.begin(), indices.end()); 208 | thrust::device_vector indices_sorted(P); 209 | 210 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 211 | temp_storage.resize(temp_storage_bytes); 212 | 213 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 214 | 215 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; 216 | thrust::device_vector boxes(num_boxes); 217 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); 218 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); 219 | 220 | cudaFree(result); 221 | } -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: simple-knn 3 | Version: 0.0.0 4 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | ext.cpp 2 | setup.py 3 | simple_knn.cu 4 | spatial.cu 5 | simple_knn.egg-info/PKG-INFO 6 | simple_knn.egg-info/SOURCES.txt 7 | simple_knn.egg-info/dependency_links.txt 8 | simple_knn.egg-info/top_level.txt -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | simple_knn 2 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef SIMPLEKNN_H_INCLUDED 13 | #define SIMPLEKNN_H_INCLUDED 14 | 15 | class SimpleKNN 16 | { 17 | public: 18 | static void knn(int P, float3* points, float* meanDists); 19 | }; 20 | 21 | #endif -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/LightGaussian/6676b983e77baadd909effc56a6aaadafa964dcc/submodules/simple-knn/simple_knn/.gitkeep -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "spatial.h" 13 | #include "simple_knn.h" 14 | 15 | torch::Tensor 16 | distCUDA2(const torch::Tensor& points) 17 | { 18 | const int P = points.size(0); 19 | 20 | auto float_opts = points.options().dtype(torch::kFloat32); 21 | torch::Tensor means = torch::full({P}, 0.0, float_opts); 22 | 23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); 24 | 25 | return means; 26 | } -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | 14 | torch::Tensor distCUDA2(const torch::Tensor& points); -------------------------------------------------------------------------------- /train_densify_prune.py: -------------------------------------------------------------------------------- 1 | # 2 | # This software is free for non-commercial, research and evaluation use 3 | # under the terms of the LICENSE.md file. 4 | # 5 | # For inquiries contact george.drettakis@inria.fr 6 | # 7 | import os 8 | import torch 9 | from random import randint 10 | from utils.loss_utils import l1_loss, ssim 11 | from gaussian_renderer import render, network_gui 12 | import sys 13 | from lpipsPyTorch import lpips 14 | 15 | from scene import Scene, GaussianModel 16 | from utils.general_utils import safe_state 17 | from utils.logger_utils import training_report, prepare_output_and_logger 18 | 19 | import uuid 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser, Namespace 23 | from arguments import ModelParams, PipelineParams, OptimizationParams 24 | 25 | # from prune_train import prepare_output_and_logger, training_report 26 | from icecream import ic 27 | from os import makedirs 28 | from prune import prune_list, calculate_v_imp_score 29 | import torchvision 30 | from torch.optim.lr_scheduler import ExponentialLR 31 | import csv 32 | import numpy as np 33 | 34 | 35 | try: 36 | from torch.utils.tensorboard import SummaryWriter 37 | 38 | TENSORBOARD_FOUND = True 39 | except ImportError: 40 | TENSORBOARD_FOUND = False 41 | 42 | 43 | def training( 44 | dataset, 45 | opt, 46 | pipe, 47 | testing_iterations, 48 | saving_iterations, 49 | checkpoint_iterations, 50 | checkpoint, 51 | debug_from, 52 | args, 53 | ): 54 | first_iter = 0 55 | tb_writer = prepare_output_and_logger(dataset) 56 | gaussians = GaussianModel(dataset.sh_degree) 57 | scene = Scene(dataset, gaussians) 58 | gaussians.training_setup(opt) 59 | if checkpoint: 60 | (model_params, first_iter) = torch.load(checkpoint) 61 | gaussians.restore(model_params, opt) 62 | 63 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 64 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 65 | 66 | iter_start = torch.cuda.Event(enable_timing=True) 67 | iter_end = torch.cuda.Event(enable_timing=True) 68 | 69 | viewpoint_stack = None 70 | ema_loss_for_log = 0.0 71 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 72 | first_iter += 1 73 | gaussians.scheduler = ExponentialLR(gaussians.optimizer, gamma=0.97) 74 | for iteration in range(first_iter, opt.iterations + 1): 75 | if network_gui.conn == None: 76 | network_gui.try_connect() 77 | while network_gui.conn != None: 78 | try: 79 | net_image_bytes = None 80 | ( 81 | custom_cam, 82 | do_training, 83 | pipe.convert_SHs_python, 84 | pipe.compute_cov3D_python, 85 | keep_alive, 86 | scaling_modifer, 87 | ) = network_gui.receive() 88 | if custom_cam != None: 89 | net_image = render( 90 | custom_cam, gaussians, pipe, background, scaling_modifer 91 | )["render"] 92 | net_image_bytes = memoryview( 93 | (torch.clamp(net_image, min=0, max=1.0) * 255) 94 | .byte() 95 | .permute(1, 2, 0) 96 | .contiguous() 97 | .cpu() 98 | .numpy() 99 | ) 100 | network_gui.send(net_image_bytes, dataset.source_path) 101 | if do_training and ( 102 | (iteration < int(opt.iterations)) or not keep_alive 103 | ): 104 | break 105 | except Exception as e: 106 | network_gui.conn = None 107 | 108 | iter_start.record() 109 | 110 | gaussians.update_learning_rate(iteration) 111 | 112 | # Every 1000 its we increase the levels of SH up to a maximum degree 113 | if iteration % 1000 == 0: 114 | gaussians.oneupSHdegree() 115 | gaussians.scheduler.step() 116 | 117 | # Pick a random Camera 118 | if not viewpoint_stack: 119 | viewpoint_stack = scene.getTrainCameras().copy() 120 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) 121 | 122 | # Render 123 | if (iteration - 1) == debug_from: 124 | pipe.debug = True 125 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 126 | image, viewspace_point_tensor, visibility_filter, radii = ( 127 | render_pkg["render"], 128 | render_pkg["viewspace_points"], 129 | render_pkg["visibility_filter"], 130 | render_pkg["radii"], 131 | ) 132 | 133 | # Loss 134 | gt_image = viewpoint_cam.original_image.cuda() 135 | Ll1 = l1_loss(image, gt_image) 136 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ( 137 | 1.0 - ssim(image, gt_image) 138 | ) 139 | loss.backward() 140 | 141 | iter_end.record() 142 | 143 | with torch.no_grad(): 144 | # Progress bar 145 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 146 | if iteration % 10 == 0: 147 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 148 | progress_bar.update(10) 149 | if iteration == opt.iterations: 150 | progress_bar.close() 151 | 152 | # Log and save 153 | if iteration in saving_iterations: 154 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 155 | scene.save(iteration) 156 | training_report( 157 | tb_writer, 158 | iteration, 159 | Ll1, 160 | loss, 161 | l1_loss, 162 | iter_start.elapsed_time(iter_end), 163 | testing_iterations, 164 | scene, 165 | render, 166 | (pipe, background), 167 | ) 168 | 169 | # Densification 170 | if iteration < opt.densify_until_iter: 171 | # Keep track of max radii in image-space for pruning 172 | gaussians.max_radii2D[visibility_filter] = torch.max( 173 | gaussians.max_radii2D[visibility_filter], radii[visibility_filter] 174 | ) 175 | gaussians.add_densification_stats( 176 | viewspace_point_tensor, visibility_filter 177 | ) 178 | 179 | if ( 180 | iteration > opt.densify_from_iter 181 | and iteration % opt.densification_interval == 0 182 | ): 183 | size_threshold = ( 184 | 20 if iteration > opt.opacity_reset_interval else None 185 | ) 186 | gaussians.densify_and_prune( 187 | opt.densify_grad_threshold, 188 | 0.005, 189 | scene.cameras_extent, 190 | size_threshold, 191 | ) 192 | 193 | if iteration % opt.opacity_reset_interval == 0 or ( 194 | dataset.white_background and iteration == opt.densify_from_iter 195 | ): 196 | gaussians.reset_opacity() 197 | 198 | if iteration in args.prune_iterations: 199 | # TODO Add prunning types 200 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background) 201 | i = args.prune_iterations.index(iteration) 202 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow) 203 | gaussians.prune_gaussians( 204 | (args.prune_decay**i) * args.prune_percent, v_list 205 | ) 206 | 207 | 208 | 209 | # Optimizer step 210 | if iteration < opt.iterations: 211 | gaussians.optimizer.step() 212 | gaussians.optimizer.zero_grad(set_to_none=True) 213 | 214 | if iteration in checkpoint_iterations: 215 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 216 | if not os.path.exists(scene.model_path): 217 | os.makedirs(scene.model_path) 218 | torch.save( 219 | (gaussians.capture(), iteration), 220 | scene.model_path + "/chkpnt" + str(iteration) + ".pth", 221 | ) 222 | if iteration == checkpoint_iterations[-1]: 223 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background) 224 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow) 225 | np.savez(os.path.join(scene.model_path,"imp_score"), v_list.cpu().detach().numpy()) 226 | 227 | 228 | if __name__ == "__main__": 229 | # Set up command line argument parser 230 | parser = ArgumentParser(description="Training script parameters") 231 | lp = ModelParams(parser) 232 | op = OptimizationParams(parser) 233 | pp = PipelineParams(parser) 234 | parser.add_argument("--ip", type=str, default="127.0.0.1") 235 | parser.add_argument("--port", type=int, default=6009) 236 | parser.add_argument("--debug_from", type=int, default=-1) 237 | parser.add_argument("--detect_anomaly", action="store_true", default=False) 238 | parser.add_argument( 239 | "--test_iterations", 240 | nargs="+", 241 | type=int, 242 | default=[7_000, 30_000], 243 | ) 244 | parser.add_argument( 245 | "--save_iterations", nargs="+", type=int, default=[7_000, 30_000] 246 | ) 247 | parser.add_argument("--quiet", action="store_true") 248 | parser.add_argument( 249 | "--checkpoint_iterations", nargs="+", type=int, default=[7_000, 30_000] 250 | ) 251 | parser.add_argument("--start_checkpoint", type=str, default=None) 252 | 253 | parser.add_argument( 254 | "--prune_iterations", nargs="+", type=int, default=[16_000, 24_000] 255 | ) 256 | parser.add_argument("--prune_percent", type=float, default=0.5) 257 | parser.add_argument("--v_pow", type=float, default=0.1) 258 | parser.add_argument("--prune_decay", type=float, default=0.8) 259 | args = parser.parse_args(sys.argv[1:]) 260 | args.save_iterations.append(args.iterations) 261 | 262 | print("Optimizing " + args.model_path) 263 | # Initialize system state (RNG) 264 | safe_state(args.quiet) 265 | # Start GUI server, configure and run training 266 | network_gui.init(args.ip, args.port) 267 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 268 | training( 269 | lp.extract(args), 270 | op.extract(args), 271 | pp.extract(args), 272 | args.test_iterations, 273 | args.save_iterations, 274 | args.checkpoint_iterations, 275 | args.start_checkpoint, 276 | args.debug_from, 277 | args, 278 | ) 279 | 280 | # All done 281 | print("\nTraining complete.") 282 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | 17 | WARNED = False 18 | 19 | 20 | def loadCam(args, id, cam_info, resolution_scale): 21 | orig_w, orig_h = cam_info.image.size 22 | 23 | if args.resolution in [1, 2, 4, 8]: 24 | resolution = round(orig_w / (resolution_scale * args.resolution)), round( 25 | orig_h / (resolution_scale * args.resolution) 26 | ) 27 | else: # should be a type that converts to float 28 | if args.resolution == -1: 29 | if orig_w > 1600: 30 | global WARNED 31 | if not WARNED: 32 | print( 33 | "[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 34 | "If this is not desired, please explicitly specify '--resolution/-r' as 1" 35 | ) 36 | WARNED = True 37 | global_down = orig_w / 1600 38 | else: 39 | global_down = 1 40 | else: 41 | global_down = orig_w / args.resolution 42 | 43 | scale = float(global_down) * float(resolution_scale) 44 | resolution = (int(orig_w / scale), int(orig_h / scale)) 45 | 46 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 47 | 48 | gt_image = resized_image_rgb[:3, ...] 49 | loaded_mask = None 50 | 51 | if resized_image_rgb.shape[1] == 4: 52 | loaded_mask = resized_image_rgb[3:4, ...] 53 | 54 | return Camera( 55 | colmap_id=cam_info.uid, 56 | R=cam_info.R, 57 | T=cam_info.T, 58 | FoVx=cam_info.FovX, 59 | FoVy=cam_info.FovY, 60 | image=gt_image, 61 | gt_alpha_mask=loaded_mask, 62 | image_name=cam_info.image_name, 63 | uid=id, 64 | data_device=args.data_device, 65 | ) 66 | 67 | 68 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 69 | camera_list = [] 70 | 71 | for id, c in enumerate(cam_infos): 72 | camera_list.append(loadCam(args, id, c, resolution_scale)) 73 | 74 | return camera_list 75 | 76 | 77 | def camera_to_JSON(id, camera: Camera): 78 | Rt = np.zeros((4, 4)) 79 | Rt[:3, :3] = camera.R.transpose() 80 | Rt[:3, 3] = camera.T 81 | Rt[3, 3] = 1.0 82 | 83 | W2C = np.linalg.inv(Rt) 84 | pos = W2C[:3, 3] 85 | rot = W2C[:3, :3] 86 | serializable_array_2d = [x.tolist() for x in rot] 87 | camera_entry = { 88 | "id": id, 89 | "img_name": camera.image_name, 90 | "width": camera.width, 91 | "height": camera.height, 92 | "position": pos.tolist(), 93 | "rotation": serializable_array_2d, 94 | "fy": fov2focal(camera.FovY, camera.height), 95 | "fx": fov2focal(camera.FovX, camera.width), 96 | } 97 | return camera_entry 98 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x / (1 - x)) 21 | 22 | 23 | def PILtoTorch(pil_image, resolution): 24 | resized_image_PIL = pil_image.resize(resolution) 25 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 26 | if len(resized_image.shape) == 3: 27 | return resized_image.permute(2, 0, 1) 28 | else: 29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 30 | 31 | 32 | def get_expon_lr_func( 33 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 34 | ): 35 | """ 36 | Copied from Plenoxels 37 | 38 | Continuous learning rate decay function. Adapted from JaxNeRF 39 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 40 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 41 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 42 | function of lr_delay_mult, such that the initial learning rate is 43 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 44 | to the normal learning rate when steps>lr_delay_steps. 45 | :param conf: config subtree 'lr' or similar 46 | :param max_steps: int, the number of steps during optimization. 47 | :return HoF which takes step as input 48 | """ 49 | 50 | def helper(step): 51 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 52 | # Disable this parameter 53 | return 0.0 54 | if lr_delay_steps > 0: 55 | # A kind of reverse cosine decay. 56 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 57 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 58 | ) 59 | else: 60 | delay_rate = 1.0 61 | t = np.clip(step / max_steps, 0, 1) 62 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 63 | return delay_rate * log_lerp 64 | 65 | return helper 66 | 67 | 68 | def strip_lowerdiag(L): 69 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 70 | 71 | uncertainty[:, 0] = L[:, 0, 0] 72 | uncertainty[:, 1] = L[:, 0, 1] 73 | uncertainty[:, 2] = L[:, 0, 2] 74 | uncertainty[:, 3] = L[:, 1, 1] 75 | uncertainty[:, 4] = L[:, 1, 2] 76 | uncertainty[:, 5] = L[:, 2, 2] 77 | return uncertainty 78 | 79 | 80 | def strip_symmetric(sym): 81 | return strip_lowerdiag(sym) 82 | 83 | 84 | def build_rotation(r): 85 | norm = torch.sqrt( 86 | r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] 87 | ) 88 | 89 | q = r / norm[:, None] 90 | 91 | R = torch.zeros((q.size(0), 3, 3), device="cuda") 92 | 93 | r = q[:, 0] 94 | x = q[:, 1] 95 | y = q[:, 2] 96 | z = q[:, 3] 97 | 98 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 99 | R[:, 0, 1] = 2 * (x * y - r * z) 100 | R[:, 0, 2] = 2 * (x * z + r * y) 101 | R[:, 1, 0] = 2 * (x * y + r * z) 102 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 103 | R[:, 1, 2] = 2 * (y * z - r * x) 104 | R[:, 2, 0] = 2 * (x * z - r * y) 105 | R[:, 2, 1] = 2 * (y * z + r * x) 106 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 107 | return R 108 | 109 | 110 | def build_scaling_rotation(s, r): 111 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 112 | R = build_rotation(r) 113 | 114 | L[:, 0, 0] = s[:, 0] 115 | L[:, 1, 1] = s[:, 1] 116 | L[:, 2, 2] = s[:, 2] 117 | 118 | L = R @ L 119 | return L 120 | 121 | 122 | def safe_state(silent): 123 | old_f = sys.stdout 124 | 125 | class F: 126 | def __init__(self, silent): 127 | self.silent = silent 128 | 129 | def write(self, x): 130 | if not self.silent: 131 | if x.endswith("\n"): 132 | old_f.write( 133 | x.replace( 134 | "\n", 135 | " [{}]\n".format( 136 | str(datetime.now().strftime("%d/%m %H:%M:%S")) 137 | ), 138 | ) 139 | ) 140 | else: 141 | old_f.write(x) 142 | 143 | def flush(self): 144 | old_f.flush() 145 | 146 | sys.stdout = F(silent) 147 | 148 | random.seed(0) 149 | np.random.seed(0) 150 | torch.manual_seed(0) 151 | torch.cuda.set_device(torch.device("cuda:0")) 152 | 153 | 154 | class CircularTensor: 155 | def __init__(self, max_size): 156 | self.buffer = torch.empty(max_size) 157 | self.max_size = max_size 158 | self.current_pos = 0 159 | self.current_size = 0 # Tracks the number of elements added 160 | 161 | def add(self, element): 162 | self.buffer[self.current_pos] = element 163 | self.current_pos = (self.current_pos + 1) % self.max_size 164 | if self.current_size < self.max_size: 165 | self.current_size += 1 166 | 167 | def get(self, index): 168 | if index >= self.current_size: 169 | raise IndexError("Index out of bounds") 170 | return self.buffer[index] 171 | 172 | def size(self): 173 | return self.current_size 174 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | 18 | class BasicPointCloud(NamedTuple): 19 | points: np.array 20 | colors: np.array 21 | normals: np.array 22 | 23 | 24 | def geom_transform_points(points, transf_matrix): 25 | P, _ = points.shape 26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 27 | points_hom = torch.cat([points, ones], dim=1) 28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 29 | 30 | denom = points_out[..., 3:] + 0.0000001 31 | return (points_out[..., :3] / denom).squeeze(dim=0) 32 | 33 | 34 | def getWorld2View(R, t): 35 | Rt = np.zeros((4, 4)) 36 | Rt[:3, :3] = R.transpose() 37 | Rt[:3, 3] = t 38 | Rt[3, 3] = 1.0 39 | return np.float32(Rt) 40 | 41 | 42 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): 43 | Rt = np.zeros((4, 4)) 44 | Rt[:3, :3] = R.transpose() 45 | Rt[:3, 3] = t 46 | Rt[3, 3] = 1.0 47 | 48 | C2W = np.linalg.inv(Rt) 49 | cam_center = C2W[:3, 3] 50 | cam_center = (cam_center + translate) * scale 51 | C2W[:3, 3] = cam_center 52 | Rt = np.linalg.inv(C2W) 53 | return np.float32(Rt) 54 | 55 | 56 | def getProjectionMatrix(znear, zfar, fovX, fovY): 57 | tanHalfFovY = math.tan((fovY / 2)) 58 | tanHalfFovX = math.tan((fovX / 2)) 59 | 60 | top = tanHalfFovY * znear 61 | bottom = -top 62 | right = tanHalfFovX * znear 63 | left = -right 64 | 65 | P = torch.zeros(4, 4) 66 | 67 | z_sign = 1.0 68 | 69 | P[0, 0] = 2.0 * znear / (right - left) 70 | P[1, 1] = 2.0 * znear / (top - bottom) 71 | P[0, 2] = (right + left) / (right - left) 72 | P[1, 2] = (top + bottom) / (top - bottom) 73 | P[3, 2] = z_sign 74 | P[2, 2] = z_sign * zfar / (zfar - znear) 75 | P[2, 3] = -(zfar * znear) / (zfar - znear) 76 | return P 77 | 78 | 79 | def fov2focal(fov, pixels): 80 | return pixels / (2 * math.tan(fov / 2)) 81 | 82 | 83 | def focal2fov(focal, pixels): 84 | return 2 * math.atan(pixels / (2 * focal)) 85 | -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math, random, time 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import imageio 10 | from pdb import set_trace as st 11 | 12 | 13 | mse2psnr = ( 14 | lambda x: -10.0 * torch.log(x) / torch.log(torch.tensor([10.0], device=x.device)) 15 | ) 16 | 17 | 18 | def img2mse(x, y, mask=None): 19 | if mask is None: 20 | return torch.mean((x - y) ** 2) 21 | else: 22 | return torch.sum((x * mask - y * mask) ** 2) / (torch.sum(mask) + 1e-5) 23 | 24 | 25 | def img2mae(x, y, mask=None): 26 | if mask is None: 27 | return torch.mean(torch.abs(x - y)) 28 | else: 29 | return torch.sum(torch.abs(x * mask - y * mask)) / (torch.sum(mask) + 1e-5) 30 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | 15 | def mse(img1, img2): 16 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 17 | 18 | 19 | def psnr(img1, img2): 20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 21 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 22 | -------------------------------------------------------------------------------- /utils/logger_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from random import randint 4 | from utils.loss_utils import l1_loss, ssim 5 | from lpipsPyTorch import lpips 6 | from scene import Scene, GaussianModel 7 | from utils.general_utils import safe_state 8 | import uuid 9 | from utils.image_utils import psnr 10 | from argparse import Namespace 11 | from icecream import ic 12 | import csv 13 | 14 | try: 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | TENSORBOARD_FOUND = True 18 | except ImportError: 19 | TENSORBOARD_FOUND = False 20 | 21 | 22 | def prepare_output_and_logger(args): 23 | if not args.model_path: 24 | if os.getenv("OAR_JOB_ID"): 25 | unique_str = os.getenv("OAR_JOB_ID") 26 | else: 27 | unique_str = str(uuid.uuid4()) 28 | args.model_path = os.path.join("./output/", unique_str[0:10]) 29 | 30 | # Set up output folder 31 | print("Output folder: {}".format(args.model_path)) 32 | os.makedirs(args.model_path, exist_ok=True) 33 | with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f: 34 | cfg_log_f.write(str(Namespace(**vars(args)))) 35 | 36 | # Create Tensorboard writer 37 | tb_writer = None 38 | if TENSORBOARD_FOUND: 39 | tb_writer = SummaryWriter(args.model_path) 40 | else: 41 | print("Tensorboard not available: not logging progress") 42 | return tb_writer 43 | 44 | 45 | def training_report( 46 | tb_writer, 47 | iteration, 48 | Ll1, 49 | loss, 50 | l1_loss, 51 | elapsed, 52 | testing_iterations, 53 | scene: Scene, 54 | renderFunc, 55 | renderArgs, 56 | ): 57 | if tb_writer: 58 | tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration) 59 | tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration) 60 | tb_writer.add_scalar("iter_time", elapsed, iteration) 61 | 62 | # Report test and samples of training set 63 | if iteration in testing_iterations: 64 | ic("report") 65 | headers = [ 66 | "iteration", 67 | "set", 68 | "l1_loss", 69 | "psnr", 70 | "ssim", 71 | "lpips", 72 | "file_size", 73 | "elapsed", 74 | ] 75 | csv_path = os.path.join(scene.model_path, "metric.csv") 76 | # Check if the CSV file exists, if not, create it and write the header 77 | file_exists = os.path.isfile(csv_path) 78 | save_path = os.path.join( 79 | scene.model_path, 80 | "point_cloud/iteration_" + str(iteration), 81 | "point_cloud.ply", 82 | ) 83 | # Check if the file exists 84 | if os.path.exists(save_path): 85 | # Get the size of the file 86 | file_size = os.path.getsize(save_path) 87 | file_size_mb = file_size / 1024 / 1024 # Convert bytes to kilobytes 88 | else: 89 | file_size_mb = None 90 | 91 | with open(csv_path, "a", newline="") as csvfile: 92 | writer = csv.DictWriter(csvfile, fieldnames=headers) 93 | if not file_exists: 94 | writer.writeheader() # file doesn't exist yet, write a header 95 | 96 | torch.cuda.empty_cache() 97 | validation_configs = ({"name": "test", "cameras": scene.getTestCameras()},) 98 | # {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 99 | 100 | for config in validation_configs: 101 | if config["cameras"] and len(config["cameras"]) > 0: 102 | l1_test = 0.0 103 | psnr_test = 0.0 104 | ssim_test = 0.0 105 | lpips_test = 0.0 106 | for idx, viewpoint in enumerate(config["cameras"]): 107 | image = torch.clamp( 108 | renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 109 | 0.0, 110 | 1.0, 111 | ) 112 | gt_image = torch.clamp( 113 | viewpoint.original_image.to("cuda"), 0.0, 1.0 114 | ) 115 | if tb_writer and (idx < 5): 116 | tb_writer.add_images( 117 | config["name"] 118 | + "_view_{}/render".format(viewpoint.image_name), 119 | image[None], 120 | global_step=iteration, 121 | ) 122 | if iteration == testing_iterations[0]: 123 | tb_writer.add_images( 124 | config["name"] 125 | + "_view_{}/ground_truth".format(viewpoint.image_name), 126 | gt_image[None], 127 | global_step=iteration, 128 | ) 129 | l1_test += l1_loss(image, gt_image).mean().double() 130 | psnr_test += psnr(image, gt_image).mean().double() 131 | ssim_test += ssim(image, gt_image).mean().double() 132 | lpips_test += lpips(image, gt_image, net_type="vgg").mean().double() 133 | 134 | psnr_test /= len(config["cameras"]) 135 | l1_test /= len(config["cameras"]) 136 | ssim_test /= len(config["cameras"]) 137 | lpips_test /= len(config["cameras"]) 138 | # sys.stderr.write(f"Iteration {iteration} Evaluating {config['name']}: L1 {l1_test} PSNR {psnr_test} SSIM {ssim_test} LPIPS {lpips_test}\n") 139 | # sys.stderr.flush() 140 | print( 141 | "\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format( 142 | iteration, 143 | config["name"], 144 | l1_test, 145 | psnr_test, 146 | ssim_test, 147 | lpips_test, 148 | ) 149 | ) 150 | if tb_writer: 151 | tb_writer.add_scalar( 152 | config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration 153 | ) 154 | tb_writer.add_scalar( 155 | config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration 156 | ) 157 | tb_writer.add_scalar( 158 | config["name"] + "/loss_viewpoint - ssim", ssim_test, iteration 159 | ) 160 | tb_writer.add_scalar( 161 | config["name"] + "/loss_viewpoint - lpips", 162 | lpips_test, 163 | iteration, 164 | ) 165 | if config["name"] == "test": 166 | with open(csv_path, "a", newline="") as csvfile: 167 | writer = csv.DictWriter(csvfile, fieldnames=headers) 168 | writer.writerow( 169 | { 170 | "iteration": iteration, 171 | "set": config["name"], 172 | "l1_loss": l1_test.item(), 173 | "psnr": psnr_test.item(), 174 | "ssim": ssim_test.item(), 175 | "lpips": lpips_test.item(), 176 | "file_size": file_size_mb, 177 | "elapsed": elapsed, 178 | } 179 | ) 180 | 181 | if tb_writer: 182 | tb_writer.add_histogram( 183 | "scene/opacity_histogram", scene.gaussians.get_opacity, iteration 184 | ) 185 | tb_writer.add_scalar( 186 | "total_points", scene.gaussians.get_xyz.shape[0], iteration 187 | ) 188 | 189 | torch.cuda.empty_cache() 190 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | 18 | def l1_loss(network_output, gt): 19 | return torch.abs((network_output - gt)).mean() 20 | 21 | 22 | def l2_loss(network_output, gt): 23 | return ((network_output - gt) ** 2).mean() 24 | 25 | 26 | def gaussian(window_size, sigma): 27 | gauss = torch.Tensor( 28 | [ 29 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 30 | for x in range(window_size) 31 | ] 32 | ) 33 | return gauss / gauss.sum() 34 | 35 | 36 | def create_window(window_size, channel): 37 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 38 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 39 | window = Variable( 40 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 41 | ) 42 | return window 43 | 44 | 45 | def ssim(img1, img2, window_size=11, size_average=True): 46 | channel = img1.size(-3) 47 | window = create_window(window_size, channel) 48 | 49 | if img1.is_cuda: 50 | window = window.cuda(img1.get_device()) 51 | window = window.type_as(img1) 52 | 53 | return _ssim(img1, img2, window, window_size, channel, size_average) 54 | 55 | 56 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 57 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 58 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 59 | 60 | mu1_sq = mu1.pow(2) 61 | mu2_sq = mu2.pow(2) 62 | mu1_mu2 = mu1 * mu2 63 | 64 | sigma1_sq = ( 65 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 66 | ) 67 | sigma2_sq = ( 68 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 69 | ) 70 | sigma12 = ( 71 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 72 | - mu1_mu2 73 | ) 74 | 75 | C1 = 0.01**2 76 | C2 = 0.03**2 77 | 78 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 79 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 80 | ) 81 | 82 | if size_average: 83 | return ssim_map.mean() 84 | else: 85 | return ssim_map.mean(1).mean(1).mean(1) 86 | 87 | 88 | def img2mse(x, y, mask=None): 89 | if mask is None: 90 | return torch.mean((x - y) ** 2) 91 | else: 92 | return torch.sum((x * mask - y * mask) ** 2) / (torch.sum(mask) + 1e-5) 93 | 94 | 95 | def img2mae(x, y, mask=None): 96 | if mask is None: 97 | return torch.mean(torch.abs(x - y)) 98 | else: 99 | return torch.sum(torch.abs(x * mask - y * mask)) / (torch.sum(mask) + 1e-5) 100 | -------------------------------------------------------------------------------- /utils/save_imp_score.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | import sys 16 | from scene import Scene, GaussianModel 17 | from utils.general_utils import safe_state 18 | from tqdm import tqdm 19 | from utils.image_utils import psnr 20 | from argparse import ArgumentParser, Namespace 21 | from arguments import ModelParams, PipelineParams, OptimizationParams 22 | try: 23 | from torch.utils.tensorboard import SummaryWriter 24 | TENSORBOARD_FOUND = True 25 | except ImportError: 26 | TENSORBOARD_FOUND = False 27 | from icecream import ic 28 | import random 29 | import copy 30 | import gc 31 | from os import makedirs 32 | from prune import prune_list, calculate_v_imp_score 33 | import csv 34 | import numpy as np 35 | 36 | def save_imp_score(dataset, opt, pipe, checkpoint, args): 37 | gaussians = GaussianModel(dataset.sh_degree) 38 | scene = Scene(dataset, gaussians) 39 | gaussians.training_setup(opt) 40 | if checkpoint: 41 | (model_params, first_iter) = torch.load(checkpoint) 42 | gaussians.restore(model_params, opt) 43 | 44 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 45 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 46 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background) 47 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow) 48 | np.savez(os.path.join(scene.model_path,"imp_score"), v_list) 49 | 50 | # If you want to print the imp_score: 51 | if args.show_imp_score: 52 | data = np.load(os.path.join(scene.model_path,"imp_score.npz")) 53 | lst = data.files 54 | for item in lst: 55 | ic(item) 56 | ic(data[item].shape) 57 | 58 | 59 | 60 | if __name__ == "__main__": 61 | # Set up command line argument parser 62 | parser = ArgumentParser(description="Training script parameters") 63 | lp = ModelParams(parser) 64 | op = OptimizationParams(parser) 65 | pp = PipelineParams(parser) 66 | parser.add_argument('--debug_from', type=int, default=-1) 67 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 68 | parser.add_argument("--start_checkpoint", type=str, default = None) 69 | parser.add_argument("--show_imp_score", action='store_true', default=False) 70 | parser.add_argument("--get_fps",action='store_true', default=False) 71 | parser.add_argument("--quiet", action="store_true") 72 | parser.add_argument("--v_pow", type=float, default=0.1) 73 | 74 | 75 | args = parser.parse_args(sys.argv[1:]) 76 | 77 | print("Optimizing " + args.model_path) 78 | 79 | # Initialize system state (RNG) 80 | safe_state(args.quiet) 81 | save_imp_score(lp.extract(args), op.extract(args), pp.extract(args), args.start_checkpoint, args) 82 | # All done 83 | print("\nTraining complete.") 84 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396, 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435, 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = ( 78 | result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] 79 | ) 80 | 81 | if deg > 1: 82 | xx, yy, zz = x * x, y * y, z * z 83 | xy, yz, xz = x * y, y * z, x * z 84 | result = ( 85 | result 86 | + C2[0] * xy * sh[..., 4] 87 | + C2[1] * yz * sh[..., 5] 88 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] 89 | + C2[3] * xz * sh[..., 7] 90 | + C2[4] * (xx - yy) * sh[..., 8] 91 | ) 92 | 93 | if deg > 2: 94 | result = ( 95 | result 96 | + C3[0] * y * (3 * xx - yy) * sh[..., 9] 97 | + C3[1] * xy * z * sh[..., 10] 98 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] 99 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] 100 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] 101 | + C3[5] * z * (xx - yy) * sh[..., 14] 102 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15] 103 | ) 104 | 105 | if deg > 3: 106 | result = ( 107 | result 108 | + C4[0] * xy * (xx - yy) * sh[..., 16] 109 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17] 110 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18] 111 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19] 112 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] 113 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21] 114 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] 115 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] 116 | + C4[8] 117 | * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 118 | * sh[..., 24] 119 | ) 120 | return result 121 | 122 | 123 | def RGB2SH(rgb): 124 | return (rgb - 0.5) / C0 125 | 126 | 127 | def SH2RGB(sh): 128 | return sh * C0 + 0.5 129 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | 17 | def mkdir_p(folder_path): 18 | # Creates a directory. equivalent to using mkdir -p on the command line 19 | try: 20 | makedirs(folder_path) 21 | except OSError as exc: # Python >2.5 22 | if exc.errno == EEXIST and path.isdir(folder_path): 23 | pass 24 | else: 25 | raise 26 | 27 | 28 | def searchForMaxIteration(folder): 29 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 30 | return max(saved_iters) 31 | -------------------------------------------------------------------------------- /utils/tracker_utils.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import random 3 | 4 | class HardestExamplesTracker: 5 | def __init__(self, max_size=10): 6 | self.max_size = max_size 7 | self.heap = [] 8 | self.total_added = 0 9 | 10 | def add(self, loss, example, label): 11 | # Ensure the label is either "virtual" or "gt" 12 | # assert label in ["virtual", "gt"], "Label must be 'virtual' or 'gt'" 13 | 14 | if len(self.heap) < self.max_size: 15 | heapq.heappush(self.heap, (loss, example, label)) 16 | self.total_added += 1 17 | elif loss > self.heap[0][0]: 18 | heapq.heappushpop(self.heap, (loss, example, label)) 19 | 20 | def get_hardest_examples(self): 21 | # Sort by loss and return examples with their labels 22 | return [(example, label) for loss, example, label in sorted(self.heap, reverse=True)] 23 | 24 | def get_random_example(self): 25 | if not self.heap: 26 | return None 27 | _, example, label = random.choice(self.heap) 28 | return example, label 29 | 30 | def get_hardest_example(self): 31 | if not self.heap: 32 | return None 33 | _, example, label = max(self.heap, key=lambda x: x[0]) 34 | return example, label 35 | 36 | def get_size(self): 37 | return self.total_added 38 | 39 | 40 | -------------------------------------------------------------------------------- /utils/vgg.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torchvision import models 5 | 6 | 7 | class Vgg16(torch.nn.Module): 8 | def __init__(self, requires_grad=False): 9 | super(Vgg16, self).__init__() 10 | vgg_pretrained_features = models.vgg16(weights="VGG16_Weights.DEFAULT").features 11 | self.slice1 = torch.nn.Sequential() 12 | self.slice2 = torch.nn.Sequential() 13 | self.slice3 = torch.nn.Sequential() 14 | self.slice4 = torch.nn.Sequential() 15 | for x in range(4): 16 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 17 | for x in range(4, 9): 18 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 19 | for x in range(9, 16): 20 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(16, 23): 22 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 23 | if not requires_grad: 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def forward(self, X): 28 | h = self.slice1(X) 29 | h_relu1_2 = h 30 | h = self.slice2(h) 31 | h_relu2_2 = h 32 | h = self.slice3(h) 33 | h_relu3_3 = h 34 | h = self.slice4(h) 35 | h_relu4_3 = h 36 | # vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']) 37 | # out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) 38 | out = {} 39 | out["relu1_2"] = h_relu1_2 40 | out["relu2_2"] = h_relu2_2 41 | out["relu3_3"] = h_relu3_3 42 | out["relu4_3"] = h_relu4_3 43 | return out 44 | -------------------------------------------------------------------------------- /vectree/utils.py: -------------------------------------------------------------------------------- 1 | import os, math, torch 2 | import numpy as np 3 | from plyfile import PlyData, PlyElement 4 | 5 | def load_vqgaussian(path, device='cuda'): 6 | def load_f(name, allow_pickle=False,array_name='arr_0'): 7 | return np.load(os.path.join(path,name),allow_pickle=allow_pickle)[array_name] 8 | 9 | metadata = load_f('metadata.npz',allow_pickle=True,array_name='metadata') 10 | metadata = metadata.item() 11 | 12 | ## load basic info 13 | codebook_size = metadata['codebook_size'] 14 | codebook_dim = metadata['codebook_dim'] 15 | bit_length = int(math.log2(codebook_size)) # log_2_K 16 | input_pc_num = metadata['input_pc_num'] # feats.shape[0] 17 | input_pc_dim = metadata['input_pc_dim'] # feats.shape[1] 18 | 19 | # ===================================================== load vq_SH ============================================ 20 | ## loading the two masks 21 | non_vq_mask = load_f('non_vq_mask.npz') 22 | non_vq_mask = np.unpackbits(non_vq_mask) 23 | non_vq_mask = non_vq_mask[:input_pc_num] 24 | non_vq_mask = torch.from_numpy(non_vq_mask).bool().to(device) # non_vq_mask 25 | all_one_mask = torch.ones_like(non_vq_mask).bool().to(device) # all_one_mask 26 | 27 | ## loading codebook and vq indexes 28 | codebook = load_f('codebook.npz') 29 | codebook = torch.from_numpy(codebook).float().to(device) 30 | vq_mask = torch.logical_xor(non_vq_mask, all_one_mask) # vq_mask 31 | vq_elements = vq_mask.sum() 32 | 33 | vq_indexs = load_f('vq_indexs.npz') 34 | vq_indexs = np.unpackbits(vq_indexs) 35 | vq_indexs = vq_indexs[:vq_elements*bit_length].reshape(vq_elements,bit_length) 36 | vq_indexs = torch.from_numpy(vq_indexs).float() 37 | vq_indexs = bin2dec(vq_indexs, bits=bit_length) 38 | vq_indexs = vq_indexs.long().to(device) # vq_indexs 39 | 40 | # ===================================================== load non_vq_SH ========================================== 41 | non_vq_feats = load_f('non_vq_feats.npz') 42 | non_vq_feats = torch.from_numpy(non_vq_feats).float().to(device) 43 | 44 | # =========================================== load xyz & other attr(opacity + 3*scale + 4*rot) =============== 45 | other_attribute = load_f('other_attribute.npz') 46 | other_attribute = torch.from_numpy(other_attribute).float().to(device) 47 | 48 | xyz = load_f('xyz.npz') 49 | xyz = torch.from_numpy(xyz).float().to(device) 50 | # =========================================== build full features ============================================= 51 | full_feats = torch.zeros(input_pc_num, input_pc_dim).to(device) 52 | # --- xyz & other attr--- 53 | full_feats[:, 0:3] = xyz 54 | full_feats[:, -8:] = other_attribute 55 | 56 | # --- nx==ny==nz==0 57 | 58 | # --- vq_SH --- 59 | full_feats[vq_mask, 6:6+codebook_dim] = codebook[vq_indexs] 60 | 61 | # --- non_vq_SH --- 62 | # non_vq_mask = torch.logical_xor(vq_mask, all_one_mask) 63 | full_feats[non_vq_mask, 6:6+codebook_dim] = non_vq_feats 64 | 65 | return full_feats 66 | 67 | 68 | 69 | def read_ply_data(input_file): 70 | ply_data = PlyData.read(input_file) 71 | i = 0 72 | vertex = ply_data['vertex'] 73 | for prop in vertex._property_lookup: 74 | tmp = vertex.data[prop].reshape(-1,1) 75 | if i == 0: 76 | data = tmp 77 | i += 1 78 | else: 79 | data = np.concatenate((data, tmp), axis=1) 80 | return data 81 | 82 | 83 | def write_ply_data(feats, save_ply_path, sh_dim): 84 | def construct_list_of_attributes(): 85 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 86 | # All channels except the 3 DC 87 | for i in range(3): 88 | l.append('f_dc_{}'.format(i)) 89 | for i in range(sh_dim-3-8 if sh_dim==24+3+8 else sh_dim-3): 90 | l.append('f_rest_{}'.format(i)) 91 | l.append('opacity') 92 | for i in range(3): 93 | l.append('scale_{}'.format(i)) 94 | for i in range(4): 95 | l.append('rot_{}'.format(i)) 96 | return l 97 | 98 | path= save_ply_path+'/point_cloud.ply' 99 | dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes()] # f4:float32,f2:float16 100 | elements = np.empty(feats.shape[0], dtype=dtype_full) 101 | elements[:] = list(map(tuple, feats)) 102 | el = PlyElement.describe(elements, 'vertex') 103 | PlyData([el]).write(path) 104 | 105 | def dec2bin(x, bits): 106 | mask = 2 ** torch.arange(bits - 1, -1, -1).to(x.device, x.dtype) 107 | return x.unsqueeze(-1).bitwise_and(mask).ne(0).float() 108 | 109 | def bin2dec(b, bits): 110 | mask = 2 ** torch.arange(bits - 1, -1, -1).to(b.device, b.dtype) 111 | return torch.sum(mask * b, -1) 112 | 113 | -------------------------------------------------------------------------------- /vectree/vectree.py: -------------------------------------------------------------------------------- 1 | import os, torch, argparse, math 2 | import numpy as np 3 | from copy import deepcopy 4 | from tqdm import tqdm, trange 5 | 6 | from vq import VectorQuantize 7 | from utils import read_ply_data, write_ply_data, load_vqgaussian 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description="vectree quantization") 12 | parser.add_argument("--important_score_npz_path", type=str, default='room') 13 | parser.add_argument("--input_path", type=str, default='room/iteration_40000/point_cloud.ply') 14 | 15 | parser.add_argument("--save_path", type=str, default='./output/room') 16 | parser.add_argument("--no_load_data", type=bool, default=False) 17 | parser.add_argument("--no_save_ply", type=bool, default=False) 18 | parser.add_argument("--sh_degree", type=int, default=2) 19 | 20 | parser.add_argument("--iteration_num", type=float, default=1000) 21 | parser.add_argument("--vq_ratio", type=float, default=0.6) 22 | parser.add_argument("--codebook_size", type=int, default=2**13) # 2**13 = 8192 23 | parser.add_argument("--no_IS", type=bool, default=False) 24 | parser.add_argument("--vq_way", type=str, default='half') 25 | opt = parser.parse_args() 26 | return opt 27 | 28 | 29 | class Quantization(): 30 | def __init__(self, opt): 31 | 32 | # ----- load ply data ----- 33 | if opt.sh_degree == 3: 34 | self.sh_dim = 3+45 35 | elif opt.sh_degree == 2: 36 | self.sh_dim = 3+24 37 | 38 | self.feats = read_ply_data(opt.input_path) 39 | self.feats = torch.tensor(self.feats) 40 | self.feats_bak = self.feats.clone() 41 | self.feats = self.feats[:, 6:6+self.sh_dim] 42 | 43 | # ----- define model ----- 44 | self.model_vq = VectorQuantize( 45 | dim = self.feats.shape[1], 46 | codebook_size = opt.codebook_size, 47 | decay = 0.8, 48 | commitment_weight = 1.0, 49 | use_cosine_sim = False, 50 | threshold_ema_dead_code=0, 51 | ).to(device) 52 | 53 | # ----- other ----- 54 | self.save_path = opt.save_path 55 | self.ply_path = opt.save_path 56 | self.imp_path = opt.important_score_npz_path 57 | self.high = None 58 | self.VQ_CHUNK = 80000 59 | self.k_expire = 10 60 | self.vq_ratio = opt.vq_ratio 61 | 62 | self.no_IS = opt.no_IS 63 | self.no_load_data = opt.no_load_data 64 | self.no_save_ply = opt.no_save_ply 65 | 66 | self.codebook_size = opt.codebook_size 67 | self.iteration_num = opt.iteration_num 68 | self.vq_way = opt.vq_way 69 | 70 | # ----- print info ----- 71 | print("\n================== Print Info ================== ") 72 | print("Input_feats_shape: ", self.feats_bak.shape) 73 | print("VQ_feats_shape: ", self.feats.shape) 74 | print("SH_degree: ", opt.sh_degree) 75 | print("Quantization_ratio: ", opt.vq_ratio) 76 | print("Add_important_score: ", opt.no_IS==False) 77 | print("Codebook_size: ", opt.codebook_size) 78 | print("================================================ ") 79 | 80 | @torch.no_grad() 81 | def calc_vector_quantized_feature(self): 82 | """ 83 | apply vector quantize on gaussian attributes and return vq indexes 84 | """ 85 | CHUNK = 8192 86 | feat_list = [] 87 | indice_list = [] 88 | self.model_vq.eval() 89 | self.model_vq._codebook.embed.half().float() # 90 | for i in tqdm(range(0, self.feats.shape[0], CHUNK)): 91 | feat, indices, commit = self.model_vq(self.feats[i:i+CHUNK,:].unsqueeze(0).to(device)) 92 | indice_list.append(indices[0]) 93 | feat_list.append(feat[0]) 94 | self.model_vq.train() 95 | all_feat = torch.cat(feat_list).half().float() # [num_elements, feats_dim] 96 | all_indice = torch.cat(indice_list) # [num_elements, 1] 97 | return all_feat, all_indice 98 | 99 | 100 | @torch.no_grad() 101 | def fully_vq_reformat(self): 102 | 103 | print("\n=============== Start vector quantize ===============") 104 | all_feat, all_indice = self.calc_vector_quantized_feature() 105 | 106 | if self.save_path is not None: 107 | save_path = self.save_path 108 | os.makedirs(f'{save_path}/extreme_saving', exist_ok=True) 109 | 110 | # ----- save basic info ----- 111 | metadata = dict() 112 | metadata['input_pc_num'] = self.feats_bak.shape[0] 113 | metadata['input_pc_dim'] = self.feats_bak.shape[1] 114 | metadata['codebook_size'] = self.codebook_size 115 | metadata['codebook_dim'] = self.sh_dim 116 | np.savez_compressed(f'{save_path}/extreme_saving/metadata.npz', metadata=metadata) 117 | 118 | # ===================================================== save vq_SH ============================================= 119 | # ----- save mapping_index (vq_index) ----- 120 | def dec2bin(x, bits): 121 | mask = 2 ** torch.arange(bits - 1, -1, -1).to(x.device, x.dtype) 122 | return x.unsqueeze(-1).bitwise_and(mask).ne(0).float() 123 | # vq indice was saved in according to the bit length 124 | self.codebook_vq_index = all_indice[torch.logical_xor(self.all_one_mask,self.non_vq_mask)] # vq_index 125 | bin_indices = dec2bin(self.codebook_vq_index, int(math.log2(self.codebook_size))).bool().cpu().numpy() # mapping_index 126 | np.savez_compressed(f'{save_path}/extreme_saving/vq_indexs.npz',np.packbits(bin_indices.reshape(-1))) 127 | 128 | # ----- save codebook ----- 129 | codebook = self.model_vq._codebook.embed.cpu().half().numpy().squeeze(0) 130 | np.savez_compressed(f'{save_path}/extreme_saving/codebook.npz', codebook) 131 | 132 | # ----- save keep mask (non_vq_feats_index)----- 133 | np.savez_compressed(f'{save_path}/extreme_saving/non_vq_mask.npz',np.packbits(self.non_vq_mask.reshape(-1).cpu().numpy())) 134 | 135 | # ===================================================== save non_vq_SH ============================================= 136 | non_vq_feats = self.feats_bak[self.non_vq_mask, 6:6+self.sh_dim] 137 | wage_non_vq_feats = self.wage_vq(non_vq_feats) 138 | np.savez_compressed(f'{save_path}/extreme_saving/non_vq_feats.npz', wage_non_vq_feats) 139 | 140 | # =========================================== save xyz & other attr(opacity + 3*scale + 4*rot) ==================================== 141 | other_attribute = self.feats_bak[:, -8:] 142 | wage_other_attribute = self.wage_vq(other_attribute) 143 | np.savez_compressed(f'{save_path}/extreme_saving/other_attribute.npz', wage_other_attribute) 144 | 145 | xyz = self.feats_bak[:, 0:3] 146 | np.savez_compressed(f'{save_path}/extreme_saving/xyz.npz', xyz) 147 | 148 | 149 | # zip everything together to get final size 150 | os.system(f"zip -r {save_path}/extreme_saving.zip {save_path}/extreme_saving") 151 | size = os.path.getsize(f'{save_path}/extreme_saving.zip') 152 | size_MB = size / 1024.0 / 1024.0 153 | print("Size = {:.2f} MB".format(size_MB)) 154 | 155 | return all_feat, all_indice 156 | 157 | def load_f(self, path, name, allow_pickle=False,array_name='arr_0'): 158 | return np.load(os.path.join(path, name),allow_pickle=allow_pickle)[array_name] 159 | 160 | def wage_vq(self, feats): 161 | if self.vq_way == 'half': 162 | return feats.half() 163 | else: 164 | return feats 165 | 166 | def quantize(self): 167 | if self.no_IS: # no important score 168 | importance = np.ones((self.feats.shape[0])) 169 | else: 170 | importance = self.load_f(self.imp_path, 'imp_score.npz') 171 | 172 | ################################################### 173 | only_vq_some_vector = True 174 | if only_vq_some_vector: 175 | tensor_importance = torch.tensor(importance) 176 | large_val, large_index = torch.topk(tensor_importance, k=int(tensor_importance.shape[0] * (1-self.vq_ratio)), largest=True) 177 | self.all_one_mask = torch.ones_like(tensor_importance).bool() 178 | self.non_vq_mask = torch.zeros_like(tensor_importance).bool() 179 | self.non_vq_mask[large_index] = True 180 | self.non_vq_index = large_index 181 | 182 | IS_non_vq_point = large_val.sum() 183 | IS_all_point = tensor_importance.sum() 184 | IS_percent = IS_non_vq_point/IS_all_point 185 | print("IS_percent: ", IS_percent) 186 | 187 | #=================== Codebook initialization & Update codebook ==================== 188 | self.model_vq.train() 189 | with torch.no_grad(): 190 | self.vq_mask = torch.logical_xor(self.all_one_mask, self.non_vq_mask) 191 | feats_needs_vq = self.feats[self.vq_mask].clone() 192 | imp = tensor_importance[self.vq_mask].float() 193 | k = self.k_expire 194 | if k > self.model_vq.codebook_size: 195 | k = 0 196 | for i in trange(self.iteration_num): 197 | indexes = torch.randint(low=0, high=feats_needs_vq.shape[0], size=[self.VQ_CHUNK]) 198 | vq_weight = imp[indexes].to(device) 199 | vq_feature = feats_needs_vq[indexes,:].to(device) 200 | quantize, embed, loss = self.model_vq(vq_feature.unsqueeze(0), weight=vq_weight.reshape(1,-1,1)) 201 | 202 | replace_val, replace_index = torch.topk(self.model_vq._codebook.cluster_size, k=k, largest=False) 203 | _, most_important_index = torch.topk(vq_weight, k=k, largest=True) 204 | self.model_vq._codebook.embed[:,replace_index,:] = vq_feature[most_important_index,:] 205 | 206 | #=================== Apply vector quantization ==================== 207 | all_feat, all_indices = self.fully_vq_reformat() 208 | 209 | def dequantize(self): 210 | print("\n==================== Load saved data & Dequantize ==================== ") 211 | dequantized_feats = load_vqgaussian(os.path.join(self.save_path,'extreme_saving'), device=device) 212 | 213 | if self.no_save_ply == False: 214 | os.makedirs(f'{self.ply_path}/', exist_ok=True) 215 | write_ply_data(dequantized_feats.cpu().numpy(), self.ply_path, self.sh_dim) 216 | 217 | 218 | if __name__=='__main__': 219 | opt = parse_args() 220 | device = torch.device('cuda') 221 | vq = Quantization(opt) 222 | 223 | vq.quantize() 224 | vq.dequantize() 225 | 226 | print("All done!") 227 | 228 | --------------------------------------------------------------------------------