├── .gitignore ├── LICENSE.md ├── README.md ├── arguments └── __init__.py ├── assets └── main.png ├── convert.py ├── data ├── dtu │ └── .gitkeep ├── llff │ └── .gitkeep └── nerf_synthetic │ └── .gitkeep ├── dpt ├── get_depth_map.sh ├── get_depth_map_for_blender.py ├── get_depth_map_for_llff_dtu.py └── utils_io.py ├── encoding.py ├── environment.yml ├── gaussian_renderer └── __init__.py ├── gridencoder ├── __init__.py ├── backend.py ├── grid.py ├── gridencoder.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── setup.py └── src │ ├── bindings.cpp │ ├── gridencoder.cu │ └── gridencoder.h ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── metrics_count.py ├── metrics_dtu.py ├── render.py ├── render_sh.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py ├── gaussian_model.py ├── gaussian_model_sh.py └── neural_renderer.py ├── scripts ├── copy_mask_dtu.sh ├── organize_dtu_dataset.sh ├── run_blender.sh ├── run_dtu.sh ├── run_llff.sh └── run_llff_mvs.sh ├── shencoder ├── __init__.py ├── backend.py ├── setup.py ├── shencoder.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── sphere_harmonics.py └── src │ ├── bindings.cpp │ ├── shencoder.cu │ └── shencoder.h ├── spiral.py ├── submodules └── .gitkeep ├── train_blender.py ├── train_dtu.py ├── train_llff.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── pose_utils.py ├── sh_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | 6 | *.pyc 7 | .vscode 8 | output* 9 | build 10 | diff_rasterization/diff_rast.egg-info 11 | diff_rasterization/dist 12 | tensorboard_3d 13 | screenshots 14 | 15 | 16 | data/* 17 | !data/dtu 18 | !data/llff 19 | !data/nerf_synthetic 20 | data/dtu/* 21 | data/llff/* 22 | data/nerf_synthetic/* 23 | !*.gitkeep -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DNGaussian: Optimizing Sparse-View 3D Gaussian Radiance Fields with Global-Local Depth Normalization 2 | 3 | This is the official repository for our CVPR 2024 paper **DNGaussian: Optimizing Sparse-View 3D Gaussian Radiance Fields with Global-Local Depth Normalization**. 4 | 5 | [Paper](https://arxiv.org/abs/2403.06912) | [Project](https://fictionarry.github.io/DNGaussian/) | [Video](https://www.youtube.com/watch?v=WKXCFNJHZ4o) 6 | 7 | ![image](assets/main.png) 8 | 9 | 10 | ## Installation 11 | 12 | Tested on Ubuntu 18.04, CUDA 11.3, PyTorch 1.12.1 13 | 14 | `````` 15 | conda env create --file environment.yml 16 | conda activate dngaussian 17 | 18 | cd submodules 19 | git clone git@github.com:ashawkey/diff-gaussian-rasterization.git --recursive 20 | git clone https://gitlab.inria.fr/bkerbl/simple-knn.git 21 | pip install ./diff-gaussian-rasterization ./simple-knn 22 | `````` 23 | 24 | If encountering installation problem of the `diff-gaussian-rasterization` or `gridencoder`, you may get some help from [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [torch-ngp](https://github.com/ashawkey/torch-ngp). 25 | 26 | 27 | ## Evaluation 28 | 29 | ### LLFF 30 | 31 | 1. Download LLFF from [the official download link](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). 32 | 33 | 2. Generate monocular depths by DPT: 34 | 35 | ```bash 36 | cd dpt 37 | python get_depth_map_for_llff_dtu.py --root_path $ --benchmark LLFF 38 | ``` 39 | 40 | 3. Start training and testing: 41 | 42 | ```bash 43 | # for example 44 | bash scripts/run_llff.sh data/llff/fern output/llff/fern ${gpu_id} 45 | ``` 46 | 47 | 48 | 49 | ### DTU 50 | 51 | 1. Download DTU dataset 52 | 53 | - Download the DTU dataset "Rectified (123 GB)" from the [official website](https://roboimagedata.compute.dtu.dk/?page_id=36/), and extract it. 54 | - Download masks (used for evaluation only) from [this link](https://drive.google.com/file/d/1Yt5T3LJ9DZDiHbtd9PDFNHqJAd7wt-_E/view?usp=sharing). 55 | 56 | 57 | 2. Organize DTU for few-shot setting 58 | 59 | ```bash 60 | bash scripts/organize_dtu_dataset.sh $rectified_path 61 | ``` 62 | 63 | 3. Format 64 | 65 | - Poses: following [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting), run `convert.py` to get the poses and the undistorted images by COLMAP. 66 | - Render Path: following [LLFF](https://github.com/Fyusion/LLFF) to get the `poses_bounds.npy` from the COLMAP data. (Optional) 67 | 68 | 69 | 4. Generate monocular depths by DPT: 70 | 71 | ```bash 72 | cd dpt 73 | python get_depth_map_for_llff_dtu.py --root_path $ --benchmark DTU 74 | ``` 75 | 76 | 5. Set the mask path and the expected output model path in `copy_mask_dtu.sh` for evaluation. (default: "data/dtu/submission_data/idrmasks" and "output/dtu") 77 | 78 | 6. Start training and testing: 79 | 80 | ```bash 81 | # for example 82 | bash scripts/run_dtu.sh data/dtu/scan8 output/dtu/scan8 ${gpu_id} 83 | ``` 84 | 85 | 86 | 87 | ### Blender 88 | 89 | 1. Download the NeRF Synthetic dataset from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1?usp=sharing). 90 | 91 | 2. Generate monocular depths by DPT: 92 | 93 | ```bash 94 | cd dpt 95 | python get_depth_map_for_blender.py --root_path $ 96 | ``` 97 | 98 | 3. Start training and testing: 99 | 100 | ```bash 101 | # for example 102 | # there are some special settings for different scenes in the Blender dataset, please refer to "run_blender.sh". 103 | bash scripts/run_blender.sh data/nerf_synthetic/drums output/blender/drums ${gpu_id} 104 | ``` 105 | 106 | 107 | ## Reproducing Results 108 | Due to the randomness of the densification process and random initialization, the metrics may be unstable in some scenes, especially PSNR. 109 | 110 | 111 | ### Checkpoints and Results 112 | You can download our provided checkpoints from [here](https://drive.google.com/drive/folders/1V8XGg1MXJDb-bK3NAEo5Gw2GLLByF7FM?usp=sharing). These results are reproduced with a lower error tolerance bound to keep aligned with this repo, which is different from what we use in the paper. This could lead to higher metrics but worse visualization. 113 | 114 | 115 | ### MVS Point Cloud Initialization 116 | 117 | If more stable performance is needed, we recommend trying the dense initialization from [FSGS](https://github.com/VITA-Group/FSGS). 118 | 119 | Here we provide an example script for LLFF that just modifies a few hyperparameters to adapt our method to this initialization: 120 | 121 | ```bash 122 | # Following FSGS to get the "data/llff/$/3_views/dense/fused.ply" first 123 | bash scripts/run_llff_mvs.sh data/llff/$ output_dense/$ ${gpu_id} 124 | ``` 125 | 126 | However, there may still be some randomness. 127 | 128 | For reference, the best results we get in two random tests are as follows: 129 | 130 | | PSNR | LPIPS | SSIM (SK) | SSIM (GS) | 131 | | ------ | ------ | ----- | ----- | 132 | | 19.942 | 0.228 | 0.682 | 0.687 | 133 | 134 | where GS refers to the calculation originally provided by 3DGS, and SK denotes calculated by sklearn which is used in most previous NeRF-based methods. 135 | 136 | 137 | ## Customized Dataset 138 | Similar to Gaussian Splatting, our method can read standard COLMAP format datasets. Please customize your sampling rule in `scenes/dataset_readers.py`, and see how to organize a COLMAP-format dataset from raw RGB images referring to our preprocessing of DTU. 139 | 140 | 141 | 142 | ## Citation 143 | 144 | Consider citing as below if you find this repository helpful to your project: 145 | 146 | ``` 147 | @article{li2024dngaussian, 148 | title={DNGaussian: Optimizing Sparse-View 3D Gaussian Radiance Fields with Global-Local Depth Normalization}, 149 | author={Li, Jiahe and Zhang, Jiawei and Bai, Xiao and Zheng, Jin and Ning, Xin and Zhou, Jun and Gu, Lin}, 150 | journal={arXiv preprint arXiv:2403.06912}, 151 | year={2024} 152 | } 153 | ``` 154 | 155 | ## Acknowledgement 156 | 157 | This code is developed on [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) with [simple-knn](https://gitlab.inria.fr/bkerbl/simple-knn) and a modified [diff-gaussian-rasterization](https://github.com/ashawkey/diff-gaussian-rasterization). The implementation of neural renderer are based on [torch-ngp](https://github.com/ashawkey/torch-ngp). Codes about [DPT](https://github.com/isl-org/MiDaS) are partial from [SparseNeRF](https://github.com/Wanggcong/SparseNeRF). Thanks for these great projects! 158 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self.dataset = "LLFF" 54 | self._resolution = -1 55 | self._white_background = False 56 | self.data_device = "cuda:0" 57 | self.eval = False 58 | self.rand_pcd = False 59 | self.mvs_pcd = False 60 | self.n_sparse = -1 61 | super().__init__(parser, "Loading Parameters", sentinel) 62 | 63 | def extract(self, args): 64 | g = super().extract(args) 65 | g.source_path = os.path.abspath(g.source_path) 66 | return g 67 | 68 | class PipelineParams(ParamGroup): 69 | def __init__(self, parser): 70 | self.convert_SHs_python = False 71 | self.compute_cov3D_python = False 72 | self.debug = False 73 | super().__init__(parser, "Pipeline Parameters") 74 | 75 | class OptimizationParams(ParamGroup): 76 | def __init__(self, parser): 77 | self.iterations = 30_000 78 | self.position_lr_init = 0.00016 79 | self.position_lr_final = 0.0000016 80 | self.position_lr_delay_mult = 0.01 81 | self.position_lr_max_steps = 30_000 82 | self.position_lr_delay_steps = 0 83 | self.position_lr_start = 0 84 | self.feature_lr = 0.0025 85 | self.opacity_lr = 0.05 86 | self.scaling_lr = 0.005 87 | self.rotation_lr = 0.001 88 | self.percent_dense = 0.01 89 | 90 | self.neural_grid = 5e-3 91 | self.neural_net = 5e-4 92 | self.error_tolerance = 0.2 93 | self.split_opacity_thresh = 0.1 94 | self.soft_depth_start = 1000 95 | self.hard_depth_start = 0 96 | 97 | self.shape_pena = 0.001 98 | self.scale_pena = 0.001 99 | self.opa_pena = 0.01 100 | 101 | self.lambda_dssim = 0.2 102 | self.densification_interval = 100 103 | self.opacity_reset_interval = 3000 104 | self.densify_from_iter = 500 105 | self.densify_until_iter = 15_000 106 | self.densify_grad_threshold = 0.0001 107 | self.prune_threshold = 0.01 108 | # self.densify_grad_threshold = 0.002 109 | super().__init__(parser, "Optimization Parameters") 110 | 111 | def get_combined_args(parser : ArgumentParser): 112 | cmdlne_string = sys.argv[1:] 113 | cfgfile_string = "Namespace()" 114 | args_cmdline = parser.parse_args(cmdlne_string) 115 | 116 | try: 117 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 118 | print("Looking for config file in", cfgfilepath) 119 | with open(cfgfilepath) as cfg_file: 120 | print("Config file found: {}".format(cfgfilepath)) 121 | cfgfile_string = cfg_file.read() 122 | except TypeError: 123 | print("Config file not found at") 124 | pass 125 | args_cfgfile = eval(cfgfile_string) 126 | 127 | merged_dict = vars(args_cfgfile).copy() 128 | for k,v in vars(args_cmdline).items(): 129 | if v != None: 130 | merged_dict[k] = v 131 | return Namespace(**merged_dict) 132 | -------------------------------------------------------------------------------- /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/DNGaussian/24a0d7de512ea5c2caaf5f7380db357046e353c8/assets/main.png -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 0 if args.no_gpu else 1 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/input \ 61 | --output_path " + args.source_path + "/distorted/sparse") 62 | # --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/input \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /data/dtu/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/DNGaussian/24a0d7de512ea5c2caaf5f7380db357046e353c8/data/dtu/.gitkeep -------------------------------------------------------------------------------- /data/llff/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/DNGaussian/24a0d7de512ea5c2caaf5f7380db357046e353c8/data/llff/.gitkeep -------------------------------------------------------------------------------- /data/nerf_synthetic/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/DNGaussian/24a0d7de512ea5c2caaf5f7380db357046e353c8/data/nerf_synthetic/.gitkeep -------------------------------------------------------------------------------- /dpt/get_depth_map.sh: -------------------------------------------------------------------------------- 1 | 2 | benchmark=LLFF # LLFF 3 | root_path=../data/distorted/nerf_llff_data 4 | 5 | python get_depth_map_for_llff_dtu.py --root_path $root_path --benchmark $benchmark 6 | -------------------------------------------------------------------------------- /dpt/get_depth_map_for_blender.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | # import matplotlib.pyplot as plt 5 | import utils_io 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | import glob 11 | 12 | import ssl 13 | ssl._create_default_https_context = ssl._create_unverified_context 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-r', '--root_path', type=str) 17 | args = parser.parse_args() 18 | 19 | 20 | 21 | model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest inference speed) 22 | # model_type = "DPT_Hybrid" # MiDaS v3 - Hybrid (medium accuracy, medium inference speed) 23 | # model_type = "MiDaS_small" # MiDaS v2.1 - Small (lowest accuracy, highest inference speed) 24 | # model_type = "DPT_BEiT_L_384" 25 | 26 | midas = torch.hub.load("intel-isl/MiDaS", model_type) 27 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 28 | midas.to(device) 29 | midas.eval() 30 | 31 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 32 | 33 | if "DPT" in model_type: 34 | transform = midas_transforms.dpt_transform 35 | else: 36 | transform = midas_transforms.small_transform 37 | 38 | 39 | for dataset_id in ["chair", "drums", "ficus", "hotdog", "lego", "materials", "mic", "ship"]: 40 | if args.root_path[-1]!="/": 41 | root_path = args.root_path+'/' 42 | else: 43 | root_path = args.root_path 44 | 45 | # output_path = root_path 46 | 47 | root_path_1 = root_path+dataset_id+'/train/*png' 48 | root_path_2 = root_path+dataset_id+'/test/*png' 49 | image_paths_1 = sorted(glob.glob(root_path_1)) 50 | image_paths_2 = sorted(glob.glob(root_path_2)) 51 | image_path_pkg = [image_paths_1, image_paths_2] 52 | 53 | output_path_1 = os.path.join('/'.join(root_path_1.split('/')[:-1]), 'depth_maps') 54 | output_path_2 = os.path.join('/'.join(root_path_2.split('/')[:-1]), 'depth_maps') 55 | output_path_pkg = [output_path_1, output_path_2] 56 | 57 | print('image_paths:', image_path_pkg) 58 | 59 | downsampling = 2 60 | for output_path in output_path_pkg: 61 | if not os.path.exists(output_path): 62 | os.makedirs(output_path, exist_ok=True) 63 | for image_paths, output_path in zip(image_path_pkg, output_path_pkg): 64 | for k in range(len(image_paths)): 65 | filename = image_paths[k] 66 | img = cv2.imread(filename) 67 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 68 | # img = cv2.resize(img, (img.shape[1] // 8, img.shape[0] // 8), interpolation=cv2.INTER_CUBIC) 69 | print('k, img.shape:', k, img.shape) #(1213, 1546, 3) 70 | h, w = img.shape[:2] 71 | input_batch = transform(img).to(device) 72 | 73 | with torch.no_grad(): 74 | prediction = midas(input_batch) 75 | prediction = torch.nn.functional.interpolate( 76 | prediction.unsqueeze(1), 77 | size=(h//downsampling, w//downsampling), 78 | mode="bicubic", 79 | align_corners=False, 80 | ).squeeze() 81 | 82 | output = prediction.cpu().numpy() 83 | name = 'depth_'+filename.split('/')[-1] 84 | print('######### output_path and name:', output_path, name) 85 | output_file_name = os.path.join(output_path, name.split('.')[0]) 86 | # utils.io.write_depth(output_file_name.split('.')[0], output, bits=2) 87 | utils_io.write_depth(output_file_name, output, bits=2) -------------------------------------------------------------------------------- /dpt/get_depth_map_for_llff_dtu.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | # import matplotlib.pyplot as plt 5 | import utils_io 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | import glob 11 | 12 | import ssl 13 | ssl._create_default_https_context = ssl._create_unverified_context 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-b', '--benchmark', type=str) 17 | # parser.add_argument('-d', '--dataset_id', type=str) 18 | parser.add_argument('-r', '--root_path', type=str) 19 | args = parser.parse_args() 20 | 21 | 22 | 23 | if args.benchmark=="DTU": 24 | model_type = "DPT_Large" 25 | scenes = ["scan30", "scan34", "scan41", "scan45", "scan82", "scan103", "scan38", "scan21", "scan40", "scan55", "scan63", "scan31", "scan8", "scan110", "scan114"] 26 | elif args.benchmark=="LLFF": 27 | model_type = "DPT_Hybrid" 28 | scenes = ["fern", "flower", "fortress", "horns", "leaves", "orchids", "room", "trex"] 29 | 30 | midas = torch.hub.load("intel-isl/MiDaS", model_type) 31 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 32 | midas.to(device) 33 | midas.eval() 34 | 35 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 36 | 37 | if "DPT" in model_type: 38 | transform = midas_transforms.dpt_transform 39 | else: 40 | transform = midas_transforms.small_transform 41 | 42 | 43 | for dataset_id in scenes: 44 | if args.root_path[-1]!="/": 45 | root_path = args.root_path+'/' 46 | else: 47 | root_path = args.root_path 48 | 49 | # output_path = root_path 50 | if args.benchmark=="DTU": 51 | root_path_1 = root_path+dataset_id+'/images/*3_r5000*' 52 | image_paths_1 = sorted(glob.glob(root_path_1)) 53 | image_path_pkg = [image_paths_1] 54 | downsampling = 4 55 | 56 | elif args.benchmark=="LLFF": 57 | root_path_1 = root_path+dataset_id+'/images/*.JPG' 58 | root_path_2 = root_path+dataset_id+'/images/*.jpg' 59 | image_paths_1 = sorted(glob.glob(root_path_1)) 60 | image_paths_2 = sorted(glob.glob(root_path_2)) 61 | image_path_pkg = [image_paths_1, image_paths_2] 62 | # root_path = root_path+'/*png' 63 | downsampling = 8 64 | 65 | 66 | output_path = os.path.join(root_path+dataset_id, 'depth_maps') 67 | 68 | 69 | print('image_paths:', image_path_pkg) 70 | 71 | if not os.path.exists(output_path): 72 | os.makedirs(output_path, exist_ok=True) 73 | for image_paths in image_path_pkg: 74 | for k in range(len(image_paths)): 75 | filename = image_paths[k] 76 | img = cv2.imread(filename) 77 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 78 | # img = cv2.resize(img, (img.shape[1] // 8, img.shape[0] // 8), interpolation=cv2.INTER_CUBIC) 79 | print('k, img.shape:', k, img.shape) #(1213, 1546, 3) 80 | h, w = img.shape[:2] 81 | input_batch = transform(img).to(device) 82 | 83 | with torch.no_grad(): 84 | prediction = midas(input_batch) 85 | prediction = torch.nn.functional.interpolate( 86 | prediction.unsqueeze(1), 87 | size=(h//downsampling, w//downsampling), 88 | mode="bicubic", 89 | align_corners=False, 90 | ).squeeze() 91 | 92 | output = prediction.cpu().numpy() 93 | name = 'depth_'+filename.split('/')[-1] 94 | print('######### output_path and name:', output_path, name) 95 | output_file_name = os.path.join(output_path, name.split('.')[0]) 96 | # utils.io.write_depth(output_file_name.split('.')[0], output, bits=2) 97 | utils_io.write_depth(output_file_name, output, bits=2) -------------------------------------------------------------------------------- /dpt/utils_io.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth. 2 | """ 3 | import sys 4 | import re 5 | import numpy as np 6 | import cv2 7 | import torch 8 | 9 | from PIL import Image 10 | 11 | 12 | # from .pallete import get_mask_pallete 13 | 14 | def read_pfm(path): 15 | """Read pfm file. 16 | 17 | Args: 18 | path (str): path to file 19 | 20 | Returns: 21 | tuple: (data, scale) 22 | """ 23 | with open(path, "rb") as file: 24 | 25 | color = None 26 | width = None 27 | height = None 28 | scale = Noneutils.pallete 29 | endian = None 30 | 31 | header = file.readline().rstrip() 32 | if header.decode("ascii") == "PF": 33 | color = True 34 | elif header.decode("ascii") == "Pf": 35 | color = False 36 | else: 37 | raise Exception("Not a PFM file: " + path) 38 | 39 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 40 | if dim_match: 41 | width, height = list(map(int, dim_match.groups())) 42 | else: 43 | raise Exception("Malformed PFM header.") 44 | 45 | scale = float(file.readline().decode("ascii").rstrip()) 46 | if scale < 0: 47 | # little-endian 48 | endian = "<" 49 | scale = -scale 50 | else: 51 | # big-endian 52 | endian = ">" 53 | 54 | data = np.fromfile(file, endian + "f") 55 | shape = (height, width, 3) if color else (height, width) 56 | 57 | data = np.reshape(data, shape) 58 | data = np.flipud(data) 59 | 60 | return data, scale 61 | 62 | def read_pfm_mvsnerf(filename): 63 | file = open(filename, 'rb') 64 | color = None 65 | width = None 66 | height = None 67 | scale = None 68 | endian = None 69 | 70 | header = file.readline().decode('utf-8').rstrip() 71 | if header == 'PF': 72 | color = True 73 | elif header == 'Pf': 74 | color = False 75 | else: 76 | raise Exception('Not a PFM file.') 77 | 78 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 79 | if dim_match: 80 | width, height = map(int, dim_match.groups()) 81 | else: 82 | raise Exception('Malformed PFM header.') 83 | 84 | scale = float(file.readline().rstrip()) 85 | if scale < 0: # little-endian 86 | endian = '<' 87 | scale = -scale 88 | else: 89 | endian = '>' # big-endian 90 | 91 | data = np.fromfile(file, endian + 'f') 92 | shape = (height, width, 3) if color else (height, width) 93 | 94 | data = np.reshape(data, shape) 95 | data = np.flipud(data) 96 | file.close() 97 | return data, scale 98 | 99 | def write_pfm(path, image, scale=1): 100 | """Write pfm file. 101 | 102 | Args: 103 | path (str): pathto file 104 | image (array): data 105 | scale (int, optional): Scale. Defaults to 1. 106 | """ 107 | 108 | with open(path, "wb") as file: 109 | color = None 110 | 111 | if image.dtype.name != "float32": 112 | raise Exception("Image dtype must be float32.") 113 | 114 | image = np.flipud(image) 115 | 116 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 117 | color = True 118 | elif ( 119 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 120 | ): # greyscale 121 | color = False 122 | else: 123 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 124 | 125 | file.write("PF\n" if color else "Pf\n".encode()) 126 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 127 | 128 | endian = image.dtype.byteorder 129 | 130 | if endian == "<" or endian == "=" and sys.byteorder == "little": 131 | scale = -scale 132 | 133 | file.write("%f\n".encode() % scale) 134 | 135 | image.tofile(file) 136 | 137 | 138 | def read_image(path): 139 | """Read image and output RGB image (0-1). 140 | 141 | Args: 142 | path (str): path to file 143 | 144 | Returns: 145 | array: RGB image (0-1) 146 | """ 147 | img = cv2.imread(path) 148 | 149 | if img.ndim == 2: 150 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 151 | 152 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 153 | 154 | return img 155 | 156 | 157 | def resize_image(img): 158 | """Resize image and make it fit for network. 159 | 160 | Args: 161 | img (array): image 162 | 163 | Returns: 164 | tensor: data ready for network 165 | """ 166 | height_orig = img.shape[0] 167 | width_orig = img.shape[1] 168 | 169 | if width_orig > height_orig: 170 | scale = width_orig / 384 171 | else: 172 | scale = height_orig / 384 173 | 174 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 175 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 176 | 177 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 178 | 179 | img_resized = ( 180 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 181 | ) 182 | img_resized = img_resized.unsqueeze(0) 183 | 184 | return img_resized 185 | 186 | 187 | def resize_depth(depth, width, height): 188 | """Resize depth map and bring to CPU (numpy). 189 | 190 | Args: 191 | depth (tensor): depth 192 | width (int): image width 193 | height (int): image height 194 | 195 | Returns: 196 | array: processed depth 197 | """ 198 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 199 | 200 | depth_resized = cv2.resize( 201 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 202 | ) 203 | 204 | return depth_resized 205 | 206 | 207 | def write_depth(path, depth, bits=1, absolute_depth=False): 208 | """Write depth map to pfm and png file. 209 | 210 | Args: 211 | path (str): filepath without extension 212 | depth (array): depth 213 | """ 214 | write_pfm(path + ".pfm", depth.astype(np.float32)) 215 | 216 | if absolute_depth: 217 | out = depth 218 | else: 219 | depth_min = depth.min() 220 | depth_max = depth.max() 221 | 222 | max_val = (2 ** (8 * bits)) - 1 223 | 224 | if depth_max - depth_min > np.finfo("float").eps: 225 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 226 | else: 227 | out = np.zeros(depth.shape, dtype=depth.dtype) 228 | # print('depth:', depth.min(), depth.max()) 229 | # print('out:', out.min(), out.max()) 230 | if bits == 1: 231 | cv2.imwrite(path + ".png", out.astype("uint8"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 232 | elif bits == 2: 233 | cv2.imwrite(path + ".png", out.astype("uint16"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 234 | 235 | return 236 | 237 | ''' 238 | def write_segm_img(path, image, labels, palette="detail", alpha=0.5): 239 | """Write depth map to pfm and png file. 240 | 241 | Args: 242 | path (str): filepath without extension 243 | image (array): input image 244 | labels (array): labeling of the image 245 | """ 246 | 247 | mask = get_mask_pallete(labels, "ade20k") 248 | 249 | img = Image.fromarray(np.uint8(255*image)).convert("RGBA") 250 | seg = mask.convert("RGBA") 251 | 252 | out = Image.blend(img, seg, alpha) 253 | 254 | out.save(path + ".png") 255 | 256 | return 257 | ''' 258 | -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FreqEncoder(nn.Module): 6 | def __init__(self, input_dim, max_freq_log2, N_freqs, 7 | log_sampling=True, include_input=True, 8 | periodic_fns=(torch.sin, torch.cos)): 9 | 10 | super().__init__() 11 | 12 | self.input_dim = input_dim 13 | self.include_input = include_input 14 | self.periodic_fns = periodic_fns 15 | 16 | self.output_dim = 0 17 | if self.include_input: 18 | self.output_dim += self.input_dim 19 | 20 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) 21 | 22 | if log_sampling: 23 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) 24 | else: 25 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) 26 | 27 | self.freq_bands = self.freq_bands.numpy().tolist() 28 | 29 | def forward(self, input, **kwargs): 30 | 31 | out = [] 32 | if self.include_input: 33 | out.append(input) 34 | 35 | for i in range(len(self.freq_bands)): 36 | freq = self.freq_bands[i] 37 | for p_fn in self.periodic_fns: 38 | out.append(p_fn(input * freq)) 39 | 40 | out = torch.cat(out, dim=-1) 41 | 42 | 43 | return out 44 | 45 | def get_encoder(encoding, input_dim=3, 46 | multires=6, 47 | degree=4, 48 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, 49 | **kwargs): 50 | 51 | if encoding == 'None': 52 | return lambda x, **kwargs: x, input_dim 53 | 54 | elif encoding == 'frequency': 55 | #encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) 56 | from freqencoder import FreqEncoder 57 | encoder = FreqEncoder(input_dim=input_dim, degree=multires) 58 | 59 | elif encoding == 'sphere_harmonics': 60 | from shencoder import SHEncoder 61 | encoder = SHEncoder(input_dim=input_dim, degree=degree) 62 | 63 | elif encoding == 'hashgrid': 64 | from gridencoder import GridEncoder 65 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) 66 | 67 | elif encoding == 'tiledgrid': 68 | from gridencoder import GridEncoder 69 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) 70 | 71 | elif encoding == 'ash': 72 | from ashencoder import AshEncoder 73 | encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) 74 | 75 | else: 76 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') 77 | 78 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dngaussian 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.3 8 | - plyfile=0.8.1 9 | - python=3.7.13 10 | - pip=22.3.1 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - tqdm 15 | 16 | - pip: 17 | - ./gridencoder 18 | - pillow 19 | - tensorboard 20 | - opencv-python 21 | - timm 22 | - scikit-image 23 | - imageio 24 | - matplotlib 25 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from scene.gaussian_model_sh import GaussianModelSH 17 | from utils.sh_utils import eval_sh 18 | from utils.graphics_utils import fov2focal 19 | 20 | 21 | def render_neural(viewpoint_camera, pc : GaussianModel): 22 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_xyz.shape[0], 1)) 23 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 24 | sigma, color = pc.neural_renderer(pc.get_xyz.cuda(), dir_pp_normalized.cuda()) 25 | # opacity = 1 - torch.exp(-sigma.view(-1, 1)) 26 | opacity = sigma.view(-1, 1) 27 | return pc.combine_opacity(opacity), color 28 | 29 | 30 | def mip_scales(viewpoint_camera, pc : GaussianModel): 31 | dist_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_xyz.shape[0], 1)).norm(dim=1, keepdim=True) 32 | focal = fov2focal(viewpoint_camera.FoVx, viewpoint_camera.image_width) 33 | return dist_pp / focal 34 | 35 | 36 | 37 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, inference=False): 38 | """ 39 | Render the scene. 40 | 41 | Background tensor (bg_color) must be on GPU! 42 | """ 43 | 44 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 45 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 46 | try: 47 | screenspace_points.retain_grad() 48 | except: 49 | pass 50 | 51 | # Set up rasterization configuration 52 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 53 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 54 | 55 | raster_settings = GaussianRasterizationSettings( 56 | image_height=int(viewpoint_camera.image_height), 57 | image_width=int(viewpoint_camera.image_width), 58 | tanfovx=tanfovx, 59 | tanfovy=tanfovy, 60 | bg=bg_color, 61 | scale_modifier=scaling_modifier, 62 | viewmatrix=viewpoint_camera.world_view_transform, 63 | projmatrix=viewpoint_camera.full_proj_transform, 64 | sh_degree=pc.active_sh_degree, 65 | campos=viewpoint_camera.camera_center, 66 | prefiltered=False, 67 | debug=pipe.debug 68 | ) 69 | 70 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 71 | 72 | means3D = pc.get_xyz 73 | means2D = screenspace_points 74 | # opacity = pc.get_opacity 75 | 76 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 77 | # scaling / rotation by the rasterizer. 78 | scales = None 79 | rotations = None 80 | cov3D_precomp = None 81 | if pipe.compute_cov3D_python: 82 | cov3D_precomp = pc.get_covariance(scaling_modifier) 83 | else: 84 | scales = pc.get_scaling 85 | rotations = pc.get_rotation 86 | 87 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 88 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 89 | shs = None 90 | colors_precomp = None 91 | 92 | # pre 93 | opacity, colors_precomp = render_neural(viewpoint_camera, pc) 94 | if inference: 95 | opacity = pc.get_opacity_ 96 | 97 | # sh 98 | # opacity = pc.get_opacity 99 | # shs = pc.get_features 100 | 101 | # Ashawkey version 102 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 103 | means3D=means3D, 104 | means2D=means2D, 105 | shs=shs, 106 | colors_precomp=colors_precomp, 107 | opacities=opacity, 108 | scales=scales, 109 | rotations=rotations, 110 | cov3D_precomp=cov3D_precomp, 111 | ) 112 | 113 | 114 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 115 | # They will be excluded from value updates used in the splitting criteria. 116 | return {"render": rendered_image, 117 | "depth": rendered_depth, 118 | "alpha": rendered_alpha, 119 | "viewspace_points": screenspace_points, 120 | "visibility_filter" : radii > 0, 121 | "radii": radii, 122 | "opacity": opacity, 123 | "color": colors_precomp} 124 | 125 | 126 | 127 | 128 | def render_for_depth(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, value=0.95): 129 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 130 | try: 131 | screenspace_points.retain_grad() 132 | except: 133 | pass 134 | 135 | # Set up rasterization configuration 136 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 137 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 138 | 139 | raster_settings = GaussianRasterizationSettings( 140 | image_height=int(viewpoint_camera.image_height), 141 | image_width=int(viewpoint_camera.image_width), 142 | tanfovx=tanfovx, 143 | tanfovy=tanfovy, 144 | bg=bg_color, 145 | scale_modifier=scaling_modifier, 146 | viewmatrix=viewpoint_camera.world_view_transform, 147 | projmatrix=viewpoint_camera.full_proj_transform, 148 | sh_degree=pc.active_sh_degree, 149 | campos=viewpoint_camera.camera_center, 150 | prefiltered=False, 151 | debug=pipe.debug 152 | ) 153 | 154 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 155 | 156 | means3D = pc.get_xyz 157 | means2D = screenspace_points 158 | # opacity = pc.get_opacity 159 | opacity = torch.ones(pc.get_xyz.shape[0], 1, device=pc.get_xyz.device) * value 160 | 161 | with torch.no_grad(): 162 | scales = None 163 | rotations = None 164 | cov3D_precomp = None 165 | if pipe.compute_cov3D_python: 166 | cov3D_precomp = pc.get_covariance(scaling_modifier) 167 | else: 168 | scales = pc.get_scaling.detach() 169 | rotations = pc.get_rotation.detach() 170 | 171 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 172 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 173 | shs = None 174 | colors_precomp = torch.ones_like(pc.get_xyz) 175 | 176 | 177 | # Ashawkey version 178 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 179 | means3D=means3D, 180 | means2D=means2D, 181 | shs=shs, 182 | colors_precomp=colors_precomp, 183 | opacities=opacity, 184 | scales=scales, 185 | rotations=rotations, 186 | cov3D_precomp=cov3D_precomp, 187 | ) 188 | 189 | 190 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 191 | # They will be excluded from value updates used in the splitting criteria. 192 | return {"render": rendered_image, 193 | "depth": rendered_depth, 194 | "alpha": rendered_alpha, 195 | "viewspace_points": screenspace_points, 196 | "visibility_filter" : radii > 0, 197 | "radii": radii} 198 | 199 | 200 | 201 | def render_for_opa(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0): 202 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 203 | try: 204 | screenspace_points.retain_grad() 205 | except: 206 | pass 207 | 208 | # Set up rasterization configuration 209 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 210 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 211 | 212 | raster_settings = GaussianRasterizationSettings( 213 | image_height=int(viewpoint_camera.image_height), 214 | image_width=int(viewpoint_camera.image_width), 215 | tanfovx=tanfovx, 216 | tanfovy=tanfovy, 217 | bg=bg_color, 218 | scale_modifier=scaling_modifier, 219 | viewmatrix=viewpoint_camera.world_view_transform, 220 | projmatrix=viewpoint_camera.full_proj_transform, 221 | sh_degree=pc.active_sh_degree, 222 | campos=viewpoint_camera.camera_center, 223 | prefiltered=False, 224 | debug=pipe.debug 225 | ) 226 | 227 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 228 | 229 | means3D = pc.get_xyz.detach() 230 | means2D = screenspace_points 231 | opacity = pc.get_opacity 232 | 233 | scales = None 234 | rotations = None 235 | cov3D_precomp = None 236 | if pipe.compute_cov3D_python: 237 | cov3D_precomp = pc.get_covariance(scaling_modifier) 238 | else: 239 | scales = pc.get_scaling.detach() 240 | rotations = pc.get_rotation.detach() 241 | 242 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 243 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 244 | shs = None 245 | colors_precomp = torch.ones_like(pc.get_xyz) 246 | 247 | 248 | # Ashawkey version 249 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 250 | means3D=means3D, 251 | means2D=means2D, 252 | shs=shs, 253 | colors_precomp=colors_precomp, 254 | opacities=opacity, 255 | scales=scales, 256 | rotations=rotations, 257 | cov3D_precomp=cov3D_precomp, 258 | ) 259 | 260 | 261 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 262 | # They will be excluded from value updates used in the splitting criteria. 263 | return {"render": rendered_image, 264 | "depth": rendered_depth, 265 | "alpha": rendered_alpha, 266 | "viewspace_points": screenspace_points, 267 | "visibility_filter" : radii > 0, 268 | "radii": radii, 269 | "opacity": opacity} 270 | 271 | 272 | 273 | 274 | #----------- for SH 275 | 276 | 277 | 278 | 279 | def render_sh(viewpoint_camera, pc : GaussianModelSH, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, inference=False): 280 | """ 281 | Render the scene. 282 | 283 | Background tensor (bg_color) must be on GPU! 284 | """ 285 | 286 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 287 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 288 | try: 289 | screenspace_points.retain_grad() 290 | except: 291 | pass 292 | 293 | # Set up rasterization configuration 294 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 295 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 296 | 297 | raster_settings = GaussianRasterizationSettings( 298 | image_height=int(viewpoint_camera.image_height), 299 | image_width=int(viewpoint_camera.image_width), 300 | tanfovx=tanfovx, 301 | tanfovy=tanfovy, 302 | bg=bg_color, 303 | scale_modifier=scaling_modifier, 304 | viewmatrix=viewpoint_camera.world_view_transform, 305 | projmatrix=viewpoint_camera.full_proj_transform, 306 | sh_degree=pc.active_sh_degree, 307 | campos=viewpoint_camera.camera_center, 308 | prefiltered=False, 309 | debug=pipe.debug 310 | ) 311 | 312 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 313 | 314 | means3D = pc.get_xyz 315 | means2D = screenspace_points 316 | # opacity = pc.get_opacity 317 | 318 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 319 | # scaling / rotation by the rasterizer. 320 | scales = None 321 | rotations = None 322 | cov3D_precomp = None 323 | if pipe.compute_cov3D_python: 324 | cov3D_precomp = pc.get_covariance(scaling_modifier) 325 | else: 326 | scales = pc.get_scaling 327 | rotations = pc.get_rotation 328 | 329 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 330 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 331 | shs = None 332 | colors_precomp = None 333 | 334 | # sh 335 | opacity = pc.get_opacity 336 | shs = pc.get_features 337 | 338 | # Ashawkey version 339 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 340 | means3D=means3D, 341 | means2D=means2D, 342 | shs=shs, 343 | colors_precomp=colors_precomp, 344 | opacities=opacity, 345 | scales=scales, 346 | rotations=rotations, 347 | cov3D_precomp=cov3D_precomp, 348 | ) 349 | 350 | 351 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 352 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 353 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 354 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 355 | color = torch.clamp_min(sh2rgb + 0.5, 0.0) 356 | 357 | 358 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 359 | # They will be excluded from value updates used in the splitting criteria. 360 | return {"render": rendered_image, 361 | "depth": rendered_depth, 362 | "alpha": rendered_alpha, 363 | "viewspace_points": screenspace_points, 364 | "visibility_filter" : radii > 0, 365 | "radii": radii, 366 | "opacity": opacity, 367 | "color": color} 368 | 369 | 370 | 371 | 372 | def render_for_depth_sh(viewpoint_camera, pc : GaussianModelSH, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, value=0.95): 373 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 374 | try: 375 | screenspace_points.retain_grad() 376 | except: 377 | pass 378 | 379 | # Set up rasterization configuration 380 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 381 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 382 | 383 | raster_settings = GaussianRasterizationSettings( 384 | image_height=int(viewpoint_camera.image_height), 385 | image_width=int(viewpoint_camera.image_width), 386 | tanfovx=tanfovx, 387 | tanfovy=tanfovy, 388 | bg=bg_color, 389 | scale_modifier=scaling_modifier, 390 | viewmatrix=viewpoint_camera.world_view_transform, 391 | projmatrix=viewpoint_camera.full_proj_transform, 392 | sh_degree=pc.active_sh_degree, 393 | campos=viewpoint_camera.camera_center, 394 | prefiltered=False, 395 | debug=pipe.debug 396 | ) 397 | 398 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 399 | 400 | means3D = pc.get_xyz 401 | means2D = screenspace_points 402 | # opacity = pc.get_opacity 403 | opacity = torch.ones(pc.get_xyz.shape[0], 1, device=pc.get_xyz.device) * value 404 | 405 | with torch.no_grad(): 406 | scales = None 407 | rotations = None 408 | cov3D_precomp = None 409 | if pipe.compute_cov3D_python: 410 | cov3D_precomp = pc.get_covariance(scaling_modifier) 411 | else: 412 | scales = pc.get_scaling.detach() 413 | rotations = pc.get_rotation.detach() 414 | 415 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 416 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 417 | shs = None 418 | colors_precomp = torch.ones_like(pc.get_xyz) 419 | 420 | 421 | # Ashawkey version 422 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 423 | means3D=means3D, 424 | means2D=means2D, 425 | shs=shs, 426 | colors_precomp=colors_precomp, 427 | opacities=opacity, 428 | scales=scales, 429 | rotations=rotations, 430 | cov3D_precomp=cov3D_precomp, 431 | ) 432 | 433 | 434 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 435 | # They will be excluded from value updates used in the splitting criteria. 436 | return {"render": rendered_image, 437 | "depth": rendered_depth, 438 | "alpha": rendered_alpha, 439 | "viewspace_points": screenspace_points, 440 | "visibility_filter" : radii > 0, 441 | "radii": radii} 442 | 443 | 444 | 445 | def render_for_opa_sh(viewpoint_camera, pc : GaussianModelSH, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0): 446 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 447 | try: 448 | screenspace_points.retain_grad() 449 | except: 450 | pass 451 | 452 | # Set up rasterization configuration 453 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 454 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 455 | 456 | raster_settings = GaussianRasterizationSettings( 457 | image_height=int(viewpoint_camera.image_height), 458 | image_width=int(viewpoint_camera.image_width), 459 | tanfovx=tanfovx, 460 | tanfovy=tanfovy, 461 | bg=bg_color, 462 | scale_modifier=scaling_modifier, 463 | viewmatrix=viewpoint_camera.world_view_transform, 464 | projmatrix=viewpoint_camera.full_proj_transform, 465 | sh_degree=pc.active_sh_degree, 466 | campos=viewpoint_camera.camera_center, 467 | prefiltered=False, 468 | debug=pipe.debug 469 | ) 470 | 471 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 472 | 473 | means3D = pc.get_xyz.detach() 474 | means2D = screenspace_points 475 | opacity = pc.get_opacity 476 | 477 | scales = None 478 | rotations = None 479 | cov3D_precomp = None 480 | if pipe.compute_cov3D_python: 481 | cov3D_precomp = pc.get_covariance(scaling_modifier) 482 | else: 483 | scales = pc.get_scaling.detach() 484 | rotations = pc.get_rotation.detach() 485 | 486 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 487 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 488 | shs = None 489 | colors_precomp = torch.ones_like(pc.get_xyz) 490 | 491 | 492 | # Ashawkey version 493 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 494 | means3D=means3D, 495 | means2D=means2D, 496 | shs=shs, 497 | colors_precomp=colors_precomp, 498 | opacities=opacity, 499 | scales=scales, 500 | rotations=rotations, 501 | cov3D_precomp=cov3D_precomp, 502 | ) 503 | 504 | 505 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 506 | # They will be excluded from value updates used in the splitting criteria. 507 | return {"render": rendered_image, 508 | "depth": rendered_depth, 509 | "alpha": rendered_alpha, 510 | "viewspace_points": screenspace_points, 511 | "visibility_filter" : radii > 0, 512 | "radii": radii, 513 | "opacity": opacity} 514 | 515 | 516 | -------------------------------------------------------------------------------- /gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gridencoder/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _gridencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | _gridtype_to_id = { 15 | 'hash': 0, 16 | 'tiled': 1, 17 | } 18 | 19 | _interp_to_id = { 20 | 'linear': 0, 21 | 'smoothstep': 1, 22 | } 23 | 24 | class _grid_encode(Function): 25 | @staticmethod 26 | @custom_fwd 27 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0): 28 | # inputs: [B, D], float in [0, 1] 29 | # embeddings: [sO, C], float 30 | # offsets: [L + 1], int 31 | # RETURN: [B, F], float 32 | 33 | inputs = inputs.contiguous() 34 | 35 | B, D = inputs.shape # batch size, coord dim 36 | L = offsets.shape[0] - 1 # level 37 | C = embeddings.shape[1] # embedding dim for each level 38 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 39 | H = base_resolution # base resolution 40 | 41 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) 42 | # if C % 2 != 0, force float, since half for atomicAdd is very slow. 43 | if torch.is_autocast_enabled() and C % 2 == 0: 44 | embeddings = embeddings.to(torch.half) 45 | 46 | # L first, optimize cache for cuda kernel, but needs an extra permute later 47 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) 48 | 49 | if calc_grad_inputs: 50 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) 51 | else: 52 | dy_dx = None 53 | 54 | _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation) 55 | 56 | # permute back to [B, L * C] 57 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 58 | 59 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 60 | ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] 61 | ctx.align_corners = align_corners 62 | 63 | return outputs 64 | 65 | @staticmethod 66 | #@once_differentiable 67 | @custom_bwd 68 | def backward(ctx, grad): 69 | 70 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 71 | B, D, C, L, S, H, gridtype, interpolation = ctx.dims 72 | align_corners = ctx.align_corners 73 | 74 | # grad: [B, L * C] --> [L, B, C] 75 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 76 | 77 | grad_embeddings = torch.zeros_like(embeddings) 78 | 79 | if dy_dx is not None: 80 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) 81 | else: 82 | grad_inputs = None 83 | 84 | _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) 85 | 86 | if dy_dx is not None: 87 | grad_inputs = grad_inputs.to(inputs.dtype) 88 | 89 | return grad_inputs, grad_embeddings, None, None, None, None, None, None, None 90 | 91 | 92 | 93 | grid_encode = _grid_encode.apply 94 | 95 | 96 | class GridEncoder(nn.Module): 97 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): 98 | super().__init__() 99 | 100 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 101 | if desired_resolution is not None: 102 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) 103 | 104 | self.input_dim = input_dim # coord dims, 2 or 3 105 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 106 | self.level_dim = level_dim # encode channels per level 107 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. 108 | self.log2_hashmap_size = log2_hashmap_size 109 | self.base_resolution = base_resolution 110 | self.output_dim = num_levels * level_dim 111 | self.gridtype = gridtype 112 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" 113 | self.interpolation = interpolation 114 | self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" 115 | self.align_corners = align_corners 116 | 117 | # allocate parameters 118 | offsets = [] 119 | offset = 0 120 | self.max_params = 2 ** log2_hashmap_size 121 | for i in range(num_levels): 122 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 123 | params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number 124 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible 125 | offsets.append(offset) 126 | offset += params_in_level 127 | offsets.append(offset) 128 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 129 | self.register_buffer('offsets', offsets) 130 | 131 | self.n_params = offsets[-1] * level_dim 132 | 133 | # parameters 134 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 135 | 136 | self.reset_parameters() 137 | 138 | def reset_parameters(self): 139 | std = 1e-4 140 | self.embeddings.data.uniform_(-std, std) 141 | 142 | def __repr__(self): 143 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" 144 | 145 | def forward(self, inputs, bound=1): 146 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound] 147 | # return: [..., num_levels * level_dim] 148 | 149 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 150 | 151 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) 152 | 153 | prefix_shape = list(inputs.shape[:-1]) 154 | inputs = inputs.view(-1, self.input_dim) 155 | 156 | outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) 157 | outputs = outputs.view(prefix_shape + [self.output_dim]) 158 | 159 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 160 | 161 | return outputs 162 | 163 | # always run in float precision! 164 | @torch.cuda.amp.autocast(enabled=False) 165 | def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): 166 | # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. 167 | 168 | D = self.input_dim 169 | C = self.embeddings.shape[1] # embedding dim for each level 170 | L = self.offsets.shape[0] - 1 # level 171 | S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 172 | H = self.base_resolution # base resolution 173 | 174 | if inputs is None: 175 | # randomized in [0, 1] 176 | inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) 177 | else: 178 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 179 | inputs = inputs.view(-1, self.input_dim) 180 | B = inputs.shape[0] 181 | 182 | if self.embeddings.grad is None: 183 | raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') 184 | 185 | _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) -------------------------------------------------------------------------------- /gridencoder/gridencoder.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: gridencoder 3 | Version: 0.0.0 4 | -------------------------------------------------------------------------------- /gridencoder/gridencoder.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | /disk2/lijiahe/dngaussian-code/gridencoder/src/bindings.cpp 3 | /disk2/lijiahe/dngaussian-code/gridencoder/src/gridencoder.cu 4 | gridencoder.egg-info/PKG-INFO 5 | gridencoder.egg-info/SOURCES.txt 6 | gridencoder.egg-info/dependency_links.txt 7 | gridencoder.egg-info/top_level.txt -------------------------------------------------------------------------------- /gridencoder/gridencoder.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /gridencoder/gridencoder.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | _gridencoder 2 | -------------------------------------------------------------------------------- /gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='gridencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_gridencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'gridencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); 9 | } -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 14 | 15 | void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); 16 | 17 | #endif -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | 19 | from skimage.metrics import structural_similarity 20 | 21 | 22 | from lpipsPyTorch import lpips 23 | import json 24 | from tqdm import tqdm 25 | from utils.image_utils import psnr 26 | from argparse import ArgumentParser 27 | 28 | def readImages(renders_dir, gt_dir): 29 | renders = [] 30 | gts = [] 31 | image_names = [] 32 | for fname in os.listdir(renders_dir): 33 | render = Image.open(renders_dir / fname) 34 | gt = Image.open(gt_dir / fname) 35 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 36 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 37 | image_names.append(fname) 38 | return renders, gts, image_names 39 | 40 | def evaluate(model_paths): 41 | 42 | full_dict = {} 43 | per_view_dict = {} 44 | full_dict_polytopeonly = {} 45 | per_view_dict_polytopeonly = {} 46 | print("") 47 | 48 | for scene_dir in model_paths: 49 | print("Scene:", scene_dir) 50 | full_dict[scene_dir] = {} 51 | per_view_dict[scene_dir] = {} 52 | full_dict_polytopeonly[scene_dir] = {} 53 | per_view_dict_polytopeonly[scene_dir] = {} 54 | 55 | # for acc.. 56 | test_dir = Path(scene_dir) / "eval" 57 | eval_dir = Path(scene_dir) / "eval" 58 | 59 | for test_dir in [eval_dir]: 60 | dataset = test_dir.stem 61 | for method in os.listdir(test_dir): 62 | print("Method:", method, dataset) 63 | 64 | full_dict[scene_dir][method] = {} 65 | per_view_dict[scene_dir][method] = {} 66 | full_dict_polytopeonly[scene_dir][method] = {} 67 | per_view_dict_polytopeonly[scene_dir][method] = {} 68 | 69 | method_dir = test_dir / method 70 | gt_dir = method_dir/ "gt" 71 | renders_dir = method_dir / "renders" 72 | renders, gts, image_names = readImages(renders_dir, gt_dir) 73 | 74 | ssims = [] 75 | ssims_sk = [] 76 | psnrs = [] 77 | lpipss = [] 78 | 79 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress", ascii=True, dynamic_ncols=True): 80 | ssims.append(ssim(renders[idx], gts[idx])) 81 | ssims_sk.append(structural_similarity(renders[idx][0].permute(1,2,0).cpu().numpy(), gts[idx][0].permute(1,2,0).cpu().numpy(), multichannel=True, channel_axis=2 ,data_range=1.0)) 82 | psnrs.append(psnr(renders[idx], gts[idx])) 83 | 84 | # Following previous works to keep the range of RGB in [0, 1]. (however, may be a mistake : https://github.com/richzhang/PerceptualSimilarity) 85 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 86 | 87 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 88 | print(" SSIM_sk : {:>12.7f}".format(torch.tensor(ssims_sk).mean(), ".5")) 89 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 90 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 91 | print("") 92 | 93 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 94 | "SSIM_sk": torch.tensor(ssims_sk).mean().item(), 95 | "PSNR": torch.tensor(psnrs).mean().item(), 96 | "LPIPS": torch.tensor(lpipss).mean().item()}) 97 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 98 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 99 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 100 | 101 | with open(scene_dir + "/results_{}.json".format(dataset), 'w') as fp: 102 | json.dump(full_dict[scene_dir], fp, indent=True) 103 | with open(scene_dir + "/per_view_{}.json".format(dataset), 'w') as fp: 104 | json.dump(per_view_dict[scene_dir], fp, indent=True) 105 | 106 | 107 | if __name__ == "__main__": 108 | device = torch.device("cuda:0") 109 | torch.cuda.set_device(device) 110 | 111 | # Set up command line argument parser 112 | parser = ArgumentParser(description="Training script parameters") 113 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 114 | args = parser.parse_args() 115 | evaluate(args.model_paths) 116 | -------------------------------------------------------------------------------- /metrics_count.py: -------------------------------------------------------------------------------- 1 | # A tool to quickly count the mean metrics of one dir 2 | # usage: 3 | # $ python metrics_count.py output/ 6000 4 | 5 | import os 6 | import json 7 | 8 | import numpy as np 9 | import sys 10 | 11 | dataset_path = sys.argv[1] 12 | model_id = "ours_" + sys.argv[2] 13 | 14 | 15 | ssims_sk = [] 16 | ssims_gs = [] 17 | psnrs = [] 18 | lpipss = [] 19 | avgs = [] 20 | 21 | 22 | def psnr_to_mse(psnr): 23 | """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" 24 | return np.exp(-0.1 * np.log(10.) * psnr) 25 | 26 | def compute_avg_error(psnr, ssim, lpips): 27 | """The 'average' error used in the paper.""" 28 | mse = psnr_to_mse(psnr) 29 | dssim = np.sqrt(1 - ssim) 30 | return np.exp(np.mean(np.log(np.array([mse, dssim, lpips])))) 31 | 32 | 33 | for fname in os.listdir(dataset_path): 34 | 35 | with open(os.path.join(dataset_path, fname, 'results_eval.json')) as f: 36 | result=json.load(f) 37 | ssims_sk.append(result[model_id]["SSIM_sk"]) 38 | ssims_gs.append(result[model_id]["SSIM"]) 39 | psnrs.append(result[model_id]["PSNR"]) 40 | lpipss.append(result[model_id]["LPIPS"]) 41 | avgs.append(compute_avg_error(psnrs[-1], ssims_sk[-1], lpipss[-1])) 42 | 43 | # print(np.mean(psnrs), np.mean(lpipss), np.mean(ssims_sk), np.mean(ssims_gs), np.mean(avgs)) 44 | print(np.mean(psnrs), np.mean(lpipss), np.mean(ssims_sk), np.mean(avgs)) -------------------------------------------------------------------------------- /metrics_dtu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision 17 | import torchvision.transforms.functional as tf 18 | from utils.loss_utils import ssim 19 | 20 | from skimage.metrics import structural_similarity 21 | 22 | from lpipsPyTorch import lpips 23 | import json 24 | from tqdm import tqdm 25 | from utils.image_utils import psnr 26 | from argparse import ArgumentParser 27 | 28 | def readImages(renders_dir, gt_dir, mask_dir): 29 | renders = [] 30 | gts = [] 31 | masks = [] 32 | image_names = [] 33 | for fname in os.listdir(renders_dir): 34 | render = Image.open(renders_dir / fname) 35 | gt = Image.open(gt_dir / fname) 36 | mask = Image.open(mask_dir / fname) 37 | mask = mask.resize((mask.size[0] // 4, mask.size[1] // 4)) 38 | render = render.resize(mask.size) 39 | gt = gt.resize(mask.size) 40 | mask = tf.to_tensor(mask).unsqueeze(0)[:, :3, :, :].cuda() 41 | mask_bin = (mask == 1.) 42 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda() * mask + (1-mask)) 43 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda() * mask + (1-mask)) 44 | masks.append(mask_bin) 45 | image_names.append(fname) 46 | return renders, gts, image_names, masks 47 | 48 | 49 | def evaluate(model_paths): 50 | 51 | full_dict = {} 52 | per_view_dict = {} 53 | full_dict_polytopeonly = {} 54 | per_view_dict_polytopeonly = {} 55 | print("") 56 | 57 | for scene_dir in model_paths: 58 | print("Scene:", scene_dir) 59 | full_dict[scene_dir] = {} 60 | per_view_dict[scene_dir] = {} 61 | full_dict_polytopeonly[scene_dir] = {} 62 | per_view_dict_polytopeonly[scene_dir] = {} 63 | 64 | test_dir = Path(scene_dir) / "eval" 65 | 66 | for test_dir in [test_dir]: 67 | dataset = test_dir.stem 68 | for method in os.listdir(test_dir): 69 | print("Method:", method, dataset) 70 | 71 | full_dict[scene_dir][method] = {} 72 | per_view_dict[scene_dir][method] = {} 73 | full_dict_polytopeonly[scene_dir][method] = {} 74 | per_view_dict_polytopeonly[scene_dir][method] = {} 75 | 76 | method_dir = test_dir / method 77 | mask_dir = Path(scene_dir) / "mask" 78 | gt_dir = method_dir/ "gt" 79 | renders_dir = method_dir / "renders" 80 | renders, gts, image_names, masks = readImages(renders_dir, gt_dir, mask_dir) 81 | 82 | os.makedirs(mask_dir / "masked", exist_ok=True) 83 | 84 | for idx, img in enumerate(tqdm(renders, desc="save", ascii=True, dynamic_ncols=True)): 85 | torchvision.utils.save_image(img, os.path.join(mask_dir / "masked", '{0:05d}'.format(idx) + ".png")) 86 | 87 | ssims = [] 88 | ssims_sk = [] 89 | psnrs = [] 90 | lpipss = [] 91 | 92 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress", ascii=True, dynamic_ncols=True): 93 | ssims.append(ssim(renders[idx], gts[idx])) 94 | ssims_sk.append(structural_similarity(renders[idx][0].permute(1,2,0).cpu().numpy(), gts[idx][0].permute(1,2,0).cpu().numpy(), multichannel=True, channel_axis=2, data_range=1.0)) 95 | psnrs.append(psnr(renders[idx][masks[idx]][None, ...], gts[idx][masks[idx]][None, ...])) 96 | 97 | # Following previous works to keep the range of RGB in [0, 1]. (however, may be a mistake : https://github.com/richzhang/PerceptualSimilarity) 98 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 99 | 100 | 101 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 102 | print(" SSIM_sk : {:>12.7f}".format(torch.tensor(ssims_sk).mean(), ".5")) 103 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 104 | print(" LPIPS : {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 105 | print("") 106 | 107 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 108 | "SSIM_sk": torch.tensor(ssims_sk).mean().item(), 109 | "PSNR": torch.tensor(psnrs).mean().item(), 110 | "LPIPS": torch.tensor(lpipss).mean().item()}) 111 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 112 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 113 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 114 | 115 | with open(scene_dir + "/results_{}_mask.json".format(dataset), 'w') as fp: 116 | json.dump(full_dict[scene_dir], fp, indent=True) 117 | with open(scene_dir + "/per_view_{}_mask.json".format(dataset), 'w') as fp: 118 | json.dump(per_view_dict[scene_dir], fp, indent=True) 119 | 120 | 121 | if __name__ == "__main__": 122 | device = torch.device("cuda:0") 123 | torch.cuda.set_device(device) 124 | 125 | # Set up command line argument parser 126 | parser = ArgumentParser(description="Training script parameters") 127 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 128 | args = parser.parse_args() 129 | evaluate(args.model_paths) 130 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from argparse import ArgumentParser 21 | from arguments import ModelParams, PipelineParams, get_combined_args 22 | from gaussian_renderer import GaussianModel 23 | 24 | 25 | import numpy as np 26 | import matplotlib.cm as cm 27 | 28 | 29 | def weighted_percentile(x, w, ps, assume_sorted=False): 30 | """Compute the weighted percentile(s) of a single vector.""" 31 | x = x.reshape([-1]) 32 | w = w.reshape([-1]) 33 | if not assume_sorted: 34 | sortidx = np.argsort(x) 35 | x, w = x[sortidx], w[sortidx] 36 | acc_w = np.cumsum(w) 37 | return np.interp(np.array(ps) * (acc_w[-1] / 100), acc_w, x) 38 | 39 | def visualize_cmap(value, 40 | weight, 41 | colormap, 42 | lo=None, 43 | hi=None, 44 | percentile=99., 45 | curve_fn=lambda x: x, 46 | modulus=None, 47 | matte_background=True): 48 | """Visualize a 1D image and a 1D weighting according to some colormap. 49 | 50 | Args: 51 | value: A 1D image. 52 | weight: A weight map, in [0, 1]. 53 | colormap: A colormap function. 54 | lo: The lower bound to use when rendering, if None then use a percentile. 55 | hi: The upper bound to use when rendering, if None then use a percentile. 56 | percentile: What percentile of the value map to crop to when automatically 57 | generating `lo` and `hi`. Depends on `weight` as well as `value'. 58 | curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` 59 | before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). 60 | modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If 61 | `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. 62 | matte_background: If True, matte the image over a checkerboard. 63 | 64 | Returns: 65 | A colormap rendering. 66 | """ 67 | # Identify the values that bound the middle of `value' according to `weight`. 68 | lo_auto, hi_auto = weighted_percentile( 69 | value, weight, [50 - percentile / 2, 50 + percentile / 2]) 70 | 71 | # If `lo` or `hi` are None, use the automatically-computed bounds above. 72 | eps = np.finfo(np.float32).eps 73 | lo = lo or (lo_auto - eps) 74 | hi = hi or (hi_auto + eps) 75 | 76 | # Curve all values. 77 | value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] 78 | 79 | # Wrap the values around if requested. 80 | if modulus: 81 | value = np.mod(value, modulus) / modulus 82 | else: 83 | # Otherwise, just scale to [0, 1]. 84 | value = np.nan_to_num( 85 | np.clip((value - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)) 86 | 87 | if colormap: 88 | colorized = colormap(value)[:, :, :3] 89 | else: 90 | assert len(value.shape) == 3 and value.shape[-1] == 3 91 | colorized = value 92 | 93 | return colorized 94 | 95 | depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) 96 | 97 | 98 | 99 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background, near=0): 100 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 101 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 102 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth") 103 | 104 | makedirs(render_path, exist_ok=True) 105 | makedirs(gts_path, exist_ok=True) 106 | makedirs(depth_path, exist_ok=True) 107 | 108 | if near > 0: 109 | mask_near = None 110 | for idx, view in enumerate(tqdm(views, desc="Rendering progress", ascii=True, dynamic_ncols=True)): 111 | mask_temp = (gaussians.get_xyz - view.camera_center.repeat(gaussians.get_xyz.shape[0], 1)).norm(dim=1, keepdim=True) < near 112 | mask_near = mask_near + mask_temp if mask_near is not None else mask_temp 113 | gaussians.prune_points_inference(mask_near) 114 | 115 | for idx, view in enumerate(tqdm(views, desc="Rendering progress", ascii=True, dynamic_ncols=True)): 116 | render_pkg = render(view, gaussians, pipeline, background, inference=True) 117 | rendering = render_pkg["render"] 118 | gt = view.original_image[0:3, :, :] 119 | depth = (render_pkg['depth'] - render_pkg['depth'].min()) / (render_pkg['depth'].max() - render_pkg['depth'].min()) + 1 * (1 - render_pkg["alpha"]) 120 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 121 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 122 | torchvision.utils.save_image(1 - depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png")) 123 | torchvision.utils.save_image(render_pkg["alpha"], os.path.join(depth_path, 'alpha_{0:05d}'.format(idx) + ".png")) 124 | 125 | depth_est = depth.squeeze().cpu().numpy() 126 | depth_est = visualize_cmap(depth_est, np.ones_like(depth_est), cm.get_cmap('turbo'), curve_fn=depth_curve_fn).copy() 127 | depth_est = torch.as_tensor(depth_est).permute(2,0,1) 128 | torchvision.utils.save_image(depth_est, os.path.join(depth_path, 'color_{0:05d}'.format(idx) + ".png")) 129 | 130 | 131 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, near : int): 132 | with torch.no_grad(): 133 | gaussians = GaussianModel(dataset.sh_degree) 134 | (model_params, _) = torch.load(os.path.join(dataset.model_path, "chkpnt_latest.pth")) 135 | gaussians.restore(model_params) 136 | gaussians.neural_renderer.keep_sigma=True 137 | 138 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 139 | 140 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 141 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 142 | 143 | render_set(dataset.model_path, "eval", scene.loaded_iter, scene.getEvalCameras(), gaussians, pipeline, background, near) 144 | if not skip_train: 145 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 146 | 147 | if not skip_test: 148 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 149 | 150 | if __name__ == "__main__": 151 | # Set up command line argument parser 152 | parser = ArgumentParser(description="Testing script parameters") 153 | model = ModelParams(parser, sentinel=True) 154 | pipeline = PipelineParams(parser) 155 | parser.add_argument("--iteration", default=-1, type=int) 156 | parser.add_argument("--skip_train", action="store_true") 157 | parser.add_argument("--skip_test", action="store_true") 158 | parser.add_argument("--quiet", action="store_true") 159 | parser.add_argument("--near", default=0, type=int) 160 | args = get_combined_args(parser) 161 | print("Rendering " + args.model_path) 162 | 163 | # Initialize system state (RNG) 164 | safe_state(args.quiet) 165 | 166 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.near) -------------------------------------------------------------------------------- /render_sh.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render_sh 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from argparse import ArgumentParser 21 | from arguments import ModelParams, PipelineParams, get_combined_args 22 | from gaussian_renderer import GaussianModelSH 23 | 24 | 25 | 26 | import numpy as np 27 | import matplotlib.cm as cm 28 | 29 | 30 | def weighted_percentile(x, w, ps, assume_sorted=False): 31 | """Compute the weighted percentile(s) of a single vector.""" 32 | x = x.reshape([-1]) 33 | w = w.reshape([-1]) 34 | if not assume_sorted: 35 | sortidx = np.argsort(x) 36 | x, w = x[sortidx], w[sortidx] 37 | acc_w = np.cumsum(w) 38 | return np.interp(np.array(ps) * (acc_w[-1] / 100), acc_w, x) 39 | 40 | def visualize_cmap(value, 41 | weight, 42 | colormap, 43 | lo=None, 44 | hi=None, 45 | percentile=99., 46 | curve_fn=lambda x: x, 47 | modulus=None, 48 | matte_background=True): 49 | """Visualize a 1D image and a 1D weighting according to some colormap. 50 | 51 | Args: 52 | value: A 1D image. 53 | weight: A weight map, in [0, 1]. 54 | colormap: A colormap function. 55 | lo: The lower bound to use when rendering, if None then use a percentile. 56 | hi: The upper bound to use when rendering, if None then use a percentile. 57 | percentile: What percentile of the value map to crop to when automatically 58 | generating `lo` and `hi`. Depends on `weight` as well as `value'. 59 | curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` 60 | before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). 61 | modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If 62 | `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. 63 | matte_background: If True, matte the image over a checkerboard. 64 | 65 | Returns: 66 | A colormap rendering. 67 | """ 68 | # Identify the values that bound the middle of `value' according to `weight`. 69 | lo_auto, hi_auto = weighted_percentile( 70 | value, weight, [50 - percentile / 2, 50 + percentile / 2]) 71 | 72 | # If `lo` or `hi` are None, use the automatically-computed bounds above. 73 | eps = np.finfo(np.float32).eps 74 | lo = lo or (lo_auto - eps) 75 | hi = hi or (hi_auto + eps) 76 | 77 | # Curve all values. 78 | value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] 79 | 80 | # Wrap the values around if requested. 81 | if modulus: 82 | value = np.mod(value, modulus) / modulus 83 | else: 84 | # Otherwise, just scale to [0, 1]. 85 | value = np.nan_to_num( 86 | np.clip((value - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)) 87 | 88 | if colormap: 89 | colorized = colormap(value)[:, :, :3] 90 | else: 91 | assert len(value.shape) == 3 and value.shape[-1] == 3 92 | colorized = value 93 | 94 | return colorized 95 | 96 | depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) 97 | 98 | 99 | 100 | 101 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 102 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 103 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 104 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth") 105 | 106 | makedirs(render_path, exist_ok=True) 107 | makedirs(gts_path, exist_ok=True) 108 | makedirs(depth_path, exist_ok=True) 109 | 110 | for idx, view in enumerate(tqdm(views, desc="Rendering progress", ascii=True, dynamic_ncols=True)): 111 | render_pkg = render_sh(view, gaussians, pipeline, background) 112 | rendering = render_pkg["render"] 113 | gt = view.original_image[0:3, :, :] 114 | depth = 1.0 - (render_pkg['depth'] - render_pkg['depth'].min()) / (render_pkg['depth'].max() - render_pkg['depth'].min()) 115 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 116 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 117 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png")) 118 | torchvision.utils.save_image(render_pkg['alpha'], os.path.join(depth_path, 'aplha_{0:05d}'.format(idx) + ".png")) 119 | 120 | depth_est = (1 - depth * render_pkg["alpha"]).squeeze().cpu().numpy() 121 | depth_est = visualize_cmap(depth_est, np.ones_like(depth_est), cm.get_cmap('turbo'), curve_fn=depth_curve_fn).copy() 122 | depth_est = torch.as_tensor(depth_est).permute(2,0,1) 123 | torchvision.utils.save_image(depth_est, os.path.join(depth_path, 'color_{0:05d}'.format(idx) + ".png")) 124 | 125 | 126 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): 127 | with torch.no_grad(): 128 | gaussians = GaussianModelSH(dataset.sh_degree) 129 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 130 | 131 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 132 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 133 | 134 | if not skip_train: 135 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 136 | 137 | if not skip_test: 138 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 139 | render_set(dataset.model_path, "eval", scene.loaded_iter, scene.getEvalCameras(), gaussians, pipeline, background) 140 | 141 | if __name__ == "__main__": 142 | # Set up command line argument parser 143 | parser = ArgumentParser(description="Testing script parameters") 144 | model = ModelParams(parser, sentinel=True) 145 | pipeline = PipelineParams(parser) 146 | parser.add_argument("--iteration", default=-1, type=int) 147 | parser.add_argument("--skip_train", action="store_true") 148 | parser.add_argument("--skip_test", action="store_true") 149 | parser.add_argument("--quiet", action="store_true") 150 | args = get_combined_args(parser) 151 | print("Rendering " + args.model_path) 152 | 153 | # Initialize system state (RNG) 154 | safe_state(args.quiet) 155 | 156 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from scene.gaussian_model_sh import GaussianModelSH 19 | from arguments import ModelParams 20 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON, renderCameraList_from_camInfos 21 | 22 | class Scene: 23 | 24 | gaussians : GaussianModel 25 | 26 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 27 | """b 28 | :param path: Path to colmap scene main folder. 29 | """ 30 | self.model_path = args.model_path 31 | self.source_path = args.source_path 32 | self.loaded_iter = None 33 | self.gaussians = gaussians 34 | 35 | if load_iteration: 36 | if load_iteration == -1: 37 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 38 | else: 39 | self.loaded_iter = load_iteration 40 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 41 | 42 | self.train_cameras = {} 43 | self.test_cameras = {} 44 | self.eval_cameras = {} 45 | 46 | if os.path.exists(os.path.join(args.source_path, "sparse")): 47 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.dataset, args.eval, args.rand_pcd, args.mvs_pcd, N_sparse = args.n_sparse) 48 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 49 | print("Found transforms_train.json file, assuming Blender data set!") 50 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, args.rand_pcd, N_sparse = args.n_sparse) 51 | else: 52 | assert False, "Could not recognize scene type!" 53 | 54 | if not self.loaded_iter: 55 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 56 | dest_file.write(src_file.read()) 57 | json_cams = [] 58 | camlist = [] 59 | if scene_info.test_cameras: 60 | camlist.extend(scene_info.test_cameras) 61 | if scene_info.train_cameras: 62 | camlist.extend(scene_info.train_cameras) 63 | if scene_info.eval_cameras: 64 | camlist.extend(scene_info.eval_cameras) 65 | for id, cam in enumerate(camlist): 66 | json_cams.append(camera_to_JSON(id, cam)) 67 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 68 | json.dump(json_cams, file) 69 | 70 | if shuffle: 71 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 72 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 73 | random.shuffle(scene_info.eval_cameras) # Multi-res consistent random shuffling 74 | 75 | self.cameras_extent = scene_info.nerf_normalization["radius"] 76 | 77 | for resolution_scale in resolution_scales: 78 | print("Loading Training Cameras", resolution_scales) 79 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 80 | print("Loading Test Cameras", resolution_scales) 81 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 82 | print("Loading Eval Cameras", resolution_scales) 83 | self.eval_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.eval_cameras, resolution_scale, args) 84 | 85 | if self.loaded_iter: 86 | self.gaussians.load_ply(os.path.join(self.model_path, 87 | "point_cloud", 88 | "iteration_" + str(self.loaded_iter), 89 | "point_cloud.ply")) 90 | else: 91 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 92 | 93 | def save(self, iteration, color=None): 94 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 95 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 96 | if color is not None: 97 | self.gaussians.save_ply_color(os.path.join(point_cloud_path, "point_cloud_color.ply"), color) 98 | 99 | def getTrainCameras(self, scale=1.0): 100 | return self.train_cameras[scale] 101 | 102 | def getTestCameras(self, scale=1.0): 103 | return self.test_cameras[scale] 104 | 105 | def getEvalCameras(self, scale=1.0): 106 | return self.eval_cameras[scale] 107 | 108 | 109 | 110 | 111 | class RenderScene: 112 | 113 | gaussians : GaussianModel 114 | 115 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, spiral=True, resolution_scales=[1.0]): 116 | """b 117 | :param path: Path to colmap scene main folder. 118 | """ 119 | self.model_path = args.model_path 120 | self.loaded_iter = None 121 | self.gaussians = gaussians 122 | 123 | if load_iteration: 124 | if load_iteration == -1: 125 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 126 | else: 127 | self.loaded_iter = load_iteration 128 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 129 | 130 | self.test_cameras = {} 131 | 132 | if 'scan' in args.source_path: 133 | scene_info = sceneLoadTypeCallbacks["SpiralDTU"](args.source_path) 134 | else: 135 | scene_info = sceneLoadTypeCallbacks["Spiral"](args.source_path) 136 | 137 | self.cameras_extent = scene_info.nerf_normalization["radius"] 138 | 139 | for resolution_scale in resolution_scales: 140 | print("Loading Render Cameras", resolution_scales) 141 | self.test_cameras[resolution_scale] = renderCameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 142 | 143 | if self.loaded_iter: 144 | self.gaussians.load_ply(os.path.join(self.model_path, 145 | "point_cloud", 146 | "iteration_" + str(self.loaded_iter), 147 | "point_cloud.ply")) 148 | else: 149 | pass 150 | 151 | 152 | def getRenderCameras(self, scale=1.0): 153 | return self.test_cameras[scale] 154 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, depth_mono, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.image_name = image_name 31 | 32 | try: 33 | self.data_device = torch.device(data_device) 34 | except Exception as e: 35 | print(e) 36 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 37 | self.data_device = torch.device("cuda") 38 | if torch.cuda.is_available(): 39 | torch.cuda.set_device(self.data_device) 40 | 41 | self.depth_mono = None 42 | self.original_image = None 43 | if depth_mono is not None: 44 | self.depth_mono = depth_mono.to(self.data_device) 45 | # self.mono_scale = torch.nn.parameter.Parameter(data=torch.tensor(1.0, device=data_device), requires_grad=True) 46 | # self.mono_bias = torch.nn.parameter.Parameter(data=torch.tensor(0.0, device=data_device), requires_grad=True) 47 | # self.mono_optimizer = torch.optim.Adam(self.parameters(), lr=0.001) 48 | 49 | if image is not None: 50 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 51 | self.image_width = self.original_image.shape[2] 52 | self.image_height = self.original_image.shape[1] 53 | 54 | if gt_alpha_mask is not None: 55 | self.original_image *= gt_alpha_mask.to(self.data_device) 56 | else: 57 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 58 | 59 | self.zfar = 100.0 60 | self.znear = 0.01 61 | 62 | self.trans = trans 63 | self.scale = scale 64 | 65 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 66 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 67 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 68 | self.camera_center = self.world_view_transform.inverse()[3, :3] 69 | 70 | class MiniCam: 71 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 72 | self.image_width = width 73 | self.image_height = height 74 | self.FoVy = fovy 75 | self.FoVx = fovx 76 | self.znear = znear 77 | self.zfar = zfar 78 | self.world_view_transform = world_view_transform 79 | self.full_proj_transform = full_proj_transform 80 | view_inv = torch.inverse(self.world_view_transform) 81 | self.camera_center = view_inv[3][:3] 82 | 83 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | # assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /scene/neural_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | 7 | 8 | class MLP(nn.Module): 9 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers): 10 | super().__init__() 11 | self.dim_in = dim_in 12 | self.dim_out = dim_out 13 | self.dim_hidden = dim_hidden 14 | self.num_layers = num_layers 15 | 16 | net = [] 17 | for l in range(num_layers): 18 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False)) 19 | 20 | self.net = nn.ModuleList(net) 21 | 22 | def forward(self, x): 23 | for l in range(self.num_layers): 24 | x = self.net[l](x) 25 | if l != self.num_layers - 1: 26 | x = F.relu(x, inplace=True) 27 | # x = F.dropout(x, p=0.1, training=self.training) 28 | 29 | return x 30 | 31 | 32 | class GridRenderer(nn.Module): 33 | def __init__(self, 34 | bound = 1., 35 | coord_center=[0., 0., 0.], 36 | keep_sigma=False 37 | ): 38 | super().__init__() 39 | self.register_buffer('bound', torch.as_tensor(bound, dtype=torch.float32).detach()) 40 | self.register_buffer('coord_center', torch.as_tensor(coord_center, dtype=torch.float32).detach()) 41 | 42 | self.keep_sigma = keep_sigma 43 | self.sigma_results_static = None 44 | 45 | self.num_levels = 16 46 | self.level_dim = 2 47 | self.base_resolution = 16 48 | self.table_size = 19 49 | self.desired_resolution = 512 50 | self.encoder_x, self.in_dim_x = self.create_encoder() 51 | 52 | ## sigma network 53 | self.num_layers = 3 54 | self.hidden_dim = 64 55 | self.geo_feat_dim = 64 56 | self.sigma_net = MLP(self.in_dim_x, 1 + self.geo_feat_dim, self.hidden_dim, self.num_layers) 57 | ## color network 58 | self.num_layers_color = 2 59 | self.hidden_dim_color = 64 60 | self.encoder_dir, self.in_dim_dir = get_encoder('sphere_harmonics') 61 | self.color_net = MLP(self.in_dim_dir + self.geo_feat_dim, 3, self.hidden_dim_color, self.num_layers_color) 62 | 63 | 64 | def create_encoder(self): 65 | self.encoder_x, self.in_dim_x = get_encoder( 66 | 'hashgrid', input_dim=3, num_levels=self.num_levels, level_dim=self.level_dim, 67 | base_resolution=self.base_resolution, log2_hashmap_size=self.table_size, desired_resolution=self.desired_resolution * self.bound.cpu()) 68 | return self.encoder_x, self.in_dim_x 69 | 70 | def recover_from_ckpt(self, state_dict): 71 | self.bound = state_dict['bound'] 72 | self.encoder_x, self.in_dim_x = self.create_encoder() 73 | self.load_state_dict(state_dict) 74 | 75 | def encode_x(self, x): 76 | # x: [N, 3], in [-bound, bound] 77 | return self.encoder_x(x - self.coord_center, bound=self.bound) 78 | 79 | 80 | def forward(self, x, d): 81 | # x: [N, 3], in [-bound, bound] 82 | # d: [N, 3], nomalized in [-1, 1] 83 | enc_x = self.encode_x(x) 84 | sigma_result = self.density(x, enc_x) 85 | sigma = sigma_result['sigma'] 86 | color = self.color(sigma_result, d) 87 | return sigma, color 88 | 89 | 90 | def color(self, sigma_result, d): 91 | geo_feat = sigma_result['geo_feat'] 92 | enc_d = self.encoder_dir(d) 93 | h = torch.cat([enc_d, geo_feat], dim=-1) 94 | 95 | h_color = self.color_net(h) 96 | color = torch.sigmoid(h_color)*(1 + 2*0.001) - 0.001 97 | return color 98 | 99 | 100 | def density(self, x, enc_x=None): 101 | # x: [N, 3], in [-bound, bound] 102 | if self.keep_sigma and self.sigma_results_static is not None: 103 | return self.sigma_results_static 104 | 105 | if enc_x is None: 106 | enc_x = self.encode_x(x) 107 | 108 | h = self.sigma_net(enc_x) 109 | sigma = h[..., 0] 110 | # sigma = torch.exp(h[..., 0]) 111 | # sigma = torch.sigmoid(h[..., 0]) 112 | geo_feat = h[..., 1:] 113 | 114 | if self.keep_sigma: 115 | self.sigma_results_static = { 116 | 'sigma': sigma, 117 | 'geo_feat': geo_feat, 118 | } 119 | 120 | return { 121 | 'sigma': sigma, 122 | 'geo_feat': geo_feat, 123 | } 124 | 125 | 126 | # optimizer utils 127 | def get_params(self, lr, lr_net, wd=0): 128 | 129 | params = [ 130 | {'params': self.encoder_x.parameters(), 'name': 'neural_encoder', 'lr': lr}, 131 | {'params': self.sigma_net.parameters(), 'name': 'neural_sigma', 'lr': lr_net, 'weight_decay': wd}, 132 | {'params': self.color_net.parameters(), 'name': 'neural_color', 'lr': lr_net, 'weight_decay': wd}, 133 | ] 134 | 135 | return params -------------------------------------------------------------------------------- /scripts/copy_mask_dtu.sh: -------------------------------------------------------------------------------- 1 | base="output/dtu/" 2 | mask_path="data/dtu/submission_data/idrmasks" 3 | 4 | for scan_id in scan30 scan34 scan41 scan45 scan82 scan103 scan38 scan21 scan40 scan55 scan63 scan31 scan8 scan110 scan114 5 | do 6 | if [ -d $base/$scan_id ]; then 7 | # rm -r $base/$scan_id/mask 8 | mkdir $base/$scan_id/mask 9 | id=0 10 | if [ -d ${mask_path}/$scan_id/mask ]; then 11 | for file in ${mask_path}/scan8/* 12 | do 13 | # echo $file 14 | file_name=$(printf "%05d" $id).png; 15 | cp ${file//scan8/$scan_id'/mask'} $base/$scan_id/mask/$file_name 16 | ((id = id + 1)) 17 | done 18 | 19 | else 20 | 21 | for file in ${mask_path}/$scan_id/* 22 | do 23 | # echo $file 24 | file_name=$(printf "%05d" $id).png; 25 | cp $file $base/$scan_id/mask/$file_name 26 | ((id = id + 1)) 27 | done 28 | fi 29 | fi 30 | 31 | done -------------------------------------------------------------------------------- /scripts/organize_dtu_dataset.sh: -------------------------------------------------------------------------------- 1 | rectified_path=$1 # ../data/DTU/Rectified 2 | 3 | 4 | for scan_id in scan30 scan34 scan41 scan45 scan82 scan103 scan38 scan21 scan40 scan55 scan63 scan31 scan8 scan110 scan114 5 | do 6 | echo $scan_id 7 | mkdir -p ./data/dtu/$scan_id/input 8 | cp $rectified_path/$scan_id/*_3_r5000.png ./data/dtu/$scan_id/input/ 9 | done -------------------------------------------------------------------------------- /scripts/run_blender.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | workspace=$2 3 | export CUDA_VISIBLE_DEVICES=$3 4 | 5 | 6 | ## (sorry for that this part is not so elegent) 7 | 8 | 9 | ## For the later scenes, we do not need to apply soft depth supervision. 10 | 11 | 12 | ## for materials, drums 13 | 14 | python train_blender.py -s $dataset --model_path $workspace -r 2 --eval --n_sparse 8 --rand_pcd --iterations 6000 --lambda_dssim 0.6 --white_background \ 15 | --densify_grad_threshold 0.001 --prune_threshold 0.01 --densify_until_iter 6000 --percent_dense 0.01 \ 16 | --densify_from_iter 500 \ 17 | --position_lr_init 0.00016 --position_lr_final 0.0000016 --position_lr_max_steps 1000 --position_lr_start 5000 \ 18 | --test_iterations 1000 2000 3000 4500 6000 --save_iterations 1000 2000 3000 6000 \ 19 | --hard_depth_start 0 --soft_depth_start 9999999 \ 20 | --split_opacity_thresh 0.1 --error_tolerance 0.001 \ 21 | --scaling_lr 0.005 \ 22 | --shape_pena 0.000 --opa_pena 0.000 --scale_pena 0.000 \ 23 | 24 | python render.py -s $dataset --model_path $workspace -r 2 25 | python metrics.py --model_path $workspace 26 | 27 | 28 | 29 | 30 | ## for ship, lego, ficus, hotdog SH peforms better 31 | 32 | # python train_blender.py -s $dataset --model_path $workspace -r 2 --eval --n_sparse 8 --rand_pcd --iterations 6000 --lambda_dssim 0.2 --white_background \ 33 | # --densify_grad_threshold 0.0002 --prune_threshold 0.005 --densify_until_iter 6000 --percent_dense 0.01 \ 34 | # --densify_from_iter 500 \ 35 | # --position_lr_init 0.00016 --position_lr_final 0.0000016 --position_lr_max_steps 1000 --position_lr_start 5000 \ 36 | # --test_iterations 1000 2000 3000 4500 6000 --save_iterations 1000 2000 3000 6000 \ 37 | # --hard_depth_start 0 \ 38 | # --error_tolerance 0.01 \ 39 | # --scaling_lr 0.005 \ 40 | # --shape_pena 0.000 --opa_pena 0.000 --scale_pena 0.000 \ 41 | # --use_SH 42 | 43 | # python render_sh.py -s $dataset --model_path $workspace -r 2 44 | # python metrics.py --model_path $workspace 45 | 46 | 47 | 48 | 49 | ## for chair, mic the sampled views has a fully covering range so the model do not need monocular depth any more.... 50 | 51 | # python train_blender.py -s $dataset --model_path $workspace -r 2 --eval --n_sparse 8 --rand_pcd --iterations 30000 --lambda_dssim 0.2 --white_background \ 52 | # --densify_grad_threshold 0.0002 --prune_threshold 0.005 --densify_until_iter 15000 --percent_dense 0.01 \ 53 | # --densify_from_iter 500 \ 54 | # --position_lr_init 0.00016 --position_lr_final 0.0000016 --position_lr_max_steps 30000 --position_lr_start 0 \ 55 | # --test_iterations 1000 2000 3000 4500 6000 --save_iterations 1000 2000 3000 6000 \ 56 | # --hard_depth_start 99999 \ 57 | # --error_tolerance 0.2 \ 58 | # --scaling_lr 0.005 \ 59 | # --shape_pena 0.000 --opa_pena 0.000 --scale_pena 0.000 \ 60 | # --use_SH 61 | 62 | # python render_sh.py -s $dataset --model_path $workspace -r 2 63 | # python metrics.py --model_path $workspace -------------------------------------------------------------------------------- /scripts/run_dtu.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | workspace=$2 3 | export CUDA_VISIBLE_DEVICES=$3 4 | 5 | 6 | python train_dtu.py --dataset DTU -s $dataset --model_path $workspace -r 4 --eval --n_sparse 3 --rand_pcd --iterations 6000 --lambda_dssim 0.6 \ 7 | --densify_grad_threshold 0.001 --prune_threshold 0.01 --densify_until_iter 6000 --percent_dense 0.1 \ 8 | --position_lr_init 0.0016 --position_lr_final 0.000016 --position_lr_max_steps 5500 --position_lr_start 500 \ 9 | --test_iterations 100 1000 2000 3000 4500 6000 --save_iterations 100 500 1000 3000 6000\ 10 | --error_tolerance 0.01 \ 11 | --opacity_lr 0.05 --scaling_lr 0.003 \ 12 | --shape_pena 0.005 --opa_pena 0.001 --scale_pena 0.005\ 13 | 14 | bash ./scripts/copy_mask_dtu.sh 15 | 16 | python render.py -s $dataset --model_path $workspace -r 4 17 | python spiral.py -s $dataset --model_path $workspace -r 4 18 | 19 | python metrics_dtu.py --model_path $workspace 20 | -------------------------------------------------------------------------------- /scripts/run_llff.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | workspace=$2 3 | export CUDA_VISIBLE_DEVICES=$3 4 | 5 | 6 | python train_llff.py -s $dataset --model_path $workspace -r 8 --eval --n_sparse 3 --rand_pcd --iterations 6000 --lambda_dssim 0.2 \ 7 | --densify_grad_threshold 0.0013 --prune_threshold 0.01 --densify_until_iter 6000 --percent_dense 0.01 \ 8 | --position_lr_init 0.016 --position_lr_final 0.00016 --position_lr_max_steps 5500 --position_lr_start 500 \ 9 | --split_opacity_thresh 0.1 --error_tolerance 0.00025 \ 10 | --scaling_lr 0.003 \ 11 | --shape_pena 0.002 --opa_pena 0.001 \ 12 | --near 10 13 | 14 | # set a larger "--error_tolerance" may get more smooth results in visualization 15 | 16 | 17 | python render.py -s $dataset --model_path $workspace -r 8 --near 10 18 | python spiral.py -s $dataset --model_path $workspace -r 8 --near 10 19 | 20 | 21 | python metrics.py --model_path $workspace 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /scripts/run_llff_mvs.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | workspace=$2 3 | export CUDA_VISIBLE_DEVICES=$3 4 | 5 | 6 | python train_llff.py -s $dataset --model_path $workspace -r 8 --eval --n_sparse 3 --mvs_pcd --iterations 6000 --lambda_dssim 0.25 \ 7 | --densify_grad_threshold 0.0008 --prune_threshold 0.01 --densify_until_iter 6000 --percent_dense 0.01 \ 8 | --position_lr_init 0.0005 --position_lr_final 0.000005 --position_lr_max_steps 5500 --position_lr_start 500 \ 9 | --split_opacity_thresh 0.1 --error_tolerance 0.01 \ 10 | --scaling_lr 0.005 \ 11 | --shape_pena 0.002 --opa_pena 0.001 --scale_pena 0\ 12 | --test_iterations 1000 2000 3000 4500 6000 \ 13 | --near 10 14 | 15 | # set a larger "--error_tolerance" may get more smooth results in visualization 16 | 17 | 18 | python render.py -s $dataset --model_path $workspace -r 8 --near 10 19 | python spiral.py -s $dataset --model_path $workspace -r 8 --near 10 20 | 21 | 22 | python metrics.py --model_path $workspace 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /shencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphere_harmonics import SHEncoder -------------------------------------------------------------------------------- /shencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_sh_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'shencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /shencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='shencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_shencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'shencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /shencoder/shencoder.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: shencoder 3 | Version: 0.0.0 4 | -------------------------------------------------------------------------------- /shencoder/shencoder.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | /home/ub/Desktop/sparse3dgs/shencoder/src/bindings.cpp 3 | /home/ub/Desktop/sparse3dgs/shencoder/src/shencoder.cu 4 | shencoder.egg-info/PKG-INFO 5 | shencoder.egg-info/SOURCES.txt 6 | shencoder.egg-info/dependency_links.txt 7 | shencoder.egg-info/top_level.txt -------------------------------------------------------------------------------- /shencoder/shencoder.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /shencoder/shencoder.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | _shencoder 2 | -------------------------------------------------------------------------------- /shencoder/sphere_harmonics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _shencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | class _sh_encoder(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 17 | def forward(ctx, inputs, degree, calc_grad_inputs=False): 18 | # inputs: [B, input_dim], float in [-1, 1] 19 | # RETURN: [B, F], float 20 | 21 | inputs = inputs.contiguous() 22 | B, input_dim = inputs.shape # batch size, coord dim 23 | output_dim = degree ** 2 24 | 25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 26 | 27 | if calc_grad_inputs: 28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) 29 | else: 30 | dy_dx = None 31 | 32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) 33 | 34 | ctx.save_for_backward(inputs, dy_dx) 35 | ctx.dims = [B, input_dim, degree] 36 | 37 | return outputs 38 | 39 | @staticmethod 40 | #@once_differentiable 41 | @custom_bwd 42 | def backward(ctx, grad): 43 | # grad: [B, C * C] 44 | 45 | inputs, dy_dx = ctx.saved_tensors 46 | 47 | if dy_dx is not None: 48 | grad = grad.contiguous() 49 | B, input_dim, degree = ctx.dims 50 | grad_inputs = torch.zeros_like(inputs) 51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) 52 | return grad_inputs, None, None 53 | else: 54 | return None, None, None 55 | 56 | 57 | 58 | sh_encode = _sh_encoder.apply 59 | 60 | 61 | class SHEncoder(nn.Module): 62 | def __init__(self, input_dim=3, degree=4): 63 | super().__init__() 64 | 65 | self.input_dim = input_dim # coord dims, must be 3 66 | self.degree = degree # 0 ~ 4 67 | self.output_dim = degree ** 2 68 | 69 | assert self.input_dim == 3, "SH encoder only support input dim == 3" 70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" 71 | 72 | def __repr__(self): 73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" 74 | 75 | def forward(self, inputs, size=1): 76 | # inputs: [..., input_dim], normalized real world positions in [-size, size] 77 | # return: [..., degree^2] 78 | 79 | inputs = inputs / size # [-1, 1] 80 | 81 | prefix_shape = list(inputs.shape[:-1]) 82 | inputs = inputs.reshape(-1, self.input_dim) 83 | 84 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad) 85 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 86 | 87 | return outputs -------------------------------------------------------------------------------- /shencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "shencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); 7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /shencoder/src/shencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // inputs: [B, D], float, in [-1, 1] 7 | // outputs: [B, F], float 8 | 9 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); 10 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /spiral.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import RenderScene 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from argparse import ArgumentParser 21 | from arguments import ModelParams, PipelineParams, get_combined_args 22 | from gaussian_renderer import GaussianModel 23 | 24 | 25 | 26 | import numpy as np 27 | import matplotlib.cm as cm 28 | 29 | 30 | def weighted_percentile(x, w, ps, assume_sorted=False): 31 | """Compute the weighted percentile(s) of a single vector.""" 32 | x = x.reshape([-1]) 33 | w = w.reshape([-1]) 34 | if not assume_sorted: 35 | sortidx = np.argsort(x) 36 | x, w = x[sortidx], w[sortidx] 37 | acc_w = np.cumsum(w) 38 | return np.interp(np.array(ps) * (acc_w[-1] / 100), acc_w, x) 39 | 40 | def visualize_cmap(value, 41 | weight, 42 | colormap, 43 | lo=None, 44 | hi=None, 45 | percentile=99., 46 | curve_fn=lambda x: x, 47 | modulus=None, 48 | matte_background=True): 49 | """Visualize a 1D image and a 1D weighting according to some colormap. 50 | 51 | Args: 52 | value: A 1D image. 53 | weight: A weight map, in [0, 1]. 54 | colormap: A colormap function. 55 | lo: The lower bound to use when rendering, if None then use a percentile. 56 | hi: The upper bound to use when rendering, if None then use a percentile. 57 | percentile: What percentile of the value map to crop to when automatically 58 | generating `lo` and `hi`. Depends on `weight` as well as `value'. 59 | curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` 60 | before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). 61 | modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If 62 | `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. 63 | matte_background: If True, matte the image over a checkerboard. 64 | 65 | Returns: 66 | A colormap rendering. 67 | """ 68 | # Identify the values that bound the middle of `value' according to `weight`. 69 | lo_auto, hi_auto = weighted_percentile( 70 | value, weight, [50 - percentile / 2, 50 + percentile / 2]) 71 | 72 | # If `lo` or `hi` are None, use the automatically-computed bounds above. 73 | eps = np.finfo(np.float32).eps 74 | lo = lo or (lo_auto - eps) 75 | hi = hi or (hi_auto + eps) 76 | 77 | # Curve all values. 78 | value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] 79 | 80 | # Wrap the values around if requested. 81 | if modulus: 82 | value = np.mod(value, modulus) / modulus 83 | else: 84 | # Otherwise, just scale to [0, 1]. 85 | value = np.nan_to_num( 86 | np.clip((value - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)) 87 | 88 | if colormap: 89 | colorized = colormap(value)[:, :, :3] 90 | else: 91 | assert len(value.shape) == 3 and value.shape[-1] == 3 92 | colorized = value 93 | 94 | return colorized 95 | 96 | depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) 97 | 98 | 99 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background, near): 100 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration)) 101 | 102 | makedirs(render_path, exist_ok=True) 103 | 104 | mask_near = None 105 | for idx, view in enumerate(tqdm(views, desc="Rendering progress", ascii=True, dynamic_ncols=True)): 106 | mask_temp = (gaussians.get_xyz - view.camera_center.repeat(gaussians.get_xyz.shape[0], 1)).norm(dim=1, keepdim=True) < near 107 | mask_near = mask_near + mask_temp if mask_near is not None else mask_temp 108 | gaussians.prune_points_inference(mask_near) 109 | 110 | for idx, view in enumerate(tqdm(views, desc="Rendering progress", ascii=True, dynamic_ncols=True)): 111 | with torch.no_grad(): 112 | render_pkg = render(view, gaussians, pipeline, background, inference=True) 113 | rendering = render_pkg["render"] 114 | depth = (render_pkg['depth'] - render_pkg['depth'].min()) / (render_pkg['depth'].max() - render_pkg['depth'].min()) + 1 * (1 - render_pkg["alpha"]) 115 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 116 | torchvision.utils.save_image(1 - depth, os.path.join(render_path, 'depth_{0:05d}'.format(idx) + ".png")) 117 | 118 | depth_est = depth.squeeze().cpu().numpy() 119 | depth_est = visualize_cmap(depth_est, np.ones_like(depth_est), cm.get_cmap('turbo'), curve_fn=depth_curve_fn).copy() 120 | depth_est = torch.as_tensor(depth_est).permute(2,0,1) 121 | torchvision.utils.save_image(depth_est, os.path.join(render_path, 'cdepth_{0:05d}'.format(idx) + ".png")) 122 | 123 | os.system(f"ffmpeg -i " + render_path + f"/%5d.png -q 2 " + model_path + "/out_{}.mp4 -y".format(model_path.split('/')[-1])) 124 | os.system(f"ffmpeg -i " + render_path + f"/depth_%5d.png -q 2 " + model_path + "/out_depth_{}.mp4 -y".format(model_path.split('/')[-1])) 125 | os.system(f"ffmpeg -i " + render_path + f"/cdepth_%5d.png -q 2 " + model_path + "/out_cdepth_{}.mp4 -y".format(model_path.split('/')[-1])) 126 | 127 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, near): 128 | with torch.no_grad(): 129 | gaussians = GaussianModel(dataset.sh_degree) 130 | (model_params, _) = torch.load(os.path.join(dataset.model_path, "chkpnt_latest.pth")) 131 | gaussians.restore(model_params) 132 | gaussians.neural_renderer.keep_sigma=True 133 | 134 | scene = RenderScene(dataset, gaussians, load_iteration=iteration, spiral=True) 135 | 136 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 137 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 138 | 139 | render_set(dataset.model_path, "render", scene.loaded_iter, scene.getRenderCameras(), gaussians, pipeline, background, near) 140 | 141 | if __name__ == "__main__": 142 | # Set up command line argument parser 143 | parser = ArgumentParser(description="Testing script parameters") 144 | model = ModelParams(parser, sentinel=True) 145 | pipeline = PipelineParams(parser) 146 | parser.add_argument("--iteration", default=-1, type=int) 147 | parser.add_argument("--quiet", action="store_true") 148 | parser.add_argument("--near", default=0, type=float) 149 | args = get_combined_args(parser) 150 | print("Rendering " + args.model_path) 151 | 152 | # Initialize system state (RNG) 153 | safe_state(args.quiet) 154 | 155 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.near) -------------------------------------------------------------------------------- /submodules/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/DNGaussian/24a0d7de512ea5c2caaf5f7380db357046e353c8/submodules/.gitkeep -------------------------------------------------------------------------------- /train_dtu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | import torchvision 15 | from os import makedirs 16 | from random import randint 17 | from utils.graphics_utils import fov2focal 18 | from utils.loss_utils import l1_loss, patch_norm_mse_loss, patch_norm_mse_loss_global, ssim 19 | # from utils.loss_utils import mssim as ssim 20 | from gaussian_renderer import render, render_for_depth, render_for_opa # , network_gui 21 | import sys 22 | from scene import Scene, GaussianModel 23 | from utils.general_utils import safe_state 24 | import uuid 25 | from tqdm import tqdm 26 | from utils.image_utils import psnr 27 | from argparse import ArgumentParser, Namespace 28 | from arguments import ModelParams, PipelineParams, OptimizationParams 29 | try: 30 | from torch.utils.tensorboard import SummaryWriter 31 | print('Launch TensorBoard') 32 | TENSORBOARD_FOUND = True 33 | except ImportError: 34 | TENSORBOARD_FOUND = False 35 | 36 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 37 | 38 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 39 | first_iter = 0 40 | tb_writer = prepare_output_and_logger(dataset, opt) 41 | gaussians = GaussianModel(dataset.sh_degree) 42 | scene = Scene(dataset, gaussians) 43 | gaussians.training_setup(opt) 44 | if checkpoint: 45 | (model_params, _) = torch.load(checkpoint) 46 | gaussians.load_shape(model_params, opt) 47 | 48 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 49 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 50 | 51 | iter_start = torch.cuda.Event(enable_timing = True) 52 | iter_end = torch.cuda.Event(enable_timing = True) 53 | 54 | viewpoint_stack = None 55 | ema_loss_for_log = 0.0 56 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress", ascii=True, dynamic_ncols=True) 57 | first_iter += 1 58 | 59 | ema_loss_hard = 0.0 60 | 61 | if args.dataset == 'DTU': 62 | patch_range = (17, 53) 63 | 64 | for iteration in range(first_iter, opt.iterations + 1): 65 | 66 | iter_start.record() 67 | 68 | gaussians.update_learning_rate(max(iteration - opt.position_lr_start, 0)) 69 | 70 | # Every 1000 its we increase the levels of SH up to a maximum degree 71 | # if iteration % 1000 == 0: 72 | # gaussians.oneupSHdegree() 73 | 74 | # Pick a random Camera 75 | if not viewpoint_stack: 76 | viewpoint_stack = scene.getTrainCameras().copy() 77 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 78 | gt_image = viewpoint_cam.original_image.cuda() 79 | 80 | # Render 81 | if (iteration - 1) == debug_from: 82 | pipe.debug = True 83 | 84 | bg_mask = None 85 | if args.dataset == 'DTU': 86 | if 'scan110' not in scene.source_path: 87 | bg_mask = (gt_image.max(0, keepdim=True).values < 30/255) 88 | else: 89 | bg_mask = (gt_image.max(0, keepdim=True).values < 15/255) 90 | bg_mask_clone = bg_mask.clone() 91 | for i in range(1, 50): 92 | bg_mask[:, i:] *= bg_mask_clone[:, :-i] 93 | gt_image[bg_mask.repeat(3,1,1)] = 0. 94 | 95 | # -------------------------------------------------- DEPTH -------------------------------------------- 96 | if iteration > opt.hard_depth_start: 97 | render_pkg = render_for_depth(viewpoint_cam, gaussians, pipe, background) 98 | depth = render_pkg["depth"] 99 | 100 | # Depth loss 101 | loss_hard = 0 102 | depth_mono = 255.0 - viewpoint_cam.depth_mono 103 | if args.dataset == 'DTU': 104 | depth_mono[bg_mask] = depth_mono[~bg_mask].mean() 105 | depth[bg_mask] = depth[~bg_mask].mean().detach() 106 | 107 | 108 | loss_l2_dpt = patch_norm_mse_loss(depth[None,...], depth_mono[None,...], randint(patch_range[0], patch_range[1]), opt.error_tolerance) 109 | loss_hard += 0.1 * loss_l2_dpt 110 | 111 | loss_global = patch_norm_mse_loss_global(depth[None,...], depth_mono[None,...], randint(patch_range[0], patch_range[1]), opt.error_tolerance) 112 | loss_hard += 1 * loss_global 113 | 114 | loss_hard.backward() 115 | 116 | # Optimizer step 117 | if iteration < opt.iterations: 118 | gaussians.optimizer.step() 119 | gaussians.optimizer.zero_grad(set_to_none = True) 120 | 121 | 122 | # if iteration > opt.densify_from_iter: 123 | # gaussians.prune(opt.prune_threshold) 124 | 125 | 126 | # -------------------------------------------------- soft -------------------------------------------- 127 | ema_loss_hard = 0.1 * (loss_hard.item()) + 0.9 * ema_loss_hard 128 | if iteration > opt.soft_depth_start and ema_loss_hard < 0.1: 129 | render_pkg = render_for_opa(viewpoint_cam, gaussians, pipe, background) 130 | viewspace_point_tensor, visibility_filter = render_pkg["viewspace_points"], render_pkg["visibility_filter"] 131 | depth, alpha = render_pkg["depth"], render_pkg["alpha"] 132 | 133 | # Depth loss 134 | loss_soft = 0 135 | depth_mono = 255.0 - viewpoint_cam.depth_mono 136 | if args.dataset == 'DTU': 137 | depth_mono[bg_mask] = depth_mono[~bg_mask].mean() 138 | depth[bg_mask] = depth[~bg_mask].mean().detach() 139 | 140 | loss_l2_dpt = patch_norm_mse_loss(depth[None,...], depth_mono[None,...], randint(patch_range[0], patch_range[1]), opt.error_tolerance) 141 | loss_soft += 0.1 * loss_l2_dpt 142 | 143 | loss_global = patch_norm_mse_loss_global(depth[None,...], depth_mono[None,...], randint(patch_range[0], patch_range[1]), opt.error_tolerance) 144 | loss_soft += 1 * loss_global 145 | 146 | loss_soft.backward() 147 | 148 | # Optimizer step 149 | if iteration < opt.iterations: 150 | gaussians.optimizer.step() 151 | gaussians.optimizer.zero_grad(set_to_none = True) 152 | 153 | 154 | 155 | if args.dataset == 'DTU': 156 | render_pkg = render_for_opa(viewpoint_cam, gaussians, pipe, background) 157 | (render_pkg["alpha"][bg_mask]**2).mean().backward() 158 | gaussians.optimizer.step() 159 | gaussians.optimizer.zero_grad(set_to_none = True) 160 | 161 | 162 | 163 | # ---------------------------------------------- Photometric -------------------------------------------- 164 | 165 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 166 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 167 | # depth 168 | depth, opacity, alpha = render_pkg["depth"], render_pkg["opacity"], render_pkg['alpha'] # [visibility_filter] 169 | 170 | # Loss 171 | Ll1 = l1_loss(image, gt_image) 172 | loss = Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 173 | 174 | # Reg 175 | loss_reg = torch.tensor(0., device=loss.device) 176 | shape_pena = (gaussians.get_scaling.max(dim=1).values / gaussians.get_scaling.min(dim=1).values).mean() 177 | # scale_pena = (gaussians.get_scaling.max(dim=1).values).std() 178 | scale_pena = ((gaussians.get_scaling.max(dim=1, keepdim=True).values)**2).mean() 179 | opa_pena = 1 - (opacity[opacity > 0.2]**2).mean() + ((1 - opacity[opacity < 0.2])**2).mean() 180 | 181 | # loss_reg += 0.01*shape_pena + 0.001*scale_pena + 0.01*opa_pena 182 | loss_reg += opt.shape_pena*shape_pena + opt.scale_pena*scale_pena + opt.opa_pena*opa_pena 183 | loss += loss_reg 184 | 185 | loss.backward() 186 | 187 | # ================================================================================ 188 | 189 | iter_end.record() 190 | 191 | with torch.no_grad(): 192 | # Progress bar 193 | if not loss.isnan(): 194 | ema_loss_for_log = 0.4 * (loss.item()) + 0.6 * ema_loss_for_log 195 | if iteration % 10 == 0: 196 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 197 | progress_bar.update(10) 198 | if iteration == opt.iterations: 199 | progress_bar.close() 200 | 201 | # Log and save 202 | clean_iterations = testing_iterations + [first_iter] 203 | clean_views(iteration, clean_iterations, scene, gaussians, pipe, background) 204 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 205 | if (iteration in saving_iterations): 206 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 207 | scene.save(iteration, render(viewpoint_cam, gaussians, pipe, background)["color"]) 208 | 209 | # Densification 210 | if iteration < opt.densify_until_iter and iteration not in clean_iterations: 211 | # Keep track of max radii in image-space for pruning 212 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 213 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 214 | 215 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 216 | size_threshold = max_dist = None 217 | 218 | if args.dataset == "DTU": 219 | if 'scan110' not in scene.source_path: 220 | color = render(viewpoint_cam, gaussians, pipe, background)["color"] 221 | black_mask = color.max(-1, keepdim=True).values < 20/255 222 | gaussians.xyz_gradient_accum[black_mask] /= 10 223 | gaussians._opacity[black_mask] = gaussians.inverse_opacity_activation(torch.ones_like(gaussians._opacity[black_mask]) * 0.1) 224 | 225 | if 'scan114' not in scene.source_path and 'scan21' not in scene.source_path: 226 | white_mask = color.min(-1, keepdim=True).values > 240/255 227 | gaussians.xyz_gradient_accum[white_mask] /= 2 228 | 229 | if iteration % 2001 == 0: 230 | gaussians._opacity[white_mask] = gaussians.inverse_opacity_activation(torch.ones_like(gaussians._opacity[white_mask]) * 0.1) 231 | 232 | gaussians.densify_and_prune(opt.densify_grad_threshold, opt.prune_threshold, scene.cameras_extent, size_threshold, opt.split_opacity_thresh, max_dist) 233 | 234 | # if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 235 | # gaussians.reset_opacity() 236 | 237 | # Optimizer step 238 | if iteration < opt.iterations: 239 | gaussians.optimizer.step() 240 | gaussians.optimizer.zero_grad(set_to_none = True) 241 | 242 | if (iteration in checkpoint_iterations): 243 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 244 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 245 | if iteration == opt.iterations: 246 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 247 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt_latest.pth") 248 | 249 | 250 | def prepare_output_and_logger(args, opt): 251 | if not args.model_path: 252 | if os.getenv('OAR_JOB_ID'): 253 | unique_str=os.getenv('OAR_JOB_ID') 254 | else: 255 | unique_str = str(uuid.uuid4()) 256 | args.model_path = os.path.join("./output/", unique_str[0:10]) 257 | 258 | # Set up output folder 259 | print("Output folder: {}".format(args.model_path)) 260 | os.makedirs(args.model_path, exist_ok = True) 261 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 262 | cfg_log_f.write(str(Namespace(**vars(args)))) 263 | with open(os.path.join(args.model_path, "opt_args"), 'w') as opt_log_f: 264 | opt_log_f.write(str(Namespace(**vars(opt)))) 265 | 266 | # Create Tensorboard writer 267 | tb_writer = None 268 | if TENSORBOARD_FOUND: 269 | tb_writer = SummaryWriter(args.model_path) 270 | else: 271 | print("Tensorboard not available: not logging progress") 272 | return tb_writer 273 | 274 | 275 | @torch.no_grad() 276 | def clean_views(iteration, test_iterations, scene, gaussians, pipe, background): 277 | if iteration in test_iterations: 278 | visible_pnts = None 279 | for viewpoint_cam in scene.getTrainCameras().copy(): 280 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 281 | visibility_filter = render_pkg["visibility_filter"] 282 | if visible_pnts is None: 283 | visible_pnts = visibility_filter 284 | visible_pnts += visibility_filter 285 | unvisible_pnts = ~visible_pnts 286 | gaussians.prune_points(unvisible_pnts) 287 | 288 | 289 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, depth_loss=torch.tensor(0), reg_loss=torch.tensor(0)): 290 | if tb_writer: 291 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 292 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 293 | tb_writer.add_scalar('iter_time', elapsed, iteration) 294 | tb_writer.add_scalar('train_loss_patches/depth_kl_loss', depth_loss.item(), iteration) 295 | tb_writer.add_scalar('train_loss_patches/reg_loss', reg_loss.item(), iteration) 296 | 297 | # Report test and samples of training set 298 | if iteration in testing_iterations: 299 | torch.cuda.empty_cache() 300 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 301 | {'name': 'eval', 'cameras' : scene.getEvalCameras()}, 302 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 303 | 304 | for config in validation_configs: 305 | if config['cameras'] and len(config['cameras']) > 0: 306 | l1_test = 0.0 307 | psnr_test = 0.0 308 | for idx, viewpoint in enumerate(config['cameras']): 309 | render_results = renderFunc(viewpoint, scene.gaussians, *renderArgs) 310 | image = torch.clamp(render_results["render"], 0.0, 1.0) 311 | depth = render_results["depth"] 312 | depth = 1 - (depth - depth.min()) / (depth.max() - depth.min()) 313 | alpha = render_results["alpha"] 314 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 315 | bg_mask = (gt_image.max(0, keepdim=True).values < 30/255) 316 | bg_mask_clone = bg_mask.clone() 317 | for i in range(1, 50): 318 | bg_mask[:, i:] *= bg_mask_clone[:, :-i] 319 | white_mask = (gt_image.min(0, keepdim=True).values > 240/255) 320 | if tb_writer and (idx < 5): 321 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 322 | tb_writer.add_images(config['name'] + "_view_{}/depth".format(viewpoint.image_name), depth[None], global_step=iteration) 323 | tb_writer.add_images(config['name'] + "_view_{}_alpha/alpha".format(viewpoint.image_name), alpha[None], global_step=iteration) 324 | tb_writer.add_images(config['name'] + "_view_{}_alpha/mask".format(viewpoint.image_name), bg_mask[None], global_step=iteration) 325 | tb_writer.add_images(config['name'] + "_view_{}_alpha/white_mask".format(viewpoint.image_name), white_mask[None], global_step=iteration) 326 | 327 | if iteration == testing_iterations[0]: 328 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 329 | l1_test += l1_loss(image, gt_image).mean().double() 330 | psnr_test += psnr(image, gt_image).mean().double() 331 | psnr_test /= len(config['cameras']) 332 | l1_test /= len(config['cameras']) 333 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 334 | if tb_writer: 335 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 336 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 337 | 338 | if tb_writer: 339 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 340 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 341 | torch.cuda.empty_cache() 342 | 343 | if __name__ == "__main__": 344 | # Set up command line argument parser 345 | parser = ArgumentParser(description="Training script parameters") 346 | lp = ModelParams(parser) 347 | op = OptimizationParams(parser) 348 | pp = PipelineParams(parser) 349 | # parser.add_argument('--ip', type=str, default="127.0.0.1") 350 | # parser.add_argument('--port', type=int, default=6009) 351 | parser.add_argument('--debug_from', type=int, default=-1) 352 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 353 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[1000, 2000, 3000, 6000]) 354 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1000, 2000, 3000, 6000]) 355 | parser.add_argument("--quiet", action="store_true") 356 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 357 | parser.add_argument("--start_checkpoint", type=str, default = None) 358 | args = parser.parse_args(sys.argv[1:]) 359 | args.save_iterations.append(args.iterations) 360 | # args.checkpoint_iterations.append(args.iterations) 361 | 362 | print("Optimizing " + args.model_path) 363 | 364 | # Initialize system state (RNG) 365 | safe_state(args.quiet) 366 | 367 | # Start GUI server, configure and run training 368 | # network_gui.init(args.ip, args.port) 369 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 370 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 371 | 372 | # All done 373 | print("\nTraining complete.") 374 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | try: 13 | from scene.cameras import Camera 14 | except: 15 | from scene_sh.cameras import Camera 16 | import numpy as np 17 | from utils.general_utils import NPtoTorch, PILtoTorch 18 | from utils.graphics_utils import fov2focal 19 | 20 | WARNED = False 21 | 22 | def loadCam(args, id, cam_info, resolution_scale): 23 | orig_w, orig_h = cam_info.image.size 24 | 25 | if args.resolution in [1, 2, 4, 8]: 26 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 27 | else: # should be a type that converts to float 28 | if args.resolution == -1: 29 | if orig_w > 6400: 30 | global WARNED 31 | if not WARNED: 32 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 6.4K.\n " 33 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 34 | WARNED = True 35 | global_down = orig_w / 6400 36 | else: 37 | global_down = 1 38 | else: 39 | global_down = orig_w / args.resolution 40 | 41 | scale = float(global_down) * float(resolution_scale) 42 | resolution = (int(orig_w / scale), int(orig_h / scale)) 43 | 44 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 45 | resized_depth_mono = PILtoTorch(cam_info.depth_mono, resolution) 46 | 47 | gt_image = resized_image_rgb[:3, ...] 48 | loaded_mask = None 49 | 50 | if resized_image_rgb.shape[1] == 4: 51 | loaded_mask = resized_image_rgb[3:4, ...] 52 | 53 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 54 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 55 | image=gt_image, gt_alpha_mask=loaded_mask, depth_mono=resized_depth_mono, 56 | image_name=cam_info.image_name, uid=id, data_device=args.data_device) 57 | 58 | 59 | def loadRenderCam(args, id, cam_info, resolution_scale): 60 | orig_w, orig_h = cam_info.width, cam_info.height 61 | 62 | if args.resolution in [1, 2, 4, 8]: 63 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 64 | else: # should be a type that converts to float 65 | if args.resolution == -1: 66 | if orig_w > 6400: 67 | global WARNED 68 | if not WARNED: 69 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 6.4K.\n " 70 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 71 | WARNED = True 72 | global_down = orig_w / 6400 73 | else: 74 | global_down = 1 75 | else: 76 | global_down = orig_w / args.resolution 77 | 78 | scale = float(global_down) * float(resolution_scale) 79 | resolution = (int(orig_w / scale), int(orig_h / scale)) 80 | 81 | cam = Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 82 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 83 | image=None, gt_alpha_mask=None, depth_mono=None, 84 | image_name=cam_info.image_name, uid=id, data_device=args.data_device) 85 | cam.image_width, cam.image_height = resolution 86 | return cam 87 | 88 | 89 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 90 | camera_list = [] 91 | 92 | for id, c in enumerate(cam_infos): 93 | camera_list.append(loadCam(args, id, c, resolution_scale)) 94 | 95 | return camera_list 96 | 97 | def renderCameraList_from_camInfos(cam_infos, resolution_scale, args): 98 | camera_list = [] 99 | 100 | for id, c in enumerate(cam_infos): 101 | camera_list.append(loadRenderCam(args, id, c, resolution_scale)) 102 | 103 | return camera_list 104 | 105 | 106 | def camera_to_JSON(id, camera : Camera): 107 | Rt = np.zeros((4, 4)) 108 | Rt[:3, :3] = camera.R.transpose() 109 | Rt[:3, 3] = camera.T 110 | Rt[3, 3] = 1.0 111 | 112 | W2C = np.linalg.inv(Rt) 113 | pos = W2C[:3, 3] 114 | rot = W2C[:3, :3] 115 | serializable_array_2d = [x.tolist() for x in rot] 116 | camera_entry = { 117 | 'id' : id, 118 | 'img_name' : camera.image_name, 119 | 'width' : camera.width, 120 | 'height' : camera.height, 121 | 'position': pos.tolist(), 122 | 'rotation': serializable_array_2d, 123 | 'fy' : fov2focal(camera.FovY, camera.height), 124 | 'fx' : fov2focal(camera.FovX, camera.width) 125 | } 126 | return camera_entry 127 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def NPtoTorch(arr, resolution): 30 | resized_image = torch.from_numpy(arr)[None, ...] 31 | resized_image = torch.nn.functional.interpolate(resized_image, size=resolution, mode='bilinear', align_corners=None) 32 | return resized_image 33 | 34 | def get_expon_lr_func( 35 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 36 | ): 37 | """ 38 | Copied from Plenoxels 39 | 40 | Continuous learning rate decay function. Adapted from JaxNeRF 41 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 42 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 43 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 44 | function of lr_delay_mult, such that the initial learning rate is 45 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 46 | to the normal learning rate when steps>lr_delay_steps. 47 | :param conf: config subtree 'lr' or similar 48 | :param max_steps: int, the number of steps during optimization. 49 | :return HoF which takes step as input 50 | """ 51 | 52 | def helper(step): 53 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 54 | # Disable this parameter 55 | return 0.0 56 | if lr_delay_steps > 0: 57 | # A kind of reverse cosine decay. 58 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 59 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 60 | ) 61 | else: 62 | delay_rate = 1.0 63 | t = np.clip(step / max_steps, 0, 1) 64 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 65 | return delay_rate * log_lerp 66 | 67 | return helper 68 | 69 | def strip_lowerdiag(L): 70 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 71 | 72 | uncertainty[:, 0] = L[:, 0, 0] 73 | uncertainty[:, 1] = L[:, 0, 1] 74 | uncertainty[:, 2] = L[:, 0, 2] 75 | uncertainty[:, 3] = L[:, 1, 1] 76 | uncertainty[:, 4] = L[:, 1, 2] 77 | uncertainty[:, 5] = L[:, 2, 2] 78 | return uncertainty 79 | 80 | def strip_symmetric(sym): 81 | return strip_lowerdiag(sym) 82 | 83 | def build_rotation(r): 84 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 85 | 86 | q = r / norm[:, None] 87 | 88 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 89 | 90 | r = q[:, 0] 91 | x = q[:, 1] 92 | y = q[:, 2] 93 | z = q[:, 3] 94 | 95 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 96 | R[:, 0, 1] = 2 * (x*y - r*z) 97 | R[:, 0, 2] = 2 * (x*z + r*y) 98 | R[:, 1, 0] = 2 * (x*y + r*z) 99 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 100 | R[:, 1, 2] = 2 * (y*z - r*x) 101 | R[:, 2, 0] = 2 * (x*z - r*y) 102 | R[:, 2, 1] = 2 * (y*z + r*x) 103 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 104 | return R 105 | 106 | def build_scaling_rotation(s, r): 107 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 108 | R = build_rotation(r) 109 | 110 | L[:,0,0] = s[:,0] 111 | L[:,1,1] = s[:,1] 112 | L[:,2,2] = s[:,2] 113 | 114 | L = R @ L 115 | return L 116 | 117 | def safe_state(silent): 118 | old_f = sys.stdout 119 | class F: 120 | def __init__(self, silent): 121 | self.silent = silent 122 | 123 | def write(self, x): 124 | if not self.silent: 125 | if x.endswith("\n"): 126 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 127 | else: 128 | old_f.write(x) 129 | 130 | def flush(self): 131 | old_f.flush() 132 | 133 | sys.stdout = F(silent) 134 | 135 | random.seed(0) 136 | np.random.seed(0) 137 | torch.manual_seed(0) 138 | torch.cuda.set_device(torch.device("cuda:0")) 139 | 140 | 141 | def quaternion_to_matrix(quaternions): 142 | """ 143 | Convert rotations given as quaternions to rotation matrices. 144 | 145 | Args: 146 | quaternions: quaternions with real part first, 147 | as tensor of shape (..., 4). 148 | 149 | Returns: 150 | Rotation matrices as tensor of shape (..., 3, 3). 151 | """ 152 | r, i, j, k = torch.unbind(quaternions, -1) 153 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 154 | 155 | o = torch.stack( 156 | ( 157 | 1 - two_s * (j * j + k * k), 158 | two_s * (i * j - k * r), 159 | two_s * (i * k + j * r), 160 | two_s * (i * j + k * r), 161 | 1 - two_s * (i * i + k * k), 162 | two_s * (j * k - i * r), 163 | two_s * (i * k - j * r), 164 | two_s * (j * k + i * r), 165 | 1 - two_s * (i * i + j * j), 166 | ), 167 | -1, 168 | ) 169 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 170 | 171 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | 18 | def normalize(input, mean=None, std=None): 19 | input_mean = torch.mean(input, dim=1, keepdim=True) if mean is None else mean 20 | input_std = torch.std(input, dim=1, keepdim=True) if std is None else std 21 | return (input - input_mean) / (input_std + 1e-2*torch.std(input.reshape(-1))) 22 | 23 | def shuffle(input): 24 | # shuffle dim=1 25 | idx = torch.randperm(input[0].shape[1]) 26 | for i in range(input.shape[0]): 27 | input[i] = input[i][:, idx].view(input[i].shape) 28 | 29 | def loss_depth_smoothness(depth, img): 30 | img_grad_x = img[:, :, :, :-1] - img[:, :, :, 1:] 31 | img_grad_y = img[:, :, :-1, :] - img[:, :, 1:, :] 32 | weight_x = torch.exp(-torch.abs(img_grad_x).mean(1).unsqueeze(1)) 33 | weight_y = torch.exp(-torch.abs(img_grad_y).mean(1).unsqueeze(1)) 34 | 35 | loss = (((depth[:, :, :, :-1] - depth[:, :, :, 1:]).abs() * weight_x).sum() + 36 | ((depth[:, :, :-1, :] - depth[:, :, 1:, :]).abs() * weight_y).sum()) / \ 37 | (weight_x.sum() + weight_y.sum()) 38 | return loss 39 | 40 | def loss_depth_grad(depth, img): 41 | img_grad_x = img[:, :, :, :-1] - img[:, :, :, 1:] 42 | img_grad_y = img[:, :, :-1, :] - img[:, :, 1:, :] 43 | weight_x = img_grad_x / (torch.abs(img_grad_x) + 1e-6) 44 | weight_y = img_grad_y / (torch.abs(img_grad_y) + 1e-6) 45 | 46 | depth_grad_x = depth[:, :, :, :-1] - depth[:, :, :, 1:] 47 | depth_grad_y = depth[:, :, :-1, :] - depth[:, :, 1:, :] 48 | grad_x = depth_grad_x / (torch.abs(depth_grad_x) + 1e-6) 49 | grad_y = depth_grad_y / (torch.abs(depth_grad_y) + 1e-6) 50 | 51 | loss = l1_loss(grad_x, weight_x) + l1_loss(grad_y, weight_y) 52 | return loss 53 | 54 | 55 | def l1_loss(network_output, gt): 56 | return torch.abs((network_output - gt)).mean() 57 | 58 | def l2_loss(network_output, gt): 59 | return ((network_output - gt) ** 2).mean() 60 | 61 | def margin_l2_loss(network_output, gt, margin, return_mask=False): 62 | mask = (network_output - gt).abs() > margin 63 | if not return_mask: 64 | return ((network_output - gt)[mask] ** 2).mean() 65 | else: 66 | return ((network_output - gt)[mask] ** 2).mean(), mask 67 | 68 | def margin_l1_loss(network_output, gt, margin, return_mask=False): 69 | mask = (network_output - gt).abs() > margin 70 | if not return_mask: 71 | return ((network_output - gt)[mask].abs()).mean() 72 | else: 73 | return ((network_output - gt)[mask].abs()).mean(), mask 74 | 75 | 76 | def kl_loss(input, target): 77 | input = F.log_softmax(input, dim=-1) 78 | target = F.softmax(target, dim=-1) 79 | return F.kl_div(input, target, reduction="batchmean") 80 | 81 | def patchify(input, patch_size): 82 | patches = F.unfold(input, kernel_size=patch_size, stride=patch_size).permute(0,2,1).view(-1, 1*patch_size*patch_size) 83 | return patches 84 | 85 | def patch_norm_mse_loss(input, target, patch_size, margin, return_mask=False): 86 | input_patches = normalize(patchify(input, patch_size)) 87 | target_patches = normalize(patchify(target, patch_size)) 88 | return margin_l2_loss(input_patches, target_patches, margin, return_mask) 89 | 90 | def patch_norm_mse_loss_global(input, target, patch_size, margin, return_mask=False): 91 | input_patches = normalize(patchify(input, patch_size), std = input.std().detach()) 92 | target_patches = normalize(patchify(target, patch_size), std = target.std().detach()) 93 | return margin_l2_loss(input_patches, target_patches, margin, return_mask) 94 | 95 | def patch_norm_l1_loss_global(input, target, patch_size, margin, return_mask=False): 96 | input_patches = normalize(patchify(input, patch_size), std = input.std().detach()) 97 | target_patches = normalize(patchify(target, patch_size), std = target.std().detach()) 98 | return margin_l1_loss(input_patches, target_patches, margin, return_mask) 99 | 100 | def patch_norm_l1_loss(input, target, patch_size, margin, return_mask=False): 101 | input_patches = normalize(patchify(input, patch_size)) 102 | target_patches = normalize(patchify(target, patch_size)) 103 | return margin_l1_loss(input_patches, target_patches, margin, return_mask) 104 | 105 | 106 | def gaussian(window_size, sigma): 107 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 108 | return gauss / gauss.sum() 109 | 110 | def create_window(window_size, channel): 111 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 112 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 113 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 114 | return window 115 | 116 | 117 | def margin_ssim(img1, img2, window_size=11, size_average=True): 118 | result = ssim(img1, img2, window_size, False) 119 | print(result.shape) 120 | 121 | 122 | def ssim(img1, img2, window_size=11, size_average=True): 123 | channel = img1.size(-3) 124 | window = create_window(window_size, channel) 125 | 126 | if img1.is_cuda: 127 | window = window.cuda(img1.get_device()) 128 | window = window.type_as(img1) 129 | 130 | return _ssim(img1, img2, window, window_size, channel, size_average) 131 | 132 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 133 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 134 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 135 | 136 | mu1_sq = mu1.pow(2) 137 | mu2_sq = mu2.pow(2) 138 | mu1_mu2 = mu1 * mu2 139 | 140 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 141 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 142 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 143 | 144 | C1 = 0.01 ** 2 145 | C2 = 0.03 ** 2 146 | 147 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 148 | 149 | if size_average: 150 | return ssim_map.mean() 151 | else: 152 | return ssim_map.mean(1).mean(1).mean(1) 153 | 154 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------