├── LICENSE.md
├── README.md
├── arguments
├── __init__.py
├── __pycache__
│ └── __init__.cpython-37.pyc
└── endonerf
│ └── default.py
├── assets
├── Tab1.png
├── Tab2-padding.png
├── Tab2.png
├── demo1.mp4
├── demo_scene.mp4
├── overview.png
└── visual_results.png
├── gaussian_renderer
├── __init__.py
└── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── network_gui.cpython-37.pyc
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── metrics.py
├── render.py
├── requirements.txt
├── scene
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── cameras.cpython-37.pyc
│ ├── colmap_loader.cpython-37.pyc
│ ├── dataset.cpython-37.pyc
│ ├── dataset_readers.cpython-37.pyc
│ ├── deformation.cpython-37.pyc
│ ├── endo_loader.cpython-37.pyc
│ ├── flexible_deform_model.cpython-37.pyc
│ ├── gaussian_flow_model.cpython-37.pyc
│ ├── gaussian_gaussian_model.cpython-37.pyc
│ ├── gaussian_model.cpython-37.pyc
│ ├── hexplane.cpython-37.pyc
│ ├── hyper_loader.cpython-37.pyc
│ ├── regulation.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── cameras.py
├── dataset_readers.py
├── endo_loader.py
├── flexible_deform_model.py
├── regulation.py
└── utils.py
├── stereomis2endonerf.py
├── submodules
├── RAFT
│ ├── .gitignore
│ ├── LICENSE
│ ├── RAFT.png
│ ├── README.md
│ ├── alt_cuda_corr
│ │ ├── correlation.cpp
│ │ ├── correlation_kernel.cu
│ │ └── setup.py
│ ├── chairs_split.txt
│ ├── core
│ │ ├── __init__.py
│ │ ├── corr.py
│ │ ├── datasets.py
│ │ ├── extractor.py
│ │ ├── raft.py
│ │ ├── update.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── augmentor.py
│ │ │ ├── flow_viz.py
│ │ │ ├── frame_utils.py
│ │ │ └── utils.py
│ ├── demo-frames
│ │ ├── frame_0016.png
│ │ ├── frame_0017.png
│ │ ├── frame_0018.png
│ │ ├── frame_0019.png
│ │ ├── frame_0020.png
│ │ ├── frame_0021.png
│ │ ├── frame_0022.png
│ │ ├── frame_0023.png
│ │ ├── frame_0024.png
│ │ └── frame_0025.png
│ ├── demo.py
│ ├── download_models.sh
│ ├── evaluate.py
│ ├── pretrained
│ │ └── raft-things.pth
│ ├── train.py
│ ├── train_mixed.sh
│ └── train_standard.sh
├── depth-diff-gaussian-rasterization
│ ├── .gitignore
│ ├── .gitmodules
│ ├── CMakeLists.txt
│ ├── LICENSE.md
│ ├── README.md
│ ├── cuda_rasterizer
│ │ ├── auxiliary.h
│ │ ├── backward.cu
│ │ ├── backward.h
│ │ ├── config.h
│ │ ├── forward.cu
│ │ ├── forward.h
│ │ ├── rasterizer.h
│ │ ├── rasterizer_impl.cu
│ │ └── rasterizer_impl.h
│ ├── diff_gaussian_rasterization
│ │ ├── _C.cpython-37m-x86_64-linux-gnu.so
│ │ ├── _C.cpython-38-x86_64-linux-gnu.so
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ └── __init__.cpython-37.pyc
│ ├── ext.cpp
│ ├── rasterize_points.cu
│ ├── rasterize_points.h
│ ├── setup.py
│ └── third_party
│ │ └── stbi_image_write.h
└── simple-knn
│ ├── build
│ ├── lib.linux-x86_64-cpython-37
│ │ └── simple_knn
│ │ │ └── _C.cpython-37m-x86_64-linux-gnu.so
│ └── temp.linux-x86_64-cpython-37
│ │ ├── ext.o
│ │ ├── simple_knn.o
│ │ └── spatial.o
│ ├── ext.cpp
│ ├── setup.py
│ ├── simple_knn.cu
│ ├── simple_knn.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ └── top_level.txt
│ ├── simple_knn.h
│ ├── simple_knn
│ ├── .gitkeep
│ ├── _C.cpython-37m-x86_64-linux-gnu.so
│ └── _C.cpython-38-x86_64-linux-gnu.so
│ ├── spatial.cu
│ └── spatial.h
├── train.py
└── utils
├── TIMES.TTF
├── TIMESBD.TTF
├── TIMESBI.TTF
├── TIMESI.TTF
├── __pycache__
├── camera_utils.cpython-37.pyc
├── general_utils.cpython-37.pyc
├── graphics_utils.cpython-37.pyc
├── image_utils.cpython-37.pyc
├── loss_utils.cpython-37.pyc
├── params_utils.cpython-37.pyc
├── scene_utils.cpython-37.pyc
├── sh_utils.cpython-37.pyc
├── system_utils.cpython-37.pyc
└── timer.cpython-37.pyc
├── camera_utils.py
├── general_utils.py
├── graphics_utils.py
├── image_utils.py
├── loss_utils.py
├── params_utils.py
├── scene_utils.py
├── sh_utils.py
├── stereo_rectify.py
├── system_utils.py
└── timer.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Gaussian-Splatting License
2 | ===========================
3 |
4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5 | The *Software* is in the process of being registered with the Agence pour la Protection des
6 | Programmes (APP).
7 |
8 | The *Software* is still being developed by the *Licensor*.
9 |
10 | *Licensor*'s goal is to allow the research community to use, test and evaluate
11 | the *Software*.
12 |
13 | ## 1. Definitions
14 |
15 | *Licensee* means any person or entity that uses the *Software* and distributes
16 | its *Work*.
17 |
18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII
19 |
20 | *Software* means the original work of authorship made available under this
21 | License ie gaussian-splatting.
22 |
23 | *Work* means the *Software* and any additions to or derivative works of the
24 | *Software* that are made available under this License.
25 |
26 |
27 | ## 2. Purpose
28 | This license is intended to define the rights granted to the *Licensee* by
29 | Licensors under the *Software*.
30 |
31 | ## 3. Rights granted
32 |
33 | For the above reasons Licensors have decided to distribute the *Software*.
34 | Licensors grant non-exclusive rights to use the *Software* for research purposes
35 | to research users (both academic and industrial), free of charge, without right
36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37 | and/or evaluation purposes only.
38 |
39 | Subject to the terms and conditions of this License, you are granted a
40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41 | publicly display, publicly perform and distribute its *Work* and any resulting
42 | derivative works in any form.
43 |
44 | ## 4. Limitations
45 |
46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47 | so under this License, (b) you include a complete copy of this License with
48 | your distribution, and (c) you retain without modification any copyright,
49 | patent, trademark, or attribution notices that are present in the *Work*.
50 |
51 | **4.2 Derivative Works.** You may specify that additional or different terms apply
52 | to the use, reproduction, and distribution of your derivative works of the *Work*
53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in
54 | Section 2 applies to your derivative works, and (b) you identify the specific
55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56 | this License (including the redistribution requirements in Section 3.1) will
57 | continue to apply to the *Work* itself.
58 |
59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60 | users explicitly acknowledge having received from Licensors all information
61 | allowing to appreciate the adequacy between of the *Software* and their needs and
62 | to undertake all necessary precautions for its execution and use.
63 |
64 | **4.4** The *Software* is provided both as a compiled library file and as source
65 | code. In case of using the *Software* for a publication or other results obtained
66 | through the use of the *Software*, users are strongly encouraged to cite the
67 | corresponding publications as explained in the documentation of the *Software*.
68 |
69 | ## 5. Disclaimer
70 |
71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76 | USE, PROFESSIONAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deform3DGS: Flexible Deformation for Fast Surgical Scene Reconstruction with Gaussian Splatting
2 |
3 | Official code implementation for [Deform3DGS](https://arxiv.org/abs/2405.17835), a Gaussian Splatting based framework for surgical scene reconstruction.
4 |
5 |
6 |
7 | > [Deform3DGS: Flexible Deformation for Fast Surgical Scene Reconstruction with Gaussian Splatting](https://arxiv.org/abs/2405.17835)\
8 | > Shuojue Yang, Qian Li, Daiyun Shen, Bingchen Gong, Qi Dou, Yueming Jin\
9 | > MICCAI2024, **Early Accept**
10 |
11 | ## Demo
12 |
13 | ### Reconstruction within 1 minute
14 |
15 |
16 |
17 | https://github.com/jinlab-imvr/Deform3DGS/assets/157268160/7609bfb6-9130-488f-b893-85cc82d60d63
18 |
19 | Compared to previous SOTA method in fast reconstruction, our method reduces the training time to **1 minute** for each clip in EndoNeRF dataset, demonstrating remarkable superiority in efficiency.
20 |
21 | ### Reconstruction of various scenes
22 |
23 |
27 |
28 | https://github.com/jinlab-imvr/Deform3DGS/assets/157268160/633777fa-9110-4823-b6e5-f5d338e72551
29 |
30 | ## Pipeline
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 | **Deform3DGS** is composed of (a) Point cloud initialization, (b) Flexible Deformation Modeling, and (c) 3D Gaussian Splatting. Experiments on DaVinci robotic surgery videos indicate the efficacy of our approach, showcasing superior reconstruction fidelity PSNR: (37.90) and rendering speed (338.8 FPS) while substantially reducing training time to only 1 minute/scene.
39 |
40 |
44 |
45 | ## Environment setup
46 |
47 | Tested with NVIDIA RTX A5000 GPU.
48 |
49 | ```bash
50 | git clone https://github.com/jinlab-imvr/Deform3DGS.git
51 | cd Deform3DGS
52 | conda create -n Deform3DGS python=3.7
53 | conda activate Deform3DGS
54 |
55 | pip install -r requirements.txt
56 | pip install -e submodules/depth-diff-gaussian-rasterization
57 | pip install -e submodules/simple-knn
58 | ```
59 |
60 | ## Datasets
61 |
62 | We use 6 clips from [EndoNeRF](https://github.com/med-air/EndoNeRF) and 3 clips manually extracted from [StereoMIS](https://zenodo.org/records/7727692) to verify our method.
63 |
64 | To use the two available examples in [EndoNeRF](https://github.com/med-air/EndoNeRF) dataset. Please download the data via [this link](https://forms.gle/1VAqDJTEgZduD6157) and organize the data according to the [guideline](https://github.com/med-air/EndoNeRF.git).
65 |
66 | To use the [StereoMIS](https://zenodo.org/records/7727692) dataset, please follow this [github repo](https://github.com/aimi-lab/robust-pose-estimator) to preprocess the dataset. After that, run the provided script `stereomis2endonerf.py` to extract clips from the StereoMIS dataset and organize the depth, masks, images, intrinsic and extrinsic parameters in the same format as [EndoNeRF](https://github.com/med-air/EndoNeRF). In our implementation, we used [RAFT](https://github.com/princeton-vl/RAFT) to estimate the stereo depth for [StereoMIS](https://zenodo.org/records/7727692) clips. Following EndoNeRF dataset, this script only supports fixed-view settings.
67 |
68 |
69 |
70 | The data structure is as follows:
71 |
72 | ```
73 | data
74 | | - endonerf_full_datasets
75 | | | - cutting_tissues_twice
76 | | | | - depth/
77 | | | | - images/
78 | | | | - masks/
79 | | | | - pose_bounds.npy
80 | | | - pushing_soft_tissues
81 | | - StereoMIS
82 | | | - stereo_seq_1
83 | | | - stereo_seq_2
84 | ```
85 |
86 | ## Training
87 |
88 | To train Deform3DGS with customized hyper-parameters, please make changes in `arguments/endonerf/default.py`.
89 |
90 | To train Deform3DGS, run the following example command:
91 |
92 | ```
93 | python train.py -s data/endonerf_full_datasets/pulling_soft_tissues --expname endonerf/pulling_fdm --configs arguments/endonerf/default.py
94 | ```
95 |
96 | ## Testing
97 |
98 | For testing, we perform rendering and evaluation separately.
99 |
100 | ### Rendering
101 |
102 | To run the following example command to render the images:
103 |
104 | ```
105 | python render.py --model_path output/endonerf/pulling_fdm --skip_train --reconstruct_test --configs arguments/endonerf/default.py
106 | ```
107 |
108 | Please follow [EndoGaussian](https://github.com/yifliu3/EndoGaussian/tree/master) to skip rendering. Of note, you can also set `--reconstruct_train`, `--reconstruct_test`, and `--reconstruct_video` to reconstruct and save the `.ply` 3D point cloud of the rendered outputs for `train`, `test` and`video` sets, respectively.
109 |
110 | ### Evaluation
111 |
112 | To evaluate the reconstruction quality, run following command:
113 |
114 | ```
115 | python metrics.py --model_path output/endonerf/pulling_fdm -p test
116 | ```
117 |
118 | Note that you can set `-p video`, `-p test`, `-p train` to select the set for evaluation.
119 |
120 | ## Acknowledgements
121 |
122 | This repo borrows some source code from [EndoGaussian](https://github.com/yifliu3/EndoGaussian/tree/master), [4DGS](https://github.com/hustvl/4DGaussians), [depth-diff-gaussian-rasterizer](https://github.com/ingra14m/depth-diff-gaussian-rasterization), [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), and [EndoNeRF](https://github.com/med-air/EndoNeRF). We would like to acknowledge these great prior literatures for inspiring our work.
123 |
124 | Thanks to [EndoGaussian](https://github.com/yifliu3/EndoGaussian/tree/master) for their great and timely effort in releasing the framework adapting Gaussian Splatting into surgical scene.
125 |
126 | ## Citation
127 |
128 | If you find this code useful for your research, please use the following BibTeX entries:
129 |
130 | ```
131 | @misc{yang2024deform3dgs,
132 | title={Deform3DGS: Flexible Deformation for Fast Surgical Scene Reconstruction with Gaussian Splatting},
133 | author={Shuojue Yang and Qian Li and Daiyun Shen and Bingchen Gong and Qi Dou and Yueming Jin},
134 | year={2024},
135 | eprint={2405.17835},
136 | archivePrefix={arXiv},
137 | primaryClass={cs.CV}
138 | }
139 |
140 | ```
141 |
142 | ### Questions
143 |
144 | For further question about the code or paper, welcome to create an issue, or contact 's.yang@u.nus.edu' or 'ymjin@nus.edu.sg'
145 |
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from argparse import ArgumentParser, Namespace
13 | import sys
14 | import os
15 |
16 | class GroupParams:
17 | pass
18 |
19 | class ParamGroup:
20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21 | group = parser.add_argument_group(name)
22 | for key, value in vars(self).items():
23 | shorthand = False
24 | if key.startswith("_"):
25 | shorthand = True
26 | key = key[1:]
27 | t = type(value)
28 | value = value if not fill_none else None
29 | if shorthand:
30 | if t == bool:
31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32 | else:
33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34 | else:
35 | if t == bool:
36 | group.add_argument("--" + key, default=value, action="store_true")
37 | else:
38 | group.add_argument("--" + key, default=value, type=t)
39 |
40 | def extract(self, args):
41 | group = GroupParams()
42 | for arg in vars(args).items():
43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44 | setattr(group, arg[0], arg[1])
45 | return group
46 |
47 | class ModelParams(ParamGroup):
48 | def __init__(self, parser, sentinel=False):
49 | self.sh_degree = 3
50 | self._source_path = ""
51 | self._model_path = ""
52 | self._images = "images"
53 | self._resolution = -1
54 | self._white_background = False
55 | self.data_device = "cuda"
56 | self.eval = True
57 | self.render_process=False
58 | self.extra_mark = None
59 | self.camera_extent = None
60 | super().__init__(parser, "Loading Parameters", sentinel)
61 |
62 | def extract(self, args):
63 | g = super().extract(args)
64 | g.source_path = os.path.abspath(g.source_path)
65 | return g
66 |
67 | class PipelineParams(ParamGroup):
68 | def __init__(self, parser):
69 | self.convert_SHs_python = False
70 | self.compute_cov3D_python = False
71 | self.debug = False
72 | super().__init__(parser, "Pipeline Parameters")
73 |
74 |
75 | class FDMHiddenParams(ParamGroup):
76 | def __init__(self, parser):
77 | self.net_width = 64
78 | self.timebase_pe = 4
79 | self.defor_depth = 1
80 | self.posebase_pe = 10
81 | self.scale_rotation_pe = 2
82 | self.opacity_pe = 2
83 | self.timenet_width = 64
84 | self.timenet_output = 32
85 | self.bounds = 1.6
86 |
87 | self.ch_num = 10
88 | self.curve_num = 17
89 | self.init_param = 0.01
90 |
91 | super().__init__(parser, "FDMHiddenParams")
92 |
93 |
94 | class OptimizationParams(ParamGroup):
95 | def __init__(self, parser):
96 | self.dataloader=False
97 | self.iterations = 30_000
98 | self.position_lr_init = 0.00016
99 | self.position_lr_final = 0.0000016
100 | self.position_lr_delay_mult = 0.01
101 | self.position_lr_max_steps = 20_000
102 | self.deformation_lr_init = 0.00016
103 | self.deformation_lr_final = 0.000016
104 | self.deformation_lr_delay_mult = 0.01
105 | # self.grid_lr_init = 0.0016
106 | # self.grid_lr_final = 0.00016
107 |
108 | self.feature_lr = 0.0025
109 | self.opacity_lr = 0.05
110 | self.scaling_lr = 0.005
111 | self.rotation_lr = 0.001
112 | self.percent_dense = 0.01
113 | self.weight_constraint_init= 1
114 | self.weight_constraint_after = 0.2
115 | self.weight_decay_iteration = 5000
116 | self.opacity_reset_interval = 3000
117 | self.densification_interval = 100
118 | self.densify_from_iter = 500
119 | self.densify_until_iter = 15_000
120 | self.densify_grad_threshold_coarse = 0.0002
121 | self.densify_grad_threshold_fine_init = 0.0002
122 | self.densify_grad_threshold_after = 0.0002
123 | self.pruning_from_iter = 500
124 | self.pruning_interval = 100
125 | self.opacity_threshold_coarse = 0.005
126 | self.opacity_threshold_fine_init = 0.005
127 | self.opacity_threshold_fine_after = 0.005
128 |
129 | super().__init__(parser, "Optimization Parameters")
130 |
131 | def get_combined_args(parser : ArgumentParser):
132 | cmdlne_string = sys.argv[1:]
133 | cfgfile_string = "Namespace()"
134 | args_cmdline = parser.parse_args(cmdlne_string)
135 |
136 | try:
137 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
138 | print("Looking for config file in", cfgfilepath)
139 | with open(cfgfilepath) as cfg_file:
140 | print("Config file found: {}".format(cfgfilepath))
141 | cfgfile_string = cfg_file.read()
142 | except TypeError:
143 | print("Config file not found at")
144 | pass
145 | args_cfgfile = eval(cfgfile_string)
146 |
147 | merged_dict = vars(args_cfgfile).copy()
148 | for k,v in vars(args_cmdline).items():
149 | if v != None:
150 | merged_dict[k] = v
151 | return Namespace(**merged_dict)
152 |
--------------------------------------------------------------------------------
/arguments/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/arguments/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/arguments/endonerf/default.py:
--------------------------------------------------------------------------------
1 | ModelParams = dict(
2 | extra_mark = 'endonerf',
3 | camera_extent = 10
4 | )
5 |
6 | OptimizationParams = dict(
7 | coarse_iterations = 0,
8 | deformation_lr_init = 0.00016,
9 | deformation_lr_final = 0.0000016,
10 | deformation_lr_delay_mult = 0.01,
11 | iterations = 3000,
12 | percent_dense = 0.01,
13 | opacity_reset_interval = 3000,
14 | position_lr_max_steps = 4000,
15 | prune_interval = 3000
16 | )
17 |
18 | ModelHiddenParams = dict(
19 | curve_num = 17, # number of learnable basis functions. This number was set to 17 for all the experiments in paper (https://arxiv.org/abs/2405.17835)
20 |
21 | ch_num = 10, # channel number of deformable attributes: 10 = 3 (scale) + 3 (mean) + 4 (rotation)
22 | init_param = 0.01, )
23 |
--------------------------------------------------------------------------------
/assets/Tab1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/Tab1.png
--------------------------------------------------------------------------------
/assets/Tab2-padding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/Tab2-padding.png
--------------------------------------------------------------------------------
/assets/Tab2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/Tab2.png
--------------------------------------------------------------------------------
/assets/demo1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/demo1.mp4
--------------------------------------------------------------------------------
/assets/demo_scene.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/demo_scene.mp4
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/overview.png
--------------------------------------------------------------------------------
/assets/visual_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/assets/visual_results.png
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15 | from scene.flexible_deform_model import GaussianModel
16 | from utils.sh_utils import eval_sh
17 |
18 | def render_flow(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
19 | """
20 | Render the scene.
21 |
22 | Background tensor (bg_color) must be on GPU!
23 | """
24 |
25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
27 | try:
28 | screenspace_points.retain_grad()
29 | except:
30 | pass
31 |
32 | # Set up rasterization configuration
33 |
34 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
35 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
36 |
37 | raster_settings = GaussianRasterizationSettings(
38 | image_height=int(viewpoint_camera.image_height),
39 | image_width=int(viewpoint_camera.image_width),
40 | tanfovx=tanfovx,
41 | tanfovy=tanfovy,
42 | bg=bg_color,
43 | scale_modifier=scaling_modifier,
44 | viewmatrix=viewpoint_camera.world_view_transform.cuda(),
45 | projmatrix=viewpoint_camera.full_proj_transform.cuda(),
46 | sh_degree=pc.active_sh_degree,
47 | campos=viewpoint_camera.camera_center.cuda(),
48 | prefiltered=False,
49 | debug=pipe.debug
50 | )
51 |
52 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
53 |
54 | # means3D = pc.get_xyz
55 | # add deformation to each points
56 | # deformation = pc.get_deformation
57 | means3D = pc.get_xyz
58 | ori_time = torch.tensor(viewpoint_camera.time).to(means3D.device)
59 | means2D = screenspace_points
60 | opacity = pc._opacity
61 |
62 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
63 | # scaling / rotation by the rasterizer.
64 | scales = None
65 | rotations = None
66 | cov3D_precomp = None
67 |
68 | if pipe.compute_cov3D_python:
69 | cov3D_precomp = pc.get_covariance(scaling_modifier)
70 | else:
71 | scales = pc._scaling
72 | rotations = pc._rotation
73 | deformation_point = pc._deformation_table
74 |
75 |
76 | means3D_deform, scales_deform, rotations_deform = pc.deformation(means3D[deformation_point], scales[deformation_point],
77 | rotations[deformation_point],
78 | ori_time)
79 | opacity_deform = opacity[deformation_point]
80 |
81 | # print(time.max())
82 | with torch.no_grad():
83 | pc._deformation_accum[deformation_point] += torch.abs(means3D_deform - means3D[deformation_point])
84 |
85 | means3D_final = torch.zeros_like(means3D)
86 | rotations_final = torch.zeros_like(rotations)
87 | scales_final = torch.zeros_like(scales)
88 | opacity_final = torch.zeros_like(opacity)
89 | means3D_final[deformation_point] = means3D_deform
90 | rotations_final[deformation_point] = rotations_deform
91 | scales_final[deformation_point] = scales_deform
92 | opacity_final[deformation_point] = opacity_deform
93 | means3D_final[~deformation_point] = means3D[~deformation_point]
94 | rotations_final[~deformation_point] = rotations[~deformation_point]
95 | scales_final[~deformation_point] = scales[~deformation_point]
96 | opacity_final[~deformation_point] = opacity[~deformation_point]
97 |
98 | scales_final = pc.scaling_activation(scales_final)
99 | rotations_final = pc.rotation_activation(rotations_final)
100 | opacity = pc.opacity_activation(opacity)
101 |
102 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
103 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
104 | shs = None
105 | colors_precomp = None
106 | if override_color is None:
107 | if pipe.convert_SHs_python:
108 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
109 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.cuda().repeat(pc.get_features.shape[0], 1))
110 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
111 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
112 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
113 | else:
114 | shs = pc.get_features
115 | else:
116 | colors_precomp = override_color
117 |
118 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
119 | rendered_image, radii, depth = rasterizer(
120 | means3D = means3D_final,
121 | means2D = means2D,
122 | shs = shs,
123 | colors_precomp = colors_precomp,
124 | opacities = opacity,
125 | scales = scales_final,
126 | rotations = rotations_final,
127 | cov3D_precomp = cov3D_precomp)
128 |
129 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
130 | # They will be excluded from value updates used in the splitting criteria.
131 | return {"render": rendered_image,
132 | "depth": depth,
133 | "viewspace_points": screenspace_points,
134 | "visibility_filter" : radii > 0,
135 | "radii": radii,}
136 |
--------------------------------------------------------------------------------
/gaussian_renderer/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/gaussian_renderer/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/gaussian_renderer/__pycache__/network_gui.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/gaussian_renderer/__pycache__/network_gui.cpython-37.pyc
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y)
22 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16().features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from pathlib import Path
13 | import os
14 | from PIL import Image
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | from utils.loss_utils import ssim
18 | # from lpipsPyTorch import lpips
19 | import lpips
20 | import json
21 | from tqdm import tqdm
22 | from utils.image_utils import psnr
23 | from utils.image_utils import rmse
24 | from argparse import ArgumentParser
25 | import numpy as np
26 |
27 |
28 | def array2tensor(array, device="cuda", dtype=torch.float32):
29 | return torch.tensor(array, dtype=dtype, device=device)
30 |
31 | # Learned Perceptual Image Patch Similarity
32 | class LPIPS(object):
33 | """
34 | borrowed from https://github.com/huster-wgm/Pytorch-metrics/blob/master/metrics.py
35 | """
36 | def __init__(self, device="cuda"):
37 | self.model = lpips.LPIPS(net='alex').to(device)
38 |
39 | def __call__(self, y_pred, y_true, normalized=True):
40 | """
41 | args:
42 | y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
43 | y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
44 | normalized : change [0,1] => [-1,1] (default by LPIPS)
45 | return LPIPS, smaller the better
46 | """
47 | if normalized:
48 | y_pred = y_pred * 2.0 - 1.0
49 | y_true = y_true * 2.0 - 1.0
50 | error = self.model.forward(y_pred, y_true)
51 | return torch.mean(error)
52 |
53 | lpips = LPIPS()
54 | def cal_lpips(a, b, device="cuda", batch=2):
55 | """Compute lpips.
56 | a, b: [batch, H, W, 3]"""
57 | if not torch.is_tensor(a):
58 | a = array2tensor(a, device)
59 | if not torch.is_tensor(b):
60 | b = array2tensor(b, device)
61 |
62 | lpips_all = []
63 | for a_split, b_split in zip(a.split(split_size=batch, dim=0), b.split(split_size=batch, dim=0)):
64 | out = lpips(a_split, b_split)
65 | lpips_all.append(out)
66 | lpips_all = torch.stack(lpips_all)
67 | lpips_mean = lpips_all.mean()
68 | return lpips_mean
69 |
70 | def readImages(renders_dir, gt_dir, depth_dir, gtdepth_dir, masks_dir):
71 | renders = []
72 | gts = []
73 | image_names = []
74 | depths = []
75 | gt_depths = []
76 | masks = []
77 |
78 | for fname in os.listdir(renders_dir):
79 | render = np.array(Image.open(renders_dir / fname))
80 | gt = np.array(Image.open(gt_dir / fname))
81 | depth = np.array(Image.open(depth_dir / fname))
82 | gt_depth = np.array(Image.open(gtdepth_dir / fname))
83 | mask = np.array(Image.open(masks_dir / fname))
84 |
85 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
86 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
87 | depths.append(torch.from_numpy(depth).unsqueeze(0).unsqueeze(1)[:, :, :, :].cuda())
88 | gt_depths.append(torch.from_numpy(gt_depth).unsqueeze(0).unsqueeze(1)[:, :3, :, :].cuda())
89 | masks.append(tf.to_tensor(mask).unsqueeze(0).cuda())
90 |
91 | image_names.append(fname)
92 | return renders, gts, depths, gt_depths, masks, image_names
93 |
94 | def evaluate(model_paths):
95 |
96 | full_dict = {}
97 | per_view_dict = {}
98 | full_dict_polytopeonly = {}
99 | per_view_dict_polytopeonly = {}
100 | print("")
101 |
102 | with torch.no_grad():
103 | for scene_dir in model_paths:
104 | print("Scene:", scene_dir)
105 | full_dict[scene_dir] = {}
106 | per_view_dict[scene_dir] = {}
107 | full_dict_polytopeonly[scene_dir] = {}
108 | per_view_dict_polytopeonly[scene_dir] = {}
109 |
110 | test_dir = Path(scene_dir) / args.phase
111 |
112 | for method in os.listdir(test_dir):
113 | print("Method:", method)
114 |
115 | full_dict[scene_dir][method] = {}
116 | per_view_dict[scene_dir][method] = {}
117 | full_dict_polytopeonly[scene_dir][method] = {}
118 | per_view_dict_polytopeonly[scene_dir][method] = {}
119 |
120 | method_dir = test_dir / method
121 | gt_dir = method_dir/ "gt"
122 | renders_dir = method_dir / "renders"
123 | depth_dir = method_dir / "depth"
124 | gt_depth_dir = method_dir / "gt_depth"
125 | masks_dir = method_dir / "masks"
126 |
127 | renders, gts, depths, gt_depths, masks, image_names = readImages(renders_dir, gt_dir, depth_dir, gt_depth_dir, masks_dir)
128 |
129 | ssims = []
130 | psnrs = []
131 | lpipss = []
132 | rmses = []
133 |
134 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
135 | render, gt, depth, gt_depth, mask = renders[idx], gts[idx], depths[idx], gt_depths[idx], masks[idx]
136 | render = render * mask
137 | gt = gt * mask
138 | psnrs.append(psnr(render, gt))
139 | ssims.append(ssim(render, gt))
140 | lpipss.append(cal_lpips(render, gt))
141 | if (gt_depth!=0).sum() < 10:
142 | continue
143 | rmses.append(rmse(depth, gt_depth, mask))
144 |
145 | print("Scene: ", scene_dir, "SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
146 | print("Scene: ", scene_dir, "PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
147 | print("Scene: ", scene_dir, "LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
148 | print("Scene: ", scene_dir, "RMSE: {:>12.7f}".format(torch.tensor(rmses).mean(), ".5"))
149 | print("")
150 |
151 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
152 | "PSNR": torch.tensor(psnrs).mean().item(),
153 | "LPIPS": torch.tensor(lpipss).mean().item(),
154 | "RMSE": torch.tensor(rmses).mean().item()})
155 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
156 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
157 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)},
158 | "RMSES": {name: lp for lp, name in zip(torch.tensor(rmses).tolist(), image_names)}})
159 |
160 | with open(scene_dir + "/results.json", 'w') as fp:
161 | json.dump(full_dict[scene_dir], fp, indent=True)
162 | with open(scene_dir + "/per_view.json", 'w') as fp:
163 | json.dump(per_view_dict[scene_dir], fp, indent=True)
164 |
165 |
166 | if __name__ == "__main__":
167 | device = torch.device("cuda:0")
168 | torch.cuda.set_device(device)
169 |
170 | # Set up command line argument parser
171 | parser = ArgumentParser(description="Training script parameters")
172 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
173 | parser.add_argument('--phase', '-p', type=str, default='test')
174 | args = parser.parse_args()
175 | evaluate(args.model_paths)
176 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1
2 | torchvision==0.14.1
3 | torchaudio==0.13.1
4 | mmcv==1.6.0
5 | matplotlib
6 | argparse
7 | lpips
8 | plyfile
9 | imageio-ffmpeg
10 | open3d
11 | imageio
12 |
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.flexible_deform_model import GaussianModel
18 | from arguments import ModelParams
19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
20 | from torch.utils.data import Dataset
21 |
22 | class Scene:
23 |
24 | gaussians : GaussianModel
25 |
26 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None):
27 | """b
28 | :param path: Path to colmap scene main folder.
29 | """
30 | self.model_path = args.model_path
31 | self.loaded_iter = None
32 | self.gaussians = gaussians
33 |
34 | if load_iteration:
35 | if load_iteration == -1:
36 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
37 | else:
38 | self.loaded_iter = load_iteration
39 | print("Loading trained model at iteration {}".format(self.loaded_iter))
40 |
41 | if os.path.exists(os.path.join(args.source_path, "poses_bounds.npy")) and args.extra_mark == 'endonerf':
42 | scene_info = sceneLoadTypeCallbacks["endonerf"](args.source_path)
43 | print("Found poses_bounds.py and extra marks with EndoNeRf")
44 | elif os.path.exists(os.path.join(args.source_path, "point_cloud.obj")) or os.path.exists(os.path.join(args.source_path, "left_point_cloud.obj")):
45 | scene_info = sceneLoadTypeCallbacks["scared"](args.source_path, args.white_background, args.eval)
46 | print("Found point_cloud.obj, assuming SCARED data!")
47 | else:
48 | assert False, "Could not recognize scene type!"
49 |
50 | self.maxtime = scene_info.maxtime
51 | self.cameras_extent = scene_info.nerf_normalization["radius"]
52 | # self.cameras_extent = args.camera_extent
53 | print("self.cameras_extent is ", self.cameras_extent)
54 |
55 | print("Loading Training Cameras")
56 | self.train_camera = scene_info.train_cameras
57 | print("Loading Test Cameras")
58 | self.test_camera = scene_info.test_cameras
59 | print("Loading Video Cameras")
60 | self.video_camera = scene_info.video_cameras
61 |
62 | xyz_max = scene_info.point_cloud.points.max(axis=0)
63 | xyz_min = scene_info.point_cloud.points.min(axis=0)
64 | # self.gaussians._deformation.deformation_net.grid.set_aabb(xyz_max,xyz_min)
65 |
66 | if self.loaded_iter:
67 | self.gaussians.load_ply(os.path.join(self.model_path,
68 | "point_cloud",
69 | "iteration_" + str(self.loaded_iter),
70 | "point_cloud.ply"))
71 | self.gaussians.load_model(os.path.join(self.model_path,
72 | "point_cloud",
73 | "iteration_" + str(self.loaded_iter),
74 | ))
75 | else:
76 | self.gaussians.create_from_pcd(scene_info.point_cloud, args.camera_extent, self.maxtime)
77 |
78 | def save(self, iteration, stage):
79 | if stage == "coarse":
80 | point_cloud_path = os.path.join(self.model_path, "point_cloud/coarse_iteration_{}".format(iteration))
81 | else:
82 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
83 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
84 | # self.gaussians.save_deformation(point_cloud_path)
85 |
86 | def getTrainCameras(self, scale=1.0):
87 | return self.train_camera
88 |
89 | def getTestCameras(self, scale=1.0):
90 | return self.test_camera
91 |
92 | def getVideoCameras(self, scale=1.0):
93 | return self.video_camera
94 |
95 |
--------------------------------------------------------------------------------
/scene/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/cameras.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/cameras.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/colmap_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/colmap_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/dataset_readers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/dataset_readers.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/deformation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/deformation.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/endo_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/endo_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/flexible_deform_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/flexible_deform_model.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/gaussian_flow_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/gaussian_flow_model.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/gaussian_gaussian_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/gaussian_gaussian_model.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/gaussian_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/gaussian_model.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/hexplane.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/hexplane.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/hyper_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/hyper_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/regulation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/regulation.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/scene/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getProjectionMatrix2
16 |
17 |
18 |
19 | class Camera(nn.Module):
20 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, depth, mask, gt_alpha_mask,
21 | image_name, uid,
22 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0,
23 | data_device = "cuda", time = 0, Znear=None, Zfar=None,
24 | K=None, h=None, w=None
25 | ):
26 | super(Camera, self).__init__()
27 |
28 | self.uid = uid
29 | self.colmap_id = colmap_id
30 | self.R = R
31 | self.T = T
32 | self.FoVx = FoVx
33 | self.FoVy = FoVy
34 | self.image_name = image_name
35 | self.time = time
36 | self.mask = mask
37 | try:
38 | self.data_device = torch.device(data_device)
39 | except Exception as e:
40 | print(e)
41 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
42 | self.data_device = torch.device("cuda")
43 |
44 | self.original_image = image.clamp(0.0, 1.0)
45 | self.original_depth = depth
46 | self.image_width = self.original_image.shape[2]
47 | self.image_height = self.original_image.shape[1]
48 | if gt_alpha_mask is not None:
49 | self.original_image *= gt_alpha_mask
50 | else:
51 | self.original_image *= torch.ones((1, self.image_height, self.image_width))
52 |
53 | if Zfar is not None and Znear is not None:
54 | self.zfar = Zfar
55 | self.znear = Znear
56 | else:
57 | # ENDONERF
58 | self.zfar = 120.0
59 | self.znear = 0.01
60 |
61 | # StereoMIS
62 | self.zfar = 250
63 | self.znear= 0.03
64 |
65 | self.trans = trans
66 | self.scale = scale
67 |
68 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1)
69 | if K is None or h is None or w is None:
70 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1)
71 | else:
72 | self.projection_matrix = getProjectionMatrix2(znear=self.znear, zfar=self.zfar, K=K, h = h, w=w).transpose(0,1)
73 |
74 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
75 | self.camera_center = self.world_view_transform.inverse()[3, :3]
76 |
77 | class MiniCam:
78 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform, time):
79 | self.image_width = width
80 | self.image_height = height
81 | self.FoVy = fovy
82 | self.FoVx = fovx
83 | self.znear = znear
84 | self.zfar = zfar
85 | self.world_view_transform = world_view_transform
86 | self.full_proj_transform = full_proj_transform
87 | view_inv = torch.inverse(self.world_view_transform)
88 | self.camera_center = view_inv[3][:3]
89 | self.time = time
90 |
91 |
--------------------------------------------------------------------------------
/scene/dataset_readers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from PIL import Image
15 | from typing import NamedTuple
16 | import torchvision.transforms as transforms
17 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
18 | import numpy as np
19 | import torch
20 | import json
21 | from pathlib import Path
22 | from plyfile import PlyData, PlyElement
23 | from scene.flexible_deform_model import BasicPointCloud
24 | from utils.general_utils import PILtoTorch
25 | from tqdm import tqdm
26 |
27 | class CameraInfo(NamedTuple):
28 | uid: int
29 | R: np.array
30 | T: np.array
31 | FovY: np.array
32 | FovX: np.array
33 | image: np.array
34 | image_path: str
35 | image_name: str
36 | width: int
37 | height: int
38 | time : float
39 |
40 | class SceneInfo(NamedTuple):
41 | point_cloud: BasicPointCloud
42 | train_cameras: list
43 | test_cameras: list
44 | video_cameras: list
45 | nerf_normalization: dict
46 | ply_path: str
47 | maxtime: int
48 |
49 | def getNerfppNorm(cam_info):
50 | def get_center_and_diag(cam_centers):
51 | cam_centers = np.hstack(cam_centers)
52 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
53 | center = avg_cam_center
54 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
55 | diagonal = np.max(dist)
56 | return center.flatten(), diagonal
57 |
58 | cam_centers = []
59 | for cam in cam_info:
60 | W2C = getWorld2View2(cam.R, cam.T)
61 | C2W = np.linalg.inv(W2C)
62 | cam_centers.append(C2W[:3, 3:4])
63 |
64 | center, diagonal = get_center_and_diag(cam_centers)
65 | radius = diagonal * 1.1
66 | translate = -center
67 |
68 | return {"translate": translate, "radius": radius}
69 |
70 |
71 |
72 | def fetchPly(path):
73 | plydata = PlyData.read(path)
74 | vertices = plydata['vertex']
75 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
76 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
77 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
78 | return BasicPointCloud(points=positions, colors=colors, normals=normals)
79 |
80 | def storePly(path, xyz, rgb):
81 | # Define the dtype for the structured array
82 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
83 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
84 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
85 |
86 | normals = np.zeros_like(xyz)
87 | elements = np.empty(xyz.shape[0], dtype=dtype)
88 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
89 | elements[:] = list(map(tuple, attributes))
90 |
91 | # Create the PlyData object and write to file
92 | vertex_element = PlyElement.describe(elements, 'vertex')
93 | ply_data = PlyData([vertex_element])
94 | ply_data.write(path)
95 |
96 |
97 | def generateCamerasFromTransforms(path, template_transformsfile, extension, maxtime):
98 | trans_t = lambda t : torch.Tensor([
99 | [1,0,0,0],
100 | [0,1,0,0],
101 | [0,0,1,t],
102 | [0,0,0,1]]).float()
103 |
104 | rot_phi = lambda phi : torch.Tensor([
105 | [1,0,0,0],
106 | [0,np.cos(phi),-np.sin(phi),0],
107 | [0,np.sin(phi), np.cos(phi),0],
108 | [0,0,0,1]]).float()
109 |
110 | rot_theta = lambda th : torch.Tensor([
111 | [np.cos(th),0,-np.sin(th),0],
112 | [0,1,0,0],
113 | [np.sin(th),0, np.cos(th),0],
114 | [0,0,0,1]]).float()
115 |
116 | def pose_spherical(theta, phi, radius):
117 | c2w = trans_t(radius)
118 | c2w = rot_phi(phi/180.*np.pi) @ c2w
119 | c2w = rot_theta(theta/180.*np.pi) @ c2w
120 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
121 | return c2w
122 |
123 | cam_infos = []
124 | # generate render poses and times
125 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
126 | render_times = torch.linspace(0,maxtime,render_poses.shape[0])
127 | with open(os.path.join(path, template_transformsfile)) as json_file:
128 | template_json = json.load(json_file)
129 | fovx = template_json["camera_angle_x"]
130 | # load a single image to get image info.
131 | for idx, frame in enumerate(template_json["frames"]):
132 | cam_name = os.path.join(path, frame["file_path"] + extension)
133 | image_path = os.path.join(path, cam_name)
134 | image_name = Path(cam_name).stem
135 | image = Image.open(image_path)
136 | im_data = np.array(image.convert("RGBA"))
137 | image = PILtoTorch(image,(800,800))
138 | break
139 | # format information
140 | for idx, (time, poses) in enumerate(zip(render_times,render_poses)):
141 | time = time/maxtime
142 | matrix = np.linalg.inv(np.array(poses))
143 | R = -np.transpose(matrix[:3,:3])
144 | R[:,0] = -R[:,0]
145 | T = -matrix[:3, 3]
146 | fovy = focal2fov(fov2focal(fovx, image.shape[1]), image.shape[2])
147 | FovY = fovy
148 | FovX = fovx
149 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
150 | image_path=None, image_name=None, width=image.shape[1], height=image.shape[2],
151 | time = time))
152 | return cam_infos
153 |
154 | def readEndoNeRFInfo(datadir):
155 | # load camera infos
156 | from scene.endo_loader import EndoNeRF_Dataset
157 | endo_dataset = EndoNeRF_Dataset(
158 | datadir=datadir,
159 | downsample=1.0,
160 | )
161 | train_cam_infos = endo_dataset.format_infos(split="train")
162 | test_cam_infos = endo_dataset.format_infos(split="test")
163 | video_cam_infos = endo_dataset.format_infos(split="video")
164 |
165 | # get normalizations
166 | nerf_normalization = getNerfppNorm(train_cam_infos)
167 |
168 | # initialize sparse point clouds
169 | ply_path = os.path.join(datadir, "points3d.ply")
170 | xyz, rgb, normals = endo_dataset.get_sparse_pts()
171 |
172 | normals = np.random.random((xyz.shape[0], 3))
173 | pcd = BasicPointCloud(points=xyz, colors=rgb, normals=normals)
174 | storePly(ply_path, xyz,rgb*255)
175 |
176 | try:
177 | pcd = fetchPly(ply_path)
178 | except:
179 | pcd = None
180 |
181 | # get the maximum time
182 | maxtime = endo_dataset.get_maxtime()
183 |
184 | scene_info = SceneInfo(point_cloud=pcd,
185 | train_cameras=train_cam_infos,
186 | test_cameras=test_cam_infos,
187 | video_cameras=video_cam_infos,
188 | nerf_normalization=nerf_normalization,
189 | ply_path=ply_path,
190 | maxtime=maxtime)
191 |
192 | return scene_info
193 |
194 |
195 |
196 | sceneLoadTypeCallbacks = {
197 | "endonerf": readEndoNeRFInfo,
198 | }
199 |
--------------------------------------------------------------------------------
/scene/regulation.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import os
3 | from typing import Sequence
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.optim.lr_scheduler
9 | from torch import nn
10 |
11 |
12 |
13 | def compute_plane_tv(t):
14 | batch_size, c, h, w = t.shape
15 | count_h = batch_size * c * (h - 1) * w
16 | count_w = batch_size * c * h * (w - 1)
17 | h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum()
18 | w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum()
19 | return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg
20 |
21 |
22 | def compute_plane_smoothness(t):
23 | batch_size, c, h, w = t.shape
24 | # Convolve with a second derivative filter, in the time dimension which is dimension 2
25 | first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w]
26 | second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w]
27 | # Take the L2 norm of the result
28 | return torch.square(second_difference).mean()
29 |
30 |
31 | class Regularizer():
32 | def __init__(self, reg_type, initialization):
33 | self.reg_type = reg_type
34 | self.initialization = initialization
35 | self.weight = float(self.initialization)
36 | self.last_reg = None
37 |
38 | def step(self, global_step):
39 | pass
40 |
41 | def report(self, d):
42 | if self.last_reg is not None:
43 | d[self.reg_type].update(self.last_reg.item())
44 |
45 | def regularize(self, *args, **kwargs) -> torch.Tensor:
46 | out = self._regularize(*args, **kwargs) * self.weight
47 | self.last_reg = out.detach()
48 | return out
49 |
50 | @abc.abstractmethod
51 | def _regularize(self, *args, **kwargs) -> torch.Tensor:
52 | raise NotImplementedError()
53 |
54 | def __str__(self):
55 | return f"Regularizer({self.reg_type}, weight={self.weight})"
56 |
57 |
58 | class PlaneTV(Regularizer):
59 | def __init__(self, initial_value, what: str = 'field'):
60 | if what not in {'field', 'proposal_network'}:
61 | raise ValueError(f'what must be one of "field" or "proposal_network" '
62 | f'but {what} was passed.')
63 | name = f'planeTV-{what[:2]}'
64 | super().__init__(name, initial_value)
65 | self.what = what
66 |
67 | def step(self, global_step):
68 | pass
69 |
70 | def _regularize(self, model, **kwargs):
71 | multi_res_grids: Sequence[nn.ParameterList]
72 | if self.what == 'field':
73 | multi_res_grids = model.field.grids
74 | elif self.what == 'proposal_network':
75 | multi_res_grids = [p.grids for p in model.proposal_networks]
76 | else:
77 | raise NotImplementedError(self.what)
78 | total = 0
79 | # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w]
80 | for grids in multi_res_grids:
81 | if len(grids) == 3:
82 | spatial_grids = [0, 1, 2]
83 | else:
84 | spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal
85 | for grid_id in spatial_grids:
86 | total += compute_plane_tv(grids[grid_id])
87 | for grid in grids:
88 | # grid: [1, c, h, w]
89 | total += compute_plane_tv(grid)
90 | return total
91 |
92 |
93 | class TimeSmoothness(Regularizer):
94 | def __init__(self, initial_value, what: str = 'field'):
95 | if what not in {'field', 'proposal_network'}:
96 | raise ValueError(f'what must be one of "field" or "proposal_network" '
97 | f'but {what} was passed.')
98 | name = f'time-smooth-{what[:2]}'
99 | super().__init__(name, initial_value)
100 | self.what = what
101 |
102 | def _regularize(self, model, **kwargs) -> torch.Tensor:
103 | multi_res_grids: Sequence[nn.ParameterList]
104 | if self.what == 'field':
105 | multi_res_grids = model.field.grids
106 | elif self.what == 'proposal_network':
107 | multi_res_grids = [p.grids for p in model.proposal_networks]
108 | else:
109 | raise NotImplementedError(self.what)
110 | total = 0
111 | # model.grids is 6 x [1, rank * F_dim, reso, reso]
112 | for grids in multi_res_grids:
113 | if len(grids) == 3:
114 | time_grids = []
115 | else:
116 | time_grids = [2, 4, 5]
117 | for grid_id in time_grids:
118 | total += compute_plane_smoothness(grids[grid_id])
119 | return torch.as_tensor(total)
120 |
121 |
122 |
123 | class L1ProposalNetwork(Regularizer):
124 | def __init__(self, initial_value):
125 | super().__init__('l1-proposal-network', initial_value)
126 |
127 | def _regularize(self, model, **kwargs) -> torch.Tensor:
128 | grids = [p.grids for p in model.proposal_networks]
129 | total = 0.0
130 | for pn_grids in grids:
131 | for grid in pn_grids:
132 | total += torch.abs(grid).mean()
133 | return torch.as_tensor(total)
134 |
135 |
136 | class DepthTV(Regularizer):
137 | def __init__(self, initial_value):
138 | super().__init__('tv-depth', initial_value)
139 |
140 | def _regularize(self, model, model_out, **kwargs) -> torch.Tensor:
141 | depth = model_out['depth']
142 | tv = compute_plane_tv(
143 | depth.reshape(64, 64)[None, None, :, :]
144 | )
145 | return tv
146 |
147 |
148 | class L1TimePlanes(Regularizer):
149 | def __init__(self, initial_value, what='field'):
150 | if what not in {'field', 'proposal_network'}:
151 | raise ValueError(f'what must be one of "field" or "proposal_network" '
152 | f'but {what} was passed.')
153 | super().__init__(f'l1-time-{what[:2]}', initial_value)
154 | self.what = what
155 |
156 | def _regularize(self, model, **kwargs) -> torch.Tensor:
157 | # model.grids is 6 x [1, rank * F_dim, reso, reso]
158 | multi_res_grids: Sequence[nn.ParameterList]
159 | if self.what == 'field':
160 | multi_res_grids = model.field.grids
161 | elif self.what == 'proposal_network':
162 | multi_res_grids = [p.grids for p in model.proposal_networks]
163 | else:
164 | raise NotImplementedError(self.what)
165 |
166 | total = 0.0
167 | for grids in multi_res_grids:
168 | if len(grids) == 3:
169 | continue
170 | else:
171 | # These are the spatiotemporal grids
172 | spatiotemporal_grids = [2, 4, 5]
173 | for grid_id in spatiotemporal_grids:
174 | total += torch.abs(1 - grids[grid_id]).mean()
175 | return torch.as_tensor(total)
176 |
177 |
--------------------------------------------------------------------------------
/stereomis2endonerf.py:
--------------------------------------------------------------------------------
1 | from utils.stereo_rectify import StereoRectifier
2 | from submodules.RAFT.core.raft import RAFT
3 | from argparse import ArgumentParser, Action
4 | from torchvision.transforms import Resize, InterpolationMode
5 | from collections import OrderedDict
6 | import os
7 | import numpy as np
8 | import torch
9 | import cv2
10 | import shutil
11 |
12 | RAFT_config = {
13 | "pretrained": "submodules/RAFT/pretrained/raft-things.pth",
14 | "iters": 12,
15 | "dropout": 0.0,
16 | "small": False,
17 | "pose_scale": 1.0,
18 | "lbgfs_iters": 100,
19 | "use_weights": True,
20 | "dbg": False
21 | }
22 |
23 | def check_arg_limits(arg_name, n):
24 | class CheckArgLimits(Action):
25 | def __call__(self, parser, args, values, option_string=None):
26 | if len(values) > n:
27 | parser.error("Too many arguments for " + arg_name + ". Maximum is {0}.".format(n))
28 | if len(values) < n:
29 | parser.error("Too few arguments for " + arg_name + ". Minimum is {0}.".format(n))
30 | setattr(args, self.dest, values)
31 | return CheckArgLimits
32 |
33 | def read_mask(path):
34 | mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
35 | mask = mask > 0
36 | mask = torch.from_numpy(mask).unsqueeze(0)
37 | return mask
38 |
39 |
40 | class DepthEstimator(torch.nn.Module):
41 | def __init__(self, config):
42 | super(DepthEstimator, self).__init__()
43 | self.model = RAFT(config).to('cuda')
44 | self.model.freeze_bn()
45 | new_state_dict = OrderedDict()
46 | raft_ckp = config['pretrained']
47 | try:
48 | state_dict = torch.load(raft_ckp)
49 | except RuntimeError:
50 | state_dict = torch.load(raft_ckp, map_location='cpu')
51 | for k, v in state_dict.items():
52 | name = k.replace('module.','') # remove `module.`
53 | new_state_dict[name] = v
54 | self.model.load_state_dict(new_state_dict)
55 |
56 | def forward(self, imagel, imager, baseline, upsample=True):
57 | n, _, h, w = imagel.shape
58 | flow = self.model(imagel.to('cuda'), imager.to('cuda'), upsample=upsample)[0][-1]
59 | baseline = torch.from_numpy(baseline).to('cuda')
60 | depth = baseline[:, None, None] / -flow[:, 0]
61 | if not upsample:
62 | depth/= 8.0 # factor 8 of upsampling
63 | valid = (depth > 0) & (depth <= 250.0)
64 | depth[~valid] = 0.0
65 | return depth.unsqueeze(1)
66 |
67 | def reformat_dataset(data_dir, start_frame, end_frame, img_size=(512, 640)):
68 | """
69 | Reformat the StereoMIS to the same format as EndoNeRF dataset by stereo depth estimation.
70 | """
71 | # Load parameters after rectification
72 | calib_file = os.path.join(data_dir, 'StereoCalibration.ini')
73 | assert os.path.exists(calib_file), "Calibration file not found."
74 | rect = StereoRectifier(calib_file, img_size_new=(img_size[1], img_size[0]), mode='conventional')
75 | calib = rect.get_rectified_calib()
76 | baseline = calib['bf'].astype(np.float32)
77 | intrinsics = calib['intrinsics']['left'].astype(np.float32)
78 |
79 | # Sort images and masks according to the start and end frame indexes
80 | frames = sorted(os.listdir(os.path.join(data_dir, 'masks')))
81 | frames = [f for f in frames if 'l.png' in f and int(f.split('l.')[0]) >= start_frame and int(f.split('l.')[0]) <= end_frame]
82 | assert len(frames) > 0, "No frames found."
83 | resize = Resize(img_size)
84 | resize_msk = Resize(img_size, interpolation=InterpolationMode.NEAREST)
85 |
86 | # Configurate depth estimator. We follow the settings of RAFT in robust-pose-estimator(https://github.com/aimi-lab/robust-pose-estimator)
87 | depth_estimator = DepthEstimator(RAFT_config)
88 |
89 | # Create folders
90 | output_dir = os.path.join(data_dir, 'stereo_'+ os.path.basename(data_dir)+'_'+str(start_frame)+'_'+str(end_frame))
91 | image_dir = os.path.join(output_dir, 'images')
92 | mask_dir = os.path.join(output_dir, 'masks')
93 | depth_dir = os.path.join(output_dir, 'depth')
94 | if not os.path.exists(image_dir):
95 | os.makedirs(image_dir)
96 | if not os.path.exists(mask_dir):
97 | os.makedirs(mask_dir)
98 | if not os.path.exists(depth_dir):
99 | os.makedirs(depth_dir)
100 | poses_bounds = []
101 | for i, frame in enumerate(frames):
102 | left_img = torch.from_numpy(cv2.cvtColor(cv2.imread(os.path.join(data_dir, 'video_frames', frame)), cv2.COLOR_BGR2RGB)).permute(2, 0, 1).float()
103 | right_img = torch.from_numpy(cv2.cvtColor(cv2.imread(os.path.join(data_dir, 'video_frames', frame.replace('l', 'r'))), cv2.COLOR_BGR2RGB)).permute(2, 0, 1).float()
104 | left_img = resize(left_img)
105 | right_img = resize(right_img)
106 | with torch.no_grad():
107 | depth = depth_estimator(left_img[None], right_img[None], baseline[None])
108 | try:
109 | mask = read_mask(os.path.join(data_dir, 'masks', frame))
110 | mask = resize_msk(mask)
111 | except:
112 | mask = torch.ones(1, img_size[0], img_size[1])
113 |
114 | # Save the data. Of note, the file should start with 'stereo_' to be compatible with the dataloader in Deform3DGS.
115 | left_img_np = left_img.permute(1, 2, 0).numpy()
116 | mask_np = mask.permute(1, 2, 0).numpy()
117 | left_img_bgr = cv2.cvtColor(left_img_np, cv2.COLOR_RGB2BGR)
118 |
119 | # Save left_img, right_img, and mask to output_dir
120 | name = 'frame-'+str(i).zfill(6)+'.color.png'
121 | cv2.imwrite(os.path.join(image_dir, name), left_img_bgr)
122 | cv2.imwrite(os.path.join(mask_dir, name.replace('color','mask')), mask_np.astype(np.uint8) * 255)
123 | cv2.imwrite(os.path.join(depth_dir, name.replace('color','depth')), depth[0, 0].cpu().numpy())
124 |
125 | # Save poses_bounds.npy. Only static view is considered, i.e., R = I and T = 0.
126 | R = np.eye(3)
127 | T = np.zeros(3)
128 | extr = np.concatenate([R, T[:, None]], axis=1)
129 | cy, cx, focal = intrinsics[1, 2], intrinsics[0, 2], intrinsics[0, 0]
130 | param = np.concatenate([extr, np.array([[cy, cx, focal]]).T], axis=1)
131 | param = param.reshape(1, 15)
132 | param = np.concatenate([param, np.array([[0.03, 250.]])], axis=1)
133 | poses_bounds.append(param[0])
134 |
135 | np.save(os.path.join(output_dir, 'poses_bounds.npy'), np.array(poses_bounds))
136 |
137 | if __name__ == "__main__":
138 | torch.manual_seed(1234)
139 | np.random.seed(1234)
140 | # Set up command line argument parser
141 | parser = ArgumentParser(description="parameters for dataset format conversions")
142 | parser.add_argument('--data_dir', '-d', type=str, default='data/StereoMIS/P3')
143 | # Frame ID of the start and end of the sequence. Of note, only 2 arguments (start and end) are required.
144 | parser.add_argument('--frame_id', '-f',nargs="+", action=check_arg_limits('frame_id', 2), type=int, default=[9100, 9467])
145 | args = parser.parse_args()
146 | frame_id = args.frame_id
147 | reformat_dataset(args.data_dir, frame_id[0], frame_id[1])
--------------------------------------------------------------------------------
/submodules/RAFT/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | dist
4 | datasets
5 | pytorch_env
6 | models
7 | build
8 | correlation.egg-info
9 |
--------------------------------------------------------------------------------
/submodules/RAFT/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, princeton-vl
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/submodules/RAFT/RAFT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/RAFT.png
--------------------------------------------------------------------------------
/submodules/RAFT/README.md:
--------------------------------------------------------------------------------
1 | # RAFT
2 | This repository contains the source code for our paper:
3 |
4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 |
8 |
9 |
10 | ## Requirements
11 | The code has been tested with PyTorch 1.6 and Cuda 10.1.
12 | ```Shell
13 | conda create --name raft
14 | conda activate raft
15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
16 | ```
17 |
18 | ## Demos
19 | Pretrained models can be downloaded by running
20 | ```Shell
21 | ./download_models.sh
22 | ```
23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
24 |
25 | You can demo a trained model on a sequence of frames
26 | ```Shell
27 | python demo.py --model=models/raft-things.pth --path=demo-frames
28 | ```
29 |
30 | ## Required Data
31 | To evaluate/train RAFT, you will need to download the required datasets.
32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
34 | * [Sintel](http://sintel.is.tue.mpg.de/)
35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
37 |
38 |
39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
40 |
41 | ```Shell
42 | ├── datasets
43 | ├── Sintel
44 | ├── test
45 | ├── training
46 | ├── KITTI
47 | ├── testing
48 | ├── training
49 | ├── devkit
50 | ├── FlyingChairs_release
51 | ├── data
52 | ├── FlyingThings3D
53 | ├── frames_cleanpass
54 | ├── frames_finalpass
55 | ├── optical_flow
56 | ```
57 |
58 | ## Evaluation
59 | You can evaluate a trained model using `evaluate.py`
60 | ```Shell
61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision
62 | ```
63 |
64 | ## Training
65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
66 | ```Shell
67 | ./train_standard.sh
68 | ```
69 |
70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
71 | ```Shell
72 | ./train_mixed.sh
73 | ```
74 |
75 | ## (Optional) Efficent Implementation
76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
77 | ```Shell
78 | cd alt_cuda_corr && python setup.py install && cd ..
79 | ```
80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
81 |
--------------------------------------------------------------------------------
/submodules/RAFT/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 | std::vector corr_cuda_forward(
6 | torch::Tensor fmap1,
7 | torch::Tensor fmap2,
8 | torch::Tensor coords,
9 | int radius);
10 |
11 | std::vector corr_cuda_backward(
12 | torch::Tensor fmap1,
13 | torch::Tensor fmap2,
14 | torch::Tensor coords,
15 | torch::Tensor corr_grad,
16 | int radius);
17 |
18 | // C++ interface
19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22 |
23 | std::vector corr_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius) {
28 | CHECK_INPUT(fmap1);
29 | CHECK_INPUT(fmap2);
30 | CHECK_INPUT(coords);
31 |
32 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
33 | }
34 |
35 |
36 | std::vector corr_backward(
37 | torch::Tensor fmap1,
38 | torch::Tensor fmap2,
39 | torch::Tensor coords,
40 | torch::Tensor corr_grad,
41 | int radius) {
42 | CHECK_INPUT(fmap1);
43 | CHECK_INPUT(fmap2);
44 | CHECK_INPUT(coords);
45 | CHECK_INPUT(corr_grad);
46 |
47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48 | }
49 |
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &corr_forward, "CORR forward");
53 | m.def("backward", &corr_backward, "CORR backward");
54 | }
--------------------------------------------------------------------------------
/submodules/RAFT/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name='correlation',
7 | ext_modules=[
8 | CUDAExtension('alt_cuda_corr',
9 | sources=['correlation.cpp', 'correlation_kernel.cu'],
10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
16 |
--------------------------------------------------------------------------------
/submodules/RAFT/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/core/__init__.py
--------------------------------------------------------------------------------
/submodules/RAFT/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39 | delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/submodules/RAFT/core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .update import BasicUpdateBlock, SmallUpdateBlock
7 | from .extractor import BasicEncoder, SmallEncoder
8 | from .corr import CorrBlock, AlternateCorrBlock
9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, config):
26 | super(RAFT, self).__init__()
27 | self.config = config
28 |
29 | if config['small']:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | config['corr_levels'] = 4
33 | config['corr_radius'] = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | config['corr_levels'] = 4
39 | config['corr_radius'] = 4
40 |
41 | # feature network, context network, and update block
42 | if config['small']:
43 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=config['dropout'])
44 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=config['dropout'])
45 | self.update_block = SmallUpdateBlock(config, hidden_dim=hdim)
46 |
47 | else:
48 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=config['dropout'])
49 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=config['dropout'])
50 | self.update_block = BasicUpdateBlock(config, hidden_dim=hdim)
51 |
52 | def freeze_bn(self):
53 | for m in self.modules():
54 | if isinstance(m, nn.BatchNorm2d):
55 | m.eval()
56 |
57 | def initialize_flow(self, img):
58 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
59 | N, C, H, W = img.shape
60 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
61 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
62 |
63 | # optical flow computed as difference: flow = coords1 - coords0
64 | return coords0, coords1
65 |
66 | def upsample_flow(self, flow, mask):
67 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
68 | N, _, H, W = flow.shape
69 | mask = mask.view(N, 1, 9, 8, 8, H, W)
70 | mask = torch.softmax(mask, dim=2)
71 |
72 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
73 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
74 |
75 | up_flow = torch.sum(mask * up_flow, dim=2)
76 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
77 | return up_flow.reshape(N, 2, 8*H, 8*W)
78 |
79 | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False, upsample=True):
80 | """ Estimate optical flow between pair of frames """
81 |
82 | image1 = 2 * (image1 / 255.0) - 1.0
83 | image2 = 2 * (image2 / 255.0) - 1.0
84 |
85 | image1 = image1.contiguous()
86 | image2 = image2.contiguous()
87 |
88 | hdim = self.hidden_dim
89 | cdim = self.context_dim
90 |
91 | # run the feature network
92 | with autocast():
93 | fmap1, fmap2 = self.fnet([image1, image2])
94 |
95 | fmap1 = fmap1.float()
96 | fmap2 = fmap2.float()
97 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.config['corr_radius'])
98 |
99 | # run the context network
100 | with autocast():
101 | cnet = self.cnet(image1)
102 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
103 | net = torch.tanh(net)
104 | inp = torch.relu(inp)
105 |
106 | coords0, coords1 = self.initialize_flow(image1)
107 |
108 | if flow_init is not None:
109 | coords1 = coords1 + flow_init
110 |
111 | flow_predictions = []
112 | for itr in range(iters):
113 | coords1 = coords1.detach()
114 | corr = corr_fn(coords1) # index correlation volume
115 |
116 | flow = coords1 - coords0
117 | with autocast():
118 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
119 |
120 | # F(t+1) = F(t) + \Delta(t)
121 | coords1 = coords1 + delta_flow
122 |
123 | # upsample predictions
124 | if upsample:
125 | if up_mask is None:
126 | flow_up = upflow8(coords1 - coords0)
127 | else:
128 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
129 | else:
130 | flow_up = coords1 - coords0
131 |
132 | flow_predictions.append(flow_up)
133 |
134 | if test_mode:
135 | return coords1 - coords0, flow_up
136 |
137 | return flow_predictions, net, inp
138 |
--------------------------------------------------------------------------------
/submodules/RAFT/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args['corr_levels'] * (2*args['corr_radius'] + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args['corr_levels'] * (2*args['corr_radius'] + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/submodules/RAFT/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/core/utils/__init__.py
--------------------------------------------------------------------------------
/submodules/RAFT/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/submodules/RAFT/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/submodules/RAFT/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd, device):
75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0016.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0017.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0018.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0019.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0020.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0021.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0021.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0022.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0022.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0023.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0024.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo-frames/frame_0025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/demo-frames/frame_0025.png
--------------------------------------------------------------------------------
/submodules/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | import argparse
5 | import os
6 | import cv2
7 | import glob
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 | from raft import RAFT
13 | from utils import flow_viz
14 | from utils.utils import InputPadder
15 |
16 |
17 |
18 | DEVICE = 'cuda'
19 |
20 | def load_image(imfile):
21 | img = np.array(Image.open(imfile)).astype(np.uint8)
22 | img = torch.from_numpy(img).permute(2, 0, 1).float()
23 | return img[None].to(DEVICE)
24 |
25 |
26 | def viz(img, flo):
27 | img = img[0].permute(1,2,0).cpu().numpy()
28 | flo = flo[0].permute(1,2,0).cpu().numpy()
29 |
30 | # map flow to rgb image
31 | flo = flow_viz.flow_to_image(flo)
32 | img_flo = np.concatenate([img, flo], axis=0)
33 |
34 | # import matplotlib.pyplot as plt
35 | # plt.imshow(img_flo / 255.0)
36 | # plt.show()
37 |
38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
39 | cv2.waitKey()
40 |
41 |
42 | def demo(args):
43 | model = torch.nn.DataParallel(RAFT(args))
44 | model.load_state_dict(torch.load(args.model))
45 |
46 | model = model.module
47 | model.to(DEVICE)
48 | model.eval()
49 |
50 | with torch.no_grad():
51 | images = glob.glob(os.path.join(args.path, '*.png')) + \
52 | glob.glob(os.path.join(args.path, '*.jpg'))
53 |
54 | images = sorted(images)
55 | for imfile1, imfile2 in zip(images[:-1], images[1:]):
56 | image1 = load_image(imfile1)
57 | image2 = load_image(imfile2)
58 |
59 | padder = InputPadder(image1.shape)
60 | image1, image2 = padder.pad(image1, image2)
61 |
62 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
63 | viz(image1, flow_up)
64 |
65 |
66 | if __name__ == '__main__':
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('--model', help="restore checkpoint")
69 | parser.add_argument('--path', help="dataset for evaluation")
70 | parser.add_argument('--small', action='store_true', help='use small model')
71 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
72 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
73 | args = parser.parse_args()
74 |
75 | demo(args)
76 |
--------------------------------------------------------------------------------
/submodules/RAFT/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
3 | unzip models.zip
4 |
--------------------------------------------------------------------------------
/submodules/RAFT/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | from PIL import Image
5 | import argparse
6 | import os
7 | import time
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import matplotlib.pyplot as plt
12 |
13 | import datasets
14 | from utils import flow_viz
15 | from utils import frame_utils
16 |
17 | from raft import RAFT
18 | from utils.utils import InputPadder, forward_interpolate
19 |
20 |
21 | @torch.no_grad()
22 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
23 | """ Create submission for the Sintel leaderboard """
24 | model.eval()
25 | for dstype in ['clean', 'final']:
26 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
27 |
28 | flow_prev, sequence_prev = None, None
29 | for test_id in range(len(test_dataset)):
30 | image1, image2, (sequence, frame) = test_dataset[test_id]
31 | if sequence != sequence_prev:
32 | flow_prev = None
33 |
34 | padder = InputPadder(image1.shape)
35 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
36 |
37 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
38 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
39 |
40 | if warm_start:
41 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
42 |
43 | output_dir = os.path.join(output_path, dstype, sequence)
44 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
45 |
46 | if not os.path.exists(output_dir):
47 | os.makedirs(output_dir)
48 |
49 | frame_utils.writeFlow(output_file, flow)
50 | sequence_prev = sequence
51 |
52 |
53 | @torch.no_grad()
54 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
55 | """ Create submission for the Sintel leaderboard """
56 | model.eval()
57 | test_dataset = datasets.KITTI(split='testing', aug_params=None)
58 |
59 | if not os.path.exists(output_path):
60 | os.makedirs(output_path)
61 |
62 | for test_id in range(len(test_dataset)):
63 | image1, image2, (frame_id, ) = test_dataset[test_id]
64 | padder = InputPadder(image1.shape, mode='kitti')
65 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
66 |
67 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
68 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
69 |
70 | output_filename = os.path.join(output_path, frame_id)
71 | frame_utils.writeFlowKITTI(output_filename, flow)
72 |
73 |
74 | @torch.no_grad()
75 | def validate_chairs(model, iters=24):
76 | """ Perform evaluation on the FlyingChairs (test) split """
77 | model.eval()
78 | epe_list = []
79 |
80 | val_dataset = datasets.FlyingChairs(split='validation')
81 | for val_id in range(len(val_dataset)):
82 | image1, image2, flow_gt, _ = val_dataset[val_id]
83 | image1 = image1[None].cuda()
84 | image2 = image2[None].cuda()
85 |
86 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
87 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
88 | epe_list.append(epe.view(-1).numpy())
89 |
90 | epe = np.mean(np.concatenate(epe_list))
91 | print("Validation Chairs EPE: %f" % epe)
92 | return {'chairs': epe}
93 |
94 |
95 | @torch.no_grad()
96 | def validate_sintel(model, iters=32):
97 | """ Peform validation using the Sintel (train) split """
98 | model.eval()
99 | results = {}
100 | for dstype in ['clean', 'final']:
101 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
102 | epe_list = []
103 |
104 | for val_id in range(len(val_dataset)):
105 | image1, image2, flow_gt, _ = val_dataset[val_id]
106 | image1 = image1[None].cuda()
107 | image2 = image2[None].cuda()
108 |
109 | padder = InputPadder(image1.shape)
110 | image1, image2 = padder.pad(image1, image2)
111 |
112 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
113 | flow = padder.unpad(flow_pr[0]).cpu()
114 |
115 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
116 | epe_list.append(epe.view(-1).numpy())
117 |
118 | epe_all = np.concatenate(epe_list)
119 | epe = np.mean(epe_all)
120 | px1 = np.mean(epe_all<1)
121 | px3 = np.mean(epe_all<3)
122 | px5 = np.mean(epe_all<5)
123 |
124 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
125 | results[dstype] = np.mean(epe_list)
126 |
127 | return results
128 |
129 |
130 | @torch.no_grad()
131 | def validate_kitti(model, iters=24):
132 | """ Peform validation using the KITTI-2015 (train) split """
133 | model.eval()
134 | val_dataset = datasets.KITTI(split='training')
135 |
136 | out_list, epe_list = [], []
137 | for val_id in range(len(val_dataset)):
138 | image1, image2, flow_gt, valid_gt = val_dataset[val_id]
139 | image1 = image1[None].cuda()
140 | image2 = image2[None].cuda()
141 |
142 | padder = InputPadder(image1.shape, mode='kitti')
143 | image1, image2 = padder.pad(image1, image2)
144 |
145 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
146 | flow = padder.unpad(flow_pr[0]).cpu()
147 |
148 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
149 | mag = torch.sum(flow_gt**2, dim=0).sqrt()
150 |
151 | epe = epe.view(-1)
152 | mag = mag.view(-1)
153 | val = valid_gt.view(-1) >= 0.5
154 |
155 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
156 | epe_list.append(epe[val].mean().item())
157 | out_list.append(out[val].cpu().numpy())
158 |
159 | epe_list = np.array(epe_list)
160 | out_list = np.concatenate(out_list)
161 |
162 | epe = np.mean(epe_list)
163 | f1 = 100 * np.mean(out_list)
164 |
165 | print("Validation KITTI: %f, %f" % (epe, f1))
166 | return {'kitti-epe': epe, 'kitti-f1': f1}
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser()
171 | parser.add_argument('--model', help="restore checkpoint")
172 | parser.add_argument('--dataset', help="dataset for evaluation")
173 | parser.add_argument('--small', action='store_true', help='use small model')
174 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
175 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
176 | args = parser.parse_args()
177 |
178 | model = torch.nn.DataParallel(RAFT(args))
179 | model.load_state_dict(torch.load(args.model))
180 |
181 | model.cuda()
182 | model.eval()
183 |
184 | # create_sintel_submission(model.module, warm_start=True)
185 | # create_kitti_submission(model.module)
186 |
187 | with torch.no_grad():
188 | if args.dataset == 'chairs':
189 | validate_chairs(model.module)
190 |
191 | elif args.dataset == 'sintel':
192 | validate_sintel(model.module)
193 |
194 | elif args.dataset == 'kitti':
195 | validate_kitti(model.module)
196 |
197 |
198 |
--------------------------------------------------------------------------------
/submodules/RAFT/pretrained/raft-things.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/RAFT/pretrained/raft-things.pth
--------------------------------------------------------------------------------
/submodules/RAFT/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import sys
3 | sys.path.append('core')
4 |
5 | import argparse
6 | import os
7 | import cv2
8 | import time
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | import torch.nn.functional as F
16 |
17 | from torch.utils.data import DataLoader
18 | from raft import RAFT
19 | import evaluate
20 | import datasets
21 |
22 | from torch.utils.tensorboard import SummaryWriter
23 |
24 | try:
25 | from torch.cuda.amp import GradScaler
26 | except:
27 | # dummy GradScaler for PyTorch < 1.6
28 | class GradScaler:
29 | def __init__(self):
30 | pass
31 | def scale(self, loss):
32 | return loss
33 | def unscale_(self, optimizer):
34 | pass
35 | def step(self, optimizer):
36 | optimizer.step()
37 | def update(self):
38 | pass
39 |
40 |
41 | # exclude extremly large displacements
42 | MAX_FLOW = 400
43 | SUM_FREQ = 100
44 | VAL_FREQ = 5000
45 |
46 |
47 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
48 | """ Loss function defined over sequence of flow predictions """
49 |
50 | n_predictions = len(flow_preds)
51 | flow_loss = 0.0
52 |
53 | # exlude invalid pixels and extremely large diplacements
54 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
55 | valid = (valid >= 0.5) & (mag < max_flow)
56 |
57 | for i in range(n_predictions):
58 | i_weight = gamma**(n_predictions - i - 1)
59 | i_loss = (flow_preds[i] - flow_gt).abs()
60 | flow_loss += i_weight * (valid[:, None] * i_loss).mean()
61 |
62 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
63 | epe = epe.view(-1)[valid.view(-1)]
64 |
65 | metrics = {
66 | 'epe': epe.mean().item(),
67 | '1px': (epe < 1).float().mean().item(),
68 | '3px': (epe < 3).float().mean().item(),
69 | '5px': (epe < 5).float().mean().item(),
70 | }
71 |
72 | return flow_loss, metrics
73 |
74 |
75 | def count_parameters(model):
76 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
77 |
78 |
79 | def fetch_optimizer(args, model):
80 | """ Create the optimizer and learning rate scheduler """
81 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
82 |
83 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
84 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
85 |
86 | return optimizer, scheduler
87 |
88 |
89 | class Logger:
90 | def __init__(self, model, scheduler):
91 | self.model = model
92 | self.scheduler = scheduler
93 | self.total_steps = 0
94 | self.running_loss = {}
95 | self.writer = None
96 |
97 | def _print_training_status(self):
98 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
99 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
100 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
101 |
102 | # print the training status
103 | print(training_str + metrics_str)
104 |
105 | if self.writer is None:
106 | self.writer = SummaryWriter()
107 |
108 | for k in self.running_loss:
109 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
110 | self.running_loss[k] = 0.0
111 |
112 | def push(self, metrics):
113 | self.total_steps += 1
114 |
115 | for key in metrics:
116 | if key not in self.running_loss:
117 | self.running_loss[key] = 0.0
118 |
119 | self.running_loss[key] += metrics[key]
120 |
121 | if self.total_steps % SUM_FREQ == SUM_FREQ-1:
122 | self._print_training_status()
123 | self.running_loss = {}
124 |
125 | def write_dict(self, results):
126 | if self.writer is None:
127 | self.writer = SummaryWriter()
128 |
129 | for key in results:
130 | self.writer.add_scalar(key, results[key], self.total_steps)
131 |
132 | def close(self):
133 | self.writer.close()
134 |
135 |
136 | def train(args):
137 |
138 | model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
139 | print("Parameter Count: %d" % count_parameters(model))
140 |
141 | if args.restore_ckpt is not None:
142 | model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
143 |
144 | model.cuda()
145 | model.train()
146 |
147 | if args.stage != 'chairs':
148 | model.module.freeze_bn()
149 |
150 | train_loader = datasets.fetch_dataloader(args)
151 | optimizer, scheduler = fetch_optimizer(args, model)
152 |
153 | total_steps = 0
154 | scaler = GradScaler(enabled=args.mixed_precision)
155 | logger = Logger(model, scheduler)
156 |
157 | VAL_FREQ = 5000
158 | add_noise = True
159 |
160 | should_keep_training = True
161 | while should_keep_training:
162 |
163 | for i_batch, data_blob in enumerate(train_loader):
164 | optimizer.zero_grad()
165 | image1, image2, flow, valid = [x.cuda() for x in data_blob]
166 |
167 | if args.add_noise:
168 | stdv = np.random.uniform(0.0, 5.0)
169 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
170 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
171 |
172 | flow_predictions = model(image1, image2, iters=args.iters)
173 |
174 | loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
175 | scaler.scale(loss).backward()
176 | scaler.unscale_(optimizer)
177 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
178 |
179 | scaler.step(optimizer)
180 | scheduler.step()
181 | scaler.update()
182 |
183 | logger.push(metrics)
184 |
185 | if total_steps % VAL_FREQ == VAL_FREQ - 1:
186 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
187 | torch.save(model.state_dict(), PATH)
188 |
189 | results = {}
190 | for val_dataset in args.validation:
191 | if val_dataset == 'chairs':
192 | results.update(evaluate.validate_chairs(model.module))
193 | elif val_dataset == 'sintel':
194 | results.update(evaluate.validate_sintel(model.module))
195 | elif val_dataset == 'kitti':
196 | results.update(evaluate.validate_kitti(model.module))
197 |
198 | logger.write_dict(results)
199 |
200 | model.train()
201 | if args.stage != 'chairs':
202 | model.module.freeze_bn()
203 |
204 | total_steps += 1
205 |
206 | if total_steps > args.num_steps:
207 | should_keep_training = False
208 | break
209 |
210 | logger.close()
211 | PATH = 'checkpoints/%s.pth' % args.name
212 | torch.save(model.state_dict(), PATH)
213 |
214 | return PATH
215 |
216 |
217 | if __name__ == '__main__':
218 | parser = argparse.ArgumentParser()
219 | parser.add_argument('--name', default='raft', help="name your experiment")
220 | parser.add_argument('--stage', help="determines which dataset to use for training")
221 | parser.add_argument('--restore_ckpt', help="restore checkpoint")
222 | parser.add_argument('--small', action='store_true', help='use small model')
223 | parser.add_argument('--validation', type=str, nargs='+')
224 |
225 | parser.add_argument('--lr', type=float, default=0.00002)
226 | parser.add_argument('--num_steps', type=int, default=100000)
227 | parser.add_argument('--batch_size', type=int, default=6)
228 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
229 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
230 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
231 |
232 | parser.add_argument('--iters', type=int, default=12)
233 | parser.add_argument('--wdecay', type=float, default=.00005)
234 | parser.add_argument('--epsilon', type=float, default=1e-8)
235 | parser.add_argument('--clip', type=float, default=1.0)
236 | parser.add_argument('--dropout', type=float, default=0.0)
237 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
238 | parser.add_argument('--add_noise', action='store_true')
239 | args = parser.parse_args()
240 |
241 | torch.manual_seed(1234)
242 | np.random.seed(1234)
243 |
244 | if not os.path.isdir('checkpoints'):
245 | os.mkdir('checkpoints')
246 |
247 | train(args)
--------------------------------------------------------------------------------
/submodules/RAFT/train_mixed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
7 |
--------------------------------------------------------------------------------
/submodules/RAFT/train_standard.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
7 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | diff_gaussian_rasterization.egg-info/
3 | dist/
4 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/glm"]
2 | path = third_party/glm
3 | url = https://github.com/g-truc/glm.git
4 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | cmake_minimum_required(VERSION 3.20)
13 |
14 | project(DiffRast LANGUAGES CUDA CXX)
15 |
16 | set(CMAKE_CXX_STANDARD 17)
17 | set(CMAKE_CXX_EXTENSIONS OFF)
18 | set(CMAKE_CUDA_STANDARD 17)
19 |
20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
21 |
22 | add_library(CudaRasterizer
23 | cuda_rasterizer/backward.h
24 | cuda_rasterizer/backward.cu
25 | cuda_rasterizer/forward.h
26 | cuda_rasterizer/forward.cu
27 | cuda_rasterizer/auxiliary.h
28 | cuda_rasterizer/rasterizer_impl.cu
29 | cuda_rasterizer/rasterizer_impl.h
30 | cuda_rasterizer/rasterizer.h
31 | )
32 |
33 | set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86")
34 |
35 | target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer)
36 | target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
37 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/LICENSE.md:
--------------------------------------------------------------------------------
1 | Gaussian-Splatting License
2 | ===========================
3 |
4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5 | The *Software* is in the process of being registered with the Agence pour la Protection des
6 | Programmes (APP).
7 |
8 | The *Software* is still being developed by the *Licensor*.
9 |
10 | *Licensor*'s goal is to allow the research community to use, test and evaluate
11 | the *Software*.
12 |
13 | ## 1. Definitions
14 |
15 | *Licensee* means any person or entity that uses the *Software* and distributes
16 | its *Work*.
17 |
18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII
19 |
20 | *Software* means the original work of authorship made available under this
21 | License ie gaussian-splatting.
22 |
23 | *Work* means the *Software* and any additions to or derivative works of the
24 | *Software* that are made available under this License.
25 |
26 |
27 | ## 2. Purpose
28 | This license is intended to define the rights granted to the *Licensee* by
29 | Licensors under the *Software*.
30 |
31 | ## 3. Rights granted
32 |
33 | For the above reasons Licensors have decided to distribute the *Software*.
34 | Licensors grant non-exclusive rights to use the *Software* for research purposes
35 | to research users (both academic and industrial), free of charge, without right
36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37 | and/or evaluation purposes only.
38 |
39 | Subject to the terms and conditions of this License, you are granted a
40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41 | publicly display, publicly perform and distribute its *Work* and any resulting
42 | derivative works in any form.
43 |
44 | ## 4. Limitations
45 |
46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47 | so under this License, (b) you include a complete copy of this License with
48 | your distribution, and (c) you retain without modification any copyright,
49 | patent, trademark, or attribution notices that are present in the *Work*.
50 |
51 | **4.2 Derivative Works.** You may specify that additional or different terms apply
52 | to the use, reproduction, and distribution of your derivative works of the *Work*
53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in
54 | Section 2 applies to your derivative works, and (b) you identify the specific
55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56 | this License (including the redistribution requirements in Section 3.1) will
57 | continue to apply to the *Work* itself.
58 |
59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60 | users explicitly acknowledge having received from Licensors all information
61 | allowing to appreciate the adequacy between of the *Software* and their needs and
62 | to undertake all necessary precautions for its execution and use.
63 |
64 | **4.4** The *Software* is provided both as a compiled library file and as source
65 | code. In case of using the *Software* for a publication or other results obtained
66 | through the use of the *Software*, users are strongly encouraged to cite the
67 | corresponding publications as explained in the documentation of the *Software*.
68 |
69 | ## 5. Disclaimer
70 |
71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
84 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/README.md:
--------------------------------------------------------------------------------
1 | # Differential Gaussian Rasterization
2 |
3 | Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us.
4 |
5 |
6 |
7 |
BibTeX
8 |
@Article{kerbl3Dgaussians,
9 | author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
10 | title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
11 | journal = {ACM Transactions on Graphics},
12 | number = {4},
13 | volume = {42},
14 | month = {July},
15 | year = {2023},
16 | url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
17 | }
18 |
19 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
13 | #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
14 |
15 | #include "config.h"
16 | #include "stdio.h"
17 |
18 | #define BLOCK_SIZE (BLOCK_X * BLOCK_Y)
19 | #define NUM_WARPS (BLOCK_SIZE/32)
20 |
21 | // Spherical harmonics coefficients
22 | __device__ const float SH_C0 = 0.28209479177387814f;
23 | __device__ const float SH_C1 = 0.4886025119029199f;
24 | __device__ const float SH_C2[] = {
25 | 1.0925484305920792f,
26 | -1.0925484305920792f,
27 | 0.31539156525252005f,
28 | -1.0925484305920792f,
29 | 0.5462742152960396f
30 | };
31 | __device__ const float SH_C3[] = {
32 | -0.5900435899266435f,
33 | 2.890611442640554f,
34 | -0.4570457994644658f,
35 | 0.3731763325901154f,
36 | -0.4570457994644658f,
37 | 1.445305721320277f,
38 | -0.5900435899266435f
39 | };
40 |
41 | __forceinline__ __device__ float ndc2Pix(float v, int S)
42 | {
43 | return ((v + 1.0) * S - 1.0) * 0.5;
44 | }
45 |
46 | __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
47 | {
48 | rect_min = {
49 | min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
50 | min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))
51 | };
52 | rect_max = {
53 | min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
54 | min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))
55 | };
56 | }
57 |
58 | __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix)
59 | {
60 | float3 transformed = {
61 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
62 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
63 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
64 | };
65 | return transformed;
66 | }
67 |
68 | __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix)
69 | {
70 | float4 transformed = {
71 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
72 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
73 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
74 | matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]
75 | };
76 | return transformed;
77 | }
78 |
79 | __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix)
80 | {
81 | float3 transformed = {
82 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z,
83 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z,
84 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z,
85 | };
86 | return transformed;
87 | }
88 |
89 | __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix)
90 | {
91 | float3 transformed = {
92 | matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z,
93 | matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z,
94 | matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z,
95 | };
96 | return transformed;
97 | }
98 |
99 | __forceinline__ __device__ float dnormvdz(float3 v, float3 dv)
100 | {
101 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
102 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
103 | float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
104 | return dnormvdz;
105 | }
106 |
107 | __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv)
108 | {
109 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
110 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
111 |
112 | float3 dnormvdv;
113 | dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32;
114 | dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32;
115 | dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
116 | return dnormvdv;
117 | }
118 |
119 | __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv)
120 | {
121 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
122 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
123 |
124 | float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w };
125 | float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w;
126 | float4 dnormvdv;
127 | dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32;
128 | dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32;
129 | dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32;
130 | dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32;
131 | return dnormvdv;
132 | }
133 |
134 | __forceinline__ __device__ float sigmoid(float x)
135 | {
136 | return 1.0f / (1.0f + expf(-x));
137 | }
138 |
139 | __forceinline__ __device__ bool in_frustum(int idx,
140 | const float* orig_points,
141 | const float* viewmatrix,
142 | const float* projmatrix,
143 | bool prefiltered,
144 | float3& p_view)
145 | {
146 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
147 |
148 | // Bring points to screen space
149 | float4 p_hom = transformPoint4x4(p_orig, projmatrix);
150 | float p_w = 1.0f / (p_hom.w + 0.0000001f);
151 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
152 | p_view = transformPoint4x3(p_orig, viewmatrix);
153 |
154 | if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
155 | {
156 | if (prefiltered)
157 | {
158 | printf("Point is filtered although prefiltered is set. This shouldn't happen!");
159 | __trap();
160 | }
161 | return false;
162 | }
163 | return true;
164 | }
165 |
166 | #define CHECK_CUDA(A, debug) \
167 | A; if(debug) { \
168 | auto ret = cudaDeviceSynchronize(); \
169 | if (ret != cudaSuccess) { \
170 | std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \
171 | throw std::runtime_error(cudaGetErrorString(ret)); \
172 | } \
173 | }
174 |
175 | #endif
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/backward.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED
13 | #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED
14 |
15 | #include
16 | #include "cuda_runtime.h"
17 | #include "device_launch_parameters.h"
18 | #define GLM_FORCE_CUDA
19 | #include
20 |
21 | namespace BACKWARD
22 | {
23 | void render(
24 | const dim3 grid, dim3 block,
25 | const uint2* ranges,
26 | const uint32_t* point_list,
27 | int W, int H,
28 | const float* bg_color,
29 | const float2* means2D,
30 | const float4* conic_opacity,
31 | const float* colors,
32 | const float* depths,
33 | const float* final_Ts,
34 | const uint32_t* n_contrib,
35 | const float* dL_dpixels,
36 | const float* dL_depths,
37 | float3* dL_dmean2D,
38 | float4* dL_dconic2D,
39 | float* dL_dopacity,
40 | float* dL_dcolors);
41 |
42 | void preprocess(
43 | int P, int D, int M,
44 | const float3* means,
45 | const int* radii,
46 | const float* shs,
47 | const bool* clamped,
48 | const glm::vec3* scales,
49 | const glm::vec4* rotations,
50 | const float scale_modifier,
51 | const float* cov3Ds,
52 | const float* view,
53 | const float* proj,
54 | const float focal_x, float focal_y,
55 | const float tan_fovx, float tan_fovy,
56 | const glm::vec3* campos,
57 | const float3* dL_dmean2D,
58 | const float* dL_dconics,
59 | glm::vec3* dL_dmeans,
60 | float* dL_dcolor,
61 | float* dL_dcov3D,
62 | float* dL_dsh,
63 | glm::vec3* dL_dscale,
64 | glm::vec4* dL_drot);
65 | }
66 |
67 | #endif
68 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/config.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED
13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED
14 |
15 | #define NUM_CHANNELS 3 // Default 3, RGB
16 | #define BLOCK_X 16
17 | #define BLOCK_Y 16
18 |
19 | #endif
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/forward.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED
13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED
14 |
15 | #include
16 | #include "cuda_runtime.h"
17 | #include "device_launch_parameters.h"
18 | #define GLM_FORCE_CUDA
19 | #include
20 |
21 | namespace FORWARD
22 | {
23 | // Perform initial steps for each Gaussian prior to rasterization.
24 | void preprocess(int P, int D, int M,
25 | const float* orig_points,
26 | const glm::vec3* scales,
27 | const float scale_modifier,
28 | const glm::vec4* rotations,
29 | const float* opacities,
30 | const float* shs,
31 | bool* clamped,
32 | const float* cov3D_precomp,
33 | const float* colors_precomp,
34 | const float* viewmatrix,
35 | const float* projmatrix,
36 | const glm::vec3* cam_pos,
37 | const int W, int H,
38 | const float focal_x, float focal_y,
39 | const float tan_fovx, float tan_fovy,
40 | int* radii,
41 | float2* points_xy_image,
42 | float* depths,
43 | float* cov3Ds,
44 | float* colors,
45 | float4* conic_opacity,
46 | const dim3 grid,
47 | uint32_t* tiles_touched,
48 | bool prefiltered);
49 |
50 | // Main rasterization method.
51 | void render(
52 | const dim3 grid, dim3 block,
53 | const uint2* ranges,
54 | const uint32_t* point_list,
55 | int W, int H,
56 | const float2* points_xy_image,
57 | const float* features,
58 | const float* depths,
59 | const float4* conic_opacity,
60 | float* final_T,
61 | uint32_t* n_contrib,
62 | const float* bg_color,
63 | float* out_color,
64 | float* out_depth);
65 | }
66 |
67 |
68 | #endif
69 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #ifndef CUDA_RASTERIZER_H_INCLUDED
13 | #define CUDA_RASTERIZER_H_INCLUDED
14 |
15 | #include
16 | #include
17 |
18 | namespace CudaRasterizer
19 | {
20 | class Rasterizer
21 | {
22 | public:
23 |
24 | static void markVisible(
25 | int P,
26 | float* means3D,
27 | float* viewmatrix,
28 | float* projmatrix,
29 | bool* present);
30 |
31 | static int forward(
32 | std::function geometryBuffer,
33 | std::function binningBuffer,
34 | std::function imageBuffer,
35 | const int P, int D, int M,
36 | const float* background,
37 | const int width, int height,
38 | const float* means3D,
39 | const float* shs,
40 | const float* colors_precomp,
41 | const float* opacities,
42 | const float* scales,
43 | const float scale_modifier,
44 | const float* rotations,
45 | const float* cov3D_precomp,
46 | const float* viewmatrix,
47 | const float* projmatrix,
48 | const float* cam_pos,
49 | const float tan_fovx, float tan_fovy,
50 | const bool prefiltered,
51 | float* out_color,
52 | float* out_depth,
53 | int* radii = nullptr,
54 | bool debug = false);
55 |
56 | static void backward(
57 | const int P, int D, int M, int R,
58 | const float* background,
59 | const int width, int height,
60 | const float* means3D,
61 | const float* shs,
62 | const float* colors_precomp,
63 | const float* scales,
64 | const float scale_modifier,
65 | const float* rotations,
66 | const float* cov3D_precomp,
67 | const float* viewmatrix,
68 | const float* projmatrix,
69 | const float* campos,
70 | const float tan_fovx, float tan_fovy,
71 | const int* radii,
72 | char* geom_buffer,
73 | char* binning_buffer,
74 | char* image_buffer,
75 | const float* dL_dpix,
76 | const float* dL_depths,
77 | float* dL_dmean2D,
78 | float* dL_dconic,
79 | float* dL_dopacity,
80 | float* dL_dcolor,
81 | float* dL_dmean3D,
82 | float* dL_dcov3D,
83 | float* dL_dsh,
84 | float* dL_dscale,
85 | float* dL_drot,
86 | bool debug);
87 | };
88 | };
89 |
90 | #endif
91 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #pragma once
13 |
14 | #include
15 | #include
16 | #include "rasterizer.h"
17 | #include
18 |
19 | namespace CudaRasterizer
20 | {
21 | template
22 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment)
23 | {
24 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1);
25 | ptr = reinterpret_cast(offset);
26 | chunk = reinterpret_cast(ptr + count);
27 | }
28 |
29 | struct GeometryState
30 | {
31 | size_t scan_size;
32 | float* depths;
33 | char* scanning_space;
34 | bool* clamped;
35 | int* internal_radii;
36 | float2* means2D;
37 | float* cov3D;
38 | float4* conic_opacity;
39 | float* rgb;
40 | uint32_t* point_offsets;
41 | uint32_t* tiles_touched;
42 |
43 | static GeometryState fromChunk(char*& chunk, size_t P);
44 | };
45 |
46 | struct ImageState
47 | {
48 | uint2* ranges;
49 | uint32_t* n_contrib;
50 | float* accum_alpha;
51 |
52 | static ImageState fromChunk(char*& chunk, size_t N);
53 | };
54 |
55 | struct BinningState
56 | {
57 | size_t sorting_size;
58 | uint64_t* point_list_keys_unsorted;
59 | uint64_t* point_list_keys;
60 | uint32_t* point_list_unsorted;
61 | uint32_t* point_list;
62 | char* list_sorting_space;
63 |
64 | static BinningState fromChunk(char*& chunk, size_t P);
65 | };
66 |
67 | template
68 | size_t required(size_t P)
69 | {
70 | char* size = nullptr;
71 | T::fromChunk(size, P);
72 | return ((size_t)size) + 128;
73 | }
74 | };
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/_C.cpython-37m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/_C.cpython-37m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/_C.cpython-38-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/_C.cpython-38-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from typing import NamedTuple
13 | import torch.nn as nn
14 | import torch
15 | from . import _C
16 |
17 | def cpu_deep_copy_tuple(input_tuple):
18 | copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
19 | return tuple(copied_tensors)
20 |
21 | def rasterize_gaussians(
22 | means3D,
23 | means2D,
24 | sh,
25 | colors_precomp,
26 | opacities,
27 | scales,
28 | rotations,
29 | cov3Ds_precomp,
30 | raster_settings,
31 | ):
32 | return _RasterizeGaussians.apply(
33 | means3D,
34 | means2D,
35 | sh,
36 | colors_precomp,
37 | opacities,
38 | scales,
39 | rotations,
40 | cov3Ds_precomp,
41 | raster_settings,
42 | )
43 |
44 | class _RasterizeGaussians(torch.autograd.Function):
45 | @staticmethod
46 | def forward(
47 | ctx,
48 | means3D,
49 | means2D,
50 | sh,
51 | colors_precomp,
52 | opacities,
53 | scales,
54 | rotations,
55 | cov3Ds_precomp,
56 | raster_settings,
57 | ):
58 |
59 | # Restructure arguments the way that the C++ lib expects them
60 | args = (
61 | raster_settings.bg,
62 | means3D,
63 | colors_precomp,
64 | opacities,
65 | scales,
66 | rotations,
67 | raster_settings.scale_modifier,
68 | cov3Ds_precomp,
69 | raster_settings.viewmatrix,
70 | raster_settings.projmatrix,
71 | raster_settings.tanfovx,
72 | raster_settings.tanfovy,
73 | raster_settings.image_height,
74 | raster_settings.image_width,
75 | sh,
76 | raster_settings.sh_degree,
77 | raster_settings.campos,
78 | raster_settings.prefiltered,
79 | raster_settings.debug
80 | )
81 |
82 | # Invoke C++/CUDA rasterizer
83 | if raster_settings.debug:
84 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
85 | try:
86 | num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
87 | except Exception as ex:
88 | torch.save(cpu_args, "snapshot_fw.dump")
89 | print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
90 | raise ex
91 | else:
92 | num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
93 |
94 | # Keep relevant tensors for backward
95 | ctx.raster_settings = raster_settings
96 | ctx.num_rendered = num_rendered
97 | ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
98 | return color, radii, depth
99 |
100 | @staticmethod
101 | def backward(ctx, grad_out_color, grad_radii, grad_depth):
102 |
103 | # Restore necessary values from context
104 | num_rendered = ctx.num_rendered
105 | raster_settings = ctx.raster_settings
106 | colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
107 |
108 | # Restructure args as C++ method expects them
109 | args = (raster_settings.bg,
110 | means3D,
111 | radii,
112 | colors_precomp,
113 | scales,
114 | rotations,
115 | raster_settings.scale_modifier,
116 | cov3Ds_precomp,
117 | raster_settings.viewmatrix,
118 | raster_settings.projmatrix,
119 | raster_settings.tanfovx,
120 | raster_settings.tanfovy,
121 | grad_out_color,
122 | grad_depth,
123 | sh,
124 | raster_settings.sh_degree,
125 | raster_settings.campos,
126 | geomBuffer,
127 | num_rendered,
128 | binningBuffer,
129 | imgBuffer,
130 | raster_settings.debug)
131 |
132 | # Compute gradients for relevant tensors by invoking backward method
133 | if raster_settings.debug:
134 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
135 | try:
136 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
137 | except Exception as ex:
138 | torch.save(cpu_args, "snapshot_bw.dump")
139 | print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
140 | raise ex
141 | else:
142 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
143 |
144 | grads = (
145 | grad_means3D,
146 | grad_means2D,
147 | grad_sh,
148 | grad_colors_precomp,
149 | grad_opacities,
150 | grad_scales,
151 | grad_rotations,
152 | grad_cov3Ds_precomp,
153 | None,
154 | )
155 |
156 | return grads
157 |
158 | class GaussianRasterizationSettings(NamedTuple):
159 | image_height: int
160 | image_width: int
161 | tanfovx : float
162 | tanfovy : float
163 | bg : torch.Tensor
164 | scale_modifier : float
165 | viewmatrix : torch.Tensor
166 | projmatrix : torch.Tensor
167 | sh_degree : int
168 | campos : torch.Tensor
169 | prefiltered : bool
170 | debug : bool
171 |
172 | class GaussianRasterizer(nn.Module):
173 | def __init__(self, raster_settings):
174 | super().__init__()
175 | self.raster_settings = raster_settings
176 |
177 | def markVisible(self, positions):
178 | # Mark visible points (based on frustum culling for camera) with a boolean
179 | with torch.no_grad():
180 | raster_settings = self.raster_settings
181 | visible = _C.mark_visible(
182 | positions,
183 | raster_settings.viewmatrix,
184 | raster_settings.projmatrix)
185 |
186 | return visible
187 |
188 | def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
189 |
190 | raster_settings = self.raster_settings
191 |
192 | if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
193 | raise Exception('Please provide excatly one of either SHs or precomputed colors!')
194 |
195 | if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
196 | raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
197 |
198 | if shs is None:
199 | shs = torch.Tensor([])
200 | if colors_precomp is None:
201 | colors_precomp = torch.Tensor([])
202 |
203 | if scales is None:
204 | scales = torch.Tensor([])
205 | if rotations is None:
206 | rotations = torch.Tensor([])
207 | if cov3D_precomp is None:
208 | cov3D_precomp = torch.Tensor([])
209 |
210 | # Invoke C++/CUDA rasterization routine
211 | return rasterize_gaussians(
212 | means3D,
213 | means2D,
214 | shs,
215 | colors_precomp,
216 | opacities,
217 | scales,
218 | rotations,
219 | cov3D_precomp,
220 | raster_settings,
221 | )
222 |
223 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/depth-diff-gaussian-rasterization/diff_gaussian_rasterization/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/ext.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #include
13 | #include "rasterize_points.h"
14 |
15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16 | m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
17 | m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
18 | m.def("mark_visible", &markVisible);
19 | }
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/rasterize_points.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include "cuda_rasterizer/config.h"
22 | #include "cuda_rasterizer/rasterizer.h"
23 | #include
24 | #include
25 | #include
26 |
27 | std::function resizeFunctional(torch::Tensor& t) {
28 | auto lambda = [&t](size_t N) {
29 | t.resize_({(long long)N});
30 | return reinterpret_cast(t.contiguous().data_ptr());
31 | };
32 | return lambda;
33 | }
34 |
35 | std::tuple
36 | RasterizeGaussiansCUDA(
37 | const torch::Tensor& background,
38 | const torch::Tensor& means3D,
39 | const torch::Tensor& colors,
40 | const torch::Tensor& opacity,
41 | const torch::Tensor& scales,
42 | const torch::Tensor& rotations,
43 | const float scale_modifier,
44 | const torch::Tensor& cov3D_precomp,
45 | const torch::Tensor& viewmatrix,
46 | const torch::Tensor& projmatrix,
47 | const float tan_fovx,
48 | const float tan_fovy,
49 | const int image_height,
50 | const int image_width,
51 | const torch::Tensor& sh,
52 | const int degree,
53 | const torch::Tensor& campos,
54 | const bool prefiltered,
55 | const bool debug)
56 | {
57 | if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
58 | AT_ERROR("means3D must have dimensions (num_points, 3)");
59 | }
60 |
61 | const int P = means3D.size(0);
62 | const int H = image_height;
63 | const int W = image_width;
64 |
65 | auto int_opts = means3D.options().dtype(torch::kInt32);
66 | auto float_opts = means3D.options().dtype(torch::kFloat32);
67 |
68 | torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
69 | torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts);
70 | torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
71 |
72 | torch::Device device(torch::kCUDA);
73 | torch::TensorOptions options(torch::kByte);
74 | torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
75 | torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
76 | torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
77 | std::function geomFunc = resizeFunctional(geomBuffer);
78 | std::function binningFunc = resizeFunctional(binningBuffer);
79 | std::function imgFunc = resizeFunctional(imgBuffer);
80 |
81 | int rendered = 0;
82 | if(P != 0)
83 | {
84 | int M = 0;
85 | if(sh.size(0) != 0)
86 | {
87 | M = sh.size(1);
88 | }
89 |
90 | rendered = CudaRasterizer::Rasterizer::forward(
91 | geomFunc,
92 | binningFunc,
93 | imgFunc,
94 | P, degree, M,
95 | background.contiguous().data(),
96 | W, H,
97 | means3D.contiguous().data(),
98 | sh.contiguous().data_ptr(),
99 | colors.contiguous().data(),
100 | opacity.contiguous().data(),
101 | scales.contiguous().data_ptr(),
102 | scale_modifier,
103 | rotations.contiguous().data_ptr(),
104 | cov3D_precomp.contiguous().data(),
105 | viewmatrix.contiguous().data(),
106 | projmatrix.contiguous().data(),
107 | campos.contiguous().data(),
108 | tan_fovx,
109 | tan_fovy,
110 | prefiltered,
111 | out_color.contiguous().data(),
112 | out_depth.contiguous().data(),
113 | radii.contiguous().data(),
114 | debug);
115 | }
116 | return std::make_tuple(rendered, out_color, out_depth, radii, geomBuffer, binningBuffer, imgBuffer);
117 | }
118 |
119 | std::tuple
120 | RasterizeGaussiansBackwardCUDA(
121 | const torch::Tensor& background,
122 | const torch::Tensor& means3D,
123 | const torch::Tensor& radii,
124 | const torch::Tensor& colors,
125 | const torch::Tensor& scales,
126 | const torch::Tensor& rotations,
127 | const float scale_modifier,
128 | const torch::Tensor& cov3D_precomp,
129 | const torch::Tensor& viewmatrix,
130 | const torch::Tensor& projmatrix,
131 | const float tan_fovx,
132 | const float tan_fovy,
133 | const torch::Tensor& dL_dout_color,
134 | const torch::Tensor& dL_dout_depth,
135 | const torch::Tensor& sh,
136 | const int degree,
137 | const torch::Tensor& campos,
138 | const torch::Tensor& geomBuffer,
139 | const int R,
140 | const torch::Tensor& binningBuffer,
141 | const torch::Tensor& imageBuffer,
142 | const bool debug)
143 | {
144 | const int P = means3D.size(0);
145 | const int H = dL_dout_color.size(1);
146 | const int W = dL_dout_color.size(2);
147 |
148 | int M = 0;
149 | if(sh.size(0) != 0)
150 | {
151 | M = sh.size(1);
152 | }
153 |
154 | torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
155 | torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
156 | torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
157 | torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
158 | torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
159 | torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
160 | torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
161 | torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
162 | torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
163 |
164 | if(P != 0)
165 | {
166 | CudaRasterizer::Rasterizer::backward(P, degree, M, R,
167 | background.contiguous().data(),
168 | W, H,
169 | means3D.contiguous().data(),
170 | sh.contiguous().data(),
171 | colors.contiguous().data(),
172 | scales.data_ptr(),
173 | scale_modifier,
174 | rotations.data_ptr(),
175 | cov3D_precomp.contiguous().data(),
176 | viewmatrix.contiguous().data(),
177 | projmatrix.contiguous().data(),
178 | campos.contiguous().data(),
179 | tan_fovx,
180 | tan_fovy,
181 | radii.contiguous().data(),
182 | reinterpret_cast(geomBuffer.contiguous().data_ptr()),
183 | reinterpret_cast(binningBuffer.contiguous().data_ptr()),
184 | reinterpret_cast(imageBuffer.contiguous().data_ptr()),
185 | dL_dout_color.contiguous().data(),
186 | dL_dout_depth.contiguous().data(),
187 | dL_dmeans2D.contiguous().data(),
188 | dL_dconic.contiguous().data(),
189 | dL_dopacity.contiguous().data(),
190 | dL_dcolors.contiguous().data(),
191 | dL_dmeans3D.contiguous().data(),
192 | dL_dcov3D.contiguous().data(),
193 | dL_dsh.contiguous().data(),
194 | dL_dscales.contiguous().data(),
195 | dL_drotations.contiguous().data(),
196 | debug);
197 | }
198 |
199 | return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations);
200 | }
201 |
202 | torch::Tensor markVisible(
203 | torch::Tensor& means3D,
204 | torch::Tensor& viewmatrix,
205 | torch::Tensor& projmatrix)
206 | {
207 | const int P = means3D.size(0);
208 |
209 | torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
210 |
211 | if(P != 0)
212 | {
213 | CudaRasterizer::Rasterizer::markVisible(P,
214 | means3D.contiguous().data(),
215 | viewmatrix.contiguous().data(),
216 | projmatrix.contiguous().data(),
217 | present.contiguous().data());
218 | }
219 |
220 | return present;
221 | }
222 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/rasterize_points.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #pragma once
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | std::tuple
19 | RasterizeGaussiansCUDA(
20 | const torch::Tensor& background,
21 | const torch::Tensor& means3D,
22 | const torch::Tensor& colors,
23 | const torch::Tensor& opacity,
24 | const torch::Tensor& scales,
25 | const torch::Tensor& rotations,
26 | const float scale_modifier,
27 | const torch::Tensor& cov3D_precomp,
28 | const torch::Tensor& viewmatrix,
29 | const torch::Tensor& projmatrix,
30 | const float tan_fovx,
31 | const float tan_fovy,
32 | const int image_height,
33 | const int image_width,
34 | const torch::Tensor& sh,
35 | const int degree,
36 | const torch::Tensor& campos,
37 | const bool prefiltered,
38 | const bool debug);
39 |
40 | std::tuple
41 | RasterizeGaussiansBackwardCUDA(
42 | const torch::Tensor& background,
43 | const torch::Tensor& means3D,
44 | const torch::Tensor& radii,
45 | const torch::Tensor& colors,
46 | const torch::Tensor& scales,
47 | const torch::Tensor& rotations,
48 | const float scale_modifier,
49 | const torch::Tensor& cov3D_precomp,
50 | const torch::Tensor& viewmatrix,
51 | const torch::Tensor& projmatrix,
52 | const float tan_fovx,
53 | const float tan_fovy,
54 | const torch::Tensor& dL_dout_color,
55 | const torch::Tensor& dL_dout_depth,
56 | const torch::Tensor& sh,
57 | const int degree,
58 | const torch::Tensor& campos,
59 | const torch::Tensor& geomBuffer,
60 | const int R,
61 | const torch::Tensor& binningBuffer,
62 | const torch::Tensor& imageBuffer,
63 | const bool debug);
64 |
65 | torch::Tensor markVisible(
66 | torch::Tensor& means3D,
67 | torch::Tensor& viewmatrix,
68 | torch::Tensor& projmatrix);
69 |
--------------------------------------------------------------------------------
/submodules/depth-diff-gaussian-rasterization/setup.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from setuptools import setup
13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
14 | import os
15 | os.path.dirname(os.path.abspath(__file__))
16 |
17 | setup(
18 | name="diff_gaussian_rasterization",
19 | packages=['diff_gaussian_rasterization'],
20 | ext_modules=[
21 | CUDAExtension(
22 | name="diff_gaussian_rasterization._C",
23 | sources=[
24 | "cuda_rasterizer/rasterizer_impl.cu",
25 | "cuda_rasterizer/forward.cu",
26 | "cuda_rasterizer/backward.cu",
27 | "rasterize_points.cu",
28 | "ext.cpp"],
29 | extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]})
30 | ],
31 | cmdclass={
32 | 'build_ext': BuildExtension
33 | }
34 | )
35 |
--------------------------------------------------------------------------------
/submodules/simple-knn/build/lib.linux-x86_64-cpython-37/simple_knn/_C.cpython-37m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/build/lib.linux-x86_64-cpython-37/simple_knn/_C.cpython-37m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/submodules/simple-knn/build/temp.linux-x86_64-cpython-37/ext.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/build/temp.linux-x86_64-cpython-37/ext.o
--------------------------------------------------------------------------------
/submodules/simple-knn/build/temp.linux-x86_64-cpython-37/simple_knn.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/build/temp.linux-x86_64-cpython-37/simple_knn.o
--------------------------------------------------------------------------------
/submodules/simple-knn/build/temp.linux-x86_64-cpython-37/spatial.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/build/temp.linux-x86_64-cpython-37/spatial.o
--------------------------------------------------------------------------------
/submodules/simple-knn/ext.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #include
13 | #include "spatial.h"
14 |
15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16 | m.def("distCUDA2", &distCUDA2);
17 | }
18 |
--------------------------------------------------------------------------------
/submodules/simple-knn/setup.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from setuptools import setup
13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
14 | import os
15 |
16 | cxx_compiler_flags = []
17 |
18 | if os.name == 'nt':
19 | cxx_compiler_flags.append("/wd4624")
20 |
21 | setup(
22 | name="simple_knn",
23 | ext_modules=[
24 | CUDAExtension(
25 | name="simple_knn._C",
26 | sources=[
27 | "spatial.cu",
28 | "simple_knn.cu",
29 | "ext.cpp"],
30 | extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags})
31 | ],
32 | cmdclass={
33 | 'build_ext': BuildExtension
34 | }
35 | )
36 |
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #define BOX_SIZE 1024
13 |
14 | #include "cuda_runtime.h"
15 | #include "device_launch_parameters.h"
16 | #include "simple_knn.h"
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #define __CUDACC__
24 | #include
25 | #include
26 |
27 | namespace cg = cooperative_groups;
28 |
29 | struct CustomMin
30 | {
31 | __device__ __forceinline__
32 | float3 operator()(const float3& a, const float3& b) const {
33 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) };
34 | }
35 | };
36 |
37 | struct CustomMax
38 | {
39 | __device__ __forceinline__
40 | float3 operator()(const float3& a, const float3& b) const {
41 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) };
42 | }
43 | };
44 |
45 | __host__ __device__ uint32_t prepMorton(uint32_t x)
46 | {
47 | x = (x | (x << 16)) & 0x030000FF;
48 | x = (x | (x << 8)) & 0x0300F00F;
49 | x = (x | (x << 4)) & 0x030C30C3;
50 | x = (x | (x << 2)) & 0x09249249;
51 | return x;
52 | }
53 |
54 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx)
55 | {
56 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1));
57 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1));
58 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1));
59 |
60 | return x | (y << 1) | (z << 2);
61 | }
62 |
63 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes)
64 | {
65 | auto idx = cg::this_grid().thread_rank();
66 | if (idx >= P)
67 | return;
68 |
69 | codes[idx] = coord2Morton(points[idx], minn, maxx);
70 | }
71 |
72 | struct MinMax
73 | {
74 | float3 minn;
75 | float3 maxx;
76 | };
77 |
78 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes)
79 | {
80 | auto idx = cg::this_grid().thread_rank();
81 |
82 | MinMax me;
83 | if (idx < P)
84 | {
85 | me.minn = points[indices[idx]];
86 | me.maxx = points[indices[idx]];
87 | }
88 | else
89 | {
90 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX };
91 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX };
92 | }
93 |
94 | __shared__ MinMax redResult[BOX_SIZE];
95 |
96 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2)
97 | {
98 | if (threadIdx.x < 2 * off)
99 | redResult[threadIdx.x] = me;
100 | __syncthreads();
101 |
102 | if (threadIdx.x < off)
103 | {
104 | MinMax other = redResult[threadIdx.x + off];
105 | me.minn.x = min(me.minn.x, other.minn.x);
106 | me.minn.y = min(me.minn.y, other.minn.y);
107 | me.minn.z = min(me.minn.z, other.minn.z);
108 | me.maxx.x = max(me.maxx.x, other.maxx.x);
109 | me.maxx.y = max(me.maxx.y, other.maxx.y);
110 | me.maxx.z = max(me.maxx.z, other.maxx.z);
111 | }
112 | __syncthreads();
113 | }
114 |
115 | if (threadIdx.x == 0)
116 | boxes[blockIdx.x] = me;
117 | }
118 |
119 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p)
120 | {
121 | float3 diff = { 0, 0, 0 };
122 | if (p.x < box.minn.x || p.x > box.maxx.x)
123 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x));
124 | if (p.y < box.minn.y || p.y > box.maxx.y)
125 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y));
126 | if (p.z < box.minn.z || p.z > box.maxx.z)
127 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z));
128 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z;
129 | }
130 |
131 | template
132 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn)
133 | {
134 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z };
135 | float dist = d.x * d.x + d.y * d.y + d.z * d.z;
136 | for (int j = 0; j < K; j++)
137 | {
138 | if (knn[j] > dist)
139 | {
140 | float t = knn[j];
141 | knn[j] = dist;
142 | dist = t;
143 | }
144 | }
145 | }
146 |
147 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists)
148 | {
149 | int idx = cg::this_grid().thread_rank();
150 | if (idx >= P)
151 | return;
152 |
153 | float3 point = points[indices[idx]];
154 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX };
155 |
156 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++)
157 | {
158 | if (i == idx)
159 | continue;
160 | updateKBest<3>(point, points[indices[i]], best);
161 | }
162 |
163 | float reject = best[2];
164 | best[0] = FLT_MAX;
165 | best[1] = FLT_MAX;
166 | best[2] = FLT_MAX;
167 |
168 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++)
169 | {
170 | MinMax box = boxes[b];
171 | float dist = distBoxPoint(box, point);
172 | if (dist > reject || dist > best[2])
173 | continue;
174 |
175 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++)
176 | {
177 | if (i == idx)
178 | continue;
179 | updateKBest<3>(point, points[indices[i]], best);
180 | }
181 | }
182 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f;
183 | }
184 |
185 | void SimpleKNN::knn(int P, float3* points, float* meanDists)
186 | {
187 | float3* result;
188 | cudaMalloc(&result, sizeof(float3));
189 | size_t temp_storage_bytes;
190 |
191 | float3 init = { 0, 0, 0 }, minn, maxx;
192 |
193 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init);
194 | thrust::device_vector temp_storage(temp_storage_bytes);
195 |
196 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init);
197 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost);
198 |
199 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init);
200 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost);
201 |
202 | thrust::device_vector morton(P);
203 | thrust::device_vector morton_sorted(P);
204 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get());
205 |
206 | thrust::device_vector indices(P);
207 | thrust::sequence(indices.begin(), indices.end());
208 | thrust::device_vector indices_sorted(P);
209 |
210 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P);
211 | temp_storage.resize(temp_storage_bytes);
212 |
213 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P);
214 |
215 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE;
216 | thrust::device_vector boxes(num_boxes);
217 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get());
218 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists);
219 |
220 | cudaFree(result);
221 | }
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: simple-knn
3 | Version: 0.0.0
4 |
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | ext.cpp
2 | setup.py
3 | simple_knn.cu
4 | spatial.cu
5 | simple_knn.egg-info/PKG-INFO
6 | simple_knn.egg-info/SOURCES.txt
7 | simple_knn.egg-info/dependency_links.txt
8 | simple_knn.egg-info/top_level.txt
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | simple_knn
2 |
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #ifndef SIMPLEKNN_H_INCLUDED
13 | #define SIMPLEKNN_H_INCLUDED
14 |
15 | class SimpleKNN
16 | {
17 | public:
18 | static void knn(int P, float3* points, float* meanDists);
19 | };
20 |
21 | #endif
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/simple_knn/.gitkeep
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn/_C.cpython-37m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/simple_knn/_C.cpython-37m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn/_C.cpython-38-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/submodules/simple-knn/simple_knn/_C.cpython-38-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/submodules/simple-knn/spatial.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #include "spatial.h"
13 | #include "simple_knn.h"
14 |
15 | torch::Tensor
16 | distCUDA2(const torch::Tensor& points)
17 | {
18 | const int P = points.size(0);
19 |
20 | auto float_opts = points.options().dtype(torch::kFloat32);
21 | torch::Tensor means = torch::full({P}, 0.0, float_opts);
22 |
23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data());
24 |
25 | return means;
26 | }
--------------------------------------------------------------------------------
/submodules/simple-knn/spatial.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #include
13 |
14 | torch::Tensor distCUDA2(const torch::Tensor& points);
--------------------------------------------------------------------------------
/utils/TIMES.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/TIMES.TTF
--------------------------------------------------------------------------------
/utils/TIMESBD.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/TIMESBD.TTF
--------------------------------------------------------------------------------
/utils/TIMESBI.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/TIMESBI.TTF
--------------------------------------------------------------------------------
/utils/TIMESI.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/TIMESI.TTF
--------------------------------------------------------------------------------
/utils/__pycache__/camera_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/camera_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/general_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/general_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/graphics_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/graphics_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/image_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/image_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/loss_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/loss_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/params_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/params_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/scene_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/scene_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/sh_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/sh_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/system_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/system_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/timer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinlab-imvr/Deform3DGS/b06f0a501bd1c71aa70d9a48a8dd4810ccddb2b7/utils/__pycache__/timer.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from scene.cameras import Camera
13 | import numpy as np
14 | from utils.general_utils import PILtoTorch
15 | from utils.graphics_utils import fov2focal
16 |
17 | WARNED = False
18 |
19 | def loadCam(args, id, cam_info, resolution_scale):
20 |
21 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
22 | FoVx=cam_info.FovX, FoVy=cam_info.FovY,
23 | image=cam_info.image, mask=None, gt_alpha_mask=None, depth=None,
24 | image_name=cam_info.image_name, uid=id, data_device=args.data_device,
25 | time = cam_info.time)
26 |
27 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
28 | camera_list = []
29 |
30 | for id, c in enumerate(cam_infos):
31 | camera_list.append(loadCam(args, id, c, resolution_scale))
32 |
33 | return camera_list
34 |
35 | def camera_to_JSON(id, camera : Camera):
36 | Rt = np.zeros((4, 4))
37 | Rt[:3, :3] = camera.R.transpose()
38 | Rt[:3, 3] = camera.T
39 | Rt[3, 3] = 1.0
40 |
41 | W2C = np.linalg.inv(Rt)
42 | pos = W2C[:3, 3]
43 | rot = W2C[:3, :3]
44 | serializable_array_2d = [x.tolist() for x in rot]
45 | camera_entry = {
46 | 'id' : id,
47 | 'img_name' : camera.image_name,
48 | 'width' : camera.width,
49 | 'height' : camera.height,
50 | 'position': pos.tolist(),
51 | 'rotation': serializable_array_2d,
52 | 'fy' : fov2focal(camera.FovY, camera.height),
53 | 'fx' : fov2focal(camera.FovX, camera.width)
54 | }
55 | return camera_entry
56 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import sys
14 | from datetime import datetime
15 | import numpy as np
16 | import random
17 | import cv2
18 |
19 | def inpaint_rgb(rgb_image, mask):
20 | # Convert mask to uint8
21 | mask_uint8 = (mask * 255).astype(np.uint8)
22 | # Inpaint missing regions
23 | inpainted_image = cv2.inpaint(rgb_image, mask_uint8, inpaintRadius=5, flags=cv2.INPAINT_TELEA)
24 |
25 | return inpainted_image
26 |
27 | def inpaint_depth(depth_image, mask):
28 | # Convert mask to uint8
29 | mask_uint8 = (mask * 255).astype(np.uint8)
30 |
31 | # Inpaint missing regions
32 | inpainted_depth_image = cv2.inpaint((depth_image).astype(np.uint8), mask_uint8, inpaintRadius=5, flags=cv2.INPAINT_TELEA)
33 |
34 | return inpainted_depth_image
35 |
36 | def inverse_sigmoid(x):
37 | return torch.log(x/(1-x))
38 |
39 | def PILtoTorch(pil_image, resolution):
40 | if resolution is not None:
41 | resized_image_PIL = pil_image.resize(resolution)
42 | else:
43 | resized_image_PIL = pil_image
44 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
45 | if len(resized_image.shape) == 3:
46 | return resized_image.permute(2, 0, 1)
47 | else:
48 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
49 |
50 | def get_expon_lr_func(
51 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
52 | ):
53 | """
54 | Copied from Plenoxels
55 |
56 | Continuous learning rate decay function. Adapted from JaxNeRF
57 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
58 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
59 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
60 | function of lr_delay_mult, such that the initial learning rate is
61 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
62 | to the normal learning rate when steps>lr_delay_steps.
63 | :param conf: config subtree 'lr' or similar
64 | :param max_steps: int, the number of steps during optimization.
65 | :return HoF which takes step as input
66 | """
67 |
68 | def helper(step):
69 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
70 | # Disable this parameter
71 | return 0.0
72 | if lr_delay_steps > 0:
73 | # A kind of reverse cosine decay.
74 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
75 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
76 | )
77 | else:
78 | delay_rate = 1.0
79 | t = np.clip(step / max_steps, 0, 1)
80 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
81 | return delay_rate * log_lerp
82 |
83 | return helper
84 |
85 | def strip_lowerdiag(L):
86 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
87 |
88 | uncertainty[:, 0] = L[:, 0, 0]
89 | uncertainty[:, 1] = L[:, 0, 1]
90 | uncertainty[:, 2] = L[:, 0, 2]
91 | uncertainty[:, 3] = L[:, 1, 1]
92 | uncertainty[:, 4] = L[:, 1, 2]
93 | uncertainty[:, 5] = L[:, 2, 2]
94 | return uncertainty
95 |
96 | def strip_symmetric(sym):
97 | return strip_lowerdiag(sym)
98 |
99 | def build_rotation(r):
100 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
101 | q = r / norm[:, None]
102 | r = q[:, 0]
103 | x = q[:, 1]
104 | y = q[:, 2]
105 | z = q[:, 3]
106 |
107 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
108 | R[:, 0, 0] = 1 - 2 * (y*y + z*z)
109 | R[:, 0, 1] = 2 * (x*y - r*z)
110 | R[:, 0, 2] = 2 * (x*z + r*y)
111 | R[:, 1, 0] = 2 * (x*y + r*z)
112 | R[:, 1, 1] = 1 - 2 * (x*x + z*z)
113 | R[:, 1, 2] = 2 * (y*z - r*x)
114 | R[:, 2, 0] = 2 * (x*z - r*y)
115 | R[:, 2, 1] = 2 * (y*z + r*x)
116 | R[:, 2, 2] = 1 - 2 * (x*x + y*y)
117 | return R
118 |
119 | def build_scaling_rotation(s, r):
120 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
121 | R = build_rotation(r)
122 |
123 | L[:,0,0] = s[:,0]
124 | L[:,1,1] = s[:,1]
125 | L[:,2,2] = s[:,2]
126 |
127 | L = R @ L
128 | return L
129 |
130 | def safe_state(silent):
131 | old_f = sys.stdout
132 | class F:
133 | def __init__(self, silent):
134 | self.silent = silent
135 |
136 | def write(self, x):
137 | if not self.silent:
138 | if x.endswith("\n"):
139 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
140 | else:
141 | old_f.write(x)
142 |
143 | def flush(self):
144 | old_f.flush()
145 |
146 | sys.stdout = F(silent)
147 |
148 | random.seed(0)
149 | np.random.seed(0)
150 | torch.manual_seed(0)
151 | torch.cuda.set_device(torch.device("cuda:0"))
152 |
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 | class BasicPointCloud(NamedTuple):
18 | points : np.array
19 | colors : np.array
20 | normals : np.array
21 |
22 | def geom_transform_points(points, transf_matrix):
23 | P, _ = points.shape
24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
25 | points_hom = torch.cat([points, ones], dim=1)
26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
27 |
28 | denom = points_out[..., 3:] + 0.0000001
29 | return (points_out[..., :3] / denom).squeeze(dim=0)
30 |
31 | def getWorld2View(R, t):
32 | Rt = np.zeros((4, 4))
33 | Rt[:3, :3] = R.transpose()
34 | Rt[:3, 3] = t
35 | Rt[3, 3] = 1.0
36 | return np.float32(Rt)
37 |
38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
39 | Rt = np.zeros((4, 4))
40 | Rt[:3, :3] = R.transpose() #w2c
41 | Rt[:3, 3] = t
42 | Rt[3, 3] = 1.0
43 |
44 | C2W = np.linalg.inv(Rt)
45 | cam_center = C2W[:3, 3]
46 | cam_center = (cam_center + translate) * scale
47 | C2W[:3, 3] = cam_center
48 | Rt = np.linalg.inv(C2W)
49 | return np.float32(Rt)
50 |
51 | def getProjectionMatrix(znear, zfar, fovX, fovY):
52 | tanHalfFovY = math.tan((fovY / 2))
53 | tanHalfFovX = math.tan((fovX / 2))
54 |
55 | top = tanHalfFovY * znear
56 | bottom = -top
57 | right = tanHalfFovX * znear
58 | left = -right
59 |
60 | P = torch.zeros(4, 4)
61 |
62 | z_sign = 1.0
63 |
64 | P[0, 0] = 2.0 * znear / (right - left)
65 | P[1, 1] = 2.0 * znear / (top - bottom)
66 | P[0, 2] = (right + left) / (right - left)
67 | P[1, 2] = (top + bottom) / (top - bottom)
68 | P[3, 2] = z_sign
69 | P[2, 2] = z_sign * zfar / (zfar - znear)
70 | P[2, 3] = -(zfar * znear) / (zfar - znear)
71 | return P
72 |
73 | def getProjectionMatrix2(znear, zfar, K, h, w):
74 | near_fx = znear / K[0, 0]
75 | near_fy = znear / K[1, 1]
76 | left = - (w - K[0, 2]) * near_fx
77 | right = K[0, 2] * near_fx
78 | bottom = (K[1, 2] - h) * near_fy
79 | top = K[1, 2] * near_fy
80 |
81 | P = torch.zeros(4, 4)
82 | z_sign = 1.0
83 | P[0, 0] = 2.0 * znear / (right - left)
84 | P[1, 1] = 2.0 * znear / (top - bottom)
85 | P[0, 2] = (right + left) / (right - left)
86 | P[1, 2] = (top + bottom) / (top - bottom)
87 | P[3, 2] = z_sign
88 | P[2, 2] = z_sign * zfar / (zfar - znear)
89 | P[2, 3] = -(zfar * znear) / (zfar - znear)
90 | return P
91 |
92 | def fov2focal(fov, pixels):
93 | return pixels / (2 * math.tan(fov / 2))
94 |
95 | def focal2fov(focal, pixels):
96 | return 2*math.atan(pixels/(2*focal))
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import torch
14 |
15 |
16 | def tensor2array(tensor):
17 | if torch.is_tensor(tensor):
18 | return tensor.detach().cpu().numpy()
19 | else:
20 | return tensor
21 |
22 | def mse(img1, img2):
23 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
24 |
25 | @torch.no_grad()
26 | def psnr(img1, img2, mask=None):
27 | if mask is None:
28 | mse_mask = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
29 | else:
30 | if mask.shape[1] == 3:
31 | mse_mask = (((img1-img2)**2)*mask).sum() / ((mask.sum()+1e-10))
32 | else:
33 | mse_mask = (((img1-img2)**2)*mask).sum() / ((mask.sum()+1e-10)*3.0)
34 |
35 | return 20 * torch.log10(1.0 / torch.sqrt(mse_mask))
36 |
37 | def rmse(a, b, mask):
38 | """Compute rmse.
39 | """
40 | if torch.is_tensor(a):
41 | a = tensor2array(a)
42 | if torch.is_tensor(b):
43 | b = tensor2array(b)
44 | if torch.is_tensor(mask):
45 | mask = tensor2array(mask)
46 | if len(mask.shape) == len(a.shape) - 1:
47 | mask = mask[..., None]
48 | mask_sum = np.sum(mask) + 1e-10
49 | rmse = (((a - b)**2 * mask).sum() / (mask_sum))**0.5
50 | return rmse
51 |
52 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import numpy as np
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 |
18 | def TV_loss(x, mask):
19 | B, C, H, W = x.shape
20 | tv_h = torch.abs(x[:,:,1:,:] - x[:,:,:-1,:]).sum()
21 | tv_w = torch.abs(x[:,:,:,1:] - x[:,:,:,:-1]).sum()
22 | return (tv_h + tv_w) / (B * C * H * W)
23 |
24 |
25 | def lpips_loss(img1, img2, lpips_model):
26 | loss = lpips_model(img1,img2)
27 | return loss.mean()
28 |
29 | def l1_loss(network_output, gt, mask=None):
30 | loss = torch.abs((network_output - gt))
31 | if mask is not None:
32 | if mask.ndim == 4:
33 | mask = mask.repeat(1, network_output.shape[1], 1, 1)
34 | elif mask.ndim == 3:
35 | mask = mask.repeat(network_output.shape[1], 1, 1)
36 | else:
37 | raise ValueError('the dimension of mask should be either 3 or 4')
38 |
39 | try:
40 | loss = loss[mask!=0]
41 | except:
42 | print(loss.shape)
43 | print(mask.shape)
44 | print(loss.dtype)
45 | print(mask.dtype)
46 | return loss.mean()
47 |
48 | def l2_loss(network_output, gt):
49 | return ((network_output - gt) ** 2).mean()
50 |
51 | def gaussian(window_size, sigma):
52 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
53 | return gauss / gauss.sum()
54 |
55 | def create_window(window_size, channel):
56 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
57 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
58 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
59 | return window
60 |
61 | def ssim(img1, img2, window_size=11, size_average=True):
62 | channel = img1.size(-3)
63 | window = create_window(window_size, channel)
64 |
65 | if img1.is_cuda:
66 | window = window.cuda(img1.get_device())
67 | window = window.type_as(img1)
68 |
69 | return _ssim(img1, img2, window, window_size, channel, size_average)
70 |
71 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
72 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
73 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
74 |
75 | mu1_sq = mu1.pow(2)
76 | mu2_sq = mu2.pow(2)
77 | mu1_mu2 = mu1 * mu2
78 |
79 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
80 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
81 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
82 |
83 | C1 = 0.01 ** 2
84 | C2 = 0.03 ** 2
85 |
86 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
87 |
88 | if size_average:
89 | return ssim_map.mean()
90 | else:
91 | return ssim_map.mean(1).mean(1).mean(1)
92 |
93 |
--------------------------------------------------------------------------------
/utils/params_utils.py:
--------------------------------------------------------------------------------
1 | def merge_hparams(args, config):
2 | params = ["OptimizationParams", "ModelHiddenParams", "ModelParams", "PipelineParams"]
3 | for param in params:
4 | if param in config.keys():
5 | for key, value in config[param].items():
6 | if hasattr(args, key):
7 | setattr(args, key, value)
8 | return args
--------------------------------------------------------------------------------
/utils/scene_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from PIL import Image, ImageDraw, ImageFont
4 | from matplotlib import pyplot as plt
5 | plt.rcParams['font.sans-serif'] = ['Times New Roman']
6 |
7 | import numpy as np
8 |
9 | import copy
10 | @torch.no_grad()
11 | def render_training_image(scene, gaussians, viewpoints, render_func, pipe, background, stage, iteration, time_now):
12 | def render(gaussians, viewpoint, path, scaling):
13 | # scaling_copy = gaussians._scaling
14 | render_pkg = render_func(viewpoint, gaussians, pipe, background, stage=stage)
15 | label1 = f"stage:{stage},iter:{iteration}"
16 | times = time_now/60
17 | if times < 1:
18 | end = "min"
19 | else:
20 | end = "mins"
21 | label2 = "time:%.2f" % times + end
22 | image = render_pkg["render"]
23 | depth = render_pkg["depth"]
24 | image_np = image.permute(1, 2, 0).cpu().numpy() # 转换通道顺序为 (H, W, 3)
25 | depth_np = depth.permute(1, 2, 0).cpu().numpy()
26 | depth_np /= depth_np.max()
27 | depth_np = np.repeat(depth_np, 3, axis=2)
28 | image_np = np.concatenate((image_np, depth_np), axis=1)
29 | image_with_labels = Image.fromarray((np.clip(image_np,0,1) * 255).astype('uint8')) # 转换为8位图像
30 | # 创建PIL图像对象的副本以绘制标签
31 | draw1 = ImageDraw.Draw(image_with_labels)
32 |
33 | # 选择字体和字体大小
34 | font = ImageFont.truetype('./utils/TIMES.TTF', size=40) # 请将路径替换为您选择的字体文件路径
35 |
36 | # 选择文本颜色
37 | text_color = (255, 0, 0) # 白色
38 |
39 | # 选择标签的位置(左上角坐标)
40 | label1_position = (10, 10)
41 | label2_position = (image_with_labels.width - 100 - len(label2) * 10, 10) # 右上角坐标
42 |
43 | # 在图像上添加标签
44 | draw1.text(label1_position, label1, fill=text_color, font=font)
45 | draw1.text(label2_position, label2, fill=text_color, font=font)
46 |
47 | image_with_labels.save(path)
48 | render_base_path = os.path.join(scene.model_path, f"{stage}_render")
49 | point_cloud_path = os.path.join(render_base_path,"pointclouds")
50 | image_path = os.path.join(render_base_path,"images")
51 | if not os.path.exists(os.path.join(scene.model_path, f"{stage}_render")):
52 | os.makedirs(render_base_path)
53 | if not os.path.exists(point_cloud_path):
54 | os.makedirs(point_cloud_path)
55 | if not os.path.exists(image_path):
56 | os.makedirs(image_path)
57 | # image:3,800,800
58 |
59 | # point_save_path = os.path.join(point_cloud_path,f"{iteration}.jpg")
60 | for idx in range(len(viewpoints)):
61 | image_save_path = os.path.join(image_path,f"{iteration}_{idx}.jpg")
62 | render(gaussians,viewpoints[idx],image_save_path,scaling = 1)
63 | # render(gaussians,point_save_path,scaling = 0.1)
64 | # 保存带有标签的图像
65 |
66 | pc_mask = gaussians.get_opacity
67 | pc_mask = pc_mask > 0.1
68 | xyz = gaussians.get_xyz.detach()[pc_mask.squeeze()].cpu().permute(1,0).numpy()
69 | # visualize_and_save_point_cloud(xyz, viewpoint.R, viewpoint.T, point_save_path)
70 | # 如果需要,您可以将PIL图像转换回PyTorch张量
71 | # return image
72 | # image_with_labels_tensor = torch.tensor(image_with_labels, dtype=torch.float32).permute(2, 0, 1) / 255.0
73 |
74 | def visualize_and_save_point_cloud(point_cloud, R, T, filename):
75 | # 创建3D散点图
76 | fig = plt.figure()
77 | ax = fig.add_subplot(111, projection='3d')
78 | R = R.T
79 | # 应用旋转和平移变换
80 | T = -R.dot(T)
81 | transformed_point_cloud = np.dot(R, point_cloud) + T.reshape(-1, 1)
82 | # pcd = o3d.geometry.PointCloud()
83 | # pcd.points = o3d.utility.Vector3dVector(transformed_point_cloud.T) # 转置点云数据以匹配Open3D的格式
84 | # transformed_point_cloud[2,:] = -transformed_point_cloud[2,:]
85 | # 可视化点云
86 | ax.scatter(transformed_point_cloud[0], transformed_point_cloud[1], transformed_point_cloud[2], c='g', marker='o')
87 | ax.axis("off")
88 | # ax.set_xlabel('X Label')
89 | # ax.set_ylabel('Y Label')
90 | # ax.set_zlabel('Z Label')
91 |
92 | # 保存渲染结果为图片
93 | plt.savefig(filename)
94 |
95 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from errno import EEXIST
13 | from os import makedirs, path
14 | import os
15 |
16 | def mkdir_p(folder_path):
17 | # Creates a directory. equivalent to using mkdir -p on the command line
18 | try:
19 | makedirs(folder_path)
20 | except OSError as exc: # Python >2.5
21 | if exc.errno == EEXIST and path.isdir(folder_path):
22 | pass
23 | else:
24 | raise
25 |
26 | def searchForMaxIteration(folder):
27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
28 | return max(saved_iters)
29 |
--------------------------------------------------------------------------------
/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 | class Timer:
3 | def __init__(self):
4 | self.start_time = None
5 | self.elapsed = 0
6 | self.paused = False
7 |
8 | def start(self):
9 | if self.start_time is None:
10 | self.start_time = time.time()
11 | elif self.paused:
12 | self.start_time = time.time() - self.elapsed
13 | self.paused = False
14 |
15 | def pause(self):
16 | if not self.paused:
17 | self.elapsed = time.time() - self.start_time
18 | self.paused = True
19 |
20 | def get_elapsed_time(self):
21 | if self.paused:
22 | return self.elapsed
23 | else:
24 | return time.time() - self.start_time
--------------------------------------------------------------------------------