├── .gitignore ├── .gitmodules ├── .idea ├── .gitignore ├── 4DGaussians-master.iml ├── ST_4DGS.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml └── modules.xml ├── LICENSE.md ├── README.md ├── arguments ├── DyNeRF.py ├── Dynamic.py ├── Outdoor.py └── __init__.py ├── asset ├── Ballon.mp4 ├── ST-4DGS.jpg └── cut_roasted_beef.mp4 ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── render.sh ├── requirements.txt ├── scene ├── KDTree.py ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset.py ├── dataset_readers.py ├── deformation.py ├── external.py ├── gaussian_model.py ├── getData.py ├── hexplane.py ├── neural_3D_dataset_NDC.py ├── regulation.py └── utils.py ├── scripts ├── convert.py └── getFlow.py ├── submodules ├── depth-diff-gaussian-rasterization │ ├── .gitignore │ ├── .gitmodules │ ├── CMakeLists.txt │ ├── LICENSE.md │ ├── README.md │ ├── cuda_rasterizer │ │ ├── auxiliary.h │ │ ├── backward.cu │ │ ├── backward.h │ │ ├── config.h │ │ ├── forward.cu │ │ ├── forward.h │ │ ├── rasterizer.h │ │ ├── rasterizer_impl.cu │ │ └── rasterizer_impl.h │ ├── diff_gaussian_rasterization │ │ └── __init__.py │ ├── ext.cpp │ ├── rasterize_points.cu │ ├── rasterize_points.h │ ├── setup.py │ └── third_party │ │ ├── glm.zip │ │ └── stbi_image_write.h └── simple-knn │ ├── dist │ └── simple_knn-0.0.0-py3.9-win-amd64.egg │ ├── ext.cpp │ ├── setup.py │ ├── simple_knn.cu │ ├── simple_knn.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt │ ├── simple_knn.h │ ├── simple_knn │ └── .gitkeep │ ├── spatial.cu │ └── spatial.h ├── train.py ├── train.sh └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── params_utils.py ├── scene_utils.py ├── sh_utils.py ├── system_utils.py └── timer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | data/ 10 | data 11 | argument/ 12 | scripts/ 13 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/depth-diff-gaussian=rasterization"] 2 | path = submodules/depth-diff-gaussian=rasterization 3 | url = https://github.com/ingra14m/depth-diff-gaussian-rasterization.git 4 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/4DGaussians-master.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/ST_4DGS.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 21 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # ST-4DGS: Spatial-Temporally Consistent 4D Gaussian Splatting for Efficient Dynamic Scene Rendering 3 | *** 4 | ### SIGGRAPH 2024 | [Paper](https://dlnext.acm.org/doi/10.1145/3641519.3657520) 5 | *** 6 | Deqi Li1, Shi-Sheng Huang1, Zhiyuan Lu1, Xinran Duan1, Hua Huang1✉ 7 | 8 | 1School of Artificial Intelligence, Beijing Normal University; Corresponding Author. 9 | *** 10 | 11 | ![block](asset/ST-4DGS.jpg) 12 | 13 | 14 | Our method guarantee the compactness of the 4D Gaussians that adhere to the surface in 15 | motion objects. It achieve high-fidelity dynamic rendering quality and maintains real-time rendering efficiency. 16 | 17 | 18 | 21 | 22 | *** 23 | 24 | ## Environmental Setups 25 | Please follow the [3D-GS](https://github.com/graphdeco-inria/gaussian-splatting) to install the relative packages. And install some necessary environments in ```requirements.txt```. 26 | 27 | ```bash 28 | git clone https://github.com/wanglids/ST-4DGS 29 | cd ST-4DGS 30 | conda create -n ST4DGS python=3.9 31 | conda activate ST4DGS 32 | 33 | pip install -r requirements.txt 34 | pip install -e submodules/depth-diff-gaussian-rasterization 35 | pip install -e submodules/simple-knn 36 | ``` 37 | 38 | ## Data Preparation 39 | We evaluate the proposed ST-4DGS on three publicly available datasets of dynamic scenes, namely [DyNeRF](https://github.com/facebookresearch/Neural_3D_Video), [ENeRF-Outdoor](https://github.com/zju3dv/ENeRF/blob/master/docs/enerf_outdoor.md)), and [Dynamic Scene](https://gorokee.github.io/jsyoon/dynamic_synth/). Download datasets form these links, you should extract the frames of each video and then organize your dataset as follows. 40 | ``` 41 | |── data 42 | | |── DyNeRF 43 | | |── cook_spinach 44 | | |── cam00 45 | | |── ... 46 | | |── cam08 47 | | |── images 48 | | |── 0.png 49 | | |── 1.png 50 | | |── 2.png 51 | | |── ... 52 | | |── flow 53 | | |── 0.npy 54 | | |── 1.npy 55 | | |── 2.npy 56 | | |── ... 57 | 58 | | |── ... 59 | | |── cam19 60 | | |── colmap 61 | | |── input 62 | | |── cam00.png 63 | | |── ... 64 | | |── cam19.png 65 | | |── ... 66 | | |── ENeRF-Outdoor 67 | | |── ... 68 | | |── Dynamic Scene 69 | | |── ... 70 | ``` 71 | The ```colmap/input``` folder is the collection data of different cameras at the same time. Calculate camera parameters and initialize Gaussians based on [COLMAP](https://github.com/colmap/colmap) (execute ```python scripts/convert.py```). The optical flow is estimated by [RAFT](https://github.com/princeton-vl/RAFT). You can place **scripts/getFlow.py** in the installation root directory of [RAFT](https://github.com/princeton-vl/RAFT) (such as ./submodels/RAFT) and then estimate the optical flow via running 72 | ``` 73 | cd $ ROOT_PATH/submodels/RAFT 74 | python getFlow.py --source_path rootpath/data/DyNeRF/cook_spinach --win_size timestep 75 | ``` 76 | 77 | 78 | ## Training 79 | For cook_spinach dataset, run 80 | ``` 81 | cd $ ROOT_PATH/ 82 | python train.py --source_path rootpath/data/DyNeRF/cook_spinach --model_path output/test --configs arguments/DyNeRF.py 83 | #The results will be saved in rootpath/data/cook_spinach/output/test 84 | ``` 85 | 86 | 87 | ## Rendering 88 | You can download [pre trained data and models](https://drive.google.com/drive/folders/1VdLo514HKJdQPUb5vYIXPdVW1xsXqKzB?usp=drive_link) and place them in the **output/test** folder. Run the following script to render the images. 89 | ``` 90 | cd $ ROOT_PATH/ 91 | python render.py --source_path rootpath/data/DyNeRF/cook_spinach --model_path output/test --configs arguments/DyNeRF.py 92 | ``` 93 | 94 | --- 95 | 96 | ## Citation 97 | If you find this code useful for your research, welcome to cite the following paper: 98 | ``` 99 | @inproceedings{Li2024ST, 100 | author = {Li, Deqi and Huang, Shi-Sheng and Lu, Zhiyuan and Duan, Xinran and Huang, Hua}, 101 | title = {ST-4DGS: Spatial-Temporally Consistent 4D Gaussian Splatting for Efficient Dynamic Scene Rendering}, 102 | publisher = {Association for Computing Machinery}, 103 | address = {New York, NY, USA}, 104 | booktitle = {ACM SIGGRAPH 2024 Conference Papers}, 105 | location = {Denver, CO, USA}, 106 | } 107 | ``` 108 | ## Acknowledgments 109 | Our training code is build upon [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), [4DGS](https://github.com/hustvl/4DGaussians), [D3DGS](https://dynamic3dgaussians.github.io/). We sincerely appreciate these excellent works. 110 | 111 | -------------------------------------------------------------------------------- /arguments/DyNeRF.py: -------------------------------------------------------------------------------- 1 | ModelHiddenParams = dict( 2 | kplanes_config = { 3 | 'grid_dimensions': 2, 4 | 'input_coordinate_dim': 4, 5 | 'output_coordinate_dim': 16, 6 | 'resolution': [64, 64, 64, 150] 7 | }, 8 | multires = [1,2,4,8], 9 | defor_depth = 2, 10 | net_width = 256, 11 | plane_tv_weight = 0.0002, 12 | time_smoothness_weight = 0.001, 13 | l1_time_planes = 0.001, 14 | 15 | eval_index=0, 16 | is_short=False, 17 | ) 18 | OptimizationParams = dict( 19 | dataloader=True, 20 | iterations = 30_000, 21 | coarse_iterations = 3000, 22 | densify_until_iter = 15_000, 23 | opacity_reset_interval = 6000, 24 | 25 | opacity_threshold_coarse = 0.005, 26 | opacity_threshold_fine_init = 0.005, 27 | opacity_threshold_fine_after = 0.005, 28 | 29 | coarse_neighbors=200, 30 | fine_neighbors=200, 31 | coarse_std=1, 32 | fine_std=2, 33 | ) -------------------------------------------------------------------------------- /arguments/Dynamic.py: -------------------------------------------------------------------------------- 1 | ModelHiddenParams = dict( 2 | kplanes_config = { 3 | 'grid_dimensions': 2, 4 | 'input_coordinate_dim': 4, 5 | 'output_coordinate_dim': 16, 6 | 'resolution': [64, 64, 64, 150] 7 | }, 8 | multires = [1,2,4,8], 9 | defor_depth = 2, 10 | net_width = 256, 11 | plane_tv_weight = 0.0002, 12 | time_smoothness_weight = 0.001, 13 | l1_time_planes = 0.001, 14 | eval_index=3, 15 | is_short=False, 16 | 17 | ) 18 | OptimizationParams = dict( 19 | dataloader=True, 20 | iterations = 30_000, 21 | coarse_iterations = 3000, 22 | densify_until_iter = 15_000, 23 | opacity_reset_interval = 6000, 24 | 25 | opacity_threshold_coarse = 0.005, 26 | opacity_threshold_fine_init = 0.005, 27 | opacity_threshold_fine_after = 0.005, 28 | 29 | coarse_neighbors=100, 30 | fine_neighbors=100, 31 | coarse_std=1.6, 32 | fine_std=3, 33 | ) -------------------------------------------------------------------------------- /arguments/Outdoor.py: -------------------------------------------------------------------------------- 1 | ModelHiddenParams = dict( 2 | kplanes_config = { 3 | 'grid_dimensions': 2, 4 | 'input_coordinate_dim': 4, 5 | 'output_coordinate_dim': 16, 6 | 'resolution': [64, 64, 64, 150] 7 | }, 8 | multires = [1,2,4,8], 9 | defor_depth = 2, 10 | net_width = 256, 11 | plane_tv_weight = 0.0002, 12 | time_smoothness_weight = 0.001, 13 | l1_time_planes = 0.001, 14 | 15 | eval_index = 8, 16 | is_short = True, 17 | 18 | ) 19 | OptimizationParams = dict( 20 | dataloader=True, 21 | iterations = 30_000, 22 | coarse_iterations = 3000, 23 | densify_until_iter = 15_000, 24 | opacity_reset_interval = 6000, 25 | 26 | opacity_threshold_coarse = 0.005, 27 | opacity_threshold_fine_init = 0.005, 28 | opacity_threshold_fine_after = 0.005, 29 | 30 | coarse_neighbors = 200, 31 | fine_neighbors = 100, 32 | coarse_std = 1, 33 | fine_std = 2, 34 | ) -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | import sys 3 | import os 4 | 5 | class GroupParams: 6 | pass 7 | 8 | class ParamGroup: 9 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 10 | group = parser.add_argument_group(name) 11 | for key, value in vars(self).items(): 12 | shorthand = False 13 | if key.startswith("_"): 14 | shorthand = True 15 | key = key[1:] 16 | t = type(value) 17 | value = value if not fill_none else None 18 | if shorthand: 19 | if t == bool: 20 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 21 | else: 22 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 23 | else: 24 | if t == bool: 25 | group.add_argument("--" + key, default=value, action="store_true") 26 | else: 27 | group.add_argument("--" + key, default=value, type=t) 28 | 29 | def extract(self, args): 30 | group = GroupParams() 31 | for arg in vars(args).items(): 32 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 33 | setattr(group, arg[0], arg[1]) 34 | return group 35 | 36 | class ModelParams(ParamGroup): 37 | def __init__(self, parser, sentinel=False): 38 | self.sh_degree = 3 39 | self._source_path = "" 40 | self._model_path = "" 41 | self._images = "images" 42 | self._resolution = -1 43 | self._white_background = True 44 | self.data_device = "cuda" 45 | self.eval = True 46 | self.render_process=False 47 | self.is_short = False 48 | self.eval_index = 0 49 | super().__init__(parser, "Loading Parameters", sentinel) 50 | 51 | def extract(self, args): 52 | g = super().extract(args) 53 | g.source_path = os.path.abspath(g.source_path) 54 | return g 55 | 56 | class PipelineParams(ParamGroup): 57 | def __init__(self, parser): 58 | self.convert_SHs_python = False 59 | self.compute_cov3D_python = False 60 | self.debug = False 61 | super().__init__(parser, "Pipeline Parameters") 62 | class ModelHiddenParams(ParamGroup): 63 | def __init__(self, parser): 64 | self.net_width = 64 65 | self.timebase_pe = 4 66 | self.defor_depth = 1 67 | self.posebase_pe = 10 68 | self.scale_rotation_pe = 2 69 | self.opacity_pe = 2 70 | self.timenet_width = 64 71 | self.timenet_output = 32 72 | self.bounds = 1.6 73 | self.plane_tv_weight = 0.0001 74 | self.time_smoothness_weight = 0.01 75 | self.l1_time_planes = 0.0001 76 | self.kplanes_config = { 77 | 'grid_dimensions': 2, 78 | 'input_coordinate_dim': 4, 79 | 'output_coordinate_dim': 32, 80 | 'resolution': [64, 64, 64, 25] 81 | } 82 | self.multires = [1, 2, 4, 8] 83 | self.no_grid=False 84 | self.no_ds=False 85 | self.no_dr=False 86 | self.no_do=True 87 | 88 | 89 | super().__init__(parser, "ModelHiddenParams") 90 | 91 | class OptimizationParams(ParamGroup): 92 | def __init__(self, parser): 93 | self.dataloader=False 94 | self.iterations = 30_000 95 | self.coarse_iterations = 3000 96 | self.position_lr_init = 0.00016 97 | self.position_lr_final = 0.0000016 98 | self.position_lr_delay_mult = 0.01 99 | self.position_lr_max_steps = 20_000 100 | self.deformation_lr_init = 0.00016 101 | self.deformation_lr_final = 0.000016 102 | self.deformation_lr_delay_mult = 0.01 103 | self.grid_lr_init = 0.0016 104 | self.grid_lr_final = 0.00016 105 | 106 | self.feature_lr = 0.0025 107 | self.opacity_lr = 0.05 108 | self.scaling_lr = 0.005 109 | self.rotation_lr = 0.001 110 | self.percent_dense = 0.01 111 | self.lambda_dssim = 0.2 112 | self.lambda_lpips = 0 113 | 114 | self.lambda_ani = 0.2 115 | self.lambda_loc = 0.001 116 | self.lambda_tem = 0.01 117 | 118 | self.coarse_neighbors = 200 119 | self.fine_neighbors = 200 120 | self.coarse_std = 1 121 | self.fine_std = 2 122 | 123 | 124 | self.weight_constraint_init= 1 125 | self.weight_constraint_after = 0.2 126 | self.weight_decay_iteration = 5000 127 | self.opacity_reset_interval = 3000 128 | self.densification_interval = 100 129 | self.densify_from_iter = 500 130 | self.densify_until_iter = 15_000 131 | self.densify_grad_threshold_coarse = 0.0002 132 | self.densify_grad_threshold_fine_init = 0.0002 133 | self.densify_grad_threshold_after = 0.0002 134 | self.pruning_from_iter = 500 135 | self.pruning_interval = 100 136 | self.opacity_threshold_coarse = 0.005 137 | self.opacity_threshold_fine_init = 0.005 138 | self.opacity_threshold_fine_after = 0.005 139 | 140 | super().__init__(parser, "Optimization Parameters") 141 | 142 | def get_combined_args(parser : ArgumentParser): 143 | cmdlne_string = sys.argv[1:] 144 | cfgfile_string = "Namespace()" 145 | args_cmdline = parser.parse_args(cmdlne_string) 146 | 147 | try: 148 | cfgfilepath = os.path.join(args_cmdline.source_path,args_cmdline.model_path, "cfg_args") 149 | print("Looking for config file in", cfgfilepath) 150 | with open(cfgfilepath) as cfg_file: 151 | print("Config file found: {}".format(cfgfilepath)) 152 | cfgfile_string = cfg_file.read() 153 | except TypeError: 154 | print("Config file not found at") 155 | pass 156 | args_cfgfile = eval(cfgfile_string) 157 | 158 | merged_dict = vars(args_cfgfile).copy() 159 | for k,v in vars(args_cmdline).items(): 160 | if v != None: 161 | merged_dict[k] = v 162 | return Namespace(**merged_dict) 163 | -------------------------------------------------------------------------------- /asset/Ballon.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglids/ST-4DGS/bf0dbb13e76bf41b2c2a4ca64063e5d346db7c74/asset/Ballon.mp4 -------------------------------------------------------------------------------- /asset/ST-4DGS.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglids/ST-4DGS/bf0dbb13e76bf41b2c2a4ca64063e5d346db7c74/asset/ST-4DGS.jpg -------------------------------------------------------------------------------- /asset/cut_roasted_beef.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglids/ST-4DGS/bf0dbb13e76bf41b2c2a4ca64063e5d346db7c74/asset/cut_roasted_beef.mp4 -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 4 | from scene.gaussian_model import GaussianModel 5 | from utils.sh_utils import eval_sh 6 | from scene.getData import get_model_data 7 | 8 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, stage="fine"): 9 | """ 10 | Render the scene. 11 | 12 | Background tensor (bg_color) must be on GPU! 13 | """ 14 | 15 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 16 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 17 | try: 18 | screenspace_points.retain_grad() 19 | except: 20 | pass 21 | 22 | # Set up rasterization configuration 23 | 24 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 25 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 26 | 27 | raster_settings = GaussianRasterizationSettings( 28 | image_height=int(viewpoint_camera.image_height), 29 | image_width=int(viewpoint_camera.image_width), 30 | tanfovx=tanfovx, 31 | tanfovy=tanfovy, 32 | bg=bg_color, 33 | scale_modifier=scaling_modifier, 34 | viewmatrix=viewpoint_camera.world_view_transform.cuda(), 35 | projmatrix=viewpoint_camera.full_proj_transform.cuda(), 36 | sh_degree=pc.active_sh_degree, 37 | campos=viewpoint_camera.camera_center.cuda(), 38 | prefiltered=False, 39 | debug=pipe.debug 40 | ) 41 | 42 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 43 | time = torch.tensor(viewpoint_camera.time).to(pc.get_xyz.device).repeat(pc.get_xyz.shape[0], 1) 44 | means3D_final, scales_final, rotations_final, opacity = get_model_data(pc, time, stage) 45 | shs = None 46 | colors_precomp = None 47 | if override_color is None: 48 | if pipe.convert_SHs_python: 49 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 50 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.cuda().repeat(pc.get_features.shape[0], 1)) 51 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 52 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 53 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 54 | else: 55 | shs = pc.get_features 56 | else: 57 | colors_precomp = override_color 58 | 59 | cov3D_precomp = None 60 | means2D = screenspace_points 61 | 62 | 63 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 64 | rendered_image, radii, depth = rasterizer( 65 | means3D = means3D_final, 66 | means2D = means2D, 67 | shs = shs, 68 | colors_precomp = colors_precomp, 69 | opacities = opacity, 70 | scales = scales_final, 71 | rotations = rotations_final, 72 | cov3D_precomp = cov3D_precomp) 73 | 74 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 75 | # They will be excluded from value updates used in the splitting criteria. 76 | return {"render": rendered_image, 77 | "viewspace_points": screenspace_points, 78 | "visibility_filter" : radii > 0, 79 | "radii": radii, 80 | "depth":depth} 81 | 82 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import traceback 3 | import socket 4 | import json 5 | from scene.cameras import MiniCam 6 | 7 | host = "127.0.0.1" 8 | port = 6009 9 | 10 | conn = None 11 | addr = None 12 | 13 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 14 | 15 | def init(wish_host, wish_port): 16 | global host, port, listener 17 | host = wish_host 18 | port = wish_port 19 | listener.bind((host, port)) 20 | listener.listen() 21 | listener.settimeout(0) 22 | 23 | def try_connect(): 24 | global conn, addr, listener 25 | try: 26 | conn, addr = listener.accept() 27 | print(f"\nConnected by {addr}") 28 | conn.settimeout(None) 29 | except Exception as inst: 30 | pass 31 | 32 | def read(): 33 | global conn 34 | messageLength = conn.recv(4) 35 | messageLength = int.from_bytes(messageLength, 'little') 36 | message = conn.recv(messageLength) 37 | return json.loads(message.decode("utf-8")) 38 | 39 | def send(message_bytes, verify): 40 | global conn 41 | if message_bytes != None: 42 | conn.sendall(message_bytes) 43 | conn.sendall(len(verify).to_bytes(4, 'little')) 44 | conn.sendall(bytes(verify, 'ascii')) 45 | 46 | def receive(): 47 | message = read() 48 | 49 | width = message["resolution_x"] 50 | height = message["resolution_y"] 51 | 52 | if width != 0 and height != 0: 53 | try: 54 | do_training = bool(message["train"]) 55 | fovy = message["fov_y"] 56 | fovx = message["fov_x"] 57 | znear = message["z_near"] 58 | zfar = message["z_far"] 59 | do_shs_python = bool(message["shs_python"]) 60 | do_rot_scale_python = bool(message["rot_scale_python"]) 61 | keep_alive = bool(message["keep_alive"]) 62 | scaling_modifier = message["scaling_modifier"] 63 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 64 | world_view_transform[:,1] = -world_view_transform[:,1] 65 | world_view_transform[:,2] = -world_view_transform[:,2] 66 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 67 | full_proj_transform[:,1] = -full_proj_transform[:,1] 68 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 69 | except Exception as e: 70 | print("") 71 | traceback.print_exc() 72 | raise e 73 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 74 | else: 75 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in os.listdir(renders_dir): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | print("") 43 | 44 | for scene_dir in model_paths: 45 | try: 46 | print("Scene:", scene_dir) 47 | full_dict[scene_dir] = {} 48 | per_view_dict[scene_dir] = {} 49 | full_dict_polytopeonly[scene_dir] = {} 50 | per_view_dict_polytopeonly[scene_dir] = {} 51 | 52 | test_dir = Path(scene_dir) / "test" 53 | 54 | for method in os.listdir(test_dir): 55 | print("Method:", method) 56 | 57 | full_dict[scene_dir][method] = {} 58 | per_view_dict[scene_dir][method] = {} 59 | full_dict_polytopeonly[scene_dir][method] = {} 60 | per_view_dict_polytopeonly[scene_dir][method] = {} 61 | 62 | method_dir = test_dir / method 63 | gt_dir = method_dir/ "gt" 64 | renders_dir = method_dir / "renders" 65 | renders, gts, image_names = readImages(renders_dir, gt_dir) 66 | 67 | ssims = [] 68 | psnrs = [] 69 | lpipss = [] 70 | 71 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 72 | ssims.append(ssim(renders[idx], gts[idx])) 73 | psnrs.append(psnr(renders[idx], gts[idx])) 74 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 75 | 76 | print("Scene: ", scene_dir, "SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 77 | print("Scene: ", scene_dir, "PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 78 | print("Scene: ", scene_dir, "LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 79 | print("") 80 | 81 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 82 | "PSNR": torch.tensor(psnrs).mean().item(), 83 | "LPIPS": torch.tensor(lpipss).mean().item()}) 84 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 85 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 86 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 87 | 88 | with open(scene_dir + "/results.json", 'w') as fp: 89 | json.dump(full_dict[scene_dir], fp, indent=True) 90 | with open(scene_dir + "/per_view.json", 'w') as fp: 91 | json.dump(per_view_dict[scene_dir], fp, indent=True) 92 | except: 93 | print("Unable to compute metrics for model", scene_dir) 94 | 95 | if __name__ == "__main__": 96 | device = torch.device("cuda:0") 97 | torch.cuda.set_device(device) 98 | 99 | # Set up command line argument parser 100 | parser = ArgumentParser(description="Training script parameters") 101 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 102 | args = parser.parse_args() 103 | evaluate(args.model_paths) 104 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 3 | import imageio 4 | import numpy as np 5 | import torch 6 | from scene import Scene 7 | # import os 8 | import cv2 9 | from tqdm import tqdm 10 | from os import makedirs 11 | from gaussian_renderer import render 12 | import torchvision 13 | from utils.general_utils import safe_state 14 | from argparse import ArgumentParser 15 | from arguments import ModelParams, PipelineParams, get_combined_args, ModelHiddenParams 16 | from gaussian_renderer import GaussianModel 17 | from time import time 18 | import lpips 19 | to8b = lambda x : (255*np.clip(x.cpu().numpy(),0,1)).astype(np.uint8) 20 | 21 | 22 | from utils.loss_utils import ssim 23 | from utils.image_utils import psnr 24 | 25 | device = torch.device("cuda:0") 26 | 27 | 28 | def render_set(source_path,model_path, name, iteration, views, gaussians, pipeline, background): 29 | render_path = os.path.join(source_path,model_path, name, "ours_{}".format(iteration), "renders") 30 | gts_path = os.path.join(source_path,model_path, name, "ours_{}".format(iteration), "gt") 31 | 32 | makedirs(render_path, exist_ok=True) 33 | makedirs(gts_path, exist_ok=True) 34 | render_images = [] 35 | gt_list = [] 36 | render_list = [] 37 | PSNR = 0 38 | SSIM = 0 39 | LPIPS = 0 40 | lpips_vgg = lpips.LPIPS(net="vgg").cuda() 41 | 42 | time_all = 0 43 | view_all = [] 44 | try: 45 | frame_index = np.load(source_path + "/index.npy") 46 | except: 47 | if frame_index.shape[1] < 2: 48 | print("frame_index error!") 49 | exit() 50 | 51 | for condi in range(frame_index.shape[0]): 52 | view_all.append[views[condi]] 53 | 54 | idx = 0 55 | for view in views: 56 | time_1 = time() 57 | rendering = render(view, gaussians, pipeline, background)["render"] 58 | time_2 = time() 59 | time_all +=(time_2-time_1) 60 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 61 | if name !="video": 62 | SSIM += ssim(rendering.unsqueeze(0), view.original_image.unsqueeze(0)) 63 | PSNR += psnr(rendering, view.original_image).mean().double() 64 | LPIPS += lpips_vgg(rendering.unsqueeze(0).to(device=device), view.original_image.unsqueeze(0).to(device=device)) 65 | idx += 1 66 | 67 | print("FPS:",(len(views)-1)/time_all) 68 | print("Rendering Speed:",time_all/(len(views)-1)) 69 | count = 0 70 | 71 | print("writing training images.") 72 | if name !="video": 73 | SSIM = SSIM / len(views) 74 | PSNR = PSNR / len(views) 75 | LPIPS = LPIPS / len(views) 76 | output_text = f"{name} SIMM: {SSIM}, PSNR: {PSNR}, LPIPS: {LPIPS}, FPS:{(len(views)-1)/time_all}, Speed: {time_all/(len(views)-1)} " 77 | print(output_text) 78 | text_path = os.path.dirname(render_path) + f"/{name}_SIMM_PSNR.txt" 79 | 80 | with open(text_path, "w") as f: 81 | f.write(output_text) 82 | 83 | if len(gt_list) != 0: 84 | for image in tqdm(gt_list): 85 | torchvision.utils.save_image(image, os.path.join(gts_path, '{0:05d}'.format(count) + ".png")) 86 | count+=1 87 | count = 0 88 | print("writing rendering images.") 89 | if len(render_list) != 0: 90 | for image in tqdm(render_list): 91 | torchvision.utils.save_image(image, os.path.join(render_path, '{0:05d}'.format(count) + ".png")) 92 | count +=1 93 | if len(render_images)!=0: 94 | imageio.mimwrite(os.path.join(source_path,model_path, name, "ours_{}".format(iteration), 'video_rgb.mp4'), render_images, fps=30, quality=8) 95 | def render_sets(dataset : ModelParams, hyperparam, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, skip_video: bool): 96 | with torch.no_grad(): 97 | 98 | gaussians = GaussianModel(dataset.sh_degree, hyperparam) 99 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 100 | 101 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 102 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 103 | 104 | 105 | # if not skip_train: 106 | # render_set(dataset.source_path,dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 107 | if not skip_test: 108 | render_set(dataset.source_path,dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 109 | # if not skip_video: 110 | # render_set(dataset.source_path,dataset.model_path,"video",scene.loaded_iter,scene.getVideoCameras(),gaussians,pipeline,background) 111 | if __name__ == "__main__": 112 | # Set up command line argument parser 113 | parser = ArgumentParser(description="Testing script parameters") 114 | model = ModelParams(parser, sentinel=True) 115 | pipeline = PipelineParams(parser) 116 | hyperparam = ModelHiddenParams(parser) 117 | parser.add_argument("--iteration", default=-1, type=int) 118 | parser.add_argument("--skip_train", action="store_true") 119 | parser.add_argument("--skip_test", action="store_true") 120 | parser.add_argument("--quiet", action="store_true") 121 | parser.add_argument("--skip_video", action="store_true") 122 | parser.add_argument("--configs", type=str) 123 | args = get_combined_args(parser) 124 | print("Rendering " , args.model_path) 125 | if args.configs: 126 | import mmengine 127 | from utils.params_utils import merge_hparams 128 | config = mmengine.Config.fromfile(args.configs) 129 | args = merge_hparams(args, config) 130 | # Initialize system state (RNG) 131 | safe_state(args.quiet) 132 | 133 | render_sets(model.extract(args), hyperparam.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.skip_video) -------------------------------------------------------------------------------- /render.sh: -------------------------------------------------------------------------------- 1 | python render.py --source_path /lideqi/gaussian/data/cut_roasted_beef_temp --model_path output/Ours_code_test/ --configs arguments/dynerf/default.py 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.33.1 2 | matplotlib==3.4.3 3 | mmengine==0.9.1 4 | numpy==1.21.2 5 | open3d==0.17.0 6 | opencv_contrib_python==4.5.1.48 7 | opencv_python_headless==4.5.1.48 8 | Pillow==9.5.0 9 | Pillow==10.3.0 10 | plyfile==1.0.3 11 | scipy==1.6.3 12 | setuptools==58.2.0 13 | torch==1.10.0a0+0aef44c 14 | torchvision==0.11.0a0 15 | tqdm==4.62.3 16 | -------------------------------------------------------------------------------- /scene/KDTree.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | 5 | class Node(): 6 | def __init__(self, pt, leftBranch, rightBranch, dimension): 7 | self.pt = pt 8 | self.leftBranch = leftBranch 9 | self.rightBranch = rightBranch 10 | self.dimension = dimension 11 | 12 | 13 | class KDTree(): 14 | def __init__(self, data): 15 | self.nearestPt = None 16 | self.nearestDis = math.inf 17 | 18 | def createKDTree(self, currPts, dimension): 19 | if (len(currPts) == 0): 20 | return None 21 | mid = self.calMedium(currPts) 22 | sortedData = sorted(currPts, key=lambda x: x[dimension]) 23 | leftBranch = self.createKDTree(sortedData[:mid], self.calDimension(dimension)) 24 | rightBranch = self.createKDTree(sortedData[mid + 1:], self.calDimension(dimension)) 25 | return Node(sortedData[mid], leftBranch, rightBranch, dimension) 26 | 27 | def calMedium(self, currPts): 28 | return len(currPts) // 2 29 | 30 | def calDimension(self, dimension): 31 | return (dimension + 1) % 2 32 | 33 | def calDistance(self, p0, p1): 34 | return math.sqrt((p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2) 35 | 36 | def getNearestPt(self, root, targetPt): 37 | self.search(root, targetPt) 38 | near_pt = self.nearestPt 39 | near_dis = self.nearestDis 40 | 41 | self.nearestPt = None 42 | self.nearestDis = math.inf 43 | 44 | return near_pt,near_dis 45 | 46 | 47 | def search(self, node, targetPt): 48 | if node == None: 49 | return 50 | dist = node.pt[node.dimension] - targetPt[node.dimension] 51 | if (dist > 0): 52 | self.search(node.leftBranch, targetPt) 53 | else: 54 | self.search(node.rightBranch, targetPt) 55 | tempDis = self.calDistance(node.pt, targetPt) 56 | if (tempDis < self.nearestDis): 57 | self.nearestDis = tempDis 58 | self.nearestPt = node.pt 59 | 60 | if (self.nearestDis > abs(dist)): 61 | if (dist > 0): 62 | self.search(node.rightBranch, targetPt) 63 | else: 64 | self.search(node.leftBranch, targetPt) -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils.system_utils import searchForMaxIteration 3 | from scene.dataset_readers import sceneLoadTypeCallbacks 4 | from scene.gaussian_model import GaussianModel 5 | from scene.dataset import FourDGSdataset 6 | from arguments import ModelParams 7 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 8 | 9 | class Scene: 10 | 11 | gaussians : GaussianModel 12 | 13 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0], load_coarse=False): 14 | """b 15 | :param path: Path to colmap scene main folder. 16 | """ 17 | self.model_path = args.model_path 18 | self.source_path = args.source_path 19 | self.loaded_iter = None 20 | self.gaussians = gaussians 21 | 22 | if load_iteration: 23 | if load_iteration == -1: 24 | self.loaded_iter = searchForMaxIteration(os.path.join(self.source_path,self.model_path, "point_cloud")) 25 | else: 26 | self.loaded_iter = load_iteration 27 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 28 | 29 | self.train_cameras = {} 30 | self.test_cameras = {} 31 | self.video_cameras = {} 32 | scene_info = sceneLoadTypeCallbacks["dynerf"](self.source_path, args) 33 | 34 | self.maxtime = scene_info.maxtime 35 | 36 | self.cameras_extent = scene_info.nerf_normalization["radius"] 37 | 38 | 39 | print("Loading Training Cameras") 40 | self.train_camera = FourDGSdataset(scene_info.train_cameras, args) 41 | print("Loading Test Cameras") 42 | self.test_camera = FourDGSdataset(scene_info.test_cameras, args) 43 | print("Loading Video Cameras") 44 | 45 | self.video_camera = cameraList_from_camInfos(scene_info.video_cameras,-1,args) 46 | 47 | xyz_max = scene_info.point_cloud.points.max(axis=0) 48 | xyz_min = scene_info.point_cloud.points.min(axis=0) 49 | self.gaussians._deformation.deformation_net.grid.set_aabb(xyz_max,xyz_min) 50 | if self.loaded_iter: 51 | self.gaussians.load_ply(os.path.join(self.source_path,self.model_path, 52 | "point_cloud", 53 | "iteration_" + str(self.loaded_iter), 54 | "point_cloud.ply")) 55 | self.gaussians.load_model(os.path.join(self.source_path,self.model_path, 56 | "point_cloud", 57 | "iteration_" + str(self.loaded_iter), 58 | )) 59 | else: 60 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, self.maxtime) 61 | 62 | def save(self, iteration, stage): 63 | if stage == "coarse": 64 | point_cloud_path = os.path.join(self.source_path,self.model_path, "point_cloud/coarse_iteration_{}".format(iteration)) 65 | 66 | else: 67 | point_cloud_path = os.path.join(self.source_path,self.model_path, "point_cloud/iteration_{}".format(iteration)) 68 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 69 | self.gaussians.save_deformation(point_cloud_path) 70 | def getTrainCameras(self, scale=1.0): 71 | return self.train_camera 72 | 73 | def getTestCameras(self, scale=1.0): 74 | return self.test_camera 75 | def getVideoCameras(self, scale=1.0): 76 | return self.video_camera -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 5 | 6 | class Camera(nn.Module): 7 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image,flow,focal, gt_alpha_mask, 8 | image_name, uid, 9 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", time = 0,per_time = 0 10 | ): 11 | super(Camera, self).__init__() 12 | 13 | self.uid = uid 14 | self.colmap_id = colmap_id 15 | self.R = R 16 | self.T = T 17 | self.FoVx = FoVx 18 | self.FoVy = FoVy 19 | self.image_name = image_name 20 | self.time = time 21 | self.per_time = per_time 22 | self.focal = focal 23 | try: 24 | self.data_device = torch.device(data_device) 25 | except Exception as e: 26 | print(e) 27 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 28 | self.data_device = torch.device("cuda") 29 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 30 | self.flow = flow 31 | 32 | self.image_width = self.original_image.shape[2] 33 | self.image_height = self.original_image.shape[1] 34 | 35 | 36 | self.zfar = 100.0 37 | self.znear = 0.01 38 | 39 | self.trans = trans 40 | self.scale = scale 41 | 42 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1) 43 | 44 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1) 45 | 46 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 47 | self.camera_center = self.world_view_transform.inverse()[3, :3] 48 | 49 | class MiniCam: 50 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform, time): 51 | self.image_width = width 52 | self.image_height = height 53 | self.FoVy = fovy 54 | self.FoVx = fovx 55 | self.znear = znear 56 | self.zfar = zfar 57 | self.world_view_transform = world_view_transform 58 | self.full_proj_transform = full_proj_transform 59 | view_inv = torch.inverse(self.world_view_transform) 60 | self.camera_center = view_inv[3][:3] 61 | self.time = time 62 | 63 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | import struct 4 | import os 5 | 6 | 7 | CameraModel = collections.namedtuple( 8 | "CameraModel", ["model_id", "model_name", "num_params"]) 9 | Camera = collections.namedtuple( 10 | "Camera", ["id", "model", "width", "height", "params"]) 11 | BaseImage = collections.namedtuple( 12 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 13 | Point3D = collections.namedtuple( 14 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 15 | CAMERA_MODELS = { 16 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 17 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 18 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 19 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 20 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 21 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 22 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 23 | CameraModel(model_id=7, model_name="FOV", num_params=5), 24 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 25 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 26 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 27 | } 28 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 29 | for camera_model in CAMERA_MODELS]) 30 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 31 | for camera_model in CAMERA_MODELS]) 32 | 33 | 34 | def qvec2rotmat(qvec): 35 | return np.array([ 36 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 37 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 38 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 39 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 40 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 41 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 42 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 43 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 44 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 45 | 46 | def rotmat2qvec(R): 47 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 48 | K = np.array([ 49 | [Rxx - Ryy - Rzz, 0, 0, 0], 50 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 51 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 52 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 53 | eigvals, eigvecs = np.linalg.eigh(K) 54 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 55 | if qvec[0] < 0: 56 | qvec *= -1 57 | return qvec 58 | 59 | class Image(BaseImage): 60 | def qvec2rotmat(self): 61 | return qvec2rotmat(self.qvec) 62 | 63 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 64 | """Read and unpack the next bytes from a binary file. 65 | :param fid: 66 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 67 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 68 | :param endian_character: Any of {@, =, <, >, !} 69 | :return: Tuple of read and unpacked values. 70 | """ 71 | data = fid.read(num_bytes) 72 | return struct.unpack(endian_character + format_char_sequence, data) 73 | 74 | def read_points3D_text(path): 75 | """ 76 | see: src/base/reconstruction.cc 77 | void Reconstruction::ReadPoints3DText(const std::string& path) 78 | void Reconstruction::WritePoints3DText(const std::string& path) 79 | """ 80 | xyzs = None 81 | rgbs = None 82 | errors = None 83 | with open(path, "r") as fid: 84 | while True: 85 | line = fid.readline() 86 | if not line: 87 | break 88 | line = line.strip() 89 | if len(line) > 0 and line[0] != "#": 90 | elems = line.split() 91 | xyz = np.array(tuple(map(float, elems[1:4]))) 92 | rgb = np.array(tuple(map(int, elems[4:7]))) 93 | error = np.array(float(elems[7])) 94 | if xyzs is None: 95 | xyzs = xyz[None, ...] 96 | rgbs = rgb[None, ...] 97 | errors = error[None, ...] 98 | else: 99 | xyzs = np.append(xyzs, xyz[None, ...], axis=0) 100 | rgbs = np.append(rgbs, rgb[None, ...], axis=0) 101 | errors = np.append(errors, error[None, ...], axis=0) 102 | return xyzs, rgbs, errors 103 | 104 | def read_points3D_binary(path_to_model_file): 105 | """ 106 | see: src/base/reconstruction.cc 107 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 108 | void Reconstruction::WritePoints3DBinary(const std::string& path) 109 | """ 110 | 111 | 112 | with open(path_to_model_file, "rb") as fid: 113 | num_points = read_next_bytes(fid, 8, "Q")[0] 114 | 115 | xyzs = np.empty((num_points, 3)) 116 | rgbs = np.empty((num_points, 3)) 117 | errors = np.empty((num_points, 1)) 118 | 119 | for p_id in range(num_points): 120 | binary_point_line_properties = read_next_bytes( 121 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 122 | xyz = np.array(binary_point_line_properties[1:4]) 123 | rgb = np.array(binary_point_line_properties[4:7]) 124 | error = np.array(binary_point_line_properties[7]) 125 | track_length = read_next_bytes( 126 | fid, num_bytes=8, format_char_sequence="Q")[0] 127 | track_elems = read_next_bytes( 128 | fid, num_bytes=8*track_length, 129 | format_char_sequence="ii"*track_length) 130 | xyzs[p_id] = xyz 131 | rgbs[p_id] = rgb 132 | errors[p_id] = error 133 | return xyzs, rgbs, errors 134 | 135 | 136 | 137 | def read_intrinsics_text(path): 138 | """ 139 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 140 | """ 141 | cameras = {} 142 | with open(path, "r") as fid: 143 | while True: 144 | line = fid.readline() 145 | if not line: 146 | break 147 | line = line.strip() 148 | if len(line) > 0 and line[0] != "#": 149 | elems = line.split() 150 | camera_id = int(elems[0]) 151 | model = elems[1] 152 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 153 | width = int(elems[2]) 154 | height = int(elems[3]) 155 | params = np.array(tuple(map(float, elems[4:]))) 156 | cameras[camera_id] = Camera(id=camera_id, model=model, 157 | width=width, height=height, 158 | params=params) 159 | return cameras 160 | 161 | def read_extrinsics_binary(path_to_model_file): 162 | """ 163 | see: src/base/reconstruction.cc 164 | void Reconstruction::ReadImagesBinary(const std::string& path) 165 | void Reconstruction::WriteImagesBinary(const std::string& path) 166 | """ 167 | images = {} 168 | with open(path_to_model_file, "rb") as fid: 169 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 170 | for _ in range(num_reg_images): 171 | binary_image_properties = read_next_bytes( 172 | fid, num_bytes=64, format_char_sequence="idddddddi") 173 | image_id = binary_image_properties[0] 174 | qvec = np.array(binary_image_properties[1:5]) 175 | tvec = np.array(binary_image_properties[5:8]) 176 | camera_id = binary_image_properties[8] 177 | image_name = "" 178 | current_char = read_next_bytes(fid, 1, "c")[0] 179 | while current_char != b"\x00": # look for the ASCII 0 entry 180 | image_name += current_char.decode("utf-8") 181 | current_char = read_next_bytes(fid, 1, "c")[0] 182 | num_points2D = read_next_bytes(fid, num_bytes=8, 183 | format_char_sequence="Q")[0] 184 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 185 | format_char_sequence="ddq"*num_points2D) 186 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 187 | tuple(map(float, x_y_id_s[1::3]))]) 188 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 189 | images[image_id] = Image( 190 | id=image_id, qvec=qvec, tvec=tvec, 191 | camera_id=camera_id, name=image_name, 192 | xys=xys, point3D_ids=point3D_ids) 193 | return images 194 | 195 | 196 | def read_pt3d_binary(path_to_model_file): 197 | """ 198 | see: src/base/reconstruction.cc 199 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 200 | void Reconstruction::WritePoints3DBinary(const std::string& path) 201 | """ 202 | points3D = {} 203 | with open(path_to_model_file, "rb") as fid: 204 | num_points = read_next_bytes(fid, 8, "Q")[0] 205 | for point_line_index in range(num_points): 206 | binary_point_line_properties = read_next_bytes( 207 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 208 | point3D_id = binary_point_line_properties[0] 209 | xyz = np.array(binary_point_line_properties[1:4]) 210 | rgb = np.array(binary_point_line_properties[4:7]) 211 | error = np.array(binary_point_line_properties[7]) 212 | track_length = read_next_bytes( 213 | fid, num_bytes=8, format_char_sequence="Q")[0] 214 | track_elems = read_next_bytes( 215 | fid, num_bytes=8*track_length, 216 | format_char_sequence="ii"*track_length) 217 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 218 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 219 | points3D[point3D_id] = Point3D( 220 | id=point3D_id, xyz=xyz, rgb=rgb, 221 | error=error, image_ids=image_ids, 222 | point2D_idxs=point2D_idxs) 223 | return points3D 224 | 225 | 226 | def read_intrinsics_binary(path_to_model_file): 227 | """ 228 | see: src/base/reconstruction.cc 229 | void Reconstruction::WriteCamerasBinary(const std::string& path) 230 | void Reconstruction::ReadCamerasBinary(const std::string& path) 231 | """ 232 | cameras = {} 233 | with open(path_to_model_file, "rb") as fid: 234 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 235 | for _ in range(num_cameras): 236 | camera_properties = read_next_bytes( 237 | fid, num_bytes=24, format_char_sequence="iiQQ") 238 | camera_id = camera_properties[0] 239 | model_id = camera_properties[1] 240 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 241 | width = camera_properties[2] 242 | height = camera_properties[3] 243 | num_params = CAMERA_MODEL_IDS[model_id].num_params 244 | params = read_next_bytes(fid, num_bytes=8*num_params, 245 | format_char_sequence="d"*num_params) 246 | cameras[camera_id] = Camera(id=camera_id, 247 | model=model_name, 248 | width=width, 249 | height=height, 250 | params=np.array(params)) 251 | assert len(cameras) == num_cameras 252 | return cameras 253 | 254 | 255 | def read_extrinsics_text(path): 256 | """ 257 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 258 | """ 259 | images = {} 260 | with open(path, "r") as fid: 261 | while True: 262 | line = fid.readline() 263 | if not line: 264 | break 265 | line = line.strip() 266 | if len(line) > 0 and line[0] != "#": 267 | elems = line.split() 268 | image_id = int(elems[0]) 269 | qvec = np.array(tuple(map(float, elems[1:5]))) 270 | tvec = np.array(tuple(map(float, elems[5:8]))) 271 | camera_id = int(elems[8]) 272 | image_name = elems[9] 273 | elems = fid.readline().split() 274 | xys = np.column_stack([tuple(map(float, elems[0::3])), 275 | tuple(map(float, elems[1::3]))]) 276 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 277 | images[image_id] = Image( 278 | id=image_id, qvec=qvec, tvec=tvec, 279 | camera_id=camera_id, name=image_name, 280 | xys=xys, point3D_ids=point3D_ids) 281 | return images 282 | 283 | 284 | def read_colmap_bin_array(path): 285 | """ 286 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 287 | 288 | :param path: path to the colmap binary file. 289 | :return: nd array with the floating point values in the value 290 | """ 291 | with open(path, "rb") as fid: 292 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 293 | usecols=(0, 1, 2), dtype=int) 294 | fid.seek(0) 295 | num_delimiter = 0 296 | byte = fid.read(1) 297 | while True: 298 | if byte == b"&": 299 | num_delimiter += 1 300 | if num_delimiter >= 3: 301 | break 302 | byte = fid.read(1) 303 | array = np.fromfile(fid, np.float32) 304 | array = array.reshape((width, height, channels), order="F") 305 | return np.transpose(array, (1, 0, 2)).squeeze() 306 | 307 | 308 | def load_colmap_data(realdir): 309 | camerasfile = os.path.join(realdir, 'cameras.bin') 310 | camdata = read_intrinsics_binary(camerasfile) 311 | 312 | # cam = camdata[camdata.keys()[0]] 313 | list_of_keys = list(camdata.keys()) 314 | cam = camdata[list_of_keys[0]] 315 | print('Cameras', len(cam)) 316 | 317 | h, w, f = cam.height, cam.width, cam.params[0] 318 | # w, h, f = factor * w, factor * h, factor * f 319 | hwf = np.array([h, w, f]).reshape([3, 1]) 320 | 321 | imagesfile = os.path.join(realdir, 'images.bin') 322 | imdata = read_extrinsics_binary(imagesfile) 323 | 324 | w2c_mats = [] 325 | bottom = np.array([0, 0, 0, 1.]).reshape([1, 4]) 326 | 327 | names = [imdata[k].name for k in imdata] 328 | print('Images #', len(names)) 329 | perm = np.argsort(names) 330 | for k in imdata: 331 | im = imdata[k] 332 | R = im.qvec2rotmat() 333 | t = im.tvec.reshape([3, 1]) 334 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 335 | w2c_mats.append(m) 336 | 337 | w2c_mats = np.stack(w2c_mats, 0) 338 | c2w_mats = np.linalg.inv(w2c_mats) 339 | 340 | poses = c2w_mats[:, :3, :4].transpose([1, 2, 0]) 341 | poses = np.concatenate([poses, np.tile(hwf[..., np.newaxis], [1, 1, poses.shape[-1]])], 1) 342 | 343 | points3dfile = os.path.join(realdir, 'points3D.bin') 344 | pts3d = read_pt3d_binary(points3dfile) 345 | 346 | # must switch to [-u, r, -t] from [r, -u, t], NOT [r, u, -t] 347 | poses = np.concatenate([poses[:, 1:2, :], poses[:, 0:1, :], -poses[:, 2:3, :], poses[:, 3:4, :], poses[:, 4:5, :]], 348 | 1) 349 | 350 | return poses, pts3d, perm -------------------------------------------------------------------------------- /scene/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from scene.cameras import Camera 3 | import torch 4 | from utils.graphics_utils import focal2fov 5 | class FourDGSdataset(Dataset): 6 | def __init__( 7 | self, 8 | dataset, 9 | args 10 | ): 11 | self.dataset = dataset 12 | self.args = args 13 | def __getitem__(self, index): 14 | 15 | 16 | image, flow, w2c, focal, time, per_time = self.dataset[index] 17 | R,T = w2c 18 | 19 | width = image.shape[2] 20 | height = image.shape[1] 21 | 22 | focal_length_x = focal[0] 23 | focal_length_y = focal[1] 24 | 25 | FovY = focal2fov(focal_length_y, height) 26 | FovX = focal2fov(focal_length_x, width) 27 | 28 | return Camera(colmap_id=index,R=R,T=T,FoVx=FovX,FoVy=FovY,image=image,flow=flow,focal = focal,gt_alpha_mask=None, 29 | image_name=f"{index}",uid=index,data_device=torch.device("cuda"),time=time,per_time = per_time) 30 | def __len__(self): 31 | 32 | return len(self.dataset) 33 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import NamedTuple 3 | from scene.colmap_loader import read_points3D_binary 4 | from scene.external import storePly 5 | import torchvision.transforms as transforms 6 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 7 | import numpy as np 8 | from plyfile import PlyData, PlyElement 9 | from scene.gaussian_model import BasicPointCloud 10 | from tqdm import tqdm 11 | 12 | 13 | class CameraInfo(NamedTuple): 14 | uid: int 15 | R: np.array 16 | T: np.array 17 | FovY: np.array 18 | FovX: np.array 19 | image: np.array 20 | flow: np.array 21 | focal: np.array 22 | image_path: str 23 | image_name: str 24 | width: int 25 | height: int 26 | time : float 27 | per_time : float 28 | 29 | class SceneInfo(NamedTuple): 30 | point_cloud: BasicPointCloud 31 | train_cameras: list 32 | test_cameras: list 33 | video_cameras: list 34 | nerf_normalization: dict 35 | ply_path: str 36 | maxtime: int 37 | 38 | def getNerfppNorm(cam_info): 39 | def get_center_and_diag(cam_centers): 40 | cam_centers = np.hstack(cam_centers) 41 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 42 | center = avg_cam_center 43 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 44 | diagonal = np.max(dist) 45 | return center.flatten(), diagonal 46 | 47 | cam_centers = [] 48 | 49 | for cam in cam_info: 50 | W2C = getWorld2View2(cam.R, cam.T) 51 | C2W = np.linalg.inv(W2C) 52 | cam_centers.append(C2W[:3, 3:4]) 53 | 54 | center, diagonal = get_center_and_diag(cam_centers) 55 | radius = diagonal * 1.1 56 | 57 | translate = -center 58 | 59 | return {"translate": translate, "radius": radius} 60 | 61 | 62 | def fetchPly(path): 63 | plydata = PlyData.read(path) 64 | vertices = plydata['vertex'] 65 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 66 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 67 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 68 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 69 | 70 | 71 | def format_infos(dataset,split): 72 | # loading 73 | cameras = [] 74 | image = dataset[0][0] 75 | flow = dataset[0][1] 76 | focal = dataset[0][2] 77 | width = image.shape[2] 78 | height = image.shape[1] 79 | if split == "train": 80 | for idx in tqdm(range(len(dataset))): 81 | image_path = None 82 | image_name = f"{idx}" 83 | time = dataset.image_times[idx] 84 | per_time = dataset.per_image_times[idx] 85 | focal = dataset.focal[idx] 86 | R, T = dataset.load_pose(idx) 87 | 88 | FovX = focal2fov(focal[0], width) 89 | FovY = focal2fov(focal[1], height) 90 | 91 | cameras.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,flow=flow,focal = focal, 92 | image_path=image_path, image_name=image_name, width=image.shape[2], height=image.shape[1], 93 | time = time,per_time = per_time)) 94 | 95 | return cameras 96 | 97 | 98 | def format_render_poses(poses,data_infos): 99 | cameras = [] 100 | tensor_to_pil = transforms.ToPILImage() 101 | len_poses = len(poses) 102 | times = [i/len_poses for i in range(len_poses)] 103 | image = data_infos[0][0] 104 | flow = data_infos[0][1] 105 | focal = data_infos[0][3] 106 | width = image.shape[2] 107 | height = image.shape[1] 108 | 109 | for idx, p in tqdm(enumerate(poses)): 110 | 111 | image_path = None 112 | image_name = f"{idx}" 113 | time = times[idx] 114 | pose = np.eye(4) 115 | pose[:3,:] = p[:3,:] 116 | R = pose[:3,:3] 117 | R[:,0] = -R[:,0] 118 | T = pose[:3,3] 119 | FovX = focal2fov(focal[0], width) 120 | FovY = focal2fov(focal[1], height) 121 | 122 | cameras.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,flow = flow,focal=focal, 123 | image_path=image_path, image_name=image_name, width=width, height=height, 124 | time = time,per_time=0)) 125 | return cameras 126 | 127 | 128 | def readdynerfInfo(datadir,args): 129 | # loading all the data follow hexplane format 130 | # ply_path = os.path.join(datadir, "points3d.ply") 131 | ply_path = os.path.join(datadir, 'colmap/sparse/0/points3D.ply') 132 | bin_path = os.path.join(datadir, 'colmap/sparse/0/points3D.bin') 133 | 134 | from scene.neural_3D_dataset_NDC import Neural3D_NDC_Dataset 135 | train_dataset = Neural3D_NDC_Dataset( 136 | datadir, 137 | "train", 138 | 1.0, 139 | time_scale=1, 140 | scene_bbox_min=[-2.5, -2.0, -1.0], 141 | scene_bbox_max=[2.5, 2.0, 1.0], 142 | eval_index=args.eval_index, 143 | is_short = args.is_short 144 | ) 145 | 146 | test_dataset = Neural3D_NDC_Dataset( 147 | datadir, 148 | "test", 149 | 1.0, 150 | time_scale=1, 151 | scene_bbox_min=[-2.5, -2.0, -1.0], 152 | scene_bbox_max=[2.5, 2.0, 1.0], 153 | eval_index=args.eval_index, 154 | is_short = args.is_short 155 | ) 156 | 157 | train_cam_infos = format_infos(train_dataset,"train") 158 | 159 | val_cam_infos = format_render_poses(test_dataset.val_poses,test_dataset) 160 | nerf_normalization = getNerfppNorm(train_cam_infos) 161 | 162 | xyz, rgb, _ = read_points3D_binary(bin_path) 163 | storePly(ply_path, xyz, rgb) 164 | 165 | try: 166 | pcd = fetchPly(ply_path) 167 | except: 168 | pcd = None 169 | scene_info = SceneInfo(point_cloud=pcd, 170 | train_cameras=train_dataset, 171 | test_cameras=test_dataset, 172 | video_cameras=val_cam_infos, 173 | nerf_normalization=nerf_normalization, 174 | ply_path=ply_path, 175 | maxtime=300 176 | ) 177 | return scene_info 178 | 179 | sceneLoadTypeCallbacks = { 180 | "dynerf" : readdynerfInfo, 181 | } -------------------------------------------------------------------------------- /scene/deformation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from scene.hexplane import HexPlaneField 5 | 6 | class Deformation(nn.Module): 7 | def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None): 8 | super(Deformation, self).__init__() 9 | self.D = D 10 | self.W = W 11 | self.input_ch = input_ch 12 | self.input_ch_time = input_ch_time 13 | self.skips = skips 14 | 15 | self.no_grid = args.no_grid 16 | self.grid = HexPlaneField(args.bounds, args.kplanes_config, args.multires) 17 | self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net() 18 | self.args = args 19 | def create_net(self): 20 | 21 | mlp_out_dim = 0 22 | if self.no_grid: 23 | self.feature_out = [nn.Linear(4,self.W)] 24 | else: 25 | self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)] 26 | 27 | for i in range(self.D-1): 28 | self.feature_out.append(nn.ReLU()) 29 | self.feature_out.append(nn.Linear(self.W,self.W)) 30 | self.feature_out = nn.Sequential(*self.feature_out) 31 | output_dim = self.W 32 | return \ 33 | nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\ 34 | nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\ 35 | nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)), \ 36 | nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1)) 37 | 38 | def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb): 39 | 40 | if self.no_grid: 41 | h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1) 42 | else: 43 | grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1]) 44 | 45 | h = grid_feature 46 | 47 | h = self.feature_out(h) 48 | 49 | return h 50 | 51 | def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None, time_emb=None): 52 | if time_emb is None: 53 | return self.forward_static(rays_pts_emb[:,:3]) 54 | else: 55 | return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, time_emb) 56 | 57 | def forward_static(self, rays_pts_emb): 58 | grid_feature = self.grid(rays_pts_emb[:,:3]) 59 | dx = self.static_mlp(grid_feature) 60 | return rays_pts_emb[:, :3] + dx 61 | def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, time_emb): 62 | hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float() 63 | dx = self.pos_deform(hidden) 64 | pts = rays_pts_emb[:, :3] + dx 65 | if self.args.no_ds: 66 | scales = scales_emb[:,:3] 67 | else: 68 | ds = self.scales_deform(hidden) 69 | scales = scales_emb[:,:3] + ds 70 | if self.args.no_dr: 71 | rotations = rotations_emb[:,:4] 72 | else: 73 | dr = self.rotations_deform(hidden) 74 | rotations = rotations_emb[:,:4] + dr 75 | if self.args.no_do: 76 | opacity = opacity_emb[:,:1] 77 | else: 78 | do = self.opacity_deform(hidden) 79 | opacity = opacity_emb[:,:1] + do 80 | # + do 81 | # print("deformation value:","pts:",torch.abs(dx).mean(),"rotation:",torch.abs(dr).mean()) 82 | 83 | return pts, scales, rotations, opacity 84 | def get_mlp_parameters(self): 85 | parameter_list = [] 86 | for name, param in self.named_parameters(): 87 | if "grid" not in name: 88 | parameter_list.append(param) 89 | return parameter_list 90 | def get_grid_parameters(self): 91 | return list(self.grid.parameters() ) 92 | # + list(self.timegrid.parameters()) 93 | class deform_network(nn.Module): 94 | def __init__(self, args) : 95 | super(deform_network, self).__init__() 96 | net_width = args.net_width 97 | timebase_pe = args.timebase_pe 98 | defor_depth= args.defor_depth 99 | posbase_pe= args.posebase_pe 100 | scale_rotation_pe = args.scale_rotation_pe 101 | opacity_pe = args.opacity_pe 102 | timenet_width = args.timenet_width 103 | timenet_output = args.timenet_output 104 | times_ch = 2*timebase_pe+1 105 | self.timenet = nn.Sequential( 106 | nn.Linear(times_ch, timenet_width), nn.ReLU(), 107 | nn.Linear(timenet_width, timenet_output)) 108 | self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=args) 109 | self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)])) 110 | self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)])) 111 | self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)])) 112 | self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)])) 113 | self.apply(initialize_weights) 114 | # print(self) 115 | 116 | def forward(self, point, scales=None, rotations=None, opacity=None, times_sel=None): 117 | if times_sel is not None: 118 | return self.forward_dynamic(point, scales, rotations, opacity, times_sel) 119 | else: 120 | return self.forward_static(point) 121 | 122 | 123 | def forward_static(self, points): 124 | points = self.deformation_net(points) 125 | return points 126 | def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, times_sel=None): 127 | # times_emb = poc_fre(times_sel, self.time_poc) 128 | 129 | means3D, scales, rotations, opacity = self.deformation_net( point, 130 | scales, 131 | rotations, 132 | opacity, 133 | # times_feature, 134 | times_sel) 135 | return means3D, scales, rotations, opacity 136 | def get_mlp_parameters(self): 137 | return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters()) 138 | def get_grid_parameters(self): 139 | return self.deformation_net.get_grid_parameters() 140 | 141 | def initialize_weights(m): 142 | if isinstance(m, nn.Linear): 143 | # init.constant_(m.weight, 0) 144 | init.xavier_uniform_(m.weight,gain=1) 145 | if m.bias is not None: 146 | init.xavier_uniform_(m.weight,gain=1) 147 | # init.constant_(m.bias, 0) 148 | -------------------------------------------------------------------------------- /scene/external.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import open3d as o3d 6 | from scene.KDTree import KDTree 7 | from plyfile import PlyElement,PlyData 8 | import os 9 | 10 | 11 | def build_rotation(q): 12 | norm = torch.sqrt(q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3]) 13 | q = q / norm[:, None] 14 | rot = torch.zeros((q.size(0), 3, 3), device='cuda') 15 | r = q[:, 0] 16 | x = q[:, 1] 17 | y = q[:, 2] 18 | z = q[:, 3] 19 | rot[:, 0, 0] = 1 - 2 * (y * y + z * z) 20 | rot[:, 0, 1] = 2 * (x * y - r * z) 21 | rot[:, 0, 2] = 2 * (x * z + r * y) 22 | rot[:, 1, 0] = 2 * (x * y + r * z) 23 | rot[:, 1, 1] = 1 - 2 * (x * x + z * z) 24 | rot[:, 1, 2] = 2 * (y * z - r * x) 25 | rot[:, 2, 0] = 2 * (x * z - r * y) 26 | rot[:, 2, 1] = 2 * (y * z + r * x) 27 | rot[:, 2, 2] = 1 - 2 * (x * x + y * y) 28 | return rot 29 | 30 | 31 | def world_to_view_screen(pts3D, K, RT_cam2): 32 | # print("pts3D",pts3D.max(), pts3D.min()) 33 | wrld_X = RT_cam2.bmm(pts3D) 34 | xy_proj = K.bmm(wrld_X) 35 | 36 | 37 | # And finally we project to get the final result 38 | mask = (xy_proj[:, 2:3, :].abs() < 1E-2).detach() 39 | mask = mask.to(pts3D.device) 40 | mask.requires_grad_(False) 41 | 42 | zs = xy_proj[:, 2:3, :] 43 | 44 | mask_unsq = mask.unsqueeze(0).unsqueeze(0) 45 | 46 | if True in mask_unsq: 47 | zs[mask] = 1E-2 48 | sampler = torch.cat((xy_proj[:, 0:2, :] / zs, xy_proj[:, 2:3, :]), 1) 49 | 50 | # Remove invalid zs that cause nans 51 | if True in mask_unsq: 52 | sampler[mask.repeat(1, 3, 1)] = -10 53 | return sampler 54 | 55 | def get_pixel_grids(height, width): 56 | with torch.no_grad(): 57 | # texture coordinate 58 | x_linspace = torch.linspace(0, width - 1, width).view(1, width).expand(height, width) 59 | y_linspace = torch.linspace(0, height - 1, height).view(height, 1).expand(height, width) 60 | x_coordinates = x_linspace.contiguous().view(-1) 61 | y_coordinates = y_linspace.contiguous().view(-1) 62 | ones = torch.ones(height * width) 63 | indices_grid = torch.stack([x_coordinates, y_coordinates, ones, torch.ones(height * width)], dim=0) 64 | return indices_grid 65 | 66 | def my_view_to_world_coord(pts3D, K_inv, RTinv_cam1, xyzs): 67 | # PERFORM PROJECTION 68 | # Project the world points into the new view 69 | projected_coors = xyzs * pts3D 70 | projected_coors[:, -1, :] = 1 71 | cam1_X = K_inv.bmm(projected_coors) 72 | wrld_X = RTinv_cam1.bmm(cam1_X) 73 | return wrld_X 74 | 75 | def get_add_point(view_camera,xyz): 76 | # Motion-aware Splitting. Establish a connection between candidate points and Gaussian. 77 | 78 | 79 | temp_R = copy.deepcopy(view_camera.R) 80 | temp_T = copy.deepcopy(view_camera.T) 81 | 82 | temp_R = np.transpose(temp_R) 83 | R = np.eye(4) 84 | R[:3, :3] = temp_R 85 | R[:3, 3] = temp_T 86 | 87 | H, W = view_camera.original_image.shape[1], view_camera.original_image.shape[2] 88 | 89 | K = np.eye(4) 90 | K[0, 2] = W / 2 91 | K[1, 2] = H / 2 92 | K[0, 0] = view_camera.focal[0] 93 | K[1, 1] = view_camera.focal[1] 94 | 95 | 96 | K = torch.FloatTensor(K).unsqueeze(0).cuda() 97 | R = torch.FloatTensor(R).unsqueeze(0).cuda() 98 | 99 | src_xyz_t = xyz 100 | src_xyz_t = src_xyz_t.unsqueeze(0).permute(0, 2, 1) 101 | tempdata = torch.ones((src_xyz_t.shape[0], 1, src_xyz_t.shape[2])).cuda() 102 | src_xyz = torch.cat((src_xyz_t, tempdata), dim=1) 103 | 104 | xyz_sampler = world_to_view_screen(src_xyz, RT_cam2=R, K=K) 105 | 106 | 107 | sampler = xyz_sampler[0, 0:2].transpose(1, 0) 108 | depth_sampler = xyz_sampler[0, 2:].transpose(1, 0) 109 | 110 | sampler_t = sampler.detach().cpu().numpy().astype(int) 111 | sampler_mask = np.ones((sampler_t.shape[0],1)) 112 | 113 | sampler_mask[sampler_t[:, 1] >= H] = 0 114 | sampler_mask[sampler_t[:, 0] >= W] = 0 115 | sampler_mask[sampler_t[:, 1] < 0] = 0 116 | sampler_mask[sampler_t[:, 0] < 0] = 0 117 | 118 | sampler_t[sampler_t[:, 1] >= H, 1] = H - 1 119 | sampler_t[sampler_t[:, 0] >= W, 0] = W - 1 120 | sampler_t[sampler_t<0] = 0 121 | 122 | sampler_w = np.zeros_like(sampler_t) 123 | sampler_w[:, 0] = sampler_t[:, 1] 124 | sampler_w[:, 1] = sampler_t[:, 0] 125 | 126 | mask = np.zeros((H,W)) 127 | mask[sampler_t[:,1],sampler_t[:,0]] = 255 128 | 129 | x_linspace = torch.linspace(0, W - 1, W).view(1, W).expand(H, W) 130 | y_linspace = torch.linspace(0, H - 1, H).view(H, 1).expand(H, W) 131 | xyzs_big = torch.stack([y_linspace, x_linspace], dim=2) # H W 2 132 | 133 | flow_t = copy.deepcopy(view_camera.flow) 134 | 135 | flow_p = np.sum(flow_t,axis=2) 136 | flow_p[flow_p>0.5] = 1 137 | flow_p[flow_p<-0.5] = 1 138 | flow_p[flow_p!=1] = 0 139 | flow_p = np.array(flow_p,dtype= np.uint8) # motion regions 140 | 141 | kernel=np.ones((5,5),np.uint8) 142 | flow_p=cv2.dilate(flow_p,kernel,iterations=1) # Morphological operation 143 | 144 | 145 | flow_mask = flow_p[sampler_t[:, 1], sampler_t[:, 0]] 146 | 147 | wind = 3 148 | 149 | image1 = np.zeros((H,W)) 150 | image2 = np.zeros((H,W)) 151 | x_linspace1 = np.linspace(0,H-1,int(H/wind)-1).astype(int) 152 | y_linspace1 = np.linspace(0,W-1,int(W/wind)-1).astype(int) 153 | 154 | 155 | image1[x_linspace1,:] = 1 156 | image2[:,y_linspace1] = 1 157 | image = image1+image2 158 | image_indx = image[flow_p!=0] 159 | 160 | xyz_mask = xyzs_big[flow_p!=0] 161 | sampler_mask = xyz_mask[image_indx==2] 162 | sampler_w = sampler_w[flow_mask!=0] 163 | 164 | sampler_w = np.array(sampler_w) 165 | sampler_mask = np.array(sampler_mask) 166 | 167 | num_train = sampler_mask.shape[0] 168 | 169 | kdtree = KDTree(sampler_w) 170 | root = kdtree.createKDTree(sampler_w, 0) 171 | 172 | near_point_mask = torch.zeros((num_train),dtype = torch.long) 173 | for i in range(num_train): 174 | pt, minDis = kdtree.getNearestPt(root, sampler_mask[i]) 175 | index1,_ = np.where(sampler_w == pt) 176 | near_point_mask[i] = index1[0] 177 | 178 | return near_point_mask 179 | 180 | 181 | 182 | def get_sample_point(view_camera,xyz): 183 | 184 | temp_R = copy.deepcopy(view_camera.R) 185 | temp_T = copy.deepcopy(view_camera.T) 186 | 187 | temp_R = np.transpose(temp_R) 188 | R = np.eye(4) 189 | R[:3, :3] = temp_R 190 | R[:3, 3] = temp_T 191 | 192 | H, W = view_camera.original_image.shape[1], view_camera.original_image.shape[2] 193 | 194 | K = np.eye(4) 195 | K[0, 2] = view_camera.focal[2] 196 | K[1, 2] = view_camera.focal[3] 197 | K[0, 0] = view_camera.focal[1] 198 | K[1, 1] = view_camera.focal[0] 199 | 200 | src_xyz_t = xyz 201 | src_xyz_t = src_xyz_t.unsqueeze(0).permute(0, 2, 1) 202 | tempdata = torch.ones((src_xyz_t.shape[0], 1, src_xyz_t.shape[2])).cuda() 203 | src_xyz = torch.cat((src_xyz_t, tempdata), dim=1) 204 | 205 | K = torch.FloatTensor(K).unsqueeze(0).cuda() 206 | R = torch.FloatTensor(R).unsqueeze(0).cuda() 207 | 208 | xyz_sampler = world_to_view_screen(src_xyz, RT_cam2=R, K=K) 209 | 210 | sampler = xyz_sampler[0, 0:2].transpose(1, 0) 211 | temp_depth = xyz_sampler[:,2].transpose(1, 0) 212 | 213 | 214 | 215 | 216 | sampler_t = sampler.detach().cpu().numpy().astype(int) 217 | sampler_mask = np.ones((sampler_t.shape[0],1)) 218 | sampler_mask[sampler_t[:, 1] >= H] = 0 219 | sampler_mask[sampler_t[:, 0] >= W] = 0 220 | sampler_mask[sampler_t[:, 1] < 0] = 0 221 | sampler_mask[sampler_t[:, 0] < 0] = 0 222 | 223 | sampler_t[sampler_t[:, 1] >= H, 1] = H - 1 224 | sampler_t[sampler_t[:, 0] >= W, 0] = W - 1 225 | sampler_t[sampler_t<0] = 0 226 | 227 | mask = np.zeros((H,W)) 228 | mask[sampler_t[:,1],sampler_t[:,0]] = 255 229 | 230 | x_linspace = torch.linspace(0, W - 1, W).view(1, W).expand(H, W) 231 | y_linspace = torch.linspace(0, H - 1, H).view(H, 1).expand(H, W) 232 | xyzs_big = torch.stack([y_linspace, x_linspace], dim=2) 233 | 234 | flow_t = copy.deepcopy(view_camera.flow) 235 | 236 | flow_axis = flow_t+xyzs_big.numpy() 237 | 238 | flow_p = np.sum(flow_t,axis=2) 239 | flow_p[flow_p>0.5] = 1 240 | flow_p[flow_p<-0.5] = 1 241 | flow_p[flow_p!=1] = 0 242 | flow_mask = flow_p[sampler_t[:, 1], sampler_t[:, 0]] 243 | 244 | flow_sampler = flow_axis[sampler_t[:, 1], sampler_t[:, 0]] 245 | 246 | return mask,sampler_t,sampler_mask,torch.tensor(flow_sampler,dtype=torch.float,device="cuda"),torch.tensor(flow_mask,dtype=torch.float,device="cuda") 247 | 248 | 249 | 250 | def warp_point(view_camera, xyz): 251 | temp_R = copy.deepcopy(view_camera.R) 252 | temp_T = copy.deepcopy(view_camera.T) 253 | 254 | temp_T[0] *= -1 255 | 256 | 257 | R = np.eye(4) 258 | R[:3, :3] = temp_R 259 | R[:3, 3] = temp_T 260 | R = np.linalg.inv(R) 261 | H, W = view_camera.original_image.shape[1], view_camera.original_image.shape[2] 262 | 263 | K = np.eye(4) 264 | K[0, 2] = W / 2 265 | K[1, 2] = H / 2 266 | K[0, 0] = view_camera.focal[0] 267 | K[1, 1] = view_camera.focal[1] 268 | 269 | src_xyz_t = xyz 270 | src_xyz_t = src_xyz_t.unsqueeze(0).permute(0, 2, 1) 271 | tempdata = torch.ones((src_xyz_t.shape[0], 1, src_xyz_t.shape[2])).cuda() 272 | src_xyz = torch.cat((src_xyz_t, tempdata), dim=1) 273 | 274 | K = torch.FloatTensor(K).unsqueeze(0).cuda() 275 | R = torch.FloatTensor(R).unsqueeze(0).cuda() 276 | 277 | xyz_sampler = world_to_view_screen(src_xyz, RT_cam2=R, K=K) 278 | 279 | sampler = xyz_sampler[0, 0:2].transpose(1, 0) 280 | 281 | sampler_t = torch.zeros_like(sampler) 282 | sampler_t[:,0] = sampler[:,1] 283 | sampler_t[:,1] = sampler[:,0] 284 | 285 | 286 | return sampler_t 287 | 288 | 289 | def storePly(path, xyz, rgb): 290 | # Define the dtype for the structured array 291 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 292 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 293 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 294 | 295 | normals = np.zeros_like(xyz) 296 | 297 | elements = np.empty(xyz.shape[0], dtype=dtype) 298 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 299 | elements[:] = list(map(tuple, attributes)) 300 | 301 | # Create the PlyData object and write to file 302 | vertex_element = PlyElement.describe(elements, 'vertex') 303 | ply_data = PlyData([vertex_element]) 304 | ply_data.write(path) 305 | 306 | def quat_mult(q1, q2): 307 | w1, x1, y1, z1 = q1.T 308 | w2, x2, y2, z2 = q2.T 309 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 310 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 311 | y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 312 | z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 313 | return torch.stack([w, x, y, z]).T 314 | 315 | def o3d_knn(pts, num_knn): 316 | indices = [] 317 | sq_dists = [] 318 | pcd = o3d.geometry.PointCloud() 319 | pcd.points = o3d.utility.Vector3dVector(np.ascontiguousarray(pts, np.float64)) 320 | pcd_tree = o3d.geometry.KDTreeFlann(pcd) 321 | for p in pcd.points: 322 | [_, i, d] = pcd_tree.search_knn_vector_3d(p, num_knn + 1) 323 | indices.append(i[1:]) 324 | sq_dists.append(d[1:]) 325 | return np.array(sq_dists), np.array(indices) 326 | 327 | 328 | def calculate_total_size_of_files(folders): 329 | total_size = 0 330 | 331 | for folder_name in folders: 332 | deformation_path = os.path.join(folder_name, "./point_cloud/coarse_iteration_3000/deformation.pth") 333 | point_cloud_path = os.path.join(folder_name, "./point_cloud/coarse_iteration_3000/point_cloud.ply") 334 | # print(point_cloud_path) 335 | if os.path.exists(deformation_path): 336 | deformation_size = os.path.getsize(deformation_path) / (1024 * 1024) 337 | total_size += deformation_size 338 | 339 | if os.path.exists(point_cloud_path): 340 | point_cloud_size = os.path.getsize(point_cloud_path) / (1024 * 1024) 341 | total_size += point_cloud_size 342 | 343 | return total_size 344 | -------------------------------------------------------------------------------- /scene/getData.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scene.gaussian_model import GaussianModel 3 | 4 | def get_model_data( pc : GaussianModel, time, stage="fine"): 5 | means3D = pc.get_xyz 6 | opacity = pc._opacity 7 | 8 | scales = pc._scaling 9 | rotations = pc._rotation 10 | deformation_point = pc._deformation_table 11 | 12 | if stage == "coarse" : 13 | means3D_deform, scales_deform, rotations_deform, opacity_deform = means3D, scales, rotations, opacity 14 | else: 15 | means3D_deform, scales_deform, rotations_deform, opacity_deform = pc._deformation(means3D[deformation_point], scales[deformation_point], 16 | rotations[deformation_point], opacity[deformation_point], 17 | time[deformation_point]) 18 | 19 | if stage == "fine": 20 | with torch.no_grad(): 21 | pc._deformation_accum[deformation_point] += torch.abs(means3D_deform-means3D[deformation_point]) 22 | 23 | means3D_final = torch.zeros_like(means3D) 24 | rotations_final = torch.zeros_like(rotations) 25 | scales_final = torch.zeros_like(scales) 26 | opacity_final = torch.zeros_like(opacity) 27 | means3D_final[deformation_point] = means3D_deform 28 | rotations_final[deformation_point] = rotations_deform 29 | scales_final[deformation_point] = scales_deform 30 | opacity_final[deformation_point] = opacity_deform 31 | means3D_final[~deformation_point] = means3D[~deformation_point] 32 | rotations_final[~deformation_point] = rotations[~deformation_point] 33 | scales_final[~deformation_point] = scales[~deformation_point] 34 | opacity_final[~deformation_point] = opacity[~deformation_point] 35 | 36 | scales_final = pc.scaling_activation(scales_final) 37 | rotations_final = pc.rotation_activation(rotations_final) 38 | opacity = pc.opacity_activation(opacity) 39 | 40 | return means3D_final,scales_final,rotations_final,opacity -------------------------------------------------------------------------------- /scene/hexplane.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging as log 3 | from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def get_normalized_directions(directions): 11 | """SH encoding must be in the range [0, 1] 12 | 13 | Args: 14 | directions: batch of directions 15 | """ 16 | return (directions + 1.0) / 2.0 17 | 18 | 19 | def normalize_aabb(pts, aabb): 20 | return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0 21 | def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor: 22 | grid_dim = coords.shape[-1] 23 | 24 | if grid.dim() == grid_dim + 1: 25 | # no batch dimension present, need to add it 26 | grid = grid.unsqueeze(0) 27 | if coords.dim() == 2: 28 | coords = coords.unsqueeze(0) 29 | 30 | if grid_dim == 2 or grid_dim == 3: 31 | grid_sampler = F.grid_sample 32 | else: 33 | raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only " 34 | f"implemented for 2 and 3D data.") 35 | 36 | coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])) 37 | B, feature_dim = grid.shape[:2] 38 | n = coords.shape[-2] 39 | interp = grid_sampler( 40 | grid, # [B, feature_dim, reso, ...] 41 | coords, # [B, 1, ..., n, grid_dim] 42 | align_corners=align_corners, 43 | mode='bilinear', padding_mode='border') 44 | interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim] 45 | interp = interp.squeeze() # [B?, n, feature_dim?] 46 | return interp 47 | 48 | def init_grid_param( 49 | grid_nd: int, 50 | in_dim: int, 51 | out_dim: int, 52 | reso: Sequence[int], 53 | a: float = 0.1, 54 | b: float = 0.5): 55 | assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension" 56 | has_time_planes = in_dim == 4 57 | assert grid_nd <= in_dim 58 | coo_combs = list(itertools.combinations(range(in_dim), grid_nd)) 59 | grid_coefs = nn.ParameterList() 60 | for ci, coo_comb in enumerate(coo_combs): 61 | new_grid_coef = nn.Parameter(torch.empty( 62 | [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]] 63 | )) 64 | if has_time_planes and 3 in coo_comb: # Initialize time planes to 1 65 | nn.init.ones_(new_grid_coef) 66 | else: 67 | nn.init.uniform_(new_grid_coef, a=a, b=b) 68 | grid_coefs.append(new_grid_coef) 69 | 70 | return grid_coefs 71 | 72 | 73 | def interpolate_ms_features(pts: torch.Tensor, 74 | ms_grids: Collection[Iterable[nn.Module]], 75 | grid_dimensions: int, 76 | concat_features: bool, 77 | num_levels: Optional[int], 78 | ) -> torch.Tensor: 79 | coo_combs = list(itertools.combinations( 80 | range(pts.shape[-1]), grid_dimensions) 81 | ) 82 | if num_levels is None: 83 | num_levels = len(ms_grids) 84 | multi_scale_interp = [] if concat_features else 0. 85 | grid: nn.ParameterList 86 | for scale_id, grid in enumerate(ms_grids[:num_levels]): 87 | interp_space = 1. 88 | for ci, coo_comb in enumerate(coo_combs): 89 | # interpolate in plane 90 | feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso 91 | interp_out_plane = ( 92 | grid_sample_wrapper(grid[ci], pts[..., coo_comb]) 93 | .view(-1, feature_dim) 94 | ) 95 | # compute product over planes 96 | interp_space = interp_space * interp_out_plane 97 | 98 | # combine over scales 99 | if concat_features: 100 | multi_scale_interp.append(interp_space) 101 | else: 102 | multi_scale_interp = multi_scale_interp + interp_space 103 | 104 | if concat_features: 105 | multi_scale_interp = torch.cat(multi_scale_interp, dim=-1) 106 | return multi_scale_interp 107 | 108 | 109 | class HexPlaneField(nn.Module): 110 | def __init__( 111 | self, 112 | 113 | bounds, 114 | planeconfig, 115 | multires 116 | ) -> None: 117 | super().__init__() 118 | aabb = torch.tensor([[bounds,bounds,bounds], 119 | [-bounds,-bounds,-bounds]]) 120 | self.aabb = nn.Parameter(aabb, requires_grad=False) 121 | self.grid_config = [planeconfig] 122 | self.multiscale_res_multipliers = multires 123 | self.concat_features = True 124 | 125 | # 1. Init planes 126 | self.grids = nn.ModuleList() 127 | self.feat_dim = 0 128 | for res in self.multiscale_res_multipliers: 129 | # initialize coordinate grid 130 | config = self.grid_config[0].copy() 131 | # Resolution fix: multi-res only on spatial planes 132 | config["resolution"] = [ 133 | r * res for r in config["resolution"][:3] 134 | ] + config["resolution"][3:] 135 | gp = init_grid_param( 136 | grid_nd=config["grid_dimensions"], 137 | in_dim=config["input_coordinate_dim"], 138 | out_dim=config["output_coordinate_dim"], 139 | reso=config["resolution"], 140 | ) 141 | # shape[1] is out-dim - Concatenate over feature len for each scale 142 | if self.concat_features: 143 | self.feat_dim += gp[-1].shape[1] 144 | else: 145 | self.feat_dim = gp[-1].shape[1] 146 | self.grids.append(gp) 147 | # print(f"Initialized model grids: {self.grids}") 148 | print("feature_dim:",self.feat_dim) 149 | 150 | 151 | def set_aabb(self,xyz_max, xyz_min): 152 | aabb = torch.tensor([ 153 | xyz_max, 154 | xyz_min 155 | ]) 156 | self.aabb = nn.Parameter(aabb,requires_grad=True) 157 | print("Voxel Plane: set aabb=",self.aabb) 158 | 159 | def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): 160 | """Computes and returns the densities.""" 161 | 162 | pts = normalize_aabb(pts, self.aabb) 163 | pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4] 164 | 165 | pts = pts.reshape(-1, pts.shape[-1]) 166 | features = interpolate_ms_features( 167 | pts, ms_grids=self.grids, # noqa 168 | grid_dimensions=self.grid_config[0]["grid_dimensions"], 169 | concat_features=self.concat_features, num_levels=None) 170 | if len(features) < 1: 171 | features = torch.zeros((0, 1)).to(features.device) 172 | 173 | 174 | return features 175 | 176 | def forward(self, 177 | pts: torch.Tensor, 178 | timestamps: Optional[torch.Tensor] = None): 179 | 180 | features = self.get_density(pts, timestamps) 181 | 182 | return features 183 | -------------------------------------------------------------------------------- /scene/regulation.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Sequence 3 | import torch 4 | import torch.optim.lr_scheduler 5 | from torch import nn 6 | 7 | 8 | 9 | def compute_plane_tv(t): 10 | batch_size, c, h, w = t.shape 11 | count_h = batch_size * c * (h - 1) * w 12 | count_w = batch_size * c * h * (w - 1) 13 | h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum() 14 | w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum() 15 | return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg 16 | 17 | 18 | def compute_plane_smoothness(t): 19 | batch_size, c, h, w = t.shape 20 | # Convolve with a second derivative filter, in the time dimension which is dimension 2 21 | first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w] 22 | second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w] 23 | # Take the L2 norm of the result 24 | return torch.square(second_difference).mean() 25 | 26 | 27 | class Regularizer(): 28 | def __init__(self, reg_type, initialization): 29 | self.reg_type = reg_type 30 | self.initialization = initialization 31 | self.weight = float(self.initialization) 32 | self.last_reg = None 33 | 34 | def step(self, global_step): 35 | pass 36 | 37 | def report(self, d): 38 | if self.last_reg is not None: 39 | d[self.reg_type].update(self.last_reg.item()) 40 | 41 | def regularize(self, *args, **kwargs) -> torch.Tensor: 42 | out = self._regularize(*args, **kwargs) * self.weight 43 | self.last_reg = out.detach() 44 | return out 45 | 46 | @abc.abstractmethod 47 | def _regularize(self, *args, **kwargs) -> torch.Tensor: 48 | raise NotImplementedError() 49 | 50 | def __str__(self): 51 | return f"Regularizer({self.reg_type}, weight={self.weight})" 52 | 53 | 54 | class PlaneTV(Regularizer): 55 | def __init__(self, initial_value, what: str = 'field'): 56 | if what not in {'field', 'proposal_network'}: 57 | raise ValueError(f'what must be one of "field" or "proposal_network" ' 58 | f'but {what} was passed.') 59 | name = f'planeTV-{what[:2]}' 60 | super().__init__(name, initial_value) 61 | self.what = what 62 | 63 | def step(self, global_step): 64 | pass 65 | 66 | def _regularize(self, model, **kwargs): 67 | multi_res_grids: Sequence[nn.ParameterList] 68 | if self.what == 'field': 69 | multi_res_grids = model.field.grids 70 | elif self.what == 'proposal_network': 71 | multi_res_grids = [p.grids for p in model.proposal_networks] 72 | else: 73 | raise NotImplementedError(self.what) 74 | total = 0 75 | # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w] 76 | for grids in multi_res_grids: 77 | if len(grids) == 3: 78 | spatial_grids = [0, 1, 2] 79 | else: 80 | spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal 81 | for grid_id in spatial_grids: 82 | total += compute_plane_tv(grids[grid_id]) 83 | for grid in grids: 84 | # grid: [1, c, h, w] 85 | total += compute_plane_tv(grid) 86 | return total 87 | 88 | 89 | class TimeSmoothness(Regularizer): 90 | def __init__(self, initial_value, what: str = 'field'): 91 | if what not in {'field', 'proposal_network'}: 92 | raise ValueError(f'what must be one of "field" or "proposal_network" ' 93 | f'but {what} was passed.') 94 | name = f'time-smooth-{what[:2]}' 95 | super().__init__(name, initial_value) 96 | self.what = what 97 | 98 | def _regularize(self, model, **kwargs) -> torch.Tensor: 99 | multi_res_grids: Sequence[nn.ParameterList] 100 | if self.what == 'field': 101 | multi_res_grids = model.field.grids 102 | elif self.what == 'proposal_network': 103 | multi_res_grids = [p.grids for p in model.proposal_networks] 104 | else: 105 | raise NotImplementedError(self.what) 106 | total = 0 107 | # model.grids is 6 x [1, rank * F_dim, reso, reso] 108 | for grids in multi_res_grids: 109 | if len(grids) == 3: 110 | time_grids = [] 111 | else: 112 | time_grids = [2, 4, 5] 113 | for grid_id in time_grids: 114 | total += compute_plane_smoothness(grids[grid_id]) 115 | return torch.as_tensor(total) 116 | 117 | 118 | 119 | class L1ProposalNetwork(Regularizer): 120 | def __init__(self, initial_value): 121 | super().__init__('l1-proposal-network', initial_value) 122 | 123 | def _regularize(self, model, **kwargs) -> torch.Tensor: 124 | grids = [p.grids for p in model.proposal_networks] 125 | total = 0.0 126 | for pn_grids in grids: 127 | for grid in pn_grids: 128 | total += torch.abs(grid).mean() 129 | return torch.as_tensor(total) 130 | 131 | 132 | class DepthTV(Regularizer): 133 | def __init__(self, initial_value): 134 | super().__init__('tv-depth', initial_value) 135 | 136 | def _regularize(self, model, model_out, **kwargs) -> torch.Tensor: 137 | depth = model_out['depth'] 138 | tv = compute_plane_tv( 139 | depth.reshape(64, 64)[None, None, :, :] 140 | ) 141 | return tv 142 | 143 | 144 | class L1TimePlanes(Regularizer): 145 | def __init__(self, initial_value, what='field'): 146 | if what not in {'field', 'proposal_network'}: 147 | raise ValueError(f'what must be one of "field" or "proposal_network" ' 148 | f'but {what} was passed.') 149 | super().__init__(f'l1-time-{what[:2]}', initial_value) 150 | self.what = what 151 | 152 | def _regularize(self, model, **kwargs) -> torch.Tensor: 153 | # model.grids is 6 x [1, rank * F_dim, reso, reso] 154 | multi_res_grids: Sequence[nn.ParameterList] 155 | if self.what == 'field': 156 | multi_res_grids = model.field.grids 157 | elif self.what == 'proposal_network': 158 | multi_res_grids = [p.grids for p in model.proposal_networks] 159 | else: 160 | raise NotImplementedError(self.what) 161 | 162 | total = 0.0 163 | for grids in multi_res_grids: 164 | if len(grids) == 3: 165 | continue 166 | else: 167 | # These are the spatiotemporal grids 168 | spatiotemporal_grids = [2, 4, 5] 169 | for grid_id in spatiotemporal_grids: 170 | total += torch.abs(1 - grids[grid_id]).mean() 171 | return torch.as_tensor(total) 172 | 173 | -------------------------------------------------------------------------------- /scripts/convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/input \ 61 | --output_path " + args.source_path + "/distorted/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/input \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /scripts/getFlow.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('core') 4 | 5 | import argparse 6 | import os 7 | import cv2 8 | import glob 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | 13 | from raft import RAFT 14 | from utils import flow_viz 15 | from utils.utils import InputPadder 16 | 17 | DEVICE = 'cuda' 18 | 19 | 20 | 21 | 22 | def load_image(imfile): 23 | img = np.array(cv2.imread(imfile)).astype(np.uint8) 24 | img = torch.from_numpy(img).permute(2, 0, 1).float() 25 | return img[None].to(DEVICE) 26 | 27 | 28 | 29 | def viz(img, flo, filenames, args): 30 | img = img[0].permute(1, 2, 0).cpu().numpy() 31 | flo = flo[0].permute(1, 2, 0).cpu().numpy() 32 | 33 | flo = flow_viz.flow_to_image(flo) 34 | print(f'{args.savepath}/{filenames[-7:]}') 35 | cv2.imwrite(f'{args.savepath}/{filenames[-7:]}', flo[:, :, [2, 1, 0]]) 36 | 37 | 38 | def demo(args): 39 | model = torch.nn.DataParallel(RAFT(args)) 40 | model.load_state_dict(torch.load(args.model)) 41 | 42 | if not os.path.exists(args.savepath): 43 | os.makedirs(args.savepath) 44 | 45 | model = model.module 46 | model.to(DEVICE) 47 | model.eval() 48 | 49 | with torch.no_grad(): 50 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 51 | glob.glob(os.path.join(args.path, '*.jpg')) 52 | 53 | images = sorted(images) 54 | t = 0 55 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 56 | # image0 = cv2.imread(imfile1) 57 | 58 | image1 = load_image(imfile1) 59 | # print(image1.shape) 60 | image2 = load_image(imfile2) 61 | padder = InputPadder(image1.shape) 62 | image1, image2 = padder.pad(image1, image2) 63 | print(image1.shape) 64 | 65 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 66 | save_flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy() 67 | save_flow_low = flow_low[0].permute(1, 2, 0).cpu().numpy() 68 | if imfile1[-8] == '\\': 69 | file_num = imfile1[-7:-4] 70 | else: 71 | file_num = imfile1[-8:-4] 72 | print('imfile1:',imfile1) 73 | print('save_flow_low',save_flow_low.max()) 74 | print('save_flow_up',save_flow_up.max()) 75 | np.save(args.savepath+'/'+file_num+'.npy',save_flow_up) 76 | 77 | # print(flow_up.shape) 78 | # break 79 | viz(image1, flow_up, imfile1, args) 80 | 81 | def demo_CW4VS(args): 82 | 83 | model = torch.nn.DataParallel(RAFT(args)) 84 | model.load_state_dict(torch.load(args.model)) 85 | 86 | model = model.module 87 | model.to(DEVICE) 88 | model.eval() 89 | 90 | with torch.no_grad(): 91 | 92 | image1 = load_image(args.imagepath1) 93 | shape = image1.shape 94 | image2 = load_image(args.imagepath2) 95 | 96 | 97 | padder = InputPadder(image1.shape) 98 | image1, image2 = padder.pad(image1, image2) 99 | 100 | 101 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 102 | save_flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy() 103 | np.save(args.savepath,save_flow_up[:shape[2], :shape[3]]) 104 | 105 | 106 | if __name__ == '__main__': 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument('--model', help="restore checkpoint", default='./models/raft-sintel.pth') 109 | parser.add_argument('--source_path', help="dataset rootpath", 110 | default="rootpath/dtaset") 111 | parser.add_argument('--win_size', help="time step",type=int, 112 | default=1) 113 | parser.add_argument('--small', action='store_true', help='use small model') 114 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 115 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 116 | args = parser.parse_args() 117 | 118 | 119 | datapath = args.source_path 120 | 121 | imagenames = os.listdir(os.path.join(args.source_path, 'cam%02d'%(1),'images')) 122 | 123 | camNum = len(glob.glob(os.path.join(args.source_path, "cam*"))) 124 | for i in range(camNum): 125 | t= 0 126 | for imagename1 in imagenames : 127 | num,_ = imagename1.split('.') 128 | imagepath = os.path.join(args.source_path, 'cam%02d' % (i)) 129 | savepath = os.path.join(args.source_path, 'cam%02d'%(i),'flow') 130 | os.makedirs(savepath,exist_ok=True) 131 | args.savepath = os.path.join(savepath) + f'/{t}.npy' 132 | if t 20 |
21 |

BibTeX

22 |
@Article{kerbl3Dgaussians,
23 |       author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
24 |       title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
25 |       journal      = {ACM Transactions on Graphics},
26 |       number       = {4},
27 |       volume       = {42},
28 |       month        = {July},
29 |       year         = {2023},
30 |       url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
31 | }
32 |
33 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 13 | #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 14 | 15 | #include "config.h" 16 | #include "stdio.h" 17 | 18 | #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) 19 | #define NUM_WARPS (BLOCK_SIZE/32) 20 | 21 | // Spherical harmonics coefficients 22 | __device__ const float SH_C0 = 0.28209479177387814f; 23 | __device__ const float SH_C1 = 0.4886025119029199f; 24 | __device__ const float SH_C2[] = { 25 | 1.0925484305920792f, 26 | -1.0925484305920792f, 27 | 0.31539156525252005f, 28 | -1.0925484305920792f, 29 | 0.5462742152960396f 30 | }; 31 | __device__ const float SH_C3[] = { 32 | -0.5900435899266435f, 33 | 2.890611442640554f, 34 | -0.4570457994644658f, 35 | 0.3731763325901154f, 36 | -0.4570457994644658f, 37 | 1.445305721320277f, 38 | -0.5900435899266435f 39 | }; 40 | 41 | __forceinline__ __device__ float ndc2Pix(float v, int S) 42 | { 43 | return ((v + 1.0) * S - 1.0) * 0.5; 44 | } 45 | 46 | __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid) 47 | { 48 | rect_min = { 49 | min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))), 50 | min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y))) 51 | }; 52 | rect_max = { 53 | min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))), 54 | min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y))) 55 | }; 56 | } 57 | 58 | __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) 59 | { 60 | float3 transformed = { 61 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 62 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 63 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 64 | }; 65 | return transformed; 66 | } 67 | 68 | __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix) 69 | { 70 | float4 transformed = { 71 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 72 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 73 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 74 | matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15] 75 | }; 76 | return transformed; 77 | } 78 | 79 | __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) 80 | { 81 | float3 transformed = { 82 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, 83 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, 84 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, 85 | }; 86 | return transformed; 87 | } 88 | 89 | __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix) 90 | { 91 | float3 transformed = { 92 | matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z, 93 | matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z, 94 | matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z, 95 | }; 96 | return transformed; 97 | } 98 | 99 | __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) 100 | { 101 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 102 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 103 | float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 104 | return dnormvdz; 105 | } 106 | 107 | __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) 108 | { 109 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 110 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 111 | 112 | float3 dnormvdv; 113 | dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; 114 | dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32; 115 | dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 116 | return dnormvdv; 117 | } 118 | 119 | __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) 120 | { 121 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; 122 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 123 | 124 | float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; 125 | float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; 126 | float4 dnormvdv; 127 | dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32; 128 | dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32; 129 | dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32; 130 | dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32; 131 | return dnormvdv; 132 | } 133 | 134 | __forceinline__ __device__ float sigmoid(float x) 135 | { 136 | return 1.0f / (1.0f + expf(-x)); 137 | } 138 | 139 | __forceinline__ __device__ bool in_frustum(int idx, 140 | const float* orig_points, 141 | const float* viewmatrix, 142 | const float* projmatrix, 143 | bool prefiltered, 144 | float3& p_view) 145 | { 146 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; 147 | 148 | // Bring points to screen space 149 | float4 p_hom = transformPoint4x4(p_orig, projmatrix); 150 | float p_w = 1.0f / (p_hom.w + 0.0000001f); 151 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; 152 | p_view = transformPoint4x3(p_orig, viewmatrix); 153 | 154 | if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3))) 155 | { 156 | if (prefiltered) 157 | { 158 | printf("Point is filtered although prefiltered is set. This shouldn't happen!"); 159 | __trap(); 160 | } 161 | return false; 162 | } 163 | return true; 164 | } 165 | 166 | #define CHECK_CUDA(A, debug) \ 167 | A; if(debug) { \ 168 | auto ret = cudaDeviceSynchronize(); \ 169 | if (ret != cudaSuccess) { \ 170 | std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \ 171 | throw std::runtime_error(cudaGetErrorString(ret)); \ 172 | } \ 173 | } 174 | 175 | #endif -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/backward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace BACKWARD 22 | { 23 | void render( 24 | const dim3 grid, dim3 block, 25 | const uint2* ranges, 26 | const uint32_t* point_list, 27 | int W, int H, 28 | const float* bg_color, 29 | const float2* means2D, 30 | const float4* conic_opacity, 31 | const float* colors, 32 | const float* final_Ts, 33 | const uint32_t* n_contrib, 34 | const float* dL_dpixels, 35 | float3* dL_dmean2D, 36 | float4* dL_dconic2D, 37 | float* dL_dopacity, 38 | float* dL_dcolors); 39 | 40 | void preprocess( 41 | int P, int D, int M, 42 | const float3* means, 43 | const int* radii, 44 | const float* shs, 45 | const bool* clamped, 46 | const glm::vec3* scales, 47 | const glm::vec4* rotations, 48 | const float scale_modifier, 49 | const float* cov3Ds, 50 | const float* view, 51 | const float* proj, 52 | const float focal_x, float focal_y, 53 | const float tan_fovx, float tan_fovy, 54 | const glm::vec3* campos, 55 | const float3* dL_dmean2D, 56 | const float* dL_dconics, 57 | glm::vec3* dL_dmeans, 58 | float* dL_dcolor, 59 | float* dL_dcov3D, 60 | float* dL_dsh, 61 | glm::vec3* dL_dscale, 62 | glm::vec4* dL_drot); 63 | } 64 | 65 | #endif -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED 13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED 14 | 15 | #define NUM_CHANNELS 3 // Default 3, RGB 16 | #define BLOCK_X 16 17 | #define BLOCK_Y 16 18 | 19 | #endif -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/forward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace FORWARD 22 | { 23 | // Perform initial steps for each Gaussian prior to rasterization. 24 | void preprocess(int P, int D, int M, 25 | const float* orig_points, 26 | const glm::vec3* scales, 27 | const float scale_modifier, 28 | const glm::vec4* rotations, 29 | const float* opacities, 30 | const float* shs, 31 | bool* clamped, 32 | const float* cov3D_precomp, 33 | const float* colors_precomp, 34 | const float* viewmatrix, 35 | const float* projmatrix, 36 | const glm::vec3* cam_pos, 37 | const int W, int H, 38 | const float focal_x, float focal_y, 39 | const float tan_fovx, float tan_fovy, 40 | int* radii, 41 | float2* points_xy_image, 42 | float* depths, 43 | float* cov3Ds, 44 | float* colors, 45 | float4* conic_opacity, 46 | const dim3 grid, 47 | uint32_t* tiles_touched, 48 | bool prefiltered); 49 | 50 | // Main rasterization method. 51 | void render( 52 | const dim3 grid, dim3 block, 53 | const uint2* ranges, 54 | const uint32_t* point_list, 55 | int W, int H, 56 | const float2* points_xy_image, 57 | const float* features, 58 | const float* depths, 59 | const float4* conic_opacity, 60 | float* final_T, 61 | uint32_t* n_contrib, 62 | const float* bg_color, 63 | float* out_color, 64 | float* out_depth); 65 | } 66 | 67 | 68 | #endif 69 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_H_INCLUDED 13 | #define CUDA_RASTERIZER_H_INCLUDED 14 | 15 | #include 16 | #include 17 | 18 | namespace CudaRasterizer 19 | { 20 | class Rasterizer 21 | { 22 | public: 23 | 24 | static void markVisible( 25 | int P, 26 | float* means3D, 27 | float* viewmatrix, 28 | float* projmatrix, 29 | bool* present); 30 | 31 | static int forward( 32 | std::function geometryBuffer, 33 | std::function binningBuffer, 34 | std::function imageBuffer, 35 | const int P, int D, int M, 36 | const float* background, 37 | const int width, int height, 38 | const float* means3D, 39 | const float* shs, 40 | const float* colors_precomp, 41 | const float* opacities, 42 | const float* scales, 43 | const float scale_modifier, 44 | const float* rotations, 45 | const float* cov3D_precomp, 46 | const float* viewmatrix, 47 | const float* projmatrix, 48 | const float* cam_pos, 49 | const float tan_fovx, float tan_fovy, 50 | const bool prefiltered, 51 | float* out_color, 52 | float* out_depth, 53 | int* radii = nullptr, 54 | bool debug = false); 55 | 56 | static void backward( 57 | const int P, int D, int M, int R, 58 | const float* background, 59 | const int width, int height, 60 | const float* means3D, 61 | const float* shs, 62 | const float* colors_precomp, 63 | const float* scales, 64 | const float scale_modifier, 65 | const float* rotations, 66 | const float* cov3D_precomp, 67 | const float* viewmatrix, 68 | const float* projmatrix, 69 | const float* campos, 70 | const float tan_fovx, float tan_fovy, 71 | const int* radii, 72 | char* geom_buffer, 73 | char* binning_buffer, 74 | char* image_buffer, 75 | const float* dL_dpix, 76 | float* dL_dmean2D, 77 | float* dL_dconic, 78 | float* dL_dopacity, 79 | float* dL_dcolor, 80 | float* dL_dmean3D, 81 | float* dL_dcov3D, 82 | float* dL_dsh, 83 | float* dL_dscale, 84 | float* dL_drot, 85 | bool debug); 86 | }; 87 | }; 88 | 89 | #endif 90 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "rasterizer_impl.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include "cuda_runtime.h" 19 | #include "device_launch_parameters.h" 20 | #include 21 | #include 22 | #define GLM_FORCE_CUDA 23 | #include 24 | 25 | #include 26 | #include 27 | namespace cg = cooperative_groups; 28 | 29 | #include "auxiliary.h" 30 | #include "forward.h" 31 | #include "backward.h" 32 | 33 | // Helper function to find the next-highest bit of the MSB 34 | // on the CPU. 35 | uint32_t getHigherMsb(uint32_t n) 36 | { 37 | uint32_t msb = sizeof(n) * 4; 38 | uint32_t step = msb; 39 | while (step > 1) 40 | { 41 | step /= 2; 42 | if (n >> msb) 43 | msb += step; 44 | else 45 | msb -= step; 46 | } 47 | if (n >> msb) 48 | msb++; 49 | return msb; 50 | } 51 | 52 | // Wrapper method to call auxiliary coarse frustum containment test. 53 | // Mark all Gaussians that pass it. 54 | __global__ void checkFrustum(int P, 55 | const float* orig_points, 56 | const float* viewmatrix, 57 | const float* projmatrix, 58 | bool* present) 59 | { 60 | auto idx = cg::this_grid().thread_rank(); 61 | if (idx >= P) 62 | return; 63 | 64 | float3 p_view; 65 | present[idx] = in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view); 66 | } 67 | 68 | // Generates one key/value pair for all Gaussian / tile overlaps. 69 | // Run once per Gaussian (1:N mapping). 70 | __global__ void duplicateWithKeys( 71 | int P, 72 | const float2* points_xy, 73 | const float* depths, 74 | const uint32_t* offsets, 75 | uint64_t* gaussian_keys_unsorted, 76 | uint32_t* gaussian_values_unsorted, 77 | int* radii, 78 | dim3 grid) 79 | { 80 | auto idx = cg::this_grid().thread_rank(); 81 | if (idx >= P) 82 | return; 83 | 84 | // Generate no key/value pair for invisible Gaussians 85 | if (radii[idx] > 0) 86 | { 87 | // Find this Gaussian's offset in buffer for writing keys/values. 88 | uint32_t off = (idx == 0) ? 0 : offsets[idx - 1]; 89 | uint2 rect_min, rect_max; 90 | 91 | getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid); 92 | 93 | // For each tile that the bounding rect overlaps, emit a 94 | // key/value pair. The key is | tile ID | depth |, 95 | // and the value is the ID of the Gaussian. Sorting the values 96 | // with this key yields Gaussian IDs in a list, such that they 97 | // are first sorted by tile and then by depth. 98 | for (int y = rect_min.y; y < rect_max.y; y++) 99 | { 100 | for (int x = rect_min.x; x < rect_max.x; x++) 101 | { 102 | uint64_t key = y * grid.x + x; 103 | key <<= 32; 104 | key |= *((uint32_t*)&depths[idx]); 105 | gaussian_keys_unsorted[off] = key; 106 | gaussian_values_unsorted[off] = idx; 107 | off++; 108 | } 109 | } 110 | } 111 | } 112 | 113 | // Check keys to see if it is at the start/end of one tile's range in 114 | // the full sorted list. If yes, write start/end of this tile. 115 | // Run once per instanced (duplicated) Gaussian ID. 116 | __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges) 117 | { 118 | auto idx = cg::this_grid().thread_rank(); 119 | if (idx >= L) 120 | return; 121 | 122 | // Read tile ID from key. Update start/end of tile range if at limit. 123 | uint64_t key = point_list_keys[idx]; 124 | uint32_t currtile = key >> 32; 125 | if (idx == 0) 126 | ranges[currtile].x = 0; 127 | else 128 | { 129 | uint32_t prevtile = point_list_keys[idx - 1] >> 32; 130 | if (currtile != prevtile) 131 | { 132 | ranges[prevtile].y = idx; 133 | ranges[currtile].x = idx; 134 | } 135 | } 136 | if (idx == L - 1) 137 | ranges[currtile].y = L; 138 | } 139 | 140 | // Mark Gaussians as visible/invisible, based on view frustum testing 141 | void CudaRasterizer::Rasterizer::markVisible( 142 | int P, 143 | float* means3D, 144 | float* viewmatrix, 145 | float* projmatrix, 146 | bool* present) 147 | { 148 | checkFrustum << <(P + 255) / 256, 256 >> > ( 149 | P, 150 | means3D, 151 | viewmatrix, projmatrix, 152 | present); 153 | } 154 | 155 | CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P) 156 | { 157 | GeometryState geom; 158 | obtain(chunk, geom.depths, P, 128); 159 | obtain(chunk, geom.clamped, P * 3, 128); 160 | obtain(chunk, geom.internal_radii, P, 128); 161 | obtain(chunk, geom.means2D, P, 128); 162 | obtain(chunk, geom.cov3D, P * 6, 128); 163 | obtain(chunk, geom.conic_opacity, P, 128); 164 | obtain(chunk, geom.rgb, P * 3, 128); 165 | obtain(chunk, geom.tiles_touched, P, 128); 166 | cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P); 167 | obtain(chunk, geom.scanning_space, geom.scan_size, 128); 168 | obtain(chunk, geom.point_offsets, P, 128); 169 | return geom; 170 | } 171 | 172 | CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N) 173 | { 174 | ImageState img; 175 | obtain(chunk, img.accum_alpha, N, 128); 176 | obtain(chunk, img.n_contrib, N, 128); 177 | obtain(chunk, img.ranges, N, 128); 178 | return img; 179 | } 180 | 181 | CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P) 182 | { 183 | BinningState binning; 184 | obtain(chunk, binning.point_list, P, 128); 185 | obtain(chunk, binning.point_list_unsorted, P, 128); 186 | obtain(chunk, binning.point_list_keys, P, 128); 187 | obtain(chunk, binning.point_list_keys_unsorted, P, 128); 188 | cub::DeviceRadixSort::SortPairs( 189 | nullptr, binning.sorting_size, 190 | binning.point_list_keys_unsorted, binning.point_list_keys, 191 | binning.point_list_unsorted, binning.point_list, P); 192 | obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128); 193 | return binning; 194 | } 195 | 196 | // Forward rendering procedure for differentiable rasterization 197 | // of Gaussians. 198 | int CudaRasterizer::Rasterizer::forward( 199 | std::function geometryBuffer, 200 | std::function binningBuffer, 201 | std::function imageBuffer, 202 | const int P, int D, int M, 203 | const float* background, 204 | const int width, int height, 205 | const float* means3D, 206 | const float* shs, 207 | const float* colors_precomp, 208 | const float* opacities, 209 | const float* scales, 210 | const float scale_modifier, 211 | const float* rotations, 212 | const float* cov3D_precomp, 213 | const float* viewmatrix, 214 | const float* projmatrix, 215 | const float* cam_pos, 216 | const float tan_fovx, float tan_fovy, 217 | const bool prefiltered, 218 | float* out_color, 219 | float* out_depth, 220 | int* radii, 221 | bool debug) 222 | { 223 | const float focal_y = height / (2.0f * tan_fovy); 224 | const float focal_x = width / (2.0f * tan_fovx); 225 | 226 | size_t chunk_size = required(P); 227 | char* chunkptr = geometryBuffer(chunk_size); 228 | GeometryState geomState = GeometryState::fromChunk(chunkptr, P); 229 | 230 | if (radii == nullptr) 231 | { 232 | radii = geomState.internal_radii; 233 | } 234 | 235 | dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); 236 | dim3 block(BLOCK_X, BLOCK_Y, 1); 237 | 238 | // Dynamically resize image-based auxiliary buffers during training 239 | size_t img_chunk_size = required(width * height); 240 | char* img_chunkptr = imageBuffer(img_chunk_size); 241 | ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height); 242 | 243 | if (NUM_CHANNELS != 3 && colors_precomp == nullptr) 244 | { 245 | throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!"); 246 | } 247 | 248 | // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB) 249 | CHECK_CUDA(FORWARD::preprocess( 250 | P, D, M, 251 | means3D, 252 | (glm::vec3*)scales, 253 | scale_modifier, 254 | (glm::vec4*)rotations, 255 | opacities, 256 | shs, 257 | geomState.clamped, 258 | cov3D_precomp, 259 | colors_precomp, 260 | viewmatrix, projmatrix, 261 | (glm::vec3*)cam_pos, 262 | width, height, 263 | focal_x, focal_y, 264 | tan_fovx, tan_fovy, 265 | radii, 266 | geomState.means2D, 267 | geomState.depths, 268 | geomState.cov3D, 269 | geomState.rgb, 270 | geomState.conic_opacity, 271 | tile_grid, 272 | geomState.tiles_touched, 273 | prefiltered 274 | ), debug) 275 | 276 | // Compute prefix sum over full list of touched tile counts by Gaussians 277 | // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] 278 | CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug) 279 | 280 | // Retrieve total number of Gaussian instances to launch and resize aux buffers 281 | int num_rendered; 282 | CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug); 283 | 284 | size_t binning_chunk_size = required(num_rendered); 285 | char* binning_chunkptr = binningBuffer(binning_chunk_size); 286 | BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered); 287 | 288 | // For each instance to be rendered, produce adequate [ tile | depth ] key 289 | // and corresponding dublicated Gaussian indices to be sorted 290 | duplicateWithKeys << <(P + 255) / 256, 256 >> > ( 291 | P, 292 | geomState.means2D, 293 | geomState.depths, 294 | geomState.point_offsets, 295 | binningState.point_list_keys_unsorted, 296 | binningState.point_list_unsorted, 297 | radii, 298 | tile_grid) 299 | CHECK_CUDA(, debug) 300 | 301 | int bit = getHigherMsb(tile_grid.x * tile_grid.y); 302 | 303 | // Sort complete list of (duplicated) Gaussian indices by keys 304 | CHECK_CUDA(cub::DeviceRadixSort::SortPairs( 305 | binningState.list_sorting_space, 306 | binningState.sorting_size, 307 | binningState.point_list_keys_unsorted, binningState.point_list_keys, 308 | binningState.point_list_unsorted, binningState.point_list, 309 | num_rendered, 0, 32 + bit), debug) 310 | 311 | CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug); 312 | 313 | // Identify start and end of per-tile workloads in sorted list 314 | if (num_rendered > 0) 315 | identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > ( 316 | num_rendered, 317 | binningState.point_list_keys, 318 | imgState.ranges); 319 | CHECK_CUDA(, debug) 320 | 321 | // Let each tile blend its range of Gaussians independently in parallel 322 | const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb; 323 | CHECK_CUDA(FORWARD::render( 324 | tile_grid, block, 325 | imgState.ranges, 326 | binningState.point_list, 327 | width, height, 328 | geomState.means2D, 329 | feature_ptr, 330 | geomState.depths, 331 | geomState.conic_opacity, 332 | imgState.accum_alpha, 333 | imgState.n_contrib, 334 | background, 335 | out_color, 336 | out_depth), debug) 337 | 338 | return num_rendered; 339 | } 340 | 341 | // Produce necessary gradients for optimization, corresponding 342 | // to forward render pass 343 | void CudaRasterizer::Rasterizer::backward( 344 | const int P, int D, int M, int R, 345 | const float* background, 346 | const int width, int height, 347 | const float* means3D, 348 | const float* shs, 349 | const float* colors_precomp, 350 | const float* scales, 351 | const float scale_modifier, 352 | const float* rotations, 353 | const float* cov3D_precomp, 354 | const float* viewmatrix, 355 | const float* projmatrix, 356 | const float* campos, 357 | const float tan_fovx, float tan_fovy, 358 | const int* radii, 359 | char* geom_buffer, 360 | char* binning_buffer, 361 | char* img_buffer, 362 | const float* dL_dpix, 363 | float* dL_dmean2D, 364 | float* dL_dconic, 365 | float* dL_dopacity, 366 | float* dL_dcolor, 367 | float* dL_dmean3D, 368 | float* dL_dcov3D, 369 | float* dL_dsh, 370 | float* dL_dscale, 371 | float* dL_drot, 372 | bool debug) 373 | { 374 | GeometryState geomState = GeometryState::fromChunk(geom_buffer, P); 375 | BinningState binningState = BinningState::fromChunk(binning_buffer, R); 376 | ImageState imgState = ImageState::fromChunk(img_buffer, width * height); 377 | 378 | if (radii == nullptr) 379 | { 380 | radii = geomState.internal_radii; 381 | } 382 | 383 | const float focal_y = height / (2.0f * tan_fovy); 384 | const float focal_x = width / (2.0f * tan_fovx); 385 | 386 | const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); 387 | const dim3 block(BLOCK_X, BLOCK_Y, 1); 388 | 389 | // Compute loss gradients w.r.t. 2D mean position, conic matrix, 390 | // opacity and RGB of Gaussians from per-pixel loss gradients. 391 | // If we were given precomputed colors and not SHs, use them. 392 | const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb; 393 | CHECK_CUDA(BACKWARD::render( 394 | tile_grid, 395 | block, 396 | imgState.ranges, 397 | binningState.point_list, 398 | width, height, 399 | background, 400 | geomState.means2D, 401 | geomState.conic_opacity, 402 | color_ptr, 403 | imgState.accum_alpha, 404 | imgState.n_contrib, 405 | dL_dpix, 406 | (float3*)dL_dmean2D, 407 | (float4*)dL_dconic, 408 | dL_dopacity, 409 | dL_dcolor), debug) 410 | 411 | // Take care of the rest of preprocessing. Was the precomputed covariance 412 | // given to us or a scales/rot pair? If precomputed, pass that. If not, 413 | // use the one we computed ourselves. 414 | const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D; 415 | CHECK_CUDA(BACKWARD::preprocess(P, D, M, 416 | (float3*)means3D, 417 | radii, 418 | shs, 419 | geomState.clamped, 420 | (glm::vec3*)scales, 421 | (glm::vec4*)rotations, 422 | scale_modifier, 423 | cov3D_ptr, 424 | viewmatrix, 425 | projmatrix, 426 | focal_x, focal_y, 427 | tan_fovx, tan_fovy, 428 | (glm::vec3*)campos, 429 | (float3*)dL_dmean2D, 430 | dL_dconic, 431 | (glm::vec3*)dL_dmean3D, 432 | dL_dcolor, 433 | dL_dcov3D, 434 | dL_dsh, 435 | (glm::vec3*)dL_dscale, 436 | (glm::vec4*)dL_drot), debug) 437 | } 438 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include "rasterizer.h" 17 | #include 18 | 19 | namespace CudaRasterizer 20 | { 21 | template 22 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) 23 | { 24 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); 25 | ptr = reinterpret_cast(offset); 26 | chunk = reinterpret_cast(ptr + count); 27 | } 28 | 29 | struct GeometryState 30 | { 31 | size_t scan_size; 32 | float* depths; 33 | char* scanning_space; 34 | bool* clamped; 35 | int* internal_radii; 36 | float2* means2D; 37 | float* cov3D; 38 | float4* conic_opacity; 39 | float* rgb; 40 | uint32_t* point_offsets; 41 | uint32_t* tiles_touched; 42 | 43 | static GeometryState fromChunk(char*& chunk, size_t P); 44 | }; 45 | 46 | struct ImageState 47 | { 48 | uint2* ranges; 49 | uint32_t* n_contrib; 50 | float* accum_alpha; 51 | 52 | static ImageState fromChunk(char*& chunk, size_t N); 53 | }; 54 | 55 | struct BinningState 56 | { 57 | size_t sorting_size; 58 | uint64_t* point_list_keys_unsorted; 59 | uint64_t* point_list_keys; 60 | uint32_t* point_list_unsorted; 61 | uint32_t* point_list; 62 | char* list_sorting_space; 63 | 64 | static BinningState fromChunk(char*& chunk, size_t P); 65 | }; 66 | 67 | template 68 | size_t required(size_t P) 69 | { 70 | char* size = nullptr; 71 | T::fromChunk(size, P); 72 | return ((size_t)size) + 128; 73 | } 74 | }; -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from typing import NamedTuple 13 | import torch.nn as nn 14 | import torch 15 | from . import _C 16 | 17 | def cpu_deep_copy_tuple(input_tuple): 18 | copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] 19 | return tuple(copied_tensors) 20 | 21 | def rasterize_gaussians( 22 | means3D, 23 | means2D, 24 | sh, 25 | colors_precomp, 26 | opacities, 27 | scales, 28 | rotations, 29 | cov3Ds_precomp, 30 | raster_settings, 31 | ): 32 | return _RasterizeGaussians.apply( 33 | means3D, 34 | means2D, 35 | sh, 36 | colors_precomp, 37 | opacities, 38 | scales, 39 | rotations, 40 | cov3Ds_precomp, 41 | raster_settings, 42 | ) 43 | 44 | class _RasterizeGaussians(torch.autograd.Function): 45 | @staticmethod 46 | def forward( 47 | ctx, 48 | means3D, 49 | means2D, 50 | sh, 51 | colors_precomp, 52 | opacities, 53 | scales, 54 | rotations, 55 | cov3Ds_precomp, 56 | raster_settings, 57 | ): 58 | 59 | # Restructure arguments the way that the C++ lib expects them 60 | args = ( 61 | raster_settings.bg, 62 | means3D, 63 | colors_precomp, 64 | opacities, 65 | scales, 66 | rotations, 67 | raster_settings.scale_modifier, 68 | cov3Ds_precomp, 69 | raster_settings.viewmatrix, 70 | raster_settings.projmatrix, 71 | raster_settings.tanfovx, 72 | raster_settings.tanfovy, 73 | raster_settings.image_height, 74 | raster_settings.image_width, 75 | sh, 76 | raster_settings.sh_degree, 77 | raster_settings.campos, 78 | raster_settings.prefiltered, 79 | raster_settings.debug 80 | ) 81 | 82 | # Invoke C++/CUDA rasterizer 83 | if raster_settings.debug: 84 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted 85 | try: 86 | num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) 87 | except Exception as ex: 88 | torch.save(cpu_args, "snapshot_fw.dump") 89 | print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") 90 | raise ex 91 | else: 92 | num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) 93 | 94 | # Keep relevant tensors for backward 95 | ctx.raster_settings = raster_settings 96 | ctx.num_rendered = num_rendered 97 | ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer) 98 | return color, radii, depth 99 | 100 | @staticmethod 101 | def backward(ctx, grad_out_color, grad_radii, grad_depth): 102 | 103 | # Restore necessary values from context 104 | num_rendered = ctx.num_rendered 105 | raster_settings = ctx.raster_settings 106 | colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors 107 | 108 | # Restructure args as C++ method expects them 109 | args = (raster_settings.bg, 110 | means3D, 111 | radii, 112 | colors_precomp, 113 | scales, 114 | rotations, 115 | raster_settings.scale_modifier, 116 | cov3Ds_precomp, 117 | raster_settings.viewmatrix, 118 | raster_settings.projmatrix, 119 | raster_settings.tanfovx, 120 | raster_settings.tanfovy, 121 | grad_out_color, 122 | sh, 123 | raster_settings.sh_degree, 124 | raster_settings.campos, 125 | geomBuffer, 126 | num_rendered, 127 | binningBuffer, 128 | imgBuffer, 129 | raster_settings.debug) 130 | 131 | # Compute gradients for relevant tensors by invoking backward method 132 | if raster_settings.debug: 133 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted 134 | try: 135 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) 136 | except Exception as ex: 137 | torch.save(cpu_args, "snapshot_bw.dump") 138 | print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") 139 | raise ex 140 | else: 141 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) 142 | 143 | grads = ( 144 | grad_means3D, 145 | grad_means2D, 146 | grad_sh, 147 | grad_colors_precomp, 148 | grad_opacities, 149 | grad_scales, 150 | grad_rotations, 151 | grad_cov3Ds_precomp, 152 | None, 153 | ) 154 | 155 | return grads 156 | 157 | class GaussianRasterizationSettings(NamedTuple): 158 | image_height: int 159 | image_width: int 160 | tanfovx : float 161 | tanfovy : float 162 | bg : torch.Tensor 163 | scale_modifier : float 164 | viewmatrix : torch.Tensor 165 | projmatrix : torch.Tensor 166 | sh_degree : int 167 | campos : torch.Tensor 168 | prefiltered : bool 169 | debug : bool 170 | 171 | class GaussianRasterizer(nn.Module): 172 | def __init__(self, raster_settings): 173 | super().__init__() 174 | self.raster_settings = raster_settings 175 | 176 | def markVisible(self, positions): 177 | # Mark visible points (based on frustum culling for camera) with a boolean 178 | with torch.no_grad(): 179 | raster_settings = self.raster_settings 180 | visible = _C.mark_visible( 181 | positions, 182 | raster_settings.viewmatrix, 183 | raster_settings.projmatrix) 184 | 185 | return visible 186 | 187 | def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): 188 | 189 | raster_settings = self.raster_settings 190 | 191 | if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): 192 | raise Exception('Please provide excatly one of either SHs or precomputed colors!') 193 | 194 | if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None): 195 | raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!') 196 | 197 | if shs is None: 198 | shs = torch.Tensor([]) 199 | if colors_precomp is None: 200 | colors_precomp = torch.Tensor([]) 201 | 202 | if scales is None: 203 | scales = torch.Tensor([]) 204 | if rotations is None: 205 | rotations = torch.Tensor([]) 206 | if cov3D_precomp is None: 207 | cov3D_precomp = torch.Tensor([]) 208 | 209 | # Invoke C++/CUDA rasterization routine 210 | return rasterize_gaussians( 211 | means3D, 212 | means2D, 213 | shs, 214 | colors_precomp, 215 | opacities, 216 | scales, 217 | rotations, 218 | cov3D_precomp, 219 | raster_settings, 220 | ) 221 | 222 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "rasterize_points.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); 17 | m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); 18 | m.def("mark_visible", &markVisible); 19 | } -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/rasterize_points.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include "cuda_rasterizer/config.h" 22 | #include "cuda_rasterizer/rasterizer.h" 23 | #include 24 | #include 25 | #include 26 | 27 | std::function resizeFunctional(torch::Tensor& t) { 28 | auto lambda = [&t](size_t N) { 29 | t.resize_({(long long)N}); 30 | return reinterpret_cast(t.contiguous().data_ptr()); 31 | }; 32 | return lambda; 33 | } 34 | 35 | std::tuple 36 | RasterizeGaussiansCUDA( 37 | const torch::Tensor& background, 38 | const torch::Tensor& means3D, 39 | const torch::Tensor& colors, 40 | const torch::Tensor& opacity, 41 | const torch::Tensor& scales, 42 | const torch::Tensor& rotations, 43 | const float scale_modifier, 44 | const torch::Tensor& cov3D_precomp, 45 | const torch::Tensor& viewmatrix, 46 | const torch::Tensor& projmatrix, 47 | const float tan_fovx, 48 | const float tan_fovy, 49 | const int image_height, 50 | const int image_width, 51 | const torch::Tensor& sh, 52 | const int degree, 53 | const torch::Tensor& campos, 54 | const bool prefiltered, 55 | const bool debug) 56 | { 57 | if (means3D.ndimension() != 2 || means3D.size(1) != 3) { 58 | AT_ERROR("means3D must have dimensions (num_points, 3)"); 59 | } 60 | 61 | const int P = means3D.size(0); 62 | const int H = image_height; 63 | const int W = image_width; 64 | 65 | auto int_opts = means3D.options().dtype(torch::kInt32); 66 | auto float_opts = means3D.options().dtype(torch::kFloat32); 67 | 68 | torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); 69 | torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts); 70 | torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 71 | 72 | torch::Device device(torch::kCUDA); 73 | torch::TensorOptions options(torch::kByte); 74 | torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); 75 | torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); 76 | torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); 77 | std::function geomFunc = resizeFunctional(geomBuffer); 78 | std::function binningFunc = resizeFunctional(binningBuffer); 79 | std::function imgFunc = resizeFunctional(imgBuffer); 80 | 81 | int rendered = 0; 82 | if(P != 0) 83 | { 84 | int M = 0; 85 | if(sh.size(0) != 0) 86 | { 87 | M = sh.size(1); 88 | } 89 | 90 | rendered = CudaRasterizer::Rasterizer::forward( 91 | geomFunc, 92 | binningFunc, 93 | imgFunc, 94 | P, degree, M, 95 | background.contiguous().data(), 96 | W, H, 97 | means3D.contiguous().data(), 98 | sh.contiguous().data_ptr(), 99 | colors.contiguous().data(), 100 | opacity.contiguous().data(), 101 | scales.contiguous().data_ptr(), 102 | scale_modifier, 103 | rotations.contiguous().data_ptr(), 104 | cov3D_precomp.contiguous().data(), 105 | viewmatrix.contiguous().data(), 106 | projmatrix.contiguous().data(), 107 | campos.contiguous().data(), 108 | tan_fovx, 109 | tan_fovy, 110 | prefiltered, 111 | out_color.contiguous().data(), 112 | out_depth.contiguous().data(), 113 | radii.contiguous().data(), 114 | debug); 115 | } 116 | return std::make_tuple(rendered, out_color, out_depth, radii, geomBuffer, binningBuffer, imgBuffer); 117 | } 118 | 119 | std::tuple 120 | RasterizeGaussiansBackwardCUDA( 121 | const torch::Tensor& background, 122 | const torch::Tensor& means3D, 123 | const torch::Tensor& radii, 124 | const torch::Tensor& colors, 125 | const torch::Tensor& scales, 126 | const torch::Tensor& rotations, 127 | const float scale_modifier, 128 | const torch::Tensor& cov3D_precomp, 129 | const torch::Tensor& viewmatrix, 130 | const torch::Tensor& projmatrix, 131 | const float tan_fovx, 132 | const float tan_fovy, 133 | const torch::Tensor& dL_dout_color, 134 | const torch::Tensor& sh, 135 | const int degree, 136 | const torch::Tensor& campos, 137 | const torch::Tensor& geomBuffer, 138 | const int R, 139 | const torch::Tensor& binningBuffer, 140 | const torch::Tensor& imageBuffer, 141 | const bool debug) 142 | { 143 | const int P = means3D.size(0); 144 | const int H = dL_dout_color.size(1); 145 | const int W = dL_dout_color.size(2); 146 | 147 | int M = 0; 148 | if(sh.size(0) != 0) 149 | { 150 | M = sh.size(1); 151 | } 152 | 153 | torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); 154 | torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); 155 | torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); 156 | torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options()); 157 | torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); 158 | torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options()); 159 | torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); 160 | torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options()); 161 | torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); 162 | 163 | if(P != 0) 164 | { 165 | CudaRasterizer::Rasterizer::backward(P, degree, M, R, 166 | background.contiguous().data(), 167 | W, H, 168 | means3D.contiguous().data(), 169 | sh.contiguous().data(), 170 | colors.contiguous().data(), 171 | scales.data_ptr(), 172 | scale_modifier, 173 | rotations.data_ptr(), 174 | cov3D_precomp.contiguous().data(), 175 | viewmatrix.contiguous().data(), 176 | projmatrix.contiguous().data(), 177 | campos.contiguous().data(), 178 | tan_fovx, 179 | tan_fovy, 180 | radii.contiguous().data(), 181 | reinterpret_cast(geomBuffer.contiguous().data_ptr()), 182 | reinterpret_cast(binningBuffer.contiguous().data_ptr()), 183 | reinterpret_cast(imageBuffer.contiguous().data_ptr()), 184 | dL_dout_color.contiguous().data(), 185 | dL_dmeans2D.contiguous().data(), 186 | dL_dconic.contiguous().data(), 187 | dL_dopacity.contiguous().data(), 188 | dL_dcolors.contiguous().data(), 189 | dL_dmeans3D.contiguous().data(), 190 | dL_dcov3D.contiguous().data(), 191 | dL_dsh.contiguous().data(), 192 | dL_dscales.contiguous().data(), 193 | dL_drotations.contiguous().data(), 194 | debug); 195 | } 196 | 197 | return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations); 198 | } 199 | 200 | torch::Tensor markVisible( 201 | torch::Tensor& means3D, 202 | torch::Tensor& viewmatrix, 203 | torch::Tensor& projmatrix) 204 | { 205 | const int P = means3D.size(0); 206 | 207 | torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); 208 | 209 | if(P != 0) 210 | { 211 | CudaRasterizer::Rasterizer::markVisible(P, 212 | means3D.contiguous().data(), 213 | viewmatrix.contiguous().data(), 214 | projmatrix.contiguous().data(), 215 | present.contiguous().data()); 216 | } 217 | 218 | return present; 219 | } 220 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/rasterize_points.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | std::tuple 19 | RasterizeGaussiansCUDA( 20 | const torch::Tensor& background, 21 | const torch::Tensor& means3D, 22 | const torch::Tensor& colors, 23 | const torch::Tensor& opacity, 24 | const torch::Tensor& scales, 25 | const torch::Tensor& rotations, 26 | const float scale_modifier, 27 | const torch::Tensor& cov3D_precomp, 28 | const torch::Tensor& viewmatrix, 29 | const torch::Tensor& projmatrix, 30 | const float tan_fovx, 31 | const float tan_fovy, 32 | const int image_height, 33 | const int image_width, 34 | const torch::Tensor& sh, 35 | const int degree, 36 | const torch::Tensor& campos, 37 | const bool prefiltered, 38 | const bool debug); 39 | 40 | std::tuple 41 | RasterizeGaussiansBackwardCUDA( 42 | const torch::Tensor& background, 43 | const torch::Tensor& means3D, 44 | const torch::Tensor& radii, 45 | const torch::Tensor& colors, 46 | const torch::Tensor& scales, 47 | const torch::Tensor& rotations, 48 | const float scale_modifier, 49 | const torch::Tensor& cov3D_precomp, 50 | const torch::Tensor& viewmatrix, 51 | const torch::Tensor& projmatrix, 52 | const float tan_fovx, 53 | const float tan_fovy, 54 | const torch::Tensor& dL_dout_color, 55 | const torch::Tensor& sh, 56 | const int degree, 57 | const torch::Tensor& campos, 58 | const torch::Tensor& geomBuffer, 59 | const int R, 60 | const torch::Tensor& binningBuffer, 61 | const torch::Tensor& imageBuffer, 62 | const bool debug); 63 | 64 | torch::Tensor markVisible( 65 | torch::Tensor& means3D, 66 | torch::Tensor& viewmatrix, 67 | torch::Tensor& projmatrix); 68 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | os.path.dirname(os.path.abspath(__file__)) 16 | 17 | setup( 18 | name="diff_gaussian_rasterization", 19 | packages=['diff_gaussian_rasterization'], 20 | ext_modules=[ 21 | CUDAExtension( 22 | name="diff_gaussian_rasterization._C", 23 | sources=[ 24 | "cuda_rasterizer/rasterizer_impl.cu", 25 | "cuda_rasterizer/forward.cu", 26 | "cuda_rasterizer/backward.cu", 27 | "rasterize_points.cu", 28 | "ext.cpp"], 29 | extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) 30 | ], 31 | cmdclass={ 32 | 'build_ext': BuildExtension 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/third_party/glm.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglids/ST-4DGS/bf0dbb13e76bf41b2c2a4ca64063e5d346db7c74/submodules/depth-diff-gaussian-rasterization/third_party/glm.zip -------------------------------------------------------------------------------- /submodules/simple-knn/dist/simple_knn-0.0.0-py3.9-win-amd64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglids/ST-4DGS/bf0dbb13e76bf41b2c2a4ca64063e5d346db7c74/submodules/simple-knn/dist/simple_knn-0.0.0-py3.9-win-amd64.egg -------------------------------------------------------------------------------- /submodules/simple-knn/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "spatial.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("distCUDA2", &distCUDA2); 17 | } 18 | -------------------------------------------------------------------------------- /submodules/simple-knn/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | cxx_compiler_flags = [] 17 | 18 | if os.name == 'nt': 19 | cxx_compiler_flags.append("/wd4624") 20 | 21 | setup( 22 | name="simple_knn", 23 | ext_modules=[ 24 | CUDAExtension( 25 | name="simple_knn._C", 26 | sources=[ 27 | "spatial.cu", 28 | "simple_knn.cu", 29 | "ext.cpp"], 30 | extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) 31 | ], 32 | cmdclass={ 33 | 'build_ext': BuildExtension 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #define BOX_SIZE 1024 13 | 14 | #include "cuda_runtime.h" 15 | #include "device_launch_parameters.h" 16 | #include "simple_knn.h" 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #define __CUDACC__ 24 | #include 25 | #include 26 | 27 | namespace cg = cooperative_groups; 28 | 29 | struct CustomMin 30 | { 31 | __device__ __forceinline__ 32 | float3 operator()(const float3& a, const float3& b) const { 33 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; 34 | } 35 | }; 36 | 37 | struct CustomMax 38 | { 39 | __device__ __forceinline__ 40 | float3 operator()(const float3& a, const float3& b) const { 41 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; 42 | } 43 | }; 44 | 45 | __host__ __device__ uint32_t prepMorton(uint32_t x) 46 | { 47 | x = (x | (x << 16)) & 0x030000FF; 48 | x = (x | (x << 8)) & 0x0300F00F; 49 | x = (x | (x << 4)) & 0x030C30C3; 50 | x = (x | (x << 2)) & 0x09249249; 51 | return x; 52 | } 53 | 54 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) 55 | { 56 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); 57 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); 58 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); 59 | 60 | return x | (y << 1) | (z << 2); 61 | } 62 | 63 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) 64 | { 65 | auto idx = cg::this_grid().thread_rank(); 66 | if (idx >= P) 67 | return; 68 | 69 | codes[idx] = coord2Morton(points[idx], minn, maxx); 70 | } 71 | 72 | struct MinMax 73 | { 74 | float3 minn; 75 | float3 maxx; 76 | }; 77 | 78 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) 79 | { 80 | auto idx = cg::this_grid().thread_rank(); 81 | 82 | MinMax me; 83 | if (idx < P) 84 | { 85 | me.minn = points[indices[idx]]; 86 | me.maxx = points[indices[idx]]; 87 | } 88 | else 89 | { 90 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; 91 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; 92 | } 93 | 94 | __shared__ MinMax redResult[BOX_SIZE]; 95 | 96 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2) 97 | { 98 | if (threadIdx.x < 2 * off) 99 | redResult[threadIdx.x] = me; 100 | __syncthreads(); 101 | 102 | if (threadIdx.x < off) 103 | { 104 | MinMax other = redResult[threadIdx.x + off]; 105 | me.minn.x = min(me.minn.x, other.minn.x); 106 | me.minn.y = min(me.minn.y, other.minn.y); 107 | me.minn.z = min(me.minn.z, other.minn.z); 108 | me.maxx.x = max(me.maxx.x, other.maxx.x); 109 | me.maxx.y = max(me.maxx.y, other.maxx.y); 110 | me.maxx.z = max(me.maxx.z, other.maxx.z); 111 | } 112 | __syncthreads(); 113 | } 114 | 115 | if (threadIdx.x == 0) 116 | boxes[blockIdx.x] = me; 117 | } 118 | 119 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) 120 | { 121 | float3 diff = { 0, 0, 0 }; 122 | if (p.x < box.minn.x || p.x > box.maxx.x) 123 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); 124 | if (p.y < box.minn.y || p.y > box.maxx.y) 125 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); 126 | if (p.z < box.minn.z || p.z > box.maxx.z) 127 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); 128 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; 129 | } 130 | 131 | template 132 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn) 133 | { 134 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; 135 | float dist = d.x * d.x + d.y * d.y + d.z * d.z; 136 | for (int j = 0; j < K; j++) 137 | { 138 | if (knn[j] > dist) 139 | { 140 | float t = knn[j]; 141 | knn[j] = dist; 142 | dist = t; 143 | } 144 | } 145 | } 146 | 147 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) 148 | { 149 | int idx = cg::this_grid().thread_rank(); 150 | if (idx >= P) 151 | return; 152 | 153 | float3 point = points[indices[idx]]; 154 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; 155 | 156 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) 157 | { 158 | if (i == idx) 159 | continue; 160 | updateKBest<3>(point, points[indices[i]], best); 161 | } 162 | 163 | float reject = best[2]; 164 | best[0] = FLT_MAX; 165 | best[1] = FLT_MAX; 166 | best[2] = FLT_MAX; 167 | 168 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) 169 | { 170 | MinMax box = boxes[b]; 171 | float dist = distBoxPoint(box, point); 172 | if (dist > reject || dist > best[2]) 173 | continue; 174 | 175 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) 176 | { 177 | if (i == idx) 178 | continue; 179 | updateKBest<3>(point, points[indices[i]], best); 180 | } 181 | } 182 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; 183 | } 184 | 185 | void SimpleKNN::knn(int P, float3* points, float* meanDists) 186 | { 187 | float3* result; 188 | cudaMalloc(&result, sizeof(float3)); 189 | size_t temp_storage_bytes; 190 | 191 | float3 init = { 0, 0, 0 }, minn, maxx; 192 | 193 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); 194 | thrust::device_vector temp_storage(temp_storage_bytes); 195 | 196 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); 197 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); 198 | 199 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); 200 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); 201 | 202 | thrust::device_vector morton(P); 203 | thrust::device_vector morton_sorted(P); 204 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); 205 | 206 | thrust::device_vector indices(P); 207 | thrust::sequence(indices.begin(), indices.end()); 208 | thrust::device_vector indices_sorted(P); 209 | 210 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 211 | temp_storage.resize(temp_storage_bytes); 212 | 213 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 214 | 215 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; 216 | thrust::device_vector boxes(num_boxes); 217 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); 218 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); 219 | 220 | cudaFree(result); 221 | } -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: simple-knn 3 | Version: 0.0.0 4 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | ext.cpp 2 | setup.py 3 | simple_knn.cu 4 | spatial.cu 5 | simple_knn.egg-info/PKG-INFO 6 | simple_knn.egg-info/SOURCES.txt 7 | simple_knn.egg-info/dependency_links.txt 8 | simple_knn.egg-info/top_level.txt -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | simple_knn 2 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef SIMPLEKNN_H_INCLUDED 13 | #define SIMPLEKNN_H_INCLUDED 14 | 15 | class SimpleKNN 16 | { 17 | public: 18 | static void knn(int P, float3* points, float* meanDists); 19 | }; 20 | 21 | #endif -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglids/ST-4DGS/bf0dbb13e76bf41b2c2a4ca64063e5d346db7c74/submodules/simple-knn/simple_knn/.gitkeep -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "spatial.h" 13 | #include "simple_knn.h" 14 | 15 | torch::Tensor 16 | distCUDA2(const torch::Tensor& points) 17 | { 18 | const int P = points.size(0); 19 | 20 | auto float_opts = points.options().dtype(torch::kFloat32); 21 | torch::Tensor means = torch::full({P}, 0.0, float_opts); 22 | 23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); 24 | 25 | return means; 26 | } -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | 14 | torch::Tensor distCUDA2(const torch::Tensor& points); -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train.py --source_path /lideqi/gaussian/data/cut_roasted_beef_temp --model_path output/Ours_code_test --configs arguments/DyNeRF.py \ 2 | --expname "dynerf/Ours" --port 6037 3 | 4 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | from scene.cameras import Camera 2 | import numpy as np 3 | from utils.general_utils import PILtoTorch 4 | from utils.graphics_utils import fov2focal 5 | 6 | WARNED = False 7 | 8 | def loadCam(args, id, cam_info, resolution_scale): 9 | 10 | 11 | # resized_image_rgb = PILtoTorch(cam_ info.image, resolution) 12 | 13 | # gt_image = resized_image_rgb[:3, ...] 14 | # loaded_mask = None 15 | 16 | # if resized_image_rgb.shape[1] == 4: 17 | # loaded_mask = resized_image_rgb[3:4, ...] 18 | 19 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 20 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 21 | image=cam_info.image,flow=cam_info.flow,focal=cam_info.focal, gt_alpha_mask=None, 22 | image_name=cam_info.image_name, uid=id, data_device=args.data_device, 23 | time = cam_info.time,per_time=cam_info.per_time, 24 | ) 25 | 26 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 27 | camera_list = [] 28 | 29 | for id, c in enumerate(cam_infos): 30 | camera_list.append(loadCam(args, id, c, resolution_scale)) 31 | 32 | return camera_list 33 | 34 | def camera_to_JSON(id, camera : Camera): 35 | Rt = np.zeros((4, 4)) 36 | Rt[:3, :3] = camera.R.transpose() 37 | Rt[:3, 3] = camera.T 38 | Rt[3, 3] = 1.0 39 | 40 | W2C = np.linalg.inv(Rt) 41 | pos = W2C[:3, 3] 42 | rot = W2C[:3, :3] 43 | serializable_array_2d = [x.tolist() for x in rot] 44 | camera_entry = { 45 | 'id' : id, 46 | 'img_name' : camera.image_name, 47 | 'width' : camera.width, 48 | 'height' : camera.height, 49 | 'position': pos.tolist(), 50 | 'rotation': serializable_array_2d, 51 | 'fy' : fov2focal(camera.FovY, camera.height), 52 | 'fx' : fov2focal(camera.FovX, camera.width) 53 | } 54 | return camera_entry 55 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from datetime import datetime 4 | import numpy as np 5 | import random 6 | 7 | def inverse_sigmoid(x): 8 | return torch.log(x/(1-x)) 9 | 10 | def PILtoTorch(pil_image, resolution): 11 | if resolution is not None: 12 | resized_image_PIL = pil_image.resize(resolution) 13 | else: 14 | resized_image_PIL = pil_image 15 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 16 | if len(resized_image.shape) == 3: 17 | return resized_image.permute(2, 0, 1) 18 | else: 19 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 20 | 21 | def get_expon_lr_func( 22 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 23 | ): 24 | """ 25 | Copied from Plenoxels 26 | 27 | Continuous learning rate decay function. Adapted from JaxNeRF 28 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 29 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 30 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 31 | function of lr_delay_mult, such that the initial learning rate is 32 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 33 | to the normal learning rate when steps>lr_delay_steps. 34 | :param conf: config subtree 'lr' or similar 35 | :param max_steps: int, the number of steps during optimization. 36 | :return HoF which takes step as input 37 | """ 38 | 39 | def helper(step): 40 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 41 | # Disable this parameter 42 | return 0.0 43 | if lr_delay_steps > 0: 44 | # A kind of reverse cosine decay. 45 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 46 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 47 | ) 48 | else: 49 | delay_rate = 1.0 50 | t = np.clip(step / max_steps, 0, 1) 51 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 52 | return delay_rate * log_lerp 53 | 54 | return helper 55 | 56 | def strip_lowerdiag(L): 57 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 58 | 59 | uncertainty[:, 0] = L[:, 0, 0] 60 | uncertainty[:, 1] = L[:, 0, 1] 61 | uncertainty[:, 2] = L[:, 0, 2] 62 | uncertainty[:, 3] = L[:, 1, 1] 63 | uncertainty[:, 4] = L[:, 1, 2] 64 | uncertainty[:, 5] = L[:, 2, 2] 65 | return uncertainty 66 | 67 | def strip_symmetric(sym): 68 | return strip_lowerdiag(sym) 69 | 70 | def build_rotation(r): 71 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 72 | 73 | q = r / norm[:, None] 74 | 75 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 76 | 77 | r = q[:, 0] 78 | x = q[:, 1] 79 | y = q[:, 2] 80 | z = q[:, 3] 81 | 82 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 83 | R[:, 0, 1] = 2 * (x*y - r*z) 84 | R[:, 0, 2] = 2 * (x*z + r*y) 85 | R[:, 1, 0] = 2 * (x*y + r*z) 86 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 87 | R[:, 1, 2] = 2 * (y*z - r*x) 88 | R[:, 2, 0] = 2 * (x*z - r*y) 89 | R[:, 2, 1] = 2 * (y*z + r*x) 90 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 91 | return R 92 | 93 | def build_scaling_rotation(s, r): 94 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 95 | R = build_rotation(r) 96 | 97 | L[:,0,0] = s[:,0] 98 | L[:,1,1] = s[:,1] 99 | L[:,2,2] = s[:,2] 100 | 101 | L = R @ L 102 | return L 103 | 104 | def safe_state(silent): 105 | old_f = sys.stdout 106 | class F: 107 | def __init__(self, silent): 108 | self.silent = silent 109 | 110 | def write(self, x): 111 | if not self.silent: 112 | if x.endswith("\n"): 113 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 114 | else: 115 | old_f.write(x) 116 | 117 | def flush(self): 118 | old_f.flush() 119 | 120 | sys.stdout = F(silent) 121 | 122 | random.seed(0) 123 | np.random.seed(0) 124 | torch.manual_seed(0) 125 | torch.cuda.set_device(torch.device("cuda:0")) 126 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from typing import NamedTuple 5 | 6 | class BasicPointCloud(NamedTuple): 7 | points : np.array 8 | colors : np.array 9 | normals : np.array 10 | 11 | def geom_transform_points(points, transf_matrix): 12 | P, _ = points.shape 13 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 14 | points_hom = torch.cat([points, ones], dim=1) 15 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 16 | 17 | denom = points_out[..., 3:] + 0.0000001 18 | return (points_out[..., :3] / denom).squeeze(dim=0) 19 | 20 | def getWorld2View(R, t): 21 | Rt = np.zeros((4, 4)) 22 | Rt[:3, :3] = R.transpose() 23 | Rt[:3, 3] = t 24 | Rt[3, 3] = 1.0 25 | return np.float32(Rt) 26 | 27 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 28 | Rt = np.zeros((4, 4)) 29 | Rt[:3, :3] = R.transpose() 30 | Rt[:3, 3] = t 31 | Rt[3, 3] = 1.0 32 | 33 | C2W = np.linalg.inv(Rt) 34 | cam_center = C2W[:3, 3] 35 | cam_center = (cam_center + translate) * scale 36 | C2W[:3, 3] = cam_center 37 | Rt = np.linalg.inv(C2W) 38 | return np.float32(Rt) 39 | 40 | def getProjectionMatrix(znear, zfar, fovX, fovY): 41 | tanHalfFovY = math.tan((fovY / 2)) 42 | tanHalfFovX = math.tan((fovX / 2)) 43 | 44 | top = tanHalfFovY * znear 45 | bottom = -top 46 | right = tanHalfFovX * znear 47 | left = -right 48 | 49 | P = torch.zeros(4, 4) 50 | 51 | z_sign = 1.0 52 | 53 | P[0, 0] = 2.0 * znear / (right - left) 54 | P[1, 1] = 2.0 * znear / (top - bottom) 55 | P[0, 2] = (right + left) / (right - left) 56 | P[1, 2] = (top + bottom) / (top - bottom) 57 | P[3, 2] = z_sign 58 | P[2, 2] = z_sign * zfar / (zfar - znear) 59 | P[2, 3] = -(zfar * znear) / (zfar - znear) 60 | return P 61 | 62 | def fov2focal(fov, pixels): 63 | return pixels / (2 * math.tan(fov / 2)) 64 | 65 | def focal2fov(focal, pixels): 66 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mse(img1, img2): 4 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 5 | 6 | @torch.no_grad() 7 | def psnr(img1, img2): 8 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 9 | return 20 * torch.log10(1.0 / torch.sqrt(mse+1e-12)) 10 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | from math import exp 5 | 6 | def lpips_loss(img1, img2, lpips_model): 7 | loss = lpips_model(img1,img2) 8 | return loss.mean() 9 | def l1_loss(network_output, gt): 10 | return torch.abs((network_output - gt)).mean() 11 | 12 | 13 | def weighted_l2_loss_v2(x, y, w): 14 | return torch.sqrt(((x - y) ** 2).sum(-1) * w + 1e-8).mean() 15 | 16 | 17 | def l2_loss(network_output, gt): 18 | return ((network_output - gt) ** 2).mean() 19 | 20 | def gaussian(window_size, sigma): 21 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 22 | return gauss / gauss.sum() 23 | 24 | def create_window(window_size, channel): 25 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 26 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 27 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 28 | return window 29 | 30 | def ssim(img1, img2, window_size=11, size_average=True): 31 | channel = img1.size(-3) 32 | window = create_window(window_size, channel) 33 | 34 | if img1.is_cuda: 35 | # window = window.cuda(img1.get_device()) 36 | window = window.to(img1.device) 37 | window = window.type_as(img1) 38 | 39 | return _ssim(img1, img2, window, window_size, channel, size_average) 40 | 41 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 42 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 43 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 44 | 45 | mu1_sq = mu1.pow(2) 46 | mu2_sq = mu2.pow(2) 47 | mu1_mu2 = mu1 * mu2 48 | 49 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 50 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 51 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 52 | 53 | C1 = 0.01 ** 2 54 | C2 = 0.03 ** 2 55 | 56 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 57 | 58 | if size_average: 59 | return ssim_map.mean() 60 | else: 61 | return ssim_map.mean(1).mean(1).mean(1) 62 | 63 | 64 | def scale_loss(point_scale,thr): 65 | 66 | return torch.mean( 67 | torch.max((torch.max(torch.abs(point_scale), dim=1).values / torch.min(torch.abs(point_scale), dim=1).values), torch.tensor([thr],device="cuda")) - thr) 68 | -------------------------------------------------------------------------------- /utils/params_utils.py: -------------------------------------------------------------------------------- 1 | def merge_hparams(args, config): 2 | params = ["OptimizationParams", "ModelHiddenParams", "ModelParams", "PipelineParams"] 3 | for param in params: 4 | if param in config.keys(): 5 | for key, value in config[param].items(): 6 | if hasattr(args, key): 7 | setattr(args, key, value) 8 | 9 | return args -------------------------------------------------------------------------------- /utils/scene_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image, ImageDraw, ImageFont 4 | from matplotlib import pyplot as plt 5 | plt.rcParams['font.sans-serif'] = ['Times New Roman'] 6 | 7 | import numpy as np 8 | 9 | import copy 10 | @torch.no_grad() 11 | def render_training_image(scene, gaussians, viewpoints, render_func, pipe, background, stage, iteration, time_now): 12 | def render(gaussians, viewpoint, path, scaling): 13 | # scaling_copy = gaussians._scaling 14 | render_pkg = render_func(viewpoint, gaussians, pipe, background, stage=stage) 15 | label1 = f"stage:{stage},iter:{iteration}" 16 | times = time_now/60 17 | if times < 1: 18 | end = "min" 19 | else: 20 | end = "mins" 21 | label2 = "time:%.2f" % times + end 22 | image = render_pkg["render"] 23 | depth = render_pkg["depth"] 24 | image_np = image.permute(1, 2, 0).cpu().numpy() # 转换通道顺序为 (H, W, 3) 25 | depth_np = depth.permute(1, 2, 0).cpu().numpy() 26 | depth_np /= depth_np.max() 27 | depth_np = np.repeat(depth_np, 3, axis=2) 28 | image_np = np.concatenate((image_np, depth_np), axis=1) 29 | image_with_labels = Image.fromarray((np.clip(image_np,0,1) * 255).astype('uint8')) # 转换为8位图像 30 | # 创建PIL图像对象的副本以绘制标签 31 | draw1 = ImageDraw.Draw(image_with_labels) 32 | 33 | # 选择字体和字体大小 34 | font = ImageFont.truetype('./utils/TIMES.TTF', size=40) # 请将路径替换为您选择的字体文件路径 35 | 36 | # 选择文本颜色 37 | text_color = (255, 0, 0) # 白色 38 | 39 | # 选择标签的位置(左上角坐标) 40 | label1_position = (10, 10) 41 | label2_position = (image_with_labels.width - 100 - len(label2) * 10, 10) # 右上角坐标 42 | 43 | # 在图像上添加标签 44 | draw1.text(label1_position, label1, fill=text_color, font=font) 45 | draw1.text(label2_position, label2, fill=text_color, font=font) 46 | 47 | image_with_labels.save(path) 48 | render_base_path = os.path.join(scene.model_path, f"{stage}_render") 49 | point_cloud_path = os.path.join(render_base_path,"pointclouds") 50 | image_path = os.path.join(render_base_path,"images") 51 | if not os.path.exists(os.path.join(scene.model_path, f"{stage}_render")): 52 | os.makedirs(render_base_path) 53 | if not os.path.exists(point_cloud_path): 54 | os.makedirs(point_cloud_path) 55 | if not os.path.exists(image_path): 56 | os.makedirs(image_path) 57 | # image:3,800,800 58 | 59 | # point_save_path = os.path.join(point_cloud_path,f"{iteration}.jpg") 60 | for idx in range(len(viewpoints)): 61 | image_save_path = os.path.join(image_path,f"{iteration}_{idx}.jpg") 62 | render(gaussians,viewpoints[idx],image_save_path,scaling = 1) 63 | # render(gaussians,point_save_path,scaling = 0.1) 64 | # 保存带有标签的图像 65 | 66 | 67 | 68 | pc_mask = gaussians.get_opacity 69 | pc_mask = pc_mask > 0.1 70 | xyz = gaussians.get_xyz.detach()[pc_mask.squeeze()].cpu().permute(1,0).numpy() 71 | # visualize_and_save_point_cloud(xyz, viewpoint.R, viewpoint.T, point_save_path) 72 | # 如果需要,您可以将PIL图像转换回PyTorch张量 73 | # return image 74 | # image_with_labels_tensor = torch.tensor(image_with_labels, dtype=torch.float32).permute(2, 0, 1) / 255.0 75 | def visualize_and_save_point_cloud(point_cloud, R, T, filename): 76 | # 创建3D散点图 77 | fig = plt.figure() 78 | ax = fig.add_subplot(111, projection='3d') 79 | R = R.T 80 | # 应用旋转和平移变换 81 | T = -R.dot(T) 82 | transformed_point_cloud = np.dot(R, point_cloud) + T.reshape(-1, 1) 83 | # pcd = o3d.geometry.PointCloud() 84 | # pcd.points = o3d.utility.Vector3dVector(transformed_point_cloud.T) # 转置点云数据以匹配Open3D的格式 85 | # transformed_point_cloud[2,:] = -transformed_point_cloud[2,:] 86 | # 可视化点云 87 | ax.scatter(transformed_point_cloud[0], transformed_point_cloud[1], transformed_point_cloud[2], c='g', marker='o') 88 | ax.axis("off") 89 | # ax.set_xlabel('X Label') 90 | # ax.set_ylabel('Y Label') 91 | # ax.set_zlabel('Z Label') 92 | 93 | # 保存渲染结果为图片 94 | plt.savefig(filename) 95 | 96 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | from errno import EEXIST 2 | from os import makedirs, path 3 | import os 4 | 5 | def mkdir_p(folder_path): 6 | # Creates a directory. equivalent to using mkdir -p on the command line 7 | try: 8 | makedirs(folder_path) 9 | except OSError as exc: # Python >2.5 10 | if exc.errno == EEXIST and path.isdir(folder_path): 11 | pass 12 | else: 13 | raise 14 | 15 | def searchForMaxIteration(folder): 16 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 17 | return max(saved_iters) 18 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | class Timer: 3 | def __init__(self): 4 | self.start_time = None 5 | self.elapsed = 0 6 | self.paused = False 7 | 8 | def start(self): 9 | if self.start_time is None: 10 | self.start_time = time.time() 11 | elif self.paused: 12 | self.start_time = time.time() - self.elapsed 13 | self.paused = False 14 | 15 | def pause(self): 16 | if not self.paused: 17 | self.elapsed = time.time() - self.start_time 18 | self.paused = True 19 | 20 | def get_elapsed_time(self): 21 | if self.paused: 22 | return self.elapsed 23 | else: 24 | return time.time() - self.start_time --------------------------------------------------------------------------------