├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── arguments └── __init__.py ├── assets ├── ablation-asg.png ├── ablation-c2f.png ├── pipeline-anchor.png ├── pipeline.png ├── real.png ├── synthetic.png └── teaser.png ├── convert.py ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── metrics.py ├── render.py ├── render_anchor.py ├── requirements.txt ├── run_anchor.sh ├── run_wo_anchor.sh ├── scene ├── __init__.py ├── anchor_gaussian_model.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py ├── gaussian_model.py └── specular_model.py ├── train.py ├── train_anchor.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── pose_utils.py ├── quaternion_utils.py ├── sh_utils.py ├── spec_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/depth-diff-gaussian-rasterization"] 5 | path = submodules/depth-diff-gaussian-rasterization 6 | url = https://github.com/ingra14m/diff-gaussian-rasterization-extentions 7 | branch = filter-norm 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ziyi Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spec-Gaussian: Anisotropic View-Dependent Appearance for 3D Gaussian Splatting 2 | 3 | ## [Project Page](https://ingra14m.github.io/Spec-Gaussian-website/) | [Paper](https://arxiv.org/abs/2402.15870) | [Anisotropic Dataset](https://drive.google.com/drive/folders/1hH7qMSbTyR392PYgsqeMhAnaAxwxzemc?usp=drive_link) 4 | 5 | ![teaser](assets/teaser.png) 6 | 7 | This project was built on my previous released [My-exp-Gaussian](https://github.com/ingra14m/My-exp-Gaussian), aiming to enhance 3D Gaussian Splatting in modeling scenes with specular highlights. I hope this work can assist researchers who need to model specular highlights through splatting. 8 | 9 | 10 | 11 | **Note** that the current Spec-Gaussian has significantly improved in quality compared to the first version on arxiv (2024.02). Please pay attention to the latest version on arxiv. 12 | 13 | ## News 14 | 15 | - **[11/15/2024]** Update the training scripts for current version. 16 | - **[9/26/2024]** Spec-Gaussian has been accepted by NeurIPS 2024. We also release our anisotropic dataset [here](https://drive.google.com/drive/folders/1hH7qMSbTyR392PYgsqeMhAnaAxwxzemc?usp=drive_link). 17 | 18 | 19 | 20 | ## Dataset 21 | 22 | In our paper, we use: 23 | 24 | - synthetic dataset from [NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1), [NSVF](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip), and our [Anisotropic Synthetic Dataset](https://drive.google.com/drive/folders/1hH7qMSbTyR392PYgsqeMhAnaAxwxzemc?usp=drive_link) 25 | - real-world dataset from [Mip-NeRF 360](https://jonbarron.info/mipnerf360/). 26 | 27 | And the data structure should be organized as follows: 28 | 29 | ```shell 30 | data/ 31 | ├── NeRF 32 | │ ├── Chair/ 33 | │ ├── Drums/ 34 | │ ├── ... 35 | ├── NSVF 36 | │ ├── Bike/ 37 | │ ├── Lifestyle/ 38 | │ ├── ... 39 | ├── Spec-GS 40 | │ ├── ashtray/ 41 | │ ├── dishes/ 42 | │ ├── ... 43 | ├── Mip-360 44 | │ ├── bicycle/ 45 | │ ├── bonsai/ 46 | │ ├── ... 47 | ├── tandt_db 48 | │ ├── db/ 49 | │ │ ├── drjohnson/ 50 | │ │ ├── playroom/ 51 | │ ├── tandt/ 52 | │ │ ├── train/ 53 | │ │ ├── truck/ 54 | ``` 55 | 56 | 57 | 58 | ## Pipeline 59 | 60 | ![pipeline](assets/pipeline.png) 61 | 62 | 63 | 64 | ## Run 65 | 66 | ### Environment 67 | 68 | ```shell 69 | git clone https://github.com/ingra14m/Spec-Gaussian --recursive 70 | cd Spec-Gaussian 71 | 72 | conda create -n spec-gaussian-env python=3.7 73 | conda activate spec-gaussian-env 74 | 75 | # install pytorch 76 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 77 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cu116.html 78 | 79 | # install dependencies 80 | pip install -r requirements.txt 81 | ``` 82 | 83 | 84 | 85 | ### Train 86 | 87 | We have provided scripts [`run_wo_anchor.sh`](https://github.com/ingra14m/Spec-Gaussian/blob/main/run_wo_anchor.sh) and [`run_anchor.sh`](https://github.com/ingra14m/Spec-Gaussian/blob/main/run_anchor.sh) that were used to generate the table in the paper. 88 | 89 | In general, using the version without anchor Gaussian can achieve better rendering effects. Using the version with anchor Gaussian can achieve faster training and inference. For researchers who want to explore the use of Spec-Gaussian, we have provided the following general training command. 90 | 91 | **Train without anchor** 92 | 93 | ```shell 94 | python train.py -s your/path/to/the/dataset -m your/path/to/save --eval 95 | 96 | ## For synthetic bounded scenes 97 | python train.py -s data/nerf_synthetic/drums -m outputs/nerf/drums --eval 98 | 99 | ## For real-world unbounded indoor scenes 100 | python train.py -s data/mipnerf-360/bonsai -m outputs/mip360/bonsai --eval -r 2 --is_real --is_indoor --asg_degree 12 101 | 102 | ## For real-world unbounded outdoor scenes 103 | python train.py -s data/mipnerf-360/bicycle -m outputs/mip360/bicycle --eval -r 4 --is_real --asg_degree 12 104 | ``` 105 | 106 | 107 | 108 | **[Extra, for acceleration] Train with anchor** 109 | 110 | ```shell 111 | python train_anchor.py -s your/path/to/the/dataset -m your/path/to/save --eval 112 | 113 | ## For synthetic bounded scenes 114 | python train_anchor.py -s data/nerf_synthetic/drums -m outputs/nerf/drums --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 115 | 116 | ## For mip360 scenes 117 | python train_anchor.py -s data/mipnerf-360/bonsai -m outputs/mip360/bonsai --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 -r [2|4] 118 | ``` 119 | 120 | 121 | 122 | ## Results 123 | 124 | ### Synthetic Scenes 125 | 126 | ![synthetic](assets/synthetic.png) 127 | 128 | 129 | 130 | ### Real-world Scenes 131 | 132 | ![real](assets/real.png) 133 | 134 | 135 | 136 | ### Ablation 137 | 138 | ![ablation-asg](assets/ablation-asg.png) 139 | 140 | ![ablation-c2f](assets/ablation-c2f.png) 141 | 142 | 143 | 144 | ### Align with Rip-NeRF 145 | 146 | The Tri-MipRF and Rip-NeRF use both train and val set and the training data. I provided the results on NeRF-synthetic dataset with the same setting. 147 | | Scene | PSNR | SSIM | LPIPS | 148 | | --------- | ------ | --------- | -------- | 149 | | chair | 37.33 | 0.9907 | 0.0088 | 150 | | drums | 28.50 | 0.9669 | 0.0288 | 151 | | ficus | 38.08 | 0.9922 | 0.0081 | 152 | | hotdog | 39.86 | 0.9895 | 0.0148 | 153 | | lego | 38.44 | 0.9876 | 0.0121 | 154 | | materials | 32.64 | 0.9738 | 0.0285 | 155 | | mic | 38.57 | 0.995 | 0.0045 | 156 | | ship | 33.66 | 0.9248 | 0.0906 | 157 | | Average | **35.89** | **0.9776** | **0.0245** | 158 | | Rip-NeRF | 35.44 | 0.973 | 0.037 | 159 | 160 | 161 | 162 | ## Acknowledgments 163 | 164 | This work was mainly supported by ByteDance MMLab. I'm very grateful for the help from Chao Wan of Cornell University during the rebuttal. 165 | 166 | 167 | 168 | ## BibTex 169 | 170 | ```shell 171 | @article{yang2024spec, 172 | title={Spec-Gaussian: Anisotropic View-Dependent Appearance for 3D Gaussian Splatting}, 173 | author={Yang, Ziyi and Gao, Xinyu and Sun, Yangtian and Huang, Yihua and Lyu, Xiaoyang and Zhou, Wen and Jiao, Shaohui and Qi, Xiaojuan and Jin, Xiaogang}, 174 | journal={arXiv preprint arXiv:2402.15870}, 175 | year={2024} 176 | } 177 | ``` 178 | 179 | And thanks to the authors of [3D Gaussians](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) and [Scaffold-GS](https://github.com/city-super/Scaffold-GS) for their excellent code, please consider citing these repositories. 180 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | 17 | class GroupParams: 18 | pass 19 | 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | if shorthand: 32 | if t == bool: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 34 | else: 35 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 36 | else: 37 | if t == bool: 38 | group.add_argument("--" + key, default=value, action="store_true") 39 | else: 40 | group.add_argument("--" + key, default=value, type=t) 41 | 42 | def extract(self, args): 43 | group = GroupParams() 44 | for arg in vars(args).items(): 45 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 46 | setattr(group, arg[0], arg[1]) 47 | return group 48 | 49 | 50 | class ModelParams(ParamGroup): 51 | def __init__(self, parser, sentinel=False): 52 | self.sh_degree = 3 53 | self.asg_degree = 24 54 | self._source_path = "" 55 | self._model_path = "" 56 | self._images = "images" 57 | self._resolution = -1 58 | self._white_background = False 59 | self.data_device = "cuda" 60 | self.eval = False 61 | self.load2gpu_on_the_fly = False 62 | self.is_real = False 63 | self.is_indoor = False 64 | self.add_val = False 65 | super().__init__(parser, "Loading Parameters", sentinel) 66 | 67 | def extract(self, args): 68 | g = super().extract(args) 69 | g.source_path = os.path.abspath(g.source_path) 70 | return g 71 | 72 | 73 | class PipelineParams(ParamGroup): 74 | def __init__(self, parser): 75 | self.convert_SHs_python = False 76 | self.compute_cov3D_python = False 77 | self.debug = False 78 | super().__init__(parser, "Pipeline Parameters") 79 | 80 | 81 | class OptimizationParams(ParamGroup): 82 | def __init__(self, parser): 83 | self.iterations = 30_000 84 | self.position_lr_init = 0.00016 85 | self.position_lr_final = 0.0000016 86 | self.position_lr_delay_mult = 0.01 87 | self.position_lr_max_steps = 30_000 88 | self.specular_lr_max_steps = 30_000 89 | self.feature_lr = 0.0025 90 | self.opacity_lr = 0.05 91 | self.scaling_lr = 0.005 92 | self.rotation_lr = 0.001 93 | self.percent_dense = 0.01 94 | self.lambda_dssim = 0.2 95 | self.normal_lr = 0.0002 96 | self.densification_interval = 100 97 | self.opacity_reset_interval = 3000 98 | self.densify_from_iter = 500 99 | self.densify_until_iter = 15_000 100 | self.densify_grad_threshold = 0.0005 101 | super().__init__(parser, "Optimization Parameters") 102 | 103 | 104 | class AnchorModelParams(ParamGroup): 105 | def __init__(self, parser, sentinel=False): 106 | self.sh_degree = 3 107 | self.feat_dim = 32 108 | self.n_offsets = 10 109 | self.voxel_size = 0.001 # if voxel_size<=0, using 1nn dist 110 | self.update_depth = 3 111 | self.update_init_factor = 16 112 | self.update_hierachy_factor = 4 113 | self._source_path = "" 114 | self._model_path = "" 115 | self._images = "images" 116 | self._resolution = -1 117 | self._white_background = False 118 | self.data_device = "cuda" 119 | self.eval = False 120 | self.load2gpu_on_the_fly = False 121 | super().__init__(parser, "Loading Parameters", sentinel) 122 | 123 | def extract(self, args): 124 | g = super().extract(args) 125 | g.source_path = os.path.abspath(g.source_path) 126 | return g 127 | 128 | 129 | class AnchorOptimizationParams(ParamGroup): 130 | def __init__(self, parser): 131 | self.iterations = 30_000 132 | self.position_lr_init = 0.0 133 | self.position_lr_final = 0.0 134 | self.position_lr_delay_mult = 0.01 135 | self.position_lr_max_steps = 30_000 136 | 137 | self.offset_lr_init = 0.01 138 | self.offset_lr_final = 0.0001 139 | self.offset_lr_delay_mult = 0.01 140 | self.offset_lr_max_steps = 30_000 141 | 142 | self.feature_lr = 0.0075 143 | self.opacity_lr = 0.02 144 | self.scaling_lr = 0.007 145 | self.rotation_lr = 0.002 146 | 147 | self.mlp_opacity_lr_init = 0.002 148 | self.mlp_opacity_lr_final = 0.00002 149 | self.mlp_opacity_lr_delay_mult = 0.01 150 | self.mlp_opacity_lr_max_steps = 30_000 151 | 152 | self.mlp_cov_lr_init = 0.004 153 | self.mlp_cov_lr_final = 0.004 154 | self.mlp_cov_lr_delay_mult = 0.01 155 | self.mlp_cov_lr_max_steps = 30_000 156 | 157 | self.mlp_color_lr_init = 0.008 158 | self.mlp_color_lr_final = 0.00005 159 | self.mlp_color_lr_delay_mult = 0.01 160 | self.mlp_color_lr_max_steps = 30_000 161 | 162 | self.mlp_featurebank_lr_init = 0.01 163 | self.mlp_featurebank_lr_final = 0.00001 164 | self.mlp_featurebank_lr_delay_mult = 0.01 165 | self.mlp_featurebank_lr_max_steps = 30_000 166 | 167 | self.percent_dense = 0.01 168 | self.lambda_dssim = 0.2 169 | 170 | # for anchor densification 171 | self.start_stat = 500 172 | self.update_from = 1500 173 | self.update_interval = 100 174 | self.update_until = 15_000 175 | 176 | self.min_opacity = 0.005 177 | self.success_threshold = 0.8 178 | self.densify_grad_threshold = 0.0006 179 | 180 | # for coarse to fine 181 | self.use_c2f = False 182 | self.c2f_init_factor = 0.125 183 | self.c2f_until_iter = 20_000 184 | 185 | super().__init__(parser, "Optimization Parameters") 186 | 187 | 188 | def get_combined_args(parser: ArgumentParser): 189 | cmdlne_string = sys.argv[1:] 190 | cfgfile_string = "Namespace()" 191 | args_cmdline = parser.parse_args(cmdlne_string) 192 | 193 | try: 194 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 195 | print("Looking for config file in", cfgfilepath) 196 | with open(cfgfilepath) as cfg_file: 197 | print("Config file found: {}".format(cfgfilepath)) 198 | cfgfile_string = cfg_file.read() 199 | except TypeError: 200 | print("Config file not found at") 201 | pass 202 | args_cfgfile = eval(cfgfile_string) 203 | 204 | merged_dict = vars(args_cfgfile).copy() 205 | for k, v in vars(args_cmdline).items(): 206 | if v != None: 207 | merged_dict[k] = v 208 | return Namespace(**merged_dict) 209 | -------------------------------------------------------------------------------- /assets/ablation-asg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Specular-Gaussians/0b40ce3ac034761798002f82b4f5b32e80547688/assets/ablation-asg.png -------------------------------------------------------------------------------- /assets/ablation-c2f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Specular-Gaussians/0b40ce3ac034761798002f82b4f5b32e80547688/assets/ablation-c2f.png -------------------------------------------------------------------------------- /assets/pipeline-anchor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Specular-Gaussians/0b40ce3ac034761798002f82b4f5b32e80547688/assets/pipeline-anchor.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Specular-Gaussians/0b40ce3ac034761798002f82b4f5b32e80547688/assets/pipeline.png -------------------------------------------------------------------------------- /assets/real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Specular-Gaussians/0b40ce3ac034761798002f82b4f5b32e80547688/assets/real.png -------------------------------------------------------------------------------- /assets/synthetic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Specular-Gaussians/0b40ce3ac034761798002f82b4f5b32e80547688/assets/synthetic.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Specular-Gaussians/0b40ce3ac034761798002f82b4f5b32e80547688/assets/teaser.png -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | import shutil 15 | 16 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 17 | parser = ArgumentParser("Colmap converter") 18 | parser.add_argument("--no_gpu", action='store_true') 19 | parser.add_argument("--skip_matching", action='store_true') 20 | parser.add_argument("--source_path", "-s", required=True, type=str) 21 | parser.add_argument("--camera", default="OPENCV", type=str) 22 | parser.add_argument("--colmap_executable", default="", type=str) 23 | parser.add_argument("--resize", action="store_true") 24 | parser.add_argument("--magick_executable", default="", type=str) 25 | args = parser.parse_args() 26 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 27 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 28 | use_gpu = 1 if not args.no_gpu else 0 29 | 30 | if not args.skip_matching: 31 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 32 | 33 | ## Feature extraction 34 | os.system(colmap_command + " feature_extractor " \ 35 | "--database_path " + args.source_path + "/distorted/database.db \ 36 | --image_path " + args.source_path + "/input \ 37 | --ImageReader.single_camera 1 \ 38 | --ImageReader.camera_model " + args.camera + " \ 39 | --SiftExtraction.use_gpu " + str(use_gpu)) 40 | 41 | ## Feature matching 42 | os.system(colmap_command + " exhaustive_matcher \ 43 | --database_path " + args.source_path + "/distorted/database.db \ 44 | --SiftMatching.use_gpu " + str(use_gpu)) 45 | 46 | ### Bundle adjustment 47 | # The default Mapper tolerance is unnecessarily large, 48 | # decreasing it speeds up bundle adjustment steps. 49 | os.system(colmap_command + " mapper \ 50 | --database_path " + args.source_path + "/distorted/database.db \ 51 | --image_path " + args.source_path + "/input \ 52 | --output_path " + args.source_path + "/distorted/sparse \ 53 | --Mapper.ba_global_function_tolerance=0.000001") 54 | 55 | ### Image undistortion 56 | ## We need to undistort our images into ideal pinhole intrinsics. 57 | os.system(colmap_command + " image_undistorter \ 58 | --image_path " + args.source_path + "/input \ 59 | --input_path " + args.source_path + "/distorted/sparse/0 \ 60 | --output_path " + args.source_path + "\ 61 | --output_type COLMAP") 62 | 63 | files = os.listdir(args.source_path + "/sparse") 64 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 65 | # Copy each file from the source directory to the destination directory 66 | for file in files: 67 | if file == '0': 68 | continue 69 | source_file = os.path.join(args.source_path, "sparse", file) 70 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 71 | shutil.move(source_file, destination_file) 72 | 73 | if (args.resize): 74 | print("Copying and resizing...") 75 | 76 | # Resize images. 77 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 78 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 79 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 80 | # Get the list of files in the source directory 81 | files = os.listdir(args.source_path + "/images") 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | source_file = os.path.join(args.source_path, "images", file) 85 | 86 | destination_file = os.path.join(args.source_path, "images_2", file) 87 | shutil.copy2(source_file, destination_file) 88 | os.system(magick_command + " mogrify -resize 50% " + destination_file) 89 | 90 | destination_file = os.path.join(args.source_path, "images_4", file) 91 | shutil.copy2(source_file, destination_file) 92 | os.system(magick_command + " mogrify -resize 25% " + destination_file) 93 | 94 | destination_file = os.path.join(args.source_path, "images_8", file) 95 | shutil.copy2(source_file, destination_file) 96 | os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 97 | 98 | print("Done.") 99 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from einops import repeat 14 | import math 15 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 16 | from scene.gaussian_model import GaussianModel 17 | from scene.anchor_gaussian_model import AnchorGaussianModel 18 | from utils.sh_utils import eval_sh 19 | 20 | 21 | def quaternion_multiply(q1, q2): 22 | w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3] 23 | w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3] 24 | 25 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 26 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 27 | y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 28 | z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 29 | 30 | return torch.stack((w, x, y, z), dim=-1) 31 | 32 | 33 | def render(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, mlp_color, hybrid=True, 34 | scaling_modifier=1.0, voxel_visible_mask=None, override_color=None): 35 | """ 36 | Render the scene. 37 | 38 | Background tensor (bg_color) must be on GPU! 39 | """ 40 | 41 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 42 | screenspace_points = torch.zeros_like(pc.get_xyz if voxel_visible_mask is None else pc.get_xyz[voxel_visible_mask], 43 | dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 44 | screenspace_points_densify = torch.zeros_like(pc.get_xyz if voxel_visible_mask is None else pc.get_xyz[voxel_visible_mask], 45 | dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 46 | try: 47 | screenspace_points.retain_grad() 48 | screenspace_points_densify.retain_grad() 49 | except: 50 | pass 51 | 52 | # Set up rasterization configuration 53 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 54 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 55 | 56 | raster_settings = GaussianRasterizationSettings( 57 | image_height=int(viewpoint_camera.image_height), 58 | image_width=int(viewpoint_camera.image_width), 59 | tanfovx=tanfovx, 60 | tanfovy=tanfovy, 61 | bg=bg_color, 62 | scale_modifier=scaling_modifier, 63 | viewmatrix=viewpoint_camera.world_view_transform, 64 | projmatrix=viewpoint_camera.full_proj_transform, 65 | sh_degree=pc.active_sh_degree, 66 | campos=viewpoint_camera.camera_center, 67 | prefiltered=False, 68 | debug=pipe.debug, 69 | ) 70 | 71 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 72 | 73 | if voxel_visible_mask is None: 74 | means3D = pc.get_xyz 75 | opacity = pc.get_opacity 76 | scales = pc.get_scaling 77 | rotations = pc.get_rotation 78 | else: 79 | means3D = pc.get_xyz[voxel_visible_mask] 80 | opacity = pc.get_opacity[voxel_visible_mask] 81 | scales = pc.get_scaling[voxel_visible_mask] 82 | rotations = pc.get_rotation[voxel_visible_mask] 83 | 84 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 85 | # scaling / rotation by the rasterizer. 86 | cov3D_precomp = None 87 | if pipe.compute_cov3D_python: 88 | cov3D_precomp = pc.get_covariance(scaling_modifier) 89 | 90 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 91 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 92 | shs = None 93 | colors_precomp = None 94 | # colors_precomp = mlp_color 95 | if colors_precomp is None: 96 | if hybrid: 97 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2) 98 | dir_pp = (means3D - viewpoint_camera.camera_center.repeat(means3D.shape[0], 1)) 99 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 100 | sh2rgb = eval_sh(pc.active_sh_degree, 101 | shs_view if voxel_visible_mask is None else shs_view[voxel_visible_mask], 102 | dir_pp_normalized) 103 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + mlp_color 104 | else: 105 | # shs = pc.get_features 106 | colors_precomp = mlp_color 107 | else: 108 | colors_precomp = override_color 109 | 110 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 111 | rendered_image, radii, depth = rasterizer( 112 | means3D=means3D, 113 | means2D=screenspace_points, 114 | means2D_densify=screenspace_points_densify, 115 | shs=None, 116 | colors_precomp=colors_precomp, 117 | opacities=opacity, 118 | scales=scales, 119 | rotations=rotations, 120 | cov3D_precomp=cov3D_precomp) 121 | 122 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 123 | # They will be excluded from value updates used in the splitting criteria. 124 | return {"render": rendered_image, 125 | "viewspace_points": screenspace_points, 126 | "viewspace_points_densify": screenspace_points_densify, 127 | "visibility_filter": radii > 0, 128 | "radii": radii, 129 | "depth": depth} 130 | 131 | 132 | def prefilter_voxel(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, scaling_modifier=1.0, 133 | override_color=None): 134 | """ 135 | Render the scene. 136 | 137 | Background tensor (bg_color) must be on GPU! 138 | """ 139 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 140 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, 141 | device="cuda") + 0 142 | try: 143 | screenspace_points.retain_grad() 144 | except: 145 | pass 146 | 147 | # Set up rasterization configuration 148 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 149 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 150 | 151 | raster_settings = GaussianRasterizationSettings( 152 | image_height=int(viewpoint_camera.image_height), 153 | image_width=int(viewpoint_camera.image_width), 154 | tanfovx=tanfovx, 155 | tanfovy=tanfovy, 156 | bg=bg_color, 157 | scale_modifier=scaling_modifier, 158 | viewmatrix=viewpoint_camera.world_view_transform, 159 | projmatrix=viewpoint_camera.full_proj_transform, 160 | sh_degree=1, 161 | campos=viewpoint_camera.camera_center, 162 | prefiltered=False, 163 | debug=pipe.debug 164 | ) 165 | 166 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 167 | 168 | means3D = pc.get_xyz 169 | 170 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 171 | # scaling / rotation by the rasterizer. 172 | scales = None 173 | rotations = None 174 | cov3D_precomp = None 175 | if pipe.compute_cov3D_python: 176 | cov3D_precomp = pc.get_covariance(scaling_modifier) 177 | else: 178 | scales = pc.get_scaling 179 | rotations = pc.get_rotation 180 | 181 | radii_pure = rasterizer.visible_filter(means3D=means3D, 182 | scales=scales[:, :3], 183 | rotations=rotations, 184 | cov3D_precomp=cov3D_precomp) 185 | 186 | return radii_pure > 0 187 | 188 | 189 | def generate_neural_gaussians(viewpoint_camera, pc: AnchorGaussianModel, visible_mask=None, is_training=False): 190 | ## view frustum filtering for acceleration 191 | if visible_mask is None: 192 | visible_mask = torch.ones(pc.get_anchor.shape[0], dtype=torch.bool, device=pc.get_anchor.device) 193 | 194 | feat = pc._anchor_feat[visible_mask] 195 | anchor = pc.get_anchor[visible_mask] 196 | grid_offsets = pc._offset[visible_mask] 197 | grid_scaling = pc.get_scaling[visible_mask] 198 | 199 | ## get view properties for anchor 200 | ob_view = anchor - viewpoint_camera.camera_center 201 | # dist 202 | ob_dist = ob_view.norm(dim=1, keepdim=True) 203 | # view 204 | ob_view = ob_view / ob_dist 205 | 206 | cat_local_view = torch.cat([feat, ob_view, ob_dist], dim=1) # [N, c+3] 207 | 208 | # get offset's opacity 209 | neural_opacity = pc.get_opacity_mlp(cat_local_view) # [N, k] 210 | 211 | # opacity mask generation 212 | neural_opacity = neural_opacity.reshape([-1, 1]) 213 | mask = (neural_opacity > 0.0) 214 | mask = mask.view(-1) 215 | 216 | # select opacity 217 | opacity = neural_opacity[mask] 218 | 219 | # # get offset's color 220 | # color = pc.get_color_mlp(feat, ob_view) 221 | # color = color.reshape([anchor.shape[0]*pc.n_offsets, 3])# [mask] 222 | 223 | # get offset's cov 224 | scale_rot = pc.get_cov_mlp(cat_local_view) 225 | scale_rot = scale_rot.reshape([anchor.shape[0] * pc.n_offsets, 7]) # [mask] 226 | 227 | # offsets 228 | offsets = grid_offsets.view([-1, 3]) # [mask] 229 | center_normal = -torch.mean(grid_offsets, dim=1) 230 | 231 | # combine for parallel masking 232 | concatenated = torch.cat([grid_scaling, anchor, feat, center_normal], dim=-1) 233 | concatenated_repeated = repeat(concatenated, 'n (c) -> (n k) (c)', k=pc.n_offsets) 234 | concatenated_all = torch.cat([concatenated_repeated, scale_rot, offsets], dim=-1) 235 | masked = concatenated_all[mask] 236 | scaling_repeat, repeat_anchor, repeat_feat, repeat_normal, scale_rot, offsets = masked.split([6, 3, 32, 3, 7, 3], 237 | dim=-1) 238 | 239 | # post-process cov 240 | scaling = scaling_repeat[:, 3:] * torch.sigmoid(scale_rot[:, :3]) # * (1+torch.sigmoid(repeat_dist)) 241 | rot = pc.rotation_activation(scale_rot[:, 3:7]) 242 | 243 | # post-process offsets to get centers for gaussians 244 | offsets = offsets * scaling_repeat[:, :3] 245 | xyz = repeat_anchor + offsets 246 | 247 | # knn_res = knn_points(xyz[None], anchor[None], None, None, K=4+1) 248 | # idx = knn_res.idx[0] 249 | # feat_color = feat[idx].mean(1) 250 | 251 | # post color, fast convergence 252 | dir_view = xyz - viewpoint_camera.camera_center 253 | # dist 254 | dir_dist = dir_view.norm(dim=1, keepdim=True) 255 | # view 256 | dir_view = dir_view / dir_dist 257 | 258 | color = pc.get_color_mlp(repeat_feat, dir_view, repeat_normal, offsets) 259 | 260 | if is_training: 261 | return xyz, color, opacity, scaling, rot, neural_opacity, mask 262 | else: 263 | return xyz, color, opacity, scaling, rot 264 | 265 | 266 | def anchor_render(viewpoint_camera, pc: AnchorGaussianModel, pipe, bg_color: torch.Tensor, scaling_modifier=1.0, 267 | visible_mask=None, 268 | retain_grad=False, down_sampling=1): 269 | """ 270 | Render the scene. 271 | 272 | Background tensor (bg_color) must be on GPU! 273 | """ 274 | is_training = pc.get_color_mlp.training 275 | 276 | if is_training: 277 | xyz, color, opacity, scaling, rot, neural_opacity, mask = generate_neural_gaussians(viewpoint_camera, pc, 278 | visible_mask, 279 | is_training=is_training) 280 | else: 281 | xyz, color, opacity, scaling, rot = generate_neural_gaussians(viewpoint_camera, pc, visible_mask, 282 | is_training=is_training) 283 | 284 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 285 | screenspace_points = torch.zeros_like(xyz, dtype=pc.get_anchor.dtype, requires_grad=True, device="cuda") + 0 286 | screenspace_points_densify = torch.zeros_like(xyz, dtype=pc.get_anchor.dtype, requires_grad=True, device="cuda") + 0 287 | if retain_grad: 288 | try: 289 | screenspace_points.retain_grad() 290 | screenspace_points_densify.retain_grad() 291 | except: 292 | pass 293 | 294 | # Set up rasterization configuration 295 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 296 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 297 | 298 | raster_settings = GaussianRasterizationSettings( 299 | image_height=int(viewpoint_camera.image_height * down_sampling), 300 | image_width=int(viewpoint_camera.image_width * down_sampling), 301 | tanfovx=tanfovx, 302 | tanfovy=tanfovy, 303 | bg=bg_color, 304 | scale_modifier=scaling_modifier, 305 | viewmatrix=viewpoint_camera.world_view_transform, 306 | projmatrix=viewpoint_camera.full_proj_transform, 307 | sh_degree=1, 308 | campos=viewpoint_camera.camera_center, 309 | prefiltered=False, 310 | debug=pipe.debug 311 | ) 312 | 313 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 314 | 315 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 316 | rendered_image, radii, depth = rasterizer( 317 | means3D=xyz, 318 | means2D=screenspace_points, 319 | means2D_densify=screenspace_points_densify, 320 | shs=None, 321 | colors_precomp=color, 322 | opacities=opacity, 323 | scales=scaling, 324 | rotations=rot, 325 | cov3D_precomp=None) 326 | 327 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 328 | if is_training: 329 | return {"render": rendered_image, 330 | "viewspace_points": screenspace_points, 331 | "viewspace_points_densify": screenspace_points_densify, 332 | "visibility_filter": radii > 0, 333 | "radii": radii, 334 | "selection_mask": mask, 335 | "neural_opacity": neural_opacity, 336 | "scaling": scaling, 337 | "depth": depth, 338 | } 339 | else: 340 | return {"render": rendered_image, 341 | "viewspace_points": screenspace_points, 342 | "visibility_filter": radii > 0, 343 | "radii": radii, 344 | "depth": depth, 345 | } 346 | 347 | 348 | def anchor_prefilter_voxel(viewpoint_camera, pc: AnchorGaussianModel, pipe, bg_color: torch.Tensor, 349 | scaling_modifier=1.0, 350 | override_color=None): 351 | """ 352 | Render the scene. 353 | 354 | Background tensor (bg_color) must be on GPU! 355 | """ 356 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 357 | screenspace_points = torch.zeros_like(pc.get_anchor, dtype=pc.get_anchor.dtype, requires_grad=True, 358 | device="cuda") + 0 359 | try: 360 | screenspace_points.retain_grad() 361 | except: 362 | pass 363 | 364 | # Set up rasterization configuration 365 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 366 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 367 | 368 | raster_settings = GaussianRasterizationSettings( 369 | image_height=int(viewpoint_camera.image_height), 370 | image_width=int(viewpoint_camera.image_width), 371 | tanfovx=tanfovx, 372 | tanfovy=tanfovy, 373 | bg=bg_color, 374 | scale_modifier=scaling_modifier, 375 | viewmatrix=viewpoint_camera.world_view_transform, 376 | projmatrix=viewpoint_camera.full_proj_transform, 377 | sh_degree=1, 378 | campos=viewpoint_camera.camera_center, 379 | prefiltered=False, 380 | debug=pipe.debug 381 | ) 382 | 383 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 384 | 385 | means3D = pc.get_anchor 386 | 387 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 388 | # scaling / rotation by the rasterizer. 389 | scales = None 390 | rotations = None 391 | cov3D_precomp = None 392 | if pipe.compute_cov3D_python: 393 | cov3D_precomp = pc.get_covariance(scaling_modifier) 394 | else: 395 | scales = pc.get_scaling 396 | rotations = pc.get_rotation 397 | 398 | radii_pure = rasterizer.visible_filter(means3D=means3D, 399 | scales=scales[:, :3], 400 | rotations=rotations, 401 | cov3D_precomp=cov3D_precomp) 402 | 403 | return radii_pure > 0 404 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | 27 | def init(wish_host, wish_port): 28 | global host, port, listener 29 | host = wish_host 30 | port = wish_port 31 | listener.bind((host, port)) 32 | listener.listen() 33 | listener.settimeout(0) 34 | 35 | 36 | def try_connect(): 37 | global conn, addr, listener 38 | try: 39 | conn, addr = listener.accept() 40 | print(f"\nConnected by {addr}") 41 | conn.settimeout(None) 42 | except Exception as inst: 43 | pass 44 | 45 | 46 | def read(): 47 | global conn 48 | messageLength = conn.recv(4) 49 | messageLength = int.from_bytes(messageLength, 'little') 50 | message = conn.recv(messageLength) 51 | return json.loads(message.decode("utf-8")) 52 | 53 | 54 | def send(message_bytes, verify): 55 | global conn 56 | if message_bytes != None: 57 | conn.sendall(message_bytes) 58 | conn.sendall(len(verify).to_bytes(4, 'little')) 59 | conn.sendall(bytes(verify, 'ascii')) 60 | 61 | 62 | def receive(): 63 | message = read() 64 | 65 | width = message["resolution_x"] 66 | height = message["resolution_y"] 67 | 68 | if width != 0 and height != 0: 69 | try: 70 | do_training = bool(message["train"]) 71 | fovy = message["fov_y"] 72 | fovx = message["fov_x"] 73 | znear = message["z_near"] 74 | zfar = message["z_far"] 75 | do_shs_python = bool(message["shs_python"]) 76 | do_rot_scale_python = bool(message["rot_scale_python"]) 77 | keep_alive = bool(message["keep_alive"]) 78 | scaling_modifier = message["scaling_modifier"] 79 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 80 | world_view_transform[:, 1] = -world_view_transform[:, 1] 81 | world_view_transform[:, 2] = -world_view_transform[:, 2] 82 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 83 | full_proj_transform[:, 1] = -full_proj_transform[:, 1] 84 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 85 | except Exception as e: 86 | print("") 87 | traceback.print_exc() 88 | raise e 89 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 90 | else: 91 | return None, None, None, None, None, None 92 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | lpips_fn = lpips.LPIPS(net='vgg').to('cuda') 25 | 26 | 27 | def readImages(renders_dir, gt_dir): 28 | renders = [] 29 | gts = [] 30 | image_names = [] 31 | for fname in os.listdir(renders_dir): 32 | render = Image.open(renders_dir / fname) 33 | gt = Image.open(gt_dir / fname) 34 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 35 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 36 | image_names.append(fname) 37 | return renders, gts, image_names 38 | 39 | 40 | def evaluate(model_paths): 41 | full_dict = {} 42 | per_view_dict = {} 43 | full_dict_polytopeonly = {} 44 | per_view_dict_polytopeonly = {} 45 | print("") 46 | 47 | for scene_dir in model_paths: 48 | try: 49 | print("Scene:", scene_dir) 50 | full_dict[scene_dir] = {} 51 | per_view_dict[scene_dir] = {} 52 | full_dict_polytopeonly[scene_dir] = {} 53 | per_view_dict_polytopeonly[scene_dir] = {} 54 | 55 | test_dir = Path(scene_dir) / "test" 56 | 57 | for method in os.listdir(test_dir): 58 | print("Method:", method) 59 | 60 | full_dict[scene_dir][method] = {} 61 | per_view_dict[scene_dir][method] = {} 62 | full_dict_polytopeonly[scene_dir][method] = {} 63 | per_view_dict_polytopeonly[scene_dir][method] = {} 64 | 65 | method_dir = test_dir / method 66 | gt_dir = method_dir / "gt" 67 | renders_dir = method_dir / "renders" 68 | renders, gts, image_names = readImages(renders_dir, gt_dir) 69 | 70 | ssims = [] 71 | psnrs = [] 72 | lpipss = [] 73 | 74 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 75 | ssims.append(ssim(renders[idx], gts[idx])) 76 | psnrs.append(psnr(renders[idx], gts[idx])) 77 | lpipss.append(lpips_fn(renders[idx], gts[idx]).detach()) 78 | 79 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 80 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 81 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 82 | print("") 83 | 84 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 85 | "PSNR": torch.tensor(psnrs).mean().item(), 86 | "LPIPS": torch.tensor(lpipss).mean().item()}) 87 | per_view_dict[scene_dir][method].update( 88 | {"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 89 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 90 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 91 | 92 | with open(scene_dir + "/results.json", 'w') as fp: 93 | json.dump(full_dict[scene_dir], fp, indent=True) 94 | with open(scene_dir + "/per_view.json", 'w') as fp: 95 | json.dump(per_view_dict[scene_dir], fp, indent=True) 96 | except: 97 | print("Unable to compute metrics for model", scene_dir) 98 | 99 | 100 | if __name__ == "__main__": 101 | device = torch.device("cuda:0") 102 | torch.cuda.set_device(device) 103 | 104 | # Set up command line argument parser 105 | parser = ArgumentParser(description="Training script parameters") 106 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 107 | args = parser.parse_args() 108 | evaluate(args.model_paths) 109 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene, SpecularModel 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render, prefilter_voxel 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from utils.pose_utils import pose_spherical, generate_ellipse_path 21 | from utils.graphics_utils import getWorld2View2 22 | from argparse import ArgumentParser 23 | from arguments import ModelParams, PipelineParams, OptimizationParams, get_combined_args 24 | from gaussian_renderer import GaussianModel 25 | import imageio 26 | import numpy as np 27 | import time 28 | 29 | 30 | def render_set(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, specular, 31 | use_filter): 32 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 33 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 34 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth") 35 | normal_path = os.path.join(model_path, name, "ours_{}".format(iteration), "normals") 36 | 37 | makedirs(render_path, exist_ok=True) 38 | makedirs(gts_path, exist_ok=True) 39 | makedirs(depth_path, exist_ok=True) 40 | makedirs(normal_path, exist_ok=True) 41 | 42 | t_list = [] 43 | voxel_visible_mask = None 44 | 45 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 46 | if use_filter: 47 | voxel_visible_mask = prefilter_voxel(view, gaussians, pipeline, background) 48 | dir_pp = (gaussians.get_xyz - view.camera_center.repeat(gaussians.get_features.shape[0], 1)) 49 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 50 | normal = gaussians.get_normal_axis(dir_pp_normalized=dir_pp_normalized, return_delta=True) 51 | if use_filter: 52 | mlp_color = specular.step(gaussians.get_asg_features[voxel_visible_mask], 53 | dir_pp_normalized[voxel_visible_mask], normal[voxel_visible_mask]) 54 | else: 55 | mlp_color = specular.step(gaussians.get_asg_features, dir_pp_normalized, normal) 56 | results = render(view, gaussians, pipeline, background, mlp_color, voxel_visible_mask=voxel_visible_mask) 57 | normal_image = \ 58 | render(view, gaussians, pipeline, background, normal[voxel_visible_mask] * 0.5 + 0.5, hybrid=False, 59 | voxel_visible_mask=voxel_visible_mask)["render"] 60 | rendering = results["render"] 61 | depth = results["depth"] 62 | depth = depth / (depth.max() + 1e-5) 63 | 64 | gt = view.original_image[0:3, :, :] 65 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 66 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 67 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png")) 68 | torchvision.utils.save_image(normal_image, os.path.join(normal_path, '{0:05d}'.format(idx) + ".png")) 69 | 70 | for idx, view in enumerate(tqdm(views, desc="FPS test progress")): 71 | torch.cuda.synchronize() 72 | t_start = time.time() 73 | 74 | if use_filter: 75 | voxel_visible_mask = prefilter_voxel(view, gaussians, pipeline, background) 76 | dir_pp = (gaussians.get_xyz - view.camera_center.repeat(gaussians.get_features.shape[0], 1)) 77 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 78 | normal = gaussians.get_normal_axis(dir_pp_normalized=dir_pp_normalized, return_delta=True) 79 | if use_filter: 80 | mlp_color = specular.step(gaussians.get_asg_features[voxel_visible_mask], 81 | dir_pp_normalized[voxel_visible_mask], normal[voxel_visible_mask]) 82 | else: 83 | mlp_color = specular.step(gaussians.get_asg_features, dir_pp_normalized, normal) 84 | results = render(view, gaussians, pipeline, background, mlp_color, voxel_visible_mask=voxel_visible_mask) 85 | 86 | torch.cuda.synchronize() 87 | t_end = time.time() 88 | t_list.append(t_end - t_start) 89 | 90 | t = np.array(t_list[5:]) 91 | fps = 1.0 / t.mean() 92 | print(f'Test FPS: \033[1;35m{fps:.5f}\033[0m') 93 | 94 | 95 | def interpolate_all(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, specular, 96 | use_filter): 97 | render_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "renders") 98 | depth_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "depth") 99 | 100 | makedirs(render_path, exist_ok=True) 101 | makedirs(depth_path, exist_ok=True) 102 | 103 | frame = 520 104 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4) for angle in np.linspace(-180, 180, frame + 1)[:-1]], 0) 105 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 106 | 107 | idx = torch.randint(0, len(views), (1,)).item() 108 | view = views[idx] # Choose a specific time for rendering 109 | 110 | renderings = [] 111 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")): 112 | matrix = np.linalg.inv(np.array(pose)) 113 | R = -np.transpose(matrix[:3, :3]) 114 | R[:, 0] = -R[:, 0] 115 | T = -matrix[:3, 3] 116 | 117 | view.reset_extrinsic(R, T) 118 | 119 | voxel_visible_mask = prefilter_voxel(view, gaussians, pipeline, background) if use_filter else \ 120 | torch.ones_like(gaussians.get_xyz)[..., 0].bool() 121 | dir_pp = (gaussians.get_xyz - view.camera_center.repeat(gaussians.get_features.shape[0], 1)) 122 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 123 | normal = gaussians.get_normal_axis(dir_pp_normalized=dir_pp_normalized, return_delta=True) 124 | mlp_color = specular.step(gaussians.get_asg_features[voxel_visible_mask], dir_pp_normalized[voxel_visible_mask], 125 | normal[voxel_visible_mask]) 126 | results = render(view, gaussians, pipeline, background, mlp_color, voxel_visible_mask=voxel_visible_mask) 127 | rendering = results["render"] 128 | renderings.append(to8b(rendering.cpu().numpy())) 129 | # depth = results["depth"] 130 | # depth = depth / (depth.max() + 1e-5) 131 | 132 | # torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png")) 133 | # torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png")) 134 | 135 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 136 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8) 137 | 138 | 139 | def render_video(model_path, iteration, views, gaussians, pipeline, background, specular): 140 | render_path = os.path.join(model_path, 'video', "ours_{}".format(iteration)) 141 | makedirs(render_path, exist_ok=True) 142 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 143 | view = views[0] 144 | renderings = [] 145 | for idx, pose in enumerate(tqdm(generate_ellipse_path(views, n_frames=600), desc="Rendering progress")): 146 | view.world_view_transform = torch.tensor( 147 | getWorld2View2(pose[:3, :3].T, pose[:3, 3], view.trans, view.scale)).transpose(0, 1).cuda() 148 | view.full_proj_transform = ( 149 | view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0) 150 | view.camera_center = view.world_view_transform.inverse()[3, :3] 151 | voxel_visible_mask = prefilter_voxel(view, gaussians, pipeline, background) 152 | dir_pp = (gaussians.get_xyz - view.camera_center.repeat(gaussians.get_features.shape[0], 1)) 153 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 154 | normal = gaussians.get_normal_axis(dir_pp_normalized=dir_pp_normalized, return_delta=True) 155 | mlp_color = specular.step(gaussians.get_asg_features[voxel_visible_mask], 156 | dir_pp_normalized[voxel_visible_mask], normal[voxel_visible_mask]) 157 | rendering = render(view, gaussians, pipeline, background, mlp_color, voxel_visible_mask=voxel_visible_mask)["render"] 158 | renderings.append(to8b(rendering.cpu().numpy())) 159 | # torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 160 | 161 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 162 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8) 163 | 164 | 165 | def render_sets(dataset: ModelParams, iteration: int, opt: OptimizationParams, pipeline: PipelineParams, 166 | skip_train: bool, skip_test: bool, mode: str): 167 | with torch.no_grad(): 168 | gaussians = GaussianModel(dataset.sh_degree, dataset.asg_degree) 169 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 170 | specular = SpecularModel(dataset.is_real, dataset.is_indoor) 171 | specular.load_weights(dataset.model_path) 172 | 173 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 174 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 175 | use_filter = dataset.is_real 176 | 177 | if mode == "render": 178 | render_func = render_set 179 | elif mode == "all": 180 | render_func = interpolate_all 181 | elif mode == 'video': 182 | render_video(dataset.model_path, scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, 183 | background, specular) 184 | return 185 | 186 | if not skip_train: 187 | render_func(dataset.model_path, dataset.load2gpu_on_the_fly, "train", scene.loaded_iter, 188 | scene.getTrainCameras(), gaussians, pipeline, 189 | background, specular, use_filter) 190 | 191 | if not skip_test: 192 | render_func(dataset.model_path, dataset.load2gpu_on_the_fly, "test", scene.loaded_iter, 193 | scene.getTestCameras(), gaussians, pipeline, 194 | background, specular, use_filter) 195 | 196 | 197 | if __name__ == "__main__": 198 | # Set up command line argument parser 199 | parser = ArgumentParser(description="Testing script parameters") 200 | model = ModelParams(parser, sentinel=True) 201 | op = OptimizationParams(parser) 202 | pipeline = PipelineParams(parser) 203 | parser.add_argument("--iteration", default=-1, type=int) 204 | parser.add_argument("--skip_train", action="store_true") 205 | parser.add_argument("--skip_test", action="store_true") 206 | parser.add_argument("--quiet", action="store_true") 207 | parser.add_argument("--mode", default='render', choices=['render', 'all', 'video']) 208 | args = get_combined_args(parser) 209 | print("Rendering " + args.model_path) 210 | 211 | # Initialize system state (RNG) 212 | safe_state(args.quiet) 213 | 214 | render_sets(model.extract(args), args.iteration, op.extract(args), pipeline.extract(args), args.skip_train, 215 | args.skip_test, args.mode) 216 | -------------------------------------------------------------------------------- /render_anchor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import os 12 | import torch 13 | from os import makedirs 14 | import numpy as np 15 | 16 | from scene import AnchorScene 17 | import time 18 | from gaussian_renderer import anchor_render, anchor_prefilter_voxel 19 | import torchvision 20 | from tqdm import tqdm 21 | from utils.general_utils import safe_state 22 | from utils.pose_utils import generate_ellipse_path, pose_spherical 23 | from utils.graphics_utils import getWorld2View2 24 | from argparse import ArgumentParser 25 | from arguments import AnchorModelParams, PipelineParams, get_combined_args 26 | from gaussian_renderer import AnchorGaussianModel 27 | import imageio 28 | 29 | 30 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 31 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 32 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 33 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth") 34 | 35 | makedirs(render_path, exist_ok=True) 36 | makedirs(gts_path, exist_ok=True) 37 | makedirs(depth_path, exist_ok=True) 38 | 39 | t_list = [] 40 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 41 | voxel_visible_mask = anchor_prefilter_voxel(view, gaussians, pipeline, background) 42 | render_pkg = anchor_render(view, gaussians, pipeline, background, visible_mask=voxel_visible_mask) 43 | rendering = render_pkg["render"] 44 | gt = view.original_image[0:3, :, :] 45 | depth = render_pkg["depth"] 46 | depth = depth / (depth.max() + 1e-5) 47 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 48 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 49 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png")) 50 | 51 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 52 | torch.cuda.synchronize(); 53 | t0 = time.time() 54 | voxel_visible_mask = anchor_prefilter_voxel(view, gaussians, pipeline, background) 55 | render_pkg = anchor_render(view, gaussians, pipeline, background, visible_mask=voxel_visible_mask) 56 | torch.cuda.synchronize(); 57 | t1 = time.time() 58 | 59 | t_list.append(t1 - t0) 60 | 61 | t = np.array(t_list[5:]) 62 | fps = 1.0 / t.mean() 63 | print(f'Test FPS: \033[1;35m{fps:.5f}\033[0m') 64 | 65 | 66 | def render_video(model_path, iteration, views, gaussians, pipeline, background): 67 | render_path = os.path.join(model_path, 'video', "ours_{}".format(iteration)) 68 | makedirs(render_path, exist_ok=True) 69 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 70 | view = views[0] 71 | renderings = [] 72 | for idx, pose in enumerate(tqdm(generate_ellipse_path(views, n_frames=600), desc="Rendering progress")): 73 | view.world_view_transform = torch.tensor( 74 | getWorld2View2(pose[:3, :3].T, pose[:3, 3], view.trans, view.scale)).transpose(0, 1).cuda() 75 | view.full_proj_transform = ( 76 | view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0) 77 | view.camera_center = view.world_view_transform.inverse()[3, :3] 78 | voxel_visible_mask = anchor_prefilter_voxel(view, gaussians, pipeline, background) 79 | rendering = anchor_render(view, gaussians, pipeline, background, visible_mask=voxel_visible_mask)["render"] 80 | renderings.append(to8b(rendering.cpu().numpy())) 81 | # torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 82 | 83 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 84 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8) 85 | 86 | 87 | def interpolate_all(model_path, iteration, views, gaussians, pipeline, background): 88 | render_path = os.path.join(model_path, "interpolate_all_{}".format(iteration), "renders") 89 | depth_path = os.path.join(model_path, "interpolate_all_{}".format(iteration), "depth") 90 | 91 | os.makedirs(render_path, exist_ok=True) 92 | os.makedirs(depth_path, exist_ok=True) 93 | 94 | frame = 520 95 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4) for angle in np.linspace(-180, 180, frame + 1)[:-1]], 0) 96 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 97 | 98 | idx = torch.randint(0, len(views), (1,)).item() 99 | view = views[idx] # Choose a specific time for rendering 100 | 101 | renderings = [] 102 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")): 103 | matrix = np.linalg.inv(np.array(pose)) 104 | R = -np.transpose(matrix[:3, :3]) 105 | R[:, 0] = -R[:, 0] 106 | T = -matrix[:3, 3] 107 | 108 | view.reset_extrinsic(R, T) 109 | 110 | voxel_visible_mask = anchor_prefilter_voxel(view, gaussians, pipeline, background) 111 | rendering = anchor_render(view, gaussians, pipeline, background, visible_mask=voxel_visible_mask)["render"] 112 | renderings.append(to8b(rendering.cpu().numpy())) 113 | # depth = results["depth"] 114 | # depth = depth / (depth.max() + 1e-5) 115 | 116 | # torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png")) 117 | # torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png")) 118 | 119 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 120 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8) 121 | 122 | 123 | def render_sets(dataset: AnchorModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, 124 | skip_test: bool, mode: str): 125 | with torch.no_grad(): 126 | gaussians = AnchorGaussianModel(dataset.feat_dim, dataset.n_offsets, dataset.voxel_size, dataset.update_depth, 127 | dataset.update_init_factor, dataset.update_hierachy_factor) 128 | scene = AnchorScene(dataset, gaussians, load_iteration=iteration, shuffle=False) 129 | 130 | gaussians.eval() 131 | 132 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 133 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 134 | 135 | if mode == "real-360": 136 | render_video(dataset.model_path, scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, 137 | background) 138 | elif mode == "syn-360": 139 | interpolate_all(dataset.model_path, scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, 140 | background) 141 | else: 142 | if not skip_train: 143 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, 144 | background) 145 | if not skip_test: 146 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, 147 | background) 148 | 149 | 150 | if __name__ == "__main__": 151 | # Set up command line argument parser 152 | parser = ArgumentParser(description="Testing script parameters") 153 | model = AnchorModelParams(parser, sentinel=True) 154 | pipeline = PipelineParams(parser) 155 | parser.add_argument("--iteration", default=-1, type=int) 156 | parser.add_argument("--skip_train", action="store_true") 157 | parser.add_argument("--skip_test", action="store_true") 158 | parser.add_argument("--quiet", action="store_true") 159 | parser.add_argument("--mode", default='render', choices=['render', 'syn-360', 'real-360']) 160 | args = get_combined_args(parser) 161 | print("Rendering " + args.model_path) 162 | 163 | # Initialize system state (RNG) 164 | safe_state(args.quiet) 165 | 166 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.mode) 167 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | submodules/depth-diff-gaussian-rasterization 2 | submodules/simple-knn 3 | plyfile==0.8.1 4 | tqdm 5 | imageio==2.27.0 6 | opencv-python 7 | imageio-ffmpeg 8 | scipy 9 | lpips 10 | einops 11 | -------------------------------------------------------------------------------- /run_anchor.sh: -------------------------------------------------------------------------------- 1 | # For mip-360 dataset 2 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/mipnerf-360/bonsai -m outputs/mip360/bonsai-anchor --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 -r 2 3 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/mipnerf-360/counter -m outputs/mip360/counter-anchor --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 --use_c2f -r 2 4 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/mipnerf-360/kitchen -m outputs/mip360/kitchen-anchor --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 -r 2 5 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/mipnerf-360/room -m outputs/mip360/room-anchor --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 -r 2 6 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/mipnerf-360/bicycle -m outputs/mip360/bicycle-anchor --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 --use_c2f -r 4 7 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/mipnerf-360/flowers -m outputs/mip360/flowers-anchor --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 --use_c2f -r 4 8 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/mipnerf-360/garden -m outputs/mip360/garden-anchor --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 -r 4 9 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/mipnerf-360/stump -m outputs/mip360/stump-anchor --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 --use_c2f -r 4 10 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/mipnerf-360/treehill -m outputs/mip360/treehill-anchor --eval --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 --use_c2f -r 4 11 | 12 | # For nerf_synthetic dataset 13 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/chair -m outputs/blender/chair-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 14 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/drums -m outputs/blender/drums-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 15 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/ficus -m outputs/blender/ficus-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 16 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/hotdog -m outputs/blender/hotdog-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 17 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/lego -m outputs/blender/lego-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 18 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/materials -m outputs/blender/materials-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 19 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/mic -m outputs/blender/mic-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 20 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/ship -m outputs/blender/ship-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 21 | 22 | # For nsvf_synthetic dataset 23 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Bike/ -m outputs/nsvf/Bike-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 24 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Lifestyle/ -m outputs/nsvf/Lifestyle-anchor --eval -w --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 25 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Palace/ -m outputs/nsvf/Palace-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 26 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Robot/ -m outputs/nsvf/Robot-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 27 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Steamtrain/ -m outputs/nsvf/Steamtrain-anchor --eval -w --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 28 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Spaceship/ -m outputs/nsvf/Spaceship-anchor --eval -w --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 29 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Toad/ -m outputs/nsvf/Toad-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 30 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Wineholder/ -m outputs/nsvf/Wineholder-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 31 | 32 | # For our anisotropic dataset 33 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/asg/ashtray -m outputs/asg/ashtray-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 34 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/asg/dishes -m outputs/asg/dishes-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 35 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/asg/headphone -m outputs/asg/headphone-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 36 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/asg/jupyter -m outputs/asg/jupyter-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 37 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/asg/lock -m outputs/asg/lock-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 38 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/asg/plane -m outputs/asg/plane-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 39 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/asg/record -m outputs/asg/record-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 40 | python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/asg/teapot -m outputs/asg/teapot-anchor --eval --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 41 | 42 | # python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/tandt_db/tandt/train -m outputs/tandt/train --eval --voxel_size 0.01 --update_init_factor 16 --iterations 30_000 43 | # python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/tandt_db/tandt/truck -m outputs/tandt/truck --eval --voxel_size 0.01 --update_init_factor 16 --iterations 30_000 44 | 45 | # python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/tandt_db/db/drjohnson -m outputs/db/drjohnson --eval --voxel_size 0.005 --update_init_factor 16 --iterations 30_000 --use_c2f 46 | # python train_anchor.py -s /media/data_nix/yzy/Git_Project/data/tandt_db/db/playroom -m outputs/db/playroom --eval --voxel_size 0.005 --update_init_factor 16 --iterations 30_000 --use_c2f 47 | -------------------------------------------------------------------------------- /run_wo_anchor.sh: -------------------------------------------------------------------------------- 1 | # For mip-360 dataset 2 | python train.py -s data/mipnerf-360/bicycle -m outputs/mip360/bicycle --eval -r 4 --is_real --asg_degree 12 3 | python train.py -s data/mipnerf-360/bonsai -m outputs/mip360/bonsai --eval -r 2 --is_real --is_indoor --asg_degree 12 4 | python train.py -s data/mipnerf-360/counter -m outputs/mip360/counter --eval -r 2 --is_real --is_indoor --asg_degree 12 5 | python train.py -s data/mipnerf-360/flowers -m outputs/mip360/flowers --eval -r 4 --is_real --asg_degree 12 6 | python train.py -s data/mipnerf-360/garden -m outputs/mip360/garden --eval -r 4 --is_real --asg_degree 12 7 | python train.py -s data/mipnerf-360/kitchen -m outputs/mip360/kitchen --eval -r 2 --is_real --is_indoor --asg_degree 12 8 | python train.py -s data/mipnerf-360/room -m outputs/mip360/room --eval -r 2 --is_real --is_indoor --asg_degree 12 9 | python train.py -s data/mipnerf-360/stump -m outputs/mip360/stump --eval -r 4 --is_real --asg_degree 12 10 | python train.py -s data/mipnerf-360/treehill -m outputs/mip360/treehill --eval -r 4 --is_real --asg_degree 12 11 | 12 | 13 | # For nerf_synthetic dataset 14 | python train.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/chair -m outputs/blender/chair --eval 15 | python train.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/drums -m outputs/blender/drums --eval 16 | python train.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/ficus -m outputs/blender/ficus --eval 17 | python train.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/hotdog -m outputs/blender/hotdog --eval 18 | python train.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/lego -m outputs/blender/lego --eval 19 | python train.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/materials -m outputs/blender/materials --eval 20 | python train.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/mic -m outputs/blender/mic --eval 21 | python train.py -s /media/data_nix/yzy/Git_Project/data/nerf_synthetic/ship -m outputs/blender/ship --eval 22 | 23 | # For nsvf_synthetic dataset 24 | python train.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Bike/ -m outputs/nsvf/Bike --eval 25 | python train.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Lifestyle/ -m outputs/nsvf/Lifestyle --eval -w 26 | python train.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Palace/ -m outputs/nsvf/Palace --eval 27 | python train.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Robot/ -m outputs/nsvf/Robot --eval 28 | python train.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Steamtrain/ -m outputs/nsvf/Steamtrain --eval -w 29 | python train.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Spaceship/ -m outputs/nsvf/Spaceship --eval -w 30 | python train.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Toad/ -m outputs/nsvf/Toad --eval 31 | python train.py -s /media/data_nix/yzy/Git_Project/data/Synthetic_NSVF/Wineholder/ -m outputs/nsvf/Wineholder --eval 32 | 33 | # For our anisotropic dataset 34 | python train.py -s /media/data_nix/yzy/Git_Project/data/asg/ashtray -m outputs/asg/ashtray --eval 35 | python train.py -s /media/data_nix/yzy/Git_Project/data/asg/dishes -m outputs/asg/dishes --eval 36 | python train.py -s /media/data_nix/yzy/Git_Project/data/asg/headphone -m outputs/asg/headphone --eval 37 | python train.py -s /media/data_nix/yzy/Git_Project/data/asg/jupyter -m outputs/asg/jupyter --eval 38 | python train.py -s /media/data_nix/yzy/Git_Project/data/asg/lock -m outputs/asg/lock --eval 39 | python train.py -s /media/data_nix/yzy/Git_Project/data/asg/plane -m outputs/asg/plane --eval 40 | python train.py -s /media/data_nix/yzy/Git_Project/data/asg/record -m outputs/asg/record --eval 41 | python train.py -s /media/data_nix/yzy/Git_Project/data/asg/teapot -m outputs/asg/teapot --eval 42 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from scene.specular_model import SpecularModel 19 | from scene.anchor_gaussian_model import AnchorGaussianModel 20 | from arguments import ModelParams, AnchorModelParams 21 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 22 | 23 | 24 | class Scene: 25 | gaussians: GaussianModel 26 | 27 | def __init__(self, args: ModelParams, gaussians: GaussianModel, load_iteration=None, shuffle=True, 28 | resolution_scales=[1.0]): 29 | """b 30 | :param path: Path to colmap scene main folder. 31 | """ 32 | self.model_path = args.model_path 33 | self.loaded_iter = None 34 | self.gaussians = gaussians 35 | 36 | if load_iteration: 37 | if load_iteration == -1: 38 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 39 | else: 40 | self.loaded_iter = load_iteration 41 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 42 | 43 | self.train_cameras = {} 44 | self.test_cameras = {} 45 | 46 | if os.path.exists(os.path.join(args.source_path, "sparse")): 47 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 48 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 49 | print("Found transforms_train.json file, assuming Blender data set!") 50 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, args.add_val) 51 | elif os.path.exists(os.path.join(args.source_path, "transforms.json")): 52 | print("Found transforms_train.json file, assuming Blender data set!") 53 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 54 | elif os.path.exists(os.path.join(args.source_path, "cameras_sphere.npz")): 55 | print("Found cameras_sphere.npz file, assuming DTU data set!") 56 | scene_info = sceneLoadTypeCallbacks["DTU"](args.source_path, "cameras_sphere.npz", "cameras_sphere.npz") 57 | elif os.path.exists(os.path.join(args.source_path, "bbox.txt")): 58 | print("Found bbox.txt file, assuming NSVF Blender data set!") 59 | scene_info = sceneLoadTypeCallbacks["NSVF"](args.source_path, args.white_background, args.eval) 60 | else: 61 | assert False, "Could not recognize scene type!" 62 | 63 | if not self.loaded_iter: 64 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply"), 65 | 'wb') as dest_file: 66 | dest_file.write(src_file.read()) 67 | json_cams = [] 68 | camlist = [] 69 | if scene_info.test_cameras: 70 | camlist.extend(scene_info.test_cameras) 71 | if scene_info.train_cameras: 72 | camlist.extend(scene_info.train_cameras) 73 | for id, cam in enumerate(camlist): 74 | json_cams.append(camera_to_JSON(id, cam)) 75 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 76 | json.dump(json_cams, file) 77 | 78 | if shuffle: 79 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 80 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 81 | 82 | self.cameras_extent = scene_info.nerf_normalization["radius"] 83 | 84 | for resolution_scale in resolution_scales: 85 | print("Loading Training Cameras") 86 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, 87 | args) 88 | print("Loading Test Cameras") 89 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, 90 | args) 91 | 92 | if self.loaded_iter: 93 | self.gaussians.load_ply(os.path.join(self.model_path, 94 | "point_cloud", 95 | "iteration_" + str(self.loaded_iter), 96 | "point_cloud.ply"), 97 | og_number_points=len(scene_info.point_cloud.points)) 98 | else: 99 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 100 | 101 | def save(self, iteration): 102 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 103 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 104 | 105 | def getTrainCameras(self, scale=1.0): 106 | return self.train_cameras[scale] 107 | 108 | def getTestCameras(self, scale=1.0): 109 | return self.test_cameras[scale] 110 | 111 | 112 | class AnchorScene: 113 | gaussians: AnchorGaussianModel 114 | 115 | def __init__(self, args: AnchorModelParams, gaussians: AnchorGaussianModel, load_iteration=None, shuffle=True, 116 | resolution_scales=[1.0]): 117 | """b 118 | :param path: Path to colmap scene main folder. 119 | """ 120 | self.model_path = args.model_path 121 | self.loaded_iter = None 122 | self.gaussians = gaussians 123 | 124 | if load_iteration: 125 | if load_iteration == -1: 126 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 127 | else: 128 | self.loaded_iter = load_iteration 129 | 130 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 131 | 132 | self.train_cameras = {} 133 | self.test_cameras = {} 134 | 135 | if os.path.exists(os.path.join(args.source_path, "sparse")): 136 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 137 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 138 | print("Found transforms_train.json file, assuming Blender data set!") 139 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 140 | elif os.path.exists(os.path.join(args.source_path, "bbox.txt")): 141 | print("Found bbox.txt file, assuming NSVF Blender data set!") 142 | scene_info = sceneLoadTypeCallbacks["NSVF"](args.source_path, args.white_background, args.eval) 143 | else: 144 | assert False, "Could not recognize scene type!" 145 | 146 | if not self.loaded_iter: 147 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply"), 148 | 'wb') as dest_file: 149 | dest_file.write(src_file.read()) 150 | json_cams = [] 151 | camlist = [] 152 | if scene_info.test_cameras: 153 | camlist.extend(scene_info.test_cameras) 154 | if scene_info.train_cameras: 155 | camlist.extend(scene_info.train_cameras) 156 | for id, cam in enumerate(camlist): 157 | json_cams.append(camera_to_JSON(id, cam)) 158 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 159 | json.dump(json_cams, file) 160 | 161 | if shuffle: 162 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 163 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 164 | 165 | self.cameras_extent = scene_info.nerf_normalization["radius"] 166 | 167 | for resolution_scale in resolution_scales: 168 | print("Loading Training Cameras") 169 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, 170 | args) 171 | print("Loading Test Cameras") 172 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, 173 | args) 174 | 175 | if self.loaded_iter: 176 | self.gaussians.load_ply_sparse_gaussian(os.path.join(self.model_path, 177 | "point_cloud", 178 | "iteration_" + str(self.loaded_iter), 179 | "point_cloud.ply")) 180 | self.gaussians.load_mlp_checkpoints(os.path.join(self.model_path, 181 | "point_cloud", 182 | "iteration_" + str(self.loaded_iter), 183 | "checkpoint.pth")) 184 | else: 185 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 186 | 187 | def save(self, iteration): 188 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 189 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 190 | self.gaussians.save_mlp_checkpoints(os.path.join(point_cloud_path, "checkpoint.pth")) 191 | 192 | def getTrainCameras(self, scale=1.0): 193 | return self.train_cameras[scale] 194 | 195 | def getTestCameras(self, scale=1.0): 196 | return self.test_cameras[scale] 197 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | 18 | class Camera(nn.Module): 19 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", depth=None): 21 | super(Camera, self).__init__() 22 | 23 | self.uid = uid 24 | self.colmap_id = colmap_id 25 | self.R = R 26 | self.T = T 27 | self.FoVx = FoVx 28 | self.FoVy = FoVy 29 | self.image_name = image_name 30 | 31 | try: 32 | self.data_device = torch.device(data_device) 33 | except Exception as e: 34 | print(e) 35 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device") 36 | self.data_device = torch.device("cuda") 37 | 38 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 39 | self.image_width = self.original_image.shape[2] 40 | self.image_height = self.original_image.shape[1] 41 | self.depth = torch.Tensor(depth).to(self.data_device) if depth is not None else None 42 | 43 | if gt_alpha_mask is not None: 44 | self.original_image *= gt_alpha_mask.to(self.data_device) 45 | else: 46 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 47 | 48 | self.zfar = 100.0 49 | self.znear = 0.01 50 | 51 | self.trans = trans 52 | self.scale = scale 53 | 54 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).to( 55 | self.data_device) 56 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, 57 | fovY=self.FoVy).transpose(0, 1).to(self.data_device) 58 | self.full_proj_transform = ( 59 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 60 | self.camera_center = self.world_view_transform.inverse()[3, :3] 61 | 62 | def reset_extrinsic(self, R, T): 63 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, self.trans, self.scale)).transpose(0, 1).cuda() 64 | self.full_proj_transform = ( 65 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 66 | self.camera_center = self.world_view_transform.inverse()[3, :3] 67 | 68 | def load2device(self, data_device='cuda'): 69 | self.original_image = self.original_image.to(data_device) 70 | self.world_view_transform = self.world_view_transform.to(data_device) 71 | self.projection_matrix = self.projection_matrix.to(data_device) 72 | self.full_proj_transform = self.full_proj_transform.to(data_device) 73 | self.camera_center = self.camera_center.to(data_device) 74 | self.fid = self.fid.to(data_device) 75 | 76 | 77 | class MiniCam: 78 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 79 | self.image_width = width 80 | self.image_height = height 81 | self.FoVy = fovy 82 | self.FoVx = fovx 83 | self.znear = znear 84 | self.zfar = zfar 85 | self.world_view_transform = world_view_transform 86 | self.full_proj_transform = full_proj_transform 87 | view_inv = torch.inverse(self.world_view_transform) 88 | self.camera_center = view_inv[3][:3] 89 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]]) 54 | 55 | 56 | def rotmat2qvec(R): 57 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 58 | K = np.array([ 59 | [Rxx - Ryy - Rzz, 0, 0, 0], 60 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 61 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 62 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 63 | eigvals, eigvecs = np.linalg.eigh(K) 64 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 65 | if qvec[0] < 0: 66 | qvec *= -1 67 | return qvec 68 | 69 | 70 | class Image(BaseImage): 71 | def qvec2rotmat(self): 72 | return qvec2rotmat(self.qvec) 73 | 74 | 75 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 76 | """Read and unpack the next bytes from a binary file. 77 | :param fid: 78 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 79 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 80 | :param endian_character: Any of {@, =, <, >, !} 81 | :return: Tuple of read and unpacked values. 82 | """ 83 | data = fid.read(num_bytes) 84 | return struct.unpack(endian_character + format_char_sequence, data) 85 | 86 | 87 | def read_points3D_text(path): 88 | """ 89 | see: src/base/reconstruction.cc 90 | void Reconstruction::ReadPoints3DText(const std::string& path) 91 | void Reconstruction::WritePoints3DText(const std::string& path) 92 | """ 93 | xyzs = None 94 | rgbs = None 95 | errors = None 96 | with open(path, "r") as fid: 97 | while True: 98 | line = fid.readline() 99 | if not line: 100 | break 101 | line = line.strip() 102 | if len(line) > 0 and line[0] != "#": 103 | elems = line.split() 104 | xyz = np.array(tuple(map(float, elems[1:4]))) 105 | rgb = np.array(tuple(map(int, elems[4:7]))) 106 | error = np.array(float(elems[7])) 107 | if xyzs is None: 108 | xyzs = xyz[None, ...] 109 | rgbs = rgb[None, ...] 110 | errors = error[None, ...] 111 | else: 112 | xyzs = np.append(xyzs, xyz[None, ...], axis=0) 113 | rgbs = np.append(rgbs, rgb[None, ...], axis=0) 114 | errors = np.append(errors, error[None, ...], axis=0) 115 | return xyzs, rgbs, errors 116 | 117 | 118 | def read_points3D_binary(path_to_model_file): 119 | """ 120 | see: src/base/reconstruction.cc 121 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 122 | void Reconstruction::WritePoints3DBinary(const std::string& path) 123 | """ 124 | 125 | with open(path_to_model_file, "rb") as fid: 126 | num_points = read_next_bytes(fid, 8, "Q")[0] 127 | 128 | xyzs = np.empty((num_points, 3)) 129 | rgbs = np.empty((num_points, 3)) 130 | errors = np.empty((num_points, 1)) 131 | 132 | for p_id in range(num_points): 133 | binary_point_line_properties = read_next_bytes( 134 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 135 | xyz = np.array(binary_point_line_properties[1:4]) 136 | rgb = np.array(binary_point_line_properties[4:7]) 137 | error = np.array(binary_point_line_properties[7]) 138 | track_length = read_next_bytes( 139 | fid, num_bytes=8, format_char_sequence="Q")[0] 140 | track_elems = read_next_bytes( 141 | fid, num_bytes=8 * track_length, 142 | format_char_sequence="ii" * track_length) 143 | xyzs[p_id] = xyz 144 | rgbs[p_id] = rgb 145 | errors[p_id] = error 146 | return xyzs, rgbs, errors 147 | 148 | 149 | def read_intrinsics_text(path): 150 | """ 151 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 152 | """ 153 | cameras = {} 154 | with open(path, "r") as fid: 155 | while True: 156 | line = fid.readline() 157 | if not line: 158 | break 159 | line = line.strip() 160 | if len(line) > 0 and line[0] != "#": 161 | elems = line.split() 162 | camera_id = int(elems[0]) 163 | model = elems[1] 164 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 165 | width = int(elems[2]) 166 | height = int(elems[3]) 167 | params = np.array(tuple(map(float, elems[4:]))) 168 | cameras[camera_id] = Camera(id=camera_id, model=model, 169 | width=width, height=height, 170 | params=params) 171 | return cameras 172 | 173 | 174 | def read_extrinsics_binary(path_to_model_file): 175 | """ 176 | see: src/base/reconstruction.cc 177 | void Reconstruction::ReadImagesBinary(const std::string& path) 178 | void Reconstruction::WriteImagesBinary(const std::string& path) 179 | """ 180 | images = {} 181 | with open(path_to_model_file, "rb") as fid: 182 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 183 | for _ in range(num_reg_images): 184 | binary_image_properties = read_next_bytes( 185 | fid, num_bytes=64, format_char_sequence="idddddddi") 186 | image_id = binary_image_properties[0] 187 | qvec = np.array(binary_image_properties[1:5]) 188 | tvec = np.array(binary_image_properties[5:8]) 189 | camera_id = binary_image_properties[8] 190 | image_name = "" 191 | current_char = read_next_bytes(fid, 1, "c")[0] 192 | while current_char != b"\x00": # look for the ASCII 0 entry 193 | image_name += current_char.decode("utf-8") 194 | current_char = read_next_bytes(fid, 1, "c")[0] 195 | num_points2D = read_next_bytes(fid, num_bytes=8, 196 | format_char_sequence="Q")[0] 197 | x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, 198 | format_char_sequence="ddq" * num_points2D) 199 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 200 | tuple(map(float, x_y_id_s[1::3]))]) 201 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 202 | images[image_id] = Image( 203 | id=image_id, qvec=qvec, tvec=tvec, 204 | camera_id=camera_id, name=image_name, 205 | xys=xys, point3D_ids=point3D_ids) 206 | return images 207 | 208 | 209 | def read_intrinsics_binary(path_to_model_file): 210 | """ 211 | see: src/base/reconstruction.cc 212 | void Reconstruction::WriteCamerasBinary(const std::string& path) 213 | void Reconstruction::ReadCamerasBinary(const std::string& path) 214 | """ 215 | cameras = {} 216 | with open(path_to_model_file, "rb") as fid: 217 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 218 | for _ in range(num_cameras): 219 | camera_properties = read_next_bytes( 220 | fid, num_bytes=24, format_char_sequence="iiQQ") 221 | camera_id = camera_properties[0] 222 | model_id = camera_properties[1] 223 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 224 | width = camera_properties[2] 225 | height = camera_properties[3] 226 | num_params = CAMERA_MODEL_IDS[model_id].num_params 227 | params = read_next_bytes(fid, num_bytes=8 * num_params, 228 | format_char_sequence="d" * num_params) 229 | cameras[camera_id] = Camera(id=camera_id, 230 | model=model_name, 231 | width=width, 232 | height=height, 233 | params=np.array(params)) 234 | assert len(cameras) == num_cameras 235 | return cameras 236 | 237 | 238 | def read_extrinsics_text(path): 239 | """ 240 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 241 | """ 242 | images = {} 243 | with open(path, "r") as fid: 244 | while True: 245 | line = fid.readline() 246 | if not line: 247 | break 248 | line = line.strip() 249 | if len(line) > 0 and line[0] != "#": 250 | elems = line.split() 251 | image_id = int(elems[0]) 252 | qvec = np.array(tuple(map(float, elems[1:5]))) 253 | tvec = np.array(tuple(map(float, elems[5:8]))) 254 | camera_id = int(elems[8]) 255 | image_name = elems[9] 256 | elems = fid.readline().split() 257 | xys = np.column_stack([tuple(map(float, elems[0::3])), 258 | tuple(map(float, elems[1::3]))]) 259 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 260 | images[image_id] = Image( 261 | id=image_id, qvec=qvec, tvec=tvec, 262 | camera_id=camera_id, name=image_name, 263 | xys=xys, point3D_ids=point3D_ids) 264 | return images 265 | 266 | 267 | def read_colmap_bin_array(path): 268 | """ 269 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 270 | 271 | :param path: path to the colmap binary file. 272 | :return: nd array with the floating point values in the value 273 | """ 274 | with open(path, "rb") as fid: 275 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 276 | usecols=(0, 1, 2), dtype=int) 277 | fid.seek(0) 278 | num_delimiter = 0 279 | byte = fid.read(1) 280 | while True: 281 | if byte == b"&": 282 | num_delimiter += 1 283 | if num_delimiter >= 3: 284 | break 285 | byte = fid.read(1) 286 | array = np.fromfile(fid, np.float32) 287 | array = array.reshape((width, height, channels), order="F") 288 | return np.transpose(array, (1, 0, 2)).squeeze() 289 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple, Optional 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | import imageio 22 | from glob import glob 23 | import cv2 as cv 24 | from pathlib import Path 25 | from plyfile import PlyData, PlyElement 26 | from utils.sh_utils import SH2RGB 27 | from scene.gaussian_model import BasicPointCloud 28 | from utils.camera_utils import camera_nerfies_from_JSON 29 | 30 | 31 | class CameraInfo(NamedTuple): 32 | uid: int 33 | R: np.array 34 | T: np.array 35 | FovY: np.array 36 | FovX: np.array 37 | image: np.array 38 | image_path: str 39 | image_name: str 40 | width: int 41 | height: int 42 | depth: Optional[np.array] = None 43 | 44 | 45 | class SceneInfo(NamedTuple): 46 | point_cloud: BasicPointCloud 47 | train_cameras: list 48 | test_cameras: list 49 | nerf_normalization: dict 50 | ply_path: str 51 | 52 | 53 | def load_K_Rt_from_P(filename, P=None): 54 | if P is None: 55 | lines = open(filename).read().splitlines() 56 | if len(lines) == 4: 57 | lines = lines[1:] 58 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 59 | P = np.asarray(lines).astype(np.float32).squeeze() 60 | 61 | out = cv.decomposeProjectionMatrix(P) 62 | K = out[0] 63 | R = out[1] 64 | t = out[2] 65 | 66 | K = K / K[2, 2] 67 | 68 | pose = np.eye(4, dtype=np.float32) 69 | pose[:3, :3] = R.transpose() 70 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 71 | 72 | return K, pose 73 | 74 | 75 | def getNerfppNorm(cam_info): 76 | def get_center_and_diag(cam_centers): 77 | cam_centers = np.hstack(cam_centers) 78 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 79 | center = avg_cam_center 80 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 81 | diagonal = np.max(dist) 82 | return center.flatten(), diagonal 83 | 84 | cam_centers = [] 85 | 86 | for cam in cam_info: 87 | W2C = getWorld2View2(cam.R, cam.T) 88 | C2W = np.linalg.inv(W2C) 89 | cam_centers.append(C2W[:3, 3:4]) 90 | 91 | center, diagonal = get_center_and_diag(cam_centers) 92 | radius = diagonal * 1.1 93 | 94 | translate = -center 95 | 96 | return {"translate": translate, "radius": radius} 97 | 98 | 99 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 100 | cam_infos = [] 101 | num_frames = len(cam_extrinsics) 102 | for idx, key in enumerate(cam_extrinsics): 103 | sys.stdout.write('\r') 104 | # the exact output you're looking for: 105 | sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics))) 106 | sys.stdout.flush() 107 | 108 | extr = cam_extrinsics[key] 109 | intr = cam_intrinsics[extr.camera_id] 110 | height = intr.height 111 | width = intr.width 112 | 113 | uid = intr.id 114 | R = np.transpose(qvec2rotmat(extr.qvec)) 115 | T = np.array(extr.tvec) 116 | 117 | if intr.model == "SIMPLE_PINHOLE": 118 | focal_length_x = intr.params[0] 119 | FovY = focal2fov(focal_length_x, height) 120 | FovX = focal2fov(focal_length_x, width) 121 | elif intr.model == "PINHOLE" or intr.model == "OPENCV" or intr.model == "SIMPLE_RADIAL": 122 | focal_length_x = intr.params[0] 123 | focal_length_y = intr.params[1] 124 | FovY = focal2fov(focal_length_y, height) 125 | FovX = focal2fov(focal_length_x, width) 126 | else: 127 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 128 | 129 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 130 | image_name = os.path.basename(image_path).split(".")[0] 131 | image = Image.open(image_path) 132 | 133 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 134 | image_path=image_path, image_name=image_name, width=width, height=height) 135 | cam_infos.append(cam_info) 136 | sys.stdout.write('\n') 137 | return cam_infos 138 | 139 | 140 | def fetchPly(path): 141 | plydata = PlyData.read(path) 142 | vertices = plydata['vertex'] 143 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 144 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 145 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 146 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 147 | 148 | 149 | def storePly(path, xyz, rgb): 150 | # Define the dtype for the structured array 151 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 152 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 153 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 154 | 155 | normals = np.zeros_like(xyz) 156 | 157 | elements = np.empty(xyz.shape[0], dtype=dtype) 158 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 159 | elements[:] = list(map(tuple, attributes)) 160 | 161 | # Create the PlyData object and write to file 162 | vertex_element = PlyElement.describe(elements, 'vertex') 163 | ply_data = PlyData([vertex_element]) 164 | ply_data.write(path) 165 | 166 | 167 | def readColmapSceneInfo(path, images, eval, llffhold=8): 168 | try: 169 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 170 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 171 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 172 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 173 | except: 174 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 175 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 176 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 177 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 178 | 179 | reading_dir = "images" if images == None else images 180 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 181 | images_folder=os.path.join(path, reading_dir)) 182 | cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name) 183 | 184 | if eval: 185 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 186 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 187 | else: 188 | train_cam_infos = cam_infos 189 | test_cam_infos = [] 190 | 191 | nerf_normalization = getNerfppNorm(train_cam_infos) 192 | 193 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 194 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 195 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 196 | if not os.path.exists(ply_path): 197 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 198 | try: 199 | xyz, rgb, _ = read_points3D_binary(bin_path) 200 | except: 201 | xyz, rgb, _ = read_points3D_text(txt_path) 202 | storePly(ply_path, xyz, rgb) 203 | try: 204 | pcd = fetchPly(ply_path) 205 | except: 206 | pcd = None 207 | 208 | scene_info = SceneInfo(point_cloud=pcd, 209 | train_cameras=train_cam_infos, 210 | test_cameras=test_cam_infos, 211 | nerf_normalization=nerf_normalization, 212 | ply_path=ply_path) 213 | return scene_info 214 | 215 | 216 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 217 | cam_infos = [] 218 | 219 | with open(os.path.join(path, transformsfile)) as json_file: 220 | contents = json.load(json_file) 221 | fovx = contents["camera_angle_x"] 222 | 223 | frames = contents["frames"] 224 | for idx, frame in enumerate(frames): 225 | cam_name = os.path.join(path, frame["file_path"] + extension) 226 | 227 | matrix = np.linalg.inv(np.array(frame["transform_matrix"])) 228 | R = -np.transpose(matrix[:3, :3]) 229 | R[:, 0] = -R[:, 0] 230 | T = -matrix[:3, 3] 231 | 232 | image_path = os.path.join(path, cam_name) 233 | image_name = Path(cam_name).stem 234 | image = Image.open(image_path) 235 | # depth = imageio.imread(depth_name) 236 | 237 | im_data = np.array(image.convert("RGBA")) 238 | 239 | bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0]) 240 | 241 | norm_data = im_data / 255.0 242 | 243 | arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 244 | image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB") 245 | 246 | # if "train" in transformsfile: 247 | # normal_path = os.path.join(path, frame["file_path"] + "_normal.png") 248 | # normal_raw = (imageio.imread(normal_path) / 255.0) 249 | # normal = normal_raw[..., :3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[..., 3:4]) 250 | # else: 251 | # normal = None 252 | 253 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 254 | FovY = fovx 255 | FovX = fovy 256 | 257 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 258 | image_path=image_path, image_name=image_name, width=image.size[0], 259 | height=image.size[1], depth=None)) 260 | 261 | return cam_infos 262 | 263 | 264 | def readNerfSyntheticInfo(path, white_background, eval, read_val=False, extension=".png"): 265 | try: 266 | print("Reading Training Transforms") 267 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 268 | print("Reading Test Transforms") 269 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 270 | 271 | if not eval: 272 | train_cam_infos.extend(test_cam_infos) 273 | test_cam_infos = [] 274 | 275 | if read_val: 276 | print("Reading Val Transforms") 277 | val_cam_infos = readCamerasFromTransforms(path, "transforms_val.json", white_background, extension) 278 | train_cam_infos.extend(val_cam_infos) 279 | except: 280 | print("Reading All Transforms") 281 | cam_infos = readCamerasFromTransforms(path, "transforms.json", white_background, extension) 282 | 283 | if eval: 284 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % 3 != 0] 285 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % 3 == 0] 286 | 287 | nerf_normalization = getNerfppNorm(train_cam_infos) 288 | 289 | ply_path = os.path.join(path, "points3d.ply") 290 | if not os.path.exists(ply_path): 291 | # Since this data set has no colmap data, we start with random points 292 | num_pts = 100_000 293 | print(f"Generating random point cloud ({num_pts})...") 294 | 295 | # We create random points inside the bounds of the synthetic Blender scenes 296 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 297 | shs = np.random.random((num_pts, 3)) / 255.0 298 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 299 | 300 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 301 | try: 302 | pcd = fetchPly(ply_path) 303 | except: 304 | pcd = None 305 | 306 | scene_info = SceneInfo(point_cloud=pcd, 307 | train_cameras=train_cam_infos, 308 | test_cameras=test_cam_infos, 309 | nerf_normalization=nerf_normalization, 310 | ply_path=ply_path) 311 | return scene_info 312 | 313 | 314 | def readDTUCameras(path, render_camera, object_camera): 315 | camera_dict = np.load(os.path.join(path, render_camera)) 316 | images_lis = sorted(glob(os.path.join(path, 'image/*.png'))) 317 | masks_lis = sorted(glob(os.path.join(path, 'mask/*.png'))) 318 | n_images = len(images_lis) 319 | cam_infos = [] 320 | for idx in range(0, n_images): 321 | image_path = images_lis[idx] 322 | image = np.array(Image.open(image_path)) 323 | mask = np.array(imageio.imread(masks_lis[idx])) / 255.0 324 | image = Image.fromarray((image * mask).astype(np.uint8)) 325 | world_mat = camera_dict['world_mat_%d' % idx].astype(np.float32) 326 | image_name = Path(image_path).stem 327 | scale_mat = camera_dict['scale_mat_%d' % idx].astype(np.float32) 328 | P = world_mat @ scale_mat 329 | P = P[:3, :4] 330 | 331 | K, pose = load_K_Rt_from_P(None, P) 332 | a = pose[0:1, :] 333 | b = pose[1:2, :] 334 | c = pose[2:3, :] 335 | 336 | pose = np.concatenate([a, -c, -b, pose[3:, :]], 0) 337 | 338 | S = np.eye(3) 339 | S[1, 1] = -1 340 | S[2, 2] = -1 341 | pose[1, 3] = -pose[1, 3] 342 | pose[2, 3] = -pose[2, 3] 343 | pose[:3, :3] = S @ pose[:3, :3] @ S 344 | 345 | a = pose[0:1, :] 346 | b = pose[1:2, :] 347 | c = pose[2:3, :] 348 | 349 | pose = np.concatenate([a, c, b, pose[3:, :]], 0) 350 | 351 | pose[:, 3] *= 0.5 352 | 353 | matrix = np.linalg.inv(pose) 354 | R = -np.transpose(matrix[:3, :3]) 355 | R[:, 0] = -R[:, 0] 356 | T = -matrix[:3, 3] 357 | 358 | FovY = focal2fov(K[0, 0], image.size[1]) 359 | FovX = focal2fov(K[0, 0], image.size[0]) 360 | cam_info = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 361 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]) 362 | cam_infos.append(cam_info) 363 | sys.stdout.write('\n') 364 | return cam_infos 365 | 366 | 367 | def readNeuSDTUInfo(path, render_camera, object_camera): 368 | print("Reading DTU Info") 369 | train_cam_infos = readDTUCameras(path, render_camera, object_camera) 370 | 371 | nerf_normalization = getNerfppNorm(train_cam_infos) 372 | 373 | ply_path = os.path.join(path, "points3d.ply") 374 | if not os.path.exists(ply_path): 375 | # Since this data set has no colmap data, we start with random points 376 | num_pts = 100_000 377 | print(f"Generating random point cloud ({num_pts})...") 378 | 379 | # We create random points inside the bounds of the synthetic Blender scenes 380 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 381 | shs = np.random.random((num_pts, 3)) / 255.0 382 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 383 | 384 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 385 | try: 386 | pcd = fetchPly(ply_path) 387 | except: 388 | pcd = None 389 | 390 | scene_info = SceneInfo(point_cloud=pcd, 391 | train_cameras=train_cam_infos, 392 | test_cameras=[], 393 | nerf_normalization=nerf_normalization, 394 | ply_path=ply_path) 395 | return scene_info 396 | 397 | 398 | def readCamerasFromNSVFPoses(path, idx, white_background, extension=".png"): 399 | cam_infos = [] 400 | all_poses = sorted(os.listdir(os.path.join(path, "pose"))) 401 | all_rgbs = sorted(os.listdir(os.path.join(path, "rgb"))) 402 | 403 | with open(os.path.join(path, "intrinsics.txt")) as f: 404 | focal = float(f.readline().split()[0]) 405 | for i in idx: 406 | cam_name = os.path.join(path, "pose", all_poses[i]) 407 | c2w = np.loadtxt(cam_name) 408 | w2c = np.linalg.inv(c2w) 409 | 410 | R = np.transpose(w2c[:3, :3]) 411 | T = w2c[:3, 3] 412 | 413 | image_path = os.path.join(path, "rgb", all_rgbs[i]) 414 | image_name = Path(cam_name).stem 415 | image = Image.open(image_path) 416 | 417 | im_data = np.array(image.convert("RGBA")) 418 | 419 | bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0]) 420 | 421 | norm_data = im_data / 255.0 422 | arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 423 | image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB") 424 | 425 | # given focal in pixel unit 426 | FovY = focal2fov(focal, image.size[1]) 427 | FovX = focal2fov(focal, image.size[0]) 428 | 429 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 430 | image_path=image_path, image_name=image_name, width=image.size[0], 431 | height=image.size[1])) 432 | 433 | return cam_infos 434 | 435 | 436 | def readNSVFSyntheticInfo(path, white_background, eval, extension=".png"): 437 | all_rgbs = sorted(os.listdir(os.path.join(path, "rgb"))) 438 | 439 | train_idx = [idx for idx, file_name in enumerate(all_rgbs) if file_name.startswith("0_")] 440 | test_idx = [idx for idx, file_name in enumerate(all_rgbs) if file_name.startswith("2_")] 441 | 442 | print("Reading Training Transforms") 443 | train_cam_infos = readCamerasFromNSVFPoses(path, train_idx, white_background, extension) 444 | print("Reading Test Transforms") 445 | test_cam_infos = readCamerasFromNSVFPoses(path, test_idx, white_background, extension) 446 | 447 | if not eval: 448 | train_cam_infos.extend(test_cam_infos) 449 | test_cam_infos = [] 450 | 451 | nerf_normalization = getNerfppNorm(train_cam_infos) 452 | 453 | ply_path = os.path.join(path, "points3d.ply") 454 | if not os.path.exists(ply_path): 455 | # Since this data set has no colmap data, we start with random points 456 | num_pts = 10_000 457 | print(f"Generating random point cloud ({num_pts})...") 458 | 459 | # We create random points inside the bounds of the NSVF synthetic Blender scenes 460 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 461 | shs = np.random.random((num_pts, 3)) / 255.0 462 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 463 | 464 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 465 | try: 466 | pcd = fetchPly(ply_path) 467 | except: 468 | pcd = None 469 | 470 | scene_info = SceneInfo(point_cloud=pcd, 471 | train_cameras=train_cam_infos, 472 | test_cameras=test_cam_infos, 473 | nerf_normalization=nerf_normalization, 474 | ply_path=ply_path) 475 | return scene_info 476 | 477 | 478 | sceneLoadTypeCallbacks = { 479 | "Colmap": readColmapSceneInfo, 480 | "Blender": readNerfSyntheticInfo, 481 | "DTU": readNeuSDTUInfo, 482 | "NSVF": readNSVFSyntheticInfo, 483 | } 484 | -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation, get_linear_noise_func 15 | from torch import nn 16 | import os 17 | from utils.system_utils import mkdir_p 18 | from plyfile import PlyData, PlyElement 19 | from utils.sh_utils import RGB2SH 20 | from simple_knn._C import distCUDA2 21 | from utils.graphics_utils import BasicPointCloud 22 | from utils.general_utils import strip_symmetric, build_scaling_rotation, flip_align_view, get_minimum_axis 23 | 24 | 25 | class GaussianModel: 26 | def __init__(self, sh_degree: int, asg_degree: int): 27 | 28 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 29 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 30 | actual_covariance = L @ L.transpose(1, 2) 31 | symm = strip_symmetric(actual_covariance) 32 | return symm 33 | 34 | self.active_sh_degree = 0 35 | self.max_sh_degree = sh_degree 36 | self.max_asg_degree = asg_degree 37 | 38 | self._xyz = torch.empty(0) 39 | self._features_dc = torch.empty(0) 40 | self._features_rest = torch.empty(0) 41 | self._scaling = torch.empty(0) 42 | self._rotation = torch.empty(0) 43 | self._opacity = torch.empty(0) 44 | self.max_radii2D = torch.empty(0) 45 | self.xyz_gradient_accum = torch.empty(0) 46 | self._features_asg = torch.empty(0) 47 | 48 | self.optimizer = None 49 | 50 | self.scaling_activation = torch.exp 51 | self.scaling_inverse_activation = torch.log 52 | 53 | self.covariance_activation = build_covariance_from_scaling_rotation 54 | 55 | self.opacity_activation = torch.sigmoid 56 | self.inverse_opacity_activation = inverse_sigmoid 57 | 58 | self.rotation_activation = torch.nn.functional.normalize 59 | 60 | @property 61 | def get_asg_features(self): 62 | return self._features_asg 63 | 64 | @property 65 | def get_scaling(self): 66 | return self.scaling_activation(self._scaling) 67 | 68 | @property 69 | def get_rotation(self): 70 | return self.rotation_activation(self._rotation) 71 | 72 | @property 73 | def get_xyz(self): 74 | return self._xyz 75 | 76 | @property 77 | def get_features(self): 78 | features_dc = self._features_dc 79 | features_rest = self._features_rest 80 | return torch.cat((features_dc, features_rest), dim=1) 81 | 82 | @property 83 | def get_opacity(self): 84 | return self.opacity_activation(self._opacity) 85 | 86 | def get_covariance(self, scaling_modifier=1): 87 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 88 | 89 | def get_normal_axis(self, dir_pp_normalized=None, return_delta=False): 90 | normal_axis = self.get_minimum_axis 91 | normal_axis, positive = flip_align_view(normal_axis, dir_pp_normalized) 92 | normal = normal_axis / normal_axis.norm(dim=1, keepdim=True) # (N, 3) 93 | return normal 94 | 95 | @property 96 | def get_minimum_axis(self): 97 | return get_minimum_axis(self.get_scaling, self.get_rotation) 98 | 99 | def oneupSHdegree(self): 100 | if self.active_sh_degree < self.max_sh_degree: 101 | self.active_sh_degree += 1 102 | 103 | def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float): 104 | self.spatial_lr_scale = 5 105 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 106 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) 107 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() 108 | features[:, :3, 0] = fused_color 109 | features[:, 3:, 1:] = 0.0 110 | asg_features = torch.zeros((fused_color.shape[0], self.max_asg_degree)).float().cuda() 111 | 112 | print("Number of points at initialisation : ", fused_point_cloud.shape[0]) 113 | 114 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 115 | scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3) 116 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 117 | rots[:, 0] = 1 118 | 119 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 120 | 121 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 122 | self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)) 123 | self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)) 124 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 125 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 126 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 127 | self._features_asg = nn.Parameter(asg_features.requires_grad_(True)) 128 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 129 | 130 | def training_setup(self, training_args): 131 | self.percent_dense = training_args.percent_dense 132 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 133 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 134 | 135 | self.spatial_lr_scale = 5 136 | 137 | l = [ 138 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, 139 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 140 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 141 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 142 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, 143 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}, 144 | {'params': [self._features_asg], 'lr': training_args.feature_lr, "name": "f_asg"}, 145 | ] 146 | 147 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 148 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale, 149 | lr_final=training_args.position_lr_final * self.spatial_lr_scale, 150 | lr_delay_mult=training_args.position_lr_delay_mult, 151 | max_steps=training_args.position_lr_max_steps) 152 | 153 | def update_learning_rate(self, iteration): 154 | ''' Learning rate scheduling per step ''' 155 | for param_group in self.optimizer.param_groups: 156 | if param_group["name"] == "xyz": 157 | lr = self.xyz_scheduler_args(iteration) 158 | param_group['lr'] = lr 159 | 160 | def construct_list_of_attributes(self): 161 | l = ['x', 'y', 'z'] 162 | # All channels except the 3 DC 163 | for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): 164 | l.append('f_dc_{}'.format(i)) 165 | for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]): 166 | l.append('f_rest_{}'.format(i)) 167 | l.append('opacity') 168 | for i in range(self._scaling.shape[1]): 169 | l.append('scale_{}'.format(i)) 170 | for i in range(self._rotation.shape[1]): 171 | l.append('rot_{}'.format(i)) 172 | for i in range(self._features_asg.shape[1]): 173 | l.append('f_asg_{}'.format(i)) 174 | return l 175 | 176 | def save_ply(self, path): 177 | mkdir_p(os.path.dirname(path)) 178 | 179 | xyz = self._xyz.detach().cpu().numpy() 180 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 181 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 182 | opacities = self._opacity.detach().cpu().numpy() 183 | scale = self._scaling.detach().cpu().numpy() 184 | rotation = self._rotation.detach().cpu().numpy() 185 | f_asg = self._features_asg.detach().cpu().numpy() 186 | 187 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 188 | 189 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 190 | attributes = np.concatenate( 191 | (xyz, f_dc, f_rest, opacities, scale, rotation, f_asg), axis=1) 192 | elements[:] = list(map(tuple, attributes)) 193 | el = PlyElement.describe(elements, 'vertex') 194 | PlyData([el]).write(path) 195 | 196 | def reset_opacity(self): 197 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)) 198 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 199 | self._opacity = optimizable_tensors["opacity"] 200 | 201 | def load_ply(self, path, og_number_points=-1): 202 | self.og_number_points = og_number_points 203 | plydata = PlyData.read(path) 204 | 205 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 206 | np.asarray(plydata.elements[0]["y"]), 207 | np.asarray(plydata.elements[0]["z"])), axis=1) 208 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 209 | 210 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 211 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 212 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 213 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 214 | 215 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 216 | assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3 217 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 218 | for idx, attr_name in enumerate(extra_f_names): 219 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 220 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 221 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 222 | 223 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 224 | scales = np.zeros((xyz.shape[0], len(scale_names))) 225 | for idx, attr_name in enumerate(scale_names): 226 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 227 | 228 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 229 | rots = np.zeros((xyz.shape[0], len(rot_names))) 230 | for idx, attr_name in enumerate(rot_names): 231 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 232 | 233 | asg_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_asg_")] 234 | f_asgs = np.zeros((xyz.shape[0], len(asg_names))) 235 | for idx, attr_name in enumerate(asg_names): 236 | f_asgs[:, idx] = np.asarray(plydata.elements[0][attr_name]) 237 | 238 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 239 | self._features_dc = nn.Parameter( 240 | torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_( 241 | True)) 242 | self._features_rest = nn.Parameter( 243 | torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_( 244 | True)) 245 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 246 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 247 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 248 | self._features_asg = nn.Parameter(torch.tensor(f_asgs, dtype=torch.float, device="cuda").requires_grad_(True)) 249 | self.active_sh_degree = self.max_sh_degree 250 | 251 | def replace_tensor_to_optimizer(self, tensor, name): 252 | optimizable_tensors = {} 253 | for group in self.optimizer.param_groups: 254 | if group["name"] == name: 255 | stored_state = self.optimizer.state.get(group['params'][0], None) 256 | stored_state["exp_avg"] = torch.zeros_like(tensor) 257 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 258 | 259 | del self.optimizer.state[group['params'][0]] 260 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 261 | self.optimizer.state[group['params'][0]] = stored_state 262 | 263 | optimizable_tensors[group["name"]] = group["params"][0] 264 | return optimizable_tensors 265 | 266 | def _prune_optimizer(self, mask): 267 | optimizable_tensors = {} 268 | for group in self.optimizer.param_groups: 269 | stored_state = self.optimizer.state.get(group['params'][0], None) 270 | if stored_state is not None: 271 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 272 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 273 | 274 | del self.optimizer.state[group['params'][0]] 275 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 276 | self.optimizer.state[group['params'][0]] = stored_state 277 | 278 | optimizable_tensors[group["name"]] = group["params"][0] 279 | else: 280 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 281 | optimizable_tensors[group["name"]] = group["params"][0] 282 | return optimizable_tensors 283 | 284 | def prune_points(self, mask): 285 | valid_points_mask = ~mask 286 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 287 | 288 | self._xyz = optimizable_tensors["xyz"] 289 | self._features_dc = optimizable_tensors["f_dc"] 290 | self._features_rest = optimizable_tensors["f_rest"] 291 | self._opacity = optimizable_tensors["opacity"] 292 | self._scaling = optimizable_tensors["scaling"] 293 | self._rotation = optimizable_tensors["rotation"] 294 | self._features_asg = optimizable_tensors["f_asg"] 295 | 296 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 297 | 298 | self.denom = self.denom[valid_points_mask] 299 | self.max_radii2D = self.max_radii2D[valid_points_mask] 300 | 301 | def cat_tensors_to_optimizer(self, tensors_dict): 302 | optimizable_tensors = {} 303 | for group in self.optimizer.param_groups: 304 | assert len(group["params"]) == 1 305 | extension_tensor = tensors_dict[group["name"]] 306 | stored_state = self.optimizer.state.get(group['params'][0], None) 307 | if stored_state is not None: 308 | 309 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), 310 | dim=0) 311 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), 312 | dim=0) 313 | 314 | del self.optimizer.state[group['params'][0]] 315 | group["params"][0] = nn.Parameter( 316 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 317 | self.optimizer.state[group['params'][0]] = stored_state 318 | 319 | optimizable_tensors[group["name"]] = group["params"][0] 320 | else: 321 | group["params"][0] = nn.Parameter( 322 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 323 | optimizable_tensors[group["name"]] = group["params"][0] 324 | 325 | return optimizable_tensors 326 | 327 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, 328 | new_rotation, new_feature_asg): 329 | d = {"xyz": new_xyz, 330 | "f_dc": new_features_dc, 331 | "f_rest": new_features_rest, 332 | "opacity": new_opacities, 333 | "scaling": new_scaling, 334 | "rotation": new_rotation, 335 | "f_asg": new_feature_asg} 336 | 337 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 338 | self._xyz = optimizable_tensors["xyz"] 339 | self._features_dc = optimizable_tensors["f_dc"] 340 | self._features_rest = optimizable_tensors["f_rest"] 341 | self._opacity = optimizable_tensors["opacity"] 342 | self._scaling = optimizable_tensors["scaling"] 343 | self._rotation = optimizable_tensors["rotation"] 344 | self._features_asg = optimizable_tensors["f_asg"] 345 | 346 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 347 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 348 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 349 | 350 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 351 | n_init_points = self.get_xyz.shape[0] 352 | # Extract points that satisfy the gradient condition 353 | padded_grad = torch.zeros((n_init_points), device="cuda") 354 | padded_grad[:grads.shape[0]] = grads.squeeze() 355 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 356 | selected_pts_mask = torch.logical_and(selected_pts_mask, 357 | torch.max(self.get_scaling, 358 | dim=1).values > self.percent_dense * scene_extent) 359 | 360 | stds = self.get_scaling[selected_pts_mask].repeat(N, 1) 361 | means = torch.zeros((stds.size(0), 3), device="cuda") 362 | samples = torch.normal(mean=means, std=stds) 363 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1) 364 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 365 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)) 366 | new_rotation = self._rotation[selected_pts_mask].repeat(N, 1) 367 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1) 368 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1) 369 | new_opacity = self._opacity[selected_pts_mask].repeat(N, 1) 370 | new_feature_asg = self._features_asg[selected_pts_mask].repeat(N, 1) 371 | 372 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, 373 | new_feature_asg) 374 | 375 | prune_filter = torch.cat( 376 | (selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 377 | self.prune_points(prune_filter) 378 | 379 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 380 | # Extract points that satisfy the gradient condition 381 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 382 | selected_pts_mask = torch.logical_and(selected_pts_mask, 383 | torch.max(self.get_scaling, 384 | dim=1).values <= self.percent_dense * scene_extent) 385 | 386 | new_xyz = self._xyz[selected_pts_mask] 387 | new_features_dc = self._features_dc[selected_pts_mask] 388 | new_features_rest = self._features_rest[selected_pts_mask] 389 | new_opacities = self._opacity[selected_pts_mask] 390 | new_scaling = self._scaling[selected_pts_mask] 391 | new_rotation = self._rotation[selected_pts_mask] 392 | new_feature_asg = self._features_asg[selected_pts_mask] 393 | 394 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, 395 | new_rotation, new_feature_asg) 396 | 397 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 398 | grads = self.xyz_gradient_accum / self.denom 399 | grads[grads.isnan()] = 0.0 400 | 401 | self.densify_and_clone(grads, max_grad, extent) 402 | self.densify_and_split(grads, max_grad, extent) 403 | 404 | prune_mask = (self.get_opacity < min_opacity).squeeze() 405 | if max_screen_size: 406 | big_points_vs = self.max_radii2D > max_screen_size 407 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 408 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 409 | self.prune_points(prune_mask) 410 | 411 | torch.cuda.empty_cache() 412 | 413 | def add_densification_stats(self, viewspace_point_tensor, update_filter, voxel_visible_mask, use_filter=False): 414 | if use_filter: 415 | self.xyz_gradient_accum[voxel_visible_mask] += torch.norm( 416 | viewspace_point_tensor.grad[update_filter, :2], dim=-1, 417 | keepdim=True) 418 | self.denom[voxel_visible_mask] += 1 419 | else: 420 | self.xyz_gradient_accum[update_filter] += torch.norm( 421 | viewspace_point_tensor.grad[update_filter, :2], dim=-1, 422 | keepdim=True) 423 | self.denom[update_filter] += 1 424 | -------------------------------------------------------------------------------- /scene/specular_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.spec_utils import SpecularNetwork, SpecularNetworkReal 5 | import os 6 | from utils.system_utils import searchForMaxIteration 7 | from utils.general_utils import get_expon_lr_func, get_linear_noise_func 8 | 9 | 10 | class SpecularModel: 11 | def __init__(self, is_real=False, is_indoor=False): 12 | self.specular = SpecularNetworkReal(is_indoor).cuda() if is_real else SpecularNetwork().cuda() 13 | self.optimizer = None 14 | self.spatial_lr_scale = 5 15 | 16 | def step(self, asg_feature, viewdir, normal): 17 | return self.specular(asg_feature, viewdir, normal) 18 | 19 | def train_setting(self, training_args): 20 | l = [ 21 | {'params': list(self.specular.parameters()), 22 | 'lr': training_args.feature_lr / 10, 23 | "name": "specular"} 24 | ] 25 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 26 | 27 | self.specular_scheduler_args = get_linear_noise_func(lr_init=training_args.feature_lr, 28 | lr_final=training_args.feature_lr / 20, 29 | lr_delay_mult=training_args.position_lr_delay_mult, 30 | max_steps=training_args.specular_lr_max_steps) 31 | 32 | def save_weights(self, model_path, iteration): 33 | out_weights_path = os.path.join(model_path, "specular/iteration_{}".format(iteration)) 34 | os.makedirs(out_weights_path, exist_ok=True) 35 | torch.save(self.specular.state_dict(), os.path.join(out_weights_path, 'specular.pth')) 36 | 37 | def load_weights(self, model_path, iteration=-1): 38 | if iteration == -1: 39 | loaded_iter = searchForMaxIteration(os.path.join(model_path, "specular")) 40 | else: 41 | loaded_iter = iteration 42 | weights_path = os.path.join(model_path, "specular/iteration_{}/specular.pth".format(loaded_iter)) 43 | self.specular.load_state_dict(torch.load(weights_path)) 44 | 45 | def update_learning_rate(self, iteration): 46 | for param_group in self.optimizer.param_groups: 47 | if param_group["name"] == "specular": 48 | lr = self.specular_scheduler_args(iteration) 49 | param_group['lr'] = lr 50 | return lr 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim, l2_loss 16 | from gaussian_renderer import render, network_gui, prefilter_voxel 17 | import sys 18 | from scene import Scene, GaussianModel, SpecularModel 19 | from utils.general_utils import safe_state 20 | import uuid 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from utils.general_utils import get_linear_noise_func 24 | from argparse import ArgumentParser, Namespace 25 | from arguments import ModelParams, PipelineParams, OptimizationParams 26 | from render import render_sets 27 | from metrics import evaluate 28 | import lpips 29 | 30 | try: 31 | from torch.utils.tensorboard import SummaryWriter 32 | 33 | TENSORBOARD_FOUND = True 34 | except ImportError: 35 | TENSORBOARD_FOUND = False 36 | 37 | 38 | def training(dataset, opt, pipe, testing_iterations, saving_iterations): 39 | tb_writer = prepare_output_and_logger(dataset) 40 | gaussians = GaussianModel(dataset.sh_degree, dataset.asg_degree) 41 | specular_mlp = SpecularModel(dataset.is_real, dataset.is_indoor) 42 | specular_mlp.train_setting(opt) 43 | 44 | scene = Scene(dataset, gaussians) 45 | gaussians.training_setup(opt) 46 | 47 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 48 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 49 | 50 | iter_start = torch.cuda.Event(enable_timing=True) 51 | iter_end = torch.cuda.Event(enable_timing=True) 52 | 53 | viewpoint_stack = None 54 | ema_loss_for_log = 0.0 55 | best_psnr = 0.0 56 | best_iteration = 0 57 | last_ssim = 0 58 | last_lpips = 0 59 | use_filter = dataset.is_real 60 | progress_bar = tqdm(range(opt.iterations), desc="Training progress") 61 | voxel_visible_mask = None 62 | for iteration in range(1, opt.iterations + 1): 63 | 64 | iter_start.record() 65 | 66 | # Every 1000 its we increase the levels of SH up to a maximum degree 67 | if iteration % 1000 == 0: 68 | gaussians.oneupSHdegree() 69 | 70 | # Pick a random Camera 71 | if not viewpoint_stack: 72 | viewpoint_stack = scene.getTrainCameras().copy() 73 | 74 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) 75 | if dataset.load2gpu_on_the_fly: 76 | viewpoint_cam.load2device() 77 | 78 | N = gaussians.get_xyz.shape[0] 79 | 80 | if use_filter: 81 | voxel_visible_mask = prefilter_voxel(viewpoint_cam, gaussians, pipe, background) 82 | 83 | if iteration > 3000: 84 | dir_pp = (gaussians.get_xyz - viewpoint_cam.camera_center.repeat(gaussians.get_features.shape[0], 1)) 85 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 86 | normal = gaussians.get_normal_axis(dir_pp_normalized=dir_pp_normalized, return_delta=True) 87 | if use_filter: 88 | mlp_color = specular_mlp.step(gaussians.get_asg_features[voxel_visible_mask], 89 | dir_pp_normalized[voxel_visible_mask], 90 | normal.detach()[voxel_visible_mask]) 91 | else: 92 | mlp_color = specular_mlp.step(gaussians.get_asg_features, dir_pp_normalized, normal.detach()) 93 | else: 94 | mlp_color = 0 95 | 96 | render_pkg = render(viewpoint_cam, gaussians, pipe, background, mlp_color, 97 | voxel_visible_mask=voxel_visible_mask) 98 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg[ 99 | "viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 100 | 101 | # Loss 102 | gt_image = viewpoint_cam.original_image.cuda() 103 | Ll1 = l1_loss(image, gt_image) 104 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 105 | loss.backward() 106 | 107 | iter_end.record() 108 | 109 | if dataset.load2gpu_on_the_fly: 110 | viewpoint_cam.load2device('cpu') 111 | 112 | with torch.no_grad(): 113 | # Progress bar 114 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 115 | if iteration % 10 == 0: 116 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 117 | progress_bar.update(10) 118 | if iteration == opt.iterations: 119 | progress_bar.close() 120 | 121 | # Keep track of max radii in image-space for pruning 122 | if use_filter: 123 | gaussians.max_radii2D[voxel_visible_mask] = torch.max( 124 | gaussians.max_radii2D[voxel_visible_mask], 125 | radii[visibility_filter]) 126 | else: 127 | gaussians.max_radii2D[visibility_filter] = torch.max( 128 | gaussians.max_radii2D[visibility_filter], 129 | radii[visibility_filter]) 130 | 131 | # Log and save 132 | cur_psnr = training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), 133 | testing_iterations, scene, render, (pipe, background), specular_mlp, 134 | dataset.load2gpu_on_the_fly, use_filter) 135 | 136 | if iteration in testing_iterations: 137 | if iteration == testing_iterations[-1]: 138 | cur_psnr, last_ssim, last_lpips = test_report(tb_writer, iteration, Ll1, loss, l1_loss, 139 | iter_start.elapsed_time(iter_end), 140 | testing_iterations, scene, render, (pipe, background), 141 | specular_mlp, 142 | dataset.load2gpu_on_the_fly, use_filter) 143 | if cur_psnr > best_psnr: 144 | best_psnr = cur_psnr 145 | best_iteration = iteration 146 | 147 | if iteration in saving_iterations: 148 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 149 | scene.save(iteration) 150 | specular_mlp.save_weights(args.model_path, iteration) 151 | 152 | # Densification 153 | if iteration < opt.densify_until_iter: 154 | viewspace_point_tensor_densify = render_pkg["viewspace_points_densify"] 155 | gaussians.add_densification_stats(viewspace_point_tensor_densify, visibility_filter, voxel_visible_mask, 156 | use_filter) 157 | 158 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 159 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 160 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 161 | 162 | if iteration % opt.opacity_reset_interval == 0 or ( 163 | dataset.white_background and iteration == opt.densify_from_iter): 164 | gaussians.reset_opacity() 165 | 166 | # Optimizer step 167 | if iteration < opt.iterations: 168 | gaussians.optimizer.step() 169 | gaussians.update_learning_rate(iteration) 170 | specular_mlp.optimizer.step() 171 | gaussians.optimizer.zero_grad(set_to_none=True) 172 | specular_mlp.optimizer.zero_grad() 173 | specular_mlp.update_learning_rate(iteration) 174 | 175 | print("Best PSNR = {} in Iteration {}, SSIM = {}, LPIPS = {}".format(best_psnr, best_iteration, last_ssim, 176 | last_lpips)) 177 | 178 | 179 | def prepare_output_and_logger(args): 180 | if not args.model_path: 181 | if os.getenv('OAR_JOB_ID'): 182 | unique_str = os.getenv('OAR_JOB_ID') 183 | else: 184 | unique_str = str(uuid.uuid4()) 185 | args.model_path = os.path.join("./output/", unique_str[0:10]) 186 | 187 | # Set up output folder 188 | print("Output folder: {}".format(args.model_path)) 189 | os.makedirs(args.model_path, exist_ok=True) 190 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 191 | cfg_log_f.write(str(Namespace(**vars(args)))) 192 | 193 | # Create Tensorboard writer 194 | tb_writer = None 195 | if TENSORBOARD_FOUND: 196 | tb_writer = SummaryWriter(args.model_path) 197 | else: 198 | print("Tensorboard not available: not logging progress") 199 | return tb_writer 200 | 201 | 202 | def test_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc, 203 | renderArgs, specular_mlp, load2gpu_on_the_fly, use_filter): 204 | if tb_writer: 205 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 206 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 207 | tb_writer.add_scalar('iter_time', elapsed, iteration) 208 | 209 | l1_test = 0.0 210 | psnr_test = 0.0 211 | ssim_test = 0.0 212 | lpips_test = 0.0 213 | voxel_visible_mask = None 214 | lpips_fn = lpips.LPIPS(net='vgg').to('cuda') 215 | # Report test and samples of training set 216 | if iteration in testing_iterations: 217 | torch.cuda.empty_cache() 218 | config = {'name': 'test', 'cameras': scene.getTestCameras()} 219 | 220 | if config['cameras'] and len(config['cameras']) > 0: 221 | images = torch.tensor([], device="cuda") 222 | gts = torch.tensor([], device="cuda") 223 | for idx, viewpoint in enumerate(config['cameras']): 224 | if load2gpu_on_the_fly: 225 | viewpoint.load2device() 226 | 227 | if use_filter: 228 | voxel_visible_mask = prefilter_voxel(viewpoint, scene.gaussians, *renderArgs) 229 | dir_pp = (scene.gaussians.get_xyz - viewpoint.camera_center.repeat( 230 | scene.gaussians.get_features.shape[0], 1)) 231 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 232 | normal = scene.gaussians.get_normal_axis(dir_pp_normalized=dir_pp_normalized, return_delta=True) 233 | if use_filter: 234 | mlp_color = specular_mlp.step(scene.gaussians.get_asg_features[voxel_visible_mask], 235 | dir_pp_normalized[voxel_visible_mask], normal[voxel_visible_mask]) 236 | else: 237 | mlp_color = specular_mlp.step(scene.gaussians.get_asg_features, dir_pp_normalized, normal) 238 | 239 | image = torch.clamp( 240 | renderFunc(viewpoint, scene.gaussians, *renderArgs, mlp_color, 241 | voxel_visible_mask=voxel_visible_mask)["render"], 0.0, 1.0) 242 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 243 | 244 | l1_test += l1_loss(image, gt_image).mean().double() 245 | psnr_test += psnr(image, gt_image).mean().double() 246 | ssim_test += ssim(image, gt_image).mean().double() 247 | lpips_test += lpips_fn(image, gt_image).mean().double() 248 | 249 | if load2gpu_on_the_fly: 250 | viewpoint.load2device('cpu') 251 | if tb_writer and (idx < 5): 252 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), 253 | image[None], global_step=iteration) 254 | if iteration == testing_iterations[0]: 255 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), 256 | gt_image[None], global_step=iteration) 257 | 258 | l1_test /= len(config['cameras']) 259 | psnr_test /= len(config['cameras']) 260 | ssim_test /= len(config['cameras']) 261 | lpips_test /= len(config['cameras']) 262 | 263 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 264 | if tb_writer: 265 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 266 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 267 | 268 | if tb_writer: 269 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 270 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 271 | torch.cuda.empty_cache() 272 | 273 | return psnr_test, ssim_test, lpips_test 274 | 275 | 276 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc, 277 | renderArgs, specular_mlp, load2gpu_on_the_fly, use_filter): 278 | if tb_writer: 279 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 280 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 281 | tb_writer.add_scalar('iter_time', elapsed, iteration) 282 | 283 | test_psnr = 0.0 284 | voxel_visible_mask = None 285 | # Report test and samples of training set 286 | if iteration in testing_iterations[:-1]: 287 | torch.cuda.empty_cache() 288 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()}, 289 | {'name': 'train', 290 | 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in 291 | range(5, 30, 5)]}) 292 | 293 | for config in validation_configs: 294 | if config['cameras'] and len(config['cameras']) > 0: 295 | images = torch.tensor([], device="cuda") 296 | gts = torch.tensor([], device="cuda") 297 | for idx, viewpoint in enumerate(config['cameras']): 298 | if load2gpu_on_the_fly: 299 | viewpoint.load2device() 300 | 301 | if use_filter: 302 | voxel_visible_mask = prefilter_voxel(viewpoint, scene.gaussians, *renderArgs) 303 | dir_pp = (scene.gaussians.get_xyz - viewpoint.camera_center.repeat( 304 | scene.gaussians.get_features.shape[0], 1)) 305 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 306 | normal = scene.gaussians.get_normal_axis(dir_pp_normalized=dir_pp_normalized, return_delta=True) 307 | if use_filter: 308 | mlp_color = specular_mlp.step(scene.gaussians.get_asg_features[voxel_visible_mask], 309 | dir_pp_normalized[voxel_visible_mask], normal[voxel_visible_mask]) 310 | else: 311 | mlp_color = specular_mlp.step(scene.gaussians.get_asg_features, dir_pp_normalized, normal) 312 | 313 | image = torch.clamp( 314 | renderFunc(viewpoint, scene.gaussians, *renderArgs, mlp_color, 315 | voxel_visible_mask=voxel_visible_mask)["render"], 0.0, 1.0) 316 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 317 | images = torch.cat((images, image.unsqueeze(0)), dim=0) 318 | gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0) 319 | 320 | if load2gpu_on_the_fly: 321 | viewpoint.load2device('cpu') 322 | if tb_writer and (idx < 5): 323 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), 324 | image[None], global_step=iteration) 325 | if iteration == testing_iterations[0]: 326 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), 327 | gt_image[None], global_step=iteration) 328 | 329 | l1_test = l1_loss(images, gts) 330 | psnr_test = psnr(images, gts).mean() 331 | if config['name'] == 'test' or len(validation_configs[0]['cameras']) == 0: 332 | test_psnr = psnr_test 333 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 334 | if tb_writer: 335 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 336 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 337 | 338 | if tb_writer: 339 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 340 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 341 | torch.cuda.empty_cache() 342 | 343 | return test_psnr 344 | 345 | 346 | if __name__ == "__main__": 347 | # Set up command line argument parser 348 | parser = ArgumentParser(description="Training script parameters") 349 | lp = ModelParams(parser) 350 | op = OptimizationParams(parser) 351 | pp = PipelineParams(parser) 352 | parser.add_argument('--ip', type=str, default="127.0.0.1") 353 | parser.add_argument('--port', type=int, default=6009) 354 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 355 | parser.add_argument("--test_iterations", nargs="+", type=int, 356 | default=[7_000] + list(range(20000, 30001, 1000))) 357 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) 358 | parser.add_argument("--quiet", action="store_true") 359 | args = parser.parse_args(sys.argv[1:]) 360 | args.save_iterations.append(args.iterations) 361 | 362 | print("Optimizing " + args.model_path) 363 | 364 | # Initialize system state (RNG) 365 | safe_state(args.quiet) 366 | 367 | # Start GUI server, configure and run training 368 | # network_gui.init(args.ip, args.port) 369 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 370 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations) 371 | 372 | # All done 373 | print("\nTraining complete.") 374 | 375 | # rendering 376 | print(f'\nStarting Rendering~') 377 | render_sets(lp.extract(args), -1, op.extract(args), pp.extract(args), skip_train=True, skip_test=False, 378 | mode="render") 379 | print("\nRendering complete.") 380 | 381 | # calc metrics 382 | # print("\nStarting evaluation...") 383 | # evaluate([str(args.model_path)]) 384 | # print("\nEvaluating complete.") 385 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch, ArrayToTorch 15 | from utils.graphics_utils import fov2focal 16 | import json 17 | 18 | WARNED = False 19 | 20 | 21 | def loadCam(args, id, cam_info, resolution_scale): 22 | orig_w, orig_h = cam_info.image.size 23 | 24 | if args.resolution in [1, 2, 4, 8]: 25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round( 26 | orig_h / (resolution_scale * args.resolution)) 27 | else: # should be a type that converts to float 28 | if args.resolution == -1: 29 | if orig_w > 1600: 30 | global WARNED 31 | if not WARNED: 32 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 33 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 34 | WARNED = True 35 | global_down = orig_w / 1600 36 | else: 37 | global_down = 1 38 | else: 39 | global_down = orig_w / args.resolution 40 | 41 | scale = float(global_down) * float(resolution_scale) 42 | resolution = (int(orig_w / scale), int(orig_h / scale)) 43 | 44 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 45 | 46 | gt_image = resized_image_rgb[:3, ...] 47 | loaded_mask = None 48 | 49 | if resized_image_rgb.shape[1] == 4: 50 | loaded_mask = resized_image_rgb[3:4, ...] 51 | 52 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 53 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 54 | image=gt_image, gt_alpha_mask=loaded_mask, 55 | image_name=cam_info.image_name, uid=id, 56 | data_device=args.data_device if not args.load2gpu_on_the_fly else 'cpu', depth=cam_info.depth) 57 | 58 | 59 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 60 | camera_list = [] 61 | 62 | for id, c in enumerate(cam_infos): 63 | camera_list.append(loadCam(args, id, c, resolution_scale)) 64 | 65 | return camera_list 66 | 67 | 68 | def camera_to_JSON(id, camera: Camera): 69 | Rt = np.zeros((4, 4)) 70 | Rt[:3, :3] = camera.R.transpose() 71 | Rt[:3, 3] = camera.T 72 | Rt[3, 3] = 1.0 73 | 74 | W2C = np.linalg.inv(Rt) 75 | pos = W2C[:3, 3] 76 | rot = W2C[:3, :3] 77 | serializable_array_2d = [x.tolist() for x in rot] 78 | camera_entry = { 79 | 'id': id, 80 | 'img_name': camera.image_name, 81 | 'width': camera.width, 82 | 'height': camera.height, 83 | 'position': pos.tolist(), 84 | 'rotation': serializable_array_2d, 85 | 'fy': fov2focal(camera.FovY, camera.height), 86 | 'fx': fov2focal(camera.FovX, camera.width) 87 | } 88 | return camera_entry 89 | 90 | 91 | def camera_nerfies_from_JSON(path, scale): 92 | """Loads a JSON camera into memory.""" 93 | with open(path, 'r') as fp: 94 | camera_json = json.load(fp) 95 | 96 | # Fix old camera JSON. 97 | if 'tangential' in camera_json: 98 | camera_json['tangential_distortion'] = camera_json['tangential'] 99 | 100 | return dict( 101 | orientation=np.array(camera_json['orientation']), 102 | position=np.array(camera_json['position']), 103 | focal_length=camera_json['focal_length'] * scale, 104 | principal_point=np.array(camera_json['principal_point']) * scale, 105 | skew=camera_json['skew'], 106 | pixel_aspect_ratio=camera_json['pixel_aspect_ratio'], 107 | radial_distortion=np.array(camera_json['radial_distortion']), 108 | tangential_distortion=np.array(camera_json['tangential_distortion']), 109 | image_size=np.array((int(round(camera_json['image_size'][0] * scale)), 110 | int(round(camera_json['image_size'][1] * scale)))), 111 | ) 112 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x / (1 - x)) 21 | 22 | 23 | def PILtoTorch(pil_image, resolution): 24 | resized_image_PIL = pil_image.resize(resolution) 25 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 26 | if len(resized_image.shape) == 3: 27 | return resized_image.permute(2, 0, 1) 28 | else: 29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 30 | 31 | 32 | def ArrayToTorch(array, resolution): 33 | # resized_image = np.resize(array, resolution) 34 | resized_image_torch = torch.from_numpy(array) 35 | 36 | if len(resized_image_torch.shape) == 3: 37 | return resized_image_torch.permute(2, 0, 1) 38 | else: 39 | return resized_image_torch.unsqueeze(dim=-1).permute(2, 0, 1) 40 | 41 | 42 | def get_expon_lr_func( 43 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 44 | ): 45 | """ 46 | Copied from Plenoxels 47 | 48 | Continuous learning rate decay function. Adapted from JaxNeRF 49 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 50 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 51 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 52 | function of lr_delay_mult, such that the initial learning rate is 53 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 54 | to the normal learning rate when steps>lr_delay_steps. 55 | :param conf: config subtree 'lr' or similar 56 | :param max_steps: int, the number of steps during optimization. 57 | :return HoF which takes step as input 58 | """ 59 | 60 | def helper(step): 61 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 62 | # Disable this parameter 63 | return 0.0 64 | if lr_delay_steps > 0: 65 | # A kind of reverse cosine decay. 66 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 67 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 68 | ) 69 | else: 70 | delay_rate = 1.0 71 | t = np.clip(step / max_steps, 0, 1) 72 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 73 | return delay_rate * log_lerp 74 | 75 | return helper 76 | 77 | 78 | def get_linear_noise_func( 79 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 80 | ): 81 | """ 82 | Copied from Plenoxels 83 | 84 | Continuous learning rate decay function. Adapted from JaxNeRF 85 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 86 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 87 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 88 | function of lr_delay_mult, such that the initial learning rate is 89 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 90 | to the normal learning rate when steps>lr_delay_steps. 91 | :param conf: config subtree 'lr' or similar 92 | :param max_steps: int, the number of steps during optimization. 93 | :return HoF which takes step as input 94 | """ 95 | 96 | def helper(step): 97 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 98 | # Disable this parameter 99 | return 0.0 100 | if lr_delay_steps > 0: 101 | # A kind of reverse cosine decay. 102 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 103 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 104 | ) 105 | else: 106 | delay_rate = 1.0 107 | t = np.clip(step / max_steps, 0, 1) 108 | log_lerp = lr_init * (1 - t) + lr_final * t 109 | return delay_rate * log_lerp 110 | 111 | return helper 112 | 113 | 114 | def strip_lowerdiag(L): 115 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 116 | 117 | uncertainty[:, 0] = L[:, 0, 0] 118 | uncertainty[:, 1] = L[:, 0, 1] 119 | uncertainty[:, 2] = L[:, 0, 2] 120 | uncertainty[:, 3] = L[:, 1, 1] 121 | uncertainty[:, 4] = L[:, 1, 2] 122 | uncertainty[:, 5] = L[:, 2, 2] 123 | return uncertainty 124 | 125 | 126 | def strip_symmetric(sym): 127 | return strip_lowerdiag(sym) 128 | 129 | 130 | def build_rotation(r): 131 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]) 132 | 133 | q = r / norm[:, None] 134 | 135 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 136 | 137 | r = q[:, 0] 138 | x = q[:, 1] 139 | y = q[:, 2] 140 | z = q[:, 3] 141 | 142 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 143 | R[:, 0, 1] = 2 * (x * y - r * z) 144 | R[:, 0, 2] = 2 * (x * z + r * y) 145 | R[:, 1, 0] = 2 * (x * y + r * z) 146 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 147 | R[:, 1, 2] = 2 * (y * z - r * x) 148 | R[:, 2, 0] = 2 * (x * z - r * y) 149 | R[:, 2, 1] = 2 * (y * z + r * x) 150 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 151 | return R 152 | 153 | 154 | def build_scaling_rotation(s, r): 155 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 156 | R = build_rotation(r) 157 | 158 | L[:, 0, 0] = s[:, 0] 159 | L[:, 1, 1] = s[:, 1] 160 | L[:, 2, 2] = s[:, 2] 161 | 162 | L = R @ L 163 | return L 164 | 165 | 166 | def safe_state(silent): 167 | old_f = sys.stdout 168 | 169 | class F: 170 | def __init__(self, silent): 171 | self.silent = silent 172 | 173 | def write(self, x): 174 | if not self.silent: 175 | if x.endswith("\n"): 176 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 177 | else: 178 | old_f.write(x) 179 | 180 | def flush(self): 181 | old_f.flush() 182 | 183 | sys.stdout = F(silent) 184 | 185 | random.seed(0) 186 | np.random.seed(0) 187 | torch.manual_seed(0) 188 | torch.cuda.set_device(torch.device("cuda:0")) 189 | 190 | 191 | def get_minimum_axis(scales, rotations): 192 | sorted_idx = torch.argsort(scales, descending=False, dim=-1) 193 | R = build_rotation(rotations) 194 | R_sorted = torch.gather(R, dim=1, index=sorted_idx[:, :, None].repeat(1, 1, 3)).squeeze() 195 | x_axis = R_sorted[:, 0, :] # normalized by defaut 196 | 197 | return x_axis 198 | 199 | 200 | def flip_align_view(normal, viewdir): 201 | # normal: (N, 3), viewdir: (N, 3) 202 | dotprod = torch.sum( 203 | normal * -viewdir, dim=-1, keepdims=True) # (N, 1) 204 | non_flip = dotprod >= 0 # (N, 1) 205 | normal_flipped = normal * torch.where(non_flip, 1, -1) # (N, 3) 206 | return normal_flipped, non_flip 207 | 208 | 209 | def depth2normal(depth: torch.Tensor, focal: float = None): 210 | if depth.dim() == 2: 211 | depth = depth[None, None] 212 | elif depth.dim() == 3: 213 | depth = depth.squeeze()[None, None] 214 | if focal is None: 215 | focal = depth.shape[-1] / 2 / np.tan(torch.pi / 6) 216 | depth = torch.cat([depth[:, :, :1], depth, depth[:, :, -1:]], dim=2) 217 | depth = torch.cat([depth[..., :1], depth, depth[..., -1:]], dim=3) 218 | kernel = torch.tensor([[[0, 0, 0], 219 | [-.5, 0, .5], 220 | [0, 0, 0]], 221 | [[0, -.5, 0], 222 | [0, 0, 0], 223 | [0, .5, 0]]], device=depth.device, dtype=depth.dtype)[:, None] 224 | normal = torch.nn.functional.conv2d(depth, kernel, padding='valid')[0].permute(1, 2, 0) 225 | normal = normal / (depth[0, 0, 1:-1, 1:-1, None] + 1e-10) * focal 226 | normal = torch.cat([normal, torch.ones_like(normal[..., :1])], dim=-1) 227 | normal = normal / normal.norm(dim=-1, keepdim=True) 228 | return normal.permute(2, 0, 1) 229 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | 18 | class BasicPointCloud(NamedTuple): 19 | points: np.array 20 | colors: np.array 21 | normals: np.array 22 | 23 | 24 | def geom_transform_points(points, transf_matrix): 25 | P, _ = points.shape 26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 27 | points_hom = torch.cat([points, ones], dim=1) 28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 29 | 30 | denom = points_out[..., 3:] + 0.0000001 31 | return (points_out[..., :3] / denom).squeeze(dim=0) 32 | 33 | 34 | def getWorld2View(R, t): 35 | Rt = np.zeros((4, 4)) 36 | Rt[:3, :3] = R.transpose() 37 | Rt[:3, 3] = t 38 | Rt[3, 3] = 1.0 39 | return np.float32(Rt) 40 | 41 | 42 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 43 | Rt = np.zeros((4, 4)) 44 | Rt[:3, :3] = R.transpose() 45 | Rt[:3, 3] = t 46 | Rt[3, 3] = 1.0 47 | 48 | C2W = np.linalg.inv(Rt) 49 | cam_center = C2W[:3, 3] 50 | cam_center = (cam_center + translate) * scale 51 | C2W[:3, 3] = cam_center 52 | Rt = np.linalg.inv(C2W) 53 | return np.float32(Rt) 54 | 55 | 56 | def getProjectionMatrix(znear, zfar, fovX, fovY): 57 | tanHalfFovY = math.tan((fovY / 2)) 58 | tanHalfFovX = math.tan((fovX / 2)) 59 | 60 | top = tanHalfFovY * znear 61 | bottom = -top 62 | right = tanHalfFovX * znear 63 | left = -right 64 | 65 | P = torch.zeros(4, 4) 66 | 67 | z_sign = 1.0 68 | 69 | P[0, 0] = 2.0 * znear / (right - left) 70 | P[1, 1] = 2.0 * znear / (top - bottom) 71 | P[0, 2] = (right + left) / (right - left) 72 | P[1, 2] = (top + bottom) / (top - bottom) 73 | P[3, 2] = z_sign 74 | P[2, 2] = z_sign * zfar / (zfar - znear) 75 | P[2, 3] = -(zfar * znear) / (zfar - znear) 76 | return P 77 | 78 | 79 | def fov2focal(fov, pixels): 80 | return pixels / (2 * math.tan(fov / 2)) 81 | 82 | 83 | def focal2fov(focal, pixels): 84 | return 2 * math.atan(pixels / (2 * focal)) 85 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | 15 | def mse(img1, img2): 16 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 17 | 18 | 19 | def psnr(img1, img2): 20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 21 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 22 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | 18 | def l1_loss(network_output, gt): 19 | return torch.abs((network_output - gt)).mean() 20 | 21 | 22 | def kl_divergence(rho, rho_hat): 23 | rho_hat = torch.mean(torch.sigmoid(rho_hat), 0) 24 | rho = torch.tensor([rho] * len(rho_hat)).cuda() 25 | return torch.mean( 26 | rho * torch.log(rho / (rho_hat + 1e-5)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + 1e-5))) 27 | 28 | 29 | def l2_loss(network_output, gt): 30 | return ((network_output - gt) ** 2).mean() 31 | 32 | 33 | def gaussian(window_size, sigma): 34 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 35 | return gauss / gauss.sum() 36 | 37 | 38 | def create_window(window_size, channel): 39 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 40 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 41 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 42 | return window 43 | 44 | 45 | def ssim(img1, img2, window_size=11, size_average=True): 46 | channel = img1.size(-3) 47 | window = create_window(window_size, channel) 48 | 49 | if img1.is_cuda: 50 | window = window.cuda(img1.get_device()) 51 | window = window.type_as(img1) 52 | 53 | return _ssim(img1, img2, window, window_size, channel, size_average) 54 | 55 | 56 | def _ssim(img1, img2, window, window_size, channel, size_average=True, stride=None): 57 | if stride: 58 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel, stride=stride) 59 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel, stride=stride) 60 | else: 61 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 62 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 63 | 64 | mu1_sq = mu1.pow(2) 65 | mu2_sq = mu2.pow(2) 66 | mu1_mu2 = mu1 * mu2 67 | 68 | if stride: 69 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel, stride=stride) - mu1_sq 70 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel, stride=stride) - mu2_sq 71 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel, stride=stride) - mu1_mu2 72 | else: 73 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 74 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 75 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 76 | 77 | C1 = 0.01 ** 2 78 | C2 = 0.03 ** 2 79 | 80 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 81 | 82 | if size_average: 83 | return ssim_map.mean() 84 | else: 85 | return ssim_map.mean(1).mean(1).mean(1) 86 | 87 | 88 | class SSIM(torch.nn.Module): 89 | def __init__(self, window_size=3, size_average=True, stride=3): 90 | super(SSIM, self).__init__() 91 | self.window_size = window_size 92 | self.size_average = size_average 93 | self.channel = 1 94 | self.stride = stride 95 | self.window = create_window(window_size, self.channel) 96 | 97 | def forward(self, img1, img2): 98 | """ 99 | img1, img2: torch.Tensor([b,c,h,w]) 100 | """ 101 | (_, channel, _, _) = img1.size() 102 | 103 | if channel == self.channel and self.window.data.type() == img1.data.type(): 104 | window = self.window 105 | else: 106 | window = create_window(self.window_size, channel) 107 | 108 | if img1.is_cuda: 109 | window = window.cuda(img1.get_device()) 110 | window = window.type_as(img1) 111 | 112 | self.window = window 113 | self.channel = channel 114 | 115 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average, stride=self.stride) 116 | 117 | 118 | class S3IM(torch.nn.Module): 119 | r"""Implements Stochastic Structural SIMilarity(S3IM) algorithm. 120 | It is proposed in the ICCV2023 paper 121 | `S3IM: Stochastic Structural SIMilarity and Its Unreasonable Effectiveness for Neural Fields`. 122 | 123 | Arguments: 124 | kernel_size (int): kernel size in ssim's convolution(default: 4) 125 | stride (int): stride in ssim's convolution(default: 4) 126 | repeat_time (int): repeat time in re-shuffle virtual patch(default: 10) 127 | patch_height (height): height of virtual patch(default: 64) 128 | patch_width (height): width of virtual patch(default: 64) 129 | """ 130 | 131 | def __init__(self, kernel_size=4, stride=4, repeat_time=10, patch_height=64, patch_width=64): 132 | super(S3IM, self).__init__() 133 | self.kernel_size = kernel_size 134 | self.stride = stride 135 | self.repeat_time = repeat_time 136 | self.patch_height = patch_height 137 | self.patch_width = patch_width 138 | self.ssim_loss = SSIM(window_size=self.kernel_size, stride=self.stride) 139 | 140 | def forward(self, src_vec, tar_vec): 141 | loss = 0.0 142 | index_list = [] 143 | for i in range(self.repeat_time): 144 | if i == 0: 145 | tmp_index = torch.arange(len(tar_vec)) 146 | index_list.append(tmp_index) 147 | else: 148 | ran_idx = torch.randperm(len(tar_vec)) 149 | index_list.append(ran_idx) 150 | res_index = torch.cat(index_list) 151 | tar_all = tar_vec[res_index] 152 | src_all = src_vec[res_index] 153 | tar_patch = tar_all.permute(1, 0).reshape(1, 3, self.patch_height, self.patch_width * self.repeat_time) 154 | src_patch = src_all.permute(1, 0).reshape(1, 3, self.patch_height, self.patch_width * self.repeat_time) 155 | loss = (1 - self.ssim_loss(src_patch, tar_patch)) 156 | return loss 157 | -------------------------------------------------------------------------------- /utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.graphics_utils import fov2focal, getWorld2View2 4 | 5 | trans_t = lambda t: torch.Tensor([ 6 | [1, 0, 0, 0], 7 | [0, 1, 0, 0], 8 | [0, 0, 1, t], 9 | [0, 0, 0, 1]]).float() 10 | 11 | rot_phi = lambda phi: torch.Tensor([ 12 | [1, 0, 0, 0], 13 | [0, np.cos(phi), -np.sin(phi), 0], 14 | [0, np.sin(phi), np.cos(phi), 0], 15 | [0, 0, 0, 1]]).float() 16 | 17 | rot_theta = lambda th: torch.Tensor([ 18 | [np.cos(th), 0, -np.sin(th), 0], 19 | [0, 1, 0, 0], 20 | [np.sin(th), 0, np.cos(th), 0], 21 | [0, 0, 0, 1]]).float() 22 | 23 | 24 | def rodrigues_mat_to_rot(R): 25 | eps = 1e-16 26 | trc = np.trace(R) 27 | trc2 = (trc - 1.) / 2. 28 | # sinacostrc2 = np.sqrt(1 - trc2 * trc2) 29 | s = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]) 30 | if (1 - trc2 * trc2) >= eps: 31 | tHeta = np.arccos(trc2) 32 | tHetaf = tHeta / (2 * (np.sin(tHeta))) 33 | else: 34 | tHeta = np.real(np.arccos(trc2)) 35 | tHetaf = 0.5 / (1 - tHeta / 6) 36 | omega = tHetaf * s 37 | return omega 38 | 39 | 40 | def rodrigues_rot_to_mat(r): 41 | wx, wy, wz = r 42 | theta = np.sqrt(wx * wx + wy * wy + wz * wz) 43 | a = np.cos(theta) 44 | b = (1 - np.cos(theta)) / (theta * theta) 45 | c = np.sin(theta) / theta 46 | R = np.zeros([3, 3]) 47 | R[0, 0] = a + b * (wx * wx) 48 | R[0, 1] = b * wx * wy - c * wz 49 | R[0, 2] = b * wx * wz + c * wy 50 | R[1, 0] = b * wx * wy + c * wz 51 | R[1, 1] = a + b * (wy * wy) 52 | R[1, 2] = b * wy * wz - c * wx 53 | R[2, 0] = b * wx * wz - c * wy 54 | R[2, 1] = b * wz * wy + c * wx 55 | R[2, 2] = a + b * (wz * wz) 56 | return R 57 | 58 | 59 | def pose_spherical(theta, phi, radius): 60 | c2w = trans_t(radius) 61 | c2w = rot_phi(phi / 180. * np.pi) @ c2w 62 | c2w = rot_theta(theta / 180. * np.pi) @ c2w 63 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w 64 | return c2w 65 | 66 | 67 | def render_wander_path(view): 68 | focal_length = fov2focal(view.FoVy, view.image_height) 69 | R = view.R 70 | R[:, 1] = -R[:, 1] 71 | R[:, 2] = -R[:, 2] 72 | T = -view.T.reshape(-1, 1) 73 | pose = np.concatenate([R, T], -1) 74 | 75 | num_frames = 60 76 | max_disp = 5000.0 # 64 , 48 77 | 78 | max_trans = max_disp / focal_length # Maximum camera translation to satisfy max_disp parameter 79 | output_poses = [] 80 | 81 | for i in range(num_frames): 82 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames)) 83 | y_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 # * 3.0 / 4.0 84 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 85 | 86 | i_pose = np.concatenate([ 87 | np.concatenate( 88 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1), 89 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :] 90 | ], axis=0) # [np.newaxis, :, :] 91 | 92 | i_pose = np.linalg.inv(i_pose) # torch.tensor(np.linalg.inv(i_pose)).float() 93 | 94 | ref_pose = np.concatenate([pose, np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0) 95 | 96 | render_pose = np.dot(ref_pose, i_pose) 97 | output_poses.append(torch.Tensor(render_pose)) 98 | 99 | return output_poses 100 | 101 | 102 | def integrate_weights_np(w): 103 | """Compute the cumulative sum of w, assuming all weight vectors sum to 1. 104 | 105 | The output's size on the last dimension is one greater than that of the input, 106 | because we're computing the integral corresponding to the endpoints of a step 107 | function, not the integral of the interior/bin values. 108 | 109 | Args: 110 | w: Tensor, which will be integrated along the last axis. This is assumed to 111 | sum to 1 along the last axis, and this function will (silently) break if 112 | that is not the case. 113 | 114 | Returns: 115 | cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 116 | """ 117 | cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1)) 118 | shape = cw.shape[:-1] + (1,) 119 | # Ensure that the CDF starts with exactly 0 and ends with exactly 1. 120 | cw0 = np.concatenate([np.zeros(shape), cw, 121 | np.ones(shape)], axis=-1) 122 | return cw0 123 | 124 | 125 | def invert_cdf_np(u, t, w_logits): 126 | """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" 127 | # Compute the PDF and CDF for each weight vector. 128 | w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True) 129 | cw = integrate_weights_np(w) 130 | # Interpolate into the inverse CDF. 131 | interp_fn = np.interp 132 | t_new = interp_fn(u, cw, t) 133 | return t_new 134 | 135 | 136 | def sample_np(rand, 137 | t, 138 | w_logits, 139 | num_samples, 140 | single_jitter=False, 141 | deterministic_center=False): 142 | """ 143 | numpy version of sample() 144 | """ 145 | eps = np.finfo(np.float32).eps 146 | 147 | # Draw uniform samples. 148 | if not rand: 149 | if deterministic_center: 150 | pad = 1 / (2 * num_samples) 151 | u = np.linspace(pad, 1. - pad - eps, num_samples) 152 | else: 153 | u = np.linspace(0, 1. - eps, num_samples) 154 | u = np.broadcast_to(u, t.shape[:-1] + (num_samples,)) 155 | else: 156 | # `u` is in [0, 1) --- it can be zero, but it can never be 1. 157 | u_max = eps + (1 - eps) / num_samples 158 | max_jitter = (1 - u_max) / (num_samples - 1) - eps 159 | d = 1 if single_jitter else num_samples 160 | u = np.linspace(0, 1 - u_max, num_samples) + \ 161 | np.random.rand(*t.shape[:-1], d) * max_jitter 162 | 163 | return invert_cdf_np(u, t, w_logits) 164 | 165 | 166 | def pad_poses(p): 167 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" 168 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) 169 | return np.concatenate([p[..., :3, :4], bottom], axis=-2) 170 | 171 | 172 | def unpad_poses(p): 173 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" 174 | return p[..., :3, :4] 175 | 176 | 177 | def transform_poses_pca(poses): 178 | """Transforms poses so principal components lie on XYZ axes. 179 | 180 | Args: 181 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms. 182 | 183 | Returns: 184 | A tuple (poses, transform), with the transformed poses and the applied 185 | camera_to_world transforms. 186 | """ 187 | t = poses[:, :3, 3] 188 | t_mean = t.mean(axis=0) 189 | t = t - t_mean 190 | 191 | eigval, eigvec = np.linalg.eig(t.T @ t) 192 | # Sort eigenvectors in order of largest to smallest eigenvalue. 193 | inds = np.argsort(eigval)[::-1] 194 | eigvec = eigvec[:, inds] 195 | rot = eigvec.T 196 | if np.linalg.det(rot) < 0: 197 | rot = np.diag(np.array([1, 1, -1])) @ rot 198 | 199 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) 200 | poses_recentered = unpad_poses(transform @ pad_poses(poses)) 201 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) 202 | 203 | # Flip coordinate system if z component of y-axis is negative 204 | if poses_recentered.mean(axis=0)[2, 1] < 0: 205 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered 206 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform 207 | 208 | # Just make sure it's it in the [-1, 1]^3 cube 209 | scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) 210 | poses_recentered[:, :3, 3] *= scale_factor 211 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform 212 | return poses_recentered, transform 213 | 214 | 215 | def focus_point_fn(poses): 216 | """Calculate nearest point to all focal axes in poses.""" 217 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] 218 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) 219 | mt_m = np.transpose(m, [0, 2, 1]) @ m 220 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] 221 | return focus_pt 222 | 223 | 224 | def normalize(x): 225 | return x / np.linalg.norm(x) 226 | 227 | 228 | def viewmatrix(z, up, pos): 229 | vec2 = normalize(z) 230 | vec1_avg = up 231 | vec0 = normalize(np.cross(vec1_avg, vec2)) 232 | vec1 = normalize(np.cross(vec2, vec0)) 233 | m = np.stack([vec0, vec1, vec2, pos], 1) 234 | return m 235 | 236 | 237 | def generate_ellipse_path(views, n_frames=600, const_speed=True, z_variation=0., z_phase=0.): 238 | poses = [] 239 | for view in views: 240 | tmp_view = np.eye(4) 241 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1) 242 | tmp_view = np.linalg.inv(tmp_view) 243 | tmp_view[:, 1:3] *= -1 244 | poses.append(tmp_view) 245 | poses = np.stack(poses, 0) 246 | poses, transform = transform_poses_pca(poses) 247 | 248 | # Calculate the focal point for the path (cameras point toward this). 249 | center = focus_point_fn(poses) 250 | offset = np.array([center[0], center[1], center[2] * 0]) 251 | # Calculate scaling for ellipse axes based on input camera positions. 252 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) 253 | 254 | # Use ellipse that is symmetric about the focal point in xy. 255 | low = -sc + offset 256 | high = sc + offset 257 | # Optional height variation need not be symmetric 258 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) 259 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) 260 | 261 | def get_positions(theta): 262 | # Interpolate between bounds with trig functions to get ellipse in x-y. 263 | # Optionally also interpolate in z to change camera height along path. 264 | return np.stack([ 265 | (low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5)), 266 | (low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5)), 267 | z_variation * (z_low[2] + (z_high - z_low)[2] * 268 | (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)), 269 | ], -1) 270 | 271 | theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True) 272 | positions = get_positions(theta) 273 | 274 | if const_speed: 275 | # Resample theta angles so that the velocity is closer to constant. 276 | lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 277 | theta = sample_np(None, theta, np.log(lengths), n_frames + 1) 278 | positions = get_positions(theta) 279 | 280 | # Throw away duplicated last position. 281 | positions = positions[:-1] 282 | 283 | # Set path's up vector to axis closest to average of input pose up vectors. 284 | avg_up = poses[:, :3, 1].mean(0) 285 | avg_up = avg_up / np.linalg.norm(avg_up) 286 | ind_up = np.argmax(np.abs(avg_up)) 287 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 288 | 289 | render_poses = [] 290 | for p in positions: 291 | render_pose = np.eye(4) 292 | render_pose[:3] = viewmatrix(p - center, up, p) 293 | render_pose = np.linalg.inv(transform) @ render_pose 294 | render_pose[:3, 1:3] *= -1 295 | render_poses.append(np.linalg.inv(render_pose)) 296 | return render_poses 297 | -------------------------------------------------------------------------------- /utils/quaternion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def quaternion_product(p, q): 6 | p_r = p[..., [0]] 7 | p_i = p[..., 1:] 8 | q_r = q[..., [0]] 9 | q_i = q[..., 1:] 10 | 11 | out_r = p_r * q_r - (p_i * q_i).sum(dim=-1) 12 | out_i = p_r * q_i + q_r * p_i + torch.linalg.cross(p_i, q_i, dim=-1) 13 | 14 | return torch.cat([out_r, out_i], dim=-1) 15 | 16 | 17 | def quaternion_inverse(p): 18 | p_r = p[..., [0]] 19 | p_i = -p[..., 1:] 20 | 21 | return torch.cat([p_r, p_i], dim=-1) 22 | 23 | 24 | def quaternion_rotate(p, q): 25 | q_inv = quaternion_inverse(q) 26 | 27 | qp = quaternion_product(q, p) 28 | out = quaternion_product(qp, q_inv) 29 | return out 30 | 31 | 32 | def build_q(vec, angle): 33 | out_r = torch.cos(angle / 2) 34 | out_i = torch.sin(angle / 2) * vec 35 | 36 | return torch.cat([out_r, out_i], dim=-1) 37 | 38 | 39 | def cartesian2quaternion(x): 40 | zeros_ = x.new_zeros([*x.shape[:-1], 1]) 41 | return torch.cat([zeros_, x], dim=-1) 42 | 43 | 44 | def spherical2cartesian(theta, phi): 45 | x = torch.cos(phi) * torch.sin(theta) 46 | y = torch.sin(phi) * torch.sin(theta) 47 | z = torch.cos(theta) 48 | 49 | return [x, y, z] 50 | 51 | 52 | def init_predefined_omega(n_theta, n_phi): 53 | theta_list = torch.linspace(0, np.pi, n_theta) 54 | phi_list = torch.linspace(0, np.pi * 2, n_phi) 55 | 56 | out_omega = [] 57 | out_omega_lambda = [] 58 | out_omega_mu = [] 59 | 60 | for i in range(n_theta): 61 | theta = theta_list[i].view(1, 1) 62 | 63 | for j in range(n_phi): 64 | phi = phi_list[j].view(1, 1) 65 | 66 | omega = spherical2cartesian(theta, phi) 67 | omega = torch.stack(omega, dim=-1).view(1, 3) 68 | 69 | omega_lambda = spherical2cartesian(theta + np.pi / 2, phi) 70 | omega_lambda = torch.stack(omega_lambda, dim=-1).view(1, 3) 71 | 72 | p = cartesian2quaternion(omega_lambda) 73 | q = build_q(omega, torch.tensor(np.pi / 2).view(1, 1)) 74 | omega_mu = quaternion_rotate(p, q)[..., 1:] 75 | 76 | out_omega.append(omega) 77 | out_omega_lambda.append(omega_lambda) 78 | out_omega_mu.append(omega_mu) 79 | 80 | out_omega = torch.stack(out_omega, dim=0) 81 | out_omega_lambda = torch.stack(out_omega_lambda, dim=0) 82 | out_omega_mu = torch.stack(out_omega_mu, dim=0) 83 | 84 | return out_omega, out_omega_lambda, out_omega_mu 85 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | 115 | def RGB2SH(rgb): 116 | return (rgb - 0.5) / C0 117 | 118 | 119 | def SH2RGB(sh): 120 | return sh * C0 + 0.5 121 | -------------------------------------------------------------------------------- /utils/spec_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from utils.quaternion_utils import init_predefined_omega 6 | 7 | 8 | def get_embedder(multires, i=1): 9 | if i == -1: 10 | return nn.Identity(), 3 11 | 12 | embed_kwargs = { 13 | 'include_input': True, 14 | 'input_dims': i, 15 | 'max_freq_log2': multires - 1, 16 | 'num_freqs': multires, 17 | 'log_sampling': True, 18 | 'periodic_fns': [torch.sin, torch.cos], 19 | } 20 | 21 | embedder_obj = Embedder(**embed_kwargs) 22 | embed = lambda x, eo=embedder_obj: eo.embed(x) 23 | return embed, embedder_obj.out_dim 24 | 25 | 26 | # Positional encoding (section 5.1) 27 | class Embedder: 28 | def __init__(self, **kwargs): 29 | self.kwargs = kwargs 30 | self.create_embedding_fn() 31 | 32 | def create_embedding_fn(self): 33 | embed_fns = [] 34 | d = self.kwargs['input_dims'] 35 | out_dim = 0 36 | if self.kwargs['include_input']: 37 | embed_fns.append(lambda x: x) 38 | out_dim += d 39 | 40 | max_freq = self.kwargs['max_freq_log2'] 41 | N_freqs = self.kwargs['num_freqs'] 42 | 43 | if self.kwargs['log_sampling']: 44 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 45 | else: 46 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 47 | 48 | for freq in freq_bands: 49 | for p_fn in self.kwargs['periodic_fns']: 50 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 51 | out_dim += d 52 | 53 | self.embed_fns = embed_fns 54 | self.out_dim = out_dim 55 | 56 | def embed(self, inputs): 57 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 58 | 59 | 60 | def positional_encoding(positions, freqs): 61 | freq_bands = (2 ** torch.arange(freqs).float()).to(positions.device) # (F,) 62 | pts = (positions[..., None] * freq_bands).reshape( 63 | positions.shape[:-1] + (freqs * positions.shape[-1],)) # (..., DF) 64 | pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) 65 | return pts 66 | 67 | 68 | class RenderingEquationEncoding(torch.nn.Module): 69 | def __init__(self, num_theta, num_phi, device): 70 | super(RenderingEquationEncoding, self).__init__() 71 | 72 | self.num_theta = num_theta 73 | self.num_phi = num_phi 74 | 75 | omega, omega_la, omega_mu = init_predefined_omega(num_theta, num_phi) 76 | self.omega = omega.view(1, num_theta, num_phi, 3).to(device) 77 | self.omega_la = omega_la.view(1, num_theta, num_phi, 3).to(device) 78 | self.omega_mu = omega_mu.view(1, num_theta, num_phi, 3).to(device) 79 | 80 | def forward(self, omega_o, a, la, mu): 81 | Smooth = F.relu((omega_o[:, None, None] * self.omega).sum(dim=-1, keepdim=True)) # N, num_theta, num_phi, 1 82 | 83 | la = F.softplus(la - 1) 84 | mu = F.softplus(mu - 1) 85 | exp_input = -la * (self.omega_la * omega_o[:, None, None]).sum(dim=-1, keepdim=True).pow(2) - mu * ( 86 | self.omega_mu * omega_o[:, None, None]).sum(dim=-1, keepdim=True).pow(2) 87 | out = a * Smooth * torch.exp(exp_input) 88 | 89 | return out 90 | 91 | 92 | class SGEnvmap(torch.nn.Module): 93 | def __init__(self, numLgtSGs=32, device='cuda'): 94 | super(SGEnvmap, self).__init__() 95 | 96 | self.lgtSGs = nn.Parameter(torch.randn(numLgtSGs, 7).cuda()) # lobe + lambda + mu 97 | self.lgtSGs.data[..., 3:4] *= 100. 98 | self.lgtSGs.data[..., -3:] = 0. 99 | self.lgtSGs.requires_grad = True 100 | 101 | def forward(self, viewdirs): 102 | lgtSGLobes = self.lgtSGs[..., :3] / (torch.norm(self.lgtSGs[..., :3], dim=-1, keepdim=True) + 1e-7) 103 | lgtSGLambdas = torch.abs(self.lgtSGs[..., 3:4]) # sharpness 104 | lgtSGMus = torch.abs(self.lgtSGs[..., -3:]) # positive values 105 | pred_radiance = lgtSGMus[None] * torch.exp( 106 | lgtSGLambdas[None] * (torch.sum(viewdirs[:, None, :] * lgtSGLobes[None], dim=-1, keepdim=True) - 1.)) 107 | reflection = torch.sum(pred_radiance, dim=1) 108 | 109 | return reflection 110 | 111 | 112 | class ASGRender(torch.nn.Module): 113 | def __init__(self, viewpe=2, featureC=128, num_theta=4, num_phi=8): 114 | super(ASGRender, self).__init__() 115 | 116 | self.num_theta = num_theta 117 | self.num_phi = num_phi 118 | self.ch_normal_dot_viewdir = 1 119 | self.in_mlpC = 2 * viewpe * 3 + 3 + self.num_theta * self.num_phi * 2 + self.ch_normal_dot_viewdir 120 | self.viewpe = viewpe 121 | self.ree_function = RenderingEquationEncoding(self.num_theta, self.num_phi, 'cuda') 122 | 123 | layer1 = torch.nn.Linear(self.in_mlpC, featureC) 124 | layer2 = torch.nn.Linear(featureC, featureC) 125 | layer3 = torch.nn.Linear(featureC, 3) 126 | 127 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) 128 | torch.nn.init.constant_(self.mlp[-1].bias, 0) 129 | 130 | def reflect(self, viewdir, normal): 131 | out = 2 * (viewdir * normal).sum(dim=-1, keepdim=True) * normal - viewdir 132 | return out 133 | 134 | def safe_normalize(self, x, eps=1e-8): 135 | return x / (torch.norm(x, dim=-1, keepdim=True) + eps) 136 | 137 | def forward(self, pts, viewdirs, features, normal): 138 | asg_params = features.view(-1, self.num_theta, self.num_phi, 4) # [N, 8, 16, 4] 139 | a, la, mu = torch.split(asg_params, [2, 1, 1], dim=-1) 140 | 141 | reflect_dir = self.safe_normalize(self.reflect(-viewdirs, normal)) 142 | 143 | color_feature = self.ree_function(reflect_dir, a, la, mu) 144 | color_feature = color_feature.view(color_feature.size(0), -1) # [N, 256] 145 | 146 | normal_dot_viewdir = ((-viewdirs) * normal).sum(dim=-1, keepdim=True) # [N, 1] 147 | indata = [color_feature, normal_dot_viewdir] 148 | if self.viewpe > -1: 149 | indata += [viewdirs] 150 | if self.viewpe > 0: 151 | indata += [positional_encoding(viewdirs, self.viewpe)] 152 | mlp_in = torch.cat(indata, dim=-1) 153 | rgb = self.mlp(mlp_in) 154 | 155 | return rgb 156 | 157 | 158 | class ASGRenderReal(torch.nn.Module): 159 | def __init__(self, viewpe=2, featureC=32, num_theta=2, num_phi=4, is_indoor=False): 160 | super(ASGRenderReal, self).__init__() 161 | 162 | self.num_theta = num_theta 163 | self.num_phi = num_phi 164 | self.in_mlpC = 2 * viewpe * 3 + 3 + self.num_theta * self.num_phi * 2 165 | self.viewpe = viewpe 166 | self.ree_function = RenderingEquationEncoding(self.num_theta, self.num_phi, 'cuda') 167 | 168 | layer1 = torch.nn.Linear(self.in_mlpC, featureC) 169 | layer2 = torch.nn.Linear(featureC, featureC) 170 | layer3 = torch.nn.Linear(featureC, 3) 171 | 172 | if is_indoor: 173 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) 174 | else: 175 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer3) 176 | 177 | torch.nn.init.constant_(self.mlp[-1].bias, 0) 178 | 179 | def forward(self, pts, viewdirs, features, normal): 180 | asg_params = features.view(-1, self.num_theta, self.num_phi, 4) # [N, 8, 16, 4] 181 | a, la, mu = torch.split(asg_params, [2, 1, 1], dim=-1) 182 | 183 | color_feature = self.ree_function(viewdirs, a, la, mu) 184 | color_feature = color_feature.view(color_feature.size(0), -1) # [N, 256] 185 | 186 | indata = [color_feature] 187 | if self.viewpe > -1: 188 | indata += [viewdirs] 189 | if self.viewpe > 0: 190 | indata += [positional_encoding(viewdirs, self.viewpe)] 191 | mlp_in = torch.cat(indata, dim=-1) 192 | rgb = self.mlp(mlp_in) 193 | 194 | return rgb 195 | 196 | 197 | class SpecularNetwork(nn.Module): 198 | def __init__(self): 199 | super(SpecularNetwork, self).__init__() 200 | 201 | self.asg_feature = 24 202 | self.num_theta = 4 203 | self.num_phi = 8 204 | self.view_pe = 2 205 | self.hidden_feature = 128 206 | self.asg_hidden = self.num_theta * self.num_phi * 4 207 | 208 | self.gaussian_feature = nn.Linear(self.asg_feature, self.asg_hidden) 209 | 210 | self.render_module = ASGRender(self.view_pe, self.hidden_feature, self.num_theta, self.num_phi) 211 | 212 | def forward(self, x, view, normal): 213 | feature = self.gaussian_feature(x) 214 | spec = self.render_module(x, view, feature, normal) 215 | 216 | return spec 217 | 218 | 219 | class SpecularNetworkReal(nn.Module): 220 | def __init__(self, is_indoor=False): 221 | super(SpecularNetworkReal, self).__init__() 222 | 223 | self.asg_feature = 12 224 | self.num_theta = 2 225 | self.num_phi = 4 226 | self.view_pe = 2 227 | self.hidden_feature = 32 228 | self.asg_hidden = self.num_theta * self.num_phi * 4 229 | 230 | self.gaussian_feature = nn.Linear(self.asg_feature, self.asg_hidden) 231 | 232 | self.render_module = ASGRenderReal(self.view_pe, self.hidden_feature, self.num_theta, self.num_phi, is_indoor) 233 | 234 | def forward(self, x, view, normal): 235 | feature = self.gaussian_feature(x) 236 | spec = self.render_module(x, view, feature, normal) 237 | 238 | return spec 239 | 240 | 241 | class AnchorSpecularNetwork(nn.Module): 242 | def __init__(self, feature_dims): 243 | super(AnchorSpecularNetwork, self).__init__() 244 | 245 | self.asg_feature = feature_dims 246 | self.num_theta = 2 247 | self.num_phi = 4 248 | self.asg_hidden = self.num_theta * self.num_phi * 4 249 | 250 | self.gaussian_feature = nn.Linear(self.asg_feature + 3, self.asg_hidden) 251 | self.gaussian_diffuse = nn.Linear(self.asg_feature, 3) 252 | self.gaussian_normal = nn.Linear(self.asg_feature + 3, 3) 253 | 254 | self.render_module = ASGRender(2, 64, num_theta=2, num_phi=4) 255 | 256 | def forward(self, x, view, normal_center, offset): 257 | feature = self.gaussian_feature(torch.cat([x, view], dim=-1)) 258 | diffuse = self.gaussian_diffuse(x) 259 | normal_delta = self.gaussian_normal(torch.cat([x, offset], dim=-1)) 260 | normal = F.normalize(normal_center + normal_delta, dim=-1) 261 | spec = self.render_module(x, view, feature, normal) 262 | rgb = diffuse + spec 263 | 264 | return rgb 265 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | 17 | def mkdir_p(folder_path): 18 | # Creates a directory. equivalent to using mkdir -p on the command line 19 | try: 20 | makedirs(folder_path) 21 | except OSError as exc: # Python >2.5 22 | if exc.errno == EEXIST and path.isdir(folder_path): 23 | pass 24 | else: 25 | raise 26 | 27 | 28 | def searchForMaxIteration(folder): 29 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 30 | return max(saved_iters) 31 | --------------------------------------------------------------------------------