├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── arguments └── __init__.py ├── assets ├── pipeline.png ├── results │ ├── D-NeRF │ │ ├── Quantitative.jpg │ │ ├── bouncing.gif │ │ ├── hell.gif │ │ ├── hook.gif │ │ ├── jump.gif │ │ ├── lego.gif │ │ ├── mutant.gif │ │ ├── stand.gif │ │ └── trex.gif │ └── NeRF-DS │ │ └── Quantitative.jpg └── teaser.png ├── convert.py ├── full_eval.py ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── requirements.txt ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py ├── deform_model.py └── gaussian_model.py ├── train.py ├── train_gui.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── gui_utils.py ├── image_utils.py ├── loss_utils.py ├── pose_utils.py ├── rigid_utils.py ├── sh_utils.py ├── system_utils.py └── time_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | .idea -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/depth-diff-gaussian-rasterization"] 5 | path = submodules/depth-diff-gaussian-rasterization 6 | url = https://github.com/ingra14m/diff-gaussian-rasterization-extentions 7 | branch = filter-norm 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ziyi Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deformable 3D Gaussians for High-Fidelity Monocular Dynamic Scene Reconstruction 2 | 3 | ## [Project page](https://ingra14m.github.io/Deformable-Gaussians/) | [Paper](https://arxiv.org/abs/2309.13101) 4 | 5 | ![Teaser image](assets/teaser.png) 6 | 7 | This repository contains the official implementation associated with the paper "Deformable 3D Gaussians for High-Fidelity Monocular Dynamic Scene Reconstruction". 8 | 9 | 10 | 11 | ## News 12 | 13 | - **[5/26/2024]** [Lightweight-Deformable-GS](https://github.com/ingra14m/Lightweight-Deformable-GS) has been integrated into this repo. For the original version aligned with paper, please check the [paper](https://github.com/ingra14m/Deformable-3D-Gaussians/tree/paper) branch. 14 | - **[5/24/2024]** An optimized version [Lightweight-Deformable-GS](https://github.com/ingra14m/Lightweight-Deformable-GS) has been released. It offers 50% reduced storage, 200% increased FPS, and no decrease in rendering metrics. 15 | - **[2/27/2024]** Deformable-GS is accepted by CVPR 2024. Our another work, [SC-GS](https://yihua7.github.io/SC-GS-web/) (with higher quality, less points and faster FPS than vanilla 3D-GS), is also accepted. See you in Seattle. 16 | - **[11/16/2023]** Full code and real-time viewer released. 17 | - **[11/4/2023]** update the computation of LPIPS in metrics.py. Previously, the `lpipsPyTorch` was unable to execute on CUDA, prompting us to switch to the `lpips` library (~20x faster). 18 | - **[10/25/2023]** update **real-time viewer** on project page. Many, many thanks to @[yihua7](https://github.com/yihua7) for implementing the real-time viewer adapted for Deformable-GS. Also, thanks to @[ashawkey](https://github.com/ashawkey) for releasing the original GUI. 19 | 20 | 21 | 22 | ## Dataset 23 | 24 | In our paper, we use: 25 | 26 | - synthetic dataset from [D-NeRF](https://www.albertpumarola.com/research/D-NeRF/index.html). 27 | - real-world dataset from [NeRF-DS](https://jokeryan.github.io/projects/nerf-ds/) and [Hyper-NeRF](https://hypernerf.github.io/). 28 | - The dataset in the supplementary materials comes from [DeVRF](https://jia-wei-liu.github.io/DeVRF/). 29 | 30 | We organize the datasets as follows: 31 | 32 | ```shell 33 | ├── data 34 | │ | D-NeRF 35 | │ ├── hook 36 | │ ├── standup 37 | │ ├── ... 38 | │ | NeRF-DS 39 | │ ├── as 40 | │ ├── basin 41 | │ ├── ... 42 | │ | HyperNeRF 43 | │ ├── interp 44 | │ ├── misc 45 | │ ├── vrig 46 | ``` 47 | 48 | > I have identified an **inconsistency in the D-NeRF's Lego dataset**. Specifically, the scenes corresponding to the training set differ from those in the test set. This discrepancy can be verified by observing the angle of the flipped Lego shovel. To meaningfully evaluate the performance of our method on this dataset, I recommend using the **validation set of the Lego dataset** as the test set. See more in [D-NeRF dataset used in Deformable-GS](https://github.com/ingra14m/Deformable-3D-Gaussians/releases/tag/v0.1-pre-released) 49 | 50 | 51 | 52 | ## Pipeline 53 | 54 | ![Teaser image](assets/pipeline.png) 55 | 56 | 57 | 58 | ## Run 59 | 60 | ### Environment 61 | 62 | ```shell 63 | git clone https://github.com/ingra14m/Deformable-3D-Gaussians --recursive 64 | cd Deformable-3D-Gaussians 65 | 66 | conda create -n deformable_gaussian_env python=3.7 67 | conda activate deformable_gaussian_env 68 | 69 | # install pytorch 70 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 71 | 72 | # install dependencies 73 | pip install -r requirements.txt 74 | ``` 75 | 76 | 77 | 78 | ### Train 79 | 80 | **D-NeRF:** 81 | 82 | ```shell 83 | python train.py -s path/to/your/d-nerf/dataset -m output/exp-name --eval --is_blender 84 | ``` 85 | 86 | **NeRF-DS/HyperNeRF:** 87 | 88 | ```shell 89 | python train.py -s path/to/your/real-world/dataset -m output/exp-name --eval --iterations 20000 90 | ``` 91 | 92 | **6DoF Transformation:** 93 | 94 | We have also implemented the 6DoF transformation of 3D-GS, which may lead to an improvement in metrics but will reduce the speed of training and inference. 95 | 96 | ```shell 97 | # D-NeRF 98 | python train.py -s path/to/your/d-nerf/dataset -m output/exp-name --eval --is_blender --is_6dof 99 | 100 | # NeRF-DS & HyperNeRF 101 | python train.py -s path/to/your/real-world/dataset -m output/exp-name --eval --is_6dof --iterations 20000 102 | ``` 103 | 104 | You can also **train with the GUI:** 105 | 106 | ```shell 107 | python train_gui.py -s path/to/your/dataset -m output/exp-name --eval --is_blender 108 | ``` 109 | 110 | - click `start` to start training, and click `stop` to stop training. 111 | - The GUI viewer is still under development, many buttons do not have corresponding functions currently. We plan to : 112 | - [ ] reload checkpoints from the pre-trained model. 113 | - [ ] Complete the functions of the other vacant buttons in the GUI. 114 | 115 | 116 | 117 | ### Render & Evaluation 118 | 119 | ```shell 120 | python render.py -m output/exp-name --mode render 121 | python metrics.py -m output/exp-name 122 | ``` 123 | 124 | We provide several modes for rendering: 125 | 126 | - `render`: render all the test images 127 | - `time`: time interpolation tasks for D-NeRF dataset 128 | - `all`: time and view synthesis tasks for D-NeRF dataset 129 | - `view`: view synthesis tasks for D-NeRF dataset 130 | - `original`: time and view synthesis tasks for real-world dataset 131 | 132 | 133 | 134 | ## Results 135 | 136 | ### D-NeRF Dataset 137 | 138 | **Quantitative Results** 139 | 140 | Image1 141 | 142 | **Qualitative Results** 143 | 144 | Image1 Image1 Image3 Image4 145 | 146 | Image5 Image6 Image7 Image8 147 | 148 | **400x400 Resolution** 149 | 150 | | | PSNR | SSIM | LPIPS (VGG) | FPS | Mem | Num. (k) | 151 | | -------- | ----- | ------ | ----------- | ---- | ----- | -------- | 152 | | bouncing | 41.46 | 0.9958 | 0.0046 | 112 | 13.16 | 55622 | 153 | | hell | 42.11 | 0.9885 | 0.0153 | 375 | 3.72 | 15733 | 154 | | hook | 37.77 | 0.9897 | 0.0103 | 128 | 11.74 | 49613 | 155 | | jump | 39.10 | 0.9930 | 0.0090 | 217 | 6.81 | 28808 | 156 | | mutant | 43.73 | 0.9969 | 0.0029 | 124 | 11.45 | 48423 | 157 | | standup | 45.38 | 0.9967 | 0.0032 | 210 | 5.94 | 25102 | 158 | | trex | 38.40 | 0.9959 | 0.0041 | 85 | 18.6 | 78624 | 159 | | Average | 41.14 | 0.9938 | 0.0070 | 179 | 10.20 | 43132 | 160 | 161 | ### NeRF-DS Dataset 162 | 163 | Image1 164 | 165 | See more visualization on our [project page](https://ingra14m.github.io/Deformable-Gaussians/). 166 | 167 | 168 | 169 | ### HyperNeRF Dataset 170 | 171 | Since the **camera pose** in HyperNeRF is less precise compared to NeRF-DS, we use HyperNeRF as a reference for partial visualization and the display of Failure Cases, but do not include it in the calculation of quantitative metrics. The results of the HyperNeRF dataset can be viewed on the [project page](https://ingra14m.github.io/Deformable-Gaussians/). 172 | 173 | 174 | 175 | ### Real-Time Viewer 176 | 177 | https://github.com/ingra14m/Deformable-3D-Gaussians/assets/63096187/ec26d0b9-c126-4e23-b773-dcedcf386f36 178 | 179 | 180 | 181 | ## Acknowledgments 182 | 183 | We sincerely thank the authors of [3D-GS](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/), [D-NeRF](https://www.albertpumarola.com/research/D-NeRF/index.html), [HyperNeRF](https://hypernerf.github.io/), [NeRF-DS](https://jokeryan.github.io/projects/nerf-ds/), and [DeVRF](https://jia-wei-liu.github.io/DeVRF/), whose codes and datasets were used in our work. We thank [Zihao Wang](https://github.com/Alen-Wong) for the debugging in the early stage, preventing this work from sinking. We also thank the reviewers and AC for not being influenced by PR, and fairly evaluating our work. This work was mainly supported by ByteDance MMLab. 184 | 185 | 186 | 187 | 188 | ## BibTex 189 | 190 | ``` 191 | @article{yang2023deformable3dgs, 192 | title={Deformable 3D Gaussians for High-Fidelity Monocular Dynamic Scene Reconstruction}, 193 | author={Yang, Ziyi and Gao, Xinyu and Zhou, Wen and Jiao, Shaohui and Zhang, Yuqing and Jin, Xiaogang}, 194 | journal={arXiv preprint arXiv:2309.13101}, 195 | year={2023} 196 | } 197 | ``` 198 | 199 | And thanks to the authors of [3D Gaussians](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) for their excellent code, please consider also cite this repository: 200 | 201 | ``` 202 | @Article{kerbl3Dgaussians, 203 | author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George}, 204 | title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering}, 205 | journal = {ACM Transactions on Graphics}, 206 | number = {4}, 207 | volume = {42}, 208 | month = {July}, 209 | year = {2023}, 210 | url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/} 211 | } 212 | ``` 213 | 214 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | 17 | class GroupParams: 18 | pass 19 | 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | if shorthand: 32 | if t == bool: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 34 | else: 35 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 36 | else: 37 | if t == bool: 38 | group.add_argument("--" + key, default=value, action="store_true") 39 | else: 40 | group.add_argument("--" + key, default=value, type=t) 41 | 42 | def extract(self, args): 43 | group = GroupParams() 44 | for arg in vars(args).items(): 45 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 46 | setattr(group, arg[0], arg[1]) 47 | return group 48 | 49 | 50 | class ModelParams(ParamGroup): 51 | def __init__(self, parser, sentinel=False): 52 | self.sh_degree = 3 53 | self._source_path = "" 54 | self._model_path = "" 55 | self._images = "images" 56 | self._resolution = -1 57 | self._white_background = False 58 | self.data_device = "cuda" 59 | self.eval = False 60 | self.load2gpu_on_the_fly = False 61 | self.is_blender = False 62 | self.is_6dof = False 63 | super().__init__(parser, "Loading Parameters", sentinel) 64 | 65 | def extract(self, args): 66 | g = super().extract(args) 67 | g.source_path = os.path.abspath(g.source_path) 68 | return g 69 | 70 | 71 | class PipelineParams(ParamGroup): 72 | def __init__(self, parser): 73 | self.convert_SHs_python = False 74 | self.compute_cov3D_python = False 75 | self.debug = False 76 | super().__init__(parser, "Pipeline Parameters") 77 | 78 | 79 | class OptimizationParams(ParamGroup): 80 | def __init__(self, parser): 81 | self.iterations = 40_000 82 | self.warm_up = 3_000 83 | self.position_lr_init = 0.00016 84 | self.position_lr_final = 0.0000016 85 | self.position_lr_delay_mult = 0.01 86 | self.position_lr_max_steps = 30_000 87 | self.deform_lr_max_steps = 40_000 88 | self.feature_lr = 0.0025 89 | self.opacity_lr = 0.05 90 | self.scaling_lr = 0.001 91 | self.rotation_lr = 0.001 92 | self.percent_dense = 0.01 93 | self.lambda_dssim = 0.2 94 | self.densification_interval = 100 95 | self.opacity_reset_interval = 3000 96 | self.densify_from_iter = 500 97 | self.densify_until_iter = 15_000 98 | self.densify_grad_threshold = 0.0007 99 | super().__init__(parser, "Optimization Parameters") 100 | 101 | 102 | def get_combined_args(parser: ArgumentParser): 103 | cmdlne_string = sys.argv[1:] 104 | cfgfile_string = "Namespace()" 105 | args_cmdline = parser.parse_args(cmdlne_string) 106 | 107 | try: 108 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 109 | print("Looking for config file in", cfgfilepath) 110 | with open(cfgfilepath) as cfg_file: 111 | print("Config file found: {}".format(cfgfilepath)) 112 | cfgfile_string = cfg_file.read() 113 | except TypeError: 114 | print("Config file not found at") 115 | pass 116 | args_cfgfile = eval(cfgfile_string) 117 | 118 | merged_dict = vars(args_cfgfile).copy() 119 | for k, v in vars(args_cmdline).items(): 120 | if v != None: 121 | merged_dict[k] = v 122 | return Namespace(**merged_dict) 123 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/pipeline.png -------------------------------------------------------------------------------- /assets/results/D-NeRF/Quantitative.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/Quantitative.jpg -------------------------------------------------------------------------------- /assets/results/D-NeRF/bouncing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/bouncing.gif -------------------------------------------------------------------------------- /assets/results/D-NeRF/hell.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/hell.gif -------------------------------------------------------------------------------- /assets/results/D-NeRF/hook.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/hook.gif -------------------------------------------------------------------------------- /assets/results/D-NeRF/jump.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/jump.gif -------------------------------------------------------------------------------- /assets/results/D-NeRF/lego.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/lego.gif -------------------------------------------------------------------------------- /assets/results/D-NeRF/mutant.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/mutant.gif -------------------------------------------------------------------------------- /assets/results/D-NeRF/stand.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/stand.gif -------------------------------------------------------------------------------- /assets/results/D-NeRF/trex.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/trex.gif -------------------------------------------------------------------------------- /assets/results/NeRF-DS/Quantitative.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/NeRF-DS/Quantitative.jpg -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/teaser.png -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | import shutil 15 | 16 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 17 | parser = ArgumentParser("Colmap converter") 18 | parser.add_argument("--no_gpu", action='store_true') 19 | parser.add_argument("--skip_matching", action='store_true') 20 | parser.add_argument("--source_path", "-s", required=True, type=str) 21 | parser.add_argument("--camera", default="OPENCV", type=str) 22 | parser.add_argument("--colmap_executable", default="", type=str) 23 | parser.add_argument("--resize", action="store_true") 24 | parser.add_argument("--magick_executable", default="", type=str) 25 | args = parser.parse_args() 26 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 27 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 28 | use_gpu = 1 if not args.no_gpu else 0 29 | 30 | if not args.skip_matching: 31 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 32 | 33 | ## Feature extraction 34 | os.system(colmap_command + " feature_extractor "\ 35 | "--database_path " + args.source_path + "/distorted/database.db \ 36 | --image_path " + args.source_path + "/input \ 37 | --ImageReader.single_camera 1 \ 38 | --ImageReader.camera_model " + args.camera + " \ 39 | --SiftExtraction.use_gpu " + str(use_gpu)) 40 | 41 | ## Feature matching 42 | os.system(colmap_command + " exhaustive_matcher \ 43 | --database_path " + args.source_path + "/distorted/database.db \ 44 | --SiftMatching.use_gpu " + str(use_gpu)) 45 | 46 | ### Bundle adjustment 47 | # The default Mapper tolerance is unnecessarily large, 48 | # decreasing it speeds up bundle adjustment steps. 49 | os.system(colmap_command + " mapper \ 50 | --database_path " + args.source_path + "/distorted/database.db \ 51 | --image_path " + args.source_path + "/input \ 52 | --output_path " + args.source_path + "/distorted/sparse \ 53 | --Mapper.ba_global_function_tolerance=0.000001") 54 | 55 | ### Image undistortion 56 | ## We need to undistort our images into ideal pinhole intrinsics. 57 | os.system(colmap_command + " image_undistorter \ 58 | --image_path " + args.source_path + "/input \ 59 | --input_path " + args.source_path + "/distorted/sparse/0 \ 60 | --output_path " + args.source_path + "\ 61 | --output_type COLMAP") 62 | 63 | files = os.listdir(args.source_path + "/sparse") 64 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 65 | # Copy each file from the source directory to the destination directory 66 | for file in files: 67 | if file == '0': 68 | continue 69 | source_file = os.path.join(args.source_path, "sparse", file) 70 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 71 | shutil.move(source_file, destination_file) 72 | 73 | if(args.resize): 74 | print("Copying and resizing...") 75 | 76 | # Resize images. 77 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 78 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 79 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 80 | # Get the list of files in the source directory 81 | files = os.listdir(args.source_path + "/images") 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | source_file = os.path.join(args.source_path, "images", file) 85 | 86 | destination_file = os.path.join(args.source_path, "images_2", file) 87 | shutil.copy2(source_file, destination_file) 88 | os.system(magick_command + " mogrify -resize 50% " + destination_file) 89 | 90 | destination_file = os.path.join(args.source_path, "images_4", file) 91 | shutil.copy2(source_file, destination_file) 92 | os.system(magick_command + " mogrify -resize 25% " + destination_file) 93 | 94 | destination_file = os.path.join(args.source_path, "images_8", file) 95 | shutil.copy2(source_file, destination_file) 96 | os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 97 | 98 | print("Done.") -------------------------------------------------------------------------------- /full_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] 17 | tanks_and_temples_scenes = ["truck", "train"] 18 | deep_blending_scenes = ["drjohnson", "playroom"] 19 | 20 | parser = ArgumentParser(description="Full evaluation script parameters") 21 | parser.add_argument("--skip_training", action="store_true") 22 | parser.add_argument("--skip_rendering", action="store_true") 23 | parser.add_argument("--skip_metrics", action="store_true") 24 | parser.add_argument("--output_path", default="./eval") 25 | args, _ = parser.parse_known_args() 26 | 27 | all_scenes = [] 28 | all_scenes.extend(mipnerf360_outdoor_scenes) 29 | all_scenes.extend(mipnerf360_indoor_scenes) 30 | all_scenes.extend(tanks_and_temples_scenes) 31 | all_scenes.extend(deep_blending_scenes) 32 | 33 | if not args.skip_training or not args.skip_rendering: 34 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str) 35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) 36 | parser.add_argument("--deepblending", "-db", required=True, type=str) 37 | args = parser.parse_args() 38 | 39 | if not args.skip_training: 40 | common_args = " --quiet --eval --test_iterations -1 " 41 | for scene in mipnerf360_outdoor_scenes: 42 | source = args.mipnerf360 + "/" + scene 43 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) 44 | for scene in mipnerf360_indoor_scenes: 45 | source = args.mipnerf360 + "/" + scene 46 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) 47 | for scene in tanks_and_temples_scenes: 48 | source = args.tanksandtemples + "/" + scene 49 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 50 | for scene in deep_blending_scenes: 51 | source = args.deepblending + "/" + scene 52 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 53 | 54 | if not args.skip_rendering: 55 | all_sources = [] 56 | for scene in mipnerf360_outdoor_scenes: 57 | all_sources.append(args.mipnerf360 + "/" + scene) 58 | for scene in mipnerf360_indoor_scenes: 59 | all_sources.append(args.mipnerf360 + "/" + scene) 60 | for scene in tanks_and_temples_scenes: 61 | all_sources.append(args.tanksandtemples + "/" + scene) 62 | for scene in deep_blending_scenes: 63 | all_sources.append(args.deepblending + "/" + scene) 64 | 65 | common_args = " --quiet --eval --skip_train" 66 | for scene, source in zip(all_scenes, all_sources): 67 | os.system( 68 | "python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 69 | os.system( 70 | "python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 71 | 72 | if not args.skip_metrics: 73 | scenes_string = "" 74 | for scene in all_scenes: 75 | scenes_string += "\"" + args.output_path + "/" + scene + "\" " 76 | 77 | os.system("python metrics.py -m " + scenes_string) 78 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | from utils.rigid_utils import from_homogenous, to_homogenous 18 | 19 | 20 | def quaternion_multiply(q1, q2): 21 | w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3] 22 | w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3] 23 | 24 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 25 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 26 | y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 27 | z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 28 | 29 | return torch.stack((w, x, y, z), dim=-1) 30 | 31 | 32 | def render(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, d_xyz, d_rotation, d_scaling, is_6dof=False, 33 | scaling_modifier=1.0, override_color=None): 34 | """ 35 | Render the scene. 36 | 37 | Background tensor (bg_color) must be on GPU! 38 | """ 39 | 40 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 41 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 42 | screenspace_points_densify = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 43 | try: 44 | screenspace_points.retain_grad() 45 | screenspace_points_densify.retain_grad() 46 | except: 47 | pass 48 | 49 | # Set up rasterization configuration 50 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 51 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 52 | 53 | raster_settings = GaussianRasterizationSettings( 54 | image_height=int(viewpoint_camera.image_height), 55 | image_width=int(viewpoint_camera.image_width), 56 | tanfovx=tanfovx, 57 | tanfovy=tanfovy, 58 | bg=bg_color, 59 | scale_modifier=scaling_modifier, 60 | viewmatrix=viewpoint_camera.world_view_transform, 61 | projmatrix=viewpoint_camera.full_proj_transform, 62 | sh_degree=pc.active_sh_degree, 63 | campos=viewpoint_camera.camera_center, 64 | prefiltered=False, 65 | debug=pipe.debug, 66 | ) 67 | 68 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 69 | 70 | if is_6dof: 71 | if torch.is_tensor(d_xyz) is False: 72 | means3D = pc.get_xyz 73 | else: 74 | means3D = from_homogenous( 75 | torch.bmm(d_xyz, to_homogenous(pc.get_xyz).unsqueeze(-1)).squeeze(-1)) 76 | else: 77 | means3D = pc.get_xyz + d_xyz 78 | opacity = pc.get_opacity 79 | 80 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 81 | # scaling / rotation by the rasterizer. 82 | scales = None 83 | rotations = None 84 | cov3D_precomp = None 85 | if pipe.compute_cov3D_python: 86 | cov3D_precomp = pc.get_covariance(scaling_modifier) 87 | else: 88 | scales = pc.get_scaling + d_scaling 89 | rotations = pc.get_rotation + d_rotation 90 | 91 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 92 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 93 | shs = None 94 | colors_precomp = None 95 | if colors_precomp is None: 96 | if pipe.convert_SHs_python: 97 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2) 98 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 99 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 100 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 101 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 102 | else: 103 | shs = pc.get_features 104 | else: 105 | colors_precomp = override_color 106 | 107 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 108 | rendered_image, radii, depth = rasterizer( 109 | means3D=means3D, 110 | means2D=screenspace_points, 111 | means2D_densify=screenspace_points_densify, 112 | shs=shs, 113 | colors_precomp=colors_precomp, 114 | opacities=opacity, 115 | scales=scales, 116 | rotations=rotations, 117 | cov3D_precomp=cov3D_precomp) 118 | 119 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 120 | # They will be excluded from value updates used in the splitting criteria. 121 | return {"render": rendered_image, 122 | "viewspace_points": screenspace_points, 123 | "viewspace_points_densify": screenspace_points_densify, 124 | "visibility_filter": radii > 0, 125 | "radii": radii, 126 | "depth": depth} 127 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | 27 | def init(wish_host, wish_port): 28 | global host, port, listener 29 | host = wish_host 30 | port = wish_port 31 | listener.bind((host, port)) 32 | listener.listen() 33 | listener.settimeout(0) 34 | 35 | 36 | def try_connect(): 37 | global conn, addr, listener 38 | try: 39 | conn, addr = listener.accept() 40 | print(f"\nConnected by {addr}") 41 | conn.settimeout(None) 42 | except Exception as inst: 43 | pass 44 | 45 | 46 | def read(): 47 | global conn 48 | messageLength = conn.recv(4) 49 | messageLength = int.from_bytes(messageLength, 'little') 50 | message = conn.recv(messageLength) 51 | return json.loads(message.decode("utf-8")) 52 | 53 | 54 | def send(message_bytes, verify): 55 | global conn 56 | if message_bytes != None: 57 | conn.sendall(message_bytes) 58 | conn.sendall(len(verify).to_bytes(4, 'little')) 59 | conn.sendall(bytes(verify, 'ascii')) 60 | 61 | 62 | def receive(): 63 | message = read() 64 | 65 | width = message["resolution_x"] 66 | height = message["resolution_y"] 67 | 68 | if width != 0 and height != 0: 69 | try: 70 | do_training = bool(message["train"]) 71 | fovy = message["fov_y"] 72 | fovx = message["fov_x"] 73 | znear = message["z_near"] 74 | zfar = message["z_far"] 75 | do_shs_python = bool(message["shs_python"]) 76 | do_rot_scale_python = bool(message["rot_scale_python"]) 77 | keep_alive = bool(message["keep_alive"]) 78 | scaling_modifier = message["scaling_modifier"] 79 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 80 | world_view_transform[:, 1] = -world_view_transform[:, 1] 81 | world_view_transform[:, 2] = -world_view_transform[:, 2] 82 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 83 | full_proj_transform[:, 1] = -full_proj_transform[:, 1] 84 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 85 | except Exception as e: 86 | print("") 87 | traceback.print_exc() 88 | raise e 89 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 90 | else: 91 | return None, None, None, None, None, None 92 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | 18 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | # from lpipsPyTorch import lpips 19 | import lpips 20 | import json 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from argparse import ArgumentParser 24 | 25 | 26 | def readImages(renders_dir, gt_dir): 27 | renders = [] 28 | gts = [] 29 | image_names = [] 30 | for fname in os.listdir(renders_dir): 31 | render = Image.open(renders_dir / fname) 32 | gt = Image.open(gt_dir / fname) 33 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 34 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 35 | image_names.append(fname) 36 | return renders, gts, image_names 37 | 38 | 39 | def evaluate(model_paths): 40 | full_dict = {} 41 | per_view_dict = {} 42 | full_dict_polytopeonly = {} 43 | per_view_dict_polytopeonly = {} 44 | print("") 45 | 46 | for scene_dir in model_paths: 47 | try: 48 | print("Scene:", scene_dir) 49 | full_dict[scene_dir] = {} 50 | per_view_dict[scene_dir] = {} 51 | full_dict_polytopeonly[scene_dir] = {} 52 | per_view_dict_polytopeonly[scene_dir] = {} 53 | 54 | test_dir = Path(scene_dir) / "test" 55 | 56 | for method in os.listdir(test_dir): 57 | if not method.startswith("ours"): 58 | continue 59 | print("Method:", method) 60 | 61 | full_dict[scene_dir][method] = {} 62 | per_view_dict[scene_dir][method] = {} 63 | full_dict_polytopeonly[scene_dir][method] = {} 64 | per_view_dict_polytopeonly[scene_dir][method] = {} 65 | 66 | method_dir = test_dir / method 67 | gt_dir = method_dir / "gt" 68 | renders_dir = method_dir / "renders" 69 | renders, gts, image_names = readImages(renders_dir, gt_dir) 70 | 71 | ssims = [] 72 | psnrs = [] 73 | lpipss = [] 74 | 75 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 76 | ssims.append(ssim(renders[idx], gts[idx])) 77 | psnrs.append(psnr(renders[idx], gts[idx])) 78 | lpipss.append(lpips_fn(renders[idx], gts[idx]).detach()) 79 | 80 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 81 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 82 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 83 | print("") 84 | 85 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 86 | "PSNR": torch.tensor(psnrs).mean().item(), 87 | "LPIPS": torch.tensor(lpipss).mean().item()}) 88 | per_view_dict[scene_dir][method].update( 89 | {"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 90 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 91 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 92 | 93 | with open(scene_dir + "/results.json", 'w') as fp: 94 | json.dump(full_dict[scene_dir], fp, indent=True) 95 | with open(scene_dir + "/per_view.json", 'w') as fp: 96 | json.dump(per_view_dict[scene_dir], fp, indent=True) 97 | except: 98 | print("Unable to compute metrics for model", scene_dir) 99 | 100 | 101 | if __name__ == "__main__": 102 | device = torch.device("cuda:0") 103 | torch.cuda.set_device(device) 104 | lpips_fn = lpips.LPIPS(net='vgg').to(device) 105 | 106 | # Set up command line argument parser 107 | parser = ArgumentParser(description="Training script parameters") 108 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 109 | args = parser.parse_args() 110 | evaluate(args.model_paths) 111 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene, DeformModel 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from utils.pose_utils import pose_spherical, render_wander_path 21 | from argparse import ArgumentParser 22 | from arguments import ModelParams, PipelineParams, get_combined_args 23 | from gaussian_renderer import GaussianModel 24 | import imageio 25 | import numpy as np 26 | import time 27 | 28 | 29 | def render_set(model_path, load2gpu_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, deform): 30 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 31 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 32 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth") 33 | 34 | makedirs(render_path, exist_ok=True) 35 | makedirs(gts_path, exist_ok=True) 36 | makedirs(depth_path, exist_ok=True) 37 | 38 | t_list = [] 39 | 40 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 41 | if load2gpu_on_the_fly: 42 | view.load2device() 43 | fid = view.fid 44 | xyz = gaussians.get_xyz 45 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 46 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 47 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof) 48 | rendering = results["render"] 49 | depth = results["depth"] 50 | depth = depth / (depth.max() + 1e-5) 51 | 52 | gt = view.original_image[0:3, :, :] 53 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 54 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 55 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png")) 56 | 57 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 58 | fid = view.fid 59 | xyz = gaussians.get_xyz 60 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 61 | 62 | torch.cuda.synchronize() 63 | t_start = time.time() 64 | 65 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 66 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof) 67 | 68 | torch.cuda.synchronize() 69 | t_end = time.time() 70 | t_list.append(t_end - t_start) 71 | 72 | t = np.array(t_list[5:]) 73 | fps = 1.0 / t.mean() 74 | print(f'Test FPS: \033[1;35m{fps:.5f}\033[0m, Num. of GS: {xyz.shape[0]}') 75 | 76 | 77 | def interpolate_time(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, deform): 78 | render_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "renders") 79 | depth_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "depth") 80 | 81 | makedirs(render_path, exist_ok=True) 82 | makedirs(depth_path, exist_ok=True) 83 | 84 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 85 | 86 | frame = 150 87 | idx = torch.randint(0, len(views), (1,)).item() 88 | view = views[idx] 89 | renderings = [] 90 | for t in tqdm(range(0, frame, 1), desc="Rendering progress"): 91 | fid = torch.Tensor([t / (frame - 1)]).cuda() 92 | xyz = gaussians.get_xyz 93 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 94 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 95 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof) 96 | rendering = results["render"] 97 | renderings.append(to8b(rendering.cpu().numpy())) 98 | depth = results["depth"] 99 | depth = depth / (depth.max() + 1e-5) 100 | 101 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(t) + ".png")) 102 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(t) + ".png")) 103 | 104 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 105 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8) 106 | 107 | 108 | def interpolate_view(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, timer): 109 | render_path = os.path.join(model_path, name, "interpolate_view_{}".format(iteration), "renders") 110 | depth_path = os.path.join(model_path, name, "interpolate_view_{}".format(iteration), "depth") 111 | # acc_path = os.path.join(model_path, name, "interpolate_view_{}".format(iteration), "acc") 112 | 113 | makedirs(render_path, exist_ok=True) 114 | makedirs(depth_path, exist_ok=True) 115 | # makedirs(acc_path, exist_ok=True) 116 | 117 | frame = 150 118 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 119 | 120 | idx = torch.randint(0, len(views), (1,)).item() 121 | view = views[idx] # Choose a specific time for rendering 122 | 123 | render_poses = torch.stack(render_wander_path(view), 0) 124 | # render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, frame + 1)[:-1]], 125 | # 0) 126 | 127 | renderings = [] 128 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")): 129 | fid = view.fid 130 | 131 | matrix = np.linalg.inv(np.array(pose)) 132 | R = -np.transpose(matrix[:3, :3]) 133 | R[:, 0] = -R[:, 0] 134 | T = -matrix[:3, 3] 135 | 136 | view.reset_extrinsic(R, T) 137 | 138 | xyz = gaussians.get_xyz 139 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 140 | d_xyz, d_rotation, d_scaling = timer.step(xyz.detach(), time_input) 141 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof) 142 | rendering = results["render"] 143 | renderings.append(to8b(rendering.cpu().numpy())) 144 | depth = results["depth"] 145 | depth = depth / (depth.max() + 1e-5) 146 | # acc = results["acc"] 147 | 148 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png")) 149 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png")) 150 | # torchvision.utils.save_image(acc, os.path.join(acc_path, '{0:05d}'.format(i) + ".png")) 151 | 152 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 153 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8) 154 | 155 | 156 | def interpolate_all(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, deform): 157 | render_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "renders") 158 | depth_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "depth") 159 | 160 | makedirs(render_path, exist_ok=True) 161 | makedirs(depth_path, exist_ok=True) 162 | 163 | frame = 150 164 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, frame + 1)[:-1]], 165 | 0) 166 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 167 | 168 | idx = torch.randint(0, len(views), (1,)).item() 169 | view = views[idx] # Choose a specific time for rendering 170 | 171 | renderings = [] 172 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")): 173 | fid = torch.Tensor([i / (frame - 1)]).cuda() 174 | 175 | matrix = np.linalg.inv(np.array(pose)) 176 | R = -np.transpose(matrix[:3, :3]) 177 | R[:, 0] = -R[:, 0] 178 | T = -matrix[:3, 3] 179 | 180 | view.reset_extrinsic(R, T) 181 | 182 | xyz = gaussians.get_xyz 183 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 184 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 185 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof) 186 | rendering = results["render"] 187 | renderings.append(to8b(rendering.cpu().numpy())) 188 | depth = results["depth"] 189 | depth = depth / (depth.max() + 1e-5) 190 | 191 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png")) 192 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png")) 193 | 194 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 195 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8) 196 | 197 | 198 | def interpolate_poses(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, timer): 199 | render_path = os.path.join(model_path, name, "interpolate_pose_{}".format(iteration), "renders") 200 | depth_path = os.path.join(model_path, name, "interpolate_pose_{}".format(iteration), "depth") 201 | 202 | makedirs(render_path, exist_ok=True) 203 | makedirs(depth_path, exist_ok=True) 204 | # makedirs(acc_path, exist_ok=True) 205 | frame = 520 206 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 207 | 208 | idx = torch.randint(0, len(views), (1,)).item() 209 | view_begin = views[0] # Choose a specific time for rendering 210 | view_end = views[-1] 211 | view = views[idx] 212 | 213 | R_begin = view_begin.R 214 | R_end = view_end.R 215 | t_begin = view_begin.T 216 | t_end = view_end.T 217 | 218 | renderings = [] 219 | for i in tqdm(range(frame), desc="Rendering progress"): 220 | fid = view.fid 221 | 222 | ratio = i / (frame - 1) 223 | 224 | R_cur = (1 - ratio) * R_begin + ratio * R_end 225 | T_cur = (1 - ratio) * t_begin + ratio * t_end 226 | 227 | view.reset_extrinsic(R_cur, T_cur) 228 | 229 | xyz = gaussians.get_xyz 230 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 231 | d_xyz, d_rotation, d_scaling = timer.step(xyz.detach(), time_input) 232 | 233 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof) 234 | rendering = results["render"] 235 | renderings.append(to8b(rendering.cpu().numpy())) 236 | depth = results["depth"] 237 | depth = depth / (depth.max() + 1e-5) 238 | 239 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 240 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8) 241 | 242 | 243 | def interpolate_view_original(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, 244 | timer): 245 | render_path = os.path.join(model_path, name, "interpolate_hyper_view_{}".format(iteration), "renders") 246 | depth_path = os.path.join(model_path, name, "interpolate_hyper_view_{}".format(iteration), "depth") 247 | # acc_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "acc") 248 | 249 | makedirs(render_path, exist_ok=True) 250 | makedirs(depth_path, exist_ok=True) 251 | 252 | frame = 1000 253 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 254 | 255 | R = [] 256 | T = [] 257 | for view in views: 258 | R.append(view.R) 259 | T.append(view.T) 260 | 261 | view = views[0] 262 | renderings = [] 263 | for i in tqdm(range(frame), desc="Rendering progress"): 264 | fid = torch.Tensor([i / (frame - 1)]).cuda() 265 | 266 | query_idx = i / frame * len(views) 267 | begin_idx = int(np.floor(query_idx)) 268 | end_idx = int(np.ceil(query_idx)) 269 | if end_idx == len(views): 270 | break 271 | view_begin = views[begin_idx] 272 | view_end = views[end_idx] 273 | R_begin = view_begin.R 274 | R_end = view_end.R 275 | t_begin = view_begin.T 276 | t_end = view_end.T 277 | 278 | ratio = query_idx - begin_idx 279 | 280 | R_cur = (1 - ratio) * R_begin + ratio * R_end 281 | T_cur = (1 - ratio) * t_begin + ratio * t_end 282 | 283 | view.reset_extrinsic(R_cur, T_cur) 284 | 285 | xyz = gaussians.get_xyz 286 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 287 | d_xyz, d_rotation, d_scaling = timer.step(xyz.detach(), time_input) 288 | 289 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof) 290 | rendering = results["render"] 291 | renderings.append(to8b(rendering.cpu().numpy())) 292 | depth = results["depth"] 293 | depth = depth / (depth.max() + 1e-5) 294 | 295 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 296 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8) 297 | 298 | 299 | def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, skip_test: bool, 300 | mode: str): 301 | with torch.no_grad(): 302 | gaussians = GaussianModel(dataset.sh_degree) 303 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 304 | deform = DeformModel(dataset.is_blender, dataset.is_6dof) 305 | deform.load_weights(dataset.model_path) 306 | 307 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 308 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 309 | 310 | if mode == "render": 311 | render_func = render_set 312 | elif mode == "time": 313 | render_func = interpolate_time 314 | elif mode == "view": 315 | render_func = interpolate_view 316 | elif mode == "pose": 317 | render_func = interpolate_poses 318 | elif mode == "original": 319 | render_func = interpolate_view_original 320 | else: 321 | render_func = interpolate_all 322 | 323 | if not skip_train: 324 | render_func(dataset.model_path, dataset.load2gpu_on_the_fly, dataset.is_6dof, "train", scene.loaded_iter, 325 | scene.getTrainCameras(), gaussians, pipeline, 326 | background, deform) 327 | 328 | if not skip_test: 329 | render_func(dataset.model_path, dataset.load2gpu_on_the_fly, dataset.is_6dof, "test", scene.loaded_iter, 330 | scene.getTestCameras(), gaussians, pipeline, 331 | background, deform) 332 | 333 | 334 | if __name__ == "__main__": 335 | # Set up command line argument parser 336 | parser = ArgumentParser(description="Testing script parameters") 337 | model = ModelParams(parser, sentinel=True) 338 | pipeline = PipelineParams(parser) 339 | parser.add_argument("--iteration", default=-1, type=int) 340 | parser.add_argument("--skip_train", action="store_true") 341 | parser.add_argument("--skip_test", action="store_true") 342 | parser.add_argument("--quiet", action="store_true") 343 | parser.add_argument("--mode", default='render', choices=['render', 'time', 'view', 'all', 'pose', 'original']) 344 | args = get_combined_args(parser) 345 | print("Rendering " + args.model_path) 346 | 347 | # Initialize system state (RNG) 348 | safe_state(args.quiet) 349 | 350 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.mode) 351 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | submodules/depth-diff-gaussian-rasterization 2 | submodules/simple-knn 3 | plyfile==0.8.1 4 | tqdm 5 | imageio==2.27.0 6 | opencv-python 7 | imageio-ffmpeg 8 | scipy 9 | dearpygui 10 | lpips 11 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from scene.deform_model import DeformModel 19 | from arguments import ModelParams 20 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 21 | 22 | 23 | class Scene: 24 | gaussians: GaussianModel 25 | 26 | def __init__(self, args: ModelParams, gaussians: GaussianModel, load_iteration=None, shuffle=True, 27 | resolution_scales=[1.0]): 28 | """b 29 | :param path: Path to colmap scene main folder. 30 | """ 31 | self.model_path = args.model_path 32 | self.loaded_iter = None 33 | self.gaussians = gaussians 34 | 35 | if load_iteration: 36 | if load_iteration == -1: 37 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 38 | else: 39 | self.loaded_iter = load_iteration 40 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 41 | 42 | self.train_cameras = {} 43 | self.test_cameras = {} 44 | 45 | if os.path.exists(os.path.join(args.source_path, "sparse")): 46 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 47 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 48 | print("Found transforms_train.json file, assuming Blender data set!") 49 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 50 | elif os.path.exists(os.path.join(args.source_path, "cameras_sphere.npz")): 51 | print("Found cameras_sphere.npz file, assuming DTU data set!") 52 | scene_info = sceneLoadTypeCallbacks["DTU"](args.source_path, "cameras_sphere.npz", "cameras_sphere.npz") 53 | elif os.path.exists(os.path.join(args.source_path, "dataset.json")): 54 | print("Found dataset.json file, assuming Nerfies data set!") 55 | scene_info = sceneLoadTypeCallbacks["nerfies"](args.source_path, args.eval) 56 | elif os.path.exists(os.path.join(args.source_path, "poses_bounds.npy")): 57 | print("Found calibration_full.json, assuming Neu3D data set!") 58 | scene_info = sceneLoadTypeCallbacks["plenopticVideo"](args.source_path, args.eval, 24) 59 | elif os.path.exists(os.path.join(args.source_path, "transforms.json")): 60 | print("Found calibration_full.json, assuming Dynamic-360 data set!") 61 | scene_info = sceneLoadTypeCallbacks["dynamic360"](args.source_path) 62 | else: 63 | assert False, "Could not recognize scene type!" 64 | 65 | if not self.loaded_iter: 66 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply"), 67 | 'wb') as dest_file: 68 | dest_file.write(src_file.read()) 69 | json_cams = [] 70 | camlist = [] 71 | if scene_info.test_cameras: 72 | camlist.extend(scene_info.test_cameras) 73 | if scene_info.train_cameras: 74 | camlist.extend(scene_info.train_cameras) 75 | for id, cam in enumerate(camlist): 76 | json_cams.append(camera_to_JSON(id, cam)) 77 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 78 | json.dump(json_cams, file) 79 | 80 | if shuffle: 81 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 82 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 83 | 84 | self.cameras_extent = scene_info.nerf_normalization["radius"] 85 | 86 | for resolution_scale in resolution_scales: 87 | print("Loading Training Cameras") 88 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, 89 | args) 90 | print("Loading Test Cameras") 91 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, 92 | args) 93 | 94 | if self.loaded_iter: 95 | self.gaussians.load_ply(os.path.join(self.model_path, 96 | "point_cloud", 97 | "iteration_" + str(self.loaded_iter), 98 | "point_cloud.ply"), 99 | og_number_points=len(scene_info.point_cloud.points)) 100 | else: 101 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 102 | 103 | def save(self, iteration): 104 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 105 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 106 | 107 | def getTrainCameras(self, scale=1.0): 108 | return self.train_cameras[scale] 109 | 110 | def getTestCameras(self, scale=1.0): 111 | return self.test_cameras[scale] 112 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | 18 | class Camera(nn.Module): 19 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", fid=None, depth=None): 21 | super(Camera, self).__init__() 22 | 23 | self.uid = uid 24 | self.colmap_id = colmap_id 25 | self.R = R 26 | self.T = T 27 | self.FoVx = FoVx 28 | self.FoVy = FoVy 29 | self.image_name = image_name 30 | 31 | try: 32 | self.data_device = torch.device(data_device) 33 | except Exception as e: 34 | print(e) 35 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device") 36 | self.data_device = torch.device("cuda") 37 | 38 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 39 | self.fid = torch.Tensor(np.array([fid])).to(self.data_device) 40 | self.image_width = self.original_image.shape[2] 41 | self.image_height = self.original_image.shape[1] 42 | self.depth = torch.Tensor(depth).to(self.data_device) if depth is not None else None 43 | 44 | if gt_alpha_mask is not None: 45 | self.original_image *= gt_alpha_mask.to(self.data_device) 46 | else: 47 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 48 | 49 | self.zfar = 100.0 50 | self.znear = 0.01 51 | 52 | self.trans = trans 53 | self.scale = scale 54 | 55 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).to( 56 | self.data_device) 57 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, 58 | fovY=self.FoVy).transpose(0, 1).to(self.data_device) 59 | self.full_proj_transform = ( 60 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 61 | self.camera_center = self.world_view_transform.inverse()[3, :3] 62 | 63 | def reset_extrinsic(self, R, T): 64 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, self.trans, self.scale)).transpose(0, 1).cuda() 65 | self.full_proj_transform = ( 66 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 67 | self.camera_center = self.world_view_transform.inverse()[3, :3] 68 | 69 | def load2device(self, data_device='cuda'): 70 | self.original_image = self.original_image.to(data_device) 71 | self.world_view_transform = self.world_view_transform.to(data_device) 72 | self.projection_matrix = self.projection_matrix.to(data_device) 73 | self.full_proj_transform = self.full_proj_transform.to(data_device) 74 | self.camera_center = self.camera_center.to(data_device) 75 | self.fid = self.fid.to(data_device) 76 | 77 | 78 | class MiniCam: 79 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 80 | self.image_width = width 81 | self.image_height = height 82 | self.FoVy = fovy 83 | self.FoVx = fovx 84 | self.znear = znear 85 | self.zfar = zfar 86 | self.world_view_transform = world_view_transform 87 | self.full_proj_transform = full_proj_transform 88 | view_inv = torch.inverse(self.world_view_transform) 89 | self.camera_center = view_inv[3][:3] 90 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]]) 54 | 55 | 56 | def rotmat2qvec(R): 57 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 58 | K = np.array([ 59 | [Rxx - Ryy - Rzz, 0, 0, 0], 60 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 61 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 62 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 63 | eigvals, eigvecs = np.linalg.eigh(K) 64 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 65 | if qvec[0] < 0: 66 | qvec *= -1 67 | return qvec 68 | 69 | 70 | class Image(BaseImage): 71 | def qvec2rotmat(self): 72 | return qvec2rotmat(self.qvec) 73 | 74 | 75 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 76 | """Read and unpack the next bytes from a binary file. 77 | :param fid: 78 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 79 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 80 | :param endian_character: Any of {@, =, <, >, !} 81 | :return: Tuple of read and unpacked values. 82 | """ 83 | data = fid.read(num_bytes) 84 | return struct.unpack(endian_character + format_char_sequence, data) 85 | 86 | 87 | def read_points3D_text(path): 88 | """ 89 | see: src/base/reconstruction.cc 90 | void Reconstruction::ReadPoints3DText(const std::string& path) 91 | void Reconstruction::WritePoints3DText(const std::string& path) 92 | """ 93 | xyzs = None 94 | rgbs = None 95 | errors = None 96 | with open(path, "r") as fid: 97 | while True: 98 | line = fid.readline() 99 | if not line: 100 | break 101 | line = line.strip() 102 | if len(line) > 0 and line[0] != "#": 103 | elems = line.split() 104 | xyz = np.array(tuple(map(float, elems[1:4]))) 105 | rgb = np.array(tuple(map(int, elems[4:7]))) 106 | error = np.array(float(elems[7])) 107 | if xyzs is None: 108 | xyzs = xyz[None, ...] 109 | rgbs = rgb[None, ...] 110 | errors = error[None, ...] 111 | else: 112 | xyzs = np.append(xyzs, xyz[None, ...], axis=0) 113 | rgbs = np.append(rgbs, rgb[None, ...], axis=0) 114 | errors = np.append(errors, error[None, ...], axis=0) 115 | return xyzs, rgbs, errors 116 | 117 | 118 | def read_points3D_binary(path_to_model_file): 119 | """ 120 | see: src/base/reconstruction.cc 121 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 122 | void Reconstruction::WritePoints3DBinary(const std::string& path) 123 | """ 124 | 125 | with open(path_to_model_file, "rb") as fid: 126 | num_points = read_next_bytes(fid, 8, "Q")[0] 127 | 128 | xyzs = np.empty((num_points, 3)) 129 | rgbs = np.empty((num_points, 3)) 130 | errors = np.empty((num_points, 1)) 131 | 132 | for p_id in range(num_points): 133 | binary_point_line_properties = read_next_bytes( 134 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 135 | xyz = np.array(binary_point_line_properties[1:4]) 136 | rgb = np.array(binary_point_line_properties[4:7]) 137 | error = np.array(binary_point_line_properties[7]) 138 | track_length = read_next_bytes( 139 | fid, num_bytes=8, format_char_sequence="Q")[0] 140 | track_elems = read_next_bytes( 141 | fid, num_bytes=8 * track_length, 142 | format_char_sequence="ii" * track_length) 143 | xyzs[p_id] = xyz 144 | rgbs[p_id] = rgb 145 | errors[p_id] = error 146 | return xyzs, rgbs, errors 147 | 148 | 149 | def read_intrinsics_text(path): 150 | """ 151 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 152 | """ 153 | cameras = {} 154 | with open(path, "r") as fid: 155 | while True: 156 | line = fid.readline() 157 | if not line: 158 | break 159 | line = line.strip() 160 | if len(line) > 0 and line[0] != "#": 161 | elems = line.split() 162 | camera_id = int(elems[0]) 163 | model = elems[1] 164 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 165 | width = int(elems[2]) 166 | height = int(elems[3]) 167 | params = np.array(tuple(map(float, elems[4:]))) 168 | cameras[camera_id] = Camera(id=camera_id, model=model, 169 | width=width, height=height, 170 | params=params) 171 | return cameras 172 | 173 | 174 | def read_extrinsics_binary(path_to_model_file): 175 | """ 176 | see: src/base/reconstruction.cc 177 | void Reconstruction::ReadImagesBinary(const std::string& path) 178 | void Reconstruction::WriteImagesBinary(const std::string& path) 179 | """ 180 | images = {} 181 | with open(path_to_model_file, "rb") as fid: 182 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 183 | for _ in range(num_reg_images): 184 | binary_image_properties = read_next_bytes( 185 | fid, num_bytes=64, format_char_sequence="idddddddi") 186 | image_id = binary_image_properties[0] 187 | qvec = np.array(binary_image_properties[1:5]) 188 | tvec = np.array(binary_image_properties[5:8]) 189 | camera_id = binary_image_properties[8] 190 | image_name = "" 191 | current_char = read_next_bytes(fid, 1, "c")[0] 192 | while current_char != b"\x00": # look for the ASCII 0 entry 193 | image_name += current_char.decode("utf-8") 194 | current_char = read_next_bytes(fid, 1, "c")[0] 195 | num_points2D = read_next_bytes(fid, num_bytes=8, 196 | format_char_sequence="Q")[0] 197 | x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, 198 | format_char_sequence="ddq" * num_points2D) 199 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 200 | tuple(map(float, x_y_id_s[1::3]))]) 201 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 202 | images[image_id] = Image( 203 | id=image_id, qvec=qvec, tvec=tvec, 204 | camera_id=camera_id, name=image_name, 205 | xys=xys, point3D_ids=point3D_ids) 206 | return images 207 | 208 | 209 | def read_intrinsics_binary(path_to_model_file): 210 | """ 211 | see: src/base/reconstruction.cc 212 | void Reconstruction::WriteCamerasBinary(const std::string& path) 213 | void Reconstruction::ReadCamerasBinary(const std::string& path) 214 | """ 215 | cameras = {} 216 | with open(path_to_model_file, "rb") as fid: 217 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 218 | for _ in range(num_cameras): 219 | camera_properties = read_next_bytes( 220 | fid, num_bytes=24, format_char_sequence="iiQQ") 221 | camera_id = camera_properties[0] 222 | model_id = camera_properties[1] 223 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 224 | width = camera_properties[2] 225 | height = camera_properties[3] 226 | num_params = CAMERA_MODEL_IDS[model_id].num_params 227 | params = read_next_bytes(fid, num_bytes=8 * num_params, 228 | format_char_sequence="d" * num_params) 229 | cameras[camera_id] = Camera(id=camera_id, 230 | model=model_name, 231 | width=width, 232 | height=height, 233 | params=np.array(params)) 234 | assert len(cameras) == num_cameras 235 | return cameras 236 | 237 | 238 | def read_extrinsics_text(path): 239 | """ 240 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 241 | """ 242 | images = {} 243 | with open(path, "r") as fid: 244 | while True: 245 | line = fid.readline() 246 | if not line: 247 | break 248 | line = line.strip() 249 | if len(line) > 0 and line[0] != "#": 250 | elems = line.split() 251 | image_id = int(elems[0]) 252 | qvec = np.array(tuple(map(float, elems[1:5]))) 253 | tvec = np.array(tuple(map(float, elems[5:8]))) 254 | camera_id = int(elems[8]) 255 | image_name = elems[9] 256 | elems = fid.readline().split() 257 | xys = np.column_stack([tuple(map(float, elems[0::3])), 258 | tuple(map(float, elems[1::3]))]) 259 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 260 | images[image_id] = Image( 261 | id=image_id, qvec=qvec, tvec=tvec, 262 | camera_id=camera_id, name=image_name, 263 | xys=xys, point3D_ids=point3D_ids) 264 | return images 265 | 266 | 267 | def read_colmap_bin_array(path): 268 | """ 269 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 270 | 271 | :param path: path to the colmap binary file. 272 | :return: nd array with the floating point values in the value 273 | """ 274 | with open(path, "rb") as fid: 275 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 276 | usecols=(0, 1, 2), dtype=int) 277 | fid.seek(0) 278 | num_delimiter = 0 279 | byte = fid.read(1) 280 | while True: 281 | if byte == b"&": 282 | num_delimiter += 1 283 | if num_delimiter >= 3: 284 | break 285 | byte = fid.read(1) 286 | array = np.fromfile(fid, np.float32) 287 | array = array.reshape((width, height, channels), order="F") 288 | return np.transpose(array, (1, 0, 2)).squeeze() 289 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple, Optional 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | import imageio 22 | from glob import glob 23 | import cv2 as cv 24 | from pathlib import Path 25 | from plyfile import PlyData, PlyElement 26 | from utils.sh_utils import SH2RGB 27 | from scene.gaussian_model import BasicPointCloud 28 | from utils.camera_utils import camera_nerfies_from_JSON 29 | 30 | 31 | class CameraInfo(NamedTuple): 32 | uid: int 33 | R: np.array 34 | T: np.array 35 | FovY: np.array 36 | FovX: np.array 37 | image: np.array 38 | image_path: str 39 | image_name: str 40 | width: int 41 | height: int 42 | fid: float 43 | depth: Optional[np.array] = None 44 | 45 | 46 | class SceneInfo(NamedTuple): 47 | point_cloud: BasicPointCloud 48 | train_cameras: list 49 | test_cameras: list 50 | nerf_normalization: dict 51 | ply_path: str 52 | 53 | 54 | def load_K_Rt_from_P(filename, P=None): 55 | if P is None: 56 | lines = open(filename).read().splitlines() 57 | if len(lines) == 4: 58 | lines = lines[1:] 59 | lines = [[x[0], x[1], x[2], x[3]] 60 | for x in (x.split(" ") for x in lines)] 61 | P = np.asarray(lines).astype(np.float32).squeeze() 62 | 63 | out = cv.decomposeProjectionMatrix(P) 64 | K = out[0] 65 | R = out[1] 66 | t = out[2] 67 | 68 | K = K / K[2, 2] 69 | 70 | pose = np.eye(4, dtype=np.float32) 71 | pose[:3, :3] = R.transpose() 72 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 73 | 74 | return K, pose 75 | 76 | 77 | def getNerfppNorm(cam_info): 78 | def get_center_and_diag(cam_centers): 79 | cam_centers = np.hstack(cam_centers) 80 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 81 | center = avg_cam_center 82 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 83 | diagonal = np.max(dist) 84 | return center.flatten(), diagonal 85 | 86 | cam_centers = [] 87 | 88 | for cam in cam_info: 89 | W2C = getWorld2View2(cam.R, cam.T) 90 | C2W = np.linalg.inv(W2C) 91 | cam_centers.append(C2W[:3, 3:4]) 92 | 93 | center, diagonal = get_center_and_diag(cam_centers) 94 | radius = diagonal * 1.1 95 | 96 | translate = -center 97 | 98 | return {"translate": translate, "radius": radius} 99 | 100 | 101 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 102 | cam_infos = [] 103 | num_frames = len(cam_extrinsics) 104 | for idx, key in enumerate(cam_extrinsics): 105 | sys.stdout.write('\r') 106 | # the exact output you're looking for: 107 | sys.stdout.write( 108 | "Reading camera {}/{}".format(idx + 1, len(cam_extrinsics))) 109 | sys.stdout.flush() 110 | 111 | extr = cam_extrinsics[key] 112 | intr = cam_intrinsics[extr.camera_id] 113 | height = intr.height 114 | width = intr.width 115 | 116 | uid = intr.id 117 | R = np.transpose(qvec2rotmat(extr.qvec)) 118 | T = np.array(extr.tvec) 119 | 120 | if intr.model == "SIMPLE_PINHOLE": 121 | focal_length_x = intr.params[0] 122 | FovY = focal2fov(focal_length_x, height) 123 | FovX = focal2fov(focal_length_x, width) 124 | elif intr.model == "PINHOLE": 125 | focal_length_x = intr.params[0] 126 | focal_length_y = intr.params[1] 127 | FovY = focal2fov(focal_length_y, height) 128 | FovX = focal2fov(focal_length_x, width) 129 | else: 130 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 131 | 132 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 133 | image_name = os.path.basename(image_path).split(".")[0] 134 | image = Image.open(image_path) 135 | 136 | fid = int(image_name) / (num_frames - 1) 137 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 138 | image_path=image_path, image_name=image_name, width=width, height=height, fid=fid) 139 | cam_infos.append(cam_info) 140 | sys.stdout.write('\n') 141 | return cam_infos 142 | 143 | 144 | def fetchPly(path): 145 | plydata = PlyData.read(path) 146 | vertices = plydata['vertex'] 147 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 148 | colors = np.vstack([vertices['red'], vertices['green'], 149 | vertices['blue']]).T / 255.0 150 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 151 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 152 | 153 | 154 | def storePly(path, xyz, rgb): 155 | # Define the dtype for the structured array 156 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 157 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 158 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 159 | 160 | normals = np.zeros_like(xyz) 161 | 162 | elements = np.empty(xyz.shape[0], dtype=dtype) 163 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 164 | elements[:] = list(map(tuple, attributes)) 165 | 166 | # Create the PlyData object and write to file 167 | vertex_element = PlyElement.describe(elements, 'vertex') 168 | ply_data = PlyData([vertex_element]) 169 | ply_data.write(path) 170 | 171 | 172 | def readColmapSceneInfo(path, images, eval, llffhold=8): 173 | try: 174 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 175 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 176 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 177 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 178 | except: 179 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 180 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 181 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 182 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 183 | 184 | reading_dir = "images" if images == None else images 185 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 186 | images_folder=os.path.join(path, reading_dir)) 187 | cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name) 188 | 189 | if eval: 190 | train_cam_infos = [c for idx, c in enumerate( 191 | cam_infos) if idx % llffhold != 0] 192 | test_cam_infos = [c for idx, c in enumerate( 193 | cam_infos) if idx % llffhold == 0] 194 | else: 195 | train_cam_infos = cam_infos 196 | test_cam_infos = [] 197 | 198 | nerf_normalization = getNerfppNorm(train_cam_infos) 199 | 200 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 201 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 202 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 203 | if not os.path.exists(ply_path): 204 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 205 | try: 206 | xyz, rgb, _ = read_points3D_binary(bin_path) 207 | except: 208 | xyz, rgb, _ = read_points3D_text(txt_path) 209 | storePly(ply_path, xyz, rgb) 210 | try: 211 | pcd = fetchPly(ply_path) 212 | except: 213 | pcd = None 214 | 215 | scene_info = SceneInfo(point_cloud=pcd, 216 | train_cameras=train_cam_infos, 217 | test_cameras=test_cam_infos, 218 | nerf_normalization=nerf_normalization, 219 | ply_path=ply_path) 220 | return scene_info 221 | 222 | 223 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 224 | cam_infos = [] 225 | 226 | with open(os.path.join(path, transformsfile)) as json_file: 227 | contents = json.load(json_file) 228 | fovx = contents["camera_angle_x"] 229 | 230 | frames = contents["frames"] 231 | for idx, frame in enumerate(frames): 232 | cam_name = os.path.join(path, frame["file_path"] + extension) 233 | frame_time = frame['time'] 234 | 235 | matrix = np.linalg.inv(np.array(frame["transform_matrix"])) 236 | R = -np.transpose(matrix[:3, :3]) 237 | R[:, 0] = -R[:, 0] 238 | T = -matrix[:3, 3] 239 | 240 | image_path = os.path.join(path, cam_name) 241 | image_name = Path(cam_name).stem 242 | image = Image.open(image_path) 243 | 244 | im_data = np.array(image.convert("RGBA")) 245 | 246 | bg = np.array( 247 | [1, 1, 1]) if white_background else np.array([0, 0, 0]) 248 | 249 | norm_data = im_data / 255.0 250 | mask = norm_data[..., 3:4] 251 | 252 | arr = norm_data[:, :, :3] * norm_data[:, :, 253 | 3:4] + bg * (1 - norm_data[:, :, 3:4]) 254 | image = Image.fromarray( 255 | np.array(arr * 255.0, dtype=np.byte), "RGB") 256 | 257 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 258 | FovY = fovx 259 | FovX = fovy 260 | 261 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 262 | image_path=image_path, image_name=image_name, width=image.size[ 263 | 0], 264 | height=image.size[1], fid=frame_time)) 265 | 266 | return cam_infos 267 | 268 | 269 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 270 | print("Reading Training Transforms") 271 | train_cam_infos = readCamerasFromTransforms( 272 | path, "transforms_train.json", white_background, extension) 273 | print("Reading Test Transforms") 274 | test_cam_infos = readCamerasFromTransforms( 275 | path, "transforms_test.json", white_background, extension) 276 | 277 | if not eval: 278 | train_cam_infos.extend(test_cam_infos) 279 | test_cam_infos = [] 280 | 281 | nerf_normalization = getNerfppNorm(train_cam_infos) 282 | 283 | ply_path = os.path.join(path, "points3d.ply") 284 | if not os.path.exists(ply_path): 285 | # Since this data set has no colmap data, we start with random points 286 | num_pts = 100_000 287 | print(f"Generating random point cloud ({num_pts})...") 288 | 289 | # We create random points inside the bounds of the synthetic Blender scenes 290 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 291 | shs = np.random.random((num_pts, 3)) / 255.0 292 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB( 293 | shs), normals=np.zeros((num_pts, 3))) 294 | 295 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 296 | try: 297 | pcd = fetchPly(ply_path) 298 | except: 299 | pcd = None 300 | 301 | scene_info = SceneInfo(point_cloud=pcd, 302 | train_cameras=train_cam_infos, 303 | test_cameras=test_cam_infos, 304 | nerf_normalization=nerf_normalization, 305 | ply_path=ply_path) 306 | return scene_info 307 | 308 | 309 | def readDTUCameras(path, render_camera, object_camera): 310 | camera_dict = np.load(os.path.join(path, render_camera)) 311 | images_lis = sorted(glob(os.path.join(path, 'image/*.png'))) 312 | masks_lis = sorted(glob(os.path.join(path, 'mask/*.png'))) 313 | n_images = len(images_lis) 314 | cam_infos = [] 315 | cam_idx = 0 316 | for idx in range(0, n_images): 317 | image_path = images_lis[idx] 318 | image = np.array(Image.open(image_path)) 319 | mask = np.array(imageio.imread(masks_lis[idx])) / 255.0 320 | image = Image.fromarray((image * mask).astype(np.uint8)) 321 | world_mat = camera_dict['world_mat_%d' % idx].astype(np.float32) 322 | fid = camera_dict['fid_%d' % idx] / (n_images / 12 - 1) 323 | image_name = Path(image_path).stem 324 | scale_mat = camera_dict['scale_mat_%d' % idx].astype(np.float32) 325 | P = world_mat @ scale_mat 326 | P = P[:3, :4] 327 | 328 | K, pose = load_K_Rt_from_P(None, P) 329 | a = pose[0:1, :] 330 | b = pose[1:2, :] 331 | c = pose[2:3, :] 332 | 333 | pose = np.concatenate([a, -c, -b, pose[3:, :]], 0) 334 | 335 | S = np.eye(3) 336 | S[1, 1] = -1 337 | S[2, 2] = -1 338 | pose[1, 3] = -pose[1, 3] 339 | pose[2, 3] = -pose[2, 3] 340 | pose[:3, :3] = S @ pose[:3, :3] @ S 341 | 342 | a = pose[0:1, :] 343 | b = pose[1:2, :] 344 | c = pose[2:3, :] 345 | 346 | pose = np.concatenate([a, c, b, pose[3:, :]], 0) 347 | 348 | pose[:, 3] *= 0.5 349 | 350 | matrix = np.linalg.inv(pose) 351 | R = -np.transpose(matrix[:3, :3]) 352 | R[:, 0] = -R[:, 0] 353 | T = -matrix[:3, 3] 354 | 355 | FovY = focal2fov(K[0, 0], image.size[1]) 356 | FovX = focal2fov(K[0, 0], image.size[0]) 357 | cam_info = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 358 | image_path=image_path, image_name=image_name, width=image.size[ 359 | 0], height=image.size[1], 360 | fid=fid) 361 | cam_infos.append(cam_info) 362 | sys.stdout.write('\n') 363 | return cam_infos 364 | 365 | 366 | def readNeuSDTUInfo(path, render_camera, object_camera): 367 | print("Reading DTU Info") 368 | train_cam_infos = readDTUCameras(path, render_camera, object_camera) 369 | 370 | nerf_normalization = getNerfppNorm(train_cam_infos) 371 | 372 | ply_path = os.path.join(path, "points3d.ply") 373 | if not os.path.exists(ply_path): 374 | # Since this data set has no colmap data, we start with random points 375 | num_pts = 100_000 376 | print(f"Generating random point cloud ({num_pts})...") 377 | 378 | # We create random points inside the bounds of the synthetic Blender scenes 379 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 380 | shs = np.random.random((num_pts, 3)) / 255.0 381 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB( 382 | shs), normals=np.zeros((num_pts, 3))) 383 | 384 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 385 | try: 386 | pcd = fetchPly(ply_path) 387 | except: 388 | pcd = None 389 | 390 | scene_info = SceneInfo(point_cloud=pcd, 391 | train_cameras=train_cam_infos, 392 | test_cameras=[], 393 | nerf_normalization=nerf_normalization, 394 | ply_path=ply_path) 395 | return scene_info 396 | 397 | 398 | def readNerfiesCameras(path): 399 | with open(f'{path}/scene.json', 'r') as f: 400 | scene_json = json.load(f) 401 | with open(f'{path}/metadata.json', 'r') as f: 402 | meta_json = json.load(f) 403 | with open(f'{path}/dataset.json', 'r') as f: 404 | dataset_json = json.load(f) 405 | 406 | coord_scale = scene_json['scale'] 407 | scene_center = scene_json['center'] 408 | 409 | name = path.split('/')[-2] 410 | if name.startswith('vrig'): 411 | train_img = dataset_json['train_ids'] 412 | val_img = dataset_json['val_ids'] 413 | all_img = train_img + val_img 414 | ratio = 0.25 415 | elif name.startswith('NeRF'): 416 | train_img = dataset_json['train_ids'] 417 | val_img = dataset_json['val_ids'] 418 | all_img = train_img + val_img 419 | ratio = 1.0 420 | elif name.startswith('interp'): 421 | all_id = dataset_json['ids'] 422 | train_img = all_id[::4] 423 | val_img = all_id[2::4] 424 | all_img = train_img + val_img 425 | ratio = 0.5 426 | else: # for hypernerf 427 | train_img = dataset_json['ids'][::4] 428 | all_img = train_img 429 | ratio = 0.5 430 | 431 | train_num = len(train_img) 432 | 433 | all_cam = [meta_json[i]['camera_id'] for i in all_img] 434 | all_time = [meta_json[i]['time_id'] for i in all_img] 435 | max_time = max(all_time) 436 | all_time = [meta_json[i]['time_id'] / max_time for i in all_img] 437 | selected_time = set(all_time) 438 | 439 | # all poses 440 | all_cam_params = [] 441 | for im in all_img: 442 | camera = camera_nerfies_from_JSON(f'{path}/camera/{im}.json', ratio) 443 | camera['position'] = camera['position'] - scene_center 444 | camera['position'] = camera['position'] * coord_scale 445 | all_cam_params.append(camera) 446 | 447 | all_img = [f'{path}/rgb/{int(1 / ratio)}x/{i}.png' for i in all_img] 448 | 449 | cam_infos = [] 450 | for idx in range(len(all_img)): 451 | image_path = all_img[idx] 452 | image = np.array(Image.open(image_path)) 453 | image = Image.fromarray((image).astype(np.uint8)) 454 | image_name = Path(image_path).stem 455 | 456 | orientation = all_cam_params[idx]['orientation'].T 457 | position = -all_cam_params[idx]['position'] @ orientation 458 | focal = all_cam_params[idx]['focal_length'] 459 | fid = all_time[idx] 460 | T = position 461 | R = orientation 462 | 463 | FovY = focal2fov(focal, image.size[1]) 464 | FovX = focal2fov(focal, image.size[0]) 465 | cam_info = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 466 | image_path=image_path, image_name=image_name, width=image.size[ 467 | 0], height=image.size[1], 468 | fid=fid) 469 | cam_infos.append(cam_info) 470 | 471 | sys.stdout.write('\n') 472 | return cam_infos, train_num, scene_center, coord_scale 473 | 474 | 475 | def readNerfiesInfo(path, eval): 476 | print("Reading Nerfies Info") 477 | cam_infos, train_num, scene_center, scene_scale = readNerfiesCameras(path) 478 | 479 | if eval: 480 | train_cam_infos = cam_infos[:train_num] 481 | test_cam_infos = cam_infos[train_num:] 482 | else: 483 | train_cam_infos = cam_infos 484 | test_cam_infos = [] 485 | 486 | nerf_normalization = getNerfppNorm(train_cam_infos) 487 | 488 | ply_path = os.path.join(path, "points3d.ply") 489 | if not os.path.exists(ply_path): 490 | print(f"Generating point cloud from nerfies...") 491 | 492 | xyz = np.load(os.path.join(path, "points.npy")) 493 | xyz = (xyz - scene_center) * scene_scale 494 | num_pts = xyz.shape[0] 495 | shs = np.random.random((num_pts, 3)) / 255.0 496 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB( 497 | shs), normals=np.zeros((num_pts, 3))) 498 | 499 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 500 | try: 501 | pcd = fetchPly(ply_path) 502 | except: 503 | pcd = None 504 | 505 | scene_info = SceneInfo(point_cloud=pcd, 506 | train_cameras=train_cam_infos, 507 | test_cameras=test_cam_infos, 508 | nerf_normalization=nerf_normalization, 509 | ply_path=ply_path) 510 | return scene_info 511 | 512 | 513 | def readCamerasFromNpy(path, npy_file, split, hold_id, num_images): 514 | cam_infos = [] 515 | video_paths = sorted(glob(os.path.join(path, 'frames/*'))) 516 | poses_bounds = np.load(os.path.join(path, npy_file)) 517 | 518 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) 519 | H, W, focal = poses[0, :, -1] 520 | 521 | n_cameras = poses.shape[0] 522 | poses = np.concatenate( 523 | [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 524 | bottoms = np.array([0, 0, 0, 1]).reshape( 525 | 1, -1, 4).repeat(poses.shape[0], axis=0) 526 | poses = np.concatenate([poses, bottoms], axis=1) 527 | poses = poses @ np.diag([1, -1, -1, 1]) 528 | 529 | i_test = np.array(hold_id) 530 | video_list = i_test if split != 'train' else list( 531 | set(np.arange(n_cameras)) - set(i_test)) 532 | 533 | for i in video_list: 534 | video_path = video_paths[i] 535 | c2w = poses[i] 536 | images_names = sorted(os.listdir(video_path)) 537 | n_frames = num_images 538 | 539 | matrix = np.linalg.inv(np.array(c2w)) 540 | R = np.transpose(matrix[:3, :3]) 541 | T = matrix[:3, 3] 542 | 543 | for idx, image_name in enumerate(images_names[:num_images]): 544 | image_path = os.path.join(video_path, image_name) 545 | image = Image.open(image_path) 546 | frame_time = idx / (n_frames - 1) 547 | 548 | FovX = focal2fov(focal, image.size[0]) 549 | FovY = focal2fov(focal, image.size[1]) 550 | 551 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovX=FovX, FovY=FovY, 552 | image=image, 553 | image_path=image_path, image_name=image_name, 554 | width=image.size[0], height=image.size[1], fid=frame_time)) 555 | 556 | idx += 1 557 | return cam_infos 558 | 559 | 560 | def readPlenopticVideoDataset(path, eval, num_images, hold_id=[0]): 561 | print("Reading Training Camera") 562 | train_cam_infos = readCamerasFromNpy(path, 'poses_bounds.npy', split="train", hold_id=hold_id, 563 | num_images=num_images) 564 | 565 | print("Reading Training Camera") 566 | test_cam_infos = readCamerasFromNpy( 567 | path, 'poses_bounds.npy', split="test", hold_id=hold_id, num_images=num_images) 568 | 569 | if not eval: 570 | train_cam_infos.extend(test_cam_infos) 571 | test_cam_infos = [] 572 | 573 | nerf_normalization = getNerfppNorm(train_cam_infos) 574 | ply_path = os.path.join(path, 'points3D.ply') 575 | if not os.path.exists(ply_path): 576 | num_pts = 100_000 577 | print(f"Generating random point cloud ({num_pts})...") 578 | 579 | # We create random points inside the bounds of the synthetic Blender scenes 580 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 581 | shs = np.random.random((num_pts, 3)) / 255.0 582 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB( 583 | shs), normals=np.zeros((num_pts, 3))) 584 | 585 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 586 | 587 | try: 588 | pcd = fetchPly(ply_path) 589 | except: 590 | pcd = None 591 | 592 | scene_info = SceneInfo(point_cloud=pcd, 593 | train_cameras=train_cam_infos, 594 | test_cameras=test_cam_infos, 595 | nerf_normalization=nerf_normalization, 596 | ply_path=ply_path) 597 | return scene_info 598 | 599 | 600 | sceneLoadTypeCallbacks = { 601 | "Colmap": readColmapSceneInfo, # colmap dataset reader from official 3D Gaussian [https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/] 602 | "Blender": readNerfSyntheticInfo, # D-NeRF dataset [https://drive.google.com/file/d/1uHVyApwqugXTFuIRRlE4abTW8_rrVeIK/view?usp=sharing] 603 | "DTU": readNeuSDTUInfo, # DTU dataset used in Tensor4D [https://github.com/DSaurus/Tensor4D] 604 | "nerfies": readNerfiesInfo, # NeRFies & HyperNeRF dataset proposed by [https://github.com/google/hypernerf/releases/tag/v0.1] 605 | "plenopticVideo": readPlenopticVideoDataset, # Neural 3D dataset in [https://github.com/facebookresearch/Neural_3D_Video] 606 | } 607 | -------------------------------------------------------------------------------- /scene/deform_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.time_utils import DeformNetwork 5 | import os 6 | from utils.system_utils import searchForMaxIteration 7 | from utils.general_utils import get_expon_lr_func 8 | 9 | 10 | class DeformModel: 11 | def __init__(self, is_blender=False, is_6dof=False): 12 | self.deform = DeformNetwork(is_blender=is_blender, is_6dof=is_6dof).cuda() 13 | self.optimizer = None 14 | self.spatial_lr_scale = 5 15 | 16 | def step(self, xyz, time_emb): 17 | return self.deform(xyz, time_emb) 18 | 19 | def train_setting(self, training_args): 20 | l = [ 21 | {'params': list(self.deform.parameters()), 22 | 'lr': training_args.position_lr_init * self.spatial_lr_scale, 23 | "name": "deform"} 24 | ] 25 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 26 | 27 | self.deform_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale, 28 | lr_final=training_args.position_lr_final, 29 | lr_delay_mult=training_args.position_lr_delay_mult, 30 | max_steps=training_args.deform_lr_max_steps) 31 | 32 | def save_weights(self, model_path, iteration): 33 | out_weights_path = os.path.join(model_path, "deform/iteration_{}".format(iteration)) 34 | os.makedirs(out_weights_path, exist_ok=True) 35 | torch.save(self.deform.state_dict(), os.path.join(out_weights_path, 'deform.pth')) 36 | 37 | def load_weights(self, model_path, iteration=-1): 38 | if iteration == -1: 39 | loaded_iter = searchForMaxIteration(os.path.join(model_path, "deform")) 40 | else: 41 | loaded_iter = iteration 42 | weights_path = os.path.join(model_path, "deform/iteration_{}/deform.pth".format(loaded_iter)) 43 | self.deform.load_state_dict(torch.load(weights_path)) 44 | 45 | def update_learning_rate(self, iteration): 46 | for param_group in self.optimizer.param_groups: 47 | if param_group["name"] == "deform": 48 | lr = self.deform_scheduler_args(iteration) 49 | param_group['lr'] = lr 50 | return lr 51 | -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation 15 | from torch import nn 16 | import os 17 | from utils.system_utils import mkdir_p 18 | from plyfile import PlyData, PlyElement 19 | from utils.sh_utils import RGB2SH 20 | from simple_knn._C import distCUDA2 21 | from utils.graphics_utils import BasicPointCloud 22 | from utils.general_utils import strip_symmetric, build_scaling_rotation 23 | 24 | 25 | class GaussianModel: 26 | def __init__(self, sh_degree: int): 27 | 28 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 29 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 30 | actual_covariance = L @ L.transpose(1, 2) 31 | symm = strip_symmetric(actual_covariance) 32 | return symm 33 | 34 | self.active_sh_degree = 0 35 | self.max_sh_degree = sh_degree 36 | 37 | self._xyz = torch.empty(0) 38 | self._features_dc = torch.empty(0) 39 | self._features_rest = torch.empty(0) 40 | self._scaling = torch.empty(0) 41 | self._rotation = torch.empty(0) 42 | self._opacity = torch.empty(0) 43 | self.max_radii2D = torch.empty(0) 44 | self.xyz_gradient_accum = torch.empty(0) 45 | 46 | self.optimizer = None 47 | 48 | self.scaling_activation = torch.exp 49 | self.scaling_inverse_activation = torch.log 50 | 51 | self.covariance_activation = build_covariance_from_scaling_rotation 52 | 53 | self.opacity_activation = torch.sigmoid 54 | self.inverse_opacity_activation = inverse_sigmoid 55 | 56 | self.rotation_activation = torch.nn.functional.normalize 57 | 58 | @property 59 | def get_scaling(self): 60 | return self.scaling_activation(self._scaling) 61 | 62 | @property 63 | def get_rotation(self): 64 | return self.rotation_activation(self._rotation) 65 | 66 | @property 67 | def get_xyz(self): 68 | return self._xyz 69 | 70 | @property 71 | def get_features(self): 72 | features_dc = self._features_dc 73 | features_rest = self._features_rest 74 | return torch.cat((features_dc, features_rest), dim=1) 75 | 76 | @property 77 | def get_opacity(self): 78 | return self.opacity_activation(self._opacity) 79 | 80 | def get_covariance(self, scaling_modifier=1): 81 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 82 | 83 | def oneupSHdegree(self): 84 | if self.active_sh_degree < self.max_sh_degree: 85 | self.active_sh_degree += 1 86 | 87 | def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float): 88 | self.spatial_lr_scale = 5 89 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 90 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) 91 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() 92 | features[:, :3, 0] = fused_color 93 | features[:, 3:, 1:] = 0.0 94 | 95 | print("Number of points at initialisation : ", fused_point_cloud.shape[0]) 96 | 97 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 98 | scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3) 99 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 100 | rots[:, 0] = 1 101 | 102 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 103 | 104 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 105 | self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)) 106 | self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)) 107 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 108 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 109 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 110 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 111 | 112 | def training_setup(self, training_args): 113 | self.percent_dense = training_args.percent_dense 114 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 115 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 116 | 117 | self.spatial_lr_scale = 5 118 | 119 | l = [ 120 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, 121 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 122 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 123 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 124 | {'params': [self._scaling], 'lr': training_args.scaling_lr * self.spatial_lr_scale, "name": "scaling"}, 125 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} 126 | ] 127 | 128 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 129 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale, 130 | lr_final=training_args.position_lr_final * self.spatial_lr_scale, 131 | lr_delay_mult=training_args.position_lr_delay_mult, 132 | max_steps=training_args.position_lr_max_steps) 133 | 134 | def update_learning_rate(self, iteration): 135 | ''' Learning rate scheduling per step ''' 136 | for param_group in self.optimizer.param_groups: 137 | if param_group["name"] == "xyz": 138 | lr = self.xyz_scheduler_args(iteration) 139 | param_group['lr'] = lr 140 | return lr 141 | 142 | def construct_list_of_attributes(self): 143 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 144 | # All channels except the 3 DC 145 | for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): 146 | l.append('f_dc_{}'.format(i)) 147 | for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]): 148 | l.append('f_rest_{}'.format(i)) 149 | l.append('opacity') 150 | for i in range(self._scaling.shape[1]): 151 | l.append('scale_{}'.format(i)) 152 | for i in range(self._rotation.shape[1]): 153 | l.append('rot_{}'.format(i)) 154 | return l 155 | 156 | def save_ply(self, path): 157 | mkdir_p(os.path.dirname(path)) 158 | 159 | xyz = self._xyz.detach().cpu().numpy() 160 | normals = np.zeros_like(xyz) 161 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 162 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 163 | opacities = self._opacity.detach().cpu().numpy() 164 | scale = self._scaling.detach().cpu().numpy() 165 | rotation = self._rotation.detach().cpu().numpy() 166 | 167 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 168 | 169 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 170 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 171 | elements[:] = list(map(tuple, attributes)) 172 | el = PlyElement.describe(elements, 'vertex') 173 | PlyData([el]).write(path) 174 | 175 | def reset_opacity(self): 176 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)) 177 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 178 | self._opacity = optimizable_tensors["opacity"] 179 | 180 | def load_ply(self, path, og_number_points=-1): 181 | self.og_number_points = og_number_points 182 | plydata = PlyData.read(path) 183 | 184 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 185 | np.asarray(plydata.elements[0]["y"]), 186 | np.asarray(plydata.elements[0]["z"])), axis=1) 187 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 188 | 189 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 190 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 191 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 192 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 193 | 194 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 195 | assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3 196 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 197 | for idx, attr_name in enumerate(extra_f_names): 198 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 199 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 200 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 201 | 202 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 203 | scales = np.zeros((xyz.shape[0], len(scale_names))) 204 | for idx, attr_name in enumerate(scale_names): 205 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 206 | 207 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 208 | rots = np.zeros((xyz.shape[0], len(rot_names))) 209 | for idx, attr_name in enumerate(rot_names): 210 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 211 | 212 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 213 | self._features_dc = nn.Parameter( 214 | torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_( 215 | True)) 216 | self._features_rest = nn.Parameter( 217 | torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_( 218 | True)) 219 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 220 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 221 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 222 | 223 | self.active_sh_degree = self.max_sh_degree 224 | 225 | def replace_tensor_to_optimizer(self, tensor, name): 226 | optimizable_tensors = {} 227 | for group in self.optimizer.param_groups: 228 | if group["name"] == name: 229 | stored_state = self.optimizer.state.get(group['params'][0], None) 230 | stored_state["exp_avg"] = torch.zeros_like(tensor) 231 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 232 | 233 | del self.optimizer.state[group['params'][0]] 234 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 235 | self.optimizer.state[group['params'][0]] = stored_state 236 | 237 | optimizable_tensors[group["name"]] = group["params"][0] 238 | return optimizable_tensors 239 | 240 | def _prune_optimizer(self, mask): 241 | optimizable_tensors = {} 242 | for group in self.optimizer.param_groups: 243 | stored_state = self.optimizer.state.get(group['params'][0], None) 244 | if stored_state is not None: 245 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 246 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 247 | 248 | del self.optimizer.state[group['params'][0]] 249 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 250 | self.optimizer.state[group['params'][0]] = stored_state 251 | 252 | optimizable_tensors[group["name"]] = group["params"][0] 253 | else: 254 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 255 | optimizable_tensors[group["name"]] = group["params"][0] 256 | return optimizable_tensors 257 | 258 | def prune_points(self, mask): 259 | valid_points_mask = ~mask 260 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 261 | 262 | self._xyz = optimizable_tensors["xyz"] 263 | self._features_dc = optimizable_tensors["f_dc"] 264 | self._features_rest = optimizable_tensors["f_rest"] 265 | self._opacity = optimizable_tensors["opacity"] 266 | self._scaling = optimizable_tensors["scaling"] 267 | self._rotation = optimizable_tensors["rotation"] 268 | 269 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 270 | 271 | self.denom = self.denom[valid_points_mask] 272 | self.max_radii2D = self.max_radii2D[valid_points_mask] 273 | 274 | def cat_tensors_to_optimizer(self, tensors_dict): 275 | optimizable_tensors = {} 276 | for group in self.optimizer.param_groups: 277 | assert len(group["params"]) == 1 278 | extension_tensor = tensors_dict[group["name"]] 279 | stored_state = self.optimizer.state.get(group['params'][0], None) 280 | if stored_state is not None: 281 | 282 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), 283 | dim=0) 284 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), 285 | dim=0) 286 | 287 | del self.optimizer.state[group['params'][0]] 288 | group["params"][0] = nn.Parameter( 289 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 290 | self.optimizer.state[group['params'][0]] = stored_state 291 | 292 | optimizable_tensors[group["name"]] = group["params"][0] 293 | else: 294 | group["params"][0] = nn.Parameter( 295 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 296 | optimizable_tensors[group["name"]] = group["params"][0] 297 | 298 | return optimizable_tensors 299 | 300 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, 301 | new_rotation): 302 | d = {"xyz": new_xyz, 303 | "f_dc": new_features_dc, 304 | "f_rest": new_features_rest, 305 | "opacity": new_opacities, 306 | "scaling": new_scaling, 307 | "rotation": new_rotation} 308 | 309 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 310 | self._xyz = optimizable_tensors["xyz"] 311 | self._features_dc = optimizable_tensors["f_dc"] 312 | self._features_rest = optimizable_tensors["f_rest"] 313 | self._opacity = optimizable_tensors["opacity"] 314 | self._scaling = optimizable_tensors["scaling"] 315 | self._rotation = optimizable_tensors["rotation"] 316 | 317 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 318 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 319 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 320 | 321 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 322 | n_init_points = self.get_xyz.shape[0] 323 | # Extract points that satisfy the gradient condition 324 | padded_grad = torch.zeros((n_init_points), device="cuda") 325 | padded_grad[:grads.shape[0]] = grads.squeeze() 326 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 327 | selected_pts_mask = torch.logical_and(selected_pts_mask, 328 | torch.max(self.get_scaling, 329 | dim=1).values > self.percent_dense * scene_extent) 330 | 331 | stds = self.get_scaling[selected_pts_mask].repeat(N, 1) 332 | means = torch.zeros((stds.size(0), 3), device="cuda") 333 | samples = torch.normal(mean=means, std=stds) 334 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1) 335 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 336 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)) 337 | new_rotation = self._rotation[selected_pts_mask].repeat(N, 1) 338 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1) 339 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1) 340 | new_opacity = self._opacity[selected_pts_mask].repeat(N, 1) 341 | 342 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) 343 | 344 | prune_filter = torch.cat( 345 | (selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 346 | self.prune_points(prune_filter) 347 | 348 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 349 | # Extract points that satisfy the gradient condition 350 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 351 | selected_pts_mask = torch.logical_and(selected_pts_mask, 352 | torch.max(self.get_scaling, 353 | dim=1).values <= self.percent_dense * scene_extent) 354 | 355 | new_xyz = self._xyz[selected_pts_mask] 356 | new_features_dc = self._features_dc[selected_pts_mask] 357 | new_features_rest = self._features_rest[selected_pts_mask] 358 | new_opacities = self._opacity[selected_pts_mask] 359 | new_scaling = self._scaling[selected_pts_mask] 360 | new_rotation = self._rotation[selected_pts_mask] 361 | 362 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, 363 | new_rotation) 364 | 365 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 366 | grads = self.xyz_gradient_accum / self.denom 367 | grads[grads.isnan()] = 0.0 368 | 369 | self.densify_and_clone(grads, max_grad, extent) 370 | self.densify_and_split(grads, max_grad, extent) 371 | 372 | prune_mask = (self.get_opacity < min_opacity).squeeze() 373 | if max_screen_size: 374 | big_points_vs = self.max_radii2D > max_screen_size 375 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 376 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 377 | self.prune_points(prune_mask) 378 | 379 | torch.cuda.empty_cache() 380 | 381 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 382 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter, :2], dim=-1, 383 | keepdim=True) 384 | self.denom[update_filter] += 1 385 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim, kl_divergence 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | from scene import Scene, GaussianModel, DeformModel 19 | from utils.general_utils import safe_state, get_linear_noise_func 20 | import uuid 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from argparse import ArgumentParser, Namespace 24 | from arguments import ModelParams, PipelineParams, OptimizationParams 25 | 26 | try: 27 | from torch.utils.tensorboard import SummaryWriter 28 | 29 | TENSORBOARD_FOUND = True 30 | except ImportError: 31 | TENSORBOARD_FOUND = False 32 | 33 | 34 | def training(dataset, opt, pipe, testing_iterations, saving_iterations): 35 | tb_writer = prepare_output_and_logger(dataset) 36 | gaussians = GaussianModel(dataset.sh_degree) 37 | deform = DeformModel(dataset.is_blender, dataset.is_6dof) 38 | deform.train_setting(opt) 39 | 40 | scene = Scene(dataset, gaussians) 41 | gaussians.training_setup(opt) 42 | 43 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 44 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 45 | 46 | iter_start = torch.cuda.Event(enable_timing=True) 47 | iter_end = torch.cuda.Event(enable_timing=True) 48 | 49 | viewpoint_stack = None 50 | ema_loss_for_log = 0.0 51 | best_psnr = 0.0 52 | best_iteration = 0 53 | progress_bar = tqdm(range(opt.iterations), desc="Training progress") 54 | smooth_term = get_linear_noise_func(lr_init=0.1, lr_final=1e-15, lr_delay_mult=0.01, max_steps=20000) 55 | for iteration in range(1, opt.iterations + 1): 56 | if network_gui.conn == None: 57 | network_gui.try_connect() 58 | while network_gui.conn != None: 59 | try: 60 | net_image_bytes = None 61 | custom_cam, do_training, pipe.do_shs_python, pipe.do_cov_python, keep_alive, scaling_modifer = network_gui.receive() 62 | if custom_cam != None: 63 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 64 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 65 | 0).contiguous().cpu().numpy()) 66 | network_gui.send(net_image_bytes, dataset.source_path) 67 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 68 | break 69 | except Exception as e: 70 | network_gui.conn = None 71 | 72 | iter_start.record() 73 | 74 | # Every 1000 its we increase the levels of SH up to a maximum degree 75 | if iteration % 1000 == 0: 76 | gaussians.oneupSHdegree() 77 | 78 | # Pick a random Camera 79 | if not viewpoint_stack: 80 | viewpoint_stack = scene.getTrainCameras().copy() 81 | 82 | total_frame = len(viewpoint_stack) 83 | time_interval = 1 / total_frame 84 | 85 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) 86 | if dataset.load2gpu_on_the_fly: 87 | viewpoint_cam.load2device() 88 | fid = viewpoint_cam.fid 89 | 90 | if iteration < opt.warm_up: 91 | d_xyz, d_rotation, d_scaling = 0.0, 0.0, 0.0 92 | else: 93 | N = gaussians.get_xyz.shape[0] 94 | time_input = fid.unsqueeze(0).expand(N, -1) 95 | 96 | ast_noise = 0 if dataset.is_blender else torch.randn(1, 1, device='cuda').expand(N, -1) * time_interval * smooth_term(iteration) 97 | d_xyz, d_rotation, d_scaling = deform.step(gaussians.get_xyz.detach(), time_input + ast_noise) 98 | 99 | # Render 100 | render_pkg_re = render(viewpoint_cam, gaussians, pipe, background, d_xyz, d_rotation, d_scaling, dataset.is_6dof) 101 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg_re["render"], render_pkg_re[ 102 | "viewspace_points"], render_pkg_re["visibility_filter"], render_pkg_re["radii"] 103 | # depth = render_pkg_re["depth"] 104 | 105 | # Loss 106 | gt_image = viewpoint_cam.original_image.cuda() 107 | Ll1 = l1_loss(image, gt_image) 108 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 109 | loss.backward() 110 | 111 | iter_end.record() 112 | 113 | if dataset.load2gpu_on_the_fly: 114 | viewpoint_cam.load2device('cpu') 115 | 116 | with torch.no_grad(): 117 | # Progress bar 118 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 119 | if iteration % 10 == 0: 120 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 121 | progress_bar.update(10) 122 | if iteration == opt.iterations: 123 | progress_bar.close() 124 | 125 | # Keep track of max radii in image-space for pruning 126 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], 127 | radii[visibility_filter]) 128 | 129 | # Log and save 130 | cur_psnr = training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), 131 | testing_iterations, scene, render, (pipe, background), deform, 132 | dataset.load2gpu_on_the_fly, dataset.is_6dof) 133 | if iteration in testing_iterations: 134 | if cur_psnr.item() > best_psnr: 135 | best_psnr = cur_psnr.item() 136 | best_iteration = iteration 137 | 138 | if iteration in saving_iterations: 139 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 140 | scene.save(iteration) 141 | deform.save_weights(args.model_path, iteration) 142 | 143 | # Densification 144 | if iteration < opt.densify_until_iter: 145 | viewspace_point_tensor_densify = render_pkg_re["viewspace_points_densify"] 146 | gaussians.add_densification_stats(viewspace_point_tensor_densify, visibility_filter) 147 | 148 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 149 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 150 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 151 | 152 | if iteration % opt.opacity_reset_interval == 0 or ( 153 | dataset.white_background and iteration == opt.densify_from_iter): 154 | gaussians.reset_opacity() 155 | 156 | # Optimizer step 157 | if iteration < opt.iterations: 158 | gaussians.optimizer.step() 159 | gaussians.update_learning_rate(iteration) 160 | deform.optimizer.step() 161 | gaussians.optimizer.zero_grad(set_to_none=True) 162 | deform.optimizer.zero_grad() 163 | deform.update_learning_rate(iteration) 164 | 165 | print("Best PSNR = {} in Iteration {}".format(best_psnr, best_iteration)) 166 | 167 | 168 | def prepare_output_and_logger(args): 169 | if not args.model_path: 170 | if os.getenv('OAR_JOB_ID'): 171 | unique_str = os.getenv('OAR_JOB_ID') 172 | else: 173 | unique_str = str(uuid.uuid4()) 174 | args.model_path = os.path.join("./output/", unique_str[0:10]) 175 | 176 | # Set up output folder 177 | print("Output folder: {}".format(args.model_path)) 178 | os.makedirs(args.model_path, exist_ok=True) 179 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 180 | cfg_log_f.write(str(Namespace(**vars(args)))) 181 | 182 | # Create Tensorboard writer 183 | tb_writer = None 184 | if TENSORBOARD_FOUND: 185 | tb_writer = SummaryWriter(args.model_path) 186 | else: 187 | print("Tensorboard not available: not logging progress") 188 | return tb_writer 189 | 190 | 191 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc, 192 | renderArgs, deform, load2gpu_on_the_fly, is_6dof=False): 193 | if tb_writer: 194 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 195 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 196 | tb_writer.add_scalar('iter_time', elapsed, iteration) 197 | 198 | test_psnr = 0.0 199 | # Report test and samples of training set 200 | if iteration in testing_iterations: 201 | torch.cuda.empty_cache() 202 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()}, 203 | {'name': 'train', 204 | 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in 205 | range(5, 30, 5)]}) 206 | 207 | for config in validation_configs: 208 | if config['cameras'] and len(config['cameras']) > 0: 209 | images = torch.tensor([], device="cuda") 210 | gts = torch.tensor([], device="cuda") 211 | for idx, viewpoint in enumerate(config['cameras']): 212 | if load2gpu_on_the_fly: 213 | viewpoint.load2device() 214 | fid = viewpoint.fid 215 | xyz = scene.gaussians.get_xyz 216 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 217 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 218 | image = torch.clamp( 219 | renderFunc(viewpoint, scene.gaussians, *renderArgs, d_xyz, d_rotation, d_scaling, is_6dof)["render"], 220 | 0.0, 1.0) 221 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 222 | images = torch.cat((images, image.unsqueeze(0)), dim=0) 223 | gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0) 224 | 225 | if load2gpu_on_the_fly: 226 | viewpoint.load2device('cpu') 227 | if tb_writer and (idx < 5): 228 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), 229 | image[None], global_step=iteration) 230 | if iteration == testing_iterations[0]: 231 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), 232 | gt_image[None], global_step=iteration) 233 | 234 | l1_test = l1_loss(images, gts) 235 | psnr_test = psnr(images, gts).mean() 236 | if config['name'] == 'test' or len(validation_configs[0]['cameras']) == 0: 237 | test_psnr = psnr_test 238 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 239 | if tb_writer: 240 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 241 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 242 | 243 | if tb_writer: 244 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 245 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 246 | torch.cuda.empty_cache() 247 | 248 | return test_psnr 249 | 250 | 251 | if __name__ == "__main__": 252 | # Set up command line argument parser 253 | parser = ArgumentParser(description="Training script parameters") 254 | lp = ModelParams(parser) 255 | op = OptimizationParams(parser) 256 | pp = PipelineParams(parser) 257 | parser.add_argument('--ip', type=str, default="127.0.0.1") 258 | parser.add_argument('--port', type=int, default=6009) 259 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 260 | parser.add_argument("--test_iterations", nargs="+", type=int, 261 | default=[5000, 6000, 7_000] + list(range(10000, 40001, 1000))) 262 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 10_000, 20_000, 30_000, 40000]) 263 | parser.add_argument("--quiet", action="store_true") 264 | args = parser.parse_args(sys.argv[1:]) 265 | args.save_iterations.append(args.iterations) 266 | 267 | print("Optimizing " + args.model_path) 268 | 269 | # Initialize system state (RNG) 270 | safe_state(args.quiet) 271 | 272 | # Start GUI server, configure and run training 273 | # network_gui.init(args.ip, args.port) 274 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 275 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations) 276 | 277 | # All done 278 | print("\nTraining complete.") 279 | -------------------------------------------------------------------------------- /train_gui.py: -------------------------------------------------------------------------------- 1 | 7 # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import time 14 | import torch 15 | from random import randint 16 | from utils.loss_utils import l1_loss, ssim, kl_divergence 17 | from gaussian_renderer import render, network_gui 18 | import sys 19 | from scene import Scene, GaussianModel, DeformModel 20 | from utils.general_utils import safe_state, get_linear_noise_func 21 | import uuid 22 | import tqdm 23 | from utils.image_utils import psnr 24 | from argparse import ArgumentParser, Namespace 25 | from arguments import ModelParams, PipelineParams, OptimizationParams 26 | from train import training_report 27 | import math 28 | from utils.gui_utils import orbit_camera, OrbitCamera 29 | import numpy as np 30 | import dearpygui.dearpygui as dpg 31 | 32 | 33 | try: 34 | from torch.utils.tensorboard import SummaryWriter 35 | 36 | TENSORBOARD_FOUND = True 37 | except ImportError: 38 | TENSORBOARD_FOUND = False 39 | 40 | 41 | def getProjectionMatrix(znear, zfar, fovX, fovY): 42 | tanHalfFovY = math.tan((fovY / 2)) 43 | tanHalfFovX = math.tan((fovX / 2)) 44 | 45 | P = torch.zeros(4, 4) 46 | 47 | z_sign = 1.0 48 | 49 | P[0, 0] = 1 / tanHalfFovX 50 | P[1, 1] = 1 / tanHalfFovY 51 | P[3, 2] = z_sign 52 | P[2, 2] = z_sign * zfar / (zfar - znear) 53 | P[2, 3] = -(zfar * znear) / (zfar - znear) 54 | return P 55 | 56 | 57 | class MiniCam: 58 | def __init__(self, c2w, width, height, fovy, fovx, znear, zfar, fid): 59 | # c2w (pose) should be in NeRF convention. 60 | 61 | self.image_width = width 62 | self.image_height = height 63 | self.FoVy = fovy 64 | self.FoVx = fovx 65 | self.znear = znear 66 | self.zfar = zfar 67 | self.fid = fid 68 | 69 | w2c = np.linalg.inv(c2w) 70 | 71 | # rectify... 72 | w2c[1:3, :3] *= -1 73 | w2c[:3, 3] *= -1 74 | 75 | self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda() 76 | self.projection_matrix = ( 77 | getProjectionMatrix( 78 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy 79 | ) 80 | .transpose(0, 1) 81 | .cuda() 82 | ) 83 | self.full_proj_transform = self.world_view_transform @ self.projection_matrix 84 | self.camera_center = -torch.tensor(c2w[:3, 3]).cuda() 85 | 86 | 87 | class GUI: 88 | def __init__(self, args, dataset, opt, pipe, testing_iterations, saving_iterations) -> None: 89 | self.dataset = dataset 90 | self.args = args 91 | self.opt = opt 92 | self.pipe = pipe 93 | self.testing_iterations = testing_iterations 94 | self.saving_iterations = saving_iterations 95 | 96 | self.tb_writer = prepare_output_and_logger(dataset) 97 | self.gaussians = GaussianModel(dataset.sh_degree) 98 | self.deform = DeformModel(is_blender=dataset.is_blender, is_6dof=dataset.is_6dof) 99 | self.deform.train_setting(opt) 100 | 101 | self.scene = Scene(dataset, self.gaussians) 102 | self.gaussians.training_setup(opt) 103 | 104 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 105 | self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 106 | 107 | self.iter_start = torch.cuda.Event(enable_timing=True) 108 | self.iter_end = torch.cuda.Event(enable_timing=True) 109 | self.iteration = 1 110 | 111 | self.viewpoint_stack = None 112 | self.ema_loss_for_log = 0.0 113 | self.best_psnr = 0.0 114 | self.best_iteration = 0 115 | self.progress_bar = tqdm.tqdm(range(opt.iterations), desc="Training progress") 116 | self.smooth_term = get_linear_noise_func(lr_init=0.1, lr_final=1e-15, lr_delay_mult=0.01, max_steps=20000) 117 | 118 | # For UI 119 | self.visualization_mode = 'RGB' 120 | 121 | self.gui = args.gui # enable gui 122 | self.W = args.W 123 | self.H = args.H 124 | self.cam = OrbitCamera(args.W, args.H, r=args.radius, fovy=args.fovy) 125 | 126 | self.mode = "render" 127 | self.seed = "random" 128 | self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) 129 | self.training = False 130 | 131 | if self.gui: 132 | dpg.create_context() 133 | self.register_dpg() 134 | self.test_step() 135 | 136 | def __del__(self): 137 | if self.gui: 138 | dpg.destroy_context() 139 | 140 | def register_dpg(self): 141 | ### register texture 142 | with dpg.texture_registry(show=False): 143 | dpg.add_raw_texture( 144 | self.W, 145 | self.H, 146 | self.buffer_image, 147 | format=dpg.mvFormat_Float_rgb, 148 | tag="_texture", 149 | ) 150 | 151 | ### register window 152 | # the rendered image, as the primary window 153 | with dpg.window( 154 | tag="_primary_window", 155 | width=self.W, 156 | height=self.H, 157 | pos=[0, 0], 158 | no_move=True, 159 | no_title_bar=True, 160 | no_scrollbar=True, 161 | ): 162 | # add the texture 163 | dpg.add_image("_texture") 164 | 165 | # dpg.set_primary_window("_primary_window", True) 166 | 167 | # control window 168 | with dpg.window( 169 | label="Control", 170 | tag="_control_window", 171 | width=600, 172 | height=self.H, 173 | pos=[self.W, 0], 174 | no_move=True, 175 | no_title_bar=True, 176 | ): 177 | # button theme 178 | with dpg.theme() as theme_button: 179 | with dpg.theme_component(dpg.mvButton): 180 | dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) 181 | dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) 182 | dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) 183 | dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) 184 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) 185 | 186 | # timer stuff 187 | with dpg.group(horizontal=True): 188 | dpg.add_text("Infer time: ") 189 | dpg.add_text("no data", tag="_log_infer_time") 190 | 191 | def callback_setattr(sender, app_data, user_data): 192 | setattr(self, user_data, app_data) 193 | 194 | # init stuff 195 | with dpg.collapsing_header(label="Initialize", default_open=True): 196 | 197 | # seed stuff 198 | def callback_set_seed(sender, app_data): 199 | self.seed = app_data 200 | self.seed_everything() 201 | 202 | dpg.add_input_text( 203 | label="seed", 204 | default_value=self.seed, 205 | on_enter=True, 206 | callback=callback_set_seed, 207 | ) 208 | 209 | # input stuff 210 | def callback_select_input(sender, app_data): 211 | # only one item 212 | for k, v in app_data["selections"].items(): 213 | dpg.set_value("_log_input", k) 214 | self.load_input(v) 215 | 216 | self.need_update = True 217 | 218 | with dpg.file_dialog( 219 | directory_selector=False, 220 | show=False, 221 | callback=callback_select_input, 222 | file_count=1, 223 | tag="file_dialog_tag", 224 | width=700, 225 | height=400, 226 | ): 227 | dpg.add_file_extension("Images{.jpg,.jpeg,.png}") 228 | 229 | with dpg.group(horizontal=True): 230 | dpg.add_button( 231 | label="input", 232 | callback=lambda: dpg.show_item("file_dialog_tag"), 233 | ) 234 | dpg.add_text("", tag="_log_input") 235 | 236 | # save current model 237 | with dpg.group(horizontal=True): 238 | dpg.add_text("Visualization: ") 239 | 240 | def callback_vismode(sender, app_data, user_data): 241 | self.visualization_mode = user_data 242 | if user_data == 'Node': 243 | self.node_vis_fea = True if not hasattr(self, 'node_vis_fea') else not self.node_vis_fea 244 | print("Visualize node features" if self.node_vis_fea else "Visualize node importance") 245 | if self.node_vis_fea or True: 246 | from motion import visualize_featuremap 247 | if True: #self.renderer.gaussians.motion_model.soft_edge: 248 | if hasattr(self.renderer.gaussians.motion_model, 'nodes_fea'): 249 | node_rgb = visualize_featuremap(self.renderer.gaussians.motion_model.nodes_fea.detach().cpu().numpy()) 250 | self.node_rgb = torch.from_numpy(node_rgb).cuda() 251 | else: 252 | self.node_rgb = None 253 | else: 254 | self.node_rgb = None 255 | else: 256 | node_imp = self.renderer.gaussians.motion_model.cal_node_importance(x=self.renderer.gaussians.get_xyz) 257 | node_imp = (node_imp - node_imp.min()) / (node_imp.max() - node_imp.min()) 258 | node_rgb = torch.zeros([node_imp.shape[0], 3], dtype=torch.float32).cuda() 259 | node_rgb[..., 0] = node_imp 260 | node_rgb[..., -1] = 1 - node_imp 261 | self.node_rgb = node_rgb 262 | 263 | dpg.add_button( 264 | label="RGB", 265 | tag="_button_vis_rgb", 266 | callback=callback_vismode, 267 | user_data='RGB', 268 | ) 269 | dpg.bind_item_theme("_button_vis_rgb", theme_button) 270 | 271 | dpg.add_button( 272 | label="UV_COOR", 273 | tag="_button_vis_uv", 274 | callback=callback_vismode, 275 | user_data='UV_COOR', 276 | ) 277 | dpg.bind_item_theme("_button_vis_uv", theme_button) 278 | dpg.add_button( 279 | label="MotionMask", 280 | tag="_button_vis_motion_mask", 281 | callback=callback_vismode, 282 | user_data='MotionMask', 283 | ) 284 | dpg.bind_item_theme("_button_vis_motion_mask", theme_button) 285 | 286 | dpg.add_button( 287 | label="Node", 288 | tag="_button_vis_node", 289 | callback=callback_vismode, 290 | user_data='Node', 291 | ) 292 | dpg.bind_item_theme("_button_vis_node", theme_button) 293 | 294 | def callback_use_const_var(sender, app_data): 295 | self.use_const_var = not self.use_const_var 296 | dpg.add_button( 297 | label="Const Var", 298 | tag="_button_const_var", 299 | callback=callback_use_const_var 300 | ) 301 | dpg.bind_item_theme("_button_const_var", theme_button) 302 | 303 | with dpg.group(horizontal=True): 304 | dpg.add_text("Scale Const: ") 305 | def callback_vis_scale_const(sender): 306 | self.vis_scale_const = 10 ** dpg.get_value(sender) 307 | self.need_update = True 308 | dpg.add_slider_float( 309 | label="Log vis_scale_const (For debugging)", 310 | default_value=-3, 311 | max_value=-.5, 312 | min_value=-5, 313 | callback=callback_vis_scale_const, 314 | ) 315 | 316 | # save current model 317 | with dpg.group(horizontal=True): 318 | dpg.add_text("Temporal Speed: ") 319 | self.video_speed = 1. 320 | def callback_speed_control(sender): 321 | self.video_speed = dpg.get_value(sender) 322 | self.need_update = True 323 | dpg.add_slider_float( 324 | label="Play speed", 325 | default_value=1., 326 | max_value=2., 327 | min_value=0.0, 328 | callback=callback_speed_control, 329 | ) 330 | 331 | # save current model 332 | with dpg.group(horizontal=True): 333 | dpg.add_text("Save: ") 334 | 335 | def callback_save(sender, app_data, user_data): 336 | self.save_model(mode=user_data) 337 | 338 | dpg.add_button( 339 | label="model", 340 | tag="_button_save_model", 341 | callback=callback_save, 342 | user_data='model', 343 | ) 344 | dpg.bind_item_theme("_button_save_model", theme_button) 345 | 346 | dpg.add_button( 347 | label="geo", 348 | tag="_button_save_mesh", 349 | callback=callback_save, 350 | user_data='geo', 351 | ) 352 | dpg.bind_item_theme("_button_save_mesh", theme_button) 353 | 354 | dpg.add_button( 355 | label="geo+tex", 356 | tag="_button_save_mesh_with_tex", 357 | callback=callback_save, 358 | user_data='geo+tex', 359 | ) 360 | dpg.bind_item_theme("_button_save_mesh_with_tex", theme_button) 361 | 362 | dpg.add_button( 363 | label="pcl", 364 | tag="_button_save_pcl", 365 | callback=callback_save, 366 | user_data='pcl', 367 | ) 368 | dpg.bind_item_theme("_button_save_pcl", theme_button) 369 | 370 | def call_back_save_train(sender, app_data, user_data): 371 | self.render_all_train_data() 372 | dpg.add_button( 373 | label="save_train", 374 | tag="_button_save_train", 375 | callback=call_back_save_train, 376 | ) 377 | 378 | # training stuff 379 | with dpg.collapsing_header(label="Train", default_open=True): 380 | # lr and train button 381 | with dpg.group(horizontal=True): 382 | dpg.add_text("Train: ") 383 | 384 | def callback_train(sender, app_data): 385 | if self.training: 386 | self.training = False 387 | dpg.configure_item("_button_train", label="start") 388 | else: 389 | # self.prepare_train() 390 | self.training = True 391 | dpg.configure_item("_button_train", label="stop") 392 | 393 | dpg.add_button( 394 | label="start", tag="_button_train", callback=callback_train 395 | ) 396 | dpg.bind_item_theme("_button_train", theme_button) 397 | 398 | with dpg.group(horizontal=True): 399 | dpg.add_text("", tag="_log_train_psnr") 400 | dpg.add_text("", tag="_log_train_log") 401 | 402 | # rendering options 403 | with dpg.collapsing_header(label="Rendering", default_open=True): 404 | # mode combo 405 | def callback_change_mode(sender, app_data): 406 | self.mode = app_data 407 | self.need_update = True 408 | 409 | dpg.add_combo( 410 | ("render", "depth"), 411 | label="mode", 412 | default_value=self.mode, 413 | callback=callback_change_mode, 414 | ) 415 | 416 | # fov slider 417 | def callback_set_fovy(sender, app_data): 418 | self.cam.fovy = np.deg2rad(app_data) 419 | self.need_update = True 420 | 421 | dpg.add_slider_int( 422 | label="FoV (vertical)", 423 | min_value=1, 424 | max_value=120, 425 | format="%d deg", 426 | default_value=np.rad2deg(self.cam.fovy), 427 | callback=callback_set_fovy, 428 | ) 429 | 430 | ### register camera handler 431 | 432 | def callback_camera_drag_rotate_or_draw_mask(sender, app_data): 433 | if not dpg.is_item_focused("_primary_window"): 434 | return 435 | 436 | dx = app_data[1] 437 | dy = app_data[2] 438 | 439 | self.cam.orbit(dx, dy) 440 | self.need_update = True 441 | 442 | def callback_camera_wheel_scale(sender, app_data): 443 | if not dpg.is_item_focused("_primary_window"): 444 | return 445 | 446 | delta = app_data 447 | 448 | self.cam.scale(delta) 449 | self.need_update = True 450 | 451 | def callback_camera_drag_pan(sender, app_data): 452 | if not dpg.is_item_focused("_primary_window"): 453 | return 454 | 455 | dx = app_data[1] 456 | dy = app_data[2] 457 | 458 | self.cam.pan(dx, dy) 459 | self.need_update = True 460 | 461 | with dpg.handler_registry(): 462 | # for camera moving 463 | dpg.add_mouse_drag_handler( 464 | button=dpg.mvMouseButton_Left, 465 | callback=callback_camera_drag_rotate_or_draw_mask, 466 | ) 467 | dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) 468 | dpg.add_mouse_drag_handler( 469 | button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan 470 | ) 471 | 472 | dpg.create_viewport( 473 | title="Deformable-Gaussian", 474 | width=self.W + 600, 475 | height=self.H + (45 if os.name == "nt" else 0), 476 | resizable=False, 477 | ) 478 | 479 | ### global theme 480 | with dpg.theme() as theme_no_padding: 481 | with dpg.theme_component(dpg.mvAll): 482 | # set all padding to 0 to avoid scroll bar 483 | dpg.add_theme_style( 484 | dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core 485 | ) 486 | dpg.add_theme_style( 487 | dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core 488 | ) 489 | dpg.add_theme_style( 490 | dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core 491 | ) 492 | 493 | dpg.bind_item_theme("_primary_window", theme_no_padding) 494 | 495 | dpg.setup_dearpygui() 496 | 497 | ### register a larger font 498 | # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf 499 | if os.path.exists("LXGWWenKai-Regular.ttf"): 500 | with dpg.font_registry(): 501 | with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: 502 | dpg.bind_font(default_font) 503 | 504 | # dpg.show_metrics() 505 | 506 | dpg.show_viewport() 507 | 508 | def render(self): 509 | assert self.gui 510 | while dpg.is_dearpygui_running(): 511 | # update texture every frame 512 | if self.training: 513 | self.train_step() 514 | self.test_step() 515 | dpg.render_dearpygui_frame() 516 | 517 | # no gui mode 518 | def train(self, iters=5000): 519 | if iters > 0: 520 | for i in tqdm.trange(iters): 521 | self.train_step() 522 | 523 | 524 | def train_step(self): 525 | if network_gui.conn == None: 526 | network_gui.try_connect() 527 | while network_gui.conn != None: 528 | try: 529 | net_image_bytes = None 530 | custom_cam, do_training, self.pipe.do_shs_python, self.pipe.do_cov_python, keep_alive, scaling_modifer = network_gui.receive() 531 | if custom_cam != None: 532 | net_image = render(custom_cam, self.gaussians, self.pipe, self.background, scaling_modifer)["render"] 533 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 534 | 0).contiguous().cpu().numpy()) 535 | network_gui.send(net_image_bytes, self.dataset.source_path) 536 | if do_training and ((self.iteration < int(self.opt.iterations)) or not keep_alive): 537 | break 538 | except Exception as e: 539 | network_gui.conn = None 540 | 541 | self.iter_start.record() 542 | 543 | # Every 1000 its we increase the levels of SH up to a maximum degree 544 | if self.iteration % 1000 == 0: 545 | self.gaussians.oneupSHdegree() 546 | 547 | # Pick a random Camera 548 | if not self.viewpoint_stack: 549 | self.viewpoint_stack = self.scene.getTrainCameras().copy() 550 | 551 | total_frame = len(self.viewpoint_stack) 552 | time_interval = 1 / total_frame 553 | 554 | viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1)) 555 | if self.dataset.load2gpu_on_the_fly: 556 | viewpoint_cam.load2device() 557 | fid = viewpoint_cam.fid 558 | 559 | if self.iteration < self.opt.warm_up: 560 | d_xyz, d_rotation, d_scaling = 0.0, 0.0, 0.0 561 | else: 562 | N = self.gaussians.get_xyz.shape[0] 563 | time_input = fid.unsqueeze(0).expand(N, -1) 564 | ast_noise = 0 if self.dataset.is_blender else torch.randn(1, 1, device='cuda').expand(N, -1) * time_interval * self.smooth_term(self.iteration) 565 | d_xyz, d_rotation, d_scaling = self.deform.step(self.gaussians.get_xyz.detach(), time_input + ast_noise) 566 | 567 | # Render 568 | render_pkg_re = render(viewpoint_cam, self.gaussians, self.pipe, self.background, d_xyz, d_rotation, d_scaling, self.dataset.is_6dof) 569 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg_re["render"], render_pkg_re[ 570 | "viewspace_points"], render_pkg_re["visibility_filter"], render_pkg_re["radii"] 571 | # depth = render_pkg_re["depth"] 572 | 573 | # Loss 574 | gt_image = viewpoint_cam.original_image.cuda() 575 | Ll1 = l1_loss(image, gt_image) 576 | loss = (1.0 - self.opt.lambda_dssim) * Ll1 + self.opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 577 | loss.backward() 578 | 579 | self.iter_end.record() 580 | 581 | if self.dataset.load2gpu_on_the_fly: 582 | viewpoint_cam.load2device('cpu') 583 | 584 | with torch.no_grad(): 585 | # Progress bar 586 | self.ema_loss_for_log = 0.4 * loss.item() + 0.6 * self.ema_loss_for_log 587 | if self.iteration % 10 == 0: 588 | self.progress_bar.set_postfix({"Loss": f"{self.ema_loss_for_log:.{7}f}"}) 589 | self.progress_bar.update(10) 590 | if self.iteration == self.opt.iterations: 591 | self.progress_bar.close() 592 | 593 | # Keep track of max radii in image-space for pruning 594 | self.gaussians.max_radii2D[visibility_filter] = torch.max(self.gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 595 | 596 | # Log and save 597 | cur_psnr = training_report(self.tb_writer, self.iteration, Ll1, loss, l1_loss, self.iter_start.elapsed_time(self.iter_end), self.testing_iterations, self.scene, render, (self.pipe, self.background), self.deform, self.dataset.load2gpu_on_the_fly, self.dataset.is_6dof) 598 | if self.iteration in self.testing_iterations: 599 | if cur_psnr.item() > self.best_psnr: 600 | self.best_psnr = cur_psnr.item() 601 | self.best_iteration = self.iteration 602 | 603 | if self.iteration in self.saving_iterations: 604 | print("\n[ITER {}] Saving Gaussians".format(self.iteration)) 605 | self.scene.save(self.iteration) 606 | self.deform.save_weights(args.model_path, self.iteration) 607 | 608 | # Densification 609 | if self.iteration < self.opt.densify_until_iter: 610 | self.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 611 | 612 | if self.iteration > self.opt.densify_from_iter and self.iteration % self.opt.densification_interval == 0: 613 | size_threshold = 20 if self.iteration > self.opt.opacity_reset_interval else None 614 | self.gaussians.densify_and_prune(self.opt.densify_grad_threshold, 0.005, self.scene.cameras_extent, size_threshold) 615 | 616 | if self.iteration % self.opt.opacity_reset_interval == 0 or ( 617 | self.dataset.white_background and self.iteration == self.opt.densify_from_iter): 618 | self.gaussians.reset_opacity() 619 | 620 | # Optimizer step 621 | if self.iteration < self.opt.iterations: 622 | self.gaussians.optimizer.step() 623 | self.gaussians.update_learning_rate(self.iteration) 624 | self.gaussians.optimizer.zero_grad(set_to_none=True) 625 | self.deform.optimizer.step() 626 | self.deform.optimizer.zero_grad() 627 | self.deform.update_learning_rate(self.iteration) 628 | 629 | if self.gui: 630 | dpg.set_value( 631 | "_log_train_psnr", 632 | "Best PSNR = {} in Iteration {}".format(self.best_psnr, self.best_iteration) 633 | ) 634 | else: 635 | print("Best PSNR = {} in Iteration {}".format(self.best_psnr, self.best_iteration)) 636 | self.iteration += 1 637 | 638 | if self.gui: 639 | dpg.set_value( 640 | "_log_train_log", 641 | f"step = {self.iteration: 5d} loss = {loss.item():.4f}", 642 | ) 643 | 644 | @torch.no_grad() 645 | def test_step(self): 646 | 647 | starter = torch.cuda.Event(enable_timing=True) 648 | ender = torch.cuda.Event(enable_timing=True) 649 | starter.record() 650 | 651 | if not hasattr(self, 't0'): 652 | self.t0 = time.time() 653 | self.fps_of_fid = 10 654 | 655 | cur_cam = MiniCam( 656 | self.cam.pose, 657 | self.W, 658 | self.H, 659 | self.cam.fovy, 660 | self.cam.fovx, 661 | self.cam.near, 662 | self.cam.far, 663 | fid=torch.remainder(torch.tensor((time.time()-self.t0) * self.fps_of_fid).float().cuda() / len(self.scene.getTrainCameras()), 1.) 664 | ) 665 | fid = cur_cam.fid 666 | 667 | if self.iteration < self.opt.warm_up: 668 | d_xyz, d_rotation, d_scaling = 0.0, 0.0, 0.0 669 | else: 670 | N = self.gaussians.get_xyz.shape[0] 671 | time_input = fid.unsqueeze(0).expand(N, -1) 672 | d_xyz, d_rotation, d_scaling = self.deform.step(self.gaussians.get_xyz.detach(), time_input) 673 | 674 | out = render(viewpoint_camera=cur_cam, pc=self.gaussians, pipe=self.pipe, bg_color=self.background, d_xyz=d_xyz, d_rotation=d_rotation, d_scaling=d_scaling, is_6dof=self.dataset.is_6dof) 675 | 676 | buffer_image = out[self.mode] # [3, H, W] 677 | 678 | if self.mode in ['depth', 'alpha']: 679 | buffer_image = buffer_image.repeat(3, 1, 1) 680 | if self.mode == 'depth': 681 | buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20) 682 | 683 | buffer_image = torch.nn.functional.interpolate( 684 | buffer_image.unsqueeze(0), 685 | size=(self.H, self.W), 686 | mode="bilinear", 687 | align_corners=False, 688 | ).squeeze(0) 689 | 690 | self.buffer_image = ( 691 | buffer_image.permute(1, 2, 0) 692 | .contiguous() 693 | .clamp(0, 1) 694 | .contiguous() 695 | .detach() 696 | .cpu() 697 | .numpy() 698 | ) 699 | 700 | self.need_update = True 701 | 702 | ender.record() 703 | torch.cuda.synchronize() 704 | t = starter.elapsed_time(ender) 705 | 706 | if self.gui: 707 | dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS FID: {fid.item()})") 708 | dpg.set_value( 709 | "_texture", self.buffer_image 710 | ) # buffer must be contiguous, else seg fault! 711 | 712 | # no gui mode 713 | def train(self, iters=5000): 714 | if iters > 0: 715 | for i in tqdm.trange(iters): 716 | self.train_step() 717 | 718 | def prepare_output_and_logger(args): 719 | if not args.model_path: 720 | if os.getenv('OAR_JOB_ID'): 721 | unique_str = os.getenv('OAR_JOB_ID') 722 | else: 723 | unique_str = str(uuid.uuid4()) 724 | args.model_path = os.path.join("./output/", unique_str[0:10]) 725 | 726 | # Set up output folder 727 | print("Output folder: {}".format(args.model_path)) 728 | os.makedirs(args.model_path, exist_ok=True) 729 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 730 | cfg_log_f.write(str(Namespace(**vars(args)))) 731 | 732 | # Create Tensorboard writer 733 | tb_writer = None 734 | if TENSORBOARD_FOUND: 735 | tb_writer = SummaryWriter(args.model_path) 736 | else: 737 | print("Tensorboard not available: not logging progress") 738 | return tb_writer 739 | 740 | 741 | if __name__ == "__main__": 742 | # Set up command line argument parser 743 | parser = ArgumentParser(description="Training script parameters") 744 | lp = ModelParams(parser) 745 | op = OptimizationParams(parser) 746 | pp = PipelineParams(parser) 747 | 748 | parser.add_argument('--gui', action='store_false', help="start a GUI") 749 | parser.add_argument('--W', type=int, default=800, help="GUI width") 750 | parser.add_argument('--H', type=int, default=800, help="GUI height") 751 | parser.add_argument('--elevation', type=float, default=0, help="default GUI camera elevation") 752 | parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") 753 | parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy") 754 | 755 | parser.add_argument('--ip', type=str, default="127.0.0.1") 756 | parser.add_argument('--port', type=int, default=6009) 757 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 758 | parser.add_argument("--test_iterations", nargs="+", type=int, 759 | default=[5000, 6000, 7_000] + list(range(10000, 40001, 1000))) 760 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 10_000, 20_000, 30_000, 40000]) 761 | parser.add_argument("--quiet", action="store_true") 762 | args = parser.parse_args(sys.argv[1:]) 763 | args.save_iterations.append(args.iterations) 764 | 765 | print("Optimizing " + args.model_path) 766 | 767 | # Initialize system state (RNG) 768 | safe_state(args.quiet) 769 | 770 | # Start GUI server, configure and run training 771 | # network_gui.init(args.ip, args.port) 772 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 773 | gui = GUI(args=args, dataset=lp.extract(args), opt=op.extract(args), pipe=pp.extract(args),testing_iterations=args.test_iterations, saving_iterations=args.save_iterations) 774 | 775 | if args.gui: 776 | gui.render() 777 | # else: 778 | # gui.train(args.iterations) 779 | 780 | # All done 781 | print("\nTraining complete.") 782 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch, ArrayToTorch 15 | from utils.graphics_utils import fov2focal 16 | import json 17 | 18 | WARNED = False 19 | 20 | 21 | def loadCam(args, id, cam_info, resolution_scale): 22 | orig_w, orig_h = cam_info.image.size 23 | 24 | if args.resolution in [1, 2, 4, 8]: 25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round( 26 | orig_h / (resolution_scale * args.resolution)) 27 | else: # should be a type that converts to float 28 | if args.resolution == -1: 29 | if orig_w > 1600: 30 | global WARNED 31 | if not WARNED: 32 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 33 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 34 | WARNED = True 35 | global_down = orig_w / 1600 36 | else: 37 | global_down = 1 38 | else: 39 | global_down = orig_w / args.resolution 40 | 41 | scale = float(global_down) * float(resolution_scale) 42 | resolution = (int(orig_w / scale), int(orig_h / scale)) 43 | 44 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 45 | 46 | gt_image = resized_image_rgb[:3, ...] 47 | loaded_mask = None 48 | 49 | if resized_image_rgb.shape[1] == 4: 50 | loaded_mask = resized_image_rgb[3:4, ...] 51 | 52 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 53 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 54 | image=gt_image, gt_alpha_mask=loaded_mask, 55 | image_name=cam_info.image_name, uid=id, 56 | data_device=args.data_device if not args.load2gpu_on_the_fly else 'cpu', fid=cam_info.fid, 57 | depth=cam_info.depth) 58 | 59 | 60 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 61 | camera_list = [] 62 | 63 | for id, c in enumerate(cam_infos): 64 | camera_list.append(loadCam(args, id, c, resolution_scale)) 65 | 66 | return camera_list 67 | 68 | 69 | def camera_to_JSON(id, camera: Camera): 70 | Rt = np.zeros((4, 4)) 71 | Rt[:3, :3] = camera.R.transpose() 72 | Rt[:3, 3] = camera.T 73 | Rt[3, 3] = 1.0 74 | 75 | W2C = np.linalg.inv(Rt) 76 | pos = W2C[:3, 3] 77 | rot = W2C[:3, :3] 78 | serializable_array_2d = [x.tolist() for x in rot] 79 | camera_entry = { 80 | 'id': id, 81 | 'img_name': camera.image_name, 82 | 'width': camera.width, 83 | 'height': camera.height, 84 | 'position': pos.tolist(), 85 | 'rotation': serializable_array_2d, 86 | 'fy': fov2focal(camera.FovY, camera.height), 87 | 'fx': fov2focal(camera.FovX, camera.width) 88 | } 89 | return camera_entry 90 | 91 | 92 | def camera_nerfies_from_JSON(path, scale): 93 | """Loads a JSON camera into memory.""" 94 | with open(path, 'r') as fp: 95 | camera_json = json.load(fp) 96 | 97 | # Fix old camera JSON. 98 | if 'tangential' in camera_json: 99 | camera_json['tangential_distortion'] = camera_json['tangential'] 100 | 101 | return dict( 102 | orientation=np.array(camera_json['orientation']), 103 | position=np.array(camera_json['position']), 104 | focal_length=camera_json['focal_length'] * scale, 105 | principal_point=np.array(camera_json['principal_point']) * scale, 106 | skew=camera_json['skew'], 107 | pixel_aspect_ratio=camera_json['pixel_aspect_ratio'], 108 | radial_distortion=np.array(camera_json['radial_distortion']), 109 | tangential_distortion=np.array(camera_json['tangential_distortion']), 110 | image_size=np.array((int(round(camera_json['image_size'][0] * scale)), 111 | int(round(camera_json['image_size'][1] * scale)))), 112 | ) 113 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x / (1 - x)) 21 | 22 | 23 | def PILtoTorch(pil_image, resolution): 24 | resized_image_PIL = pil_image.resize(resolution) 25 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 26 | if len(resized_image.shape) == 3: 27 | return resized_image.permute(2, 0, 1) 28 | else: 29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 30 | 31 | 32 | def ArrayToTorch(array, resolution): 33 | # resized_image = np.resize(array, resolution) 34 | resized_image_torch = torch.from_numpy(array) 35 | 36 | if len(resized_image_torch.shape) == 3: 37 | return resized_image_torch.permute(2, 0, 1) 38 | else: 39 | return resized_image_torch.unsqueeze(dim=-1).permute(2, 0, 1) 40 | 41 | 42 | def get_expon_lr_func( 43 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 44 | ): 45 | """ 46 | Copied from Plenoxels 47 | 48 | Continuous learning rate decay function. Adapted from JaxNeRF 49 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 50 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 51 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 52 | function of lr_delay_mult, such that the initial learning rate is 53 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 54 | to the normal learning rate when steps>lr_delay_steps. 55 | :param conf: config subtree 'lr' or similar 56 | :param max_steps: int, the number of steps during optimization. 57 | :return HoF which takes step as input 58 | """ 59 | 60 | def helper(step): 61 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 62 | # Disable this parameter 63 | return 0.0 64 | if lr_delay_steps > 0: 65 | # A kind of reverse cosine decay. 66 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 67 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 68 | ) 69 | else: 70 | delay_rate = 1.0 71 | t = np.clip(step / max_steps, 0, 1) 72 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 73 | return delay_rate * log_lerp 74 | 75 | return helper 76 | 77 | 78 | def get_linear_noise_func( 79 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 80 | ): 81 | """ 82 | Copied from Plenoxels 83 | 84 | Continuous learning rate decay function. Adapted from JaxNeRF 85 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 86 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 87 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 88 | function of lr_delay_mult, such that the initial learning rate is 89 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 90 | to the normal learning rate when steps>lr_delay_steps. 91 | :param conf: config subtree 'lr' or similar 92 | :param max_steps: int, the number of steps during optimization. 93 | :return HoF which takes step as input 94 | """ 95 | 96 | def helper(step): 97 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 98 | # Disable this parameter 99 | return 0.0 100 | if lr_delay_steps > 0: 101 | # A kind of reverse cosine decay. 102 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 103 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 104 | ) 105 | else: 106 | delay_rate = 1.0 107 | t = np.clip(step / max_steps, 0, 1) 108 | log_lerp = lr_init * (1 - t) + lr_final * t 109 | return delay_rate * log_lerp 110 | 111 | return helper 112 | 113 | 114 | def strip_lowerdiag(L): 115 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 116 | 117 | uncertainty[:, 0] = L[:, 0, 0] 118 | uncertainty[:, 1] = L[:, 0, 1] 119 | uncertainty[:, 2] = L[:, 0, 2] 120 | uncertainty[:, 3] = L[:, 1, 1] 121 | uncertainty[:, 4] = L[:, 1, 2] 122 | uncertainty[:, 5] = L[:, 2, 2] 123 | return uncertainty 124 | 125 | 126 | def strip_symmetric(sym): 127 | return strip_lowerdiag(sym) 128 | 129 | 130 | def build_rotation(r): 131 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]) 132 | 133 | q = r / norm[:, None] 134 | 135 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 136 | 137 | r = q[:, 0] 138 | x = q[:, 1] 139 | y = q[:, 2] 140 | z = q[:, 3] 141 | 142 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 143 | R[:, 0, 1] = 2 * (x * y - r * z) 144 | R[:, 0, 2] = 2 * (x * z + r * y) 145 | R[:, 1, 0] = 2 * (x * y + r * z) 146 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 147 | R[:, 1, 2] = 2 * (y * z - r * x) 148 | R[:, 2, 0] = 2 * (x * z - r * y) 149 | R[:, 2, 1] = 2 * (y * z + r * x) 150 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 151 | return R 152 | 153 | 154 | def build_scaling_rotation(s, r): 155 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 156 | R = build_rotation(r) 157 | 158 | L[:, 0, 0] = s[:, 0] 159 | L[:, 1, 1] = s[:, 1] 160 | L[:, 2, 2] = s[:, 2] 161 | 162 | L = R @ L 163 | return L 164 | 165 | 166 | def safe_state(silent): 167 | old_f = sys.stdout 168 | 169 | class F: 170 | def __init__(self, silent): 171 | self.silent = silent 172 | 173 | def write(self, x): 174 | if not self.silent: 175 | if x.endswith("\n"): 176 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 177 | else: 178 | old_f.write(x) 179 | 180 | def flush(self): 181 | old_f.flush() 182 | 183 | sys.stdout = F(silent) 184 | 185 | random.seed(0) 186 | np.random.seed(0) 187 | torch.manual_seed(0) 188 | torch.cuda.set_device(torch.device("cuda:0")) 189 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | 18 | class BasicPointCloud(NamedTuple): 19 | points: np.array 20 | colors: np.array 21 | normals: np.array 22 | 23 | 24 | def geom_transform_points(points, transf_matrix): 25 | P, _ = points.shape 26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 27 | points_hom = torch.cat([points, ones], dim=1) 28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 29 | 30 | denom = points_out[..., 3:] + 0.0000001 31 | return (points_out[..., :3] / denom).squeeze(dim=0) 32 | 33 | 34 | def getWorld2View(R, t): 35 | Rt = np.zeros((4, 4)) 36 | Rt[:3, :3] = R.transpose() 37 | Rt[:3, 3] = t 38 | Rt[3, 3] = 1.0 39 | return np.float32(Rt) 40 | 41 | 42 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 43 | Rt = np.zeros((4, 4)) 44 | Rt[:3, :3] = R.transpose() 45 | Rt[:3, 3] = t 46 | Rt[3, 3] = 1.0 47 | 48 | C2W = np.linalg.inv(Rt) 49 | cam_center = C2W[:3, 3] 50 | cam_center = (cam_center + translate) * scale 51 | C2W[:3, 3] = cam_center 52 | Rt = np.linalg.inv(C2W) 53 | return np.float32(Rt) 54 | 55 | 56 | def getProjectionMatrix(znear, zfar, fovX, fovY): 57 | tanHalfFovY = math.tan((fovY / 2)) 58 | tanHalfFovX = math.tan((fovX / 2)) 59 | 60 | top = tanHalfFovY * znear 61 | bottom = -top 62 | right = tanHalfFovX * znear 63 | left = -right 64 | 65 | P = torch.zeros(4, 4) 66 | 67 | z_sign = 1.0 68 | 69 | P[0, 0] = 2.0 * znear / (right - left) 70 | P[1, 1] = 2.0 * znear / (top - bottom) 71 | P[0, 2] = (right + left) / (right - left) 72 | P[1, 2] = (top + bottom) / (top - bottom) 73 | P[3, 2] = z_sign 74 | P[2, 2] = z_sign * zfar / (zfar - znear) 75 | P[2, 3] = -(zfar * znear) / (zfar - znear) 76 | return P 77 | 78 | 79 | def fov2focal(fov, pixels): 80 | return pixels / (2 * math.tan(fov / 2)) 81 | 82 | 83 | def focal2fov(focal, pixels): 84 | return 2 * math.atan(pixels / (2 * focal)) 85 | -------------------------------------------------------------------------------- /utils/gui_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation as R 3 | 4 | import torch 5 | 6 | def dot(x, y): 7 | if isinstance(x, np.ndarray): 8 | return np.sum(x * y, -1, keepdims=True) 9 | else: 10 | return torch.sum(x * y, -1, keepdim=True) 11 | 12 | 13 | def length(x, eps=1e-20): 14 | if isinstance(x, np.ndarray): 15 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 16 | else: 17 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 18 | 19 | 20 | def safe_normalize(x, eps=1e-20): 21 | return x / length(x, eps) 22 | 23 | 24 | def look_at(campos, target, opengl=True): 25 | # campos: [N, 3], camera/eye position 26 | # target: [N, 3], object to look at 27 | # return: [N, 3, 3], rotation matrix 28 | if not opengl: 29 | # camera forward aligns with -z 30 | forward_vector = safe_normalize(target - campos) 31 | up_vector = np.array([0, 1, 0], dtype=np.float32) 32 | right_vector = safe_normalize(np.cross(forward_vector, up_vector)) 33 | up_vector = safe_normalize(np.cross(right_vector, forward_vector)) 34 | else: 35 | # camera forward aligns with +z 36 | forward_vector = safe_normalize(campos - target) 37 | up_vector = np.array([0, 1, 0], dtype=np.float32) 38 | right_vector = safe_normalize(np.cross(up_vector, forward_vector)) 39 | up_vector = safe_normalize(np.cross(forward_vector, right_vector)) 40 | R = np.stack([right_vector, up_vector, forward_vector], axis=1) 41 | return R 42 | 43 | 44 | # elevation & azimuth to pose (cam2world) matrix 45 | def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True): 46 | # radius: scalar 47 | # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90) 48 | # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90) 49 | # return: [4, 4], camera pose matrix 50 | if is_degree: 51 | elevation = np.deg2rad(elevation) 52 | azimuth = np.deg2rad(azimuth) 53 | x = radius * np.cos(elevation) * np.sin(azimuth) 54 | y = - radius * np.sin(elevation) 55 | z = radius * np.cos(elevation) * np.cos(azimuth) 56 | if target is None: 57 | target = np.zeros([3], dtype=np.float32) 58 | campos = np.array([x, y, z]) + target # [3] 59 | T = np.eye(4, dtype=np.float32) 60 | T[:3, :3] = look_at(campos, target, opengl) 61 | T[:3, 3] = campos 62 | return T 63 | 64 | 65 | class OrbitCamera: 66 | def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100): 67 | self.W = W 68 | self.H = H 69 | self.radius = r # camera distance from center 70 | self.fovy = np.deg2rad(fovy) # deg 2 rad 71 | self.near = near 72 | self.far = far 73 | self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point 74 | # self.rot = R.from_matrix(np.eye(3)) 75 | self.rot = R.from_matrix(np.array([[1., 0., 0.,], 76 | [0., 0., -1.], 77 | [0., 1., 0.]])) 78 | self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! 79 | self.side = np.array([1, 0, 0], dtype=np.float32) 80 | 81 | @property 82 | def fovx(self): 83 | return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H) 84 | 85 | @property 86 | def campos(self): 87 | return self.pose[:3, 3] 88 | 89 | # pose (c2w) 90 | @property 91 | def pose(self): 92 | # first move camera to radius 93 | res = np.eye(4, dtype=np.float32) 94 | res[2, 3] = self.radius # opengl convention... 95 | # rotate 96 | rot = np.eye(4, dtype=np.float32) 97 | rot[:3, :3] = self.rot.as_matrix() 98 | res = rot @ res 99 | # translate 100 | res[:3, 3] -= self.center 101 | return res 102 | 103 | # view (w2c) 104 | @property 105 | def view(self): 106 | return np.linalg.inv(self.pose) 107 | 108 | # projection (perspective) 109 | @property 110 | def perspective(self): 111 | y = np.tan(self.fovy / 2) 112 | aspect = self.W / self.H 113 | return np.array( 114 | [ 115 | [1 / (y * aspect), 0, 0, 0], 116 | [0, -1 / y, 0, 0], 117 | [ 118 | 0, 119 | 0, 120 | -(self.far + self.near) / (self.far - self.near), 121 | -(2 * self.far * self.near) / (self.far - self.near), 122 | ], 123 | [0, 0, -1, 0], 124 | ], 125 | dtype=np.float32, 126 | ) 127 | 128 | # intrinsics 129 | @property 130 | def intrinsics(self): 131 | focal = self.H / (2 * np.tan(self.fovy / 2)) 132 | return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32) 133 | 134 | @property 135 | def mvp(self): 136 | return self.perspective @ np.linalg.inv(self.pose) # [4, 4] 137 | 138 | def orbit(self, dx, dy): 139 | # rotate along camera up/side axis! 140 | side = self.rot.as_matrix()[:3, 0] 141 | up = self.rot.as_matrix()[:3, 1] 142 | rotvec_x = up * np.radians(-0.05 * dx) 143 | rotvec_y = side * np.radians(-0.05 * dy) 144 | self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot 145 | 146 | def scale(self, delta): 147 | self.radius *= 1.1 ** (-delta) 148 | 149 | def pan(self, dx, dy, dz=0): 150 | # pan in camera coordinate system (careful on the sensitivity!) 151 | self.center += 0.0001 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz]) 152 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | 15 | def mse(img1, img2): 16 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 17 | 18 | 19 | def psnr(img1, img2): 20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 21 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 22 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | 18 | def l1_loss(network_output, gt): 19 | return torch.abs((network_output - gt)).mean() 20 | 21 | 22 | def kl_divergence(rho, rho_hat): 23 | rho_hat = torch.mean(torch.sigmoid(rho_hat), 0) 24 | rho = torch.tensor([rho] * len(rho_hat)).cuda() 25 | return torch.mean( 26 | rho * torch.log(rho / (rho_hat + 1e-5)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + 1e-5))) 27 | 28 | 29 | def l2_loss(network_output, gt): 30 | return ((network_output - gt) ** 2).mean() 31 | 32 | 33 | def gaussian(window_size, sigma): 34 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 35 | return gauss / gauss.sum() 36 | 37 | 38 | def create_window(window_size, channel): 39 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 40 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 41 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 42 | return window 43 | 44 | 45 | def ssim(img1, img2, window_size=11, size_average=True): 46 | channel = img1.size(-3) 47 | window = create_window(window_size, channel) 48 | 49 | if img1.is_cuda: 50 | window = window.cuda(img1.get_device()) 51 | window = window.type_as(img1) 52 | 53 | return _ssim(img1, img2, window, window_size, channel, size_average) 54 | 55 | 56 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 57 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 58 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 59 | 60 | mu1_sq = mu1.pow(2) 61 | mu2_sq = mu2.pow(2) 62 | mu1_mu2 = mu1 * mu2 63 | 64 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 65 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 66 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 67 | 68 | C1 = 0.01 ** 2 69 | C2 = 0.03 ** 2 70 | 71 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 72 | 73 | if size_average: 74 | return ssim_map.mean() 75 | else: 76 | return ssim_map.mean(1).mean(1).mean(1) 77 | -------------------------------------------------------------------------------- /utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.graphics_utils import fov2focal 4 | 5 | trans_t = lambda t: torch.Tensor([ 6 | [1, 0, 0, 0], 7 | [0, 1, 0, 0], 8 | [0, 0, 1, t], 9 | [0, 0, 0, 1]]).float() 10 | 11 | rot_phi = lambda phi: torch.Tensor([ 12 | [1, 0, 0, 0], 13 | [0, np.cos(phi), -np.sin(phi), 0], 14 | [0, np.sin(phi), np.cos(phi), 0], 15 | [0, 0, 0, 1]]).float() 16 | 17 | rot_theta = lambda th: torch.Tensor([ 18 | [np.cos(th), 0, -np.sin(th), 0], 19 | [0, 1, 0, 0], 20 | [np.sin(th), 0, np.cos(th), 0], 21 | [0, 0, 0, 1]]).float() 22 | 23 | 24 | def rodrigues_mat_to_rot(R): 25 | eps = 1e-16 26 | trc = np.trace(R) 27 | trc2 = (trc - 1.) / 2. 28 | # sinacostrc2 = np.sqrt(1 - trc2 * trc2) 29 | s = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]) 30 | if (1 - trc2 * trc2) >= eps: 31 | tHeta = np.arccos(trc2) 32 | tHetaf = tHeta / (2 * (np.sin(tHeta))) 33 | else: 34 | tHeta = np.real(np.arccos(trc2)) 35 | tHetaf = 0.5 / (1 - tHeta / 6) 36 | omega = tHetaf * s 37 | return omega 38 | 39 | 40 | def rodrigues_rot_to_mat(r): 41 | wx, wy, wz = r 42 | theta = np.sqrt(wx * wx + wy * wy + wz * wz) 43 | a = np.cos(theta) 44 | b = (1 - np.cos(theta)) / (theta * theta) 45 | c = np.sin(theta) / theta 46 | R = np.zeros([3, 3]) 47 | R[0, 0] = a + b * (wx * wx) 48 | R[0, 1] = b * wx * wy - c * wz 49 | R[0, 2] = b * wx * wz + c * wy 50 | R[1, 0] = b * wx * wy + c * wz 51 | R[1, 1] = a + b * (wy * wy) 52 | R[1, 2] = b * wy * wz - c * wx 53 | R[2, 0] = b * wx * wz - c * wy 54 | R[2, 1] = b * wz * wy + c * wx 55 | R[2, 2] = a + b * (wz * wz) 56 | return R 57 | 58 | 59 | def pose_spherical(theta, phi, radius): 60 | c2w = trans_t(radius) 61 | c2w = rot_phi(phi / 180. * np.pi) @ c2w 62 | c2w = rot_theta(theta / 180. * np.pi) @ c2w 63 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w 64 | return c2w 65 | 66 | 67 | def render_wander_path(view): 68 | focal_length = fov2focal(view.FoVy, view.image_height) 69 | R = view.R 70 | R[:, 1] = -R[:, 1] 71 | R[:, 2] = -R[:, 2] 72 | T = -view.T.reshape(-1, 1) 73 | pose = np.concatenate([R, T], -1) 74 | 75 | num_frames = 60 76 | max_disp = 5000.0 # 64 , 48 77 | 78 | max_trans = max_disp / focal_length # Maximum camera translation to satisfy max_disp parameter 79 | output_poses = [] 80 | 81 | for i in range(num_frames): 82 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames)) 83 | y_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 # * 3.0 / 4.0 84 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 85 | 86 | i_pose = np.concatenate([ 87 | np.concatenate( 88 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1), 89 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :] 90 | ], axis=0) # [np.newaxis, :, :] 91 | 92 | i_pose = np.linalg.inv(i_pose) # torch.tensor(np.linalg.inv(i_pose)).float() 93 | 94 | ref_pose = np.concatenate([pose, np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0) 95 | 96 | render_pose = np.dot(ref_pose, i_pose) 97 | output_poses.append(torch.Tensor(render_pose)) 98 | 99 | return output_poses 100 | -------------------------------------------------------------------------------- /utils/rigid_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def skew(w: torch.Tensor) -> torch.Tensor: 5 | """Build a skew matrix ("cross product matrix") for vector w. 6 | 7 | Modern Robotics Eqn 3.30. 8 | 9 | Args: 10 | w: (N, 3) A 3-vector 11 | 12 | Returns: 13 | W: (N, 3, 3) A skew matrix such that W @ v == w x v 14 | """ 15 | zeros = torch.zeros(w.shape[0], device=w.device) 16 | w_skew_list = [zeros, -w[:, 2], w[:, 1], 17 | w[:, 2], zeros, -w[:, 0], 18 | -w[:, 1], w[:, 0], zeros] 19 | w_skew = torch.stack(w_skew_list, dim=-1).reshape(-1, 3, 3) 20 | return w_skew 21 | 22 | 23 | def rp_to_se3(R: torch.Tensor, p: torch.Tensor) -> torch.Tensor: 24 | """Rotation and translation to homogeneous transform. 25 | 26 | Args: 27 | R: (3, 3) An orthonormal rotation matrix. 28 | p: (3,) A 3-vector representing an offset. 29 | 30 | Returns: 31 | X: (4, 4) The homogeneous transformation matrix described by rotating by R 32 | and translating by p. 33 | """ 34 | bottom_row = torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=R.device).repeat(R.shape[0], 1, 1) 35 | transform = torch.cat([torch.cat([R, p], dim=-1), bottom_row], dim=1) 36 | 37 | return transform 38 | 39 | 40 | def exp_so3(w: torch.Tensor, theta: float) -> torch.Tensor: 41 | """Exponential map from Lie algebra so3 to Lie group SO3. 42 | 43 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula. 44 | 45 | Args: 46 | w: (3,) An axis of rotation. 47 | theta: An angle of rotation. 48 | 49 | Returns: 50 | R: (3, 3) An orthonormal rotation matrix representing a rotation of 51 | magnitude theta about axis w. 52 | """ 53 | W = skew(w) 54 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device) 55 | W_sqr = torch.bmm(W, W) # batch matrix multiplication 56 | R = identity + torch.sin(theta.unsqueeze(-1)) * W + (1.0 - torch.cos(theta.unsqueeze(-1))) * W_sqr 57 | return R 58 | 59 | 60 | def exp_se3(S: torch.Tensor, theta: float) -> torch.Tensor: 61 | """Exponential map from Lie algebra so3 to Lie group SO3. 62 | 63 | Modern Robotics Eqn 3.88. 64 | 65 | Args: 66 | S: (6,) A screw axis of motion. 67 | theta: Magnitude of motion. 68 | 69 | Returns: 70 | a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating 71 | motion of magnitude theta about S for one second. 72 | """ 73 | w, v = torch.split(S, 3, dim=-1) 74 | W = skew(w) 75 | R = exp_so3(w, theta) 76 | 77 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device) 78 | W_sqr = torch.bmm(W, W) 79 | theta = theta.view(-1, 1, 1) 80 | 81 | p = torch.bmm((theta * identity + (1.0 - torch.cos(theta)) * W + (theta - torch.sin(theta)) * W_sqr), 82 | v.unsqueeze(-1)) 83 | return rp_to_se3(R, p) 84 | 85 | 86 | def to_homogenous(v: torch.Tensor) -> torch.Tensor: 87 | """Converts a vector to a homogeneous coordinate vector by appending a 1. 88 | 89 | Args: 90 | v: A tensor representing a vector or batch of vectors. 91 | 92 | Returns: 93 | A tensor with an additional dimension set to 1. 94 | """ 95 | return torch.cat([v, torch.ones_like(v[..., :1])], dim=-1) 96 | 97 | 98 | def from_homogenous(v: torch.Tensor) -> torch.Tensor: 99 | """Converts a homogeneous coordinate vector to a standard vector by dividing by the last element. 100 | 101 | Args: 102 | v: A tensor representing a homogeneous coordinate vector or batch of homogeneous coordinate vectors. 103 | 104 | Returns: 105 | A tensor with the last dimension removed. 106 | """ 107 | return v[..., :3] / v[..., -1:] 108 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | 115 | def RGB2SH(rgb): 116 | return (rgb - 0.5) / C0 117 | 118 | 119 | def SH2RGB(sh): 120 | return sh * C0 + 0.5 121 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | 17 | def mkdir_p(folder_path): 18 | # Creates a directory. equivalent to using mkdir -p on the command line 19 | try: 20 | makedirs(folder_path) 21 | except OSError as exc: # Python >2.5 22 | if exc.errno == EEXIST and path.isdir(folder_path): 23 | pass 24 | else: 25 | raise 26 | 27 | 28 | def searchForMaxIteration(folder): 29 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 30 | return max(saved_iters) 31 | -------------------------------------------------------------------------------- /utils/time_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.rigid_utils import exp_se3 5 | 6 | 7 | def get_embedder(multires, i=1): 8 | if i == -1: 9 | return nn.Identity(), 3 10 | 11 | embed_kwargs = { 12 | 'include_input': True, 13 | 'input_dims': i, 14 | 'max_freq_log2': multires - 1, 15 | 'num_freqs': multires, 16 | 'log_sampling': True, 17 | 'periodic_fns': [torch.sin, torch.cos], 18 | } 19 | 20 | embedder_obj = Embedder(**embed_kwargs) 21 | embed = lambda x, eo=embedder_obj: eo.embed(x) 22 | return embed, embedder_obj.out_dim 23 | 24 | 25 | class Embedder: 26 | def __init__(self, **kwargs): 27 | self.kwargs = kwargs 28 | self.create_embedding_fn() 29 | 30 | def create_embedding_fn(self): 31 | embed_fns = [] 32 | d = self.kwargs['input_dims'] 33 | out_dim = 0 34 | if self.kwargs['include_input']: 35 | embed_fns.append(lambda x: x) 36 | out_dim += d 37 | 38 | max_freq = self.kwargs['max_freq_log2'] 39 | N_freqs = self.kwargs['num_freqs'] 40 | 41 | if self.kwargs['log_sampling']: 42 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 43 | else: 44 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 45 | 46 | for freq in freq_bands: 47 | for p_fn in self.kwargs['periodic_fns']: 48 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 49 | out_dim += d 50 | 51 | self.embed_fns = embed_fns 52 | self.out_dim = out_dim 53 | 54 | def embed(self, inputs): 55 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 56 | 57 | 58 | class DeformNetwork(nn.Module): 59 | def __init__(self, D=8, W=256, input_ch=3, output_ch=59, multires=10, is_blender=False, is_6dof=False): 60 | super(DeformNetwork, self).__init__() 61 | self.D = D 62 | self.W = W 63 | self.input_ch = input_ch 64 | self.output_ch = output_ch 65 | self.t_multires = 6 if is_blender else 10 66 | self.skips = [D // 2] 67 | 68 | self.embed_time_fn, time_input_ch = get_embedder(self.t_multires, 1) 69 | self.embed_fn, xyz_input_ch = get_embedder(multires, 3) 70 | self.input_ch = xyz_input_ch + time_input_ch 71 | 72 | if is_blender: 73 | # Better for D-NeRF Dataset 74 | self.time_out = 30 75 | 76 | self.timenet = nn.Sequential( 77 | nn.Linear(time_input_ch, 256), nn.ReLU(inplace=True), 78 | nn.Linear(256, self.time_out)) 79 | 80 | self.linear = nn.ModuleList( 81 | [nn.Linear(xyz_input_ch + self.time_out, W)] + [ 82 | nn.Linear(W, W) if i not in self.skips else nn.Linear(W + xyz_input_ch + self.time_out, W) 83 | for i in range(D - 1)] 84 | ) 85 | 86 | else: 87 | self.linear = nn.ModuleList( 88 | [nn.Linear(self.input_ch, W)] + [ 89 | nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) 90 | for i in range(D - 1)] 91 | ) 92 | 93 | self.is_blender = is_blender 94 | self.is_6dof = is_6dof 95 | 96 | if is_6dof: 97 | self.branch_w = nn.Linear(W, 3) 98 | self.branch_v = nn.Linear(W, 3) 99 | else: 100 | self.gaussian_warp = nn.Linear(W, 3) 101 | self.gaussian_rotation = nn.Linear(W, 4) 102 | self.gaussian_scaling = nn.Linear(W, 3) 103 | 104 | def forward(self, x, t): 105 | t_emb = self.embed_time_fn(t) 106 | if self.is_blender: 107 | t_emb = self.timenet(t_emb) # better for D-NeRF Dataset 108 | x_emb = self.embed_fn(x) 109 | h = torch.cat([x_emb, t_emb], dim=-1) 110 | for i, l in enumerate(self.linear): 111 | h = self.linear[i](h) 112 | h = F.relu(h) 113 | if i in self.skips: 114 | h = torch.cat([x_emb, t_emb, h], -1) 115 | 116 | if self.is_6dof: 117 | w = self.branch_w(h) 118 | v = self.branch_v(h) 119 | theta = torch.norm(w, dim=-1, keepdim=True) 120 | w = w / theta + 1e-5 121 | v = v / theta + 1e-5 122 | screw_axis = torch.cat([w, v], dim=-1) 123 | d_xyz = exp_se3(screw_axis, theta) 124 | else: 125 | d_xyz = self.gaussian_warp(h) 126 | scaling = self.gaussian_scaling(h) 127 | rotation = self.gaussian_rotation(h) 128 | 129 | return d_xyz, rotation, scaling 130 | --------------------------------------------------------------------------------