├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── arguments └── __init__.py ├── assets └── profile_res.png ├── configs └── cache │ ├── cache_F_4.json │ ├── cache_F_8.json │ ├── cache_MLP.json │ └── cache_T_14_F_4.json ├── convert.py ├── convert_frames.py ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── ntc ├── __init__.py └── flame_steak_ntc_params_F_4.pth ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── scripts ├── cache_profile.ipynb ├── cache_warmup.ipynb ├── copy_cams.py └── extract_fvv.py ├── test └── flame_steak_suite │ ├── cfg_args.json │ ├── flame_steak_init │ └── point_cloud │ │ └── iteration_15000 │ │ └── point_cloud.ply │ └── frame000000 │ ├── distorted │ ├── database.db │ └── sparse │ │ └── 0 │ │ ├── cameras.bin │ │ ├── images.bin │ │ ├── points3D.bin │ │ └── project.ini │ ├── images │ ├── cam00.png │ ├── cam01.png │ ├── cam02.png │ ├── cam03.png │ ├── cam04.png │ ├── cam05.png │ ├── cam06.png │ ├── cam07.png │ ├── cam08.png │ ├── cam09.png │ ├── cam10.png │ ├── cam11.png │ ├── cam12.png │ ├── cam13.png │ ├── cam14.png │ ├── cam15.png │ ├── cam16.png │ ├── cam17.png │ ├── cam18.png │ ├── cam19.png │ └── cam20.png │ └── sparse │ └── 0 │ ├── cameras.bin │ ├── images.bin │ └── points3D.bin ├── train.py ├── train_frames.py └── utils ├── __init__.py ├── camera_utils.py ├── debug_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | *.log 10 | *.dump 11 | *.mp4 12 | *.npy 13 | *.out 14 | __pycache__ 15 | dataset -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = https://github.com/SJoJoK/3DGStreamRasterizer 7 | ignore = dirty 8 | [submodule "SIBR_viewers"] 9 | path = SIBR_viewers 10 | url = https://gitlab.inria.fr/sibr/sibr_core 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jac Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3DGStream 2 | 3 | Official repository for the paper "3DGStream: On-the-fly Training of 3D Gaussians for Efficient Streaming of Photo-Realistic Free-Viewpoint Videos". 4 | 5 | > **3DGStream: On-the-fly Training of 3D Gaussians for Efficient Streaming of Photo-Realistic Free-Viewpoint Videos** 6 | > [Jiakai Sun](https://sjojok.github.io), Han Jiao, [Guangyuan Li](https://guangyuankk.github.io/), Zhanjie Zhang, Lei Zhao, Wei Xing 7 | > *CVPR 2024 __Highlight__* 8 | > [Project](https://sjojok.github.io/3dgstream) 9 | | [Paper](https://openaccess.thecvf.com/content/CVPR2024/papers/Sun_3DGStream_On-the-Fly_Training_of_3D_Gaussians_for_Efficient_Streaming_of_CVPR_2024_paper.pdf) 10 | | [Suppl.](https://openaccess.thecvf.com/content/CVPR2024/supplemental/Sun_3DGStream_On-the-Fly_Training_CVPR_2024_supplemental.pdf) 11 | | [Bibtex](##Bibtex) 12 | | [Viewer](https://github.com/SJoJoK/3DGStreamViewer) 13 | 14 | 15 | 16 | ## Release Roadmap 17 | 18 | - [x] Open-source [3DGStream Viewer](https://github.com/SJoJoK/3DGStreamViewer) 19 | 20 | - [x] Free-Viewpoint Video 21 | 22 | - [x] Unorganized code with few instructions (around May 2024) 23 | 24 | - [x] Pre-Release 25 | 26 | - [ ] Refactored code with added comments (after CVPR 2024) 27 | 28 | - [ ] 3DGStream v2 (hopefully in 2025) 29 | 30 | ## Step-by-step Tutorial for 3DGStream (May Ver.) 31 | 32 | 1. Follow the instructions in [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) to setup the environment and submodules, after that, you need to install [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn). 33 | 34 | You can use the same Python environment configured for gaussian-splatting. However, it is necessary to install tiny-cuda-nn and reinstall the submodules/diff-gaussian-rasterization by running `pip install submodules/diff-gaussian-rasterization`. Additionally, we recommend using PyTorch version 2.0 or higher for enhanced performance, as we utilize `torch.compile`. If you are using a PyTorch version lower than 2.0, you may need to comment out the lines of the code where `torch.compile` is used. 35 | 36 | The code is tested on: 37 | 38 | ``` 39 | OS: Ubuntu 22.04 40 | GPU: RTX A6000/3090 41 | Driver: 535.86.05 42 | CUDA: 11.8 43 | Python: 3.8 44 | Pytorch: 2.0.1+cu118 45 | tinycudann: 1.7 46 | ``` 47 | 48 | 3. Follow the instructions in [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) to create your COLMAP dataset based on the images of the timestep 0 , which will end-up like: 49 | 50 | ``` 51 | 52 | |---images 53 | | |--- 54 | | |--- 55 | | |---... 56 | |---distorted 57 | | |---sparse 58 | | |---0 59 | | |---cameras.bin 60 | | |---images.bin 61 | | |---points3D.bin 62 | |---sparse 63 | |---0 64 | |---cameras.bin 65 | |---images.bin 66 | |---points3D.bin 67 | ``` 68 | 69 | You can use *test/flame_steak_suite/frame000000* for experiment on the `flame steak` scene. 70 | 71 | 4. Follow the instructions in [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) to get a **high-quality** init_3dgs (sh_degree = 1, i.e., train with `--sh_degree 1`) from the above colmap dataset, which will end-up like: 72 | 73 | ``` 74 | 75 | |---point_cloud 76 | | |---iteration_7000 77 | | | |---point_cloud.ply 78 | | |---iteration_15000 79 | | |---... 80 | |---... 81 | ``` 82 | 83 | You can use *test/flame_steak_suite/flame_steak_init* for experiment on the `flame steak` scene. 84 | 85 | Since the training of 3DGStream is orthogonal to that of init_3dgs, you are free to use any method that enhances the quality of init_3dgs, provided that the resulting ply file remains compatible with the original [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting). 86 | 87 | 5. Prepare the multi-view video dataset: 88 | 89 | 1. Extract the frames and organize them like this: 90 | 91 | ``` 92 | 93 | |---frame000001 94 | | |--- 95 | | |--- 96 | | |---... 97 | |---frame000002 98 | |---... 99 | |---frame000299 100 | ``` 101 | 102 | If you intend to use the data we have prepared in the test/flame_steak_suite, ensure that the images are named following the pattern `cam00.png`, ..., `cam20.png`. This is necessary because COLMAP references images by their file names. 103 | 104 | For convenience, we assume that you extract the frames of the `flame steak` scene into *dataset/flame_steak*. This means your folder structure should look like this: 105 | 106 | ``` 107 | dataset/flame_steak 108 | |---frame000001 109 | | |---cam00.png 110 | | |---cam01.png 111 | | |---... 112 | |---frame000002 113 | |---... 114 | |---frame000299 115 | ``` 116 | 117 | 2. Copy the camera infos by `python scripts/copy_cams.py --source --scene `: 118 | 119 | ``` 120 | 121 | |---frame000001 122 | | |---sparse 123 | | | |---... 124 | | |--- 125 | | |--- 126 | | |---... 127 | |---frame000002 128 | |---frame000299 129 | |---distorted 130 | | |---... 131 | |---... 132 | ``` 133 | 134 | You can run 135 | 136 | ```bash 137 | python scripts/copy_cams.py --source test/flame_steak_suite/frame000000 --scene dataset/flame_steak 138 | ``` 139 | 140 | to prepare for conducting experiment on the `flame steak` scene. 141 | 142 | 4. Undistort the images by `python convert_frames.py -s --resize`, then the dataset will end-up like this: 143 | 144 | ``` 145 | 146 | |---frame000001 147 | | |---sparse 148 | | |---images 149 | | |--- 150 | | |--- 151 | | |---.... 152 | | |--- 153 | | |--- 154 | | |---... 155 | |---frame000002 156 | |---... 157 | |---frame000299 158 | ``` 159 | 160 | You can run 161 | 162 | ```bash 163 | python convert_frames.py --scene dataset/flame_steak --resize 164 | ``` 165 | 166 | to prepare for conducting experiment on the `flame steak` scene. 167 | 168 | **For multi-view datasets with distortion such as [MeetRoom](https://github.com/AlgoHunt/StreamRF), undistortion is critical to improving the reconstruction quality. 169 | We followed the settings of the original [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) and performed undistortion.** 170 | 171 | 6. Warm-up the NTC 172 | 173 | Please refer to the *scripts/cache_warmup.ipynb* notebook to perform a warm-up of the NTC. 174 | 175 | For better performance, it's crucial to define the corners of the Axis-Aligned Bounding Box that approximately enclose your scene. For instance, in a scene like `flame salmon`, the AABB should encompass the room while excluding any external landscape elements. To set the coordinates of the AABB corners, you should directly hard-code them into the `get_xyz_bound` function. 176 | 177 | **If you find that the loss is NaN when the NTC is warm-uped, please refer to [this issue](https://github.com/SJoJoK/3DGStream/issues/16) for a solution.** 178 | 179 | 7. GO! 180 | 181 | Everything is set up, just run 182 | 183 | ```bash 184 | python train_frames.py --read_config --config_path -o -m -v --image --first_load_iteration 185 | ``` 186 | 187 | Parameter explanations: 188 | * ``: We provide a configuration file containing all necessary parameters, available at *test/flame_steak_suite/cfg_args.json*. 189 | * ``: Please refer to the section 2 of this guidance. 190 | * ``: Please refer to the section 4.2 of this guidance. 191 | * ``: Typically named `images`, `images_2`, or `images_4`. 3DGStream will use the images located at *\/\/\* as input. 192 | * ``: 3DGStream will initialize the 3DGS using the point cloud at *\/\/iteration_\/point_cloud.ply*. 193 | * Use `--eval` when you have a test/train split. You may need to review and modify `readColmapSceneInfo` in *scene/dataset_renders.py* accordingly. 194 | * Specify `--resolution` only when necessary, as reading and resizing large images is time-consuming. Consider resizing the images before 3DGStream processes them. 195 | * About NTC: 196 | - `--ntc_conf_path`: Set this to the path of the NTC configuration file (see *scripts/cache_warmup.ipynb*, *configs/cache/* and [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn)). 197 | - `--ntc_path`: Set this to the path of the pre-warmed parameters (see *scripts/cache_warmup.ipynb*). 198 | 199 | You can run 200 | 201 | ```bash 202 | python train_frames.py --read_config --config_path test/flame_steak_suite/cfg_args.json -o output/Code-Release -m test/flame_steak_suite/flame_steak_init/ -v --image images_2 --first_load_iteration 15000 --quiet 203 | ``` 204 | 205 | to conduct the experiments on the `flame steak` scene. 206 | 207 | 8. Evaluate Performance 208 | 209 | * PSNR: Average among all test images 210 | 211 | * Per-frame Storage: Average among all frames (including the first frame) 212 | 213 | For a multi-view videos that has 300 frames, the per-frame storage is $$\frac{(\text{init3dgs})+299*(\text{NTC}+\text{new3dgs})}{300}$$​ 214 | 215 | * Per-frame Training Time: Average among all frames (including the first frame) 216 | 217 | * Rendering Speed 218 | 219 | There are serval ways to evaluate the rendering speed: 220 | 221 | * **[SIBR-Viewer](https://gitlab.inria.fr/sibr/sibr_core)** (As presented in our paper) 222 | 223 | Integrate 3DGStream into SIBR-Viewer for an accurate measurement. If integration is too complex, approximate the rendering speed by: 224 | 225 | 1. Use the SIBR-Viewer to render the init_3dgs and get the rendering speed 226 | 227 | 2. Profiling `query_ntc_eval` using *scripts/cache_profile.ipynb*. 228 | 229 | 3. Summing the measurements for an estimated total rendering speed, like this: 230 | 231 | | Step | Overhead(ms) | FPS | 232 | | ---------------- | ------------ | ---- | 233 | | Render w/o NTC | 2.56 | 390 | 234 | | + Query NTC | 0.62 | | 235 | | + Transformation | 0.02 | | 236 | | + SH Rotation | 1.46 | | 237 | | Total | 4.46 | 215 | 238 | 239 | To isolate the overhead for each process, you can comment out the other parts of the code. 240 | 241 | * **[3DGStreamViewer](https://github.com/SJoJoK/3DGStreamViewer)** 242 | 243 | You can use *scripts/extract_fvv.py* to re-arrange the output of 3DGStream and render it with 3DGStreamViewer 244 | 245 | * **Custom Script** 246 | 247 | Write a script that loads all NTCs and additional_3dgs and renders the test image for every frame. For guidance, you can look at the implementation within [3DGStreamViewer](https://github.com/SJoJoK/3DGStreamViewer) 248 | 249 | ## Acknowledgments 250 | 251 | We acknowledge the foundational work of [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), which form the basis of the 3DGStream code. Special thanks to [Qiankun Gao](https://github.com/gqk) for his feedback on the pre-release version. 252 | 253 | ## Bibtex 254 | ``` 255 | @InProceedings{sun20243dgstream, 256 | author = {Sun, Jiakai and Jiao, Han and Li, Guangyuan and Zhang, Zhanjie and Zhao, Lei and Xing, Wei}, 257 | title = {3DGStream: On-the-Fly Training of 3D Gaussians for Efficient Streaming of Photo-Realistic Free-Viewpoint Videos}, 258 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 259 | month = {June}, 260 | year = {2024}, 261 | pages = {20675-20685} 262 | } 263 | ``` 264 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.extent = 0 50 | self.sh_degree = 3 51 | self._source_path = "" 52 | self._model_path = "" 53 | self._output_path = "" 54 | self._video_path = "" 55 | self.ply_name = "points3D.ply" 56 | self._images = "images" 57 | self._resolution = -1 58 | self._white_background = False 59 | self.data_device = "cuda" 60 | self.eval = False 61 | super().__init__(parser, "Loading Parameters", sentinel) 62 | 63 | def extract(self, args): 64 | g = super().extract(args) 65 | g.source_path = os.path.abspath(g.source_path) 66 | return g 67 | 68 | class PipelineParams(ParamGroup): 69 | def __init__(self, parser): 70 | self.convert_SHs_python = False 71 | self.compute_cov3D_python = False 72 | self.debug = False 73 | self.bwd_depth = False 74 | self.opt_type='3DGStream' 75 | super().__init__(parser, "Pipeline Parameters") 76 | 77 | class OptimizationParams(ParamGroup): 78 | def __init__(self, parser): 79 | self.iterations = 30_000 80 | self.iterations_s2 = 0 81 | self.first_load_iteration = 15000 82 | self.position_lr_init = 0.00016 83 | self.position_lr_final = 0.0000016 84 | self.position_lr_delay_mult = 0.01 85 | self.position_lr_max_steps = 30_000 86 | self.feature_lr = 0.0025 87 | self.opacity_lr = 0.05 88 | self.scaling_lr = 0.005 89 | self.rotation_lr = 0.001 90 | self.percent_dense = 0.01 91 | self.lambda_dssim = 0.2 92 | self.depth_smooth = 0.0 93 | self.ntc_lr = None 94 | self.lambda_dxyz = 0.0 95 | self.lambda_drot= 0.0 96 | self.densification_interval = 100 97 | self.opacity_reset_interval = 3000 98 | self.densify_from_iter = 500 99 | self.densify_until_iter = 15_000 100 | self.densify_grad_threshold = 0.0002 101 | self.ntc_conf_path = "" 102 | self.ntc_path = "" 103 | self.batch_size = 1 104 | self.spawn_type = "spawn" 105 | self.s2_type = "split" 106 | self.s2_adding = False 107 | self.num_of_split=1 108 | self.num_of_spawn=2 109 | self.std_scale=1 110 | self.min_opacity = 0.005 111 | self.rotate_sh = True 112 | self.only_mlp = False 113 | super().__init__(parser, "Optimization Parameters") 114 | 115 | def get_combined_args(parser : ArgumentParser): 116 | cmdlne_string = sys.argv[1:] 117 | cfgfile_string = "Namespace()" 118 | args_cmdline = parser.parse_args(cmdlne_string) 119 | 120 | try: 121 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 122 | print("Looking for config file in", cfgfilepath) 123 | with open(cfgfilepath) as cfg_file: 124 | print("Config file found: {}".format(cfgfilepath)) 125 | cfgfile_string = cfg_file.read() 126 | except TypeError: 127 | print("Config file not found at") 128 | pass 129 | args_cfgfile = eval(cfgfile_string) 130 | 131 | merged_dict = vars(args_cfgfile).copy() 132 | for k,v in vars(args_cmdline).items(): 133 | if v != None: 134 | merged_dict[k] = v 135 | return Namespace(**merged_dict) 136 | -------------------------------------------------------------------------------- /assets/profile_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/assets/profile_res.png -------------------------------------------------------------------------------- /configs/cache/cache_F_4.json: -------------------------------------------------------------------------------- 1 | { 2 | "loss": { 3 | "otype": "RelativeL2Luminance" 4 | }, 5 | "optimizer": { 6 | "otype": "Adam", 7 | "learning_rate": 1e-3, 8 | "beta1": 0.9, 9 | "beta2": 0.99, 10 | "epsilon": 1e-15, 11 | "l2_reg": 1e-6 12 | }, 13 | "encoding": { 14 | "otype": "HashGrid", 15 | "n_dims_to_encode": 3, 16 | "per_level_scale": 2.0, 17 | "log2_hashmap_size": 15, 18 | "base_resolution": 16, 19 | "n_levels": 16, 20 | "n_features_per_level": 4 21 | }, 22 | "network": { 23 | "otype": "FullyFusedMLP", 24 | "activation": "ReLU", 25 | "output_activation": "None", 26 | "n_neurons": 64, 27 | "n_hidden_layers": 2 28 | }, 29 | "others": { 30 | "otype": "EMA", 31 | "decay": 0.99, 32 | "nested": { 33 | "otype": "Adam", 34 | "learning_rate": 1e-2, 35 | "beta1": 0.9, 36 | "beta2": 0.99, 37 | "epsilon": 1e-15, 38 | "l2_reg": 1e-6 39 | } 40 | } 41 | } -------------------------------------------------------------------------------- /configs/cache/cache_F_8.json: -------------------------------------------------------------------------------- 1 | { 2 | "loss": { 3 | "otype": "RelativeL2Luminance" 4 | }, 5 | "optimizer": { 6 | "otype": "Adam", 7 | "learning_rate": 1e-3, 8 | "beta1": 0.9, 9 | "beta2": 0.99, 10 | "epsilon": 1e-15, 11 | "l2_reg": 1e-6 12 | }, 13 | "encoding": { 14 | "otype": "HashGrid", 15 | "n_dims_to_encode": 3, 16 | "per_level_scale": 2.0, 17 | "log2_hashmap_size": 15, 18 | "base_resolution": 16, 19 | "n_levels": 16, 20 | "n_features_per_level": 8 21 | }, 22 | "network": { 23 | "otype": "FullyFusedMLP", 24 | "activation": "ReLU", 25 | "output_activation": "None", 26 | "n_neurons": 64, 27 | "n_hidden_layers": 2 28 | }, 29 | "others": { 30 | "otype": "EMA", 31 | "decay": 0.99, 32 | "nested": { 33 | "otype": "Adam", 34 | "learning_rate": 1e-2, 35 | "beta1": 0.9, 36 | "beta2": 0.99, 37 | "epsilon": 1e-15, 38 | "l2_reg": 1e-6 39 | } 40 | } 41 | } -------------------------------------------------------------------------------- /configs/cache/cache_MLP.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "otype": "FullyFusedMLP", 4 | "activation": "ReLU", 5 | "output_activation": "None", 6 | "n_neurons": 64, 7 | "n_hidden_layers": 2 8 | } 9 | } -------------------------------------------------------------------------------- /configs/cache/cache_T_14_F_4.json: -------------------------------------------------------------------------------- 1 | { 2 | "loss": { 3 | "otype": "RelativeL2Luminance" 4 | }, 5 | "optimizer": { 6 | "otype": "Adam", 7 | "learning_rate": 1e-3, 8 | "beta1": 0.9, 9 | "beta2": 0.99, 10 | "epsilon": 1e-15, 11 | "l2_reg": 1e-6 12 | }, 13 | "encoding": { 14 | "otype": "HashGrid", 15 | "n_dims_to_encode": 3, 16 | "per_level_scale": 2.0, 17 | "log2_hashmap_size": 14, 18 | "base_resolution": 16, 19 | "n_levels": 16, 20 | "n_features_per_level": 4 21 | }, 22 | "network": { 23 | "otype": "FullyFusedMLP", 24 | "activation": "ReLU", 25 | "output_activation": "None", 26 | "n_neurons": 64, 27 | "n_hidden_layers": 2 28 | }, 29 | "others": { 30 | "otype": "EMA", 31 | "decay": 0.99, 32 | "nested": { 33 | "otype": "Adam", 34 | "learning_rate": 1e-2, 35 | "beta1": 0.9, 36 | "beta2": 0.99, 37 | "epsilon": 1e-15, 38 | "l2_reg": 1e-6 39 | } 40 | } 41 | } -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | parser.add_argument("--image_dir", default="inputs", type=str) 27 | args = parser.parse_args() 28 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 29 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 30 | use_gpu = 1 if not args.no_gpu else 0 31 | 32 | inputDir='/'+args.image_dir 33 | 34 | if not args.skip_matching: 35 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 36 | 37 | ## Feature extraction 38 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 39 | "--database_path " + args.source_path + "/distorted/database.db \ 40 | --image_path " + args.source_path + inputDir + " \ 41 | --ImageReader.single_camera 1 \ 42 | --ImageReader.camera_model " + args.camera + " \ 43 | --SiftExtraction.use_gpu " + str(use_gpu) 44 | print(feat_extracton_cmd) 45 | exit_code = os.system(feat_extracton_cmd) 46 | if exit_code != 0: 47 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 48 | exit(exit_code) 49 | 50 | ## Feature matching 51 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 52 | --database_path " + args.source_path + "/distorted/database.db \ 53 | --SiftMatching.use_gpu " + str(use_gpu) 54 | exit_code = os.system(feat_matching_cmd) 55 | if exit_code != 0: 56 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 57 | exit(exit_code) 58 | 59 | ### Bundle adjustment 60 | # The default Mapper tolerance is unnecessarily large, 61 | # decreasing it speeds up bundle adjustment steps. 62 | mapper_cmd = (colmap_command + " mapper \ 63 | --database_path " + args.source_path + "/distorted/database.db \ 64 | --image_path " + args.source_path + inputDir + " \ 65 | --output_path " + args.source_path + "/distorted/sparse \ 66 | --Mapper.ba_global_function_tolerance=0.000001") 67 | exit_code = os.system(mapper_cmd) 68 | if exit_code != 0: 69 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 70 | exit(exit_code) 71 | 72 | ### Image undistortion 73 | ## We need to undistort our images into ideal pinhole intrinsics. 74 | img_undist_cmd = (colmap_command + " image_undistorter \ 75 | --image_path " + args.source_path + inputDir + " \ 76 | --input_path " + args.source_path + "/distorted/sparse/0 \ 77 | --output_path " + args.source_path + "\ 78 | --output_type COLMAP") 79 | exit_code = os.system(img_undist_cmd) 80 | if exit_code != 0: 81 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 82 | exit(exit_code) 83 | 84 | files = os.listdir(args.source_path + "/sparse") 85 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 86 | # Copy each file from the source directory to the destination directory 87 | for file in files: 88 | if file == '0': 89 | continue 90 | source_file = os.path.join(args.source_path, "sparse", file) 91 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 92 | shutil.move(source_file, destination_file) 93 | 94 | if(args.resize): 95 | print("Copying and resizing...") 96 | 97 | # Resize images. 98 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 99 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 100 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 101 | # Get the list of files in the source directory 102 | files = os.listdir(args.source_path + "/images") 103 | # Copy each file from the source directory to the destination directory 104 | for file in files: 105 | source_file = os.path.join(args.source_path, "images", file) 106 | 107 | destination_file = os.path.join(args.source_path, "images_2", file) 108 | shutil.copy2(source_file, destination_file) 109 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 110 | if exit_code != 0: 111 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 112 | exit(exit_code) 113 | 114 | destination_file = os.path.join(args.source_path, "images_4", file) 115 | shutil.copy2(source_file, destination_file) 116 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 117 | if exit_code != 0: 118 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 119 | exit(exit_code) 120 | 121 | destination_file = os.path.join(args.source_path, "images_8", file) 122 | shutil.copy2(source_file, destination_file) 123 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 124 | if exit_code != 0: 125 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 126 | exit(exit_code) 127 | 128 | print("Done.") 129 | -------------------------------------------------------------------------------- /convert_frames.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--skip_undistortion", action='store_true') 22 | parser.add_argument("--source_path", "-s", required=True, type=str) 23 | parser.add_argument("--camera", default="OPENCV", type=str) 24 | parser.add_argument("--colmap_executable", default="", type=str) 25 | parser.add_argument("--resize", action="store_true") 26 | parser.add_argument("--magick_executable", default="", type=str) 27 | parser.add_argument("--last_frame_id", default=299, type=int) 28 | 29 | args = parser.parse_args() 30 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 31 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 32 | use_gpu = 1 if not args.no_gpu else 0 33 | 34 | for id in range(1,args.last_frame_id+1): 35 | frame_id = f'{id:0>6}' 36 | inputDir='/frame'+frame_id 37 | 38 | print("Processing "+inputDir) 39 | 40 | ### Image undistortion 41 | ## We need to undistort our images into ideal pinhole intrinsics. 42 | img_undist_cmd = (colmap_command + " image_undistorter \ 43 | --image_path " + args.source_path + inputDir + " \ 44 | --input_path " + args.source_path + "/distorted/sparse/0 \ 45 | --output_path " + args.source_path + inputDir + " \ 46 | --output_type COLMAP") 47 | exit_code = os.system(img_undist_cmd) 48 | if exit_code != 0: 49 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 50 | exit(exit_code) 51 | 52 | files = os.listdir(args.source_path + inputDir + "/sparse") 53 | os.makedirs(args.source_path + inputDir + "/sparse/0", exist_ok=True) 54 | # Copy each file from the source directory to the destination directory 55 | for file in files: 56 | if file == '0': 57 | continue 58 | source_file = os.path.join(args.source_path, inputDir[1:], "sparse", file) 59 | destination_file = os.path.join(args.source_path, inputDir[1:], "sparse", "0", file) 60 | shutil.move(source_file, destination_file) 61 | 62 | if(args.resize): 63 | print("Copying and resizing...") 64 | # Resize images. 65 | os.makedirs(os.path.join(args.source_path, inputDir[1:]) + "/images_2", exist_ok=True) 66 | # os.makedirs(args.source_path + "/images_4", exist_ok=True) 67 | # os.makedirs(args.source_path + "/images_8", exist_ok=True) 68 | 69 | # Get the list of files in the source directory 70 | files = os.listdir(args.source_path + inputDir + "/images") 71 | # Copy each file from the source directory to the destination directory 72 | for file in files: 73 | source_file = os.path.join(args.source_path, inputDir[1:], "images", file) 74 | destination_file = os.path.join(args.source_path, inputDir[1:], "images_2", file) 75 | shutil.copy2(source_file, destination_file) 76 | print("Resizing " + source_file + " to " + destination_file) 77 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 78 | if exit_code != 0: 79 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 80 | exit(exit_code) 81 | 82 | # destination_file = os.path.join(args.source_path, "images_4", file) 83 | # shutil.copy2(source_file, destination_file) 84 | # exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 85 | # if exit_code != 0: 86 | # logging.error(f"25% resize failed with code {exit_code}. Exiting.") 87 | # exit(exit_code) 88 | 89 | # destination_file = os.path.join(args.source_path, "images_8", file) 90 | # shutil.copy2(source_file, destination_file) 91 | # exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 92 | # if exit_code != 0: 93 | # logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 94 | # exit(exit_code) 95 | 96 | print("Done.") -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | 18 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(viewpoint_camera.image_height), 38 | image_width=int(viewpoint_camera.image_width), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=scaling_modifier, 43 | viewmatrix=viewpoint_camera.world_view_transform, 44 | projmatrix=viewpoint_camera.full_proj_transform, 45 | sh_degree=pc.active_sh_degree, 46 | campos=viewpoint_camera.camera_center, 47 | prefiltered=False, 48 | debug=pipe.debug, 49 | bwd_depth=pipe.bwd_depth 50 | ) 51 | 52 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 53 | 54 | means3D = pc.get_xyz 55 | means2D = screenspace_points 56 | opacity = pc.get_opacity 57 | 58 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 59 | # scaling / rotation by the rasterizer. 60 | scales = None 61 | rotations = None 62 | cov3D_precomp = None 63 | if pipe.compute_cov3D_python: 64 | cov3D_precomp = pc.get_covariance(scaling_modifier) 65 | else: 66 | scales = pc.get_scaling 67 | rotations = pc.get_rotation 68 | 69 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 70 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 71 | shs = None 72 | colors_precomp = None 73 | if colors_precomp is None: 74 | if pipe.convert_SHs_python: 75 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 76 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 77 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 78 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 79 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 80 | else: 81 | shs = pc.get_features 82 | else: 83 | colors_precomp = override_color 84 | 85 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 86 | rendered_image, radii, depth = rasterizer( 87 | means3D = means3D, 88 | means2D = means2D, 89 | shs = shs, 90 | colors_precomp = colors_precomp, 91 | opacities = opacity, 92 | scales = scales, 93 | rotations = rotations, 94 | cov3D_precomp = cov3D_precomp) 95 | 96 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 97 | # They will be excluded from value updates used in the splitting criteria. 98 | return {"render": rendered_image, 99 | "viewspace_points": screenspace_points, 100 | "visibility_filter" : radii > 0, 101 | "radii": radii, 102 | "depth":depth} 103 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | try: 29 | host = wish_host 30 | port = wish_port 31 | listener.bind((host, port)) 32 | listener.listen() 33 | listener.settimeout(0) 34 | except Exception as inst: 35 | pass 36 | 37 | def try_connect(): 38 | global conn, addr, listener 39 | try: 40 | conn, addr = listener.accept() 41 | print(f"\nConnected by {addr}") 42 | conn.settimeout(None) 43 | except Exception as inst: 44 | pass 45 | 46 | def read(): 47 | global conn 48 | messageLength = conn.recv(4) 49 | messageLength = int.from_bytes(messageLength, 'little') 50 | message = conn.recv(messageLength) 51 | return json.loads(message.decode("utf-8")) 52 | 53 | def send(message_bytes, verify): 54 | global conn 55 | if message_bytes != None: 56 | conn.sendall(message_bytes) 57 | conn.sendall(len(verify).to_bytes(4, 'little')) 58 | conn.sendall(bytes(verify, 'ascii')) 59 | 60 | def receive(): 61 | message = read() 62 | 63 | width = message["resolution_x"] 64 | height = message["resolution_y"] 65 | 66 | if width != 0 and height != 0: 67 | try: 68 | do_training = bool(message["train"]) 69 | fovy = message["fov_y"] 70 | fovx = message["fov_x"] 71 | znear = message["z_near"] 72 | zfar = message["z_far"] 73 | do_shs_python = bool(message["shs_python"]) 74 | do_rot_scale_python = bool(message["rot_scale_python"]) 75 | keep_alive = bool(message["keep_alive"]) 76 | scaling_modifier = message["scaling_modifier"] 77 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 78 | world_view_transform[:,1] = -world_view_transform[:,1] 79 | world_view_transform[:,2] = -world_view_transform[:,2] 80 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 81 | full_proj_transform[:,1] = -full_proj_transform[:,1] 82 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 83 | except Exception as e: 84 | print("") 85 | traceback.print_exc() 86 | raise e 87 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 88 | else: 89 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /ntc/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | class NeuralTransformationCache(torch.nn.Module): 3 | def __init__(self, model, xyz_bound_min, xyz_bound_max): 4 | super(NeuralTransformationCache, self).__init__() 5 | self.model = model 6 | self.register_buffer('xyz_bound_min',xyz_bound_min) 7 | self.register_buffer('xyz_bound_max',xyz_bound_max) 8 | 9 | def dump(self, path): 10 | torch.save(self.state_dict(),path) 11 | 12 | def get_contracted_xyz(self, xyz): 13 | with torch.no_grad(): 14 | contracted_xyz=(xyz-self.xyz_bound_min)/(self.xyz_bound_max-self.xyz_bound_min) 15 | return contracted_xyz 16 | 17 | def forward(self, xyz:torch.Tensor): 18 | contracted_xyz=self.get_contracted_xyz(xyz) # Shape: [N, 3] 19 | 20 | mask = (contracted_xyz >= 0) & (contracted_xyz <= 1) 21 | mask = mask.all(dim=1) 22 | 23 | ntc_inputs=torch.cat([contracted_xyz[mask]],dim=-1) 24 | resi=self.model(ntc_inputs) 25 | 26 | masked_d_xyz=resi[:,:3] 27 | masked_d_rot=resi[:,3:7] 28 | # masked_d_opacity=resi[:,7:None] 29 | 30 | d_xyz = torch.full((xyz.shape[0], 3), 0.0, dtype=torch.half, device="cuda") 31 | d_rot = torch.full((xyz.shape[0], 4), 0.0, dtype=torch.half, device="cuda") 32 | d_rot[:, 0] = 1.0 33 | # d_opacity = self._origin_d_opacity.clone() 34 | 35 | d_xyz[mask] = masked_d_xyz 36 | d_rot[mask] = masked_d_rot 37 | 38 | return mask, d_xyz, d_rot 39 | 40 | 41 | -------------------------------------------------------------------------------- /ntc/flame_steak_ntc_params_F_4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/ntc/flame_steak_ntc_params_F_4.pth -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | class Scene: 21 | 22 | gaussians : GaussianModel 23 | 24 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 25 | """b 26 | :param path: Path to colmap scene main folder. 27 | """ 28 | self.model_path = args.model_path 29 | try: 30 | self.output_path = args.output_path 31 | except: 32 | self.output_path = args.model_path 33 | self.loaded_iter = None 34 | self.gaussians = gaussians 35 | 36 | if load_iteration: 37 | if load_iteration == -1: 38 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 39 | else: 40 | self.loaded_iter = load_iteration 41 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 42 | 43 | self.train_cameras = {} 44 | self.test_cameras = {} 45 | 46 | if os.path.exists(os.path.join(args.source_path, "sparse")): 47 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, ply_name=args.ply_name) 48 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 49 | print("Found transforms_train.json file, assuming Blender data set!") 50 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 51 | else: 52 | assert False, "Could not recognize scene type!" 53 | 54 | if not self.loaded_iter: 55 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.output_path, "input.ply") , 'wb') as dest_file: 56 | dest_file.write(src_file.read()) 57 | json_cams = [] 58 | camlist = [] 59 | if scene_info.test_cameras: 60 | camlist.extend(scene_info.test_cameras) 61 | if scene_info.train_cameras: 62 | camlist.extend(scene_info.train_cameras) 63 | for id, cam in enumerate(camlist): 64 | json_cams.append(camera_to_JSON(id, cam)) 65 | with open(os.path.join(self.output_path, "cameras.json"), 'w') as file: 66 | json.dump(json_cams, file) 67 | 68 | if shuffle: 69 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 70 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 71 | 72 | if args.extent == 0: 73 | self.cameras_extent = scene_info.nerf_normalization["radius"] 74 | else: 75 | self.cameras_extent = args.extent 76 | for resolution_scale in resolution_scales: 77 | print("Loading Training Cameras") 78 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 79 | print("Loading Test Cameras") 80 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 81 | 82 | if self.loaded_iter: 83 | self.gaussians.load_ply(os.path.join(self.model_path, 84 | "point_cloud", 85 | "iteration_" + str(self.loaded_iter), 86 | "point_cloud.ply"), self.cameras_extent) 87 | else: 88 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 89 | 90 | def save(self, iteration, save_type='all'): 91 | point_cloud_path = os.path.join(self.output_path, "point_cloud/iteration_{}".format(iteration)) 92 | if save_type=='added' or save_type=='origin': 93 | self.gaussians.save_ply(os.path.join(point_cloud_path, save_type, "point_cloud.ply"), save_type) 94 | elif save_type=='all': 95 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"), save_type) 96 | else: 97 | raise NotImplementedError('Not Implemented!') 98 | 99 | def dump_NTC(self): 100 | NTC_path = os.path.join(self.output_path, "NTC.pth") 101 | self.gaussians.ntc.dump(NTC_path) 102 | 103 | def getTrainCameras(self, scale=1.0): 104 | return self.train_cameras[scale] 105 | 106 | def getTestCameras(self, scale=1.0): 107 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", gt_depth=None 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.image_name = image_name 31 | 32 | try: 33 | self.data_device = torch.device(data_device) 34 | except Exception as e: 35 | print(e) 36 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 37 | self.data_device = torch.device("cuda") 38 | 39 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 40 | self.depth = gt_depth.to(self.data_device) if gt_depth is not None else None 41 | self.image_width = self.original_image.shape[2] 42 | self.image_height = self.original_image.shape[1] 43 | 44 | if gt_alpha_mask is not None: 45 | self.original_image *= gt_alpha_mask.to(self.data_device) 46 | else: 47 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 48 | 49 | self.zfar = 100.0 50 | self.znear = 0.01 51 | 52 | self.trans = trans 53 | self.scale = scale 54 | 55 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 56 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 57 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 58 | self.camera_center = self.world_view_transform.inverse()[3, :3] 59 | 60 | class MiniCam: 61 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 62 | self.image_width = width 63 | self.image_height = height 64 | self.FoVy = fovy 65 | self.FoVx = fovx 66 | self.znear = znear 67 | self.zfar = zfar 68 | self.world_view_transform = world_view_transform 69 | self.full_proj_transform = full_proj_transform 70 | view_inv = torch.inverse(self.world_view_transform) 71 | self.camera_center = view_inv[3][:3] 72 | 73 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | with open(path, "r") as fid: 93 | while True: 94 | line = fid.readline() 95 | if not line: 96 | break 97 | line = line.strip() 98 | if len(line) > 0 and line[0] != "#": 99 | elems = line.split() 100 | xyz = np.array(tuple(map(float, elems[1:4]))) 101 | rgb = np.array(tuple(map(int, elems[4:7]))) 102 | error = np.array(float(elems[7])) 103 | if xyzs is None: 104 | xyzs = xyz[None, ...] 105 | rgbs = rgb[None, ...] 106 | errors = error[None, ...] 107 | else: 108 | xyzs = np.append(xyzs, xyz[None, ...], axis=0) 109 | rgbs = np.append(rgbs, rgb[None, ...], axis=0) 110 | errors = np.append(errors, error[None, ...], axis=0) 111 | return xyzs, rgbs, errors 112 | 113 | def read_points3D_binary(path_to_model_file): 114 | """ 115 | see: src/base/reconstruction.cc 116 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 117 | void Reconstruction::WritePoints3DBinary(const std::string& path) 118 | """ 119 | 120 | 121 | with open(path_to_model_file, "rb") as fid: 122 | num_points = read_next_bytes(fid, 8, "Q")[0] 123 | 124 | xyzs = np.empty((num_points, 3)) 125 | rgbs = np.empty((num_points, 3)) 126 | errors = np.empty((num_points, 1)) 127 | 128 | for p_id in range(num_points): 129 | binary_point_line_properties = read_next_bytes( 130 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 131 | xyz = np.array(binary_point_line_properties[1:4]) 132 | rgb = np.array(binary_point_line_properties[4:7]) 133 | error = np.array(binary_point_line_properties[7]) 134 | track_length = read_next_bytes( 135 | fid, num_bytes=8, format_char_sequence="Q")[0] 136 | track_elems = read_next_bytes( 137 | fid, num_bytes=8*track_length, 138 | format_char_sequence="ii"*track_length) 139 | xyzs[p_id] = xyz 140 | rgbs[p_id] = rgb 141 | errors[p_id] = error 142 | return xyzs, rgbs, errors 143 | 144 | def read_intrinsics_text(path): 145 | """ 146 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 147 | """ 148 | cameras = {} 149 | with open(path, "r") as fid: 150 | while True: 151 | line = fid.readline() 152 | if not line: 153 | break 154 | line = line.strip() 155 | if len(line) > 0 and line[0] != "#": 156 | elems = line.split() 157 | camera_id = int(elems[0]) 158 | model = elems[1] 159 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 160 | width = int(elems[2]) 161 | height = int(elems[3]) 162 | params = np.array(tuple(map(float, elems[4:]))) 163 | cameras[camera_id] = Camera(id=camera_id, model=model, 164 | width=width, height=height, 165 | params=params) 166 | return cameras 167 | 168 | def read_extrinsics_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for _ in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_intrinsics_binary(path_to_model_file): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::WriteCamerasBinary(const std::string& path) 207 | void Reconstruction::ReadCamerasBinary(const std::string& path) 208 | """ 209 | cameras = {} 210 | with open(path_to_model_file, "rb") as fid: 211 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 212 | for _ in range(num_cameras): 213 | camera_properties = read_next_bytes( 214 | fid, num_bytes=24, format_char_sequence="iiQQ") 215 | camera_id = camera_properties[0] 216 | model_id = camera_properties[1] 217 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 218 | width = camera_properties[2] 219 | height = camera_properties[3] 220 | num_params = CAMERA_MODEL_IDS[model_id].num_params 221 | params = read_next_bytes(fid, num_bytes=8*num_params, 222 | format_char_sequence="d"*num_params) 223 | cameras[camera_id] = Camera(id=camera_id, 224 | model=model_name, 225 | width=width, 226 | height=height, 227 | params=np.array(params)) 228 | assert len(cameras) == num_cameras 229 | return cameras 230 | 231 | 232 | def read_extrinsics_text(path): 233 | """ 234 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 235 | """ 236 | images = {} 237 | with open(path, "r") as fid: 238 | while True: 239 | line = fid.readline() 240 | if not line: 241 | break 242 | line = line.strip() 243 | if len(line) > 0 and line[0] != "#": 244 | elems = line.split() 245 | image_id = int(elems[0]) 246 | qvec = np.array(tuple(map(float, elems[1:5]))) 247 | tvec = np.array(tuple(map(float, elems[5:8]))) 248 | camera_id = int(elems[8]) 249 | image_name = elems[9] 250 | elems = fid.readline().split() 251 | xys = np.column_stack([tuple(map(float, elems[0::3])), 252 | tuple(map(float, elems[1::3]))]) 253 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 254 | images[image_id] = Image( 255 | id=image_id, qvec=qvec, tvec=tvec, 256 | camera_id=camera_id, name=image_name, 257 | xys=xys, point3D_ids=point3D_ids) 258 | return images 259 | 260 | 261 | def read_colmap_bin_array(path): 262 | """ 263 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 264 | 265 | :param path: path to the colmap binary file. 266 | :return: nd array with the floating point values in the value 267 | """ 268 | with open(path, "rb") as fid: 269 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 270 | usecols=(0, 1, 2), dtype=int) 271 | fid.seek(0) 272 | num_delimiter = 0 273 | byte = fid.read(1) 274 | while True: 275 | if byte == b"&": 276 | num_delimiter += 1 277 | if num_delimiter >= 3: 278 | break 279 | byte = fid.read(1) 280 | array = np.fromfile(fid, np.float32) 281 | array = array.reshape((width, height, channels), order="F") 282 | return np.transpose(array, (1, 0, 2)).squeeze() 283 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from utils.sh_utils import SH2RGB 24 | from scene.gaussian_model import BasicPointCloud 25 | 26 | class CameraInfo(NamedTuple): 27 | uid: int 28 | R: np.array 29 | T: np.array 30 | FovY: np.array 31 | FovX: np.array 32 | image: np.array 33 | image_path: str 34 | image_name: str 35 | width: int 36 | height: int 37 | 38 | class SceneInfo(NamedTuple): 39 | point_cloud: BasicPointCloud 40 | train_cameras: list 41 | test_cameras: list 42 | nerf_normalization: dict 43 | ply_path: str 44 | 45 | def getNerfppNorm(cam_info): 46 | def get_center_and_diag(cam_centers): 47 | cam_centers = np.hstack(cam_centers) 48 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 49 | center = avg_cam_center 50 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 51 | diagonal = np.max(dist) 52 | return center.flatten(), diagonal 53 | 54 | cam_centers = [] 55 | 56 | for cam in cam_info: 57 | W2C = getWorld2View2(cam.R, cam.T) 58 | C2W = np.linalg.inv(W2C) 59 | cam_centers.append(C2W[:3, 3:4]) 60 | 61 | center, diagonal = get_center_and_diag(cam_centers) 62 | radius = diagonal * 1.1 63 | 64 | translate = -center 65 | 66 | return {"translate": translate, "radius": radius} 67 | 68 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 69 | cam_infos = [] 70 | for idx, key in enumerate(cam_extrinsics): 71 | sys.stdout.write('\r') 72 | # the exact output you're looking for: 73 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 74 | sys.stdout.flush() 75 | 76 | extr = cam_extrinsics[key] 77 | intr = cam_intrinsics[extr.camera_id] 78 | height = intr.height 79 | width = intr.width 80 | 81 | uid = intr.id 82 | R = np.transpose(qvec2rotmat(extr.qvec)) 83 | T = np.array(extr.tvec) 84 | 85 | if intr.model=="SIMPLE_PINHOLE": 86 | focal_length_x = intr.params[0] 87 | FovY = focal2fov(focal_length_x, height) 88 | FovX = focal2fov(focal_length_x, width) 89 | elif intr.model=="PINHOLE": 90 | focal_length_x = intr.params[0] 91 | focal_length_y = intr.params[1] 92 | FovY = focal2fov(focal_length_y, height) 93 | FovX = focal2fov(focal_length_x, width) 94 | else: 95 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 96 | 97 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 98 | image_name = os.path.basename(image_path).split(".")[0] 99 | image=Image.open(image_path) 100 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 101 | image_path=image_path, image_name=image_name, width=width, height=height) 102 | cam_infos.append(cam_info) 103 | sys.stdout.write('\n') 104 | return cam_infos 105 | 106 | def fetchPly(path): 107 | plydata = PlyData.read(path) 108 | vertices = plydata['vertex'] 109 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 110 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 111 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 112 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 113 | 114 | def storePly(path, xyz, rgb): 115 | # Define the dtype for the structured array 116 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 117 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 118 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 119 | 120 | normals = np.zeros_like(xyz) 121 | 122 | elements = np.empty(xyz.shape[0], dtype=dtype) 123 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 124 | elements[:] = list(map(tuple, attributes)) 125 | 126 | # Create the PlyData object and write to file 127 | vertex_element = PlyElement.describe(elements, 'vertex') 128 | ply_data = PlyData([vertex_element]) 129 | ply_data.write(path) 130 | 131 | def readColmapSceneInfo(path, images, eval, llffhold=8, testidx=[0], ply_name="points3D.ply"): 132 | try: 133 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 134 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 135 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 136 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 137 | except: 138 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 139 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 140 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 141 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 142 | 143 | reading_dir = "images" if images == None else images 144 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) 145 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 146 | 147 | if eval: 148 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx not in testidx] 149 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx in testidx] 150 | else: 151 | train_cam_infos = cam_infos 152 | test_cam_infos = [] 153 | 154 | nerf_normalization = getNerfppNorm(train_cam_infos) 155 | 156 | ply_path = os.path.join(path, "sparse/0/"+ply_name) 157 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 158 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 159 | if not os.path.exists(ply_path): 160 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 161 | try: 162 | xyz, rgb, _ = read_points3D_binary(bin_path) 163 | except: 164 | xyz, rgb, _ = read_points3D_text(txt_path) 165 | storePly(ply_path, xyz, rgb) 166 | try: 167 | pcd = fetchPly(ply_path) 168 | except: 169 | pcd = None 170 | 171 | scene_info = SceneInfo(point_cloud=pcd, 172 | train_cameras=train_cam_infos, 173 | test_cameras=test_cam_infos, 174 | nerf_normalization=nerf_normalization, 175 | ply_path=ply_path) 176 | return scene_info 177 | 178 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 179 | cam_infos = [] 180 | 181 | with open(os.path.join(path, transformsfile)) as json_file: 182 | contents = json.load(json_file) 183 | fovx = contents["camera_angle_x"] 184 | 185 | frames = contents["frames"] 186 | for idx, frame in enumerate(frames): 187 | cam_name = os.path.join(path, frame["file_path"] + extension) 188 | 189 | matrix = np.linalg.inv(np.array(frame["transform_matrix"])) 190 | R = -np.transpose(matrix[:3,:3]) 191 | R[:,0] = -R[:,0] 192 | T = -matrix[:3, 3] 193 | 194 | image_path = os.path.join(path, cam_name) 195 | image_name = Path(cam_name).stem 196 | image = Image.open(image_path) 197 | 198 | im_data = np.array(image.convert("RGBA")) 199 | 200 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 201 | 202 | norm_data = im_data / 255.0 203 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 204 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 205 | 206 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 207 | FovY = fovy 208 | FovX = fovx 209 | 210 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 211 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) 212 | 213 | return cam_infos 214 | 215 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 216 | print("Reading Training Transforms") 217 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 218 | print("Reading Test Transforms") 219 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 220 | 221 | if not eval: 222 | train_cam_infos.extend(test_cam_infos) 223 | test_cam_infos = [] 224 | 225 | nerf_normalization = getNerfppNorm(train_cam_infos) 226 | 227 | ply_path = os.path.join(path, "points3d.ply") 228 | if not os.path.exists(ply_path): 229 | # Since this data set has no colmap data, we start with random points 230 | num_pts = 100_000 231 | print(f"Generating random point cloud ({num_pts})...") 232 | 233 | # We create random points inside the bounds of the synthetic Blender scenes 234 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 235 | shs = np.random.random((num_pts, 3)) / 255.0 236 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 237 | 238 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 239 | try: 240 | pcd = fetchPly(ply_path) 241 | except: 242 | pcd = None 243 | 244 | scene_info = SceneInfo(point_cloud=pcd, 245 | train_cameras=train_cam_infos, 246 | test_cameras=test_cam_infos, 247 | nerf_normalization=nerf_normalization, 248 | ply_path=ply_path) 249 | return scene_info 250 | 251 | sceneLoadTypeCallbacks = { 252 | "Colmap": readColmapSceneInfo, 253 | "Blender" : readNerfSyntheticInfo 254 | } -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, strip_symmetric, build_scaling_rotation, build_rotation, quaternion_multiply 15 | from utils.debug_utils import save_cal_graph, save_tensor_img 16 | from torch import nn 17 | import os 18 | from utils.system_utils import mkdir_p 19 | from plyfile import PlyData, PlyElement 20 | from utils.sh_utils import RGB2SH, rotate_sh_by_matrix, rotate_sh_by_quaternion 21 | from simple_knn._C import distCUDA2 22 | from utils.graphics_utils import BasicPointCloud 23 | import tinycudann as tcnn 24 | from ntc import NeuralTransformationCache 25 | import commentjson as ctjs 26 | 27 | class GaussianModel: 28 | 29 | def setup_functions(self): 30 | 31 | # @torch.compile 32 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 33 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 34 | actual_covariance = L @ L.transpose(1, 2) 35 | symm = strip_symmetric(actual_covariance) 36 | return symm 37 | 38 | self.scaling_activation = torch.exp 39 | self.scaling_inverse_activation = torch.log 40 | 41 | self.covariance_activation = build_covariance_from_scaling_rotation 42 | self.rotation_compose = quaternion_multiply 43 | self.opacity_activation = torch.sigmoid 44 | self.inverse_opacity_activation = inverse_sigmoid 45 | 46 | self.rotation_activation = torch.nn.functional.normalize 47 | 48 | def __init__(self, sh_degree : int, rotate_sh:bool = False): 49 | self.active_sh_degree = 0 50 | self.max_sh_degree = sh_degree 51 | self._xyz = torch.empty(0) 52 | self._features_dc = torch.empty(0) 53 | self._features_rest = torch.empty(0) 54 | self._scaling = torch.empty(0) 55 | self._rotation = torch.empty(0) 56 | self._opacity = torch.empty(0) 57 | 58 | self._xyz_bound_min = None 59 | self._xyz_bound_max = None 60 | 61 | self._d_xyz = None 62 | self._d_rot = None 63 | self._d_rot_matrix = None 64 | self._d_scaling = None 65 | self._d_opacity = None 66 | 67 | self._new_xyz = None 68 | self._new_rot = None 69 | self._new_scaling = None 70 | self._new_opacity = None 71 | self._new_feature = None 72 | self._rotate_sh=rotate_sh 73 | 74 | self._added_xyz = None 75 | self._added_features_dc = None 76 | self._added_features_rest = None 77 | self._added_opacity = None 78 | self._added_scaling = None 79 | self._added_rotation = None 80 | self._added_mask = None 81 | 82 | self.max_radii2D = torch.empty(0) 83 | self.xyz_gradient_accum = torch.empty(0) 84 | self.color_gradient_accum = torch.empty(0) 85 | self.denom = torch.empty(0) 86 | self.optimizer = None 87 | self.percent_dense = 0 88 | self.spatial_lr_scale = 0 89 | self.setup_functions() 90 | 91 | def capture(self): 92 | return ( 93 | self.active_sh_degree, 94 | self._xyz, 95 | self._features_dc, 96 | self._features_rest, 97 | self._scaling, 98 | self._rotation, 99 | self._opacity, 100 | self.max_radii2D, 101 | self.xyz_gradient_accum, 102 | self.denom, 103 | self.optimizer.state_dict(), 104 | self.spatial_lr_scale, 105 | ) 106 | 107 | def restore(self, model_args, training_args): 108 | (self.active_sh_degree, 109 | self._xyz, 110 | self._features_dc, 111 | self._features_rest, 112 | self._scaling, 113 | self._rotation, 114 | self._opacity, 115 | self.max_radii2D, 116 | xyz_gradient_accum, 117 | denom, 118 | opt_dict, 119 | self.spatial_lr_scale) = model_args 120 | self.training_setup(training_args) 121 | self.xyz_gradient_accum = xyz_gradient_accum 122 | self.denom = denom 123 | self.optimizer.load_state_dict(opt_dict) 124 | 125 | @property 126 | def get_scaling(self): 127 | if self._new_scaling is not None: 128 | return self._new_scaling 129 | elif self._added_scaling is not None: 130 | return self.scaling_activation(torch.cat((self._scaling, self._added_scaling), dim=0)) 131 | else: 132 | return self.scaling_activation(self._scaling) 133 | 134 | @property 135 | def get_rotation(self): 136 | if self._new_rot is not None: 137 | return self._new_rot 138 | elif self._added_rotation is not None: 139 | return self.rotation_activation(torch.cat((self._rotation, self._added_rotation), dim=0)) 140 | else: 141 | return self.rotation_activation(self._rotation) 142 | 143 | @property 144 | def get_xyz(self): 145 | if self._new_xyz is not None: 146 | return self._new_xyz 147 | elif self._added_xyz is not None: 148 | return torch.cat((self._xyz, self._added_xyz), dim=0) 149 | else: 150 | return self._xyz 151 | 152 | @property 153 | def get_features(self): 154 | if self._new_feature is not None: 155 | return self._new_feature 156 | elif self._added_features_dc is not None and self._added_features_rest is not None: 157 | features_dc = torch.cat((self._features_dc, self._added_features_dc), dim=0) 158 | features_rest = torch.cat((self._features_rest, self._added_features_rest), dim=0) 159 | return torch.cat((features_dc, features_rest), dim=1) 160 | else: 161 | features_dc = self._features_dc 162 | features_rest = self._features_rest 163 | return torch.cat((features_dc, features_rest), dim=1) 164 | 165 | @property 166 | def get_opacity(self): 167 | if self._new_opacity is not None: 168 | return self._new_opacity 169 | elif self._added_opacity is not None: 170 | return self.opacity_activation(torch.cat((self._opacity, self._added_opacity), dim=0)) 171 | else: 172 | return self.opacity_activation(self._opacity) 173 | 174 | def get_covariance(self, scaling_modifier = 1): 175 | return self.covariance_activation(self.get_scaling, scaling_modifier, self.get_rotation) 176 | 177 | def oneupSHdegree(self): 178 | if self.active_sh_degree < self.max_sh_degree: 179 | self.active_sh_degree += 1 180 | 181 | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): 182 | self.spatial_lr_scale = spatial_lr_scale 183 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 184 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) 185 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() 186 | features[:, :3, 0 ] = fused_color 187 | features[:, 3:, 1:] = 0.0 188 | 189 | print("Number of points at initialisation : ", fused_point_cloud.shape[0]) 190 | 191 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 192 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) 193 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 194 | rots[:, 0] = 1 195 | 196 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 197 | 198 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 199 | self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) 200 | self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) 201 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 202 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 203 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 204 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 205 | 206 | def training_setup(self, training_args): 207 | self.percent_dense = training_args.percent_dense 208 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 209 | self.color_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 210 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 211 | 212 | l = [ 213 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, 214 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 215 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 216 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 217 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, 218 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} 219 | ] 220 | 221 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 222 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, 223 | lr_final=training_args.position_lr_final*self.spatial_lr_scale, 224 | lr_delay_mult=training_args.position_lr_delay_mult, 225 | max_steps=training_args.position_lr_max_steps) 226 | 227 | def update_learning_rate(self, iteration): 228 | ''' Learning rate scheduling per step ''' 229 | for param_group in self.optimizer.param_groups: 230 | if param_group["name"] == "xyz": 231 | lr = self.xyz_scheduler_args(iteration) 232 | param_group['lr'] = lr 233 | return lr 234 | 235 | def construct_list_of_attributes(self): 236 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 237 | # All channels except the 3 DC 238 | for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): 239 | l.append('f_dc_{}'.format(i)) 240 | for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): 241 | l.append('f_rest_{}'.format(i)) 242 | l.append('opacity') 243 | for i in range(self._scaling.shape[1]): 244 | l.append('scale_{}'.format(i)) 245 | for i in range(self._rotation.shape[1]): 246 | l.append('rot_{}'.format(i)) 247 | return l 248 | 249 | def save_ply(self, path, save_type='all'): 250 | mkdir_p(os.path.dirname(path)) 251 | if save_type=='added': 252 | xyz = self._added_xyz.detach().cpu().numpy() 253 | normals = np.zeros_like(xyz) 254 | f_dc = self._added_features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 255 | f_rest = self._added_features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 256 | opacities = self._added_opacity.detach().cpu().numpy() 257 | scale = self._added_scaling.detach().cpu().numpy() 258 | rotation = self._added_rotation.detach().cpu().numpy() 259 | elif save_type=='origin': 260 | xyz = self._xyz.detach().cpu().numpy() 261 | normals = np.zeros_like(xyz) 262 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 263 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 264 | opacities = self._opacity.detach().cpu().numpy() 265 | scale = self._scaling.detach().cpu().numpy() 266 | rotation = self._rotation.detach().cpu().numpy() 267 | elif save_type=='all': 268 | xyz = self.get_xyz.detach().cpu().numpy() 269 | normals = np.zeros_like(xyz) 270 | f_dc = self.get_features[:,0:1,:].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 271 | f_rest = self.get_features[:,1:,:].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 272 | opacities = self.inverse_opacity_activation(self.get_opacity).detach().cpu().numpy() 273 | scale = self.scaling_inverse_activation(self.get_scaling).detach().cpu().numpy() 274 | rotation = self.get_rotation.detach().cpu().numpy() 275 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 276 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 277 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 278 | elements[:] = list(map(tuple, attributes)) 279 | el = PlyElement.describe(elements, 'vertex') 280 | PlyData([el]).write(path) 281 | 282 | def reset_opacity(self): 283 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) 284 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 285 | self._opacity = optimizable_tensors["opacity"] 286 | 287 | def load_ply(self, path, spatial_lr_scale=0): 288 | plydata = PlyData.read(path) 289 | 290 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 291 | np.asarray(plydata.elements[0]["y"]), 292 | np.asarray(plydata.elements[0]["z"])), axis=1) 293 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 294 | 295 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 296 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 297 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 298 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 299 | 300 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 301 | extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) 302 | assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 303 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 304 | for idx, attr_name in enumerate(extra_f_names): 305 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 306 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 307 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 308 | 309 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 310 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) 311 | scales = np.zeros((xyz.shape[0], len(scale_names))) 312 | for idx, attr_name in enumerate(scale_names): 313 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 314 | 315 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 316 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) 317 | rots = np.zeros((xyz.shape[0], len(rot_names))) 318 | for idx, attr_name in enumerate(rot_names): 319 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 320 | 321 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 322 | self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 323 | self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 324 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 325 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 326 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 327 | self.spatial_lr_scale = spatial_lr_scale 328 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 329 | self.active_sh_degree = self.max_sh_degree 330 | 331 | def replace_tensor_to_optimizer(self, tensor, name): 332 | optimizable_tensors = {} 333 | for group in self.optimizer.param_groups: 334 | if group["name"] == name: 335 | stored_state = self.optimizer.state.get(group['params'][0], None) 336 | stored_state["exp_avg"] = torch.zeros_like(tensor) 337 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 338 | 339 | del self.optimizer.state[group['params'][0]] 340 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 341 | self.optimizer.state[group['params'][0]] = stored_state 342 | 343 | optimizable_tensors[group["name"]] = group["params"][0] 344 | return optimizable_tensors 345 | 346 | def _prune_optimizer(self, mask): 347 | optimizable_tensors = {} 348 | for group in self.optimizer.param_groups: 349 | stored_state = self.optimizer.state.get(group['params'][0], None) 350 | if stored_state is not None: 351 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 352 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 353 | 354 | del self.optimizer.state[group['params'][0]] 355 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 356 | self.optimizer.state[group['params'][0]] = stored_state 357 | 358 | optimizable_tensors[group["name"]] = group["params"][0] 359 | else: 360 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 361 | optimizable_tensors[group["name"]] = group["params"][0] 362 | return optimizable_tensors 363 | 364 | def prune_points(self, mask): 365 | valid_points_mask = ~mask 366 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 367 | 368 | self._xyz = optimizable_tensors["xyz"] 369 | self._features_dc = optimizable_tensors["f_dc"] 370 | self._features_rest = optimizable_tensors["f_rest"] 371 | self._opacity = optimizable_tensors["opacity"] 372 | self._scaling = optimizable_tensors["scaling"] 373 | self._rotation = optimizable_tensors["rotation"] 374 | 375 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 376 | self.color_gradient_accum = self.color_gradient_accum[valid_points_mask] 377 | self.denom = self.denom[valid_points_mask] 378 | self.max_radii2D = self.max_radii2D[valid_points_mask] 379 | 380 | def cat_tensors_to_optimizer(self, tensors_dict): 381 | optimizable_tensors = {} 382 | for group in self.optimizer.param_groups: 383 | assert len(group["params"]) == 1 384 | extension_tensor = tensors_dict[group["name"]] 385 | stored_state = self.optimizer.state.get(group['params'][0], None) 386 | if stored_state is not None: 387 | 388 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) 389 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) 390 | 391 | del self.optimizer.state[group['params'][0]] 392 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 393 | self.optimizer.state[group['params'][0]] = stored_state 394 | 395 | optimizable_tensors[group["name"]] = group["params"][0] 396 | else: 397 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 398 | optimizable_tensors[group["name"]] = group["params"][0] 399 | 400 | return optimizable_tensors 401 | 402 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation): 403 | d = {"xyz": new_xyz, 404 | "f_dc": new_features_dc, 405 | "f_rest": new_features_rest, 406 | "opacity": new_opacities, 407 | "scaling" : new_scaling, 408 | "rotation" : new_rotation} 409 | 410 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 411 | self._xyz = optimizable_tensors["xyz"] 412 | self._features_dc = optimizable_tensors["f_dc"] 413 | self._features_rest = optimizable_tensors["f_rest"] 414 | self._opacity = optimizable_tensors["opacity"] 415 | self._scaling = optimizable_tensors["scaling"] 416 | self._rotation = optimizable_tensors["rotation"] 417 | 418 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 419 | self.color_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 420 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 421 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 422 | 423 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 424 | n_init_points = self.get_xyz.shape[0] 425 | # Extract points that satisfy the gradient condition 426 | padded_grad = torch.zeros((n_init_points), device="cuda") 427 | padded_grad[:grads.shape[0]] = grads.squeeze() 428 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 429 | selected_pts_mask = torch.logical_and(selected_pts_mask, 430 | torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) 431 | 432 | stds = self.get_scaling[selected_pts_mask].repeat(N,1) 433 | means =torch.zeros((stds.size(0), 3),device="cuda") 434 | samples = torch.normal(mean=means, std=stds) 435 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) 436 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 437 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) 438 | new_rotation = self._rotation[selected_pts_mask].repeat(N,1) 439 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) 440 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) 441 | new_opacity = self._opacity[selected_pts_mask].repeat(N,1) 442 | 443 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) 444 | 445 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 446 | self.prune_points(prune_filter) 447 | 448 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 449 | # Extract points that satisfy the gradient condition 450 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 451 | selected_pts_mask = torch.logical_and(selected_pts_mask, 452 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) 453 | 454 | new_xyz = self._xyz[selected_pts_mask] 455 | new_features_dc = self._features_dc[selected_pts_mask] 456 | new_features_rest = self._features_rest[selected_pts_mask] 457 | new_opacities = self._opacity[selected_pts_mask] 458 | new_scaling = self._scaling[selected_pts_mask] 459 | new_rotation = self._rotation[selected_pts_mask] 460 | 461 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) 462 | 463 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 464 | prune_mask=(self.denom==0).squeeze() 465 | self.prune_points(prune_mask) 466 | grads = self.xyz_gradient_accum / self.denom 467 | grads[grads.isnan()] = 0.0 468 | 469 | self.densify_and_clone(grads, max_grad, extent) 470 | self.densify_and_split(grads, max_grad, extent) 471 | 472 | prune_mask = (self.get_opacity < min_opacity).squeeze() 473 | if max_screen_size: 474 | big_points_vs = self.max_radii2D > max_screen_size 475 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 476 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 477 | self.prune_points(prune_mask) 478 | 479 | torch.cuda.empty_cache() 480 | 481 | def adding_postfix(self, added_xyz, added_features_dc, added_features_rest, added_opacities, added_scaling, added_rotation): 482 | d = {"added_xyz": added_xyz, 483 | "added_f_dc": added_features_dc, 484 | "added_f_rest": added_features_rest, 485 | "added_opacity": added_opacities, 486 | "added_scaling" : added_scaling, 487 | "added_rotation" : added_rotation} 488 | 489 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 490 | self._added_xyz = optimizable_tensors["added_xyz"] 491 | self._added_features_dc = optimizable_tensors["added_f_dc"] 492 | self._added_features_rest = optimizable_tensors["added_f_rest"] 493 | self._added_opacity = optimizable_tensors["added_opacity"] 494 | self._added_scaling = optimizable_tensors["added_scaling"] 495 | self._added_rotation = optimizable_tensors["added_rotation"] 496 | 497 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 498 | self.color_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 499 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 500 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 501 | 502 | added_mask=torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=torch.bool) 503 | added_mask[-self._added_xyz.shape[0]:]=True 504 | self._added_mask=added_mask 505 | 506 | def adding_and_clone(self, grads, grad_threshold, scene_extent): 507 | # Extract points that satisfy the gradient condition 508 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 509 | selected_pts_mask = torch.logical_and(selected_pts_mask, 510 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) 511 | 512 | new_xyz = self._xyz[selected_pts_mask] 513 | new_features_dc = self._features_dc[selected_pts_mask] 514 | new_features_rest = self._features_rest[selected_pts_mask] 515 | new_opacities = self._opacity[selected_pts_mask] 516 | new_scaling = self._scaling[selected_pts_mask] 517 | new_rotation = self._rotation[selected_pts_mask] 518 | 519 | self.adding_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) 520 | 521 | def adding_and_split(self, grads, grad_threshold, std_scale, num_of_split=1): 522 | # Extract points that satisfy the gradient condition 523 | contracted_xyz=self.get_contracted_xyz() 524 | mask = (contracted_xyz >= 0) & (contracted_xyz <= 1) 525 | mask = mask.all(dim=1) 526 | num_of_split=num_of_split 527 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 528 | selected_pts_mask = torch.logical_and(selected_pts_mask, mask) 529 | stds = std_scale*self.get_scaling[selected_pts_mask].repeat(num_of_split,1) 530 | means =torch.zeros((stds.size(0), 3),device="cuda") 531 | samples = torch.normal(mean=means, std=stds) 532 | rots = build_rotation(self.get_rotation[selected_pts_mask]).repeat(num_of_split,1,1) 533 | 534 | added_xyz = (torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(num_of_split, 1)).detach().requires_grad_(True) 535 | added_scaling = (self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(num_of_split,1) / (0.8*num_of_split))).detach().requires_grad_(True) 536 | added_rotation = (self.get_rotation[selected_pts_mask].repeat(num_of_split,1)).detach().requires_grad_(True) 537 | added_features_dc = (self.get_features[:,0:1,:][selected_pts_mask].repeat(num_of_split,1,1)).detach().requires_grad_(True) 538 | added_features_rest = (self.get_features[:,1:,:][selected_pts_mask].repeat(num_of_split,1,1)).detach().requires_grad_(True) 539 | added_opacity = (self.inverse_opacity_activation(self.get_opacity[selected_pts_mask]).repeat(num_of_split,1)).detach().requires_grad_(True) 540 | 541 | self.adding_postfix(added_xyz, added_features_dc, added_features_rest, added_opacity, added_scaling, added_rotation) 542 | 543 | def adding_and_prune(self, training_args, extent): 544 | grads = self.xyz_gradient_accum / self.denom 545 | grads[grads.isnan()] = 0.0 546 | if training_args.s2_adding: 547 | self.adding_and_split(grads, training_args.densify_grad_threshold, training_args.std_scale, training_args.num_of_split) 548 | self.prune_added_points(training_args.min_opacity, extent) 549 | 550 | torch.cuda.empty_cache() 551 | 552 | def prune_added_points(self, min_opacity, extent): 553 | prune_mask = (self.get_opacity < min_opacity).squeeze() 554 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 555 | prune_mask = torch.logical_or(prune_mask, big_points_ws)[-self._added_xyz.shape[0]:] 556 | valid_points_mask = ~prune_mask 557 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 558 | 559 | self._added_xyz = optimizable_tensors["added_xyz"] 560 | self._added_features_dc = optimizable_tensors["added_f_dc"] 561 | self._added_features_rest = optimizable_tensors["added_f_rest"] 562 | self._added_opacity = optimizable_tensors["added_opacity"] 563 | self._added_scaling = optimizable_tensors["added_scaling"] 564 | self._added_rotation = optimizable_tensors["added_rotation"] 565 | 566 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 567 | self.color_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 568 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 569 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 570 | 571 | added_mask=torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=torch.bool) 572 | added_mask[-self._added_xyz.shape[0]:]=True 573 | self._added_mask=added_mask 574 | torch.cuda.empty_cache() 575 | 576 | def training_one_frame_s2_setup(self, training_args): 577 | grads = self.xyz_gradient_accum / self.denom 578 | grads[grads.isnan()] = 0.0 579 | 580 | contracted_xyz=self.get_contracted_xyz() 581 | mask = (contracted_xyz >= 0) & (contracted_xyz <= 1) 582 | mask = mask.all(dim=1) 583 | 584 | if training_args.spawn_type=='clone': 585 | # Clone 586 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= training_args.densify_grad_threshold, True, False) 587 | selected_pts_mask = torch.logical_and(selected_pts_mask, mask) 588 | self._added_xyz = self.get_xyz[selected_pts_mask].detach().clone().requires_grad_(True) 589 | self._added_features_dc = self.get_features[:,0:1,:][selected_pts_mask].detach().clone().requires_grad_(True) 590 | self._added_features_rest = self.get_features[:,1:,:][selected_pts_mask].detach().clone().requires_grad_(True) 591 | self._added_opacity = self._opacity[selected_pts_mask].detach().clone().requires_grad_(True) 592 | self._added_scaling = self._scaling[selected_pts_mask].detach().clone().requires_grad_(True) 593 | self._added_rotation = self.get_rotation[selected_pts_mask].detach().clone().requires_grad_(True) 594 | 595 | elif training_args.spawn_type=='split': 596 | # Split 597 | num_of_split=training_args.num_of_split 598 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= training_args.densify_grad_threshold, True, False) 599 | selected_pts_mask = torch.logical_and(selected_pts_mask, mask) 600 | stds = training_args.std_scale*self.get_scaling[selected_pts_mask].repeat(num_of_split,1) 601 | means =torch.zeros((stds.size(0), 3),device="cuda") 602 | samples = torch.normal(mean=means, std=stds) 603 | rots = build_rotation(self.get_rotation[selected_pts_mask]).repeat(num_of_split,1,1) 604 | self._added_xyz = (torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(num_of_split, 1)).detach().requires_grad_(True) 605 | self._added_scaling = (self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(num_of_split,1) / (0.8*num_of_split))).detach().requires_grad_(True) 606 | self._added_rotation = (self.get_rotation[selected_pts_mask].repeat(num_of_split,1)).detach().requires_grad_(True) 607 | self._added_features_dc = (self.get_features[:,0:1,:][selected_pts_mask].repeat(num_of_split,1,1)).detach().requires_grad_(True) 608 | self._added_features_rest = (self.get_features[:,1:,:][selected_pts_mask].repeat(num_of_split,1,1)).detach().requires_grad_(True) 609 | self._added_opacity = (self._opacity[selected_pts_mask].repeat(num_of_split,1)).detach().requires_grad_(True) 610 | 611 | elif training_args.spawn_type=='spawn': 612 | # Spawn 613 | num_of_spawn=training_args.num_of_spawn 614 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= training_args.densify_grad_threshold, True, False) 615 | selected_pts_mask = torch.logical_and(selected_pts_mask, mask) 616 | N=selected_pts_mask.sum() 617 | stds = training_args.std_scale*self.get_scaling[selected_pts_mask].repeat(num_of_spawn,1) 618 | means =torch.zeros((stds.size(0), 3),device="cuda") 619 | samples = torch.normal(mean=means, std=stds) 620 | rots = build_rotation(self.get_rotation[selected_pts_mask]).repeat(num_of_spawn,1,1) 621 | self._added_xyz = (torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(num_of_spawn, 1)).detach().requires_grad_(True) 622 | 623 | # self._added_scaling = self.scaling_inverse_activation(torch.tensor([0.1,0.1,0.1],device='cuda').repeat(N*num_of_spawn, 1)).detach().requires_grad_(True) 624 | self._added_rotation = torch.tensor([1.,0.,0.,0.],device='cuda').repeat(N*num_of_spawn, 1).detach().requires_grad_(True) 625 | # self._added_features_dc = ((torch.ones_like(self.get_features[:,0:1,:][selected_pts_mask])/2).repeat(num_of_spawn,1,1)).detach().requires_grad_(True) 626 | # self._added_features_rest = ((torch.zeros_like(self.get_features[:,1:,:][selected_pts_mask])).repeat(num_of_spawn,1,1)).detach().requires_grad_(True) 627 | self._added_opacity = self.inverse_opacity_activation(torch.tensor([0.1],device='cuda')).repeat(N*num_of_spawn, 1).detach().requires_grad_(True) 628 | 629 | self._added_scaling = (self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(num_of_spawn,1) / (0.8*num_of_spawn))).detach().requires_grad_(True) 630 | # self._added_rotation = (self.get_rotation[selected_pts_mask].repeat(num_of_spawn,1)).detach().requires_grad_(True) 631 | self._added_features_dc = (self.get_features[:,0:1,:][selected_pts_mask].repeat(num_of_spawn,1,1)).detach().requires_grad_(True) 632 | self._added_features_rest = (self.get_features[:,1:,:][selected_pts_mask].repeat(num_of_spawn,1,1)).detach().requires_grad_(True) 633 | # self._added_opacity = (self._opacity[selected_pts_mask].repeat(num_of_spawn,1)).detach().requires_grad_(True) 634 | 635 | elif training_args.spawn_type=='random': 636 | # Spawn 637 | num_of_spawn=training_args.num_of_spawn 638 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= training_args.densify_grad_threshold, True, False) 639 | selected_pts_mask = torch.logical_and(selected_pts_mask, mask) 640 | N=selected_pts_mask.sum() 641 | 642 | self._added_xyz = (torch.rand([N*num_of_spawn,3],device='cuda')*(self._xyz_bound_max-self._xyz_bound_min)+self._xyz_bound_min).detach().requires_grad_(True) 643 | 644 | # self._added_scaling = self.scaling_inverse_activation(torch.tensor([0.1,0.1,0.1],device='cuda').repeat(N*num_of_spawn, 1)).detach().requires_grad_(True) 645 | self._added_rotation = torch.tensor([1.,0.,0.,0.],device='cuda').repeat(N*num_of_spawn, 1).detach().requires_grad_(True) 646 | # self._added_features_dc = ((torch.ones_like(self.get_features[:,0:1,:][selected_pts_mask])/2).repeat(num_of_spawn,1,1)).detach().requires_grad_(True) 647 | # self._added_features_rest = ((torch.zeros_like(self.get_features[:,1:,:][selected_pts_mask])).repeat(num_of_spawn,1,1)).detach().requires_grad_(True) 648 | self._added_opacity = self.inverse_opacity_activation(torch.tensor([0.1],device='cuda')).repeat(N*num_of_spawn, 1).detach().requires_grad_(True) 649 | 650 | self._added_scaling = (self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(num_of_spawn,1) / (0.8*num_of_spawn))).detach().requires_grad_(True) 651 | # self._added_rotation = (self.get_rotation[selected_pts_mask].repeat(num_of_spawn,1)).detach().requires_grad_(True) 652 | self._added_features_dc = (self.get_features[:,0:1,:][selected_pts_mask].repeat(num_of_spawn,1,1)).detach().requires_grad_(True) 653 | self._added_features_rest = (self.get_features[:,1:,:][selected_pts_mask].repeat(num_of_spawn,1,1)).detach().requires_grad_(True) 654 | # self._added_opacity = (self._opacity[selected_pts_mask].repeat(num_of_spawn,1)).detach().requires_grad_(True) 655 | # Optimizer 656 | l = [ 657 | {'params': [self._added_xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "added_xyz"}, 658 | {'params': [self._added_features_dc], 'lr': training_args.feature_lr, "name": "added_f_dc"}, 659 | {'params': [self._added_features_rest], 'lr': training_args.feature_lr / 20.0, "name": "added_f_rest"}, 660 | {'params': [self._added_opacity], 'lr': training_args.opacity_lr, "name": "added_opacity"}, 661 | {'params': [self._added_scaling], 'lr': training_args.scaling_lr, "name": "added_scaling"}, 662 | {'params': [self._added_rotation], 'lr': training_args.rotation_lr, "name": "added_rotation"} 663 | ] 664 | 665 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 666 | 667 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 668 | self.color_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 669 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 670 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 671 | 672 | added_mask=torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=torch.bool) 673 | added_mask[-self._added_xyz.shape[0]:]=True 674 | self._added_mask=added_mask 675 | 676 | torch.cuda.empty_cache() 677 | 678 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 679 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) 680 | self.color_gradient_accum[update_filter] += torch.norm(self._features_dc.grad[update_filter].squeeze(), dim=-1, keepdim=True) 681 | self.denom[update_filter] += 1 682 | 683 | def query_ntc(self): 684 | mask, self._d_xyz, self._d_rot = self.ntc(self._xyz) 685 | 686 | self._new_xyz = self._d_xyz + self._xyz 687 | self._new_rot = self.rotation_compose(self._rotation, self._d_rot) 688 | if self._rotate_sh == True: 689 | self._new_feature = torch.cat((self._features_dc, self._features_rest), dim=1) # [N, SHs, RGB] 690 | 691 | # self._d_rot_matrix=build_rotation(self._d_rot) 692 | # self._new_feature[mask][:,1:4,0] = rotate_sh_by_matrix(self._features_rest[mask][...,0],1,self._d_rot_matrix[mask]) 693 | # self._new_feature[mask][:,1:4,1] = rotate_sh_by_matrix(self._features_rest[mask][...,1],1,self._d_rot_matrix[mask]) 694 | # self._new_feature[mask][:,1:4,2] = rotate_sh_by_matrix(self._features_rest[mask][...,2],1,self._d_rot_matrix[mask]) 695 | 696 | # This is a bit faster... 697 | permuted_feature = self._new_feature.permute(0, 2, 1)[mask] # [N, RGB, SHs] 698 | reshaped_feature = permuted_feature.reshape(-1,4) 699 | repeated_quat = self.rotation_activation(self._d_rot[mask]).repeat(3, 1) 700 | rotated_reshaped_feature = rotate_sh_by_quaternion(sh=reshaped_feature[...,1:],l=1,q=repeated_quat) # [3N, SHs(l=1)] 701 | rotated_permuted_feature = rotated_reshaped_feature.reshape(-1,3,3) # [N, RGB, SHs(l=1)] 702 | self._new_feature[mask][:,1:4]=rotated_permuted_feature.permute(0,2,1) 703 | 704 | 705 | 706 | def update_by_ntc(self): 707 | self._xyz = self.get_xyz.detach() 708 | self._features_dc = self.get_features[:,0:1,:].detach() 709 | self._features_rest = self.get_features[:,1:,:].detach() 710 | self._opacity = self._opacity.detach() 711 | self._scaling = self._scaling.detach() 712 | self._rotation = self.get_rotation.detach() 713 | 714 | self._d_xyz = None 715 | self._d_rot = None 716 | self._d_rot_matrix = None 717 | self._d_scaling = None 718 | self._d_opacity = None 719 | 720 | self._new_xyz = None 721 | self._new_rot = None 722 | self._new_scaling = None 723 | self._new_opacity = None 724 | self._new_feature = None 725 | 726 | def get_contracted_xyz(self): 727 | with torch.no_grad(): 728 | xyz = self.get_xyz 729 | xyz_bound_min, xyz_bound_max = self.get_xyz_bound(86.6) 730 | normalzied_xyz=(xyz-xyz_bound_min)/(xyz_bound_max-xyz_bound_min) 731 | return normalzied_xyz 732 | 733 | def get_xyz_bound(self, percentile=86.6): 734 | with torch.no_grad(): 735 | if self._xyz_bound_min is None: 736 | half_percentile = (100 - percentile) / 200 737 | self._xyz_bound_min = torch.quantile(self._xyz,half_percentile,dim=0) 738 | self._xyz_bound_max = torch.quantile(self._xyz,1 - half_percentile,dim=0) 739 | return self._xyz_bound_min, self._xyz_bound_max 740 | 741 | def training_one_frame_setup(self,training_args): 742 | ntc_conf_path=training_args.ntc_conf_path 743 | with open(ntc_conf_path) as ntc_conf_file: 744 | ntc_conf = ctjs.load(ntc_conf_file) 745 | if training_args.only_mlp: 746 | model=tcnn.Network(n_input_dims=3, n_output_dims=8, network_config=ntc_conf["network"]).to(torch.device("cuda")) 747 | else: 748 | model=tcnn.NetworkWithInputEncoding(n_input_dims=3, n_output_dims=8, encoding_config=ntc_conf["encoding"], network_config=ntc_conf["network"]).to(torch.device("cuda")) 749 | self.ntc=NeuralTransformationCache(model,self.get_xyz_bound()[0],self.get_xyz_bound()[1]) 750 | self.ntc.load_state_dict(torch.load(training_args.ntc_path)) 751 | self._xyz_bound_min = self.ntc.xyz_bound_min 752 | self._xyz_bound_max = self.ntc.xyz_bound_max 753 | if training_args.ntc_lr is not None: 754 | ntc_lr=training_args.ntc_lr 755 | else: 756 | ntc_lr=ntc_conf["optimizer"]["learning_rate"] 757 | self.ntc_optimizer = torch.optim.Adam(self.ntc.parameters(), 758 | lr=ntc_lr) 759 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 760 | self.color_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 761 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 762 | 763 | def get_masked_gaussian(self, mask): 764 | new_gaussian = GaussianModel(self.max_sh_degree) 765 | new_gaussian._xyz = self.get_xyz[mask].detach() 766 | new_gaussian._features_dc = self.get_features[:,0:1,:][mask].detach() 767 | new_gaussian._features_rest = self.get_features[:,1:,:][mask].detach() 768 | new_gaussian._scaling = self.scaling_inverse_activation(self.get_scaling)[mask].detach() 769 | new_gaussian._rotation = self.get_rotation[mask].detach() 770 | new_gaussian._opacity = self.inverse_opacity_activation(self.get_opacity)[mask].detach() 771 | new_gaussian.xyz_gradient_accum = torch.zeros((new_gaussian._xyz.shape[0], 1), device="cuda") 772 | new_gaussian.color_gradient_accum = torch.zeros((new_gaussian._xyz.shape[0], 1), device="cuda") 773 | new_gaussian.denom = torch.zeros((new_gaussian._xyz.shape[0], 1), device="cuda") 774 | new_gaussian.max_radii2D = torch.zeros((new_gaussian._xyz.shape[0]), device="cuda") 775 | return new_gaussian 776 | 777 | def query_ntc_eval(self): 778 | with torch.no_grad(): 779 | mask, self._d_xyz, self._d_rot = self.ntc(self.get_xyz) 780 | 781 | self._new_xyz = self._d_xyz + self._xyz 782 | self._new_rot = self.rotation_compose(self._rotation, self._d_rot) 783 | if self._rotate_sh == True: 784 | self._new_feature = torch.cat((self._features_dc, self._features_rest), dim=1) # [N, SHs, RGB] 785 | # This is a bit faster... 786 | permuted_feature = self._new_feature.permute(0, 2, 1)[mask] # [N, RGB, SHs] 787 | reshaped_feature = permuted_feature.reshape(-1,4) 788 | repeated_quat = self.rotation_activation(self._d_rot[mask]).repeat(3, 1) 789 | rotated_reshaped_feature = rotate_sh_by_quaternion(sh=reshaped_feature[...,1:],l=1,q=repeated_quat) # [3N, SHs(l=1)] 790 | rotated_permuted_feature = rotated_reshaped_feature.reshape(-1,3,3) # [N, RGB, SHs(l=1)] 791 | self._new_feature[mask][:,1:4]=rotated_permuted_feature.permute(0,2,1) -------------------------------------------------------------------------------- /scripts/cache_profile.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "project_directory = '..'\n", 12 | "## Use insert but not append to make sure the python search the project directory first\n", 13 | "sys.path.insert(0, os.path.abspath(project_directory))\n", 14 | "\n", 15 | "import tinycudann as tcnn\n", 16 | "import commentjson as ctjs\n", 17 | "import torch\n", 18 | "import numpy as np\n", 19 | "from typing import NamedTuple\n", 20 | "from plyfile import PlyData, PlyElement\n", 21 | "import torch.nn as nn\n", 22 | "import torch.nn.functional as F\n", 23 | "from scene import Scene, GaussianModel\n", 24 | "scene=\"flame_steak\"\n", 25 | "postfixs=['F_4']\n", 26 | "ntc_conf_paths=['../configs/cache/cache_'+postfix+'.json' for postfix in postfixs]\n", 27 | "pcd_path='../test/flame_steak_suite/flame_steak_init/point_cloud/iteration_15000/point_cloud.ply'\n", 28 | "save_paths=['../ntc/flame_steak_ntc_params_'+postfix+'.pth' for postfix in postfixs]\n", 29 | "ntcs=[]\n", 30 | "gaussians = GaussianModel(1)\n", 31 | "gaussians.load_ply(pcd_path)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "class BasicPointCloud(NamedTuple):\n", 41 | " points : np.array\n", 42 | " colors : np.array\n", 43 | " normals : np.array\n", 44 | "\n", 45 | "def fetchXYZ(path):\n", 46 | " plydata = PlyData.read(path)\n", 47 | " xyz = np.stack((np.asarray(plydata.elements[0][\"x\"]),\n", 48 | " np.asarray(plydata.elements[0][\"y\"]),\n", 49 | " np.asarray(plydata.elements[0][\"z\"])), axis=1)\n", 50 | " return torch.tensor(xyz, dtype=torch.float, device=\"cuda\")\n", 51 | "\n", 52 | "def get_xyz_bound(xyz, percentile=80):\n", 53 | " ## Hard-code the coordinate of the corners here!!\n", 54 | " return torch.tensor([-20, -15, 5]).cuda(), torch.tensor([15, 10, 23]).cuda()\n", 55 | "\n", 56 | "def get_contracted_xyz(xyz):\n", 57 | " xyz_bound_min, xyz_bound_max = get_xyz_bound(xyz, 80)\n", 58 | " normalzied_xyz=(xyz-xyz_bound_min)/(xyz_bound_max-xyz_bound_min)\n", 59 | " return normalzied_xyz\n", 60 | "\n", 61 | "@torch.compile\n", 62 | "def quaternion_multiply(a, b):\n", 63 | " a_norm=nn.functional.normalize(a)\n", 64 | " b_norm=nn.functional.normalize(b)\n", 65 | " w1, x1, y1, z1 = a_norm[:, 0], a_norm[:, 1], a_norm[:, 2], a_norm[:, 3]\n", 66 | " w2, x2, y2, z2 = b_norm[:, 0], b_norm[:, 1], b_norm[:, 2], b_norm[:, 3]\n", 67 | "\n", 68 | " w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2\n", 69 | " x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2\n", 70 | " y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2\n", 71 | " z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2\n", 72 | "\n", 73 | " return torch.stack([w, x, y, z], dim=1)\n", 74 | "\n", 75 | "def quaternion_loss(q1, q2):\n", 76 | " cos_theta = F.cosine_similarity(q1, q2, dim=1)\n", 77 | " cos_theta = torch.clamp(cos_theta, -1+1e-7, 1-1e-7)\n", 78 | " return 1-torch.pow(cos_theta, 2).mean()\n", 79 | "\n", 80 | "def l1loss(network_output, gt):\n", 81 | " return torch.abs((network_output - gt)).mean()" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "from ntc import NeuralTransformationCache\n", 91 | "for idx, ntc_conf_path in enumerate(ntc_conf_paths):\n", 92 | " with open(ntc_conf_path) as ntc_conf_file:\n", 93 | " ntc_conf = ctjs.load(ntc_conf_file)\n", 94 | " model=tcnn.NetworkWithInputEncoding(n_input_dims=3, n_output_dims=8, encoding_config=ntc_conf[\"encoding\"], network_config=ntc_conf[\"network\"]).to(torch.device(\"cuda\"))\n", 95 | " ntc=NeuralTransformationCache(model,torch.tensor([0.,0.,0.]).cuda(),torch.tensor([0.,0.,0.]).cuda())\n", 96 | " ntc.load_state_dict(torch.load(save_paths[idx]))\n", 97 | " ntcs.append(ntc)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "import time\n", 107 | "import torch\n", 108 | "gaussians.ntc = ntc\n", 109 | "gaussians._rotate_sh = True\n", 110 | "\n", 111 | "# Since torch.compile is JIT compilation, we need to call the function once to trigger the compilation\n", 112 | "gaussians.query_ntc_eval()\n", 113 | "\n", 114 | "for idx, ntc in enumerate(ntcs):\n", 115 | " gaussians.ntc = ntc\n", 116 | " torch.cuda.synchronize()\n", 117 | " start = time.time()\n", 118 | " for i in range(300):\n", 119 | " gaussians.query_ntc_eval()\n", 120 | " torch.cuda.synchronize()\n", 121 | " end = time.time()\n", 122 | " print(f\"Time: {((end - start) / 300.0):.5f}s for {postfixs[idx]} in {scene} scene.\")\n" 123 | ] 124 | } 125 | ], 126 | "metadata": { 127 | "kernelspec": { 128 | "display_name": "torch2.0py38", 129 | "language": "python", 130 | "name": "python3" 131 | }, 132 | "language_info": { 133 | "codemirror_mode": { 134 | "name": "ipython", 135 | "version": 3 136 | }, 137 | "file_extension": ".py", 138 | "mimetype": "text/x-python", 139 | "name": "python", 140 | "nbconvert_exporter": "python", 141 | "pygments_lexer": "ipython3", 142 | "version": "3.8.17" 143 | } 144 | }, 145 | "nbformat": 4, 146 | "nbformat_minor": 2 147 | } 148 | -------------------------------------------------------------------------------- /scripts/cache_warmup.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "project_directory = '..'\n", 12 | "sys.path.append(os.path.abspath(project_directory))\n", 13 | "\n", 14 | "import tinycudann as tcnn\n", 15 | "import commentjson as ctjs\n", 16 | "import torch\n", 17 | "import numpy as np\n", 18 | "from typing import NamedTuple\n", 19 | "from plyfile import PlyData, PlyElement\n", 20 | "import torch.nn as nn\n", 21 | "import torch.nn.functional as F" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "postfixs=['F_4']\n", 31 | "ntc_conf_paths=['../configs/cache/cache_'+postfix+'.json' for postfix in postfixs]\n", 32 | "pcd_path='../test/flame_steak_suite/flame_steak_init/point_cloud/iteration_15000/point_cloud.ply'\n", 33 | "save_paths=['../ntc/flame_steak_ntc_params_'+postfix+'.pth' for postfix in postfixs]" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "class BasicPointCloud(NamedTuple):\n", 43 | " points : np.array\n", 44 | " colors : np.array\n", 45 | " normals : np.array\n", 46 | "\n", 47 | "def fetchXYZ(path):\n", 48 | " plydata = PlyData.read(path)\n", 49 | " xyz = np.stack((np.asarray(plydata.elements[0][\"x\"]),\n", 50 | " np.asarray(plydata.elements[0][\"y\"]),\n", 51 | " np.asarray(plydata.elements[0][\"z\"])), axis=1)\n", 52 | " return torch.tensor(xyz, dtype=torch.float, device=\"cuda\")\n", 53 | "\n", 54 | "def get_xyz_bound(xyz, percentile=80):\n", 55 | " ## Hard-code the coordinate of the corners here!!\n", 56 | " return torch.tensor([-20, -15, 5]).cuda(), torch.tensor([15, 10, 23]).cuda()\n", 57 | "\n", 58 | "def get_contracted_xyz(xyz):\n", 59 | " xyz_bound_min, xyz_bound_max = get_xyz_bound(xyz, 80)\n", 60 | " normalzied_xyz=(xyz-xyz_bound_min)/(xyz_bound_max-xyz_bound_min)\n", 61 | " return normalzied_xyz\n", 62 | "\n", 63 | "@torch.compile\n", 64 | "def quaternion_multiply(a, b):\n", 65 | " a_norm=nn.functional.normalize(a)\n", 66 | " b_norm=nn.functional.normalize(b)\n", 67 | " w1, x1, y1, z1 = a_norm[:, 0], a_norm[:, 1], a_norm[:, 2], a_norm[:, 3]\n", 68 | " w2, x2, y2, z2 = b_norm[:, 0], b_norm[:, 1], b_norm[:, 2], b_norm[:, 3]\n", 69 | "\n", 70 | " w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2\n", 71 | " x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2\n", 72 | " y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2\n", 73 | " z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2\n", 74 | "\n", 75 | " return torch.stack([w, x, y, z], dim=1)\n", 76 | "\n", 77 | "def quaternion_loss(q1, q2):\n", 78 | " cos_theta = F.cosine_similarity(q1, q2, dim=1)\n", 79 | " cos_theta = torch.clamp(cos_theta, -1+1e-7, 1-1e-7)\n", 80 | " return 1-torch.pow(cos_theta, 2).mean()\n", 81 | "\n", 82 | "def l1loss(network_output, gt):\n", 83 | " return torch.abs((network_output - gt)).mean()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "ntcs=[]\n", 93 | "for ntc_conf_path in ntc_conf_paths: \n", 94 | " with open(ntc_conf_path) as ntc_conf_file:\n", 95 | " ntc_conf = ctjs.load(ntc_conf_file)\n", 96 | " ntc=tcnn.NetworkWithInputEncoding(n_input_dims=3, n_output_dims=8, encoding_config=ntc_conf[\"encoding\"], network_config=ntc_conf[\"network\"]).to(torch.device(\"cuda\"))\n", 97 | " ntc_optimizer = torch.optim.Adam(ntc.parameters(), lr=1e-4)\n", 98 | " xyz=fetchXYZ(pcd_path)\n", 99 | " normalzied_xyz=get_contracted_xyz(xyz)\n", 100 | " mask = (normalzied_xyz >= 0) & (normalzied_xyz <= 1)\n", 101 | " mask = mask.all(dim=1)\n", 102 | " ntc_inputs=torch.cat([normalzied_xyz[mask]],dim=-1)\n", 103 | " noisy_inputs = ntc_inputs + 0.01 * torch.rand_like(ntc_inputs)\n", 104 | " d_xyz_gt=torch.tensor([0.,0.,0.]).cuda()\n", 105 | " d_rot_gt=torch.tensor([1.,0.,0.,0.]).cuda()\n", 106 | " dummy_gt=torch.tensor([1.]).cuda()\n", 107 | " def cacheloss(resi):\n", 108 | " masked_d_xyz=resi[:,:3]\n", 109 | " masked_d_rot=resi[:,3:7]\n", 110 | " masked_dummy=resi[:,7:8]\n", 111 | " loss_xyz=l1loss(masked_d_xyz,d_xyz_gt)\n", 112 | " loss_rot=quaternion_loss(masked_d_rot,d_rot_gt)\n", 113 | " loss_dummy=l1loss(masked_dummy,dummy_gt)\n", 114 | " loss=loss_xyz+loss_rot+loss_dummy\n", 115 | " return loss\n", 116 | " for iteration in range(0,3000): \n", 117 | " ntc_inputs_w_noisy = torch.cat([noisy_inputs, ntc_inputs, torch.rand_like(ntc_inputs)],dim=0) \n", 118 | " ntc_output=ntc(ntc_inputs_w_noisy)\n", 119 | " loss=cacheloss(ntc_output)\n", 120 | " if iteration % 100 ==0:\n", 121 | " print(loss)\n", 122 | " loss.backward()\n", 123 | " ntc_optimizer.step()\n", 124 | " ntc_optimizer.zero_grad(set_to_none = True)\n", 125 | " ntcs.append(ntc)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "from ntc import NeuralTransformationCache\n", 135 | "for idx, save_path in enumerate(save_paths):\n", 136 | " ntc=NeuralTransformationCache(ntcs[idx],get_xyz_bound(xyz)[0],get_xyz_bound(xyz)[1])\n", 137 | " torch.save(ntc.state_dict(),save_path)" 138 | ] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "display_name": "torch2.0py38", 144 | "language": "python", 145 | "name": "python3" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.8.17" 158 | }, 159 | "orig_nbformat": 4 160 | }, 161 | "nbformat": 4, 162 | "nbformat_minor": 2 163 | } 164 | -------------------------------------------------------------------------------- /scripts/copy_cams.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | 5 | def copy_sparse_to_frames(source, scene): 6 | sparse_dir = os.path.join(source, 'sparse') 7 | if not os.path.isdir(sparse_dir): 8 | print(f"Error: The directory '{sparse_dir}' does not exist.") 9 | return 10 | 11 | for item in os.listdir(scene): 12 | frame_dir = os.path.join(scene, item) 13 | if os.path.isdir(frame_dir) and item.startswith('frame'): 14 | dest_sparse_dir = os.path.join(frame_dir, 'sparse') 15 | if os.path.exists(dest_sparse_dir): 16 | shutil.rmtree(dest_sparse_dir) 17 | shutil.copytree(sparse_dir, dest_sparse_dir) 18 | print(f"Copied to {dest_sparse_dir}") 19 | 20 | def copy_distorted_to_scene(source, scene): 21 | distorted_dir = os.path.join(source, 'distorted') 22 | if not os.path.isdir(distorted_dir): 23 | print(f"Error: The directory '{distorted_dir}' does not exist.") 24 | return 25 | 26 | dest_distorted_dir = os.path.join(scene, 'distorted') 27 | if os.path.exists(dest_distorted_dir): 28 | shutil.rmtree(dest_distorted_dir) 29 | shutil.copytree(distorted_dir, dest_distorted_dir) 30 | print(f"Copied to {dest_distorted_dir}") 31 | 32 | def main(args): 33 | copy_sparse_to_frames(args.source, args.scene) 34 | copy_distorted_to_scene(args.source, args.scene) 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser(description='Copy directories to specified locations.') 38 | parser.add_argument('--source', type=str, help='The source directory containing sparse and distorted folders.') 39 | parser.add_argument('--scene', type=str, help='The scene directory.') 40 | 41 | args = parser.parse_args() 42 | main(args) 43 | -------------------------------------------------------------------------------- /scripts/extract_fvv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import json 4 | 5 | base_paths = ["./Code-Release/flame_steak"] 6 | 7 | for base_path in base_paths: 8 | frame_folders = sorted([f for f in os.listdir(base_path) if f.startswith('frame') and os.path.isdir(os.path.join(base_path, f))]) 9 | 10 | with open(os.path.join(base_path,'cfg_args.json'), 'r') as f: 11 | cfg_args = json.load(f) 12 | ntc_conf_path=cfg_args['ntc_conf_path'] 13 | init_3dg_path=os.path.join(cfg_args['model_path'], 'point_cloud', f'iteration_{cfg_args['first_load_iteration']}', 'point_cloud.ply') 14 | 15 | ntcs_path = os.path.join(base_path, 'NTCs') 16 | addition_3dgs_path = os.path.join(base_path, 'additional_3dgs') 17 | pre_frame_3dgs_path = os.path.join(base_path, 'pre-frame_3dgs') 18 | raw_path = os.path.join(base_path, 'raw') 19 | 20 | os.makedirs(ntcs_path, exist_ok=True) 21 | os.makedirs(addition_3dgs_path, exist_ok=True) 22 | os.makedirs(pre_frame_3dgs_path, exist_ok=True) 23 | os.makedirs(raw_path, exist_ok=True) 24 | 25 | shutil.copy(ntc_conf_path, os.path.join(ntcs_path, 'config.json') ) 26 | shutil.copy(init_3dg_path, os.path.join(base_path, 'init_3dgs.ply') ) 27 | 28 | for folder in frame_folders: 29 | frame_id = int(folder[-6:]) 30 | 31 | ntc_path = os.path.join(base_path, folder, 'NTC.pth') 32 | if os.path.isfile(ntc_path): 33 | ntc_target_path = os.path.join(ntcs_path, f'NTC_{frame_id-1:06}.pth') 34 | shutil.copy(ntc_path, ntc_target_path) 35 | 36 | addition_3dgs_source_path = os.path.join(base_path, folder, 'point_cloud', 'iteration_250', 'added', 'point_cloud.ply') 37 | if os.path.isfile(addition_3dgs_source_path): 38 | addition_3dgs_target_path = os.path.join(addition_3dgs_path, f'additions_{frame_id-1:06}.ply') 39 | shutil.copy(addition_3dgs_source_path, addition_3dgs_target_path) 40 | 41 | shutil.move(os.path.join(base_path, folder), raw_path) 42 | 43 | print(f"Files in {base_path} have been reorganized.") 44 | -------------------------------------------------------------------------------- /test/flame_steak_suite/cfg_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "extent": 0, 3 | "sh_degree": 1, 4 | "source_path": "", 5 | "model_path": "test/flame_steak_suite/flame_steak_init", 6 | "output_path": "output/Code-Release/flame_steak", 7 | "video_path": "dataset/DyNeRF/frames/flame_steak", 8 | "ply_name": "points3D.ply", 9 | "images": "images_2", 10 | "resolution": 1, 11 | "white_background": false, 12 | "data_device": "cuda", 13 | "eval": true, 14 | "iterations": 150, 15 | "iterations_s2": 100, 16 | "first_load_iteration": 15000, 17 | "position_lr_init": 0.0024, 18 | "position_lr_final": 2.4e-05, 19 | "position_lr_delay_mult": 0.01, 20 | "position_lr_max_steps": 30000, 21 | "feature_lr": 0.0375, 22 | "opacity_lr": 0.75, 23 | "scaling_lr": 0.075, 24 | "rotation_lr": 0.015, 25 | "percent_dense": 0.01, 26 | "lambda_dssim": 0.2, 27 | "depth_smooth": 0.0, 28 | "ntc_lr": 0.002, 29 | "lambda_dxyz": 0, 30 | "lambda_drot": 0, 31 | "densification_interval": 20, 32 | "opacity_reset_interval": 3000, 33 | "densify_from_iter": 130, 34 | "densify_until_iter": 15000, 35 | "densify_grad_threshold": 0.00015, 36 | "ntc_conf_path": "configs/cache/cache_F_4.json", 37 | "ntc_path": "ntc/flame_steak_ntc_params_F_4.pth", 38 | "batch_size": 1, 39 | "spawn_type": "spawn", 40 | "s2_type": "spawn", 41 | "s2_adding": true, 42 | "num_of_split": 1, 43 | "num_of_spawn": 1, 44 | "std_scale": 2, 45 | "min_opacity": 0.01, 46 | "rotate_sh": false, 47 | "only_mlp": false, 48 | "convert_SHs_python": false, 49 | "compute_cov3D_python": false, 50 | "debug": false, 51 | "bwd_depth": false, 52 | "opt_type": "3DGStream", 53 | "ip": "127.0.0.1", 54 | "port": 6009, 55 | "debug_from": -1, 56 | "detect_anomaly": false, 57 | "test_iterations": [ 58 | 150, 59 | 250 60 | ], 61 | "save_iterations": [ 62 | 150 63 | ], 64 | "frame_start": 1, 65 | "frame_end": 300, 66 | "quiet": false, 67 | "checkpoint_iterations": [], 68 | "start_checkpoint": null, 69 | "read_config": true, 70 | "load_iteration": 150 71 | } -------------------------------------------------------------------------------- /test/flame_steak_suite/flame_steak_init/point_cloud/iteration_15000/point_cloud.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/flame_steak_init/point_cloud/iteration_15000/point_cloud.ply -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/distorted/database.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/distorted/database.db -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/distorted/sparse/0/cameras.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/distorted/sparse/0/cameras.bin -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/distorted/sparse/0/images.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/distorted/sparse/0/images.bin -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/distorted/sparse/0/points3D.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/distorted/sparse/0/points3D.bin -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/distorted/sparse/0/project.ini: -------------------------------------------------------------------------------- 1 | log_to_stderr=false 2 | random_seed=0 3 | log_level=2 4 | database_path=dataset/DyNeRF/frames/flame_steak/distorted/database.db 5 | image_path=dataset/DyNeRF/frames/flame_steak/frame000000 6 | [Mapper] 7 | ignore_watermarks=false 8 | multiple_models=true 9 | extract_colors=true 10 | ba_refine_focal_length=true 11 | ba_refine_principal_point=false 12 | ba_refine_extra_params=true 13 | fix_existing_images=false 14 | tri_ignore_two_view_tracks=true 15 | min_num_matches=15 16 | max_num_models=50 17 | max_model_overlap=20 18 | min_model_size=10 19 | init_image_id1=-1 20 | init_image_id2=-1 21 | init_num_trials=200 22 | num_threads=-1 23 | ba_min_num_residuals_for_multi_threading=50000 24 | ba_local_num_images=6 25 | ba_local_max_num_iterations=25 26 | ba_global_images_freq=500 27 | ba_global_points_freq=250000 28 | ba_global_max_num_iterations=50 29 | ba_global_max_refinements=5 30 | ba_local_max_refinements=2 31 | snapshot_images_freq=0 32 | init_min_num_inliers=100 33 | init_max_reg_trials=2 34 | abs_pose_min_num_inliers=30 35 | max_reg_trials=3 36 | tri_max_transitivity=1 37 | tri_complete_max_transitivity=5 38 | tri_re_max_trials=1 39 | min_focal_length_ratio=0.10000000000000001 40 | max_focal_length_ratio=10 41 | max_extra_param=1 42 | ba_local_function_tolerance=0 43 | ba_global_images_ratio=1.1000000000000001 44 | ba_global_points_ratio=1.1000000000000001 45 | ba_global_function_tolerance=9.9999999999999995e-07 46 | ba_global_max_refinement_change=0.00050000000000000001 47 | ba_local_max_refinement_change=0.001 48 | init_max_error=4 49 | init_max_forward_motion=0.94999999999999996 50 | init_min_tri_angle=16 51 | abs_pose_max_error=12 52 | abs_pose_min_inlier_ratio=0.25 53 | filter_max_reproj_error=4 54 | filter_min_tri_angle=1.5 55 | local_ba_min_tri_angle=6 56 | tri_create_max_angle_error=2 57 | tri_continue_max_angle_error=2 58 | tri_merge_max_reproj_error=4 59 | tri_complete_max_reproj_error=4 60 | tri_re_max_angle_error=5 61 | tri_re_min_ratio=0.20000000000000001 62 | tri_min_angle=1.5 63 | snapshot_path= 64 | -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam00.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam01.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam02.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam03.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam04.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam05.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam06.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam07.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam08.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam09.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam10.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam11.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam12.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam13.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam14.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam15.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam16.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam17.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam18.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam19.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/images/cam20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/images/cam20.png -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/sparse/0/cameras.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/sparse/0/cameras.bin -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/sparse/0/images.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/sparse/0/images.bin -------------------------------------------------------------------------------- /test/flame_steak_suite/frame000000/sparse/0/points3D.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/test/flame_steak_suite/frame000000/sparse/0/points3D.bin -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | import json 19 | from scene import Scene, GaussianModel 20 | from utils.general_utils import safe_state 21 | import uuid 22 | import kornia 23 | from tqdm import tqdm 24 | from utils.image_utils import psnr 25 | from argparse import ArgumentParser, Namespace 26 | from arguments import ModelParams, PipelineParams, OptimizationParams 27 | try: 28 | from torch.utils.tensorboard import SummaryWriter 29 | TENSORBOARD_FOUND = True 30 | except ImportError: 31 | TENSORBOARD_FOUND = False 32 | 33 | def training(dataset, opt, pipe, load_iteration, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 34 | first_iter = 0 35 | tb_writer = prepare_output_and_logger(dataset) 36 | gaussians = GaussianModel(dataset.sh_degree) 37 | scene = Scene(dataset, gaussians, load_iteration) 38 | gaussians.training_setup(opt) 39 | if checkpoint: 40 | (model_params, first_iter) = torch.load(checkpoint) 41 | gaussians.restore(model_params, opt) 42 | 43 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 44 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 45 | 46 | iter_start = torch.cuda.Event(enable_timing = True) 47 | iter_end = torch.cuda.Event(enable_timing = True) 48 | 49 | viewpoint_stack = None 50 | ema_loss_for_log = 0.0 51 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 52 | first_iter += 1 53 | for iteration in range(first_iter, opt.iterations + 1): 54 | if network_gui.conn == None: 55 | network_gui.try_connect() 56 | while network_gui.conn != None: 57 | try: 58 | net_image_bytes = None 59 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 60 | if custom_cam != None: 61 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 62 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 63 | network_gui.send(net_image_bytes, dataset.source_path) 64 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 65 | break 66 | except Exception as e: 67 | network_gui.conn = None 68 | 69 | iter_start.record() 70 | 71 | gaussians.update_learning_rate(iteration) 72 | 73 | # Every 1000 its we increase the levels of SH up to a maximum degree 74 | if iteration % 1000 == 0: 75 | gaussians.oneupSHdegree() 76 | 77 | loss = torch.tensor(0.).cuda() 78 | for batch_iteraion in range(opt.batch_size): 79 | # Pick a random Camera 80 | if not viewpoint_stack: 81 | viewpoint_stack = scene.getTrainCameras().copy() 82 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 83 | 84 | # Render 85 | if (iteration - 1) == debug_from: 86 | pipe.debug = True 87 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 88 | image, depth, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["depth"],render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 89 | 90 | # Loss 91 | gt_image = viewpoint_cam.original_image.cuda() 92 | Ll1 = l1_loss(image, gt_image) 93 | inv_depth=1./(depth+ 0.0000001).unsqueeze(dim=0) 94 | inv_depth_downsampled=torch.nn.functional.interpolate(inv_depth,scale_factor=0.5) 95 | gt_image_downsampled=torch.nn.functional.interpolate(gt_image.unsqueeze(dim=0),scale_factor=0.5) 96 | Lds = kornia.losses.inverse_depth_smoothness_loss(inv_depth_downsampled,gt_image_downsampled) 97 | loss += (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 98 | if(opt.depth_smooth>0): 99 | loss+=opt.depth_smooth * Lds 100 | loss/=opt.batch_size 101 | loss.backward() 102 | iter_end.record() 103 | 104 | with torch.no_grad(): 105 | # Progress bar 106 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 107 | if iteration % 10 == 0: 108 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 109 | progress_bar.update(10) 110 | if iteration == opt.iterations: 111 | progress_bar.close() 112 | 113 | # Log and save 114 | training_report(tb_writer, iteration, Ll1, Lds, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 115 | if (iteration in saving_iterations): 116 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 117 | scene.save(iteration) 118 | 119 | # Densification 120 | if iteration < opt.densify_until_iter: 121 | # Keep track of max radii in image-space for pruning 122 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 123 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 124 | 125 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 126 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 127 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 128 | 129 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 130 | gaussians.reset_opacity() 131 | 132 | # Optimizer step 133 | if iteration < opt.iterations: 134 | gaussians.optimizer.step() 135 | gaussians.optimizer.zero_grad(set_to_none = True) 136 | 137 | if (iteration in checkpoint_iterations): 138 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 139 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 140 | 141 | def prepare_output_and_logger(args): 142 | if not args.model_path: 143 | if os.getenv('OAR_JOB_ID'): 144 | unique_str=os.getenv('OAR_JOB_ID') 145 | else: 146 | unique_str = str(uuid.uuid4()) 147 | args.model_path = os.path.join("./output/", unique_str[0:10]) 148 | 149 | # Set up output folder 150 | print("Output folder: {}".format(args.model_path)) 151 | os.makedirs(args.model_path, exist_ok = True) 152 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 153 | cfg_log_f.write(str(Namespace(**vars(args)))) 154 | 155 | # Create Tensorboard writer 156 | tb_writer = None 157 | if TENSORBOARD_FOUND: 158 | tb_writer = SummaryWriter(args.model_path) 159 | else: 160 | print("Tensorboard not available: not logging progress") 161 | return tb_writer 162 | 163 | def training_report(tb_writer, iteration, Ll1, Lds, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 164 | if tb_writer: 165 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 166 | tb_writer.add_scalar('train_loss_patches/ds_loss', Lds.item(), iteration) 167 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 168 | tb_writer.add_scalar('iter_time', elapsed, iteration) 169 | 170 | # Report test and samples of training set 171 | if iteration in testing_iterations: 172 | torch.cuda.empty_cache() 173 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 174 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 175 | 176 | for config in validation_configs: 177 | if config['cameras'] and len(config['cameras']) > 0: 178 | l1_test = 0.0 179 | psnr_test = 0.0 180 | for idx, viewpoint in enumerate(config['cameras']): 181 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs) 182 | image, depth = torch.clamp(render_pkg["render"], 0.0, 1.0), render_pkg["depth"] 183 | depth_vis=depth/(depth.max()+1e-5) 184 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 185 | if tb_writer and (idx < 5): 186 | tb_writer.add_image(config['name'] + "_view_{}/render".format(viewpoint.image_name), image, global_step=iteration) 187 | tb_writer.add_image(config['name'] + "_view_{}/depth".format(viewpoint.image_name), depth_vis, global_step=iteration) 188 | if iteration == testing_iterations[0]: 189 | tb_writer.add_image(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image, global_step=iteration) 190 | l1_test += l1_loss(image, gt_image).mean().double() 191 | psnr_test += psnr(image, gt_image).mean().double() 192 | psnr_test /= len(config['cameras']) 193 | l1_test /= len(config['cameras']) 194 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 195 | if tb_writer: 196 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 197 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 198 | 199 | if tb_writer: 200 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 201 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 202 | torch.cuda.empty_cache() 203 | 204 | def train_model(lp,op,pp,args): 205 | args.save_iterations.append(args.iterations) 206 | if args.depth_smooth==0: 207 | args.bwd_depth=False 208 | print("Optimizing " + args.model_path) 209 | 210 | # Initialize system state (RNG) 211 | safe_state(args.quiet) 212 | 213 | # Start GUI server, configure and run training 214 | network_gui.init(args.ip, args.port) 215 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 216 | training(lp.extract(args), op.extract(args), pp.extract(args), args.load_iteration, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 217 | 218 | # All done 219 | print("\nTraining complete.") 220 | 221 | if __name__ == "__main__": 222 | # Set up command line argument parser 223 | parser = ArgumentParser(description="Training script parameters") 224 | lp = ModelParams(parser) 225 | op = OptimizationParams(parser) 226 | pp = PipelineParams(parser) 227 | parser.add_argument('--ip', type=str, default="127.0.0.1") 228 | parser.add_argument('--port', type=int, default=6009) 229 | parser.add_argument('--debug_from', type=int, default=-1) 230 | parser.add_argument('--load_iteration', type=int, default=None) 231 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 232 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 10_000, 15_000, 20_000, 24_000, 27_000, 30_000]) 233 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 10_000, 15_000, 20_000, 24_000, 27_000, 30_000]) 234 | parser.add_argument("--quiet", action="store_true") 235 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 236 | parser.add_argument("--start_checkpoint", type=str, default = None) 237 | parser.add_argument("--read_config", action='store_true', default=False) 238 | parser.add_argument("--config_path", type=str, default = None) 239 | args = parser.parse_args(sys.argv[1:]) 240 | if args.output_path == "": 241 | args.output_path=args.model_path 242 | if args.read_config and args.config_path is not None: 243 | with open(args.config_path, 'r') as f: 244 | config = json.load(f) 245 | for key, value in config.items(): 246 | if key not in ["output_path", "source_path", "model_path"]: 247 | setattr(args, key, value) 248 | serializable_namespace = {k: v for k, v in vars(args).items() if isinstance(v, (int, float, str, bool, list, dict, tuple, type(None)))} 249 | json_namespace = json.dumps(serializable_namespace) 250 | os.makedirs(args.model_path, exist_ok = True) 251 | with open(os.path.join(args.model_path, "cfg_args.json"), 'w') as f: 252 | f.write(json_namespace) 253 | train_model(lp,op,pp,args) -------------------------------------------------------------------------------- /train_frames.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import time 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim, quaternion_loss, d_xyz_gt, d_rot_gt 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | import json 19 | from scene import Scene, GaussianModel 20 | from utils.general_utils import safe_state 21 | import uuid 22 | from tqdm import tqdm 23 | from utils.image_utils import psnr 24 | from utils.debug_utils import save_tensor_img 25 | from argparse import ArgumentParser, Namespace 26 | from arguments import ModelParams, PipelineParams, OptimizationParams 27 | import re 28 | try: 29 | from torch.utils.tensorboard import SummaryWriter 30 | TENSORBOARD_FOUND = True 31 | except ImportError: 32 | TENSORBOARD_FOUND = False 33 | 34 | def training_one_frame(dataset, opt, pipe, load_iteration, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 35 | start_time=time.time() 36 | last_s1_res = [] 37 | last_s2_res = [] 38 | first_iter = 0 39 | tb_writer = prepare_output_and_logger(dataset) 40 | gaussians = GaussianModel(dataset.sh_degree,opt.rotate_sh) 41 | scene = Scene(dataset, gaussians, load_iteration=load_iteration, shuffle=False) 42 | gaussians.training_one_frame_setup(opt) 43 | if checkpoint: 44 | (model_params, first_iter) = torch.load(checkpoint) 45 | gaussians.restore(model_params, opt) 46 | 47 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 48 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 49 | 50 | iter_start = torch.cuda.Event(enable_timing = True) 51 | iter_end = torch.cuda.Event(enable_timing = True) 52 | viewpoint_stack = None 53 | ema_loss_for_log = 0.0 54 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 55 | first_iter += 1 56 | s1_start_time=time.time() 57 | # Train the NTC 58 | for iteration in range(first_iter, opt.iterations + 1): 59 | iter_start.record() 60 | 61 | # gaussians.update_learning_rate(iteration) 62 | 63 | # Every 1000 its we increase the levels of SH up to a maximum degree 64 | if iteration % 1000 == 0: 65 | gaussians.oneupSHdegree() 66 | 67 | # Query the NTC 68 | gaussians.query_ntc() 69 | 70 | loss = torch.tensor(0.).cuda() 71 | 72 | 73 | # A simple 74 | for batch_iteraion in range(opt.batch_size): 75 | 76 | # Pick a random Camera 77 | if not viewpoint_stack: 78 | viewpoint_stack = scene.getTrainCameras().copy() 79 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 80 | 81 | # Render 82 | if (iteration - 1) == debug_from: 83 | pipe.debug = True 84 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 85 | image, depth, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["depth"],render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 86 | 87 | # Loss 88 | gt_image = viewpoint_cam.original_image.cuda() 89 | Ll1 = l1_loss(image, gt_image) 90 | Lds = torch.tensor(0.).cuda() 91 | loss += (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 92 | 93 | loss/=opt.batch_size 94 | loss.backward() 95 | iter_end.record() 96 | 97 | with torch.no_grad(): 98 | # Progress bar 99 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 100 | if iteration % 10 == 0: 101 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 102 | progress_bar.update(10) 103 | if iteration == opt.iterations: 104 | progress_bar.close() 105 | 106 | # Log and save 107 | s1_res = training_report(tb_writer, iteration, Ll1, Lds, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 108 | if s1_res is not None: 109 | last_s1_res.append(s1_res) 110 | if (iteration in saving_iterations): 111 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 112 | scene.save(iteration=iteration, save_type='all') 113 | 114 | # Tracking Densification Stats 115 | if iteration > opt.densify_from_iter: 116 | # Keep track of max radii in image-space for pruning 117 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 118 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 119 | 120 | # Optimizer step 121 | if iteration < opt.iterations: 122 | gaussians.ntc_optimizer.step() 123 | gaussians.ntc_optimizer.zero_grad(set_to_none = True) 124 | 125 | if (iteration in checkpoint_iterations): 126 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 127 | torch.save((gaussians.capture(), iteration), scene.output_path + "/chkpnt" + str(iteration) + ".pth") 128 | 129 | s1_end_time=time.time() 130 | # Densify 131 | if(opt.iterations_s2>0): 132 | # Dump the NTC 133 | scene.dump_NTC() 134 | # Update Gaussians by NTC 135 | gaussians.update_by_ntc() 136 | # Prune, Clone and setting up 137 | gaussians.training_one_frame_s2_setup(opt) 138 | progress_bar = tqdm(range(opt.iterations, opt.iterations + opt.iterations_s2), desc="Training progress of Stage 2") 139 | 140 | # Train the new Gaussians 141 | for iteration in range(opt.iterations + 1, opt.iterations + opt.iterations_s2 + 1): 142 | iter_start.record() 143 | 144 | # Update Learning Rate 145 | # gaussians.update_learning_rate(iteration) 146 | 147 | loss = torch.tensor(0.).cuda() 148 | 149 | for batch_iteraion in range(opt.batch_size): 150 | 151 | # Pick a random Camera 152 | if not viewpoint_stack: 153 | viewpoint_stack = scene.getTrainCameras().copy() 154 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 155 | 156 | # Render 157 | if (iteration - 1) == debug_from: 158 | pipe.debug = True 159 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 160 | image, depth, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["depth"],render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 161 | # Loss 162 | gt_image = viewpoint_cam.original_image.cuda() 163 | Ll1 = l1_loss(image, gt_image) 164 | loss += (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 165 | 166 | loss/=opt.batch_size 167 | loss.backward() 168 | iter_end.record() 169 | 170 | with torch.no_grad(): 171 | # Progress bar 172 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 173 | if (iteration - opt.iterations) % 10 == 0: 174 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 175 | progress_bar.update(10) 176 | if iteration == opt.iterations + opt.iterations_s2: 177 | progress_bar.close() 178 | 179 | # Log and save 180 | s2_res = training_report(tb_writer, iteration, Ll1, Lds, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 181 | if s2_res is not None: 182 | last_s2_res.append(s2_res) 183 | if (iteration in saving_iterations): 184 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 185 | scene.save(iteration=iteration, save_type='added') 186 | 187 | # Densification 188 | if (iteration - opt.iterations) % opt.densification_interval == 0: 189 | gaussians.adding_and_prune(opt,scene.cameras_extent) 190 | 191 | # Optimizer step 192 | if iteration < opt.iterations + opt.iterations_s2: 193 | gaussians.optimizer.step() 194 | gaussians.optimizer.zero_grad(set_to_none = True) 195 | s2_end_time=time.time() 196 | 197 | # 计算总训练时间 198 | pre_time = s1_start_time - start_time 199 | s1_time = s1_end_time - s1_start_time 200 | s2_time = s2_end_time - s1_end_time 201 | 202 | return last_s1_res, last_s2_res, pre_time, s1_time, s2_time 203 | 204 | def prepare_output_and_logger(args): 205 | if not args.output_path: 206 | if os.getenv('OAR_JOB_ID'): 207 | unique_str=os.getenv('OAR_JOB_ID') 208 | else: 209 | unique_str = str(uuid.uuid4()) 210 | args.output_path = os.path.join("./output/", unique_str[0:10]) 211 | 212 | # Set up output folder 213 | print("Output folder: {}".format(args.output_path)) 214 | os.makedirs(args.output_path, exist_ok = True) 215 | with open(os.path.join(args.output_path, "cfg_args"), 'w') as cfg_log_f: 216 | cfg_log_f.write(str(Namespace(**vars(args)))) 217 | 218 | # Create Tensorboard writer 219 | tb_writer = None 220 | if TENSORBOARD_FOUND: 221 | tb_writer = SummaryWriter(args.output_path) 222 | else: 223 | print("Tensorboard not available: not logging progress") 224 | return tb_writer 225 | 226 | def training_report(tb_writer, iteration, Ll1, Lds, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 227 | last_test_psnr=0.0 228 | if tb_writer: 229 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 230 | tb_writer.add_scalar('train_loss_patches/ds_loss', Lds.item(), iteration) 231 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 232 | tb_writer.add_scalar('iter_time', elapsed, iteration) 233 | 234 | # Report test and samples of training set 235 | if iteration in testing_iterations: 236 | torch.cuda.empty_cache() 237 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 238 | # {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]} 239 | ) 240 | 241 | for config in validation_configs: 242 | if config['cameras'] and len(config['cameras']) > 0: 243 | l1_test = 0.0 244 | psnr_test = 0.0 245 | for idx, viewpoint in enumerate(config['cameras']): 246 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs) 247 | # if scene.gaussians._added_mask is not None: 248 | # added_pkg = renderFunc(viewpoint, scene.gaussians.get_masked_gaussian(scene.gaussians._added_mask), *renderArgs) 249 | image, depth = torch.clamp(render_pkg["render"], 0.0, 1.0), render_pkg["depth"] 250 | depth_vis=depth/(depth.max()+1e-5) 251 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 252 | if tb_writer and (idx < 5): 253 | tb_writer.add_image(config['name'] + "_view_{}/render".format(viewpoint.image_name), image, global_step=iteration) 254 | # tb_writer.add_image(config['name'] + "_view_{}/diff".format(viewpoint.image_name), (gt_image-image).abs().mean(dim=0, keepdim=True), global_step=iteration) 255 | # tb_writer.add_image(config['name'] + "_view_{}/depth".format(viewpoint.image_name), depth_vis, global_step=iteration) 256 | # if scene.gaussians._added_mask is not None: 257 | # tb_writer.add_image(config['name'] + "_view_{}/added_gaussians".format(viewpoint.image_name), torch.clamp(added_pkg["render"], 0.0, 1.0), global_step=iteration) 258 | if iteration == testing_iterations[0]: 259 | tb_writer.add_image(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image, global_step=iteration) 260 | l1_test += l1_loss(image, gt_image).mean().double() 261 | psnr_test += psnr(image, gt_image).mean().double() 262 | psnr_test /= len(config['cameras']) 263 | l1_test /= len(config['cameras']) 264 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 265 | if tb_writer: 266 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 267 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 268 | if config['name'] == 'test': 269 | last_test_psnr = psnr_test 270 | last_test_image = image 271 | last_gt = gt_image 272 | 273 | if tb_writer: 274 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 275 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 276 | torch.cuda.empty_cache() 277 | 278 | return {'last_test_psnr':last_test_psnr.cpu().numpy() 279 | , 'last_test_image':last_test_image.cpu() 280 | , 'last_points_num':scene.gaussians.get_xyz.shape[0] 281 | # , 'last_gt':last_gt.cpu() 282 | } 283 | 284 | def train_one_frame(lp,op,pp,args): 285 | args.save_iterations.append(args.iterations + args.iterations_s2) 286 | if args.depth_smooth==0: 287 | args.bwd_depth=False 288 | print("Optimizing " + args.output_path) 289 | res_dict={} 290 | if(args.opt_type=='3DGStream'): 291 | s1_ress, s2_ress, pre_time, s1_time, s2_time = training_one_frame(lp.extract(args), op.extract(args), pp.extract(args), args.load_iteration, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 292 | 293 | # All done 294 | print("\nTraining complete.") 295 | print(f"Preparation: {pre_time}") 296 | if pre_time > 2: 297 | print(f"If preparation is time-consuming, consider down-scaling the images BEFORE running 3DGStream.") 298 | print(f"Stage 1: {s1_time}") 299 | print(f"Stage 2: {s2_time}") 300 | if s1_ress !=[]: 301 | for idx, s1_res in enumerate(s1_ress): 302 | save_tensor_img(s1_res['last_test_image'],os.path.join(args.output_path,f'{idx}_rendering1')) 303 | res_dict[f'stage1/psnr_{idx}']=s1_res['last_test_psnr'] 304 | res_dict[f'stage1/points_num_{idx}']=s1_res['last_points_num'] 305 | res_dict[f'stage1/time']=s1_time 306 | if s2_ress !=[]: 307 | for idx, s2_res in enumerate(s2_ress): 308 | save_tensor_img(s2_res['last_test_image'],os.path.join(args.output_path,f'{idx}_rendering2')) 309 | res_dict[f'stage2/psnr_{idx}']=s2_res['last_test_psnr'] 310 | res_dict[f'stage2/points_num_{idx}']=s2_res['last_points_num'] 311 | res_dict[f'stage2/time']=s2_time 312 | return res_dict 313 | 314 | def train_frames(lp, op, pp, args): 315 | # Initialize system state (RNG) 316 | safe_state(args.quiet) 317 | video_path=args.video_path 318 | output_path=args.output_path 319 | model_path=args.model_path 320 | load_iteration = args.load_iteration 321 | sub_paths = os.listdir(video_path) 322 | pattern = re.compile(r'frame(\d+)') 323 | frames = sorted( 324 | (item for item in sub_paths if pattern.match(item)), 325 | key=lambda x: int(pattern.match(x).group(1)) 326 | ) 327 | frames=frames[args.frame_start:args.frame_end] 328 | if args.frame_start==1: 329 | args.load_iteration = args.first_load_iteration 330 | for frame in frames: 331 | start_time = time.time() 332 | args.source_path = os.path.join(video_path, frame) 333 | args.output_path = os.path.join(output_path, frame) 334 | args.model_path = model_path 335 | train_one_frame(lp,op,pp,args) 336 | print(f"Frame {frame} finished in {time.time()-start_time} seconds.") 337 | model_path = args.output_path 338 | args.load_iteration = load_iteration 339 | torch.cuda.empty_cache() 340 | 341 | 342 | if __name__ == "__main__": 343 | # Set up command line argument parser 344 | parser = ArgumentParser(description="Training script parameters") 345 | lp = ModelParams(parser) 346 | op = OptimizationParams(parser) 347 | pp = PipelineParams(parser) 348 | parser.add_argument('--ip', type=str, default="127.0.0.1") 349 | parser.add_argument('--port', type=int, default=6009) 350 | parser.add_argument('--frame_start', type=int, default=1) 351 | parser.add_argument('--frame_end', type=int, default=150) 352 | parser.add_argument('--load_iteration', type=int, default=None) 353 | parser.add_argument('--debug_from', type=int, default=-1) 354 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 355 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[1, 50, 100]) 356 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1, 50, 100]) 357 | parser.add_argument("--quiet", action="store_true") 358 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 359 | parser.add_argument("--start_checkpoint", type=str, default = None) 360 | parser.add_argument("--read_config", action='store_true', default=False) 361 | parser.add_argument("--config_path", type=str, default = None) 362 | args = parser.parse_args(sys.argv[1:]) 363 | if args.output_path == "": 364 | args.output_path=args.model_path 365 | if args.read_config and args.config_path is not None: 366 | with open(args.config_path, 'r') as f: 367 | config = json.load(f) 368 | for key, value in config.items(): 369 | if key not in ["output_path", "source_path", "model_path", "video_path", "debug_from"]: 370 | setattr(args, key, value) 371 | serializable_namespace = {k: v for k, v in vars(args).items() if isinstance(v, (int, float, str, bool, list, dict, tuple, type(None)))} 372 | json_namespace = json.dumps(serializable_namespace) 373 | os.makedirs(args.output_path, exist_ok = True) 374 | with open(os.path.join(args.output_path, "cfg_args.json"), 'w') as f: 375 | f.write(json_namespace) 376 | # train_one_frame(lp,op,pp,args) 377 | train_frames(lp,op,pp,args) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStream/747ddfef646edf3ea628f2bd13b7bedce7c5fe47/utils/__init__.py -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | 17 | WARNED = False 18 | 19 | def loadCam(args, id, cam_info, resolution_scale): 20 | orig_w, orig_h = cam_info.image.size 21 | 22 | if args.resolution in [1, 2, 4, 8]: 23 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 24 | else: # should be a type that converts to float 25 | if args.resolution == -1: 26 | if orig_w > 1600: 27 | global WARNED 28 | if not WARNED: 29 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 30 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 31 | WARNED = True 32 | global_down = orig_w / 1600 33 | else: 34 | global_down = 1 35 | else: 36 | global_down = orig_w / args.resolution 37 | 38 | scale = float(global_down) * float(resolution_scale) 39 | resolution = (int(orig_w / scale), int(orig_h / scale)) 40 | 41 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 42 | cam_info.image.close() 43 | gt_image = resized_image_rgb[:3, ...] 44 | loaded_mask = None 45 | 46 | if resized_image_rgb.shape[1] == 4: 47 | loaded_mask = resized_image_rgb[3:4, ...] 48 | 49 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 50 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 51 | image=gt_image, gt_alpha_mask=loaded_mask, 52 | image_name=cam_info.image_name, uid=id, data_device=args.data_device) 53 | 54 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 55 | camera_list = [] 56 | 57 | for id, c in enumerate(cam_infos): 58 | camera_list.append(loadCam(args, id, c, resolution_scale)) 59 | 60 | return camera_list 61 | 62 | def camera_to_JSON(id, camera : Camera): 63 | Rt = np.zeros((4, 4)) 64 | Rt[:3, :3] = camera.R.transpose() 65 | Rt[:3, 3] = camera.T 66 | Rt[3, 3] = 1.0 67 | 68 | W2C = np.linalg.inv(Rt) 69 | pos = W2C[:3, 3] 70 | rot = W2C[:3, :3] 71 | serializable_array_2d = [x.tolist() for x in rot] 72 | camera_entry = { 73 | 'id' : id, 74 | 'img_name' : camera.image_name, 75 | 'width' : camera.width, 76 | 'height' : camera.height, 77 | 'position': pos.tolist(), 78 | 'rotation': serializable_array_2d, 79 | 'fy' : fov2focal(camera.FovY, camera.height), 80 | 'fx' : fov2focal(camera.FovX, camera.width) 81 | } 82 | return camera_entry 83 | -------------------------------------------------------------------------------- /utils/debug_utils.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchviz import make_dot 3 | # Used for debugging 4 | def save_tensor_img(img, name='rendering'): 5 | torchvision.utils.save_image(img, name+".png") 6 | 7 | def save_cal_graph(var,name='cal_graph'): 8 | dot = make_dot(var) 9 | dot.format = 'png' 10 | dot.render(filename=name, directory='./', cleanup=True) -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | if resolution != pil_image.size: 23 | resized_image_PIL = pil_image.resize(resolution) 24 | else: 25 | resized_image_PIL = pil_image 26 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 27 | if len(resized_image.shape) == 3: 28 | return resized_image.permute(2, 0, 1) 29 | else: 30 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 31 | 32 | def get_expon_lr_func( 33 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 34 | ): 35 | """ 36 | Copied from Plenoxels 37 | 38 | Continuous learning rate decay function. Adapted from JaxNeRF 39 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 40 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 41 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 42 | function of lr_delay_mult, such that the initial learning rate is 43 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 44 | to the normal learning rate when steps>lr_delay_steps. 45 | :param conf: config subtree 'lr' or similar 46 | :param max_steps: int, the number of steps during optimization. 47 | :return HoF which takes step as input 48 | """ 49 | 50 | def helper(step): 51 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 52 | # Disable this parameter 53 | return 0.0 54 | if lr_delay_steps > 0: 55 | # A kind of reverse cosine decay. 56 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 57 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 58 | ) 59 | else: 60 | delay_rate = 1.0 61 | t = np.clip(step / max_steps, 0, 1) 62 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 63 | return delay_rate * log_lerp 64 | 65 | return helper 66 | 67 | def strip_lowerdiag(L): 68 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 69 | 70 | uncertainty[:, 0] = L[:, 0, 0] 71 | uncertainty[:, 1] = L[:, 0, 1] 72 | uncertainty[:, 2] = L[:, 0, 2] 73 | uncertainty[:, 3] = L[:, 1, 1] 74 | uncertainty[:, 4] = L[:, 1, 2] 75 | uncertainty[:, 5] = L[:, 2, 2] 76 | return uncertainty 77 | 78 | def strip_symmetric(sym): 79 | return strip_lowerdiag(sym) 80 | 81 | # @torch.compile 82 | def build_rotation(r): 83 | q = torch.nn.functional.normalize(r) 84 | 85 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 86 | 87 | # 取出四元数的各个分量 88 | rr, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] 89 | 90 | # 计算重复使用的项 91 | xx, yy, zz = x * x, y * y, z * z 92 | xy, xz, yz = x * y, x * z, y * z 93 | rx, ry, rz = rr * x, rr * y, rr * z 94 | 95 | # 使用就地操作填充旋转矩阵 96 | R[:, 0, 0] = 1 - 2 * (yy + zz) 97 | R[:, 0, 1] = 2 * (xy - rz) 98 | R[:, 0, 2] = 2 * (xz + ry) 99 | R[:, 1, 0] = 2 * (xy + rz) 100 | R[:, 1, 1] = 1 - 2 * (xx + zz) 101 | R[:, 1, 2] = 2 * (yz - rx) 102 | R[:, 2, 0] = 2 * (xz - ry) 103 | R[:, 2, 1] = 2 * (yz + rx) 104 | R[:, 2, 2] = 1 - 2 * (xx + yy) 105 | return R 106 | 107 | def build_scaling_rotation(s, r): 108 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 109 | R = build_rotation(r) 110 | 111 | L[:,0,0] = s[:,0] 112 | L[:,1,1] = s[:,1] 113 | L[:,2,2] = s[:,2] 114 | 115 | L = R @ L 116 | return L 117 | 118 | def safe_state(silent): 119 | old_f = sys.stdout 120 | class F: 121 | def __init__(self, silent): 122 | self.silent = silent 123 | 124 | def write(self, x): 125 | if not self.silent: 126 | if x.endswith("\n"): 127 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 128 | else: 129 | old_f.write(x) 130 | 131 | def flush(self): 132 | old_f.flush() 133 | 134 | sys.stdout = F(silent) 135 | 136 | random.seed(0) 137 | np.random.seed(0) 138 | torch.manual_seed(0) 139 | torch.cuda.set_device(torch.device("cuda:0")) 140 | 141 | @torch.compile() 142 | def quaternion_multiply(a, b): 143 | """ 144 | Multiply two sets of quaternions. 145 | 146 | Parameters: 147 | a (Tensor): A tensor containing N quaternions, shape = [N, 4] 148 | b (Tensor): A tensor containing N quaternions, shape = [N, 4] 149 | 150 | Returns: 151 | Tensor: A tensor containing the product of the input quaternions, shape = [N, 4] 152 | """ 153 | a_norm=torch.nn.functional.normalize(a) 154 | b_norm=torch.nn.functional.normalize(b) 155 | w1, x1, y1, z1 = a_norm[:, 0], a_norm[:, 1], a_norm[:, 2], a_norm[:, 3] 156 | w2, x2, y2, z2 = b_norm[:, 0], b_norm[:, 1], b_norm[:, 2], b_norm[:, 3] 157 | 158 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 159 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 160 | y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 161 | z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 162 | 163 | return torch.stack([w, x, y, z], dim=1) -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from math import exp 17 | 18 | d_xyz_gt=torch.tensor([0.,0.,0.]).cuda() 19 | d_rot_gt=torch.tensor([1.,0.,0.,0.]).cuda() 20 | d_scaling_gt=torch.tensor([1.,1.,1.]).cuda() 21 | d_opacity_gt=torch.tensor([1.]).cuda() 22 | 23 | def l1_loss(network_output, gt): 24 | return torch.abs((network_output - gt)).mean() 25 | 26 | def l2_loss(network_output, gt): 27 | return ((network_output - gt) ** 2).mean() 28 | 29 | def gaussian(window_size, sigma): 30 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 31 | return gauss / gauss.sum() 32 | 33 | def quaternion_loss(q1, q2): 34 | cos_theta = F.cosine_similarity(q1, q2, dim=1) 35 | return 1-torch.pow(cos_theta, 2).mean() 36 | 37 | def create_window(window_size, channel): 38 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 39 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 40 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 41 | return window 42 | 43 | def ssim(img1, img2, window_size=11, size_average=True): 44 | channel = img1.size(-3) 45 | window = create_window(window_size, channel) 46 | 47 | if img1.is_cuda: 48 | window = window.cuda(img1.get_device()) 49 | window = window.type_as(img1) 50 | 51 | return _ssim(img1, img2, window, window_size, channel, size_average) 52 | 53 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 54 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 55 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 56 | 57 | mu1_sq = mu1.pow(2) 58 | mu2_sq = mu2.pow(2) 59 | mu1_mu2 = mu1 * mu2 60 | 61 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 62 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 63 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 64 | 65 | C1 = 0.01 ** 2 66 | C2 = 0.03 ** 2 67 | 68 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 69 | 70 | if size_average: 71 | return ssim_map.mean() 72 | else: 73 | return ssim_map.mean(1).mean(1).mean(1) 74 | 75 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | from utils.general_utils import build_rotation 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | INV_C1 = 2.046653415892977 29 | C2 = [ 30 | 1.0925484305920792, 31 | -1.0925484305920792, 32 | 0.31539156525252005, 33 | -1.0925484305920792, 34 | 0.5462742152960396 35 | ] 36 | C3 = [ 37 | -0.5900435899266435, 38 | 2.890611442640554, 39 | -0.4570457994644658, 40 | 0.3731763325901154, 41 | -0.4570457994644658, 42 | 1.445305721320277, 43 | -0.5900435899266435 44 | ] 45 | C4 = [ 46 | 2.5033429417967046, 47 | -1.7701307697799304, 48 | 0.9461746957575601, 49 | -0.6690465435572892, 50 | 0.10578554691520431, 51 | -0.6690465435572892, 52 | 0.47308734787878004, 53 | -1.7701307697799304, 54 | 0.6258357354491761, 55 | ] 56 | INV_A = torch.tensor([ 57 | [0, 0, INV_C1], 58 | [INV_C1, 0, 0], 59 | [0, INV_C1, 0] 60 | ]).cuda() 61 | INV_A_T = INV_A.T 62 | 63 | def eval_sh(deg, sh, dirs): 64 | """ 65 | Evaluate spherical harmonics at unit directions 66 | using hardcoded SH polynomials. 67 | Works with torch/np/jnp. 68 | ... Can be 0 or more batch dimensions. 69 | Args: 70 | deg: int SH deg. Currently, 0-3 supported 71 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 72 | dirs: jnp.ndarray unit directions [..., 3] 73 | Returns: 74 | [..., C] 75 | """ 76 | assert deg <= 4 and deg >= 0 77 | coeff = (deg + 1) ** 2 78 | assert sh.shape[-1] >= coeff 79 | 80 | result = C0 * sh[..., 0] 81 | if deg > 0: 82 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 83 | result = (result - 84 | C1 * y * sh[..., 1] + 85 | C1 * z * sh[..., 2] - 86 | C1 * x * sh[..., 3]) 87 | 88 | if deg > 1: 89 | xx, yy, zz = x * x, y * y, z * z 90 | xy, yz, xz = x * y, y * z, x * z 91 | result = (result + 92 | C2[0] * xy * sh[..., 4] + 93 | C2[1] * yz * sh[..., 5] + 94 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 95 | C2[3] * xz * sh[..., 7] + 96 | C2[4] * (xx - yy) * sh[..., 8]) 97 | 98 | if deg > 2: 99 | result = (result + 100 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 101 | C3[1] * xy * z * sh[..., 10] + 102 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 103 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 104 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 105 | C3[5] * z * (xx - yy) * sh[..., 14] + 106 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 107 | 108 | if deg > 3: 109 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 110 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 111 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 112 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 113 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 114 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 115 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 116 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 117 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 118 | return result 119 | 120 | def RGB2SH(rgb): 121 | return (rgb - 0.5) / C0 122 | 123 | def SH2RGB(sh): 124 | return sh * C0 + 0.5 125 | 126 | # @torch.compile 127 | def p_eval(dirs, l): 128 | if l==1: 129 | r = torch.norm(dirs, dim=1, keepdim=True) 130 | return (C1 * dirs / r).roll(shifts=-1,dims=1) 131 | else: 132 | raise NotImplementedError("Not implemented yet") 133 | 134 | # @torch.compile 135 | def rotate_sh_by_quaternion(sh,l,q): 136 | M=build_rotation(q) 137 | return rotate_sh_by_matrix(sh,l,M) 138 | 139 | # @torch.compile 140 | def rotate_sh_by_matrix(sh,l,M): 141 | if l==0: 142 | return sh 143 | if l==1: 144 | S = p_eval(M, 1) 145 | x = torch.matmul(sh, INV_A_T) 146 | x = torch.bmm(S,x.unsqueeze(2)) 147 | return x.squeeze() 148 | else: 149 | raise NotImplementedError("Not implemented yet") -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------