├── LICENSE.md ├── README.md ├── arguments ├── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc └── endonerf │ └── default.py ├── assets ├── Tab1.png ├── Tab2-padding.png ├── Tab2.png ├── demo1.mp4 ├── demo_scene.mp4 ├── overview.png └── visual_results.png ├── gaussian_renderer ├── __init__.py └── __pycache__ │ ├── __init__.cpython-37.pyc │ └── network_gui.cpython-37.pyc ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── requirements.txt ├── scene ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── cameras.cpython-37.pyc │ ├── colmap_loader.cpython-37.pyc │ ├── dataset.cpython-37.pyc │ ├── dataset_readers.cpython-37.pyc │ ├── deformation.cpython-37.pyc │ ├── endo_loader.cpython-37.pyc │ ├── flexible_deform_model.cpython-37.pyc │ ├── gaussian_flow_model.cpython-37.pyc │ ├── gaussian_gaussian_model.cpython-37.pyc │ ├── gaussian_model.cpython-37.pyc │ ├── hexplane.cpython-37.pyc │ ├── hyper_loader.cpython-37.pyc │ ├── regulation.cpython-37.pyc │ └── utils.cpython-37.pyc ├── cameras.py ├── dataset_readers.py ├── endo_loader.py ├── flexible_deform_model.py ├── regulation.py └── utils.py ├── stereomis2endonerf.py ├── submodules ├── RAFT │ ├── .gitignore │ ├── LICENSE │ ├── RAFT.png │ ├── README.md │ ├── alt_cuda_corr │ │ ├── correlation.cpp │ │ ├── correlation_kernel.cu │ │ └── setup.py │ ├── chairs_split.txt │ ├── core │ │ ├── __init__.py │ │ ├── corr.py │ │ ├── datasets.py │ │ ├── extractor.py │ │ ├── raft.py │ │ ├── update.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── augmentor.py │ │ │ ├── flow_viz.py │ │ │ ├── frame_utils.py │ │ │ └── utils.py │ ├── demo-frames │ │ ├── frame_0016.png │ │ ├── frame_0017.png │ │ ├── frame_0018.png │ │ ├── frame_0019.png │ │ ├── frame_0020.png │ │ ├── frame_0021.png │ │ ├── frame_0022.png │ │ ├── frame_0023.png │ │ ├── frame_0024.png │ │ └── frame_0025.png │ ├── demo.py │ ├── download_models.sh │ ├── evaluate.py │ ├── pretrained │ │ └── raft-things.pth │ ├── train.py │ ├── train_mixed.sh │ └── train_standard.sh ├── depth-diff-gaussian-rasterization │ ├── .gitignore │ ├── .gitmodules │ ├── CMakeLists.txt │ ├── LICENSE.md │ ├── README.md │ ├── cuda_rasterizer │ │ ├── auxiliary.h │ │ ├── backward.cu │ │ ├── backward.h │ │ ├── config.h │ │ ├── forward.cu │ │ ├── forward.h │ │ ├── rasterizer.h │ │ ├── rasterizer_impl.cu │ │ └── rasterizer_impl.h │ ├── diff_gaussian_rasterization │ │ ├── _C.cpython-37m-x86_64-linux-gnu.so │ │ ├── _C.cpython-38-x86_64-linux-gnu.so │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ └── __init__.cpython-37.pyc │ ├── ext.cpp │ ├── rasterize_points.cu │ ├── rasterize_points.h │ ├── setup.py │ └── third_party │ │ └── stbi_image_write.h └── simple-knn │ ├── build │ ├── lib.linux-x86_64-cpython-37 │ │ └── simple_knn │ │ │ └── _C.cpython-37m-x86_64-linux-gnu.so │ └── temp.linux-x86_64-cpython-37 │ │ ├── ext.o │ │ ├── simple_knn.o │ │ └── spatial.o │ ├── ext.cpp │ ├── setup.py │ ├── simple_knn.cu │ ├── simple_knn.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt │ ├── simple_knn.h │ ├── simple_knn │ ├── .gitkeep │ ├── _C.cpython-37m-x86_64-linux-gnu.so │ └── _C.cpython-38-x86_64-linux-gnu.so │ ├── spatial.cu │ └── spatial.h ├── train.py └── utils ├── TIMES.TTF ├── TIMESBD.TTF ├── TIMESBI.TTF ├── TIMESI.TTF ├── __pycache__ ├── camera_utils.cpython-37.pyc ├── general_utils.cpython-37.pyc ├── graphics_utils.cpython-37.pyc ├── image_utils.cpython-37.pyc ├── loss_utils.cpython-37.pyc ├── params_utils.cpython-37.pyc ├── scene_utils.cpython-37.pyc ├── sh_utils.cpython-37.pyc ├── system_utils.cpython-37.pyc └── timer.cpython-37.pyc ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── params_utils.py ├── scene_utils.py ├── sh_utils.py ├── stereo_rectify.py ├── system_utils.py └── timer.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deform3DGS: Flexible Deformation for Fast Surgical Scene Reconstruction with Gaussian Splatting 2 | 3 | Official code implementation for [Deform3DGS](https://arxiv.org/abs/2405.17835), a Gaussian Splatting based framework for surgical scene reconstruction. 4 | 5 | 6 | 7 | > [Deform3DGS: Flexible Deformation for Fast Surgical Scene Reconstruction with Gaussian Splatting](https://arxiv.org/abs/2405.17835)\ 8 | > Shuojue Yang, Qian Li, Daiyun Shen, Bingchen Gong, Qi Dou, Yueming Jin\ 9 | > MICCAI2024, **Early Accept** 10 | 11 | ## Demo 12 | 13 | ### Reconstruction within 1 minute 14 | 15 | 16 | 17 | https://github.com/jinlab-imvr/Deform3DGS/assets/157268160/7609bfb6-9130-488f-b893-85cc82d60d63 18 | 19 | Compared to previous SOTA method in fast reconstruction, our method reduces the training time to **1 minute** for each clip in EndoNeRF dataset, demonstrating remarkable superiority in efficiency. 20 | 21 | ### Reconstruction of various scenes 22 | 23 | 27 | 28 | https://github.com/jinlab-imvr/Deform3DGS/assets/157268160/633777fa-9110-4823-b6e5-f5d338e72551 29 | 30 | ## Pipeline 31 | 32 | 33 | 34 |

35 | 36 |

37 | 38 | **Deform3DGS** is composed of (a) Point cloud initialization, (b) Flexible Deformation Modeling, and (c) 3D Gaussian Splatting. Experiments on DaVinci robotic surgery videos indicate the efficacy of our approach, showcasing superior reconstruction fidelity PSNR: (37.90) and rendering speed (338.8 FPS) while substantially reducing training time to only 1 minute/scene. 39 | 40 | 44 | 45 | ## Environment setup 46 | 47 | Tested with NVIDIA RTX A5000 GPU. 48 | 49 | ```bash 50 | git clone https://github.com/jinlab-imvr/Deform3DGS.git 51 | cd Deform3DGS 52 | conda create -n Deform3DGS python=3.7 53 | conda activate Deform3DGS 54 | 55 | pip install -r requirements.txt 56 | pip install -e submodules/depth-diff-gaussian-rasterization 57 | pip install -e submodules/simple-knn 58 | ``` 59 | 60 | ## Datasets 61 | 62 | We use 6 clips from [EndoNeRF](https://github.com/med-air/EndoNeRF) and 3 clips manually extracted from [StereoMIS](https://zenodo.org/records/7727692) to verify our method. 63 | 64 | To use the two available examples in [EndoNeRF](https://github.com/med-air/EndoNeRF) dataset. Please download the data via [this link](https://forms.gle/1VAqDJTEgZduD6157) and organize the data according to the [guideline](https://github.com/med-air/EndoNeRF.git). 65 | 66 | To use the [StereoMIS](https://zenodo.org/records/7727692) dataset, please follow this [github repo](https://github.com/aimi-lab/robust-pose-estimator) to preprocess the dataset. After that, run the provided script `stereomis2endonerf.py` to extract clips from the StereoMIS dataset and organize the depth, masks, images, intrinsic and extrinsic parameters in the same format as [EndoNeRF](https://github.com/med-air/EndoNeRF). In our implementation, we used [RAFT](https://github.com/princeton-vl/RAFT) to estimate the stereo depth for [StereoMIS](https://zenodo.org/records/7727692) clips. Following EndoNeRF dataset, this script only supports fixed-view settings. 67 | 68 | 69 | 70 | The data structure is as follows: 71 | 72 | ``` 73 | data 74 | | - endonerf_full_datasets 75 | | | - cutting_tissues_twice 76 | | | | - depth/ 77 | | | | - images/ 78 | | | | - masks/ 79 | | | | - pose_bounds.npy 80 | | | - pushing_soft_tissues 81 | | - StereoMIS 82 | | | - stereo_seq_1 83 | | | - stereo_seq_2 84 | ``` 85 | 86 | ## Training 87 | 88 | To train Deform3DGS with customized hyper-parameters, please make changes in `arguments/endonerf/default.py`. 89 | 90 | To train Deform3DGS, run the following example command: 91 | 92 | ``` 93 | python train.py -s data/endonerf_full_datasets/pulling_soft_tissues --expname endonerf/pulling_fdm --configs arguments/endonerf/default.py 94 | ``` 95 | 96 | ## Testing 97 | 98 | For testing, we perform rendering and evaluation separately. 99 | 100 | ### Rendering 101 | 102 | To run the following example command to render the images: 103 | 104 | ``` 105 | python render.py --model_path output/endonerf/pulling_fdm --skip_train --reconstruct_test --configs arguments/endonerf/default.py 106 | ``` 107 | 108 | Please follow [EndoGaussian](https://github.com/yifliu3/EndoGaussian/tree/master) to skip rendering. Of note, you can also set `--reconstruct_train`, `--reconstruct_test`, and `--reconstruct_video` to reconstruct and save the `.ply` 3D point cloud of the rendered outputs for `train`, `test` and`video` sets, respectively. 109 | 110 | ### Evaluation 111 | 112 | To evaluate the reconstruction quality, run following command: 113 | 114 | ``` 115 | python metrics.py --model_path output/endonerf/pulling_fdm -p test 116 | ``` 117 | 118 | Note that you can set `-p video`, `-p test`, `-p train` to select the set for evaluation. 119 | 120 | ## Acknowledgements 121 | 122 | This repo borrows some source code from [EndoGaussian](https://github.com/yifliu3/EndoGaussian/tree/master), [4DGS](https://github.com/hustvl/4DGaussians), [depth-diff-gaussian-rasterizer](https://github.com/ingra14m/depth-diff-gaussian-rasterization), [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), and [EndoNeRF](https://github.com/med-air/EndoNeRF). We would like to acknowledge these great prior literatures for inspiring our work. 123 | 124 | Thanks to [EndoGaussian](https://github.com/yifliu3/EndoGaussian/tree/master) for their great and timely effort in releasing the framework adapting Gaussian Splatting into surgical scene. 125 | 126 | ## Citation 127 | 128 | If you find this code useful for your research, please use the following BibTeX entries: 129 | 130 | ``` 131 | @misc{yang2024deform3dgs, 132 | title={Deform3DGS: Flexible Deformation for Fast Surgical Scene Reconstruction with Gaussian Splatting}, 133 | author={Shuojue Yang and Qian Li and Daiyun Shen and Bingchen Gong and Qi Dou and Yueming Jin}, 134 | year={2024}, 135 | eprint={2405.17835}, 136 | archivePrefix={arXiv}, 137 | primaryClass={cs.CV} 138 | } 139 | 140 | ``` 141 | 142 | ### Questions 143 | 144 | For further question about the code or paper, welcome to create an issue, or contact 's.yang@u.nus.edu' or 'ymjin@nus.edu.sg' 145 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._resolution = -1 54 | self._white_background = False 55 | self.data_device = "cuda" 56 | self.eval = True 57 | self.render_process=False 58 | self.extra_mark = None 59 | self.camera_extent = None 60 | super().__init__(parser, "Loading Parameters", sentinel) 61 | 62 | def extract(self, args): 63 | g = super().extract(args) 64 | g.source_path = os.path.abspath(g.source_path) 65 | return g 66 | 67 | class PipelineParams(ParamGroup): 68 | def __init__(self, parser): 69 | self.convert_SHs_python = False 70 | self.compute_cov3D_python = False 71 | self.debug = False 72 | super().__init__(parser, "Pipeline Parameters") 73 | 74 | 75 | class FDMHiddenParams(ParamGroup): 76 | def __init__(self, parser): 77 | self.net_width = 64 78 | self.timebase_pe = 4 79 | self.defor_depth = 1 80 | self.posebase_pe = 10 81 | self.scale_rotation_pe = 2 82 | self.opacity_pe = 2 83 | self.timenet_width = 64 84 | self.timenet_output = 32 85 | self.bounds = 1.6 86 | 87 | self.ch_num = 10 88 | self.curve_num = 17 89 | self.init_param = 0.01 90 | 91 | super().__init__(parser, "FDMHiddenParams") 92 | 93 | 94 | class OptimizationParams(ParamGroup): 95 | def __init__(self, parser): 96 | self.dataloader=False 97 | self.iterations = 30_000 98 | self.position_lr_init = 0.00016 99 | self.position_lr_final = 0.0000016 100 | self.position_lr_delay_mult = 0.01 101 | self.position_lr_max_steps = 20_000 102 | self.deformation_lr_init = 0.00016 103 | self.deformation_lr_final = 0.000016 104 | self.deformation_lr_delay_mult = 0.01 105 | # self.grid_lr_init = 0.0016 106 | # self.grid_lr_final = 0.00016 107 | 108 | self.feature_lr = 0.0025 109 | self.opacity_lr = 0.05 110 | self.scaling_lr = 0.005 111 | self.rotation_lr = 0.001 112 | self.percent_dense = 0.01 113 | self.weight_constraint_init= 1 114 | self.weight_constraint_after = 0.2 115 | self.weight_decay_iteration = 5000 116 | self.opacity_reset_interval = 3000 117 | self.densification_interval = 100 118 | self.densify_from_iter = 500 119 | self.densify_until_iter = 15_000 120 | self.densify_grad_threshold_coarse = 0.0002 121 | self.densify_grad_threshold_fine_init = 0.0002 122 | self.densify_grad_threshold_after = 0.0002 123 | self.pruning_from_iter = 500 124 | self.pruning_interval = 100 125 | self.opacity_threshold_coarse = 0.005 126 | self.opacity_threshold_fine_init = 0.005 127 | self.opacity_threshold_fine_after = 0.005 128 | 129 | super().__init__(parser, "Optimization Parameters") 130 | 131 | def get_combined_args(parser : ArgumentParser): 132 | cmdlne_string = sys.argv[1:] 133 | cfgfile_string = "Namespace()" 134 | args_cmdline = parser.parse_args(cmdlne_string) 135 | 136 | try: 137 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 138 | print("Looking for config file in", cfgfilepath) 139 | with open(cfgfilepath) as cfg_file: 140 | print("Config file found: {}".format(cfgfilepath)) 141 | cfgfile_string = cfg_file.read() 142 | except TypeError: 143 | print("Config file not found at") 144 | pass 145 | args_cfgfile = eval(cfgfile_string) 146 | 147 | merged_dict = vars(args_cfgfile).copy() 148 | for k,v in vars(args_cmdline).items(): 149 | if v != None: 150 | merged_dict[k] = v 151 | return Namespace(**merged_dict) 152 | -------------------------------------------------------------------------------- /arguments/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/arguments/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /arguments/endonerf/default.py: -------------------------------------------------------------------------------- 1 | ModelParams = dict( 2 | extra_mark = 'endonerf', 3 | camera_extent = 10 4 | ) 5 | 6 | OptimizationParams = dict( 7 | coarse_iterations = 0, 8 | deformation_lr_init = 0.00016, 9 | deformation_lr_final = 0.0000016, 10 | deformation_lr_delay_mult = 0.01, 11 | iterations = 3000, 12 | percent_dense = 0.01, 13 | opacity_reset_interval = 3000, 14 | position_lr_max_steps = 4000, 15 | prune_interval = 3000 16 | ) 17 | 18 | ModelHiddenParams = dict( 19 | curve_num = 17, # number of learnable basis functions. This number was set to 17 for all the experiments in paper (https://arxiv.org/abs/2405.17835) 20 | 21 | ch_num = 10, # channel number of deformable attributes: 10 = 3 (scale) + 3 (mean) + 4 (rotation) 22 | init_param = 0.01, ) 23 | -------------------------------------------------------------------------------- /assets/Tab1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/Tab1.png -------------------------------------------------------------------------------- /assets/Tab2-padding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/Tab2-padding.png -------------------------------------------------------------------------------- /assets/Tab2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/Tab2.png -------------------------------------------------------------------------------- /assets/demo1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/demo1.mp4 -------------------------------------------------------------------------------- /assets/demo_scene.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/demo_scene.mp4 -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/overview.png -------------------------------------------------------------------------------- /assets/visual_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/visual_results.png -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.flexible_deform_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | 18 | def render_flow(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | 34 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 35 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 36 | 37 | raster_settings = GaussianRasterizationSettings( 38 | image_height=int(viewpoint_camera.image_height), 39 | image_width=int(viewpoint_camera.image_width), 40 | tanfovx=tanfovx, 41 | tanfovy=tanfovy, 42 | bg=bg_color, 43 | scale_modifier=scaling_modifier, 44 | viewmatrix=viewpoint_camera.world_view_transform.cuda(), 45 | projmatrix=viewpoint_camera.full_proj_transform.cuda(), 46 | sh_degree=pc.active_sh_degree, 47 | campos=viewpoint_camera.camera_center.cuda(), 48 | prefiltered=False, 49 | debug=pipe.debug 50 | ) 51 | 52 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 53 | 54 | # means3D = pc.get_xyz 55 | # add deformation to each points 56 | # deformation = pc.get_deformation 57 | means3D = pc.get_xyz 58 | ori_time = torch.tensor(viewpoint_camera.time).to(means3D.device) 59 | means2D = screenspace_points 60 | opacity = pc._opacity 61 | 62 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 63 | # scaling / rotation by the rasterizer. 64 | scales = None 65 | rotations = None 66 | cov3D_precomp = None 67 | 68 | if pipe.compute_cov3D_python: 69 | cov3D_precomp = pc.get_covariance(scaling_modifier) 70 | else: 71 | scales = pc._scaling 72 | rotations = pc._rotation 73 | deformation_point = pc._deformation_table 74 | 75 | 76 | means3D_deform, scales_deform, rotations_deform = pc.deformation(means3D[deformation_point], scales[deformation_point], 77 | rotations[deformation_point], 78 | ori_time) 79 | opacity_deform = opacity[deformation_point] 80 | 81 | # print(time.max()) 82 | with torch.no_grad(): 83 | pc._deformation_accum[deformation_point] += torch.abs(means3D_deform - means3D[deformation_point]) 84 | 85 | means3D_final = torch.zeros_like(means3D) 86 | rotations_final = torch.zeros_like(rotations) 87 | scales_final = torch.zeros_like(scales) 88 | opacity_final = torch.zeros_like(opacity) 89 | means3D_final[deformation_point] = means3D_deform 90 | rotations_final[deformation_point] = rotations_deform 91 | scales_final[deformation_point] = scales_deform 92 | opacity_final[deformation_point] = opacity_deform 93 | means3D_final[~deformation_point] = means3D[~deformation_point] 94 | rotations_final[~deformation_point] = rotations[~deformation_point] 95 | scales_final[~deformation_point] = scales[~deformation_point] 96 | opacity_final[~deformation_point] = opacity[~deformation_point] 97 | 98 | scales_final = pc.scaling_activation(scales_final) 99 | rotations_final = pc.rotation_activation(rotations_final) 100 | opacity = pc.opacity_activation(opacity) 101 | 102 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 103 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 104 | shs = None 105 | colors_precomp = None 106 | if override_color is None: 107 | if pipe.convert_SHs_python: 108 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 109 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.cuda().repeat(pc.get_features.shape[0], 1)) 110 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 111 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 112 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 113 | else: 114 | shs = pc.get_features 115 | else: 116 | colors_precomp = override_color 117 | 118 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 119 | rendered_image, radii, depth = rasterizer( 120 | means3D = means3D_final, 121 | means2D = means2D, 122 | shs = shs, 123 | colors_precomp = colors_precomp, 124 | opacities = opacity, 125 | scales = scales_final, 126 | rotations = rotations_final, 127 | cov3D_precomp = cov3D_precomp) 128 | 129 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 130 | # They will be excluded from value updates used in the splitting criteria. 131 | return {"render": rendered_image, 132 | "depth": depth, 133 | "viewspace_points": screenspace_points, 134 | "visibility_filter" : radii > 0, 135 | "radii": radii,} 136 | -------------------------------------------------------------------------------- /gaussian_renderer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/gaussian_renderer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /gaussian_renderer/__pycache__/network_gui.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/gaussian_renderer/__pycache__/network_gui.cpython-37.pyc -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16().features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | # from lpipsPyTorch import lpips 19 | import lpips 20 | import json 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from utils.image_utils import rmse 24 | from argparse import ArgumentParser 25 | import numpy as np 26 | 27 | 28 | def array2tensor(array, device="cuda", dtype=torch.float32): 29 | return torch.tensor(array, dtype=dtype, device=device) 30 | 31 | # Learned Perceptual Image Patch Similarity 32 | class LPIPS(object): 33 | """ 34 | borrowed from https://github.com/huster-wgm/Pytorch-metrics/blob/master/metrics.py 35 | """ 36 | def __init__(self, device="cuda"): 37 | self.model = lpips.LPIPS(net='alex').to(device) 38 | 39 | def __call__(self, y_pred, y_true, normalized=True): 40 | """ 41 | args: 42 | y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 43 | y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 44 | normalized : change [0,1] => [-1,1] (default by LPIPS) 45 | return LPIPS, smaller the better 46 | """ 47 | if normalized: 48 | y_pred = y_pred * 2.0 - 1.0 49 | y_true = y_true * 2.0 - 1.0 50 | error = self.model.forward(y_pred, y_true) 51 | return torch.mean(error) 52 | 53 | lpips = LPIPS() 54 | def cal_lpips(a, b, device="cuda", batch=2): 55 | """Compute lpips. 56 | a, b: [batch, H, W, 3]""" 57 | if not torch.is_tensor(a): 58 | a = array2tensor(a, device) 59 | if not torch.is_tensor(b): 60 | b = array2tensor(b, device) 61 | 62 | lpips_all = [] 63 | for a_split, b_split in zip(a.split(split_size=batch, dim=0), b.split(split_size=batch, dim=0)): 64 | out = lpips(a_split, b_split) 65 | lpips_all.append(out) 66 | lpips_all = torch.stack(lpips_all) 67 | lpips_mean = lpips_all.mean() 68 | return lpips_mean 69 | 70 | def readImages(renders_dir, gt_dir, depth_dir, gtdepth_dir, masks_dir): 71 | renders = [] 72 | gts = [] 73 | image_names = [] 74 | depths = [] 75 | gt_depths = [] 76 | masks = [] 77 | 78 | for fname in os.listdir(renders_dir): 79 | render = np.array(Image.open(renders_dir / fname)) 80 | gt = np.array(Image.open(gt_dir / fname)) 81 | depth = np.array(Image.open(depth_dir / fname)) 82 | gt_depth = np.array(Image.open(gtdepth_dir / fname)) 83 | mask = np.array(Image.open(masks_dir / fname)) 84 | 85 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 86 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 87 | depths.append(torch.from_numpy(depth).unsqueeze(0).unsqueeze(1)[:, :, :, :].cuda()) 88 | gt_depths.append(torch.from_numpy(gt_depth).unsqueeze(0).unsqueeze(1)[:, :3, :, :].cuda()) 89 | masks.append(tf.to_tensor(mask).unsqueeze(0).cuda()) 90 | 91 | image_names.append(fname) 92 | return renders, gts, depths, gt_depths, masks, image_names 93 | 94 | def evaluate(model_paths): 95 | 96 | full_dict = {} 97 | per_view_dict = {} 98 | full_dict_polytopeonly = {} 99 | per_view_dict_polytopeonly = {} 100 | print("") 101 | 102 | with torch.no_grad(): 103 | for scene_dir in model_paths: 104 | print("Scene:", scene_dir) 105 | full_dict[scene_dir] = {} 106 | per_view_dict[scene_dir] = {} 107 | full_dict_polytopeonly[scene_dir] = {} 108 | per_view_dict_polytopeonly[scene_dir] = {} 109 | 110 | test_dir = Path(scene_dir) / args.phase 111 | 112 | for method in os.listdir(test_dir): 113 | print("Method:", method) 114 | 115 | full_dict[scene_dir][method] = {} 116 | per_view_dict[scene_dir][method] = {} 117 | full_dict_polytopeonly[scene_dir][method] = {} 118 | per_view_dict_polytopeonly[scene_dir][method] = {} 119 | 120 | method_dir = test_dir / method 121 | gt_dir = method_dir/ "gt" 122 | renders_dir = method_dir / "renders" 123 | depth_dir = method_dir / "depth" 124 | gt_depth_dir = method_dir / "gt_depth" 125 | masks_dir = method_dir / "masks" 126 | 127 | renders, gts, depths, gt_depths, masks, image_names = readImages(renders_dir, gt_dir, depth_dir, gt_depth_dir, masks_dir) 128 | 129 | ssims = [] 130 | psnrs = [] 131 | lpipss = [] 132 | rmses = [] 133 | 134 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 135 | render, gt, depth, gt_depth, mask = renders[idx], gts[idx], depths[idx], gt_depths[idx], masks[idx] 136 | render = render * mask 137 | gt = gt * mask 138 | psnrs.append(psnr(render, gt)) 139 | ssims.append(ssim(render, gt)) 140 | lpipss.append(cal_lpips(render, gt)) 141 | if (gt_depth!=0).sum() < 10: 142 | continue 143 | rmses.append(rmse(depth, gt_depth, mask)) 144 | 145 | print("Scene: ", scene_dir, "SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 146 | print("Scene: ", scene_dir, "PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 147 | print("Scene: ", scene_dir, "LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 148 | print("Scene: ", scene_dir, "RMSE: {:>12.7f}".format(torch.tensor(rmses).mean(), ".5")) 149 | print("") 150 | 151 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 152 | "PSNR": torch.tensor(psnrs).mean().item(), 153 | "LPIPS": torch.tensor(lpipss).mean().item(), 154 | "RMSE": torch.tensor(rmses).mean().item()}) 155 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 156 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 157 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}, 158 | "RMSES": {name: lp for lp, name in zip(torch.tensor(rmses).tolist(), image_names)}}) 159 | 160 | with open(scene_dir + "/results.json", 'w') as fp: 161 | json.dump(full_dict[scene_dir], fp, indent=True) 162 | with open(scene_dir + "/per_view.json", 'w') as fp: 163 | json.dump(per_view_dict[scene_dir], fp, indent=True) 164 | 165 | 166 | if __name__ == "__main__": 167 | device = torch.device("cuda:0") 168 | torch.cuda.set_device(device) 169 | 170 | # Set up command line argument parser 171 | parser = ArgumentParser(description="Training script parameters") 172 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 173 | parser.add_argument('--phase', '-p', type=str, default='test') 174 | args = parser.parse_args() 175 | evaluate(args.model_paths) 176 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | torchaudio==0.13.1 4 | mmcv==1.6.0 5 | matplotlib 6 | argparse 7 | lpips 8 | plyfile 9 | imageio-ffmpeg 10 | open3d 11 | imageio 12 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.flexible_deform_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | from torch.utils.data import Dataset 21 | 22 | class Scene: 23 | 24 | gaussians : GaussianModel 25 | 26 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None): 27 | """b 28 | :param path: Path to colmap scene main folder. 29 | """ 30 | self.model_path = args.model_path 31 | self.loaded_iter = None 32 | self.gaussians = gaussians 33 | 34 | if load_iteration: 35 | if load_iteration == -1: 36 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 37 | else: 38 | self.loaded_iter = load_iteration 39 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 40 | 41 | if os.path.exists(os.path.join(args.source_path, "poses_bounds.npy")) and args.extra_mark == 'endonerf': 42 | scene_info = sceneLoadTypeCallbacks["endonerf"](args.source_path) 43 | print("Found poses_bounds.py and extra marks with EndoNeRf") 44 | elif os.path.exists(os.path.join(args.source_path, "point_cloud.obj")) or os.path.exists(os.path.join(args.source_path, "left_point_cloud.obj")): 45 | scene_info = sceneLoadTypeCallbacks["scared"](args.source_path, args.white_background, args.eval) 46 | print("Found point_cloud.obj, assuming SCARED data!") 47 | else: 48 | assert False, "Could not recognize scene type!" 49 | 50 | self.maxtime = scene_info.maxtime 51 | self.cameras_extent = scene_info.nerf_normalization["radius"] 52 | # self.cameras_extent = args.camera_extent 53 | print("self.cameras_extent is ", self.cameras_extent) 54 | 55 | print("Loading Training Cameras") 56 | self.train_camera = scene_info.train_cameras 57 | print("Loading Test Cameras") 58 | self.test_camera = scene_info.test_cameras 59 | print("Loading Video Cameras") 60 | self.video_camera = scene_info.video_cameras 61 | 62 | xyz_max = scene_info.point_cloud.points.max(axis=0) 63 | xyz_min = scene_info.point_cloud.points.min(axis=0) 64 | # self.gaussians._deformation.deformation_net.grid.set_aabb(xyz_max,xyz_min) 65 | 66 | if self.loaded_iter: 67 | self.gaussians.load_ply(os.path.join(self.model_path, 68 | "point_cloud", 69 | "iteration_" + str(self.loaded_iter), 70 | "point_cloud.ply")) 71 | self.gaussians.load_model(os.path.join(self.model_path, 72 | "point_cloud", 73 | "iteration_" + str(self.loaded_iter), 74 | )) 75 | else: 76 | self.gaussians.create_from_pcd(scene_info.point_cloud, args.camera_extent, self.maxtime) 77 | 78 | def save(self, iteration, stage): 79 | if stage == "coarse": 80 | point_cloud_path = os.path.join(self.model_path, "point_cloud/coarse_iteration_{}".format(iteration)) 81 | else: 82 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 83 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 84 | # self.gaussians.save_deformation(point_cloud_path) 85 | 86 | def getTrainCameras(self, scale=1.0): 87 | return self.train_camera 88 | 89 | def getTestCameras(self, scale=1.0): 90 | return self.test_camera 91 | 92 | def getVideoCameras(self, scale=1.0): 93 | return self.video_camera 94 | 95 | -------------------------------------------------------------------------------- /scene/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/cameras.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/cameras.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/colmap_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/colmap_loader.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/dataset_readers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/dataset_readers.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/deformation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/deformation.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/endo_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/endo_loader.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/flexible_deform_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/flexible_deform_model.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/gaussian_flow_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/gaussian_flow_model.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/gaussian_gaussian_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/gaussian_gaussian_model.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/gaussian_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/gaussian_model.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/hexplane.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/hexplane.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/hyper_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/hyper_loader.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/regulation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/regulation.cpython-37.pyc -------------------------------------------------------------------------------- /scene/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getProjectionMatrix2 16 | 17 | 18 | 19 | class Camera(nn.Module): 20 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, depth, mask, gt_alpha_mask, 21 | image_name, uid, 22 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, 23 | data_device = "cuda", time = 0, Znear=None, Zfar=None, 24 | K=None, h=None, w=None 25 | ): 26 | super(Camera, self).__init__() 27 | 28 | self.uid = uid 29 | self.colmap_id = colmap_id 30 | self.R = R 31 | self.T = T 32 | self.FoVx = FoVx 33 | self.FoVy = FoVy 34 | self.image_name = image_name 35 | self.time = time 36 | self.mask = mask 37 | try: 38 | self.data_device = torch.device(data_device) 39 | except Exception as e: 40 | print(e) 41 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 42 | self.data_device = torch.device("cuda") 43 | 44 | self.original_image = image.clamp(0.0, 1.0) 45 | self.original_depth = depth 46 | self.image_width = self.original_image.shape[2] 47 | self.image_height = self.original_image.shape[1] 48 | if gt_alpha_mask is not None: 49 | self.original_image *= gt_alpha_mask 50 | else: 51 | self.original_image *= torch.ones((1, self.image_height, self.image_width)) 52 | 53 | if Zfar is not None and Znear is not None: 54 | self.zfar = Zfar 55 | self.znear = Znear 56 | else: 57 | # ENDONERF 58 | self.zfar = 120.0 59 | self.znear = 0.01 60 | 61 | # StereoMIS 62 | self.zfar = 250 63 | self.znear= 0.03 64 | 65 | self.trans = trans 66 | self.scale = scale 67 | 68 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1) 69 | if K is None or h is None or w is None: 70 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1) 71 | else: 72 | self.projection_matrix = getProjectionMatrix2(znear=self.znear, zfar=self.zfar, K=K, h = h, w=w).transpose(0,1) 73 | 74 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 75 | self.camera_center = self.world_view_transform.inverse()[3, :3] 76 | 77 | class MiniCam: 78 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform, time): 79 | self.image_width = width 80 | self.image_height = height 81 | self.FoVy = fovy 82 | self.FoVx = fovx 83 | self.znear = znear 84 | self.zfar = zfar 85 | self.world_view_transform = world_view_transform 86 | self.full_proj_transform = full_proj_transform 87 | view_inv = torch.inverse(self.world_view_transform) 88 | self.camera_center = view_inv[3][:3] 89 | self.time = time 90 | 91 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | import torchvision.transforms as transforms 17 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 18 | import numpy as np 19 | import torch 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from scene.flexible_deform_model import BasicPointCloud 24 | from utils.general_utils import PILtoTorch 25 | from tqdm import tqdm 26 | 27 | class CameraInfo(NamedTuple): 28 | uid: int 29 | R: np.array 30 | T: np.array 31 | FovY: np.array 32 | FovX: np.array 33 | image: np.array 34 | image_path: str 35 | image_name: str 36 | width: int 37 | height: int 38 | time : float 39 | 40 | class SceneInfo(NamedTuple): 41 | point_cloud: BasicPointCloud 42 | train_cameras: list 43 | test_cameras: list 44 | video_cameras: list 45 | nerf_normalization: dict 46 | ply_path: str 47 | maxtime: int 48 | 49 | def getNerfppNorm(cam_info): 50 | def get_center_and_diag(cam_centers): 51 | cam_centers = np.hstack(cam_centers) 52 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 53 | center = avg_cam_center 54 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 55 | diagonal = np.max(dist) 56 | return center.flatten(), diagonal 57 | 58 | cam_centers = [] 59 | for cam in cam_info: 60 | W2C = getWorld2View2(cam.R, cam.T) 61 | C2W = np.linalg.inv(W2C) 62 | cam_centers.append(C2W[:3, 3:4]) 63 | 64 | center, diagonal = get_center_and_diag(cam_centers) 65 | radius = diagonal * 1.1 66 | translate = -center 67 | 68 | return {"translate": translate, "radius": radius} 69 | 70 | 71 | 72 | def fetchPly(path): 73 | plydata = PlyData.read(path) 74 | vertices = plydata['vertex'] 75 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 76 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 77 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 78 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 79 | 80 | def storePly(path, xyz, rgb): 81 | # Define the dtype for the structured array 82 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 83 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 84 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 85 | 86 | normals = np.zeros_like(xyz) 87 | elements = np.empty(xyz.shape[0], dtype=dtype) 88 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 89 | elements[:] = list(map(tuple, attributes)) 90 | 91 | # Create the PlyData object and write to file 92 | vertex_element = PlyElement.describe(elements, 'vertex') 93 | ply_data = PlyData([vertex_element]) 94 | ply_data.write(path) 95 | 96 | 97 | def generateCamerasFromTransforms(path, template_transformsfile, extension, maxtime): 98 | trans_t = lambda t : torch.Tensor([ 99 | [1,0,0,0], 100 | [0,1,0,0], 101 | [0,0,1,t], 102 | [0,0,0,1]]).float() 103 | 104 | rot_phi = lambda phi : torch.Tensor([ 105 | [1,0,0,0], 106 | [0,np.cos(phi),-np.sin(phi),0], 107 | [0,np.sin(phi), np.cos(phi),0], 108 | [0,0,0,1]]).float() 109 | 110 | rot_theta = lambda th : torch.Tensor([ 111 | [np.cos(th),0,-np.sin(th),0], 112 | [0,1,0,0], 113 | [np.sin(th),0, np.cos(th),0], 114 | [0,0,0,1]]).float() 115 | 116 | def pose_spherical(theta, phi, radius): 117 | c2w = trans_t(radius) 118 | c2w = rot_phi(phi/180.*np.pi) @ c2w 119 | c2w = rot_theta(theta/180.*np.pi) @ c2w 120 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 121 | return c2w 122 | 123 | cam_infos = [] 124 | # generate render poses and times 125 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 126 | render_times = torch.linspace(0,maxtime,render_poses.shape[0]) 127 | with open(os.path.join(path, template_transformsfile)) as json_file: 128 | template_json = json.load(json_file) 129 | fovx = template_json["camera_angle_x"] 130 | # load a single image to get image info. 131 | for idx, frame in enumerate(template_json["frames"]): 132 | cam_name = os.path.join(path, frame["file_path"] + extension) 133 | image_path = os.path.join(path, cam_name) 134 | image_name = Path(cam_name).stem 135 | image = Image.open(image_path) 136 | im_data = np.array(image.convert("RGBA")) 137 | image = PILtoTorch(image,(800,800)) 138 | break 139 | # format information 140 | for idx, (time, poses) in enumerate(zip(render_times,render_poses)): 141 | time = time/maxtime 142 | matrix = np.linalg.inv(np.array(poses)) 143 | R = -np.transpose(matrix[:3,:3]) 144 | R[:,0] = -R[:,0] 145 | T = -matrix[:3, 3] 146 | fovy = focal2fov(fov2focal(fovx, image.shape[1]), image.shape[2]) 147 | FovY = fovy 148 | FovX = fovx 149 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 150 | image_path=None, image_name=None, width=image.shape[1], height=image.shape[2], 151 | time = time)) 152 | return cam_infos 153 | 154 | def readEndoNeRFInfo(datadir): 155 | # load camera infos 156 | from scene.endo_loader import EndoNeRF_Dataset 157 | endo_dataset = EndoNeRF_Dataset( 158 | datadir=datadir, 159 | downsample=1.0, 160 | ) 161 | train_cam_infos = endo_dataset.format_infos(split="train") 162 | test_cam_infos = endo_dataset.format_infos(split="test") 163 | video_cam_infos = endo_dataset.format_infos(split="video") 164 | 165 | # get normalizations 166 | nerf_normalization = getNerfppNorm(train_cam_infos) 167 | 168 | # initialize sparse point clouds 169 | ply_path = os.path.join(datadir, "points3d.ply") 170 | xyz, rgb, normals = endo_dataset.get_sparse_pts() 171 | 172 | normals = np.random.random((xyz.shape[0], 3)) 173 | pcd = BasicPointCloud(points=xyz, colors=rgb, normals=normals) 174 | storePly(ply_path, xyz,rgb*255) 175 | 176 | try: 177 | pcd = fetchPly(ply_path) 178 | except: 179 | pcd = None 180 | 181 | # get the maximum time 182 | maxtime = endo_dataset.get_maxtime() 183 | 184 | scene_info = SceneInfo(point_cloud=pcd, 185 | train_cameras=train_cam_infos, 186 | test_cameras=test_cam_infos, 187 | video_cameras=video_cam_infos, 188 | nerf_normalization=nerf_normalization, 189 | ply_path=ply_path, 190 | maxtime=maxtime) 191 | 192 | return scene_info 193 | 194 | 195 | 196 | sceneLoadTypeCallbacks = { 197 | "endonerf": readEndoNeRFInfo, 198 | } 199 | -------------------------------------------------------------------------------- /scene/regulation.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | from typing import Sequence 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.optim.lr_scheduler 9 | from torch import nn 10 | 11 | 12 | 13 | def compute_plane_tv(t): 14 | batch_size, c, h, w = t.shape 15 | count_h = batch_size * c * (h - 1) * w 16 | count_w = batch_size * c * h * (w - 1) 17 | h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum() 18 | w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum() 19 | return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg 20 | 21 | 22 | def compute_plane_smoothness(t): 23 | batch_size, c, h, w = t.shape 24 | # Convolve with a second derivative filter, in the time dimension which is dimension 2 25 | first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w] 26 | second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w] 27 | # Take the L2 norm of the result 28 | return torch.square(second_difference).mean() 29 | 30 | 31 | class Regularizer(): 32 | def __init__(self, reg_type, initialization): 33 | self.reg_type = reg_type 34 | self.initialization = initialization 35 | self.weight = float(self.initialization) 36 | self.last_reg = None 37 | 38 | def step(self, global_step): 39 | pass 40 | 41 | def report(self, d): 42 | if self.last_reg is not None: 43 | d[self.reg_type].update(self.last_reg.item()) 44 | 45 | def regularize(self, *args, **kwargs) -> torch.Tensor: 46 | out = self._regularize(*args, **kwargs) * self.weight 47 | self.last_reg = out.detach() 48 | return out 49 | 50 | @abc.abstractmethod 51 | def _regularize(self, *args, **kwargs) -> torch.Tensor: 52 | raise NotImplementedError() 53 | 54 | def __str__(self): 55 | return f"Regularizer({self.reg_type}, weight={self.weight})" 56 | 57 | 58 | class PlaneTV(Regularizer): 59 | def __init__(self, initial_value, what: str = 'field'): 60 | if what not in {'field', 'proposal_network'}: 61 | raise ValueError(f'what must be one of "field" or "proposal_network" ' 62 | f'but {what} was passed.') 63 | name = f'planeTV-{what[:2]}' 64 | super().__init__(name, initial_value) 65 | self.what = what 66 | 67 | def step(self, global_step): 68 | pass 69 | 70 | def _regularize(self, model, **kwargs): 71 | multi_res_grids: Sequence[nn.ParameterList] 72 | if self.what == 'field': 73 | multi_res_grids = model.field.grids 74 | elif self.what == 'proposal_network': 75 | multi_res_grids = [p.grids for p in model.proposal_networks] 76 | else: 77 | raise NotImplementedError(self.what) 78 | total = 0 79 | # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w] 80 | for grids in multi_res_grids: 81 | if len(grids) == 3: 82 | spatial_grids = [0, 1, 2] 83 | else: 84 | spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal 85 | for grid_id in spatial_grids: 86 | total += compute_plane_tv(grids[grid_id]) 87 | for grid in grids: 88 | # grid: [1, c, h, w] 89 | total += compute_plane_tv(grid) 90 | return total 91 | 92 | 93 | class TimeSmoothness(Regularizer): 94 | def __init__(self, initial_value, what: str = 'field'): 95 | if what not in {'field', 'proposal_network'}: 96 | raise ValueError(f'what must be one of "field" or "proposal_network" ' 97 | f'but {what} was passed.') 98 | name = f'time-smooth-{what[:2]}' 99 | super().__init__(name, initial_value) 100 | self.what = what 101 | 102 | def _regularize(self, model, **kwargs) -> torch.Tensor: 103 | multi_res_grids: Sequence[nn.ParameterList] 104 | if self.what == 'field': 105 | multi_res_grids = model.field.grids 106 | elif self.what == 'proposal_network': 107 | multi_res_grids = [p.grids for p in model.proposal_networks] 108 | else: 109 | raise NotImplementedError(self.what) 110 | total = 0 111 | # model.grids is 6 x [1, rank * F_dim, reso, reso] 112 | for grids in multi_res_grids: 113 | if len(grids) == 3: 114 | time_grids = [] 115 | else: 116 | time_grids = [2, 4, 5] 117 | for grid_id in time_grids: 118 | total += compute_plane_smoothness(grids[grid_id]) 119 | return torch.as_tensor(total) 120 | 121 | 122 | 123 | class L1ProposalNetwork(Regularizer): 124 | def __init__(self, initial_value): 125 | super().__init__('l1-proposal-network', initial_value) 126 | 127 | def _regularize(self, model, **kwargs) -> torch.Tensor: 128 | grids = [p.grids for p in model.proposal_networks] 129 | total = 0.0 130 | for pn_grids in grids: 131 | for grid in pn_grids: 132 | total += torch.abs(grid).mean() 133 | return torch.as_tensor(total) 134 | 135 | 136 | class DepthTV(Regularizer): 137 | def __init__(self, initial_value): 138 | super().__init__('tv-depth', initial_value) 139 | 140 | def _regularize(self, model, model_out, **kwargs) -> torch.Tensor: 141 | depth = model_out['depth'] 142 | tv = compute_plane_tv( 143 | depth.reshape(64, 64)[None, None, :, :] 144 | ) 145 | return tv 146 | 147 | 148 | class L1TimePlanes(Regularizer): 149 | def __init__(self, initial_value, what='field'): 150 | if what not in {'field', 'proposal_network'}: 151 | raise ValueError(f'what must be one of "field" or "proposal_network" ' 152 | f'but {what} was passed.') 153 | super().__init__(f'l1-time-{what[:2]}', initial_value) 154 | self.what = what 155 | 156 | def _regularize(self, model, **kwargs) -> torch.Tensor: 157 | # model.grids is 6 x [1, rank * F_dim, reso, reso] 158 | multi_res_grids: Sequence[nn.ParameterList] 159 | if self.what == 'field': 160 | multi_res_grids = model.field.grids 161 | elif self.what == 'proposal_network': 162 | multi_res_grids = [p.grids for p in model.proposal_networks] 163 | else: 164 | raise NotImplementedError(self.what) 165 | 166 | total = 0.0 167 | for grids in multi_res_grids: 168 | if len(grids) == 3: 169 | continue 170 | else: 171 | # These are the spatiotemporal grids 172 | spatiotemporal_grids = [2, 4, 5] 173 | for grid_id in spatiotemporal_grids: 174 | total += torch.abs(1 - grids[grid_id]).mean() 175 | return torch.as_tensor(total) 176 | 177 | -------------------------------------------------------------------------------- /stereomis2endonerf.py: -------------------------------------------------------------------------------- 1 | from utils.stereo_rectify import StereoRectifier 2 | from submodules.RAFT.core.raft import RAFT 3 | from argparse import ArgumentParser, Action 4 | from torchvision.transforms import Resize, InterpolationMode 5 | from collections import OrderedDict 6 | import os 7 | import numpy as np 8 | import torch 9 | import cv2 10 | import shutil 11 | 12 | RAFT_config = { 13 | "pretrained": "submodules/RAFT/pretrained/raft-things.pth", 14 | "iters": 12, 15 | "dropout": 0.0, 16 | "small": False, 17 | "pose_scale": 1.0, 18 | "lbgfs_iters": 100, 19 | "use_weights": True, 20 | "dbg": False 21 | } 22 | 23 | def check_arg_limits(arg_name, n): 24 | class CheckArgLimits(Action): 25 | def __call__(self, parser, args, values, option_string=None): 26 | if len(values) > n: 27 | parser.error("Too many arguments for " + arg_name + ". Maximum is {0}.".format(n)) 28 | if len(values) < n: 29 | parser.error("Too few arguments for " + arg_name + ". Minimum is {0}.".format(n)) 30 | setattr(args, self.dest, values) 31 | return CheckArgLimits 32 | 33 | def read_mask(path): 34 | mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 35 | mask = mask > 0 36 | mask = torch.from_numpy(mask).unsqueeze(0) 37 | return mask 38 | 39 | 40 | class DepthEstimator(torch.nn.Module): 41 | def __init__(self, config): 42 | super(DepthEstimator, self).__init__() 43 | self.model = RAFT(config).to('cuda') 44 | self.model.freeze_bn() 45 | new_state_dict = OrderedDict() 46 | raft_ckp = config['pretrained'] 47 | try: 48 | state_dict = torch.load(raft_ckp) 49 | except RuntimeError: 50 | state_dict = torch.load(raft_ckp, map_location='cpu') 51 | for k, v in state_dict.items(): 52 | name = k.replace('module.','') # remove `module.` 53 | new_state_dict[name] = v 54 | self.model.load_state_dict(new_state_dict) 55 | 56 | def forward(self, imagel, imager, baseline, upsample=True): 57 | n, _, h, w = imagel.shape 58 | flow = self.model(imagel.to('cuda'), imager.to('cuda'), upsample=upsample)[0][-1] 59 | baseline = torch.from_numpy(baseline).to('cuda') 60 | depth = baseline[:, None, None] / -flow[:, 0] 61 | if not upsample: 62 | depth/= 8.0 # factor 8 of upsampling 63 | valid = (depth > 0) & (depth <= 250.0) 64 | depth[~valid] = 0.0 65 | return depth.unsqueeze(1) 66 | 67 | def reformat_dataset(data_dir, start_frame, end_frame, img_size=(512, 640)): 68 | """ 69 | Reformat the StereoMIS to the same format as EndoNeRF dataset by stereo depth estimation. 70 | """ 71 | # Load parameters after rectification 72 | calib_file = os.path.join(data_dir, 'StereoCalibration.ini') 73 | assert os.path.exists(calib_file), "Calibration file not found." 74 | rect = StereoRectifier(calib_file, img_size_new=(img_size[1], img_size[0]), mode='conventional') 75 | calib = rect.get_rectified_calib() 76 | baseline = calib['bf'].astype(np.float32) 77 | intrinsics = calib['intrinsics']['left'].astype(np.float32) 78 | 79 | # Sort images and masks according to the start and end frame indexes 80 | frames = sorted(os.listdir(os.path.join(data_dir, 'masks'))) 81 | frames = [f for f in frames if 'l.png' in f and int(f.split('l.')[0]) >= start_frame and int(f.split('l.')[0]) <= end_frame] 82 | assert len(frames) > 0, "No frames found." 83 | resize = Resize(img_size) 84 | resize_msk = Resize(img_size, interpolation=InterpolationMode.NEAREST) 85 | 86 | # Configurate depth estimator. We follow the settings of RAFT in robust-pose-estimator(https://github.com/aimi-lab/robust-pose-estimator) 87 | depth_estimator = DepthEstimator(RAFT_config) 88 | 89 | # Create folders 90 | output_dir = os.path.join(data_dir, 'stereo_'+ os.path.basename(data_dir)+'_'+str(start_frame)+'_'+str(end_frame)) 91 | image_dir = os.path.join(output_dir, 'images') 92 | mask_dir = os.path.join(output_dir, 'masks') 93 | depth_dir = os.path.join(output_dir, 'depth') 94 | if not os.path.exists(image_dir): 95 | os.makedirs(image_dir) 96 | if not os.path.exists(mask_dir): 97 | os.makedirs(mask_dir) 98 | if not os.path.exists(depth_dir): 99 | os.makedirs(depth_dir) 100 | poses_bounds = [] 101 | for i, frame in enumerate(frames): 102 | left_img = torch.from_numpy(cv2.cvtColor(cv2.imread(os.path.join(data_dir, 'video_frames', frame)), cv2.COLOR_BGR2RGB)).permute(2, 0, 1).float() 103 | right_img = torch.from_numpy(cv2.cvtColor(cv2.imread(os.path.join(data_dir, 'video_frames', frame.replace('l', 'r'))), cv2.COLOR_BGR2RGB)).permute(2, 0, 1).float() 104 | left_img = resize(left_img) 105 | right_img = resize(right_img) 106 | with torch.no_grad(): 107 | depth = depth_estimator(left_img[None], right_img[None], baseline[None]) 108 | try: 109 | mask = read_mask(os.path.join(data_dir, 'masks', frame)) 110 | mask = resize_msk(mask) 111 | except: 112 | mask = torch.ones(1, img_size[0], img_size[1]) 113 | 114 | # Save the data. Of note, the file should start with 'stereo_' to be compatible with the dataloader in Deform3DGS. 115 | left_img_np = left_img.permute(1, 2, 0).numpy() 116 | mask_np = mask.permute(1, 2, 0).numpy() 117 | left_img_bgr = cv2.cvtColor(left_img_np, cv2.COLOR_RGB2BGR) 118 | 119 | # Save left_img, right_img, and mask to output_dir 120 | name = 'frame-'+str(i).zfill(6)+'.color.png' 121 | cv2.imwrite(os.path.join(image_dir, name), left_img_bgr) 122 | cv2.imwrite(os.path.join(mask_dir, name.replace('color','mask')), mask_np.astype(np.uint8) * 255) 123 | cv2.imwrite(os.path.join(depth_dir, name.replace('color','depth')), depth[0, 0].cpu().numpy()) 124 | 125 | # Save poses_bounds.npy. Only static view is considered, i.e., R = I and T = 0. 126 | R = np.eye(3) 127 | T = np.zeros(3) 128 | extr = np.concatenate([R, T[:, None]], axis=1) 129 | cy, cx, focal = intrinsics[1, 2], intrinsics[0, 2], intrinsics[0, 0] 130 | param = np.concatenate([extr, np.array([[cy, cx, focal]]).T], axis=1) 131 | param = param.reshape(1, 15) 132 | param = np.concatenate([param, np.array([[0.03, 250.]])], axis=1) 133 | poses_bounds.append(param[0]) 134 | 135 | np.save(os.path.join(output_dir, 'poses_bounds.npy'), np.array(poses_bounds)) 136 | 137 | if __name__ == "__main__": 138 | torch.manual_seed(1234) 139 | np.random.seed(1234) 140 | # Set up command line argument parser 141 | parser = ArgumentParser(description="parameters for dataset format conversions") 142 | parser.add_argument('--data_dir', '-d', type=str, default='data/StereoMIS/P3') 143 | # Frame ID of the start and end of the sequence. Of note, only 2 arguments (start and end) are required. 144 | parser.add_argument('--frame_id', '-f',nargs="+", action=check_arg_limits('frame_id', 2), type=int, default=[9100, 9467]) 145 | args = parser.parse_args() 146 | frame_id = args.frame_id 147 | reformat_dataset(args.data_dir, frame_id[0], frame_id[1]) -------------------------------------------------------------------------------- /submodules/RAFT/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | dist 4 | datasets 5 | pytorch_env 6 | models 7 | build 8 | correlation.egg-info 9 | -------------------------------------------------------------------------------- /submodules/RAFT/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, princeton-vl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /submodules/RAFT/RAFT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/RAFT.png -------------------------------------------------------------------------------- /submodules/RAFT/README.md: -------------------------------------------------------------------------------- 1 | # RAFT 2 | This repository contains the source code for our paper: 3 | 4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 | 8 | 9 | 10 | ## Requirements 11 | The code has been tested with PyTorch 1.6 and Cuda 10.1. 12 | ```Shell 13 | conda create --name raft 14 | conda activate raft 15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch 16 | ``` 17 | 18 | ## Demos 19 | Pretrained models can be downloaded by running 20 | ```Shell 21 | ./download_models.sh 22 | ``` 23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) 24 | 25 | You can demo a trained model on a sequence of frames 26 | ```Shell 27 | python demo.py --model=models/raft-things.pth --path=demo-frames 28 | ``` 29 | 30 | ## Required Data 31 | To evaluate/train RAFT, you will need to download the required datasets. 32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 34 | * [Sintel](http://sintel.is.tue.mpg.de/) 35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) 37 | 38 | 39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder 40 | 41 | ```Shell 42 | ├── datasets 43 | ├── Sintel 44 | ├── test 45 | ├── training 46 | ├── KITTI 47 | ├── testing 48 | ├── training 49 | ├── devkit 50 | ├── FlyingChairs_release 51 | ├── data 52 | ├── FlyingThings3D 53 | ├── frames_cleanpass 54 | ├── frames_finalpass 55 | ├── optical_flow 56 | ``` 57 | 58 | ## Evaluation 59 | You can evaluate a trained model using `evaluate.py` 60 | ```Shell 61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision 62 | ``` 63 | 64 | ## Training 65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard 66 | ```Shell 67 | ./train_standard.sh 68 | ``` 69 | 70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) 71 | ```Shell 72 | ./train_mixed.sh 73 | ``` 74 | 75 | ## (Optional) Efficent Implementation 76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension 77 | ```Shell 78 | cd alt_cuda_corr && python setup.py install && cd .. 79 | ``` 80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. 81 | -------------------------------------------------------------------------------- /submodules/RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /submodules/RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /submodules/RAFT/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/core/__init__.py -------------------------------------------------------------------------------- /submodules/RAFT/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 39 | delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /submodules/RAFT/core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock, SmallUpdateBlock 7 | from .extractor import BasicEncoder, SmallEncoder 8 | from .corr import CorrBlock, AlternateCorrBlock 9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, config): 26 | super(RAFT, self).__init__() 27 | self.config = config 28 | 29 | if config['small']: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | config['corr_levels'] = 4 33 | config['corr_radius'] = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | config['corr_levels'] = 4 39 | config['corr_radius'] = 4 40 | 41 | # feature network, context network, and update block 42 | if config['small']: 43 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=config['dropout']) 44 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=config['dropout']) 45 | self.update_block = SmallUpdateBlock(config, hidden_dim=hdim) 46 | 47 | else: 48 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=config['dropout']) 49 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=config['dropout']) 50 | self.update_block = BasicUpdateBlock(config, hidden_dim=hdim) 51 | 52 | def freeze_bn(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.BatchNorm2d): 55 | m.eval() 56 | 57 | def initialize_flow(self, img): 58 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 59 | N, C, H, W = img.shape 60 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 61 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 62 | 63 | # optical flow computed as difference: flow = coords1 - coords0 64 | return coords0, coords1 65 | 66 | def upsample_flow(self, flow, mask): 67 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 68 | N, _, H, W = flow.shape 69 | mask = mask.view(N, 1, 9, 8, 8, H, W) 70 | mask = torch.softmax(mask, dim=2) 71 | 72 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 73 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 74 | 75 | up_flow = torch.sum(mask * up_flow, dim=2) 76 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 77 | return up_flow.reshape(N, 2, 8*H, 8*W) 78 | 79 | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False, upsample=True): 80 | """ Estimate optical flow between pair of frames """ 81 | 82 | image1 = 2 * (image1 / 255.0) - 1.0 83 | image2 = 2 * (image2 / 255.0) - 1.0 84 | 85 | image1 = image1.contiguous() 86 | image2 = image2.contiguous() 87 | 88 | hdim = self.hidden_dim 89 | cdim = self.context_dim 90 | 91 | # run the feature network 92 | with autocast(): 93 | fmap1, fmap2 = self.fnet([image1, image2]) 94 | 95 | fmap1 = fmap1.float() 96 | fmap2 = fmap2.float() 97 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.config['corr_radius']) 98 | 99 | # run the context network 100 | with autocast(): 101 | cnet = self.cnet(image1) 102 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 103 | net = torch.tanh(net) 104 | inp = torch.relu(inp) 105 | 106 | coords0, coords1 = self.initialize_flow(image1) 107 | 108 | if flow_init is not None: 109 | coords1 = coords1 + flow_init 110 | 111 | flow_predictions = [] 112 | for itr in range(iters): 113 | coords1 = coords1.detach() 114 | corr = corr_fn(coords1) # index correlation volume 115 | 116 | flow = coords1 - coords0 117 | with autocast(): 118 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 119 | 120 | # F(t+1) = F(t) + \Delta(t) 121 | coords1 = coords1 + delta_flow 122 | 123 | # upsample predictions 124 | if upsample: 125 | if up_mask is None: 126 | flow_up = upflow8(coords1 - coords0) 127 | else: 128 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 129 | else: 130 | flow_up = coords1 - coords0 131 | 132 | flow_predictions.append(flow_up) 133 | 134 | if test_mode: 135 | return coords1 - coords0, flow_up 136 | 137 | return flow_predictions, net, inp 138 | -------------------------------------------------------------------------------- /submodules/RAFT/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args['corr_levels'] * (2*args['corr_radius'] + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args['corr_levels'] * (2*args['corr_radius'] + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /submodules/RAFT/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/core/utils/__init__.py -------------------------------------------------------------------------------- /submodules/RAFT/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /submodules/RAFT/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /submodules/RAFT/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd, device): 75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij') 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0016.png -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0017.png -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0018.png -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0019.png -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0020.png -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0021.png -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0022.png -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0023.png -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0024.png -------------------------------------------------------------------------------- /submodules/RAFT/demo-frames/frame_0025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0025.png -------------------------------------------------------------------------------- /submodules/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | from raft import RAFT 13 | from utils import flow_viz 14 | from utils.utils import InputPadder 15 | 16 | 17 | 18 | DEVICE = 'cuda' 19 | 20 | def load_image(imfile): 21 | img = np.array(Image.open(imfile)).astype(np.uint8) 22 | img = torch.from_numpy(img).permute(2, 0, 1).float() 23 | return img[None].to(DEVICE) 24 | 25 | 26 | def viz(img, flo): 27 | img = img[0].permute(1,2,0).cpu().numpy() 28 | flo = flo[0].permute(1,2,0).cpu().numpy() 29 | 30 | # map flow to rgb image 31 | flo = flow_viz.flow_to_image(flo) 32 | img_flo = np.concatenate([img, flo], axis=0) 33 | 34 | # import matplotlib.pyplot as plt 35 | # plt.imshow(img_flo / 255.0) 36 | # plt.show() 37 | 38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 39 | cv2.waitKey() 40 | 41 | 42 | def demo(args): 43 | model = torch.nn.DataParallel(RAFT(args)) 44 | model.load_state_dict(torch.load(args.model)) 45 | 46 | model = model.module 47 | model.to(DEVICE) 48 | model.eval() 49 | 50 | with torch.no_grad(): 51 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 52 | glob.glob(os.path.join(args.path, '*.jpg')) 53 | 54 | images = sorted(images) 55 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 56 | image1 = load_image(imfile1) 57 | image2 = load_image(imfile2) 58 | 59 | padder = InputPadder(image1.shape) 60 | image1, image2 = padder.pad(image1, image2) 61 | 62 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 63 | viz(image1, flow_up) 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--model', help="restore checkpoint") 69 | parser.add_argument('--path', help="dataset for evaluation") 70 | parser.add_argument('--small', action='store_true', help='use small model') 71 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 72 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 73 | args = parser.parse_args() 74 | 75 | demo(args) 76 | -------------------------------------------------------------------------------- /submodules/RAFT/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 3 | unzip models.zip 4 | -------------------------------------------------------------------------------- /submodules/RAFT/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | from PIL import Image 5 | import argparse 6 | import os 7 | import time 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import matplotlib.pyplot as plt 12 | 13 | import datasets 14 | from utils import flow_viz 15 | from utils import frame_utils 16 | 17 | from raft import RAFT 18 | from utils.utils import InputPadder, forward_interpolate 19 | 20 | 21 | @torch.no_grad() 22 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'): 23 | """ Create submission for the Sintel leaderboard """ 24 | model.eval() 25 | for dstype in ['clean', 'final']: 26 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype) 27 | 28 | flow_prev, sequence_prev = None, None 29 | for test_id in range(len(test_dataset)): 30 | image1, image2, (sequence, frame) = test_dataset[test_id] 31 | if sequence != sequence_prev: 32 | flow_prev = None 33 | 34 | padder = InputPadder(image1.shape) 35 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 36 | 37 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True) 38 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 39 | 40 | if warm_start: 41 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 42 | 43 | output_dir = os.path.join(output_path, dstype, sequence) 44 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1)) 45 | 46 | if not os.path.exists(output_dir): 47 | os.makedirs(output_dir) 48 | 49 | frame_utils.writeFlow(output_file, flow) 50 | sequence_prev = sequence 51 | 52 | 53 | @torch.no_grad() 54 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'): 55 | """ Create submission for the Sintel leaderboard """ 56 | model.eval() 57 | test_dataset = datasets.KITTI(split='testing', aug_params=None) 58 | 59 | if not os.path.exists(output_path): 60 | os.makedirs(output_path) 61 | 62 | for test_id in range(len(test_dataset)): 63 | image1, image2, (frame_id, ) = test_dataset[test_id] 64 | padder = InputPadder(image1.shape, mode='kitti') 65 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 66 | 67 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 68 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 69 | 70 | output_filename = os.path.join(output_path, frame_id) 71 | frame_utils.writeFlowKITTI(output_filename, flow) 72 | 73 | 74 | @torch.no_grad() 75 | def validate_chairs(model, iters=24): 76 | """ Perform evaluation on the FlyingChairs (test) split """ 77 | model.eval() 78 | epe_list = [] 79 | 80 | val_dataset = datasets.FlyingChairs(split='validation') 81 | for val_id in range(len(val_dataset)): 82 | image1, image2, flow_gt, _ = val_dataset[val_id] 83 | image1 = image1[None].cuda() 84 | image2 = image2[None].cuda() 85 | 86 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 87 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt() 88 | epe_list.append(epe.view(-1).numpy()) 89 | 90 | epe = np.mean(np.concatenate(epe_list)) 91 | print("Validation Chairs EPE: %f" % epe) 92 | return {'chairs': epe} 93 | 94 | 95 | @torch.no_grad() 96 | def validate_sintel(model, iters=32): 97 | """ Peform validation using the Sintel (train) split """ 98 | model.eval() 99 | results = {} 100 | for dstype in ['clean', 'final']: 101 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype) 102 | epe_list = [] 103 | 104 | for val_id in range(len(val_dataset)): 105 | image1, image2, flow_gt, _ = val_dataset[val_id] 106 | image1 = image1[None].cuda() 107 | image2 = image2[None].cuda() 108 | 109 | padder = InputPadder(image1.shape) 110 | image1, image2 = padder.pad(image1, image2) 111 | 112 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 113 | flow = padder.unpad(flow_pr[0]).cpu() 114 | 115 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 116 | epe_list.append(epe.view(-1).numpy()) 117 | 118 | epe_all = np.concatenate(epe_list) 119 | epe = np.mean(epe_all) 120 | px1 = np.mean(epe_all<1) 121 | px3 = np.mean(epe_all<3) 122 | px5 = np.mean(epe_all<5) 123 | 124 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) 125 | results[dstype] = np.mean(epe_list) 126 | 127 | return results 128 | 129 | 130 | @torch.no_grad() 131 | def validate_kitti(model, iters=24): 132 | """ Peform validation using the KITTI-2015 (train) split """ 133 | model.eval() 134 | val_dataset = datasets.KITTI(split='training') 135 | 136 | out_list, epe_list = [], [] 137 | for val_id in range(len(val_dataset)): 138 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 139 | image1 = image1[None].cuda() 140 | image2 = image2[None].cuda() 141 | 142 | padder = InputPadder(image1.shape, mode='kitti') 143 | image1, image2 = padder.pad(image1, image2) 144 | 145 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 146 | flow = padder.unpad(flow_pr[0]).cpu() 147 | 148 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 149 | mag = torch.sum(flow_gt**2, dim=0).sqrt() 150 | 151 | epe = epe.view(-1) 152 | mag = mag.view(-1) 153 | val = valid_gt.view(-1) >= 0.5 154 | 155 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 156 | epe_list.append(epe[val].mean().item()) 157 | out_list.append(out[val].cpu().numpy()) 158 | 159 | epe_list = np.array(epe_list) 160 | out_list = np.concatenate(out_list) 161 | 162 | epe = np.mean(epe_list) 163 | f1 = 100 * np.mean(out_list) 164 | 165 | print("Validation KITTI: %f, %f" % (epe, f1)) 166 | return {'kitti-epe': epe, 'kitti-f1': f1} 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument('--model', help="restore checkpoint") 172 | parser.add_argument('--dataset', help="dataset for evaluation") 173 | parser.add_argument('--small', action='store_true', help='use small model') 174 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 175 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 176 | args = parser.parse_args() 177 | 178 | model = torch.nn.DataParallel(RAFT(args)) 179 | model.load_state_dict(torch.load(args.model)) 180 | 181 | model.cuda() 182 | model.eval() 183 | 184 | # create_sintel_submission(model.module, warm_start=True) 185 | # create_kitti_submission(model.module) 186 | 187 | with torch.no_grad(): 188 | if args.dataset == 'chairs': 189 | validate_chairs(model.module) 190 | 191 | elif args.dataset == 'sintel': 192 | validate_sintel(model.module) 193 | 194 | elif args.dataset == 'kitti': 195 | validate_kitti(model.module) 196 | 197 | 198 | -------------------------------------------------------------------------------- /submodules/RAFT/pretrained/raft-things.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/pretrained/raft-things.pth -------------------------------------------------------------------------------- /submodules/RAFT/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import sys 3 | sys.path.append('core') 4 | 5 | import argparse 6 | import os 7 | import cv2 8 | import time 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | 17 | from torch.utils.data import DataLoader 18 | from raft import RAFT 19 | import evaluate 20 | import datasets 21 | 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | try: 25 | from torch.cuda.amp import GradScaler 26 | except: 27 | # dummy GradScaler for PyTorch < 1.6 28 | class GradScaler: 29 | def __init__(self): 30 | pass 31 | def scale(self, loss): 32 | return loss 33 | def unscale_(self, optimizer): 34 | pass 35 | def step(self, optimizer): 36 | optimizer.step() 37 | def update(self): 38 | pass 39 | 40 | 41 | # exclude extremly large displacements 42 | MAX_FLOW = 400 43 | SUM_FREQ = 100 44 | VAL_FREQ = 5000 45 | 46 | 47 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW): 48 | """ Loss function defined over sequence of flow predictions """ 49 | 50 | n_predictions = len(flow_preds) 51 | flow_loss = 0.0 52 | 53 | # exlude invalid pixels and extremely large diplacements 54 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 55 | valid = (valid >= 0.5) & (mag < max_flow) 56 | 57 | for i in range(n_predictions): 58 | i_weight = gamma**(n_predictions - i - 1) 59 | i_loss = (flow_preds[i] - flow_gt).abs() 60 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 61 | 62 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 63 | epe = epe.view(-1)[valid.view(-1)] 64 | 65 | metrics = { 66 | 'epe': epe.mean().item(), 67 | '1px': (epe < 1).float().mean().item(), 68 | '3px': (epe < 3).float().mean().item(), 69 | '5px': (epe < 5).float().mean().item(), 70 | } 71 | 72 | return flow_loss, metrics 73 | 74 | 75 | def count_parameters(model): 76 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 77 | 78 | 79 | def fetch_optimizer(args, model): 80 | """ Create the optimizer and learning rate scheduler """ 81 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 82 | 83 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, 84 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 85 | 86 | return optimizer, scheduler 87 | 88 | 89 | class Logger: 90 | def __init__(self, model, scheduler): 91 | self.model = model 92 | self.scheduler = scheduler 93 | self.total_steps = 0 94 | self.running_loss = {} 95 | self.writer = None 96 | 97 | def _print_training_status(self): 98 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())] 99 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0]) 100 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 101 | 102 | # print the training status 103 | print(training_str + metrics_str) 104 | 105 | if self.writer is None: 106 | self.writer = SummaryWriter() 107 | 108 | for k in self.running_loss: 109 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps) 110 | self.running_loss[k] = 0.0 111 | 112 | def push(self, metrics): 113 | self.total_steps += 1 114 | 115 | for key in metrics: 116 | if key not in self.running_loss: 117 | self.running_loss[key] = 0.0 118 | 119 | self.running_loss[key] += metrics[key] 120 | 121 | if self.total_steps % SUM_FREQ == SUM_FREQ-1: 122 | self._print_training_status() 123 | self.running_loss = {} 124 | 125 | def write_dict(self, results): 126 | if self.writer is None: 127 | self.writer = SummaryWriter() 128 | 129 | for key in results: 130 | self.writer.add_scalar(key, results[key], self.total_steps) 131 | 132 | def close(self): 133 | self.writer.close() 134 | 135 | 136 | def train(args): 137 | 138 | model = nn.DataParallel(RAFT(args), device_ids=args.gpus) 139 | print("Parameter Count: %d" % count_parameters(model)) 140 | 141 | if args.restore_ckpt is not None: 142 | model.load_state_dict(torch.load(args.restore_ckpt), strict=False) 143 | 144 | model.cuda() 145 | model.train() 146 | 147 | if args.stage != 'chairs': 148 | model.module.freeze_bn() 149 | 150 | train_loader = datasets.fetch_dataloader(args) 151 | optimizer, scheduler = fetch_optimizer(args, model) 152 | 153 | total_steps = 0 154 | scaler = GradScaler(enabled=args.mixed_precision) 155 | logger = Logger(model, scheduler) 156 | 157 | VAL_FREQ = 5000 158 | add_noise = True 159 | 160 | should_keep_training = True 161 | while should_keep_training: 162 | 163 | for i_batch, data_blob in enumerate(train_loader): 164 | optimizer.zero_grad() 165 | image1, image2, flow, valid = [x.cuda() for x in data_blob] 166 | 167 | if args.add_noise: 168 | stdv = np.random.uniform(0.0, 5.0) 169 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0) 170 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0) 171 | 172 | flow_predictions = model(image1, image2, iters=args.iters) 173 | 174 | loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) 175 | scaler.scale(loss).backward() 176 | scaler.unscale_(optimizer) 177 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 178 | 179 | scaler.step(optimizer) 180 | scheduler.step() 181 | scaler.update() 182 | 183 | logger.push(metrics) 184 | 185 | if total_steps % VAL_FREQ == VAL_FREQ - 1: 186 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name) 187 | torch.save(model.state_dict(), PATH) 188 | 189 | results = {} 190 | for val_dataset in args.validation: 191 | if val_dataset == 'chairs': 192 | results.update(evaluate.validate_chairs(model.module)) 193 | elif val_dataset == 'sintel': 194 | results.update(evaluate.validate_sintel(model.module)) 195 | elif val_dataset == 'kitti': 196 | results.update(evaluate.validate_kitti(model.module)) 197 | 198 | logger.write_dict(results) 199 | 200 | model.train() 201 | if args.stage != 'chairs': 202 | model.module.freeze_bn() 203 | 204 | total_steps += 1 205 | 206 | if total_steps > args.num_steps: 207 | should_keep_training = False 208 | break 209 | 210 | logger.close() 211 | PATH = 'checkpoints/%s.pth' % args.name 212 | torch.save(model.state_dict(), PATH) 213 | 214 | return PATH 215 | 216 | 217 | if __name__ == '__main__': 218 | parser = argparse.ArgumentParser() 219 | parser.add_argument('--name', default='raft', help="name your experiment") 220 | parser.add_argument('--stage', help="determines which dataset to use for training") 221 | parser.add_argument('--restore_ckpt', help="restore checkpoint") 222 | parser.add_argument('--small', action='store_true', help='use small model') 223 | parser.add_argument('--validation', type=str, nargs='+') 224 | 225 | parser.add_argument('--lr', type=float, default=0.00002) 226 | parser.add_argument('--num_steps', type=int, default=100000) 227 | parser.add_argument('--batch_size', type=int, default=6) 228 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512]) 229 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1]) 230 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 231 | 232 | parser.add_argument('--iters', type=int, default=12) 233 | parser.add_argument('--wdecay', type=float, default=.00005) 234 | parser.add_argument('--epsilon', type=float, default=1e-8) 235 | parser.add_argument('--clip', type=float, default=1.0) 236 | parser.add_argument('--dropout', type=float, default=0.0) 237 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') 238 | parser.add_argument('--add_noise', action='store_true') 239 | args = parser.parse_args() 240 | 241 | torch.manual_seed(1234) 242 | np.random.seed(1234) 243 | 244 | if not os.path.isdir('checkpoints'): 245 | os.mkdir('checkpoints') 246 | 247 | train(args) -------------------------------------------------------------------------------- /submodules/RAFT/train_mixed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision 7 | -------------------------------------------------------------------------------- /submodules/RAFT/train_standard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 7 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | diff_gaussian_rasterization.egg-info/ 3 | dist/ 4 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/glm"] 2 | path = third_party/glm 3 | url = https://github.com/g-truc/glm.git 4 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | cmake_minimum_required(VERSION 3.20) 13 | 14 | project(DiffRast LANGUAGES CUDA CXX) 15 | 16 | set(CMAKE_CXX_STANDARD 17) 17 | set(CMAKE_CXX_EXTENSIONS OFF) 18 | set(CMAKE_CUDA_STANDARD 17) 19 | 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 21 | 22 | add_library(CudaRasterizer 23 | cuda_rasterizer/backward.h 24 | cuda_rasterizer/backward.cu 25 | cuda_rasterizer/forward.h 26 | cuda_rasterizer/forward.cu 27 | cuda_rasterizer/auxiliary.h 28 | cuda_rasterizer/rasterizer_impl.cu 29 | cuda_rasterizer/rasterizer_impl.h 30 | cuda_rasterizer/rasterizer.h 31 | ) 32 | 33 | set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86") 34 | 35 | target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer) 36 | target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 37 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/README.md: -------------------------------------------------------------------------------- 1 | # Differential Gaussian Rasterization 2 | 3 | Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us. 4 | 5 |
6 |
7 |

BibTeX

8 |
@Article{kerbl3Dgaussians,
 9 |       author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
10 |       title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
11 |       journal      = {ACM Transactions on Graphics},
12 |       number       = {4},
13 |       volume       = {42},
14 |       month        = {July},
15 |       year         = {2023},
16 |       url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
17 | }
18 |
19 |
-------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 13 | #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 14 | 15 | #include "config.h" 16 | #include "stdio.h" 17 | 18 | #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) 19 | #define NUM_WARPS (BLOCK_SIZE/32) 20 | 21 | // Spherical harmonics coefficients 22 | __device__ const float SH_C0 = 0.28209479177387814f; 23 | __device__ const float SH_C1 = 0.4886025119029199f; 24 | __device__ const float SH_C2[] = { 25 | 1.0925484305920792f, 26 | -1.0925484305920792f, 27 | 0.31539156525252005f, 28 | -1.0925484305920792f, 29 | 0.5462742152960396f 30 | }; 31 | __device__ const float SH_C3[] = { 32 | -0.5900435899266435f, 33 | 2.890611442640554f, 34 | -0.4570457994644658f, 35 | 0.3731763325901154f, 36 | -0.4570457994644658f, 37 | 1.445305721320277f, 38 | -0.5900435899266435f 39 | }; 40 | 41 | __forceinline__ __device__ float ndc2Pix(float v, int S) 42 | { 43 | return ((v + 1.0) * S - 1.0) * 0.5; 44 | } 45 | 46 | __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid) 47 | { 48 | rect_min = { 49 | min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))), 50 | min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y))) 51 | }; 52 | rect_max = { 53 | min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))), 54 | min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y))) 55 | }; 56 | } 57 | 58 | __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) 59 | { 60 | float3 transformed = { 61 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 62 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 63 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 64 | }; 65 | return transformed; 66 | } 67 | 68 | __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix) 69 | { 70 | float4 transformed = { 71 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 72 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 73 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 74 | matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15] 75 | }; 76 | return transformed; 77 | } 78 | 79 | __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) 80 | { 81 | float3 transformed = { 82 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, 83 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, 84 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, 85 | }; 86 | return transformed; 87 | } 88 | 89 | __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix) 90 | { 91 | float3 transformed = { 92 | matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z, 93 | matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z, 94 | matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z, 95 | }; 96 | return transformed; 97 | } 98 | 99 | __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) 100 | { 101 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 102 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 103 | float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 104 | return dnormvdz; 105 | } 106 | 107 | __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) 108 | { 109 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 110 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 111 | 112 | float3 dnormvdv; 113 | dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; 114 | dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32; 115 | dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 116 | return dnormvdv; 117 | } 118 | 119 | __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) 120 | { 121 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; 122 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 123 | 124 | float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; 125 | float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; 126 | float4 dnormvdv; 127 | dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32; 128 | dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32; 129 | dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32; 130 | dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32; 131 | return dnormvdv; 132 | } 133 | 134 | __forceinline__ __device__ float sigmoid(float x) 135 | { 136 | return 1.0f / (1.0f + expf(-x)); 137 | } 138 | 139 | __forceinline__ __device__ bool in_frustum(int idx, 140 | const float* orig_points, 141 | const float* viewmatrix, 142 | const float* projmatrix, 143 | bool prefiltered, 144 | float3& p_view) 145 | { 146 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; 147 | 148 | // Bring points to screen space 149 | float4 p_hom = transformPoint4x4(p_orig, projmatrix); 150 | float p_w = 1.0f / (p_hom.w + 0.0000001f); 151 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; 152 | p_view = transformPoint4x3(p_orig, viewmatrix); 153 | 154 | if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3))) 155 | { 156 | if (prefiltered) 157 | { 158 | printf("Point is filtered although prefiltered is set. This shouldn't happen!"); 159 | __trap(); 160 | } 161 | return false; 162 | } 163 | return true; 164 | } 165 | 166 | #define CHECK_CUDA(A, debug) \ 167 | A; if(debug) { \ 168 | auto ret = cudaDeviceSynchronize(); \ 169 | if (ret != cudaSuccess) { \ 170 | std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \ 171 | throw std::runtime_error(cudaGetErrorString(ret)); \ 172 | } \ 173 | } 174 | 175 | #endif -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/backward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace BACKWARD 22 | { 23 | void render( 24 | const dim3 grid, dim3 block, 25 | const uint2* ranges, 26 | const uint32_t* point_list, 27 | int W, int H, 28 | const float* bg_color, 29 | const float2* means2D, 30 | const float4* conic_opacity, 31 | const float* colors, 32 | const float* depths, 33 | const float* final_Ts, 34 | const uint32_t* n_contrib, 35 | const float* dL_dpixels, 36 | const float* dL_depths, 37 | float3* dL_dmean2D, 38 | float4* dL_dconic2D, 39 | float* dL_dopacity, 40 | float* dL_dcolors); 41 | 42 | void preprocess( 43 | int P, int D, int M, 44 | const float3* means, 45 | const int* radii, 46 | const float* shs, 47 | const bool* clamped, 48 | const glm::vec3* scales, 49 | const glm::vec4* rotations, 50 | const float scale_modifier, 51 | const float* cov3Ds, 52 | const float* view, 53 | const float* proj, 54 | const float focal_x, float focal_y, 55 | const float tan_fovx, float tan_fovy, 56 | const glm::vec3* campos, 57 | const float3* dL_dmean2D, 58 | const float* dL_dconics, 59 | glm::vec3* dL_dmeans, 60 | float* dL_dcolor, 61 | float* dL_dcov3D, 62 | float* dL_dsh, 63 | glm::vec3* dL_dscale, 64 | glm::vec4* dL_drot); 65 | } 66 | 67 | #endif 68 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED 13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED 14 | 15 | #define NUM_CHANNELS 3 // Default 3, RGB 16 | #define BLOCK_X 16 17 | #define BLOCK_Y 16 18 | 19 | #endif -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/forward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace FORWARD 22 | { 23 | // Perform initial steps for each Gaussian prior to rasterization. 24 | void preprocess(int P, int D, int M, 25 | const float* orig_points, 26 | const glm::vec3* scales, 27 | const float scale_modifier, 28 | const glm::vec4* rotations, 29 | const float* opacities, 30 | const float* shs, 31 | bool* clamped, 32 | const float* cov3D_precomp, 33 | const float* colors_precomp, 34 | const float* viewmatrix, 35 | const float* projmatrix, 36 | const glm::vec3* cam_pos, 37 | const int W, int H, 38 | const float focal_x, float focal_y, 39 | const float tan_fovx, float tan_fovy, 40 | int* radii, 41 | float2* points_xy_image, 42 | float* depths, 43 | float* cov3Ds, 44 | float* colors, 45 | float4* conic_opacity, 46 | const dim3 grid, 47 | uint32_t* tiles_touched, 48 | bool prefiltered); 49 | 50 | // Main rasterization method. 51 | void render( 52 | const dim3 grid, dim3 block, 53 | const uint2* ranges, 54 | const uint32_t* point_list, 55 | int W, int H, 56 | const float2* points_xy_image, 57 | const float* features, 58 | const float* depths, 59 | const float4* conic_opacity, 60 | float* final_T, 61 | uint32_t* n_contrib, 62 | const float* bg_color, 63 | float* out_color, 64 | float* out_depth); 65 | } 66 | 67 | 68 | #endif 69 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_H_INCLUDED 13 | #define CUDA_RASTERIZER_H_INCLUDED 14 | 15 | #include 16 | #include 17 | 18 | namespace CudaRasterizer 19 | { 20 | class Rasterizer 21 | { 22 | public: 23 | 24 | static void markVisible( 25 | int P, 26 | float* means3D, 27 | float* viewmatrix, 28 | float* projmatrix, 29 | bool* present); 30 | 31 | static int forward( 32 | std::function geometryBuffer, 33 | std::function binningBuffer, 34 | std::function imageBuffer, 35 | const int P, int D, int M, 36 | const float* background, 37 | const int width, int height, 38 | const float* means3D, 39 | const float* shs, 40 | const float* colors_precomp, 41 | const float* opacities, 42 | const float* scales, 43 | const float scale_modifier, 44 | const float* rotations, 45 | const float* cov3D_precomp, 46 | const float* viewmatrix, 47 | const float* projmatrix, 48 | const float* cam_pos, 49 | const float tan_fovx, float tan_fovy, 50 | const bool prefiltered, 51 | float* out_color, 52 | float* out_depth, 53 | int* radii = nullptr, 54 | bool debug = false); 55 | 56 | static void backward( 57 | const int P, int D, int M, int R, 58 | const float* background, 59 | const int width, int height, 60 | const float* means3D, 61 | const float* shs, 62 | const float* colors_precomp, 63 | const float* scales, 64 | const float scale_modifier, 65 | const float* rotations, 66 | const float* cov3D_precomp, 67 | const float* viewmatrix, 68 | const float* projmatrix, 69 | const float* campos, 70 | const float tan_fovx, float tan_fovy, 71 | const int* radii, 72 | char* geom_buffer, 73 | char* binning_buffer, 74 | char* image_buffer, 75 | const float* dL_dpix, 76 | const float* dL_depths, 77 | float* dL_dmean2D, 78 | float* dL_dconic, 79 | float* dL_dopacity, 80 | float* dL_dcolor, 81 | float* dL_dmean3D, 82 | float* dL_dcov3D, 83 | float* dL_dsh, 84 | float* dL_dscale, 85 | float* dL_drot, 86 | bool debug); 87 | }; 88 | }; 89 | 90 | #endif 91 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include "rasterizer.h" 17 | #include 18 | 19 | namespace CudaRasterizer 20 | { 21 | template 22 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) 23 | { 24 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); 25 | ptr = reinterpret_cast(offset); 26 | chunk = reinterpret_cast(ptr + count); 27 | } 28 | 29 | struct GeometryState 30 | { 31 | size_t scan_size; 32 | float* depths; 33 | char* scanning_space; 34 | bool* clamped; 35 | int* internal_radii; 36 | float2* means2D; 37 | float* cov3D; 38 | float4* conic_opacity; 39 | float* rgb; 40 | uint32_t* point_offsets; 41 | uint32_t* tiles_touched; 42 | 43 | static GeometryState fromChunk(char*& chunk, size_t P); 44 | }; 45 | 46 | struct ImageState 47 | { 48 | uint2* ranges; 49 | uint32_t* n_contrib; 50 | float* accum_alpha; 51 | 52 | static ImageState fromChunk(char*& chunk, size_t N); 53 | }; 54 | 55 | struct BinningState 56 | { 57 | size_t sorting_size; 58 | uint64_t* point_list_keys_unsorted; 59 | uint64_t* point_list_keys; 60 | uint32_t* point_list_unsorted; 61 | uint32_t* point_list; 62 | char* list_sorting_space; 63 | 64 | static BinningState fromChunk(char*& chunk, size_t P); 65 | }; 66 | 67 | template 68 | size_t required(size_t P) 69 | { 70 | char* size = nullptr; 71 | T::fromChunk(size, P); 72 | return ((size_t)size) + 128; 73 | } 74 | }; -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/_C.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/_C.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/_C.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/_C.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from typing import NamedTuple 13 | import torch.nn as nn 14 | import torch 15 | from . import _C 16 | 17 | def cpu_deep_copy_tuple(input_tuple): 18 | copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] 19 | return tuple(copied_tensors) 20 | 21 | def rasterize_gaussians( 22 | means3D, 23 | means2D, 24 | sh, 25 | colors_precomp, 26 | opacities, 27 | scales, 28 | rotations, 29 | cov3Ds_precomp, 30 | raster_settings, 31 | ): 32 | return _RasterizeGaussians.apply( 33 | means3D, 34 | means2D, 35 | sh, 36 | colors_precomp, 37 | opacities, 38 | scales, 39 | rotations, 40 | cov3Ds_precomp, 41 | raster_settings, 42 | ) 43 | 44 | class _RasterizeGaussians(torch.autograd.Function): 45 | @staticmethod 46 | def forward( 47 | ctx, 48 | means3D, 49 | means2D, 50 | sh, 51 | colors_precomp, 52 | opacities, 53 | scales, 54 | rotations, 55 | cov3Ds_precomp, 56 | raster_settings, 57 | ): 58 | 59 | # Restructure arguments the way that the C++ lib expects them 60 | args = ( 61 | raster_settings.bg, 62 | means3D, 63 | colors_precomp, 64 | opacities, 65 | scales, 66 | rotations, 67 | raster_settings.scale_modifier, 68 | cov3Ds_precomp, 69 | raster_settings.viewmatrix, 70 | raster_settings.projmatrix, 71 | raster_settings.tanfovx, 72 | raster_settings.tanfovy, 73 | raster_settings.image_height, 74 | raster_settings.image_width, 75 | sh, 76 | raster_settings.sh_degree, 77 | raster_settings.campos, 78 | raster_settings.prefiltered, 79 | raster_settings.debug 80 | ) 81 | 82 | # Invoke C++/CUDA rasterizer 83 | if raster_settings.debug: 84 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted 85 | try: 86 | num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) 87 | except Exception as ex: 88 | torch.save(cpu_args, "snapshot_fw.dump") 89 | print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") 90 | raise ex 91 | else: 92 | num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) 93 | 94 | # Keep relevant tensors for backward 95 | ctx.raster_settings = raster_settings 96 | ctx.num_rendered = num_rendered 97 | ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer) 98 | return color, radii, depth 99 | 100 | @staticmethod 101 | def backward(ctx, grad_out_color, grad_radii, grad_depth): 102 | 103 | # Restore necessary values from context 104 | num_rendered = ctx.num_rendered 105 | raster_settings = ctx.raster_settings 106 | colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors 107 | 108 | # Restructure args as C++ method expects them 109 | args = (raster_settings.bg, 110 | means3D, 111 | radii, 112 | colors_precomp, 113 | scales, 114 | rotations, 115 | raster_settings.scale_modifier, 116 | cov3Ds_precomp, 117 | raster_settings.viewmatrix, 118 | raster_settings.projmatrix, 119 | raster_settings.tanfovx, 120 | raster_settings.tanfovy, 121 | grad_out_color, 122 | grad_depth, 123 | sh, 124 | raster_settings.sh_degree, 125 | raster_settings.campos, 126 | geomBuffer, 127 | num_rendered, 128 | binningBuffer, 129 | imgBuffer, 130 | raster_settings.debug) 131 | 132 | # Compute gradients for relevant tensors by invoking backward method 133 | if raster_settings.debug: 134 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted 135 | try: 136 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) 137 | except Exception as ex: 138 | torch.save(cpu_args, "snapshot_bw.dump") 139 | print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") 140 | raise ex 141 | else: 142 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) 143 | 144 | grads = ( 145 | grad_means3D, 146 | grad_means2D, 147 | grad_sh, 148 | grad_colors_precomp, 149 | grad_opacities, 150 | grad_scales, 151 | grad_rotations, 152 | grad_cov3Ds_precomp, 153 | None, 154 | ) 155 | 156 | return grads 157 | 158 | class GaussianRasterizationSettings(NamedTuple): 159 | image_height: int 160 | image_width: int 161 | tanfovx : float 162 | tanfovy : float 163 | bg : torch.Tensor 164 | scale_modifier : float 165 | viewmatrix : torch.Tensor 166 | projmatrix : torch.Tensor 167 | sh_degree : int 168 | campos : torch.Tensor 169 | prefiltered : bool 170 | debug : bool 171 | 172 | class GaussianRasterizer(nn.Module): 173 | def __init__(self, raster_settings): 174 | super().__init__() 175 | self.raster_settings = raster_settings 176 | 177 | def markVisible(self, positions): 178 | # Mark visible points (based on frustum culling for camera) with a boolean 179 | with torch.no_grad(): 180 | raster_settings = self.raster_settings 181 | visible = _C.mark_visible( 182 | positions, 183 | raster_settings.viewmatrix, 184 | raster_settings.projmatrix) 185 | 186 | return visible 187 | 188 | def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): 189 | 190 | raster_settings = self.raster_settings 191 | 192 | if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): 193 | raise Exception('Please provide excatly one of either SHs or precomputed colors!') 194 | 195 | if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None): 196 | raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!') 197 | 198 | if shs is None: 199 | shs = torch.Tensor([]) 200 | if colors_precomp is None: 201 | colors_precomp = torch.Tensor([]) 202 | 203 | if scales is None: 204 | scales = torch.Tensor([]) 205 | if rotations is None: 206 | rotations = torch.Tensor([]) 207 | if cov3D_precomp is None: 208 | cov3D_precomp = torch.Tensor([]) 209 | 210 | # Invoke C++/CUDA rasterization routine 211 | return rasterize_gaussians( 212 | means3D, 213 | means2D, 214 | shs, 215 | colors_precomp, 216 | opacities, 217 | scales, 218 | rotations, 219 | cov3D_precomp, 220 | raster_settings, 221 | ) 222 | 223 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "rasterize_points.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); 17 | m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); 18 | m.def("mark_visible", &markVisible); 19 | } -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/rasterize_points.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include "cuda_rasterizer/config.h" 22 | #include "cuda_rasterizer/rasterizer.h" 23 | #include 24 | #include 25 | #include 26 | 27 | std::function resizeFunctional(torch::Tensor& t) { 28 | auto lambda = [&t](size_t N) { 29 | t.resize_({(long long)N}); 30 | return reinterpret_cast(t.contiguous().data_ptr()); 31 | }; 32 | return lambda; 33 | } 34 | 35 | std::tuple 36 | RasterizeGaussiansCUDA( 37 | const torch::Tensor& background, 38 | const torch::Tensor& means3D, 39 | const torch::Tensor& colors, 40 | const torch::Tensor& opacity, 41 | const torch::Tensor& scales, 42 | const torch::Tensor& rotations, 43 | const float scale_modifier, 44 | const torch::Tensor& cov3D_precomp, 45 | const torch::Tensor& viewmatrix, 46 | const torch::Tensor& projmatrix, 47 | const float tan_fovx, 48 | const float tan_fovy, 49 | const int image_height, 50 | const int image_width, 51 | const torch::Tensor& sh, 52 | const int degree, 53 | const torch::Tensor& campos, 54 | const bool prefiltered, 55 | const bool debug) 56 | { 57 | if (means3D.ndimension() != 2 || means3D.size(1) != 3) { 58 | AT_ERROR("means3D must have dimensions (num_points, 3)"); 59 | } 60 | 61 | const int P = means3D.size(0); 62 | const int H = image_height; 63 | const int W = image_width; 64 | 65 | auto int_opts = means3D.options().dtype(torch::kInt32); 66 | auto float_opts = means3D.options().dtype(torch::kFloat32); 67 | 68 | torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); 69 | torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts); 70 | torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 71 | 72 | torch::Device device(torch::kCUDA); 73 | torch::TensorOptions options(torch::kByte); 74 | torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); 75 | torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); 76 | torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); 77 | std::function geomFunc = resizeFunctional(geomBuffer); 78 | std::function binningFunc = resizeFunctional(binningBuffer); 79 | std::function imgFunc = resizeFunctional(imgBuffer); 80 | 81 | int rendered = 0; 82 | if(P != 0) 83 | { 84 | int M = 0; 85 | if(sh.size(0) != 0) 86 | { 87 | M = sh.size(1); 88 | } 89 | 90 | rendered = CudaRasterizer::Rasterizer::forward( 91 | geomFunc, 92 | binningFunc, 93 | imgFunc, 94 | P, degree, M, 95 | background.contiguous().data(), 96 | W, H, 97 | means3D.contiguous().data(), 98 | sh.contiguous().data_ptr(), 99 | colors.contiguous().data(), 100 | opacity.contiguous().data(), 101 | scales.contiguous().data_ptr(), 102 | scale_modifier, 103 | rotations.contiguous().data_ptr(), 104 | cov3D_precomp.contiguous().data(), 105 | viewmatrix.contiguous().data(), 106 | projmatrix.contiguous().data(), 107 | campos.contiguous().data(), 108 | tan_fovx, 109 | tan_fovy, 110 | prefiltered, 111 | out_color.contiguous().data(), 112 | out_depth.contiguous().data(), 113 | radii.contiguous().data(), 114 | debug); 115 | } 116 | return std::make_tuple(rendered, out_color, out_depth, radii, geomBuffer, binningBuffer, imgBuffer); 117 | } 118 | 119 | std::tuple 120 | RasterizeGaussiansBackwardCUDA( 121 | const torch::Tensor& background, 122 | const torch::Tensor& means3D, 123 | const torch::Tensor& radii, 124 | const torch::Tensor& colors, 125 | const torch::Tensor& scales, 126 | const torch::Tensor& rotations, 127 | const float scale_modifier, 128 | const torch::Tensor& cov3D_precomp, 129 | const torch::Tensor& viewmatrix, 130 | const torch::Tensor& projmatrix, 131 | const float tan_fovx, 132 | const float tan_fovy, 133 | const torch::Tensor& dL_dout_color, 134 | const torch::Tensor& dL_dout_depth, 135 | const torch::Tensor& sh, 136 | const int degree, 137 | const torch::Tensor& campos, 138 | const torch::Tensor& geomBuffer, 139 | const int R, 140 | const torch::Tensor& binningBuffer, 141 | const torch::Tensor& imageBuffer, 142 | const bool debug) 143 | { 144 | const int P = means3D.size(0); 145 | const int H = dL_dout_color.size(1); 146 | const int W = dL_dout_color.size(2); 147 | 148 | int M = 0; 149 | if(sh.size(0) != 0) 150 | { 151 | M = sh.size(1); 152 | } 153 | 154 | torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); 155 | torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); 156 | torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); 157 | torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options()); 158 | torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); 159 | torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options()); 160 | torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); 161 | torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options()); 162 | torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); 163 | 164 | if(P != 0) 165 | { 166 | CudaRasterizer::Rasterizer::backward(P, degree, M, R, 167 | background.contiguous().data(), 168 | W, H, 169 | means3D.contiguous().data(), 170 | sh.contiguous().data(), 171 | colors.contiguous().data(), 172 | scales.data_ptr(), 173 | scale_modifier, 174 | rotations.data_ptr(), 175 | cov3D_precomp.contiguous().data(), 176 | viewmatrix.contiguous().data(), 177 | projmatrix.contiguous().data(), 178 | campos.contiguous().data(), 179 | tan_fovx, 180 | tan_fovy, 181 | radii.contiguous().data(), 182 | reinterpret_cast(geomBuffer.contiguous().data_ptr()), 183 | reinterpret_cast(binningBuffer.contiguous().data_ptr()), 184 | reinterpret_cast(imageBuffer.contiguous().data_ptr()), 185 | dL_dout_color.contiguous().data(), 186 | dL_dout_depth.contiguous().data(), 187 | dL_dmeans2D.contiguous().data(), 188 | dL_dconic.contiguous().data(), 189 | dL_dopacity.contiguous().data(), 190 | dL_dcolors.contiguous().data(), 191 | dL_dmeans3D.contiguous().data(), 192 | dL_dcov3D.contiguous().data(), 193 | dL_dsh.contiguous().data(), 194 | dL_dscales.contiguous().data(), 195 | dL_drotations.contiguous().data(), 196 | debug); 197 | } 198 | 199 | return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations); 200 | } 201 | 202 | torch::Tensor markVisible( 203 | torch::Tensor& means3D, 204 | torch::Tensor& viewmatrix, 205 | torch::Tensor& projmatrix) 206 | { 207 | const int P = means3D.size(0); 208 | 209 | torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); 210 | 211 | if(P != 0) 212 | { 213 | CudaRasterizer::Rasterizer::markVisible(P, 214 | means3D.contiguous().data(), 215 | viewmatrix.contiguous().data(), 216 | projmatrix.contiguous().data(), 217 | present.contiguous().data()); 218 | } 219 | 220 | return present; 221 | } 222 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/rasterize_points.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | std::tuple 19 | RasterizeGaussiansCUDA( 20 | const torch::Tensor& background, 21 | const torch::Tensor& means3D, 22 | const torch::Tensor& colors, 23 | const torch::Tensor& opacity, 24 | const torch::Tensor& scales, 25 | const torch::Tensor& rotations, 26 | const float scale_modifier, 27 | const torch::Tensor& cov3D_precomp, 28 | const torch::Tensor& viewmatrix, 29 | const torch::Tensor& projmatrix, 30 | const float tan_fovx, 31 | const float tan_fovy, 32 | const int image_height, 33 | const int image_width, 34 | const torch::Tensor& sh, 35 | const int degree, 36 | const torch::Tensor& campos, 37 | const bool prefiltered, 38 | const bool debug); 39 | 40 | std::tuple 41 | RasterizeGaussiansBackwardCUDA( 42 | const torch::Tensor& background, 43 | const torch::Tensor& means3D, 44 | const torch::Tensor& radii, 45 | const torch::Tensor& colors, 46 | const torch::Tensor& scales, 47 | const torch::Tensor& rotations, 48 | const float scale_modifier, 49 | const torch::Tensor& cov3D_precomp, 50 | const torch::Tensor& viewmatrix, 51 | const torch::Tensor& projmatrix, 52 | const float tan_fovx, 53 | const float tan_fovy, 54 | const torch::Tensor& dL_dout_color, 55 | const torch::Tensor& dL_dout_depth, 56 | const torch::Tensor& sh, 57 | const int degree, 58 | const torch::Tensor& campos, 59 | const torch::Tensor& geomBuffer, 60 | const int R, 61 | const torch::Tensor& binningBuffer, 62 | const torch::Tensor& imageBuffer, 63 | const bool debug); 64 | 65 | torch::Tensor markVisible( 66 | torch::Tensor& means3D, 67 | torch::Tensor& viewmatrix, 68 | torch::Tensor& projmatrix); 69 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | os.path.dirname(os.path.abspath(__file__)) 16 | 17 | setup( 18 | name="diff_gaussian_rasterization", 19 | packages=['diff_gaussian_rasterization'], 20 | ext_modules=[ 21 | CUDAExtension( 22 | name="diff_gaussian_rasterization._C", 23 | sources=[ 24 | "cuda_rasterizer/rasterizer_impl.cu", 25 | "cuda_rasterizer/forward.cu", 26 | "cuda_rasterizer/backward.cu", 27 | "rasterize_points.cu", 28 | "ext.cpp"], 29 | extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) 30 | ], 31 | cmdclass={ 32 | 'build_ext': BuildExtension 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /submodules/simple-knn/build/lib.linux-x86_64-cpython-37/simple_knn/_C.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/build/lib.linux-x86_64-cpython-37/simple_knn/_C.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /submodules/simple-knn/build/temp.linux-x86_64-cpython-37/ext.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/build/temp.linux-x86_64-cpython-37/ext.o -------------------------------------------------------------------------------- /submodules/simple-knn/build/temp.linux-x86_64-cpython-37/simple_knn.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/build/temp.linux-x86_64-cpython-37/simple_knn.o -------------------------------------------------------------------------------- /submodules/simple-knn/build/temp.linux-x86_64-cpython-37/spatial.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/build/temp.linux-x86_64-cpython-37/spatial.o -------------------------------------------------------------------------------- /submodules/simple-knn/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "spatial.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("distCUDA2", &distCUDA2); 17 | } 18 | -------------------------------------------------------------------------------- /submodules/simple-knn/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | cxx_compiler_flags = [] 17 | 18 | if os.name == 'nt': 19 | cxx_compiler_flags.append("/wd4624") 20 | 21 | setup( 22 | name="simple_knn", 23 | ext_modules=[ 24 | CUDAExtension( 25 | name="simple_knn._C", 26 | sources=[ 27 | "spatial.cu", 28 | "simple_knn.cu", 29 | "ext.cpp"], 30 | extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) 31 | ], 32 | cmdclass={ 33 | 'build_ext': BuildExtension 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #define BOX_SIZE 1024 13 | 14 | #include "cuda_runtime.h" 15 | #include "device_launch_parameters.h" 16 | #include "simple_knn.h" 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #define __CUDACC__ 24 | #include 25 | #include 26 | 27 | namespace cg = cooperative_groups; 28 | 29 | struct CustomMin 30 | { 31 | __device__ __forceinline__ 32 | float3 operator()(const float3& a, const float3& b) const { 33 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; 34 | } 35 | }; 36 | 37 | struct CustomMax 38 | { 39 | __device__ __forceinline__ 40 | float3 operator()(const float3& a, const float3& b) const { 41 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; 42 | } 43 | }; 44 | 45 | __host__ __device__ uint32_t prepMorton(uint32_t x) 46 | { 47 | x = (x | (x << 16)) & 0x030000FF; 48 | x = (x | (x << 8)) & 0x0300F00F; 49 | x = (x | (x << 4)) & 0x030C30C3; 50 | x = (x | (x << 2)) & 0x09249249; 51 | return x; 52 | } 53 | 54 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) 55 | { 56 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); 57 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); 58 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); 59 | 60 | return x | (y << 1) | (z << 2); 61 | } 62 | 63 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) 64 | { 65 | auto idx = cg::this_grid().thread_rank(); 66 | if (idx >= P) 67 | return; 68 | 69 | codes[idx] = coord2Morton(points[idx], minn, maxx); 70 | } 71 | 72 | struct MinMax 73 | { 74 | float3 minn; 75 | float3 maxx; 76 | }; 77 | 78 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) 79 | { 80 | auto idx = cg::this_grid().thread_rank(); 81 | 82 | MinMax me; 83 | if (idx < P) 84 | { 85 | me.minn = points[indices[idx]]; 86 | me.maxx = points[indices[idx]]; 87 | } 88 | else 89 | { 90 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; 91 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; 92 | } 93 | 94 | __shared__ MinMax redResult[BOX_SIZE]; 95 | 96 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2) 97 | { 98 | if (threadIdx.x < 2 * off) 99 | redResult[threadIdx.x] = me; 100 | __syncthreads(); 101 | 102 | if (threadIdx.x < off) 103 | { 104 | MinMax other = redResult[threadIdx.x + off]; 105 | me.minn.x = min(me.minn.x, other.minn.x); 106 | me.minn.y = min(me.minn.y, other.minn.y); 107 | me.minn.z = min(me.minn.z, other.minn.z); 108 | me.maxx.x = max(me.maxx.x, other.maxx.x); 109 | me.maxx.y = max(me.maxx.y, other.maxx.y); 110 | me.maxx.z = max(me.maxx.z, other.maxx.z); 111 | } 112 | __syncthreads(); 113 | } 114 | 115 | if (threadIdx.x == 0) 116 | boxes[blockIdx.x] = me; 117 | } 118 | 119 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) 120 | { 121 | float3 diff = { 0, 0, 0 }; 122 | if (p.x < box.minn.x || p.x > box.maxx.x) 123 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); 124 | if (p.y < box.minn.y || p.y > box.maxx.y) 125 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); 126 | if (p.z < box.minn.z || p.z > box.maxx.z) 127 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); 128 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; 129 | } 130 | 131 | template 132 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn) 133 | { 134 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; 135 | float dist = d.x * d.x + d.y * d.y + d.z * d.z; 136 | for (int j = 0; j < K; j++) 137 | { 138 | if (knn[j] > dist) 139 | { 140 | float t = knn[j]; 141 | knn[j] = dist; 142 | dist = t; 143 | } 144 | } 145 | } 146 | 147 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) 148 | { 149 | int idx = cg::this_grid().thread_rank(); 150 | if (idx >= P) 151 | return; 152 | 153 | float3 point = points[indices[idx]]; 154 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; 155 | 156 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) 157 | { 158 | if (i == idx) 159 | continue; 160 | updateKBest<3>(point, points[indices[i]], best); 161 | } 162 | 163 | float reject = best[2]; 164 | best[0] = FLT_MAX; 165 | best[1] = FLT_MAX; 166 | best[2] = FLT_MAX; 167 | 168 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) 169 | { 170 | MinMax box = boxes[b]; 171 | float dist = distBoxPoint(box, point); 172 | if (dist > reject || dist > best[2]) 173 | continue; 174 | 175 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) 176 | { 177 | if (i == idx) 178 | continue; 179 | updateKBest<3>(point, points[indices[i]], best); 180 | } 181 | } 182 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; 183 | } 184 | 185 | void SimpleKNN::knn(int P, float3* points, float* meanDists) 186 | { 187 | float3* result; 188 | cudaMalloc(&result, sizeof(float3)); 189 | size_t temp_storage_bytes; 190 | 191 | float3 init = { 0, 0, 0 }, minn, maxx; 192 | 193 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); 194 | thrust::device_vector temp_storage(temp_storage_bytes); 195 | 196 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); 197 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); 198 | 199 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); 200 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); 201 | 202 | thrust::device_vector morton(P); 203 | thrust::device_vector morton_sorted(P); 204 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); 205 | 206 | thrust::device_vector indices(P); 207 | thrust::sequence(indices.begin(), indices.end()); 208 | thrust::device_vector indices_sorted(P); 209 | 210 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 211 | temp_storage.resize(temp_storage_bytes); 212 | 213 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 214 | 215 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; 216 | thrust::device_vector boxes(num_boxes); 217 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); 218 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); 219 | 220 | cudaFree(result); 221 | } -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: simple-knn 3 | Version: 0.0.0 4 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | ext.cpp 2 | setup.py 3 | simple_knn.cu 4 | spatial.cu 5 | simple_knn.egg-info/PKG-INFO 6 | simple_knn.egg-info/SOURCES.txt 7 | simple_knn.egg-info/dependency_links.txt 8 | simple_knn.egg-info/top_level.txt -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | simple_knn 2 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef SIMPLEKNN_H_INCLUDED 13 | #define SIMPLEKNN_H_INCLUDED 14 | 15 | class SimpleKNN 16 | { 17 | public: 18 | static void knn(int P, float3* points, float* meanDists); 19 | }; 20 | 21 | #endif -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/simple_knn/.gitkeep -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn/_C.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/simple_knn/_C.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn/_C.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/simple_knn/_C.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "spatial.h" 13 | #include "simple_knn.h" 14 | 15 | torch::Tensor 16 | distCUDA2(const torch::Tensor& points) 17 | { 18 | const int P = points.size(0); 19 | 20 | auto float_opts = points.options().dtype(torch::kFloat32); 21 | torch::Tensor means = torch::full({P}, 0.0, float_opts); 22 | 23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); 24 | 25 | return means; 26 | } -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | 14 | torch::Tensor distCUDA2(const torch::Tensor& points); -------------------------------------------------------------------------------- /utils/TIMES.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/TIMES.TTF -------------------------------------------------------------------------------- /utils/TIMESBD.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/TIMESBD.TTF -------------------------------------------------------------------------------- /utils/TIMESBI.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/TIMESBI.TTF -------------------------------------------------------------------------------- /utils/TIMESI.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/TIMESI.TTF -------------------------------------------------------------------------------- /utils/__pycache__/camera_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/camera_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/general_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/general_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/graphics_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/graphics_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/image_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/loss_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/params_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/params_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/scene_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/scene_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sh_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/sh_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/system_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/system_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/timer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/timer.cpython-37.pyc -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | 17 | WARNED = False 18 | 19 | def loadCam(args, id, cam_info, resolution_scale): 20 | 21 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 22 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 23 | image=cam_info.image, mask=None, gt_alpha_mask=None, depth=None, 24 | image_name=cam_info.image_name, uid=id, data_device=args.data_device, 25 | time = cam_info.time) 26 | 27 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 28 | camera_list = [] 29 | 30 | for id, c in enumerate(cam_infos): 31 | camera_list.append(loadCam(args, id, c, resolution_scale)) 32 | 33 | return camera_list 34 | 35 | def camera_to_JSON(id, camera : Camera): 36 | Rt = np.zeros((4, 4)) 37 | Rt[:3, :3] = camera.R.transpose() 38 | Rt[:3, 3] = camera.T 39 | Rt[3, 3] = 1.0 40 | 41 | W2C = np.linalg.inv(Rt) 42 | pos = W2C[:3, 3] 43 | rot = W2C[:3, :3] 44 | serializable_array_2d = [x.tolist() for x in rot] 45 | camera_entry = { 46 | 'id' : id, 47 | 'img_name' : camera.image_name, 48 | 'width' : camera.width, 49 | 'height' : camera.height, 50 | 'position': pos.tolist(), 51 | 'rotation': serializable_array_2d, 52 | 'fy' : fov2focal(camera.FovY, camera.height), 53 | 'fx' : fov2focal(camera.FovX, camera.width) 54 | } 55 | return camera_entry 56 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | import cv2 18 | 19 | def inpaint_rgb(rgb_image, mask): 20 | # Convert mask to uint8 21 | mask_uint8 = (mask * 255).astype(np.uint8) 22 | # Inpaint missing regions 23 | inpainted_image = cv2.inpaint(rgb_image, mask_uint8, inpaintRadius=5, flags=cv2.INPAINT_TELEA) 24 | 25 | return inpainted_image 26 | 27 | def inpaint_depth(depth_image, mask): 28 | # Convert mask to uint8 29 | mask_uint8 = (mask * 255).astype(np.uint8) 30 | 31 | # Inpaint missing regions 32 | inpainted_depth_image = cv2.inpaint((depth_image).astype(np.uint8), mask_uint8, inpaintRadius=5, flags=cv2.INPAINT_TELEA) 33 | 34 | return inpainted_depth_image 35 | 36 | def inverse_sigmoid(x): 37 | return torch.log(x/(1-x)) 38 | 39 | def PILtoTorch(pil_image, resolution): 40 | if resolution is not None: 41 | resized_image_PIL = pil_image.resize(resolution) 42 | else: 43 | resized_image_PIL = pil_image 44 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 45 | if len(resized_image.shape) == 3: 46 | return resized_image.permute(2, 0, 1) 47 | else: 48 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 49 | 50 | def get_expon_lr_func( 51 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 52 | ): 53 | """ 54 | Copied from Plenoxels 55 | 56 | Continuous learning rate decay function. Adapted from JaxNeRF 57 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 58 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 59 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 60 | function of lr_delay_mult, such that the initial learning rate is 61 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 62 | to the normal learning rate when steps>lr_delay_steps. 63 | :param conf: config subtree 'lr' or similar 64 | :param max_steps: int, the number of steps during optimization. 65 | :return HoF which takes step as input 66 | """ 67 | 68 | def helper(step): 69 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 70 | # Disable this parameter 71 | return 0.0 72 | if lr_delay_steps > 0: 73 | # A kind of reverse cosine decay. 74 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 75 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 76 | ) 77 | else: 78 | delay_rate = 1.0 79 | t = np.clip(step / max_steps, 0, 1) 80 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 81 | return delay_rate * log_lerp 82 | 83 | return helper 84 | 85 | def strip_lowerdiag(L): 86 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 87 | 88 | uncertainty[:, 0] = L[:, 0, 0] 89 | uncertainty[:, 1] = L[:, 0, 1] 90 | uncertainty[:, 2] = L[:, 0, 2] 91 | uncertainty[:, 3] = L[:, 1, 1] 92 | uncertainty[:, 4] = L[:, 1, 2] 93 | uncertainty[:, 5] = L[:, 2, 2] 94 | return uncertainty 95 | 96 | def strip_symmetric(sym): 97 | return strip_lowerdiag(sym) 98 | 99 | def build_rotation(r): 100 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 101 | q = r / norm[:, None] 102 | r = q[:, 0] 103 | x = q[:, 1] 104 | y = q[:, 2] 105 | z = q[:, 3] 106 | 107 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 108 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 109 | R[:, 0, 1] = 2 * (x*y - r*z) 110 | R[:, 0, 2] = 2 * (x*z + r*y) 111 | R[:, 1, 0] = 2 * (x*y + r*z) 112 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 113 | R[:, 1, 2] = 2 * (y*z - r*x) 114 | R[:, 2, 0] = 2 * (x*z - r*y) 115 | R[:, 2, 1] = 2 * (y*z + r*x) 116 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 117 | return R 118 | 119 | def build_scaling_rotation(s, r): 120 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 121 | R = build_rotation(r) 122 | 123 | L[:,0,0] = s[:,0] 124 | L[:,1,1] = s[:,1] 125 | L[:,2,2] = s[:,2] 126 | 127 | L = R @ L 128 | return L 129 | 130 | def safe_state(silent): 131 | old_f = sys.stdout 132 | class F: 133 | def __init__(self, silent): 134 | self.silent = silent 135 | 136 | def write(self, x): 137 | if not self.silent: 138 | if x.endswith("\n"): 139 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 140 | else: 141 | old_f.write(x) 142 | 143 | def flush(self): 144 | old_f.flush() 145 | 146 | sys.stdout = F(silent) 147 | 148 | random.seed(0) 149 | np.random.seed(0) 150 | torch.manual_seed(0) 151 | torch.cuda.set_device(torch.device("cuda:0")) 152 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() #w2c 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def getProjectionMatrix2(znear, zfar, K, h, w): 74 | near_fx = znear / K[0, 0] 75 | near_fy = znear / K[1, 1] 76 | left = - (w - K[0, 2]) * near_fx 77 | right = K[0, 2] * near_fx 78 | bottom = (K[1, 2] - h) * near_fy 79 | top = K[1, 2] * near_fy 80 | 81 | P = torch.zeros(4, 4) 82 | z_sign = 1.0 83 | P[0, 0] = 2.0 * znear / (right - left) 84 | P[1, 1] = 2.0 * znear / (top - bottom) 85 | P[0, 2] = (right + left) / (right - left) 86 | P[1, 2] = (top + bottom) / (top - bottom) 87 | P[3, 2] = z_sign 88 | P[2, 2] = z_sign * zfar / (zfar - znear) 89 | P[2, 3] = -(zfar * znear) / (zfar - znear) 90 | return P 91 | 92 | def fov2focal(fov, pixels): 93 | return pixels / (2 * math.tan(fov / 2)) 94 | 95 | def focal2fov(focal, pixels): 96 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import torch 14 | 15 | 16 | def tensor2array(tensor): 17 | if torch.is_tensor(tensor): 18 | return tensor.detach().cpu().numpy() 19 | else: 20 | return tensor 21 | 22 | def mse(img1, img2): 23 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 24 | 25 | @torch.no_grad() 26 | def psnr(img1, img2, mask=None): 27 | if mask is None: 28 | mse_mask = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 29 | else: 30 | if mask.shape[1] == 3: 31 | mse_mask = (((img1-img2)**2)*mask).sum() / ((mask.sum()+1e-10)) 32 | else: 33 | mse_mask = (((img1-img2)**2)*mask).sum() / ((mask.sum()+1e-10)*3.0) 34 | 35 | return 20 * torch.log10(1.0 / torch.sqrt(mse_mask)) 36 | 37 | def rmse(a, b, mask): 38 | """Compute rmse. 39 | """ 40 | if torch.is_tensor(a): 41 | a = tensor2array(a) 42 | if torch.is_tensor(b): 43 | b = tensor2array(b) 44 | if torch.is_tensor(mask): 45 | mask = tensor2array(mask) 46 | if len(mask.shape) == len(a.shape) - 1: 47 | mask = mask[..., None] 48 | mask_sum = np.sum(mask) + 1e-10 49 | rmse = (((a - b)**2 * mask).sum() / (mask_sum))**0.5 50 | return rmse 51 | 52 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | 18 | def TV_loss(x, mask): 19 | B, C, H, W = x.shape 20 | tv_h = torch.abs(x[:,:,1:,:] - x[:,:,:-1,:]).sum() 21 | tv_w = torch.abs(x[:,:,:,1:] - x[:,:,:,:-1]).sum() 22 | return (tv_h + tv_w) / (B * C * H * W) 23 | 24 | 25 | def lpips_loss(img1, img2, lpips_model): 26 | loss = lpips_model(img1,img2) 27 | return loss.mean() 28 | 29 | def l1_loss(network_output, gt, mask=None): 30 | loss = torch.abs((network_output - gt)) 31 | if mask is not None: 32 | if mask.ndim == 4: 33 | mask = mask.repeat(1, network_output.shape[1], 1, 1) 34 | elif mask.ndim == 3: 35 | mask = mask.repeat(network_output.shape[1], 1, 1) 36 | else: 37 | raise ValueError('the dimension of mask should be either 3 or 4') 38 | 39 | try: 40 | loss = loss[mask!=0] 41 | except: 42 | print(loss.shape) 43 | print(mask.shape) 44 | print(loss.dtype) 45 | print(mask.dtype) 46 | return loss.mean() 47 | 48 | def l2_loss(network_output, gt): 49 | return ((network_output - gt) ** 2).mean() 50 | 51 | def gaussian(window_size, sigma): 52 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 53 | return gauss / gauss.sum() 54 | 55 | def create_window(window_size, channel): 56 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 57 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 58 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 59 | return window 60 | 61 | def ssim(img1, img2, window_size=11, size_average=True): 62 | channel = img1.size(-3) 63 | window = create_window(window_size, channel) 64 | 65 | if img1.is_cuda: 66 | window = window.cuda(img1.get_device()) 67 | window = window.type_as(img1) 68 | 69 | return _ssim(img1, img2, window, window_size, channel, size_average) 70 | 71 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 72 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 73 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 74 | 75 | mu1_sq = mu1.pow(2) 76 | mu2_sq = mu2.pow(2) 77 | mu1_mu2 = mu1 * mu2 78 | 79 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 80 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 81 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 82 | 83 | C1 = 0.01 ** 2 84 | C2 = 0.03 ** 2 85 | 86 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 87 | 88 | if size_average: 89 | return ssim_map.mean() 90 | else: 91 | return ssim_map.mean(1).mean(1).mean(1) 92 | 93 | -------------------------------------------------------------------------------- /utils/params_utils.py: -------------------------------------------------------------------------------- 1 | def merge_hparams(args, config): 2 | params = ["OptimizationParams", "ModelHiddenParams", "ModelParams", "PipelineParams"] 3 | for param in params: 4 | if param in config.keys(): 5 | for key, value in config[param].items(): 6 | if hasattr(args, key): 7 | setattr(args, key, value) 8 | return args -------------------------------------------------------------------------------- /utils/scene_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image, ImageDraw, ImageFont 4 | from matplotlib import pyplot as plt 5 | plt.rcParams['font.sans-serif'] = ['Times New Roman'] 6 | 7 | import numpy as np 8 | 9 | import copy 10 | @torch.no_grad() 11 | def render_training_image(scene, gaussians, viewpoints, render_func, pipe, background, stage, iteration, time_now): 12 | def render(gaussians, viewpoint, path, scaling): 13 | # scaling_copy = gaussians._scaling 14 | render_pkg = render_func(viewpoint, gaussians, pipe, background, stage=stage) 15 | label1 = f"stage:{stage},iter:{iteration}" 16 | times = time_now/60 17 | if times < 1: 18 | end = "min" 19 | else: 20 | end = "mins" 21 | label2 = "time:%.2f" % times + end 22 | image = render_pkg["render"] 23 | depth = render_pkg["depth"] 24 | image_np = image.permute(1, 2, 0).cpu().numpy() # 转换通道顺序为 (H, W, 3) 25 | depth_np = depth.permute(1, 2, 0).cpu().numpy() 26 | depth_np /= depth_np.max() 27 | depth_np = np.repeat(depth_np, 3, axis=2) 28 | image_np = np.concatenate((image_np, depth_np), axis=1) 29 | image_with_labels = Image.fromarray((np.clip(image_np,0,1) * 255).astype('uint8')) # 转换为8位图像 30 | # 创建PIL图像对象的副本以绘制标签 31 | draw1 = ImageDraw.Draw(image_with_labels) 32 | 33 | # 选择字体和字体大小 34 | font = ImageFont.truetype('./utils/TIMES.TTF', size=40) # 请将路径替换为您选择的字体文件路径 35 | 36 | # 选择文本颜色 37 | text_color = (255, 0, 0) # 白色 38 | 39 | # 选择标签的位置(左上角坐标) 40 | label1_position = (10, 10) 41 | label2_position = (image_with_labels.width - 100 - len(label2) * 10, 10) # 右上角坐标 42 | 43 | # 在图像上添加标签 44 | draw1.text(label1_position, label1, fill=text_color, font=font) 45 | draw1.text(label2_position, label2, fill=text_color, font=font) 46 | 47 | image_with_labels.save(path) 48 | render_base_path = os.path.join(scene.model_path, f"{stage}_render") 49 | point_cloud_path = os.path.join(render_base_path,"pointclouds") 50 | image_path = os.path.join(render_base_path,"images") 51 | if not os.path.exists(os.path.join(scene.model_path, f"{stage}_render")): 52 | os.makedirs(render_base_path) 53 | if not os.path.exists(point_cloud_path): 54 | os.makedirs(point_cloud_path) 55 | if not os.path.exists(image_path): 56 | os.makedirs(image_path) 57 | # image:3,800,800 58 | 59 | # point_save_path = os.path.join(point_cloud_path,f"{iteration}.jpg") 60 | for idx in range(len(viewpoints)): 61 | image_save_path = os.path.join(image_path,f"{iteration}_{idx}.jpg") 62 | render(gaussians,viewpoints[idx],image_save_path,scaling = 1) 63 | # render(gaussians,point_save_path,scaling = 0.1) 64 | # 保存带有标签的图像 65 | 66 | pc_mask = gaussians.get_opacity 67 | pc_mask = pc_mask > 0.1 68 | xyz = gaussians.get_xyz.detach()[pc_mask.squeeze()].cpu().permute(1,0).numpy() 69 | # visualize_and_save_point_cloud(xyz, viewpoint.R, viewpoint.T, point_save_path) 70 | # 如果需要,您可以将PIL图像转换回PyTorch张量 71 | # return image 72 | # image_with_labels_tensor = torch.tensor(image_with_labels, dtype=torch.float32).permute(2, 0, 1) / 255.0 73 | 74 | def visualize_and_save_point_cloud(point_cloud, R, T, filename): 75 | # 创建3D散点图 76 | fig = plt.figure() 77 | ax = fig.add_subplot(111, projection='3d') 78 | R = R.T 79 | # 应用旋转和平移变换 80 | T = -R.dot(T) 81 | transformed_point_cloud = np.dot(R, point_cloud) + T.reshape(-1, 1) 82 | # pcd = o3d.geometry.PointCloud() 83 | # pcd.points = o3d.utility.Vector3dVector(transformed_point_cloud.T) # 转置点云数据以匹配Open3D的格式 84 | # transformed_point_cloud[2,:] = -transformed_point_cloud[2,:] 85 | # 可视化点云 86 | ax.scatter(transformed_point_cloud[0], transformed_point_cloud[1], transformed_point_cloud[2], c='g', marker='o') 87 | ax.axis("off") 88 | # ax.set_xlabel('X Label') 89 | # ax.set_ylabel('Y Label') 90 | # ax.set_zlabel('Z Label') 91 | 92 | # 保存渲染结果为图片 93 | plt.savefig(filename) 94 | 95 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | class Timer: 3 | def __init__(self): 4 | self.start_time = None 5 | self.elapsed = 0 6 | self.paused = False 7 | 8 | def start(self): 9 | if self.start_time is None: 10 | self.start_time = time.time() 11 | elif self.paused: 12 | self.start_time = time.time() - self.elapsed 13 | self.paused = False 14 | 15 | def pause(self): 16 | if not self.paused: 17 | self.elapsed = time.time() - self.start_time 18 | self.paused = True 19 | 20 | def get_elapsed_time(self): 21 | if self.paused: 22 | return self.elapsed 23 | else: 24 | return time.time() - self.start_time --------------------------------------------------------------------------------