├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── arguments
└── __init__.py
├── assets
├── pipeline.png
├── results
│ ├── D-NeRF
│ │ ├── Quantitative.jpg
│ │ ├── bouncing.gif
│ │ ├── hell.gif
│ │ ├── hook.gif
│ │ ├── jump.gif
│ │ ├── lego.gif
│ │ ├── mutant.gif
│ │ ├── stand.gif
│ │ └── trex.gif
│ └── NeRF-DS
│ │ └── Quantitative.jpg
└── teaser.png
├── convert.py
├── full_eval.py
├── gaussian_renderer
├── __init__.py
└── network_gui.py
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── metrics.py
├── render.py
├── requirements.txt
├── scene
├── __init__.py
├── cameras.py
├── colmap_loader.py
├── dataset_readers.py
├── deform_model.py
└── gaussian_model.py
├── train.py
├── train_gui.py
└── utils
├── camera_utils.py
├── general_utils.py
├── graphics_utils.py
├── gui_utils.py
├── image_utils.py
├── loss_utils.py
├── pose_utils.py
├── rigid_utils.py
├── sh_utils.py
├── system_utils.py
└── time_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .vscode
3 | output
4 | build
5 | diff_rasterization/diff_rast.egg-info
6 | diff_rasterization/dist
7 | tensorboard_3d
8 | screenshots
9 | .idea
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "submodules/simple-knn"]
2 | path = submodules/simple-knn
3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git
4 | [submodule "submodules/depth-diff-gaussian-rasterization"]
5 | path = submodules/depth-diff-gaussian-rasterization
6 | url = https://github.com/ingra14m/diff-gaussian-rasterization-extentions
7 | branch = filter-norm
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ziyi Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deformable 3D Gaussians for High-Fidelity Monocular Dynamic Scene Reconstruction
2 |
3 | ## [Project page](https://ingra14m.github.io/Deformable-Gaussians/) | [Paper](https://arxiv.org/abs/2309.13101)
4 |
5 | 
6 |
7 | This repository contains the official implementation associated with the paper "Deformable 3D Gaussians for High-Fidelity Monocular Dynamic Scene Reconstruction".
8 |
9 |
10 |
11 | ## News
12 |
13 | - **[5/26/2024]** [Lightweight-Deformable-GS](https://github.com/ingra14m/Lightweight-Deformable-GS) has been integrated into this repo. For the original version aligned with paper, please check the [paper](https://github.com/ingra14m/Deformable-3D-Gaussians/tree/paper) branch.
14 | - **[5/24/2024]** An optimized version [Lightweight-Deformable-GS](https://github.com/ingra14m/Lightweight-Deformable-GS) has been released. It offers 50% reduced storage, 200% increased FPS, and no decrease in rendering metrics.
15 | - **[2/27/2024]** Deformable-GS is accepted by CVPR 2024. Our another work, [SC-GS](https://yihua7.github.io/SC-GS-web/) (with higher quality, less points and faster FPS than vanilla 3D-GS), is also accepted. See you in Seattle.
16 | - **[11/16/2023]** Full code and real-time viewer released.
17 | - **[11/4/2023]** update the computation of LPIPS in metrics.py. Previously, the `lpipsPyTorch` was unable to execute on CUDA, prompting us to switch to the `lpips` library (~20x faster).
18 | - **[10/25/2023]** update **real-time viewer** on project page. Many, many thanks to @[yihua7](https://github.com/yihua7) for implementing the real-time viewer adapted for Deformable-GS. Also, thanks to @[ashawkey](https://github.com/ashawkey) for releasing the original GUI.
19 |
20 |
21 |
22 | ## Dataset
23 |
24 | In our paper, we use:
25 |
26 | - synthetic dataset from [D-NeRF](https://www.albertpumarola.com/research/D-NeRF/index.html).
27 | - real-world dataset from [NeRF-DS](https://jokeryan.github.io/projects/nerf-ds/) and [Hyper-NeRF](https://hypernerf.github.io/).
28 | - The dataset in the supplementary materials comes from [DeVRF](https://jia-wei-liu.github.io/DeVRF/).
29 |
30 | We organize the datasets as follows:
31 |
32 | ```shell
33 | ├── data
34 | │ | D-NeRF
35 | │ ├── hook
36 | │ ├── standup
37 | │ ├── ...
38 | │ | NeRF-DS
39 | │ ├── as
40 | │ ├── basin
41 | │ ├── ...
42 | │ | HyperNeRF
43 | │ ├── interp
44 | │ ├── misc
45 | │ ├── vrig
46 | ```
47 |
48 | > I have identified an **inconsistency in the D-NeRF's Lego dataset**. Specifically, the scenes corresponding to the training set differ from those in the test set. This discrepancy can be verified by observing the angle of the flipped Lego shovel. To meaningfully evaluate the performance of our method on this dataset, I recommend using the **validation set of the Lego dataset** as the test set. See more in [D-NeRF dataset used in Deformable-GS](https://github.com/ingra14m/Deformable-3D-Gaussians/releases/tag/v0.1-pre-released)
49 |
50 |
51 |
52 | ## Pipeline
53 |
54 | 
55 |
56 |
57 |
58 | ## Run
59 |
60 | ### Environment
61 |
62 | ```shell
63 | git clone https://github.com/ingra14m/Deformable-3D-Gaussians --recursive
64 | cd Deformable-3D-Gaussians
65 |
66 | conda create -n deformable_gaussian_env python=3.7
67 | conda activate deformable_gaussian_env
68 |
69 | # install pytorch
70 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
71 |
72 | # install dependencies
73 | pip install -r requirements.txt
74 | ```
75 |
76 |
77 |
78 | ### Train
79 |
80 | **D-NeRF:**
81 |
82 | ```shell
83 | python train.py -s path/to/your/d-nerf/dataset -m output/exp-name --eval --is_blender
84 | ```
85 |
86 | **NeRF-DS/HyperNeRF:**
87 |
88 | ```shell
89 | python train.py -s path/to/your/real-world/dataset -m output/exp-name --eval --iterations 20000
90 | ```
91 |
92 | **6DoF Transformation:**
93 |
94 | We have also implemented the 6DoF transformation of 3D-GS, which may lead to an improvement in metrics but will reduce the speed of training and inference.
95 |
96 | ```shell
97 | # D-NeRF
98 | python train.py -s path/to/your/d-nerf/dataset -m output/exp-name --eval --is_blender --is_6dof
99 |
100 | # NeRF-DS & HyperNeRF
101 | python train.py -s path/to/your/real-world/dataset -m output/exp-name --eval --is_6dof --iterations 20000
102 | ```
103 |
104 | You can also **train with the GUI:**
105 |
106 | ```shell
107 | python train_gui.py -s path/to/your/dataset -m output/exp-name --eval --is_blender
108 | ```
109 |
110 | - click `start` to start training, and click `stop` to stop training.
111 | - The GUI viewer is still under development, many buttons do not have corresponding functions currently. We plan to :
112 | - [ ] reload checkpoints from the pre-trained model.
113 | - [ ] Complete the functions of the other vacant buttons in the GUI.
114 |
115 |
116 |
117 | ### Render & Evaluation
118 |
119 | ```shell
120 | python render.py -m output/exp-name --mode render
121 | python metrics.py -m output/exp-name
122 | ```
123 |
124 | We provide several modes for rendering:
125 |
126 | - `render`: render all the test images
127 | - `time`: time interpolation tasks for D-NeRF dataset
128 | - `all`: time and view synthesis tasks for D-NeRF dataset
129 | - `view`: view synthesis tasks for D-NeRF dataset
130 | - `original`: time and view synthesis tasks for real-world dataset
131 |
132 |
133 |
134 | ## Results
135 |
136 | ### D-NeRF Dataset
137 |
138 | **Quantitative Results**
139 |
140 |
141 |
142 | **Qualitative Results**
143 |
144 |
145 |
146 |
147 |
148 | **400x400 Resolution**
149 |
150 | | | PSNR | SSIM | LPIPS (VGG) | FPS | Mem | Num. (k) |
151 | | -------- | ----- | ------ | ----------- | ---- | ----- | -------- |
152 | | bouncing | 41.46 | 0.9958 | 0.0046 | 112 | 13.16 | 55622 |
153 | | hell | 42.11 | 0.9885 | 0.0153 | 375 | 3.72 | 15733 |
154 | | hook | 37.77 | 0.9897 | 0.0103 | 128 | 11.74 | 49613 |
155 | | jump | 39.10 | 0.9930 | 0.0090 | 217 | 6.81 | 28808 |
156 | | mutant | 43.73 | 0.9969 | 0.0029 | 124 | 11.45 | 48423 |
157 | | standup | 45.38 | 0.9967 | 0.0032 | 210 | 5.94 | 25102 |
158 | | trex | 38.40 | 0.9959 | 0.0041 | 85 | 18.6 | 78624 |
159 | | Average | 41.14 | 0.9938 | 0.0070 | 179 | 10.20 | 43132 |
160 |
161 | ### NeRF-DS Dataset
162 |
163 |
164 |
165 | See more visualization on our [project page](https://ingra14m.github.io/Deformable-Gaussians/).
166 |
167 |
168 |
169 | ### HyperNeRF Dataset
170 |
171 | Since the **camera pose** in HyperNeRF is less precise compared to NeRF-DS, we use HyperNeRF as a reference for partial visualization and the display of Failure Cases, but do not include it in the calculation of quantitative metrics. The results of the HyperNeRF dataset can be viewed on the [project page](https://ingra14m.github.io/Deformable-Gaussians/).
172 |
173 |
174 |
175 | ### Real-Time Viewer
176 |
177 | https://github.com/ingra14m/Deformable-3D-Gaussians/assets/63096187/ec26d0b9-c126-4e23-b773-dcedcf386f36
178 |
179 |
180 |
181 | ## Acknowledgments
182 |
183 | We sincerely thank the authors of [3D-GS](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/), [D-NeRF](https://www.albertpumarola.com/research/D-NeRF/index.html), [HyperNeRF](https://hypernerf.github.io/), [NeRF-DS](https://jokeryan.github.io/projects/nerf-ds/), and [DeVRF](https://jia-wei-liu.github.io/DeVRF/), whose codes and datasets were used in our work. We thank [Zihao Wang](https://github.com/Alen-Wong) for the debugging in the early stage, preventing this work from sinking. We also thank the reviewers and AC for not being influenced by PR, and fairly evaluating our work. This work was mainly supported by ByteDance MMLab.
184 |
185 |
186 |
187 |
188 | ## BibTex
189 |
190 | ```
191 | @article{yang2023deformable3dgs,
192 | title={Deformable 3D Gaussians for High-Fidelity Monocular Dynamic Scene Reconstruction},
193 | author={Yang, Ziyi and Gao, Xinyu and Zhou, Wen and Jiao, Shaohui and Zhang, Yuqing and Jin, Xiaogang},
194 | journal={arXiv preprint arXiv:2309.13101},
195 | year={2023}
196 | }
197 | ```
198 |
199 | And thanks to the authors of [3D Gaussians](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) for their excellent code, please consider also cite this repository:
200 |
201 | ```
202 | @Article{kerbl3Dgaussians,
203 | author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
204 | title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
205 | journal = {ACM Transactions on Graphics},
206 | number = {4},
207 | volume = {42},
208 | month = {July},
209 | year = {2023},
210 | url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
211 | }
212 | ```
213 |
214 |
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from argparse import ArgumentParser, Namespace
13 | import sys
14 | import os
15 |
16 |
17 | class GroupParams:
18 | pass
19 |
20 |
21 | class ParamGroup:
22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False):
23 | group = parser.add_argument_group(name)
24 | for key, value in vars(self).items():
25 | shorthand = False
26 | if key.startswith("_"):
27 | shorthand = True
28 | key = key[1:]
29 | t = type(value)
30 | value = value if not fill_none else None
31 | if shorthand:
32 | if t == bool:
33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
34 | else:
35 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
36 | else:
37 | if t == bool:
38 | group.add_argument("--" + key, default=value, action="store_true")
39 | else:
40 | group.add_argument("--" + key, default=value, type=t)
41 |
42 | def extract(self, args):
43 | group = GroupParams()
44 | for arg in vars(args).items():
45 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
46 | setattr(group, arg[0], arg[1])
47 | return group
48 |
49 |
50 | class ModelParams(ParamGroup):
51 | def __init__(self, parser, sentinel=False):
52 | self.sh_degree = 3
53 | self._source_path = ""
54 | self._model_path = ""
55 | self._images = "images"
56 | self._resolution = -1
57 | self._white_background = False
58 | self.data_device = "cuda"
59 | self.eval = False
60 | self.load2gpu_on_the_fly = False
61 | self.is_blender = False
62 | self.is_6dof = False
63 | super().__init__(parser, "Loading Parameters", sentinel)
64 |
65 | def extract(self, args):
66 | g = super().extract(args)
67 | g.source_path = os.path.abspath(g.source_path)
68 | return g
69 |
70 |
71 | class PipelineParams(ParamGroup):
72 | def __init__(self, parser):
73 | self.convert_SHs_python = False
74 | self.compute_cov3D_python = False
75 | self.debug = False
76 | super().__init__(parser, "Pipeline Parameters")
77 |
78 |
79 | class OptimizationParams(ParamGroup):
80 | def __init__(self, parser):
81 | self.iterations = 40_000
82 | self.warm_up = 3_000
83 | self.position_lr_init = 0.00016
84 | self.position_lr_final = 0.0000016
85 | self.position_lr_delay_mult = 0.01
86 | self.position_lr_max_steps = 30_000
87 | self.deform_lr_max_steps = 40_000
88 | self.feature_lr = 0.0025
89 | self.opacity_lr = 0.05
90 | self.scaling_lr = 0.001
91 | self.rotation_lr = 0.001
92 | self.percent_dense = 0.01
93 | self.lambda_dssim = 0.2
94 | self.densification_interval = 100
95 | self.opacity_reset_interval = 3000
96 | self.densify_from_iter = 500
97 | self.densify_until_iter = 15_000
98 | self.densify_grad_threshold = 0.0007
99 | super().__init__(parser, "Optimization Parameters")
100 |
101 |
102 | def get_combined_args(parser: ArgumentParser):
103 | cmdlne_string = sys.argv[1:]
104 | cfgfile_string = "Namespace()"
105 | args_cmdline = parser.parse_args(cmdlne_string)
106 |
107 | try:
108 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
109 | print("Looking for config file in", cfgfilepath)
110 | with open(cfgfilepath) as cfg_file:
111 | print("Config file found: {}".format(cfgfilepath))
112 | cfgfile_string = cfg_file.read()
113 | except TypeError:
114 | print("Config file not found at")
115 | pass
116 | args_cfgfile = eval(cfgfile_string)
117 |
118 | merged_dict = vars(args_cfgfile).copy()
119 | for k, v in vars(args_cmdline).items():
120 | if v != None:
121 | merged_dict[k] = v
122 | return Namespace(**merged_dict)
123 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/results/D-NeRF/Quantitative.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/Quantitative.jpg
--------------------------------------------------------------------------------
/assets/results/D-NeRF/bouncing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/bouncing.gif
--------------------------------------------------------------------------------
/assets/results/D-NeRF/hell.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/hell.gif
--------------------------------------------------------------------------------
/assets/results/D-NeRF/hook.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/hook.gif
--------------------------------------------------------------------------------
/assets/results/D-NeRF/jump.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/jump.gif
--------------------------------------------------------------------------------
/assets/results/D-NeRF/lego.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/lego.gif
--------------------------------------------------------------------------------
/assets/results/D-NeRF/mutant.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/mutant.gif
--------------------------------------------------------------------------------
/assets/results/D-NeRF/stand.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/stand.gif
--------------------------------------------------------------------------------
/assets/results/D-NeRF/trex.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/D-NeRF/trex.gif
--------------------------------------------------------------------------------
/assets/results/NeRF-DS/Quantitative.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/results/NeRF-DS/Quantitative.jpg
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ingra14m/Deformable-3D-Gaussians/ac423728df780ee57a82025a6d7c4c7312fdf9bb/assets/teaser.png
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | from argparse import ArgumentParser
14 | import shutil
15 |
16 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository.
17 | parser = ArgumentParser("Colmap converter")
18 | parser.add_argument("--no_gpu", action='store_true')
19 | parser.add_argument("--skip_matching", action='store_true')
20 | parser.add_argument("--source_path", "-s", required=True, type=str)
21 | parser.add_argument("--camera", default="OPENCV", type=str)
22 | parser.add_argument("--colmap_executable", default="", type=str)
23 | parser.add_argument("--resize", action="store_true")
24 | parser.add_argument("--magick_executable", default="", type=str)
25 | args = parser.parse_args()
26 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
27 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
28 | use_gpu = 1 if not args.no_gpu else 0
29 |
30 | if not args.skip_matching:
31 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
32 |
33 | ## Feature extraction
34 | os.system(colmap_command + " feature_extractor "\
35 | "--database_path " + args.source_path + "/distorted/database.db \
36 | --image_path " + args.source_path + "/input \
37 | --ImageReader.single_camera 1 \
38 | --ImageReader.camera_model " + args.camera + " \
39 | --SiftExtraction.use_gpu " + str(use_gpu))
40 |
41 | ## Feature matching
42 | os.system(colmap_command + " exhaustive_matcher \
43 | --database_path " + args.source_path + "/distorted/database.db \
44 | --SiftMatching.use_gpu " + str(use_gpu))
45 |
46 | ### Bundle adjustment
47 | # The default Mapper tolerance is unnecessarily large,
48 | # decreasing it speeds up bundle adjustment steps.
49 | os.system(colmap_command + " mapper \
50 | --database_path " + args.source_path + "/distorted/database.db \
51 | --image_path " + args.source_path + "/input \
52 | --output_path " + args.source_path + "/distorted/sparse \
53 | --Mapper.ba_global_function_tolerance=0.000001")
54 |
55 | ### Image undistortion
56 | ## We need to undistort our images into ideal pinhole intrinsics.
57 | os.system(colmap_command + " image_undistorter \
58 | --image_path " + args.source_path + "/input \
59 | --input_path " + args.source_path + "/distorted/sparse/0 \
60 | --output_path " + args.source_path + "\
61 | --output_type COLMAP")
62 |
63 | files = os.listdir(args.source_path + "/sparse")
64 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
65 | # Copy each file from the source directory to the destination directory
66 | for file in files:
67 | if file == '0':
68 | continue
69 | source_file = os.path.join(args.source_path, "sparse", file)
70 | destination_file = os.path.join(args.source_path, "sparse", "0", file)
71 | shutil.move(source_file, destination_file)
72 |
73 | if(args.resize):
74 | print("Copying and resizing...")
75 |
76 | # Resize images.
77 | os.makedirs(args.source_path + "/images_2", exist_ok=True)
78 | os.makedirs(args.source_path + "/images_4", exist_ok=True)
79 | os.makedirs(args.source_path + "/images_8", exist_ok=True)
80 | # Get the list of files in the source directory
81 | files = os.listdir(args.source_path + "/images")
82 | # Copy each file from the source directory to the destination directory
83 | for file in files:
84 | source_file = os.path.join(args.source_path, "images", file)
85 |
86 | destination_file = os.path.join(args.source_path, "images_2", file)
87 | shutil.copy2(source_file, destination_file)
88 | os.system(magick_command + " mogrify -resize 50% " + destination_file)
89 |
90 | destination_file = os.path.join(args.source_path, "images_4", file)
91 | shutil.copy2(source_file, destination_file)
92 | os.system(magick_command + " mogrify -resize 25% " + destination_file)
93 |
94 | destination_file = os.path.join(args.source_path, "images_8", file)
95 | shutil.copy2(source_file, destination_file)
96 | os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
97 |
98 | print("Done.")
--------------------------------------------------------------------------------
/full_eval.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | from argparse import ArgumentParser
14 |
15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
17 | tanks_and_temples_scenes = ["truck", "train"]
18 | deep_blending_scenes = ["drjohnson", "playroom"]
19 |
20 | parser = ArgumentParser(description="Full evaluation script parameters")
21 | parser.add_argument("--skip_training", action="store_true")
22 | parser.add_argument("--skip_rendering", action="store_true")
23 | parser.add_argument("--skip_metrics", action="store_true")
24 | parser.add_argument("--output_path", default="./eval")
25 | args, _ = parser.parse_known_args()
26 |
27 | all_scenes = []
28 | all_scenes.extend(mipnerf360_outdoor_scenes)
29 | all_scenes.extend(mipnerf360_indoor_scenes)
30 | all_scenes.extend(tanks_and_temples_scenes)
31 | all_scenes.extend(deep_blending_scenes)
32 |
33 | if not args.skip_training or not args.skip_rendering:
34 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str)
35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
36 | parser.add_argument("--deepblending", "-db", required=True, type=str)
37 | args = parser.parse_args()
38 |
39 | if not args.skip_training:
40 | common_args = " --quiet --eval --test_iterations -1 "
41 | for scene in mipnerf360_outdoor_scenes:
42 | source = args.mipnerf360 + "/" + scene
43 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args)
44 | for scene in mipnerf360_indoor_scenes:
45 | source = args.mipnerf360 + "/" + scene
46 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args)
47 | for scene in tanks_and_temples_scenes:
48 | source = args.tanksandtemples + "/" + scene
49 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
50 | for scene in deep_blending_scenes:
51 | source = args.deepblending + "/" + scene
52 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
53 |
54 | if not args.skip_rendering:
55 | all_sources = []
56 | for scene in mipnerf360_outdoor_scenes:
57 | all_sources.append(args.mipnerf360 + "/" + scene)
58 | for scene in mipnerf360_indoor_scenes:
59 | all_sources.append(args.mipnerf360 + "/" + scene)
60 | for scene in tanks_and_temples_scenes:
61 | all_sources.append(args.tanksandtemples + "/" + scene)
62 | for scene in deep_blending_scenes:
63 | all_sources.append(args.deepblending + "/" + scene)
64 |
65 | common_args = " --quiet --eval --skip_train"
66 | for scene, source in zip(all_scenes, all_sources):
67 | os.system(
68 | "python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
69 | os.system(
70 | "python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
71 |
72 | if not args.skip_metrics:
73 | scenes_string = ""
74 | for scene in all_scenes:
75 | scenes_string += "\"" + args.output_path + "/" + scene + "\" "
76 |
77 | os.system("python metrics.py -m " + scenes_string)
78 |
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15 | from scene.gaussian_model import GaussianModel
16 | from utils.sh_utils import eval_sh
17 | from utils.rigid_utils import from_homogenous, to_homogenous
18 |
19 |
20 | def quaternion_multiply(q1, q2):
21 | w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
22 | w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]
23 |
24 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
25 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
26 | y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
27 | z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
28 |
29 | return torch.stack((w, x, y, z), dim=-1)
30 |
31 |
32 | def render(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, d_xyz, d_rotation, d_scaling, is_6dof=False,
33 | scaling_modifier=1.0, override_color=None):
34 | """
35 | Render the scene.
36 |
37 | Background tensor (bg_color) must be on GPU!
38 | """
39 |
40 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
41 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
42 | screenspace_points_densify = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
43 | try:
44 | screenspace_points.retain_grad()
45 | screenspace_points_densify.retain_grad()
46 | except:
47 | pass
48 |
49 | # Set up rasterization configuration
50 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
51 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
52 |
53 | raster_settings = GaussianRasterizationSettings(
54 | image_height=int(viewpoint_camera.image_height),
55 | image_width=int(viewpoint_camera.image_width),
56 | tanfovx=tanfovx,
57 | tanfovy=tanfovy,
58 | bg=bg_color,
59 | scale_modifier=scaling_modifier,
60 | viewmatrix=viewpoint_camera.world_view_transform,
61 | projmatrix=viewpoint_camera.full_proj_transform,
62 | sh_degree=pc.active_sh_degree,
63 | campos=viewpoint_camera.camera_center,
64 | prefiltered=False,
65 | debug=pipe.debug,
66 | )
67 |
68 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
69 |
70 | if is_6dof:
71 | if torch.is_tensor(d_xyz) is False:
72 | means3D = pc.get_xyz
73 | else:
74 | means3D = from_homogenous(
75 | torch.bmm(d_xyz, to_homogenous(pc.get_xyz).unsqueeze(-1)).squeeze(-1))
76 | else:
77 | means3D = pc.get_xyz + d_xyz
78 | opacity = pc.get_opacity
79 |
80 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
81 | # scaling / rotation by the rasterizer.
82 | scales = None
83 | rotations = None
84 | cov3D_precomp = None
85 | if pipe.compute_cov3D_python:
86 | cov3D_precomp = pc.get_covariance(scaling_modifier)
87 | else:
88 | scales = pc.get_scaling + d_scaling
89 | rotations = pc.get_rotation + d_rotation
90 |
91 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
92 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
93 | shs = None
94 | colors_precomp = None
95 | if colors_precomp is None:
96 | if pipe.convert_SHs_python:
97 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2)
98 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
99 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
100 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
101 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
102 | else:
103 | shs = pc.get_features
104 | else:
105 | colors_precomp = override_color
106 |
107 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
108 | rendered_image, radii, depth = rasterizer(
109 | means3D=means3D,
110 | means2D=screenspace_points,
111 | means2D_densify=screenspace_points_densify,
112 | shs=shs,
113 | colors_precomp=colors_precomp,
114 | opacities=opacity,
115 | scales=scales,
116 | rotations=rotations,
117 | cov3D_precomp=cov3D_precomp)
118 |
119 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
120 | # They will be excluded from value updates used in the splitting criteria.
121 | return {"render": rendered_image,
122 | "viewspace_points": screenspace_points,
123 | "viewspace_points_densify": screenspace_points_densify,
124 | "visibility_filter": radii > 0,
125 | "radii": radii,
126 | "depth": depth}
127 |
--------------------------------------------------------------------------------
/gaussian_renderer/network_gui.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import traceback
14 | import socket
15 | import json
16 | from scene.cameras import MiniCam
17 |
18 | host = "127.0.0.1"
19 | port = 6009
20 |
21 | conn = None
22 | addr = None
23 |
24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25 |
26 |
27 | def init(wish_host, wish_port):
28 | global host, port, listener
29 | host = wish_host
30 | port = wish_port
31 | listener.bind((host, port))
32 | listener.listen()
33 | listener.settimeout(0)
34 |
35 |
36 | def try_connect():
37 | global conn, addr, listener
38 | try:
39 | conn, addr = listener.accept()
40 | print(f"\nConnected by {addr}")
41 | conn.settimeout(None)
42 | except Exception as inst:
43 | pass
44 |
45 |
46 | def read():
47 | global conn
48 | messageLength = conn.recv(4)
49 | messageLength = int.from_bytes(messageLength, 'little')
50 | message = conn.recv(messageLength)
51 | return json.loads(message.decode("utf-8"))
52 |
53 |
54 | def send(message_bytes, verify):
55 | global conn
56 | if message_bytes != None:
57 | conn.sendall(message_bytes)
58 | conn.sendall(len(verify).to_bytes(4, 'little'))
59 | conn.sendall(bytes(verify, 'ascii'))
60 |
61 |
62 | def receive():
63 | message = read()
64 |
65 | width = message["resolution_x"]
66 | height = message["resolution_y"]
67 |
68 | if width != 0 and height != 0:
69 | try:
70 | do_training = bool(message["train"])
71 | fovy = message["fov_y"]
72 | fovx = message["fov_x"]
73 | znear = message["z_near"]
74 | zfar = message["z_far"]
75 | do_shs_python = bool(message["shs_python"])
76 | do_rot_scale_python = bool(message["rot_scale_python"])
77 | keep_alive = bool(message["keep_alive"])
78 | scaling_modifier = message["scaling_modifier"]
79 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
80 | world_view_transform[:, 1] = -world_view_transform[:, 1]
81 | world_view_transform[:, 2] = -world_view_transform[:, 2]
82 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
83 | full_proj_transform[:, 1] = -full_proj_transform[:, 1]
84 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
85 | except Exception as e:
86 | print("")
87 | traceback.print_exc()
88 | raise e
89 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
90 | else:
91 | return None, None, None, None, None, None
92 |
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y)
22 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 |
18 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from pathlib import Path
13 | import os
14 | from PIL import Image
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | from utils.loss_utils import ssim
18 | # from lpipsPyTorch import lpips
19 | import lpips
20 | import json
21 | from tqdm import tqdm
22 | from utils.image_utils import psnr
23 | from argparse import ArgumentParser
24 |
25 |
26 | def readImages(renders_dir, gt_dir):
27 | renders = []
28 | gts = []
29 | image_names = []
30 | for fname in os.listdir(renders_dir):
31 | render = Image.open(renders_dir / fname)
32 | gt = Image.open(gt_dir / fname)
33 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
34 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
35 | image_names.append(fname)
36 | return renders, gts, image_names
37 |
38 |
39 | def evaluate(model_paths):
40 | full_dict = {}
41 | per_view_dict = {}
42 | full_dict_polytopeonly = {}
43 | per_view_dict_polytopeonly = {}
44 | print("")
45 |
46 | for scene_dir in model_paths:
47 | try:
48 | print("Scene:", scene_dir)
49 | full_dict[scene_dir] = {}
50 | per_view_dict[scene_dir] = {}
51 | full_dict_polytopeonly[scene_dir] = {}
52 | per_view_dict_polytopeonly[scene_dir] = {}
53 |
54 | test_dir = Path(scene_dir) / "test"
55 |
56 | for method in os.listdir(test_dir):
57 | if not method.startswith("ours"):
58 | continue
59 | print("Method:", method)
60 |
61 | full_dict[scene_dir][method] = {}
62 | per_view_dict[scene_dir][method] = {}
63 | full_dict_polytopeonly[scene_dir][method] = {}
64 | per_view_dict_polytopeonly[scene_dir][method] = {}
65 |
66 | method_dir = test_dir / method
67 | gt_dir = method_dir / "gt"
68 | renders_dir = method_dir / "renders"
69 | renders, gts, image_names = readImages(renders_dir, gt_dir)
70 |
71 | ssims = []
72 | psnrs = []
73 | lpipss = []
74 |
75 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
76 | ssims.append(ssim(renders[idx], gts[idx]))
77 | psnrs.append(psnr(renders[idx], gts[idx]))
78 | lpipss.append(lpips_fn(renders[idx], gts[idx]).detach())
79 |
80 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
81 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
82 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
83 | print("")
84 |
85 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
86 | "PSNR": torch.tensor(psnrs).mean().item(),
87 | "LPIPS": torch.tensor(lpipss).mean().item()})
88 | per_view_dict[scene_dir][method].update(
89 | {"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
90 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
91 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
92 |
93 | with open(scene_dir + "/results.json", 'w') as fp:
94 | json.dump(full_dict[scene_dir], fp, indent=True)
95 | with open(scene_dir + "/per_view.json", 'w') as fp:
96 | json.dump(per_view_dict[scene_dir], fp, indent=True)
97 | except:
98 | print("Unable to compute metrics for model", scene_dir)
99 |
100 |
101 | if __name__ == "__main__":
102 | device = torch.device("cuda:0")
103 | torch.cuda.set_device(device)
104 | lpips_fn = lpips.LPIPS(net='vgg').to(device)
105 |
106 | # Set up command line argument parser
107 | parser = ArgumentParser(description="Training script parameters")
108 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
109 | args = parser.parse_args()
110 | evaluate(args.model_paths)
111 |
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from scene import Scene, DeformModel
14 | import os
15 | from tqdm import tqdm
16 | from os import makedirs
17 | from gaussian_renderer import render
18 | import torchvision
19 | from utils.general_utils import safe_state
20 | from utils.pose_utils import pose_spherical, render_wander_path
21 | from argparse import ArgumentParser
22 | from arguments import ModelParams, PipelineParams, get_combined_args
23 | from gaussian_renderer import GaussianModel
24 | import imageio
25 | import numpy as np
26 | import time
27 |
28 |
29 | def render_set(model_path, load2gpu_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, deform):
30 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
31 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
32 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth")
33 |
34 | makedirs(render_path, exist_ok=True)
35 | makedirs(gts_path, exist_ok=True)
36 | makedirs(depth_path, exist_ok=True)
37 |
38 | t_list = []
39 |
40 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
41 | if load2gpu_on_the_fly:
42 | view.load2device()
43 | fid = view.fid
44 | xyz = gaussians.get_xyz
45 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
46 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input)
47 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof)
48 | rendering = results["render"]
49 | depth = results["depth"]
50 | depth = depth / (depth.max() + 1e-5)
51 |
52 | gt = view.original_image[0:3, :, :]
53 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
54 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
55 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png"))
56 |
57 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
58 | fid = view.fid
59 | xyz = gaussians.get_xyz
60 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
61 |
62 | torch.cuda.synchronize()
63 | t_start = time.time()
64 |
65 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input)
66 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof)
67 |
68 | torch.cuda.synchronize()
69 | t_end = time.time()
70 | t_list.append(t_end - t_start)
71 |
72 | t = np.array(t_list[5:])
73 | fps = 1.0 / t.mean()
74 | print(f'Test FPS: \033[1;35m{fps:.5f}\033[0m, Num. of GS: {xyz.shape[0]}')
75 |
76 |
77 | def interpolate_time(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, deform):
78 | render_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "renders")
79 | depth_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "depth")
80 |
81 | makedirs(render_path, exist_ok=True)
82 | makedirs(depth_path, exist_ok=True)
83 |
84 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
85 |
86 | frame = 150
87 | idx = torch.randint(0, len(views), (1,)).item()
88 | view = views[idx]
89 | renderings = []
90 | for t in tqdm(range(0, frame, 1), desc="Rendering progress"):
91 | fid = torch.Tensor([t / (frame - 1)]).cuda()
92 | xyz = gaussians.get_xyz
93 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
94 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input)
95 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof)
96 | rendering = results["render"]
97 | renderings.append(to8b(rendering.cpu().numpy()))
98 | depth = results["depth"]
99 | depth = depth / (depth.max() + 1e-5)
100 |
101 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(t) + ".png"))
102 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(t) + ".png"))
103 |
104 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1)
105 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8)
106 |
107 |
108 | def interpolate_view(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, timer):
109 | render_path = os.path.join(model_path, name, "interpolate_view_{}".format(iteration), "renders")
110 | depth_path = os.path.join(model_path, name, "interpolate_view_{}".format(iteration), "depth")
111 | # acc_path = os.path.join(model_path, name, "interpolate_view_{}".format(iteration), "acc")
112 |
113 | makedirs(render_path, exist_ok=True)
114 | makedirs(depth_path, exist_ok=True)
115 | # makedirs(acc_path, exist_ok=True)
116 |
117 | frame = 150
118 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
119 |
120 | idx = torch.randint(0, len(views), (1,)).item()
121 | view = views[idx] # Choose a specific time for rendering
122 |
123 | render_poses = torch.stack(render_wander_path(view), 0)
124 | # render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, frame + 1)[:-1]],
125 | # 0)
126 |
127 | renderings = []
128 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")):
129 | fid = view.fid
130 |
131 | matrix = np.linalg.inv(np.array(pose))
132 | R = -np.transpose(matrix[:3, :3])
133 | R[:, 0] = -R[:, 0]
134 | T = -matrix[:3, 3]
135 |
136 | view.reset_extrinsic(R, T)
137 |
138 | xyz = gaussians.get_xyz
139 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
140 | d_xyz, d_rotation, d_scaling = timer.step(xyz.detach(), time_input)
141 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof)
142 | rendering = results["render"]
143 | renderings.append(to8b(rendering.cpu().numpy()))
144 | depth = results["depth"]
145 | depth = depth / (depth.max() + 1e-5)
146 | # acc = results["acc"]
147 |
148 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png"))
149 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png"))
150 | # torchvision.utils.save_image(acc, os.path.join(acc_path, '{0:05d}'.format(i) + ".png"))
151 |
152 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1)
153 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8)
154 |
155 |
156 | def interpolate_all(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, deform):
157 | render_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "renders")
158 | depth_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "depth")
159 |
160 | makedirs(render_path, exist_ok=True)
161 | makedirs(depth_path, exist_ok=True)
162 |
163 | frame = 150
164 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, frame + 1)[:-1]],
165 | 0)
166 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
167 |
168 | idx = torch.randint(0, len(views), (1,)).item()
169 | view = views[idx] # Choose a specific time for rendering
170 |
171 | renderings = []
172 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")):
173 | fid = torch.Tensor([i / (frame - 1)]).cuda()
174 |
175 | matrix = np.linalg.inv(np.array(pose))
176 | R = -np.transpose(matrix[:3, :3])
177 | R[:, 0] = -R[:, 0]
178 | T = -matrix[:3, 3]
179 |
180 | view.reset_extrinsic(R, T)
181 |
182 | xyz = gaussians.get_xyz
183 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
184 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input)
185 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof)
186 | rendering = results["render"]
187 | renderings.append(to8b(rendering.cpu().numpy()))
188 | depth = results["depth"]
189 | depth = depth / (depth.max() + 1e-5)
190 |
191 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png"))
192 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png"))
193 |
194 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1)
195 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8)
196 |
197 |
198 | def interpolate_poses(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background, timer):
199 | render_path = os.path.join(model_path, name, "interpolate_pose_{}".format(iteration), "renders")
200 | depth_path = os.path.join(model_path, name, "interpolate_pose_{}".format(iteration), "depth")
201 |
202 | makedirs(render_path, exist_ok=True)
203 | makedirs(depth_path, exist_ok=True)
204 | # makedirs(acc_path, exist_ok=True)
205 | frame = 520
206 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
207 |
208 | idx = torch.randint(0, len(views), (1,)).item()
209 | view_begin = views[0] # Choose a specific time for rendering
210 | view_end = views[-1]
211 | view = views[idx]
212 |
213 | R_begin = view_begin.R
214 | R_end = view_end.R
215 | t_begin = view_begin.T
216 | t_end = view_end.T
217 |
218 | renderings = []
219 | for i in tqdm(range(frame), desc="Rendering progress"):
220 | fid = view.fid
221 |
222 | ratio = i / (frame - 1)
223 |
224 | R_cur = (1 - ratio) * R_begin + ratio * R_end
225 | T_cur = (1 - ratio) * t_begin + ratio * t_end
226 |
227 | view.reset_extrinsic(R_cur, T_cur)
228 |
229 | xyz = gaussians.get_xyz
230 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
231 | d_xyz, d_rotation, d_scaling = timer.step(xyz.detach(), time_input)
232 |
233 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof)
234 | rendering = results["render"]
235 | renderings.append(to8b(rendering.cpu().numpy()))
236 | depth = results["depth"]
237 | depth = depth / (depth.max() + 1e-5)
238 |
239 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1)
240 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8)
241 |
242 |
243 | def interpolate_view_original(model_path, load2gpt_on_the_fly, is_6dof, name, iteration, views, gaussians, pipeline, background,
244 | timer):
245 | render_path = os.path.join(model_path, name, "interpolate_hyper_view_{}".format(iteration), "renders")
246 | depth_path = os.path.join(model_path, name, "interpolate_hyper_view_{}".format(iteration), "depth")
247 | # acc_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "acc")
248 |
249 | makedirs(render_path, exist_ok=True)
250 | makedirs(depth_path, exist_ok=True)
251 |
252 | frame = 1000
253 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
254 |
255 | R = []
256 | T = []
257 | for view in views:
258 | R.append(view.R)
259 | T.append(view.T)
260 |
261 | view = views[0]
262 | renderings = []
263 | for i in tqdm(range(frame), desc="Rendering progress"):
264 | fid = torch.Tensor([i / (frame - 1)]).cuda()
265 |
266 | query_idx = i / frame * len(views)
267 | begin_idx = int(np.floor(query_idx))
268 | end_idx = int(np.ceil(query_idx))
269 | if end_idx == len(views):
270 | break
271 | view_begin = views[begin_idx]
272 | view_end = views[end_idx]
273 | R_begin = view_begin.R
274 | R_end = view_end.R
275 | t_begin = view_begin.T
276 | t_end = view_end.T
277 |
278 | ratio = query_idx - begin_idx
279 |
280 | R_cur = (1 - ratio) * R_begin + ratio * R_end
281 | T_cur = (1 - ratio) * t_begin + ratio * t_end
282 |
283 | view.reset_extrinsic(R_cur, T_cur)
284 |
285 | xyz = gaussians.get_xyz
286 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
287 | d_xyz, d_rotation, d_scaling = timer.step(xyz.detach(), time_input)
288 |
289 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, is_6dof)
290 | rendering = results["render"]
291 | renderings.append(to8b(rendering.cpu().numpy()))
292 | depth = results["depth"]
293 | depth = depth / (depth.max() + 1e-5)
294 |
295 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1)
296 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8)
297 |
298 |
299 | def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, skip_test: bool,
300 | mode: str):
301 | with torch.no_grad():
302 | gaussians = GaussianModel(dataset.sh_degree)
303 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
304 | deform = DeformModel(dataset.is_blender, dataset.is_6dof)
305 | deform.load_weights(dataset.model_path)
306 |
307 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
308 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
309 |
310 | if mode == "render":
311 | render_func = render_set
312 | elif mode == "time":
313 | render_func = interpolate_time
314 | elif mode == "view":
315 | render_func = interpolate_view
316 | elif mode == "pose":
317 | render_func = interpolate_poses
318 | elif mode == "original":
319 | render_func = interpolate_view_original
320 | else:
321 | render_func = interpolate_all
322 |
323 | if not skip_train:
324 | render_func(dataset.model_path, dataset.load2gpu_on_the_fly, dataset.is_6dof, "train", scene.loaded_iter,
325 | scene.getTrainCameras(), gaussians, pipeline,
326 | background, deform)
327 |
328 | if not skip_test:
329 | render_func(dataset.model_path, dataset.load2gpu_on_the_fly, dataset.is_6dof, "test", scene.loaded_iter,
330 | scene.getTestCameras(), gaussians, pipeline,
331 | background, deform)
332 |
333 |
334 | if __name__ == "__main__":
335 | # Set up command line argument parser
336 | parser = ArgumentParser(description="Testing script parameters")
337 | model = ModelParams(parser, sentinel=True)
338 | pipeline = PipelineParams(parser)
339 | parser.add_argument("--iteration", default=-1, type=int)
340 | parser.add_argument("--skip_train", action="store_true")
341 | parser.add_argument("--skip_test", action="store_true")
342 | parser.add_argument("--quiet", action="store_true")
343 | parser.add_argument("--mode", default='render', choices=['render', 'time', 'view', 'all', 'pose', 'original'])
344 | args = get_combined_args(parser)
345 | print("Rendering " + args.model_path)
346 |
347 | # Initialize system state (RNG)
348 | safe_state(args.quiet)
349 |
350 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.mode)
351 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | submodules/depth-diff-gaussian-rasterization
2 | submodules/simple-knn
3 | plyfile==0.8.1
4 | tqdm
5 | imageio==2.27.0
6 | opencv-python
7 | imageio-ffmpeg
8 | scipy
9 | dearpygui
10 | lpips
11 |
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.gaussian_model import GaussianModel
18 | from scene.deform_model import DeformModel
19 | from arguments import ModelParams
20 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
21 |
22 |
23 | class Scene:
24 | gaussians: GaussianModel
25 |
26 | def __init__(self, args: ModelParams, gaussians: GaussianModel, load_iteration=None, shuffle=True,
27 | resolution_scales=[1.0]):
28 | """b
29 | :param path: Path to colmap scene main folder.
30 | """
31 | self.model_path = args.model_path
32 | self.loaded_iter = None
33 | self.gaussians = gaussians
34 |
35 | if load_iteration:
36 | if load_iteration == -1:
37 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
38 | else:
39 | self.loaded_iter = load_iteration
40 | print("Loading trained model at iteration {}".format(self.loaded_iter))
41 |
42 | self.train_cameras = {}
43 | self.test_cameras = {}
44 |
45 | if os.path.exists(os.path.join(args.source_path, "sparse")):
46 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
47 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
48 | print("Found transforms_train.json file, assuming Blender data set!")
49 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
50 | elif os.path.exists(os.path.join(args.source_path, "cameras_sphere.npz")):
51 | print("Found cameras_sphere.npz file, assuming DTU data set!")
52 | scene_info = sceneLoadTypeCallbacks["DTU"](args.source_path, "cameras_sphere.npz", "cameras_sphere.npz")
53 | elif os.path.exists(os.path.join(args.source_path, "dataset.json")):
54 | print("Found dataset.json file, assuming Nerfies data set!")
55 | scene_info = sceneLoadTypeCallbacks["nerfies"](args.source_path, args.eval)
56 | elif os.path.exists(os.path.join(args.source_path, "poses_bounds.npy")):
57 | print("Found calibration_full.json, assuming Neu3D data set!")
58 | scene_info = sceneLoadTypeCallbacks["plenopticVideo"](args.source_path, args.eval, 24)
59 | elif os.path.exists(os.path.join(args.source_path, "transforms.json")):
60 | print("Found calibration_full.json, assuming Dynamic-360 data set!")
61 | scene_info = sceneLoadTypeCallbacks["dynamic360"](args.source_path)
62 | else:
63 | assert False, "Could not recognize scene type!"
64 |
65 | if not self.loaded_iter:
66 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply"),
67 | 'wb') as dest_file:
68 | dest_file.write(src_file.read())
69 | json_cams = []
70 | camlist = []
71 | if scene_info.test_cameras:
72 | camlist.extend(scene_info.test_cameras)
73 | if scene_info.train_cameras:
74 | camlist.extend(scene_info.train_cameras)
75 | for id, cam in enumerate(camlist):
76 | json_cams.append(camera_to_JSON(id, cam))
77 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
78 | json.dump(json_cams, file)
79 |
80 | if shuffle:
81 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
82 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
83 |
84 | self.cameras_extent = scene_info.nerf_normalization["radius"]
85 |
86 | for resolution_scale in resolution_scales:
87 | print("Loading Training Cameras")
88 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale,
89 | args)
90 | print("Loading Test Cameras")
91 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale,
92 | args)
93 |
94 | if self.loaded_iter:
95 | self.gaussians.load_ply(os.path.join(self.model_path,
96 | "point_cloud",
97 | "iteration_" + str(self.loaded_iter),
98 | "point_cloud.ply"),
99 | og_number_points=len(scene_info.point_cloud.points))
100 | else:
101 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
102 |
103 | def save(self, iteration):
104 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
105 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
106 |
107 | def getTrainCameras(self, scale=1.0):
108 | return self.train_cameras[scale]
109 |
110 | def getTestCameras(self, scale=1.0):
111 | return self.test_cameras[scale]
112 |
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix
16 |
17 |
18 | class Camera(nn.Module):
19 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid,
20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", fid=None, depth=None):
21 | super(Camera, self).__init__()
22 |
23 | self.uid = uid
24 | self.colmap_id = colmap_id
25 | self.R = R
26 | self.T = T
27 | self.FoVx = FoVx
28 | self.FoVy = FoVy
29 | self.image_name = image_name
30 |
31 | try:
32 | self.data_device = torch.device(data_device)
33 | except Exception as e:
34 | print(e)
35 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device")
36 | self.data_device = torch.device("cuda")
37 |
38 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
39 | self.fid = torch.Tensor(np.array([fid])).to(self.data_device)
40 | self.image_width = self.original_image.shape[2]
41 | self.image_height = self.original_image.shape[1]
42 | self.depth = torch.Tensor(depth).to(self.data_device) if depth is not None else None
43 |
44 | if gt_alpha_mask is not None:
45 | self.original_image *= gt_alpha_mask.to(self.data_device)
46 | else:
47 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
48 |
49 | self.zfar = 100.0
50 | self.znear = 0.01
51 |
52 | self.trans = trans
53 | self.scale = scale
54 |
55 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).to(
56 | self.data_device)
57 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx,
58 | fovY=self.FoVy).transpose(0, 1).to(self.data_device)
59 | self.full_proj_transform = (
60 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
61 | self.camera_center = self.world_view_transform.inverse()[3, :3]
62 |
63 | def reset_extrinsic(self, R, T):
64 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, self.trans, self.scale)).transpose(0, 1).cuda()
65 | self.full_proj_transform = (
66 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
67 | self.camera_center = self.world_view_transform.inverse()[3, :3]
68 |
69 | def load2device(self, data_device='cuda'):
70 | self.original_image = self.original_image.to(data_device)
71 | self.world_view_transform = self.world_view_transform.to(data_device)
72 | self.projection_matrix = self.projection_matrix.to(data_device)
73 | self.full_proj_transform = self.full_proj_transform.to(data_device)
74 | self.camera_center = self.camera_center.to(data_device)
75 | self.fid = self.fid.to(data_device)
76 |
77 |
78 | class MiniCam:
79 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
80 | self.image_width = width
81 | self.image_height = height
82 | self.FoVy = fovy
83 | self.FoVx = fovx
84 | self.znear = znear
85 | self.zfar = zfar
86 | self.world_view_transform = world_view_transform
87 | self.full_proj_transform = full_proj_transform
88 | view_inv = torch.inverse(self.world_view_transform)
89 | self.camera_center = view_inv[3][:3]
90 |
--------------------------------------------------------------------------------
/scene/colmap_loader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import collections
14 | import struct
15 |
16 | CameraModel = collections.namedtuple(
17 | "CameraModel", ["model_id", "model_name", "num_params"])
18 | Camera = collections.namedtuple(
19 | "Camera", ["id", "model", "width", "height", "params"])
20 | BaseImage = collections.namedtuple(
21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22 | Point3D = collections.namedtuple(
23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24 | CAMERA_MODELS = {
25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32 | CameraModel(model_id=7, model_name="FOV", num_params=5),
33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36 | }
37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38 | for camera_model in CAMERA_MODELS])
39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40 | for camera_model in CAMERA_MODELS])
41 |
42 |
43 | def qvec2rotmat(qvec):
44 | return np.array([
45 | [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]])
54 |
55 |
56 | def rotmat2qvec(R):
57 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
58 | K = np.array([
59 | [Rxx - Ryy - Rzz, 0, 0, 0],
60 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
61 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
62 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
63 | eigvals, eigvecs = np.linalg.eigh(K)
64 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
65 | if qvec[0] < 0:
66 | qvec *= -1
67 | return qvec
68 |
69 |
70 | class Image(BaseImage):
71 | def qvec2rotmat(self):
72 | return qvec2rotmat(self.qvec)
73 |
74 |
75 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
76 | """Read and unpack the next bytes from a binary file.
77 | :param fid:
78 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
79 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
80 | :param endian_character: Any of {@, =, <, >, !}
81 | :return: Tuple of read and unpacked values.
82 | """
83 | data = fid.read(num_bytes)
84 | return struct.unpack(endian_character + format_char_sequence, data)
85 |
86 |
87 | def read_points3D_text(path):
88 | """
89 | see: src/base/reconstruction.cc
90 | void Reconstruction::ReadPoints3DText(const std::string& path)
91 | void Reconstruction::WritePoints3DText(const std::string& path)
92 | """
93 | xyzs = None
94 | rgbs = None
95 | errors = None
96 | with open(path, "r") as fid:
97 | while True:
98 | line = fid.readline()
99 | if not line:
100 | break
101 | line = line.strip()
102 | if len(line) > 0 and line[0] != "#":
103 | elems = line.split()
104 | xyz = np.array(tuple(map(float, elems[1:4])))
105 | rgb = np.array(tuple(map(int, elems[4:7])))
106 | error = np.array(float(elems[7]))
107 | if xyzs is None:
108 | xyzs = xyz[None, ...]
109 | rgbs = rgb[None, ...]
110 | errors = error[None, ...]
111 | else:
112 | xyzs = np.append(xyzs, xyz[None, ...], axis=0)
113 | rgbs = np.append(rgbs, rgb[None, ...], axis=0)
114 | errors = np.append(errors, error[None, ...], axis=0)
115 | return xyzs, rgbs, errors
116 |
117 |
118 | def read_points3D_binary(path_to_model_file):
119 | """
120 | see: src/base/reconstruction.cc
121 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
122 | void Reconstruction::WritePoints3DBinary(const std::string& path)
123 | """
124 |
125 | with open(path_to_model_file, "rb") as fid:
126 | num_points = read_next_bytes(fid, 8, "Q")[0]
127 |
128 | xyzs = np.empty((num_points, 3))
129 | rgbs = np.empty((num_points, 3))
130 | errors = np.empty((num_points, 1))
131 |
132 | for p_id in range(num_points):
133 | binary_point_line_properties = read_next_bytes(
134 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
135 | xyz = np.array(binary_point_line_properties[1:4])
136 | rgb = np.array(binary_point_line_properties[4:7])
137 | error = np.array(binary_point_line_properties[7])
138 | track_length = read_next_bytes(
139 | fid, num_bytes=8, format_char_sequence="Q")[0]
140 | track_elems = read_next_bytes(
141 | fid, num_bytes=8 * track_length,
142 | format_char_sequence="ii" * track_length)
143 | xyzs[p_id] = xyz
144 | rgbs[p_id] = rgb
145 | errors[p_id] = error
146 | return xyzs, rgbs, errors
147 |
148 |
149 | def read_intrinsics_text(path):
150 | """
151 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
152 | """
153 | cameras = {}
154 | with open(path, "r") as fid:
155 | while True:
156 | line = fid.readline()
157 | if not line:
158 | break
159 | line = line.strip()
160 | if len(line) > 0 and line[0] != "#":
161 | elems = line.split()
162 | camera_id = int(elems[0])
163 | model = elems[1]
164 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
165 | width = int(elems[2])
166 | height = int(elems[3])
167 | params = np.array(tuple(map(float, elems[4:])))
168 | cameras[camera_id] = Camera(id=camera_id, model=model,
169 | width=width, height=height,
170 | params=params)
171 | return cameras
172 |
173 |
174 | def read_extrinsics_binary(path_to_model_file):
175 | """
176 | see: src/base/reconstruction.cc
177 | void Reconstruction::ReadImagesBinary(const std::string& path)
178 | void Reconstruction::WriteImagesBinary(const std::string& path)
179 | """
180 | images = {}
181 | with open(path_to_model_file, "rb") as fid:
182 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
183 | for _ in range(num_reg_images):
184 | binary_image_properties = read_next_bytes(
185 | fid, num_bytes=64, format_char_sequence="idddddddi")
186 | image_id = binary_image_properties[0]
187 | qvec = np.array(binary_image_properties[1:5])
188 | tvec = np.array(binary_image_properties[5:8])
189 | camera_id = binary_image_properties[8]
190 | image_name = ""
191 | current_char = read_next_bytes(fid, 1, "c")[0]
192 | while current_char != b"\x00": # look for the ASCII 0 entry
193 | image_name += current_char.decode("utf-8")
194 | current_char = read_next_bytes(fid, 1, "c")[0]
195 | num_points2D = read_next_bytes(fid, num_bytes=8,
196 | format_char_sequence="Q")[0]
197 | x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D,
198 | format_char_sequence="ddq" * num_points2D)
199 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
200 | tuple(map(float, x_y_id_s[1::3]))])
201 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
202 | images[image_id] = Image(
203 | id=image_id, qvec=qvec, tvec=tvec,
204 | camera_id=camera_id, name=image_name,
205 | xys=xys, point3D_ids=point3D_ids)
206 | return images
207 |
208 |
209 | def read_intrinsics_binary(path_to_model_file):
210 | """
211 | see: src/base/reconstruction.cc
212 | void Reconstruction::WriteCamerasBinary(const std::string& path)
213 | void Reconstruction::ReadCamerasBinary(const std::string& path)
214 | """
215 | cameras = {}
216 | with open(path_to_model_file, "rb") as fid:
217 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
218 | for _ in range(num_cameras):
219 | camera_properties = read_next_bytes(
220 | fid, num_bytes=24, format_char_sequence="iiQQ")
221 | camera_id = camera_properties[0]
222 | model_id = camera_properties[1]
223 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
224 | width = camera_properties[2]
225 | height = camera_properties[3]
226 | num_params = CAMERA_MODEL_IDS[model_id].num_params
227 | params = read_next_bytes(fid, num_bytes=8 * num_params,
228 | format_char_sequence="d" * num_params)
229 | cameras[camera_id] = Camera(id=camera_id,
230 | model=model_name,
231 | width=width,
232 | height=height,
233 | params=np.array(params))
234 | assert len(cameras) == num_cameras
235 | return cameras
236 |
237 |
238 | def read_extrinsics_text(path):
239 | """
240 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
241 | """
242 | images = {}
243 | with open(path, "r") as fid:
244 | while True:
245 | line = fid.readline()
246 | if not line:
247 | break
248 | line = line.strip()
249 | if len(line) > 0 and line[0] != "#":
250 | elems = line.split()
251 | image_id = int(elems[0])
252 | qvec = np.array(tuple(map(float, elems[1:5])))
253 | tvec = np.array(tuple(map(float, elems[5:8])))
254 | camera_id = int(elems[8])
255 | image_name = elems[9]
256 | elems = fid.readline().split()
257 | xys = np.column_stack([tuple(map(float, elems[0::3])),
258 | tuple(map(float, elems[1::3]))])
259 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
260 | images[image_id] = Image(
261 | id=image_id, qvec=qvec, tvec=tvec,
262 | camera_id=camera_id, name=image_name,
263 | xys=xys, point3D_ids=point3D_ids)
264 | return images
265 |
266 |
267 | def read_colmap_bin_array(path):
268 | """
269 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
270 |
271 | :param path: path to the colmap binary file.
272 | :return: nd array with the floating point values in the value
273 | """
274 | with open(path, "rb") as fid:
275 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
276 | usecols=(0, 1, 2), dtype=int)
277 | fid.seek(0)
278 | num_delimiter = 0
279 | byte = fid.read(1)
280 | while True:
281 | if byte == b"&":
282 | num_delimiter += 1
283 | if num_delimiter >= 3:
284 | break
285 | byte = fid.read(1)
286 | array = np.fromfile(fid, np.float32)
287 | array = array.reshape((width, height, channels), order="F")
288 | return np.transpose(array, (1, 0, 2)).squeeze()
289 |
--------------------------------------------------------------------------------
/scene/dataset_readers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from PIL import Image
15 | from typing import NamedTuple, Optional
16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
19 | import numpy as np
20 | import json
21 | import imageio
22 | from glob import glob
23 | import cv2 as cv
24 | from pathlib import Path
25 | from plyfile import PlyData, PlyElement
26 | from utils.sh_utils import SH2RGB
27 | from scene.gaussian_model import BasicPointCloud
28 | from utils.camera_utils import camera_nerfies_from_JSON
29 |
30 |
31 | class CameraInfo(NamedTuple):
32 | uid: int
33 | R: np.array
34 | T: np.array
35 | FovY: np.array
36 | FovX: np.array
37 | image: np.array
38 | image_path: str
39 | image_name: str
40 | width: int
41 | height: int
42 | fid: float
43 | depth: Optional[np.array] = None
44 |
45 |
46 | class SceneInfo(NamedTuple):
47 | point_cloud: BasicPointCloud
48 | train_cameras: list
49 | test_cameras: list
50 | nerf_normalization: dict
51 | ply_path: str
52 |
53 |
54 | def load_K_Rt_from_P(filename, P=None):
55 | if P is None:
56 | lines = open(filename).read().splitlines()
57 | if len(lines) == 4:
58 | lines = lines[1:]
59 | lines = [[x[0], x[1], x[2], x[3]]
60 | for x in (x.split(" ") for x in lines)]
61 | P = np.asarray(lines).astype(np.float32).squeeze()
62 |
63 | out = cv.decomposeProjectionMatrix(P)
64 | K = out[0]
65 | R = out[1]
66 | t = out[2]
67 |
68 | K = K / K[2, 2]
69 |
70 | pose = np.eye(4, dtype=np.float32)
71 | pose[:3, :3] = R.transpose()
72 | pose[:3, 3] = (t[:3] / t[3])[:, 0]
73 |
74 | return K, pose
75 |
76 |
77 | def getNerfppNorm(cam_info):
78 | def get_center_and_diag(cam_centers):
79 | cam_centers = np.hstack(cam_centers)
80 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
81 | center = avg_cam_center
82 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
83 | diagonal = np.max(dist)
84 | return center.flatten(), diagonal
85 |
86 | cam_centers = []
87 |
88 | for cam in cam_info:
89 | W2C = getWorld2View2(cam.R, cam.T)
90 | C2W = np.linalg.inv(W2C)
91 | cam_centers.append(C2W[:3, 3:4])
92 |
93 | center, diagonal = get_center_and_diag(cam_centers)
94 | radius = diagonal * 1.1
95 |
96 | translate = -center
97 |
98 | return {"translate": translate, "radius": radius}
99 |
100 |
101 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
102 | cam_infos = []
103 | num_frames = len(cam_extrinsics)
104 | for idx, key in enumerate(cam_extrinsics):
105 | sys.stdout.write('\r')
106 | # the exact output you're looking for:
107 | sys.stdout.write(
108 | "Reading camera {}/{}".format(idx + 1, len(cam_extrinsics)))
109 | sys.stdout.flush()
110 |
111 | extr = cam_extrinsics[key]
112 | intr = cam_intrinsics[extr.camera_id]
113 | height = intr.height
114 | width = intr.width
115 |
116 | uid = intr.id
117 | R = np.transpose(qvec2rotmat(extr.qvec))
118 | T = np.array(extr.tvec)
119 |
120 | if intr.model == "SIMPLE_PINHOLE":
121 | focal_length_x = intr.params[0]
122 | FovY = focal2fov(focal_length_x, height)
123 | FovX = focal2fov(focal_length_x, width)
124 | elif intr.model == "PINHOLE":
125 | focal_length_x = intr.params[0]
126 | focal_length_y = intr.params[1]
127 | FovY = focal2fov(focal_length_y, height)
128 | FovX = focal2fov(focal_length_x, width)
129 | else:
130 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
131 |
132 | image_path = os.path.join(images_folder, os.path.basename(extr.name))
133 | image_name = os.path.basename(image_path).split(".")[0]
134 | image = Image.open(image_path)
135 |
136 | fid = int(image_name) / (num_frames - 1)
137 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
138 | image_path=image_path, image_name=image_name, width=width, height=height, fid=fid)
139 | cam_infos.append(cam_info)
140 | sys.stdout.write('\n')
141 | return cam_infos
142 |
143 |
144 | def fetchPly(path):
145 | plydata = PlyData.read(path)
146 | vertices = plydata['vertex']
147 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
148 | colors = np.vstack([vertices['red'], vertices['green'],
149 | vertices['blue']]).T / 255.0
150 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
151 | return BasicPointCloud(points=positions, colors=colors, normals=normals)
152 |
153 |
154 | def storePly(path, xyz, rgb):
155 | # Define the dtype for the structured array
156 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
157 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
158 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
159 |
160 | normals = np.zeros_like(xyz)
161 |
162 | elements = np.empty(xyz.shape[0], dtype=dtype)
163 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
164 | elements[:] = list(map(tuple, attributes))
165 |
166 | # Create the PlyData object and write to file
167 | vertex_element = PlyElement.describe(elements, 'vertex')
168 | ply_data = PlyData([vertex_element])
169 | ply_data.write(path)
170 |
171 |
172 | def readColmapSceneInfo(path, images, eval, llffhold=8):
173 | try:
174 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
175 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
176 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
177 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
178 | except:
179 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
180 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
181 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
182 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
183 |
184 | reading_dir = "images" if images == None else images
185 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics,
186 | images_folder=os.path.join(path, reading_dir))
187 | cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)
188 |
189 | if eval:
190 | train_cam_infos = [c for idx, c in enumerate(
191 | cam_infos) if idx % llffhold != 0]
192 | test_cam_infos = [c for idx, c in enumerate(
193 | cam_infos) if idx % llffhold == 0]
194 | else:
195 | train_cam_infos = cam_infos
196 | test_cam_infos = []
197 |
198 | nerf_normalization = getNerfppNorm(train_cam_infos)
199 |
200 | ply_path = os.path.join(path, "sparse/0/points3D.ply")
201 | bin_path = os.path.join(path, "sparse/0/points3D.bin")
202 | txt_path = os.path.join(path, "sparse/0/points3D.txt")
203 | if not os.path.exists(ply_path):
204 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
205 | try:
206 | xyz, rgb, _ = read_points3D_binary(bin_path)
207 | except:
208 | xyz, rgb, _ = read_points3D_text(txt_path)
209 | storePly(ply_path, xyz, rgb)
210 | try:
211 | pcd = fetchPly(ply_path)
212 | except:
213 | pcd = None
214 |
215 | scene_info = SceneInfo(point_cloud=pcd,
216 | train_cameras=train_cam_infos,
217 | test_cameras=test_cam_infos,
218 | nerf_normalization=nerf_normalization,
219 | ply_path=ply_path)
220 | return scene_info
221 |
222 |
223 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
224 | cam_infos = []
225 |
226 | with open(os.path.join(path, transformsfile)) as json_file:
227 | contents = json.load(json_file)
228 | fovx = contents["camera_angle_x"]
229 |
230 | frames = contents["frames"]
231 | for idx, frame in enumerate(frames):
232 | cam_name = os.path.join(path, frame["file_path"] + extension)
233 | frame_time = frame['time']
234 |
235 | matrix = np.linalg.inv(np.array(frame["transform_matrix"]))
236 | R = -np.transpose(matrix[:3, :3])
237 | R[:, 0] = -R[:, 0]
238 | T = -matrix[:3, 3]
239 |
240 | image_path = os.path.join(path, cam_name)
241 | image_name = Path(cam_name).stem
242 | image = Image.open(image_path)
243 |
244 | im_data = np.array(image.convert("RGBA"))
245 |
246 | bg = np.array(
247 | [1, 1, 1]) if white_background else np.array([0, 0, 0])
248 |
249 | norm_data = im_data / 255.0
250 | mask = norm_data[..., 3:4]
251 |
252 | arr = norm_data[:, :, :3] * norm_data[:, :,
253 | 3:4] + bg * (1 - norm_data[:, :, 3:4])
254 | image = Image.fromarray(
255 | np.array(arr * 255.0, dtype=np.byte), "RGB")
256 |
257 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
258 | FovY = fovx
259 | FovX = fovy
260 |
261 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
262 | image_path=image_path, image_name=image_name, width=image.size[
263 | 0],
264 | height=image.size[1], fid=frame_time))
265 |
266 | return cam_infos
267 |
268 |
269 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
270 | print("Reading Training Transforms")
271 | train_cam_infos = readCamerasFromTransforms(
272 | path, "transforms_train.json", white_background, extension)
273 | print("Reading Test Transforms")
274 | test_cam_infos = readCamerasFromTransforms(
275 | path, "transforms_test.json", white_background, extension)
276 |
277 | if not eval:
278 | train_cam_infos.extend(test_cam_infos)
279 | test_cam_infos = []
280 |
281 | nerf_normalization = getNerfppNorm(train_cam_infos)
282 |
283 | ply_path = os.path.join(path, "points3d.ply")
284 | if not os.path.exists(ply_path):
285 | # Since this data set has no colmap data, we start with random points
286 | num_pts = 100_000
287 | print(f"Generating random point cloud ({num_pts})...")
288 |
289 | # We create random points inside the bounds of the synthetic Blender scenes
290 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
291 | shs = np.random.random((num_pts, 3)) / 255.0
292 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(
293 | shs), normals=np.zeros((num_pts, 3)))
294 |
295 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
296 | try:
297 | pcd = fetchPly(ply_path)
298 | except:
299 | pcd = None
300 |
301 | scene_info = SceneInfo(point_cloud=pcd,
302 | train_cameras=train_cam_infos,
303 | test_cameras=test_cam_infos,
304 | nerf_normalization=nerf_normalization,
305 | ply_path=ply_path)
306 | return scene_info
307 |
308 |
309 | def readDTUCameras(path, render_camera, object_camera):
310 | camera_dict = np.load(os.path.join(path, render_camera))
311 | images_lis = sorted(glob(os.path.join(path, 'image/*.png')))
312 | masks_lis = sorted(glob(os.path.join(path, 'mask/*.png')))
313 | n_images = len(images_lis)
314 | cam_infos = []
315 | cam_idx = 0
316 | for idx in range(0, n_images):
317 | image_path = images_lis[idx]
318 | image = np.array(Image.open(image_path))
319 | mask = np.array(imageio.imread(masks_lis[idx])) / 255.0
320 | image = Image.fromarray((image * mask).astype(np.uint8))
321 | world_mat = camera_dict['world_mat_%d' % idx].astype(np.float32)
322 | fid = camera_dict['fid_%d' % idx] / (n_images / 12 - 1)
323 | image_name = Path(image_path).stem
324 | scale_mat = camera_dict['scale_mat_%d' % idx].astype(np.float32)
325 | P = world_mat @ scale_mat
326 | P = P[:3, :4]
327 |
328 | K, pose = load_K_Rt_from_P(None, P)
329 | a = pose[0:1, :]
330 | b = pose[1:2, :]
331 | c = pose[2:3, :]
332 |
333 | pose = np.concatenate([a, -c, -b, pose[3:, :]], 0)
334 |
335 | S = np.eye(3)
336 | S[1, 1] = -1
337 | S[2, 2] = -1
338 | pose[1, 3] = -pose[1, 3]
339 | pose[2, 3] = -pose[2, 3]
340 | pose[:3, :3] = S @ pose[:3, :3] @ S
341 |
342 | a = pose[0:1, :]
343 | b = pose[1:2, :]
344 | c = pose[2:3, :]
345 |
346 | pose = np.concatenate([a, c, b, pose[3:, :]], 0)
347 |
348 | pose[:, 3] *= 0.5
349 |
350 | matrix = np.linalg.inv(pose)
351 | R = -np.transpose(matrix[:3, :3])
352 | R[:, 0] = -R[:, 0]
353 | T = -matrix[:3, 3]
354 |
355 | FovY = focal2fov(K[0, 0], image.size[1])
356 | FovX = focal2fov(K[0, 0], image.size[0])
357 | cam_info = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
358 | image_path=image_path, image_name=image_name, width=image.size[
359 | 0], height=image.size[1],
360 | fid=fid)
361 | cam_infos.append(cam_info)
362 | sys.stdout.write('\n')
363 | return cam_infos
364 |
365 |
366 | def readNeuSDTUInfo(path, render_camera, object_camera):
367 | print("Reading DTU Info")
368 | train_cam_infos = readDTUCameras(path, render_camera, object_camera)
369 |
370 | nerf_normalization = getNerfppNorm(train_cam_infos)
371 |
372 | ply_path = os.path.join(path, "points3d.ply")
373 | if not os.path.exists(ply_path):
374 | # Since this data set has no colmap data, we start with random points
375 | num_pts = 100_000
376 | print(f"Generating random point cloud ({num_pts})...")
377 |
378 | # We create random points inside the bounds of the synthetic Blender scenes
379 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
380 | shs = np.random.random((num_pts, 3)) / 255.0
381 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(
382 | shs), normals=np.zeros((num_pts, 3)))
383 |
384 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
385 | try:
386 | pcd = fetchPly(ply_path)
387 | except:
388 | pcd = None
389 |
390 | scene_info = SceneInfo(point_cloud=pcd,
391 | train_cameras=train_cam_infos,
392 | test_cameras=[],
393 | nerf_normalization=nerf_normalization,
394 | ply_path=ply_path)
395 | return scene_info
396 |
397 |
398 | def readNerfiesCameras(path):
399 | with open(f'{path}/scene.json', 'r') as f:
400 | scene_json = json.load(f)
401 | with open(f'{path}/metadata.json', 'r') as f:
402 | meta_json = json.load(f)
403 | with open(f'{path}/dataset.json', 'r') as f:
404 | dataset_json = json.load(f)
405 |
406 | coord_scale = scene_json['scale']
407 | scene_center = scene_json['center']
408 |
409 | name = path.split('/')[-2]
410 | if name.startswith('vrig'):
411 | train_img = dataset_json['train_ids']
412 | val_img = dataset_json['val_ids']
413 | all_img = train_img + val_img
414 | ratio = 0.25
415 | elif name.startswith('NeRF'):
416 | train_img = dataset_json['train_ids']
417 | val_img = dataset_json['val_ids']
418 | all_img = train_img + val_img
419 | ratio = 1.0
420 | elif name.startswith('interp'):
421 | all_id = dataset_json['ids']
422 | train_img = all_id[::4]
423 | val_img = all_id[2::4]
424 | all_img = train_img + val_img
425 | ratio = 0.5
426 | else: # for hypernerf
427 | train_img = dataset_json['ids'][::4]
428 | all_img = train_img
429 | ratio = 0.5
430 |
431 | train_num = len(train_img)
432 |
433 | all_cam = [meta_json[i]['camera_id'] for i in all_img]
434 | all_time = [meta_json[i]['time_id'] for i in all_img]
435 | max_time = max(all_time)
436 | all_time = [meta_json[i]['time_id'] / max_time for i in all_img]
437 | selected_time = set(all_time)
438 |
439 | # all poses
440 | all_cam_params = []
441 | for im in all_img:
442 | camera = camera_nerfies_from_JSON(f'{path}/camera/{im}.json', ratio)
443 | camera['position'] = camera['position'] - scene_center
444 | camera['position'] = camera['position'] * coord_scale
445 | all_cam_params.append(camera)
446 |
447 | all_img = [f'{path}/rgb/{int(1 / ratio)}x/{i}.png' for i in all_img]
448 |
449 | cam_infos = []
450 | for idx in range(len(all_img)):
451 | image_path = all_img[idx]
452 | image = np.array(Image.open(image_path))
453 | image = Image.fromarray((image).astype(np.uint8))
454 | image_name = Path(image_path).stem
455 |
456 | orientation = all_cam_params[idx]['orientation'].T
457 | position = -all_cam_params[idx]['position'] @ orientation
458 | focal = all_cam_params[idx]['focal_length']
459 | fid = all_time[idx]
460 | T = position
461 | R = orientation
462 |
463 | FovY = focal2fov(focal, image.size[1])
464 | FovX = focal2fov(focal, image.size[0])
465 | cam_info = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
466 | image_path=image_path, image_name=image_name, width=image.size[
467 | 0], height=image.size[1],
468 | fid=fid)
469 | cam_infos.append(cam_info)
470 |
471 | sys.stdout.write('\n')
472 | return cam_infos, train_num, scene_center, coord_scale
473 |
474 |
475 | def readNerfiesInfo(path, eval):
476 | print("Reading Nerfies Info")
477 | cam_infos, train_num, scene_center, scene_scale = readNerfiesCameras(path)
478 |
479 | if eval:
480 | train_cam_infos = cam_infos[:train_num]
481 | test_cam_infos = cam_infos[train_num:]
482 | else:
483 | train_cam_infos = cam_infos
484 | test_cam_infos = []
485 |
486 | nerf_normalization = getNerfppNorm(train_cam_infos)
487 |
488 | ply_path = os.path.join(path, "points3d.ply")
489 | if not os.path.exists(ply_path):
490 | print(f"Generating point cloud from nerfies...")
491 |
492 | xyz = np.load(os.path.join(path, "points.npy"))
493 | xyz = (xyz - scene_center) * scene_scale
494 | num_pts = xyz.shape[0]
495 | shs = np.random.random((num_pts, 3)) / 255.0
496 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(
497 | shs), normals=np.zeros((num_pts, 3)))
498 |
499 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
500 | try:
501 | pcd = fetchPly(ply_path)
502 | except:
503 | pcd = None
504 |
505 | scene_info = SceneInfo(point_cloud=pcd,
506 | train_cameras=train_cam_infos,
507 | test_cameras=test_cam_infos,
508 | nerf_normalization=nerf_normalization,
509 | ply_path=ply_path)
510 | return scene_info
511 |
512 |
513 | def readCamerasFromNpy(path, npy_file, split, hold_id, num_images):
514 | cam_infos = []
515 | video_paths = sorted(glob(os.path.join(path, 'frames/*')))
516 | poses_bounds = np.load(os.path.join(path, npy_file))
517 |
518 | poses = poses_bounds[:, :15].reshape(-1, 3, 5)
519 | H, W, focal = poses[0, :, -1]
520 |
521 | n_cameras = poses.shape[0]
522 | poses = np.concatenate(
523 | [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
524 | bottoms = np.array([0, 0, 0, 1]).reshape(
525 | 1, -1, 4).repeat(poses.shape[0], axis=0)
526 | poses = np.concatenate([poses, bottoms], axis=1)
527 | poses = poses @ np.diag([1, -1, -1, 1])
528 |
529 | i_test = np.array(hold_id)
530 | video_list = i_test if split != 'train' else list(
531 | set(np.arange(n_cameras)) - set(i_test))
532 |
533 | for i in video_list:
534 | video_path = video_paths[i]
535 | c2w = poses[i]
536 | images_names = sorted(os.listdir(video_path))
537 | n_frames = num_images
538 |
539 | matrix = np.linalg.inv(np.array(c2w))
540 | R = np.transpose(matrix[:3, :3])
541 | T = matrix[:3, 3]
542 |
543 | for idx, image_name in enumerate(images_names[:num_images]):
544 | image_path = os.path.join(video_path, image_name)
545 | image = Image.open(image_path)
546 | frame_time = idx / (n_frames - 1)
547 |
548 | FovX = focal2fov(focal, image.size[0])
549 | FovY = focal2fov(focal, image.size[1])
550 |
551 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovX=FovX, FovY=FovY,
552 | image=image,
553 | image_path=image_path, image_name=image_name,
554 | width=image.size[0], height=image.size[1], fid=frame_time))
555 |
556 | idx += 1
557 | return cam_infos
558 |
559 |
560 | def readPlenopticVideoDataset(path, eval, num_images, hold_id=[0]):
561 | print("Reading Training Camera")
562 | train_cam_infos = readCamerasFromNpy(path, 'poses_bounds.npy', split="train", hold_id=hold_id,
563 | num_images=num_images)
564 |
565 | print("Reading Training Camera")
566 | test_cam_infos = readCamerasFromNpy(
567 | path, 'poses_bounds.npy', split="test", hold_id=hold_id, num_images=num_images)
568 |
569 | if not eval:
570 | train_cam_infos.extend(test_cam_infos)
571 | test_cam_infos = []
572 |
573 | nerf_normalization = getNerfppNorm(train_cam_infos)
574 | ply_path = os.path.join(path, 'points3D.ply')
575 | if not os.path.exists(ply_path):
576 | num_pts = 100_000
577 | print(f"Generating random point cloud ({num_pts})...")
578 |
579 | # We create random points inside the bounds of the synthetic Blender scenes
580 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
581 | shs = np.random.random((num_pts, 3)) / 255.0
582 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(
583 | shs), normals=np.zeros((num_pts, 3)))
584 |
585 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
586 |
587 | try:
588 | pcd = fetchPly(ply_path)
589 | except:
590 | pcd = None
591 |
592 | scene_info = SceneInfo(point_cloud=pcd,
593 | train_cameras=train_cam_infos,
594 | test_cameras=test_cam_infos,
595 | nerf_normalization=nerf_normalization,
596 | ply_path=ply_path)
597 | return scene_info
598 |
599 |
600 | sceneLoadTypeCallbacks = {
601 | "Colmap": readColmapSceneInfo, # colmap dataset reader from official 3D Gaussian [https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/]
602 | "Blender": readNerfSyntheticInfo, # D-NeRF dataset [https://drive.google.com/file/d/1uHVyApwqugXTFuIRRlE4abTW8_rrVeIK/view?usp=sharing]
603 | "DTU": readNeuSDTUInfo, # DTU dataset used in Tensor4D [https://github.com/DSaurus/Tensor4D]
604 | "nerfies": readNerfiesInfo, # NeRFies & HyperNeRF dataset proposed by [https://github.com/google/hypernerf/releases/tag/v0.1]
605 | "plenopticVideo": readPlenopticVideoDataset, # Neural 3D dataset in [https://github.com/facebookresearch/Neural_3D_Video]
606 | }
607 |
--------------------------------------------------------------------------------
/scene/deform_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils.time_utils import DeformNetwork
5 | import os
6 | from utils.system_utils import searchForMaxIteration
7 | from utils.general_utils import get_expon_lr_func
8 |
9 |
10 | class DeformModel:
11 | def __init__(self, is_blender=False, is_6dof=False):
12 | self.deform = DeformNetwork(is_blender=is_blender, is_6dof=is_6dof).cuda()
13 | self.optimizer = None
14 | self.spatial_lr_scale = 5
15 |
16 | def step(self, xyz, time_emb):
17 | return self.deform(xyz, time_emb)
18 |
19 | def train_setting(self, training_args):
20 | l = [
21 | {'params': list(self.deform.parameters()),
22 | 'lr': training_args.position_lr_init * self.spatial_lr_scale,
23 | "name": "deform"}
24 | ]
25 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
26 |
27 | self.deform_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale,
28 | lr_final=training_args.position_lr_final,
29 | lr_delay_mult=training_args.position_lr_delay_mult,
30 | max_steps=training_args.deform_lr_max_steps)
31 |
32 | def save_weights(self, model_path, iteration):
33 | out_weights_path = os.path.join(model_path, "deform/iteration_{}".format(iteration))
34 | os.makedirs(out_weights_path, exist_ok=True)
35 | torch.save(self.deform.state_dict(), os.path.join(out_weights_path, 'deform.pth'))
36 |
37 | def load_weights(self, model_path, iteration=-1):
38 | if iteration == -1:
39 | loaded_iter = searchForMaxIteration(os.path.join(model_path, "deform"))
40 | else:
41 | loaded_iter = iteration
42 | weights_path = os.path.join(model_path, "deform/iteration_{}/deform.pth".format(loaded_iter))
43 | self.deform.load_state_dict(torch.load(weights_path))
44 |
45 | def update_learning_rate(self, iteration):
46 | for param_group in self.optimizer.param_groups:
47 | if param_group["name"] == "deform":
48 | lr = self.deform_scheduler_args(iteration)
49 | param_group['lr'] = lr
50 | return lr
51 |
--------------------------------------------------------------------------------
/scene/gaussian_model.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import numpy as np
14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
15 | from torch import nn
16 | import os
17 | from utils.system_utils import mkdir_p
18 | from plyfile import PlyData, PlyElement
19 | from utils.sh_utils import RGB2SH
20 | from simple_knn._C import distCUDA2
21 | from utils.graphics_utils import BasicPointCloud
22 | from utils.general_utils import strip_symmetric, build_scaling_rotation
23 |
24 |
25 | class GaussianModel:
26 | def __init__(self, sh_degree: int):
27 |
28 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
29 | L = build_scaling_rotation(scaling_modifier * scaling, rotation)
30 | actual_covariance = L @ L.transpose(1, 2)
31 | symm = strip_symmetric(actual_covariance)
32 | return symm
33 |
34 | self.active_sh_degree = 0
35 | self.max_sh_degree = sh_degree
36 |
37 | self._xyz = torch.empty(0)
38 | self._features_dc = torch.empty(0)
39 | self._features_rest = torch.empty(0)
40 | self._scaling = torch.empty(0)
41 | self._rotation = torch.empty(0)
42 | self._opacity = torch.empty(0)
43 | self.max_radii2D = torch.empty(0)
44 | self.xyz_gradient_accum = torch.empty(0)
45 |
46 | self.optimizer = None
47 |
48 | self.scaling_activation = torch.exp
49 | self.scaling_inverse_activation = torch.log
50 |
51 | self.covariance_activation = build_covariance_from_scaling_rotation
52 |
53 | self.opacity_activation = torch.sigmoid
54 | self.inverse_opacity_activation = inverse_sigmoid
55 |
56 | self.rotation_activation = torch.nn.functional.normalize
57 |
58 | @property
59 | def get_scaling(self):
60 | return self.scaling_activation(self._scaling)
61 |
62 | @property
63 | def get_rotation(self):
64 | return self.rotation_activation(self._rotation)
65 |
66 | @property
67 | def get_xyz(self):
68 | return self._xyz
69 |
70 | @property
71 | def get_features(self):
72 | features_dc = self._features_dc
73 | features_rest = self._features_rest
74 | return torch.cat((features_dc, features_rest), dim=1)
75 |
76 | @property
77 | def get_opacity(self):
78 | return self.opacity_activation(self._opacity)
79 |
80 | def get_covariance(self, scaling_modifier=1):
81 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
82 |
83 | def oneupSHdegree(self):
84 | if self.active_sh_degree < self.max_sh_degree:
85 | self.active_sh_degree += 1
86 |
87 | def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
88 | self.spatial_lr_scale = 5
89 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
90 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
91 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
92 | features[:, :3, 0] = fused_color
93 | features[:, 3:, 1:] = 0.0
94 |
95 | print("Number of points at initialisation : ", fused_point_cloud.shape[0])
96 |
97 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
98 | scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
99 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
100 | rots[:, 0] = 1
101 |
102 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
103 |
104 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
105 | self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True))
106 | self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True))
107 | self._scaling = nn.Parameter(scales.requires_grad_(True))
108 | self._rotation = nn.Parameter(rots.requires_grad_(True))
109 | self._opacity = nn.Parameter(opacities.requires_grad_(True))
110 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
111 |
112 | def training_setup(self, training_args):
113 | self.percent_dense = training_args.percent_dense
114 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
115 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
116 |
117 | self.spatial_lr_scale = 5
118 |
119 | l = [
120 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
121 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
122 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
123 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
124 | {'params': [self._scaling], 'lr': training_args.scaling_lr * self.spatial_lr_scale, "name": "scaling"},
125 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
126 | ]
127 |
128 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
129 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale,
130 | lr_final=training_args.position_lr_final * self.spatial_lr_scale,
131 | lr_delay_mult=training_args.position_lr_delay_mult,
132 | max_steps=training_args.position_lr_max_steps)
133 |
134 | def update_learning_rate(self, iteration):
135 | ''' Learning rate scheduling per step '''
136 | for param_group in self.optimizer.param_groups:
137 | if param_group["name"] == "xyz":
138 | lr = self.xyz_scheduler_args(iteration)
139 | param_group['lr'] = lr
140 | return lr
141 |
142 | def construct_list_of_attributes(self):
143 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
144 | # All channels except the 3 DC
145 | for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
146 | l.append('f_dc_{}'.format(i))
147 | for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):
148 | l.append('f_rest_{}'.format(i))
149 | l.append('opacity')
150 | for i in range(self._scaling.shape[1]):
151 | l.append('scale_{}'.format(i))
152 | for i in range(self._rotation.shape[1]):
153 | l.append('rot_{}'.format(i))
154 | return l
155 |
156 | def save_ply(self, path):
157 | mkdir_p(os.path.dirname(path))
158 |
159 | xyz = self._xyz.detach().cpu().numpy()
160 | normals = np.zeros_like(xyz)
161 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
162 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
163 | opacities = self._opacity.detach().cpu().numpy()
164 | scale = self._scaling.detach().cpu().numpy()
165 | rotation = self._rotation.detach().cpu().numpy()
166 |
167 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
168 |
169 | elements = np.empty(xyz.shape[0], dtype=dtype_full)
170 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
171 | elements[:] = list(map(tuple, attributes))
172 | el = PlyElement.describe(elements, 'vertex')
173 | PlyData([el]).write(path)
174 |
175 | def reset_opacity(self):
176 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01))
177 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
178 | self._opacity = optimizable_tensors["opacity"]
179 |
180 | def load_ply(self, path, og_number_points=-1):
181 | self.og_number_points = og_number_points
182 | plydata = PlyData.read(path)
183 |
184 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
185 | np.asarray(plydata.elements[0]["y"]),
186 | np.asarray(plydata.elements[0]["z"])), axis=1)
187 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
188 |
189 | features_dc = np.zeros((xyz.shape[0], 3, 1))
190 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
191 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
192 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
193 |
194 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
195 | assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3
196 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
197 | for idx, attr_name in enumerate(extra_f_names):
198 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
199 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
200 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
201 |
202 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
203 | scales = np.zeros((xyz.shape[0], len(scale_names)))
204 | for idx, attr_name in enumerate(scale_names):
205 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
206 |
207 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
208 | rots = np.zeros((xyz.shape[0], len(rot_names)))
209 | for idx, attr_name in enumerate(rot_names):
210 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
211 |
212 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
213 | self._features_dc = nn.Parameter(
214 | torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(
215 | True))
216 | self._features_rest = nn.Parameter(
217 | torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(
218 | True))
219 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
220 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
221 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
222 |
223 | self.active_sh_degree = self.max_sh_degree
224 |
225 | def replace_tensor_to_optimizer(self, tensor, name):
226 | optimizable_tensors = {}
227 | for group in self.optimizer.param_groups:
228 | if group["name"] == name:
229 | stored_state = self.optimizer.state.get(group['params'][0], None)
230 | stored_state["exp_avg"] = torch.zeros_like(tensor)
231 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
232 |
233 | del self.optimizer.state[group['params'][0]]
234 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
235 | self.optimizer.state[group['params'][0]] = stored_state
236 |
237 | optimizable_tensors[group["name"]] = group["params"][0]
238 | return optimizable_tensors
239 |
240 | def _prune_optimizer(self, mask):
241 | optimizable_tensors = {}
242 | for group in self.optimizer.param_groups:
243 | stored_state = self.optimizer.state.get(group['params'][0], None)
244 | if stored_state is not None:
245 | stored_state["exp_avg"] = stored_state["exp_avg"][mask]
246 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
247 |
248 | del self.optimizer.state[group['params'][0]]
249 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
250 | self.optimizer.state[group['params'][0]] = stored_state
251 |
252 | optimizable_tensors[group["name"]] = group["params"][0]
253 | else:
254 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
255 | optimizable_tensors[group["name"]] = group["params"][0]
256 | return optimizable_tensors
257 |
258 | def prune_points(self, mask):
259 | valid_points_mask = ~mask
260 | optimizable_tensors = self._prune_optimizer(valid_points_mask)
261 |
262 | self._xyz = optimizable_tensors["xyz"]
263 | self._features_dc = optimizable_tensors["f_dc"]
264 | self._features_rest = optimizable_tensors["f_rest"]
265 | self._opacity = optimizable_tensors["opacity"]
266 | self._scaling = optimizable_tensors["scaling"]
267 | self._rotation = optimizable_tensors["rotation"]
268 |
269 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
270 |
271 | self.denom = self.denom[valid_points_mask]
272 | self.max_radii2D = self.max_radii2D[valid_points_mask]
273 |
274 | def cat_tensors_to_optimizer(self, tensors_dict):
275 | optimizable_tensors = {}
276 | for group in self.optimizer.param_groups:
277 | assert len(group["params"]) == 1
278 | extension_tensor = tensors_dict[group["name"]]
279 | stored_state = self.optimizer.state.get(group['params'][0], None)
280 | if stored_state is not None:
281 |
282 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)),
283 | dim=0)
284 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)),
285 | dim=0)
286 |
287 | del self.optimizer.state[group['params'][0]]
288 | group["params"][0] = nn.Parameter(
289 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
290 | self.optimizer.state[group['params'][0]] = stored_state
291 |
292 | optimizable_tensors[group["name"]] = group["params"][0]
293 | else:
294 | group["params"][0] = nn.Parameter(
295 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
296 | optimizable_tensors[group["name"]] = group["params"][0]
297 |
298 | return optimizable_tensors
299 |
300 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling,
301 | new_rotation):
302 | d = {"xyz": new_xyz,
303 | "f_dc": new_features_dc,
304 | "f_rest": new_features_rest,
305 | "opacity": new_opacities,
306 | "scaling": new_scaling,
307 | "rotation": new_rotation}
308 |
309 | optimizable_tensors = self.cat_tensors_to_optimizer(d)
310 | self._xyz = optimizable_tensors["xyz"]
311 | self._features_dc = optimizable_tensors["f_dc"]
312 | self._features_rest = optimizable_tensors["f_rest"]
313 | self._opacity = optimizable_tensors["opacity"]
314 | self._scaling = optimizable_tensors["scaling"]
315 | self._rotation = optimizable_tensors["rotation"]
316 |
317 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
318 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
319 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
320 |
321 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
322 | n_init_points = self.get_xyz.shape[0]
323 | # Extract points that satisfy the gradient condition
324 | padded_grad = torch.zeros((n_init_points), device="cuda")
325 | padded_grad[:grads.shape[0]] = grads.squeeze()
326 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
327 | selected_pts_mask = torch.logical_and(selected_pts_mask,
328 | torch.max(self.get_scaling,
329 | dim=1).values > self.percent_dense * scene_extent)
330 |
331 | stds = self.get_scaling[selected_pts_mask].repeat(N, 1)
332 | means = torch.zeros((stds.size(0), 3), device="cuda")
333 | samples = torch.normal(mean=means, std=stds)
334 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)
335 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
336 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N))
337 | new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)
338 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)
339 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)
340 | new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)
341 |
342 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
343 |
344 | prune_filter = torch.cat(
345 | (selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
346 | self.prune_points(prune_filter)
347 |
348 | def densify_and_clone(self, grads, grad_threshold, scene_extent):
349 | # Extract points that satisfy the gradient condition
350 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
351 | selected_pts_mask = torch.logical_and(selected_pts_mask,
352 | torch.max(self.get_scaling,
353 | dim=1).values <= self.percent_dense * scene_extent)
354 |
355 | new_xyz = self._xyz[selected_pts_mask]
356 | new_features_dc = self._features_dc[selected_pts_mask]
357 | new_features_rest = self._features_rest[selected_pts_mask]
358 | new_opacities = self._opacity[selected_pts_mask]
359 | new_scaling = self._scaling[selected_pts_mask]
360 | new_rotation = self._rotation[selected_pts_mask]
361 |
362 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling,
363 | new_rotation)
364 |
365 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
366 | grads = self.xyz_gradient_accum / self.denom
367 | grads[grads.isnan()] = 0.0
368 |
369 | self.densify_and_clone(grads, max_grad, extent)
370 | self.densify_and_split(grads, max_grad, extent)
371 |
372 | prune_mask = (self.get_opacity < min_opacity).squeeze()
373 | if max_screen_size:
374 | big_points_vs = self.max_radii2D > max_screen_size
375 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
376 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
377 | self.prune_points(prune_mask)
378 |
379 | torch.cuda.empty_cache()
380 |
381 | def add_densification_stats(self, viewspace_point_tensor, update_filter):
382 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter, :2], dim=-1,
383 | keepdim=True)
384 | self.denom[update_filter] += 1
385 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | from random import randint
15 | from utils.loss_utils import l1_loss, ssim, kl_divergence
16 | from gaussian_renderer import render, network_gui
17 | import sys
18 | from scene import Scene, GaussianModel, DeformModel
19 | from utils.general_utils import safe_state, get_linear_noise_func
20 | import uuid
21 | from tqdm import tqdm
22 | from utils.image_utils import psnr
23 | from argparse import ArgumentParser, Namespace
24 | from arguments import ModelParams, PipelineParams, OptimizationParams
25 |
26 | try:
27 | from torch.utils.tensorboard import SummaryWriter
28 |
29 | TENSORBOARD_FOUND = True
30 | except ImportError:
31 | TENSORBOARD_FOUND = False
32 |
33 |
34 | def training(dataset, opt, pipe, testing_iterations, saving_iterations):
35 | tb_writer = prepare_output_and_logger(dataset)
36 | gaussians = GaussianModel(dataset.sh_degree)
37 | deform = DeformModel(dataset.is_blender, dataset.is_6dof)
38 | deform.train_setting(opt)
39 |
40 | scene = Scene(dataset, gaussians)
41 | gaussians.training_setup(opt)
42 |
43 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
44 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
45 |
46 | iter_start = torch.cuda.Event(enable_timing=True)
47 | iter_end = torch.cuda.Event(enable_timing=True)
48 |
49 | viewpoint_stack = None
50 | ema_loss_for_log = 0.0
51 | best_psnr = 0.0
52 | best_iteration = 0
53 | progress_bar = tqdm(range(opt.iterations), desc="Training progress")
54 | smooth_term = get_linear_noise_func(lr_init=0.1, lr_final=1e-15, lr_delay_mult=0.01, max_steps=20000)
55 | for iteration in range(1, opt.iterations + 1):
56 | if network_gui.conn == None:
57 | network_gui.try_connect()
58 | while network_gui.conn != None:
59 | try:
60 | net_image_bytes = None
61 | custom_cam, do_training, pipe.do_shs_python, pipe.do_cov_python, keep_alive, scaling_modifer = network_gui.receive()
62 | if custom_cam != None:
63 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
64 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2,
65 | 0).contiguous().cpu().numpy())
66 | network_gui.send(net_image_bytes, dataset.source_path)
67 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
68 | break
69 | except Exception as e:
70 | network_gui.conn = None
71 |
72 | iter_start.record()
73 |
74 | # Every 1000 its we increase the levels of SH up to a maximum degree
75 | if iteration % 1000 == 0:
76 | gaussians.oneupSHdegree()
77 |
78 | # Pick a random Camera
79 | if not viewpoint_stack:
80 | viewpoint_stack = scene.getTrainCameras().copy()
81 |
82 | total_frame = len(viewpoint_stack)
83 | time_interval = 1 / total_frame
84 |
85 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
86 | if dataset.load2gpu_on_the_fly:
87 | viewpoint_cam.load2device()
88 | fid = viewpoint_cam.fid
89 |
90 | if iteration < opt.warm_up:
91 | d_xyz, d_rotation, d_scaling = 0.0, 0.0, 0.0
92 | else:
93 | N = gaussians.get_xyz.shape[0]
94 | time_input = fid.unsqueeze(0).expand(N, -1)
95 |
96 | ast_noise = 0 if dataset.is_blender else torch.randn(1, 1, device='cuda').expand(N, -1) * time_interval * smooth_term(iteration)
97 | d_xyz, d_rotation, d_scaling = deform.step(gaussians.get_xyz.detach(), time_input + ast_noise)
98 |
99 | # Render
100 | render_pkg_re = render(viewpoint_cam, gaussians, pipe, background, d_xyz, d_rotation, d_scaling, dataset.is_6dof)
101 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg_re["render"], render_pkg_re[
102 | "viewspace_points"], render_pkg_re["visibility_filter"], render_pkg_re["radii"]
103 | # depth = render_pkg_re["depth"]
104 |
105 | # Loss
106 | gt_image = viewpoint_cam.original_image.cuda()
107 | Ll1 = l1_loss(image, gt_image)
108 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
109 | loss.backward()
110 |
111 | iter_end.record()
112 |
113 | if dataset.load2gpu_on_the_fly:
114 | viewpoint_cam.load2device('cpu')
115 |
116 | with torch.no_grad():
117 | # Progress bar
118 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
119 | if iteration % 10 == 0:
120 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
121 | progress_bar.update(10)
122 | if iteration == opt.iterations:
123 | progress_bar.close()
124 |
125 | # Keep track of max radii in image-space for pruning
126 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter],
127 | radii[visibility_filter])
128 |
129 | # Log and save
130 | cur_psnr = training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end),
131 | testing_iterations, scene, render, (pipe, background), deform,
132 | dataset.load2gpu_on_the_fly, dataset.is_6dof)
133 | if iteration in testing_iterations:
134 | if cur_psnr.item() > best_psnr:
135 | best_psnr = cur_psnr.item()
136 | best_iteration = iteration
137 |
138 | if iteration in saving_iterations:
139 | print("\n[ITER {}] Saving Gaussians".format(iteration))
140 | scene.save(iteration)
141 | deform.save_weights(args.model_path, iteration)
142 |
143 | # Densification
144 | if iteration < opt.densify_until_iter:
145 | viewspace_point_tensor_densify = render_pkg_re["viewspace_points_densify"]
146 | gaussians.add_densification_stats(viewspace_point_tensor_densify, visibility_filter)
147 |
148 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
149 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None
150 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
151 |
152 | if iteration % opt.opacity_reset_interval == 0 or (
153 | dataset.white_background and iteration == opt.densify_from_iter):
154 | gaussians.reset_opacity()
155 |
156 | # Optimizer step
157 | if iteration < opt.iterations:
158 | gaussians.optimizer.step()
159 | gaussians.update_learning_rate(iteration)
160 | deform.optimizer.step()
161 | gaussians.optimizer.zero_grad(set_to_none=True)
162 | deform.optimizer.zero_grad()
163 | deform.update_learning_rate(iteration)
164 |
165 | print("Best PSNR = {} in Iteration {}".format(best_psnr, best_iteration))
166 |
167 |
168 | def prepare_output_and_logger(args):
169 | if not args.model_path:
170 | if os.getenv('OAR_JOB_ID'):
171 | unique_str = os.getenv('OAR_JOB_ID')
172 | else:
173 | unique_str = str(uuid.uuid4())
174 | args.model_path = os.path.join("./output/", unique_str[0:10])
175 |
176 | # Set up output folder
177 | print("Output folder: {}".format(args.model_path))
178 | os.makedirs(args.model_path, exist_ok=True)
179 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
180 | cfg_log_f.write(str(Namespace(**vars(args))))
181 |
182 | # Create Tensorboard writer
183 | tb_writer = None
184 | if TENSORBOARD_FOUND:
185 | tb_writer = SummaryWriter(args.model_path)
186 | else:
187 | print("Tensorboard not available: not logging progress")
188 | return tb_writer
189 |
190 |
191 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc,
192 | renderArgs, deform, load2gpu_on_the_fly, is_6dof=False):
193 | if tb_writer:
194 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
195 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
196 | tb_writer.add_scalar('iter_time', elapsed, iteration)
197 |
198 | test_psnr = 0.0
199 | # Report test and samples of training set
200 | if iteration in testing_iterations:
201 | torch.cuda.empty_cache()
202 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()},
203 | {'name': 'train',
204 | 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in
205 | range(5, 30, 5)]})
206 |
207 | for config in validation_configs:
208 | if config['cameras'] and len(config['cameras']) > 0:
209 | images = torch.tensor([], device="cuda")
210 | gts = torch.tensor([], device="cuda")
211 | for idx, viewpoint in enumerate(config['cameras']):
212 | if load2gpu_on_the_fly:
213 | viewpoint.load2device()
214 | fid = viewpoint.fid
215 | xyz = scene.gaussians.get_xyz
216 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
217 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input)
218 | image = torch.clamp(
219 | renderFunc(viewpoint, scene.gaussians, *renderArgs, d_xyz, d_rotation, d_scaling, is_6dof)["render"],
220 | 0.0, 1.0)
221 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
222 | images = torch.cat((images, image.unsqueeze(0)), dim=0)
223 | gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
224 |
225 | if load2gpu_on_the_fly:
226 | viewpoint.load2device('cpu')
227 | if tb_writer and (idx < 5):
228 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name),
229 | image[None], global_step=iteration)
230 | if iteration == testing_iterations[0]:
231 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name),
232 | gt_image[None], global_step=iteration)
233 |
234 | l1_test = l1_loss(images, gts)
235 | psnr_test = psnr(images, gts).mean()
236 | if config['name'] == 'test' or len(validation_configs[0]['cameras']) == 0:
237 | test_psnr = psnr_test
238 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
239 | if tb_writer:
240 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
241 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
242 |
243 | if tb_writer:
244 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
245 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
246 | torch.cuda.empty_cache()
247 |
248 | return test_psnr
249 |
250 |
251 | if __name__ == "__main__":
252 | # Set up command line argument parser
253 | parser = ArgumentParser(description="Training script parameters")
254 | lp = ModelParams(parser)
255 | op = OptimizationParams(parser)
256 | pp = PipelineParams(parser)
257 | parser.add_argument('--ip', type=str, default="127.0.0.1")
258 | parser.add_argument('--port', type=int, default=6009)
259 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
260 | parser.add_argument("--test_iterations", nargs="+", type=int,
261 | default=[5000, 6000, 7_000] + list(range(10000, 40001, 1000)))
262 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 10_000, 20_000, 30_000, 40000])
263 | parser.add_argument("--quiet", action="store_true")
264 | args = parser.parse_args(sys.argv[1:])
265 | args.save_iterations.append(args.iterations)
266 |
267 | print("Optimizing " + args.model_path)
268 |
269 | # Initialize system state (RNG)
270 | safe_state(args.quiet)
271 |
272 | # Start GUI server, configure and run training
273 | # network_gui.init(args.ip, args.port)
274 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
275 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
276 |
277 | # All done
278 | print("\nTraining complete.")
279 |
--------------------------------------------------------------------------------
/train_gui.py:
--------------------------------------------------------------------------------
1 | 7 #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import time
14 | import torch
15 | from random import randint
16 | from utils.loss_utils import l1_loss, ssim, kl_divergence
17 | from gaussian_renderer import render, network_gui
18 | import sys
19 | from scene import Scene, GaussianModel, DeformModel
20 | from utils.general_utils import safe_state, get_linear_noise_func
21 | import uuid
22 | import tqdm
23 | from utils.image_utils import psnr
24 | from argparse import ArgumentParser, Namespace
25 | from arguments import ModelParams, PipelineParams, OptimizationParams
26 | from train import training_report
27 | import math
28 | from utils.gui_utils import orbit_camera, OrbitCamera
29 | import numpy as np
30 | import dearpygui.dearpygui as dpg
31 |
32 |
33 | try:
34 | from torch.utils.tensorboard import SummaryWriter
35 |
36 | TENSORBOARD_FOUND = True
37 | except ImportError:
38 | TENSORBOARD_FOUND = False
39 |
40 |
41 | def getProjectionMatrix(znear, zfar, fovX, fovY):
42 | tanHalfFovY = math.tan((fovY / 2))
43 | tanHalfFovX = math.tan((fovX / 2))
44 |
45 | P = torch.zeros(4, 4)
46 |
47 | z_sign = 1.0
48 |
49 | P[0, 0] = 1 / tanHalfFovX
50 | P[1, 1] = 1 / tanHalfFovY
51 | P[3, 2] = z_sign
52 | P[2, 2] = z_sign * zfar / (zfar - znear)
53 | P[2, 3] = -(zfar * znear) / (zfar - znear)
54 | return P
55 |
56 |
57 | class MiniCam:
58 | def __init__(self, c2w, width, height, fovy, fovx, znear, zfar, fid):
59 | # c2w (pose) should be in NeRF convention.
60 |
61 | self.image_width = width
62 | self.image_height = height
63 | self.FoVy = fovy
64 | self.FoVx = fovx
65 | self.znear = znear
66 | self.zfar = zfar
67 | self.fid = fid
68 |
69 | w2c = np.linalg.inv(c2w)
70 |
71 | # rectify...
72 | w2c[1:3, :3] *= -1
73 | w2c[:3, 3] *= -1
74 |
75 | self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()
76 | self.projection_matrix = (
77 | getProjectionMatrix(
78 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
79 | )
80 | .transpose(0, 1)
81 | .cuda()
82 | )
83 | self.full_proj_transform = self.world_view_transform @ self.projection_matrix
84 | self.camera_center = -torch.tensor(c2w[:3, 3]).cuda()
85 |
86 |
87 | class GUI:
88 | def __init__(self, args, dataset, opt, pipe, testing_iterations, saving_iterations) -> None:
89 | self.dataset = dataset
90 | self.args = args
91 | self.opt = opt
92 | self.pipe = pipe
93 | self.testing_iterations = testing_iterations
94 | self.saving_iterations = saving_iterations
95 |
96 | self.tb_writer = prepare_output_and_logger(dataset)
97 | self.gaussians = GaussianModel(dataset.sh_degree)
98 | self.deform = DeformModel(is_blender=dataset.is_blender, is_6dof=dataset.is_6dof)
99 | self.deform.train_setting(opt)
100 |
101 | self.scene = Scene(dataset, self.gaussians)
102 | self.gaussians.training_setup(opt)
103 |
104 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
105 | self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
106 |
107 | self.iter_start = torch.cuda.Event(enable_timing=True)
108 | self.iter_end = torch.cuda.Event(enable_timing=True)
109 | self.iteration = 1
110 |
111 | self.viewpoint_stack = None
112 | self.ema_loss_for_log = 0.0
113 | self.best_psnr = 0.0
114 | self.best_iteration = 0
115 | self.progress_bar = tqdm.tqdm(range(opt.iterations), desc="Training progress")
116 | self.smooth_term = get_linear_noise_func(lr_init=0.1, lr_final=1e-15, lr_delay_mult=0.01, max_steps=20000)
117 |
118 | # For UI
119 | self.visualization_mode = 'RGB'
120 |
121 | self.gui = args.gui # enable gui
122 | self.W = args.W
123 | self.H = args.H
124 | self.cam = OrbitCamera(args.W, args.H, r=args.radius, fovy=args.fovy)
125 |
126 | self.mode = "render"
127 | self.seed = "random"
128 | self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
129 | self.training = False
130 |
131 | if self.gui:
132 | dpg.create_context()
133 | self.register_dpg()
134 | self.test_step()
135 |
136 | def __del__(self):
137 | if self.gui:
138 | dpg.destroy_context()
139 |
140 | def register_dpg(self):
141 | ### register texture
142 | with dpg.texture_registry(show=False):
143 | dpg.add_raw_texture(
144 | self.W,
145 | self.H,
146 | self.buffer_image,
147 | format=dpg.mvFormat_Float_rgb,
148 | tag="_texture",
149 | )
150 |
151 | ### register window
152 | # the rendered image, as the primary window
153 | with dpg.window(
154 | tag="_primary_window",
155 | width=self.W,
156 | height=self.H,
157 | pos=[0, 0],
158 | no_move=True,
159 | no_title_bar=True,
160 | no_scrollbar=True,
161 | ):
162 | # add the texture
163 | dpg.add_image("_texture")
164 |
165 | # dpg.set_primary_window("_primary_window", True)
166 |
167 | # control window
168 | with dpg.window(
169 | label="Control",
170 | tag="_control_window",
171 | width=600,
172 | height=self.H,
173 | pos=[self.W, 0],
174 | no_move=True,
175 | no_title_bar=True,
176 | ):
177 | # button theme
178 | with dpg.theme() as theme_button:
179 | with dpg.theme_component(dpg.mvButton):
180 | dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
181 | dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
182 | dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
183 | dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
184 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
185 |
186 | # timer stuff
187 | with dpg.group(horizontal=True):
188 | dpg.add_text("Infer time: ")
189 | dpg.add_text("no data", tag="_log_infer_time")
190 |
191 | def callback_setattr(sender, app_data, user_data):
192 | setattr(self, user_data, app_data)
193 |
194 | # init stuff
195 | with dpg.collapsing_header(label="Initialize", default_open=True):
196 |
197 | # seed stuff
198 | def callback_set_seed(sender, app_data):
199 | self.seed = app_data
200 | self.seed_everything()
201 |
202 | dpg.add_input_text(
203 | label="seed",
204 | default_value=self.seed,
205 | on_enter=True,
206 | callback=callback_set_seed,
207 | )
208 |
209 | # input stuff
210 | def callback_select_input(sender, app_data):
211 | # only one item
212 | for k, v in app_data["selections"].items():
213 | dpg.set_value("_log_input", k)
214 | self.load_input(v)
215 |
216 | self.need_update = True
217 |
218 | with dpg.file_dialog(
219 | directory_selector=False,
220 | show=False,
221 | callback=callback_select_input,
222 | file_count=1,
223 | tag="file_dialog_tag",
224 | width=700,
225 | height=400,
226 | ):
227 | dpg.add_file_extension("Images{.jpg,.jpeg,.png}")
228 |
229 | with dpg.group(horizontal=True):
230 | dpg.add_button(
231 | label="input",
232 | callback=lambda: dpg.show_item("file_dialog_tag"),
233 | )
234 | dpg.add_text("", tag="_log_input")
235 |
236 | # save current model
237 | with dpg.group(horizontal=True):
238 | dpg.add_text("Visualization: ")
239 |
240 | def callback_vismode(sender, app_data, user_data):
241 | self.visualization_mode = user_data
242 | if user_data == 'Node':
243 | self.node_vis_fea = True if not hasattr(self, 'node_vis_fea') else not self.node_vis_fea
244 | print("Visualize node features" if self.node_vis_fea else "Visualize node importance")
245 | if self.node_vis_fea or True:
246 | from motion import visualize_featuremap
247 | if True: #self.renderer.gaussians.motion_model.soft_edge:
248 | if hasattr(self.renderer.gaussians.motion_model, 'nodes_fea'):
249 | node_rgb = visualize_featuremap(self.renderer.gaussians.motion_model.nodes_fea.detach().cpu().numpy())
250 | self.node_rgb = torch.from_numpy(node_rgb).cuda()
251 | else:
252 | self.node_rgb = None
253 | else:
254 | self.node_rgb = None
255 | else:
256 | node_imp = self.renderer.gaussians.motion_model.cal_node_importance(x=self.renderer.gaussians.get_xyz)
257 | node_imp = (node_imp - node_imp.min()) / (node_imp.max() - node_imp.min())
258 | node_rgb = torch.zeros([node_imp.shape[0], 3], dtype=torch.float32).cuda()
259 | node_rgb[..., 0] = node_imp
260 | node_rgb[..., -1] = 1 - node_imp
261 | self.node_rgb = node_rgb
262 |
263 | dpg.add_button(
264 | label="RGB",
265 | tag="_button_vis_rgb",
266 | callback=callback_vismode,
267 | user_data='RGB',
268 | )
269 | dpg.bind_item_theme("_button_vis_rgb", theme_button)
270 |
271 | dpg.add_button(
272 | label="UV_COOR",
273 | tag="_button_vis_uv",
274 | callback=callback_vismode,
275 | user_data='UV_COOR',
276 | )
277 | dpg.bind_item_theme("_button_vis_uv", theme_button)
278 | dpg.add_button(
279 | label="MotionMask",
280 | tag="_button_vis_motion_mask",
281 | callback=callback_vismode,
282 | user_data='MotionMask',
283 | )
284 | dpg.bind_item_theme("_button_vis_motion_mask", theme_button)
285 |
286 | dpg.add_button(
287 | label="Node",
288 | tag="_button_vis_node",
289 | callback=callback_vismode,
290 | user_data='Node',
291 | )
292 | dpg.bind_item_theme("_button_vis_node", theme_button)
293 |
294 | def callback_use_const_var(sender, app_data):
295 | self.use_const_var = not self.use_const_var
296 | dpg.add_button(
297 | label="Const Var",
298 | tag="_button_const_var",
299 | callback=callback_use_const_var
300 | )
301 | dpg.bind_item_theme("_button_const_var", theme_button)
302 |
303 | with dpg.group(horizontal=True):
304 | dpg.add_text("Scale Const: ")
305 | def callback_vis_scale_const(sender):
306 | self.vis_scale_const = 10 ** dpg.get_value(sender)
307 | self.need_update = True
308 | dpg.add_slider_float(
309 | label="Log vis_scale_const (For debugging)",
310 | default_value=-3,
311 | max_value=-.5,
312 | min_value=-5,
313 | callback=callback_vis_scale_const,
314 | )
315 |
316 | # save current model
317 | with dpg.group(horizontal=True):
318 | dpg.add_text("Temporal Speed: ")
319 | self.video_speed = 1.
320 | def callback_speed_control(sender):
321 | self.video_speed = dpg.get_value(sender)
322 | self.need_update = True
323 | dpg.add_slider_float(
324 | label="Play speed",
325 | default_value=1.,
326 | max_value=2.,
327 | min_value=0.0,
328 | callback=callback_speed_control,
329 | )
330 |
331 | # save current model
332 | with dpg.group(horizontal=True):
333 | dpg.add_text("Save: ")
334 |
335 | def callback_save(sender, app_data, user_data):
336 | self.save_model(mode=user_data)
337 |
338 | dpg.add_button(
339 | label="model",
340 | tag="_button_save_model",
341 | callback=callback_save,
342 | user_data='model',
343 | )
344 | dpg.bind_item_theme("_button_save_model", theme_button)
345 |
346 | dpg.add_button(
347 | label="geo",
348 | tag="_button_save_mesh",
349 | callback=callback_save,
350 | user_data='geo',
351 | )
352 | dpg.bind_item_theme("_button_save_mesh", theme_button)
353 |
354 | dpg.add_button(
355 | label="geo+tex",
356 | tag="_button_save_mesh_with_tex",
357 | callback=callback_save,
358 | user_data='geo+tex',
359 | )
360 | dpg.bind_item_theme("_button_save_mesh_with_tex", theme_button)
361 |
362 | dpg.add_button(
363 | label="pcl",
364 | tag="_button_save_pcl",
365 | callback=callback_save,
366 | user_data='pcl',
367 | )
368 | dpg.bind_item_theme("_button_save_pcl", theme_button)
369 |
370 | def call_back_save_train(sender, app_data, user_data):
371 | self.render_all_train_data()
372 | dpg.add_button(
373 | label="save_train",
374 | tag="_button_save_train",
375 | callback=call_back_save_train,
376 | )
377 |
378 | # training stuff
379 | with dpg.collapsing_header(label="Train", default_open=True):
380 | # lr and train button
381 | with dpg.group(horizontal=True):
382 | dpg.add_text("Train: ")
383 |
384 | def callback_train(sender, app_data):
385 | if self.training:
386 | self.training = False
387 | dpg.configure_item("_button_train", label="start")
388 | else:
389 | # self.prepare_train()
390 | self.training = True
391 | dpg.configure_item("_button_train", label="stop")
392 |
393 | dpg.add_button(
394 | label="start", tag="_button_train", callback=callback_train
395 | )
396 | dpg.bind_item_theme("_button_train", theme_button)
397 |
398 | with dpg.group(horizontal=True):
399 | dpg.add_text("", tag="_log_train_psnr")
400 | dpg.add_text("", tag="_log_train_log")
401 |
402 | # rendering options
403 | with dpg.collapsing_header(label="Rendering", default_open=True):
404 | # mode combo
405 | def callback_change_mode(sender, app_data):
406 | self.mode = app_data
407 | self.need_update = True
408 |
409 | dpg.add_combo(
410 | ("render", "depth"),
411 | label="mode",
412 | default_value=self.mode,
413 | callback=callback_change_mode,
414 | )
415 |
416 | # fov slider
417 | def callback_set_fovy(sender, app_data):
418 | self.cam.fovy = np.deg2rad(app_data)
419 | self.need_update = True
420 |
421 | dpg.add_slider_int(
422 | label="FoV (vertical)",
423 | min_value=1,
424 | max_value=120,
425 | format="%d deg",
426 | default_value=np.rad2deg(self.cam.fovy),
427 | callback=callback_set_fovy,
428 | )
429 |
430 | ### register camera handler
431 |
432 | def callback_camera_drag_rotate_or_draw_mask(sender, app_data):
433 | if not dpg.is_item_focused("_primary_window"):
434 | return
435 |
436 | dx = app_data[1]
437 | dy = app_data[2]
438 |
439 | self.cam.orbit(dx, dy)
440 | self.need_update = True
441 |
442 | def callback_camera_wheel_scale(sender, app_data):
443 | if not dpg.is_item_focused("_primary_window"):
444 | return
445 |
446 | delta = app_data
447 |
448 | self.cam.scale(delta)
449 | self.need_update = True
450 |
451 | def callback_camera_drag_pan(sender, app_data):
452 | if not dpg.is_item_focused("_primary_window"):
453 | return
454 |
455 | dx = app_data[1]
456 | dy = app_data[2]
457 |
458 | self.cam.pan(dx, dy)
459 | self.need_update = True
460 |
461 | with dpg.handler_registry():
462 | # for camera moving
463 | dpg.add_mouse_drag_handler(
464 | button=dpg.mvMouseButton_Left,
465 | callback=callback_camera_drag_rotate_or_draw_mask,
466 | )
467 | dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
468 | dpg.add_mouse_drag_handler(
469 | button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
470 | )
471 |
472 | dpg.create_viewport(
473 | title="Deformable-Gaussian",
474 | width=self.W + 600,
475 | height=self.H + (45 if os.name == "nt" else 0),
476 | resizable=False,
477 | )
478 |
479 | ### global theme
480 | with dpg.theme() as theme_no_padding:
481 | with dpg.theme_component(dpg.mvAll):
482 | # set all padding to 0 to avoid scroll bar
483 | dpg.add_theme_style(
484 | dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
485 | )
486 | dpg.add_theme_style(
487 | dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
488 | )
489 | dpg.add_theme_style(
490 | dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
491 | )
492 |
493 | dpg.bind_item_theme("_primary_window", theme_no_padding)
494 |
495 | dpg.setup_dearpygui()
496 |
497 | ### register a larger font
498 | # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf
499 | if os.path.exists("LXGWWenKai-Regular.ttf"):
500 | with dpg.font_registry():
501 | with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font:
502 | dpg.bind_font(default_font)
503 |
504 | # dpg.show_metrics()
505 |
506 | dpg.show_viewport()
507 |
508 | def render(self):
509 | assert self.gui
510 | while dpg.is_dearpygui_running():
511 | # update texture every frame
512 | if self.training:
513 | self.train_step()
514 | self.test_step()
515 | dpg.render_dearpygui_frame()
516 |
517 | # no gui mode
518 | def train(self, iters=5000):
519 | if iters > 0:
520 | for i in tqdm.trange(iters):
521 | self.train_step()
522 |
523 |
524 | def train_step(self):
525 | if network_gui.conn == None:
526 | network_gui.try_connect()
527 | while network_gui.conn != None:
528 | try:
529 | net_image_bytes = None
530 | custom_cam, do_training, self.pipe.do_shs_python, self.pipe.do_cov_python, keep_alive, scaling_modifer = network_gui.receive()
531 | if custom_cam != None:
532 | net_image = render(custom_cam, self.gaussians, self.pipe, self.background, scaling_modifer)["render"]
533 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2,
534 | 0).contiguous().cpu().numpy())
535 | network_gui.send(net_image_bytes, self.dataset.source_path)
536 | if do_training and ((self.iteration < int(self.opt.iterations)) or not keep_alive):
537 | break
538 | except Exception as e:
539 | network_gui.conn = None
540 |
541 | self.iter_start.record()
542 |
543 | # Every 1000 its we increase the levels of SH up to a maximum degree
544 | if self.iteration % 1000 == 0:
545 | self.gaussians.oneupSHdegree()
546 |
547 | # Pick a random Camera
548 | if not self.viewpoint_stack:
549 | self.viewpoint_stack = self.scene.getTrainCameras().copy()
550 |
551 | total_frame = len(self.viewpoint_stack)
552 | time_interval = 1 / total_frame
553 |
554 | viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1))
555 | if self.dataset.load2gpu_on_the_fly:
556 | viewpoint_cam.load2device()
557 | fid = viewpoint_cam.fid
558 |
559 | if self.iteration < self.opt.warm_up:
560 | d_xyz, d_rotation, d_scaling = 0.0, 0.0, 0.0
561 | else:
562 | N = self.gaussians.get_xyz.shape[0]
563 | time_input = fid.unsqueeze(0).expand(N, -1)
564 | ast_noise = 0 if self.dataset.is_blender else torch.randn(1, 1, device='cuda').expand(N, -1) * time_interval * self.smooth_term(self.iteration)
565 | d_xyz, d_rotation, d_scaling = self.deform.step(self.gaussians.get_xyz.detach(), time_input + ast_noise)
566 |
567 | # Render
568 | render_pkg_re = render(viewpoint_cam, self.gaussians, self.pipe, self.background, d_xyz, d_rotation, d_scaling, self.dataset.is_6dof)
569 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg_re["render"], render_pkg_re[
570 | "viewspace_points"], render_pkg_re["visibility_filter"], render_pkg_re["radii"]
571 | # depth = render_pkg_re["depth"]
572 |
573 | # Loss
574 | gt_image = viewpoint_cam.original_image.cuda()
575 | Ll1 = l1_loss(image, gt_image)
576 | loss = (1.0 - self.opt.lambda_dssim) * Ll1 + self.opt.lambda_dssim * (1.0 - ssim(image, gt_image))
577 | loss.backward()
578 |
579 | self.iter_end.record()
580 |
581 | if self.dataset.load2gpu_on_the_fly:
582 | viewpoint_cam.load2device('cpu')
583 |
584 | with torch.no_grad():
585 | # Progress bar
586 | self.ema_loss_for_log = 0.4 * loss.item() + 0.6 * self.ema_loss_for_log
587 | if self.iteration % 10 == 0:
588 | self.progress_bar.set_postfix({"Loss": f"{self.ema_loss_for_log:.{7}f}"})
589 | self.progress_bar.update(10)
590 | if self.iteration == self.opt.iterations:
591 | self.progress_bar.close()
592 |
593 | # Keep track of max radii in image-space for pruning
594 | self.gaussians.max_radii2D[visibility_filter] = torch.max(self.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
595 |
596 | # Log and save
597 | cur_psnr = training_report(self.tb_writer, self.iteration, Ll1, loss, l1_loss, self.iter_start.elapsed_time(self.iter_end), self.testing_iterations, self.scene, render, (self.pipe, self.background), self.deform, self.dataset.load2gpu_on_the_fly, self.dataset.is_6dof)
598 | if self.iteration in self.testing_iterations:
599 | if cur_psnr.item() > self.best_psnr:
600 | self.best_psnr = cur_psnr.item()
601 | self.best_iteration = self.iteration
602 |
603 | if self.iteration in self.saving_iterations:
604 | print("\n[ITER {}] Saving Gaussians".format(self.iteration))
605 | self.scene.save(self.iteration)
606 | self.deform.save_weights(args.model_path, self.iteration)
607 |
608 | # Densification
609 | if self.iteration < self.opt.densify_until_iter:
610 | self.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
611 |
612 | if self.iteration > self.opt.densify_from_iter and self.iteration % self.opt.densification_interval == 0:
613 | size_threshold = 20 if self.iteration > self.opt.opacity_reset_interval else None
614 | self.gaussians.densify_and_prune(self.opt.densify_grad_threshold, 0.005, self.scene.cameras_extent, size_threshold)
615 |
616 | if self.iteration % self.opt.opacity_reset_interval == 0 or (
617 | self.dataset.white_background and self.iteration == self.opt.densify_from_iter):
618 | self.gaussians.reset_opacity()
619 |
620 | # Optimizer step
621 | if self.iteration < self.opt.iterations:
622 | self.gaussians.optimizer.step()
623 | self.gaussians.update_learning_rate(self.iteration)
624 | self.gaussians.optimizer.zero_grad(set_to_none=True)
625 | self.deform.optimizer.step()
626 | self.deform.optimizer.zero_grad()
627 | self.deform.update_learning_rate(self.iteration)
628 |
629 | if self.gui:
630 | dpg.set_value(
631 | "_log_train_psnr",
632 | "Best PSNR = {} in Iteration {}".format(self.best_psnr, self.best_iteration)
633 | )
634 | else:
635 | print("Best PSNR = {} in Iteration {}".format(self.best_psnr, self.best_iteration))
636 | self.iteration += 1
637 |
638 | if self.gui:
639 | dpg.set_value(
640 | "_log_train_log",
641 | f"step = {self.iteration: 5d} loss = {loss.item():.4f}",
642 | )
643 |
644 | @torch.no_grad()
645 | def test_step(self):
646 |
647 | starter = torch.cuda.Event(enable_timing=True)
648 | ender = torch.cuda.Event(enable_timing=True)
649 | starter.record()
650 |
651 | if not hasattr(self, 't0'):
652 | self.t0 = time.time()
653 | self.fps_of_fid = 10
654 |
655 | cur_cam = MiniCam(
656 | self.cam.pose,
657 | self.W,
658 | self.H,
659 | self.cam.fovy,
660 | self.cam.fovx,
661 | self.cam.near,
662 | self.cam.far,
663 | fid=torch.remainder(torch.tensor((time.time()-self.t0) * self.fps_of_fid).float().cuda() / len(self.scene.getTrainCameras()), 1.)
664 | )
665 | fid = cur_cam.fid
666 |
667 | if self.iteration < self.opt.warm_up:
668 | d_xyz, d_rotation, d_scaling = 0.0, 0.0, 0.0
669 | else:
670 | N = self.gaussians.get_xyz.shape[0]
671 | time_input = fid.unsqueeze(0).expand(N, -1)
672 | d_xyz, d_rotation, d_scaling = self.deform.step(self.gaussians.get_xyz.detach(), time_input)
673 |
674 | out = render(viewpoint_camera=cur_cam, pc=self.gaussians, pipe=self.pipe, bg_color=self.background, d_xyz=d_xyz, d_rotation=d_rotation, d_scaling=d_scaling, is_6dof=self.dataset.is_6dof)
675 |
676 | buffer_image = out[self.mode] # [3, H, W]
677 |
678 | if self.mode in ['depth', 'alpha']:
679 | buffer_image = buffer_image.repeat(3, 1, 1)
680 | if self.mode == 'depth':
681 | buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)
682 |
683 | buffer_image = torch.nn.functional.interpolate(
684 | buffer_image.unsqueeze(0),
685 | size=(self.H, self.W),
686 | mode="bilinear",
687 | align_corners=False,
688 | ).squeeze(0)
689 |
690 | self.buffer_image = (
691 | buffer_image.permute(1, 2, 0)
692 | .contiguous()
693 | .clamp(0, 1)
694 | .contiguous()
695 | .detach()
696 | .cpu()
697 | .numpy()
698 | )
699 |
700 | self.need_update = True
701 |
702 | ender.record()
703 | torch.cuda.synchronize()
704 | t = starter.elapsed_time(ender)
705 |
706 | if self.gui:
707 | dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS FID: {fid.item()})")
708 | dpg.set_value(
709 | "_texture", self.buffer_image
710 | ) # buffer must be contiguous, else seg fault!
711 |
712 | # no gui mode
713 | def train(self, iters=5000):
714 | if iters > 0:
715 | for i in tqdm.trange(iters):
716 | self.train_step()
717 |
718 | def prepare_output_and_logger(args):
719 | if not args.model_path:
720 | if os.getenv('OAR_JOB_ID'):
721 | unique_str = os.getenv('OAR_JOB_ID')
722 | else:
723 | unique_str = str(uuid.uuid4())
724 | args.model_path = os.path.join("./output/", unique_str[0:10])
725 |
726 | # Set up output folder
727 | print("Output folder: {}".format(args.model_path))
728 | os.makedirs(args.model_path, exist_ok=True)
729 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
730 | cfg_log_f.write(str(Namespace(**vars(args))))
731 |
732 | # Create Tensorboard writer
733 | tb_writer = None
734 | if TENSORBOARD_FOUND:
735 | tb_writer = SummaryWriter(args.model_path)
736 | else:
737 | print("Tensorboard not available: not logging progress")
738 | return tb_writer
739 |
740 |
741 | if __name__ == "__main__":
742 | # Set up command line argument parser
743 | parser = ArgumentParser(description="Training script parameters")
744 | lp = ModelParams(parser)
745 | op = OptimizationParams(parser)
746 | pp = PipelineParams(parser)
747 |
748 | parser.add_argument('--gui', action='store_false', help="start a GUI")
749 | parser.add_argument('--W', type=int, default=800, help="GUI width")
750 | parser.add_argument('--H', type=int, default=800, help="GUI height")
751 | parser.add_argument('--elevation', type=float, default=0, help="default GUI camera elevation")
752 | parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center")
753 | parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy")
754 |
755 | parser.add_argument('--ip', type=str, default="127.0.0.1")
756 | parser.add_argument('--port', type=int, default=6009)
757 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
758 | parser.add_argument("--test_iterations", nargs="+", type=int,
759 | default=[5000, 6000, 7_000] + list(range(10000, 40001, 1000)))
760 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 10_000, 20_000, 30_000, 40000])
761 | parser.add_argument("--quiet", action="store_true")
762 | args = parser.parse_args(sys.argv[1:])
763 | args.save_iterations.append(args.iterations)
764 |
765 | print("Optimizing " + args.model_path)
766 |
767 | # Initialize system state (RNG)
768 | safe_state(args.quiet)
769 |
770 | # Start GUI server, configure and run training
771 | # network_gui.init(args.ip, args.port)
772 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
773 | gui = GUI(args=args, dataset=lp.extract(args), opt=op.extract(args), pipe=pp.extract(args),testing_iterations=args.test_iterations, saving_iterations=args.save_iterations)
774 |
775 | if args.gui:
776 | gui.render()
777 | # else:
778 | # gui.train(args.iterations)
779 |
780 | # All done
781 | print("\nTraining complete.")
782 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from scene.cameras import Camera
13 | import numpy as np
14 | from utils.general_utils import PILtoTorch, ArrayToTorch
15 | from utils.graphics_utils import fov2focal
16 | import json
17 |
18 | WARNED = False
19 |
20 |
21 | def loadCam(args, id, cam_info, resolution_scale):
22 | orig_w, orig_h = cam_info.image.size
23 |
24 | if args.resolution in [1, 2, 4, 8]:
25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round(
26 | orig_h / (resolution_scale * args.resolution))
27 | else: # should be a type that converts to float
28 | if args.resolution == -1:
29 | if orig_w > 1600:
30 | global WARNED
31 | if not WARNED:
32 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
33 | "If this is not desired, please explicitly specify '--resolution/-r' as 1")
34 | WARNED = True
35 | global_down = orig_w / 1600
36 | else:
37 | global_down = 1
38 | else:
39 | global_down = orig_w / args.resolution
40 |
41 | scale = float(global_down) * float(resolution_scale)
42 | resolution = (int(orig_w / scale), int(orig_h / scale))
43 |
44 | resized_image_rgb = PILtoTorch(cam_info.image, resolution)
45 |
46 | gt_image = resized_image_rgb[:3, ...]
47 | loaded_mask = None
48 |
49 | if resized_image_rgb.shape[1] == 4:
50 | loaded_mask = resized_image_rgb[3:4, ...]
51 |
52 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
53 | FoVx=cam_info.FovX, FoVy=cam_info.FovY,
54 | image=gt_image, gt_alpha_mask=loaded_mask,
55 | image_name=cam_info.image_name, uid=id,
56 | data_device=args.data_device if not args.load2gpu_on_the_fly else 'cpu', fid=cam_info.fid,
57 | depth=cam_info.depth)
58 |
59 |
60 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
61 | camera_list = []
62 |
63 | for id, c in enumerate(cam_infos):
64 | camera_list.append(loadCam(args, id, c, resolution_scale))
65 |
66 | return camera_list
67 |
68 |
69 | def camera_to_JSON(id, camera: Camera):
70 | Rt = np.zeros((4, 4))
71 | Rt[:3, :3] = camera.R.transpose()
72 | Rt[:3, 3] = camera.T
73 | Rt[3, 3] = 1.0
74 |
75 | W2C = np.linalg.inv(Rt)
76 | pos = W2C[:3, 3]
77 | rot = W2C[:3, :3]
78 | serializable_array_2d = [x.tolist() for x in rot]
79 | camera_entry = {
80 | 'id': id,
81 | 'img_name': camera.image_name,
82 | 'width': camera.width,
83 | 'height': camera.height,
84 | 'position': pos.tolist(),
85 | 'rotation': serializable_array_2d,
86 | 'fy': fov2focal(camera.FovY, camera.height),
87 | 'fx': fov2focal(camera.FovX, camera.width)
88 | }
89 | return camera_entry
90 |
91 |
92 | def camera_nerfies_from_JSON(path, scale):
93 | """Loads a JSON camera into memory."""
94 | with open(path, 'r') as fp:
95 | camera_json = json.load(fp)
96 |
97 | # Fix old camera JSON.
98 | if 'tangential' in camera_json:
99 | camera_json['tangential_distortion'] = camera_json['tangential']
100 |
101 | return dict(
102 | orientation=np.array(camera_json['orientation']),
103 | position=np.array(camera_json['position']),
104 | focal_length=camera_json['focal_length'] * scale,
105 | principal_point=np.array(camera_json['principal_point']) * scale,
106 | skew=camera_json['skew'],
107 | pixel_aspect_ratio=camera_json['pixel_aspect_ratio'],
108 | radial_distortion=np.array(camera_json['radial_distortion']),
109 | tangential_distortion=np.array(camera_json['tangential_distortion']),
110 | image_size=np.array((int(round(camera_json['image_size'][0] * scale)),
111 | int(round(camera_json['image_size'][1] * scale)))),
112 | )
113 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import sys
14 | from datetime import datetime
15 | import numpy as np
16 | import random
17 |
18 |
19 | def inverse_sigmoid(x):
20 | return torch.log(x / (1 - x))
21 |
22 |
23 | def PILtoTorch(pil_image, resolution):
24 | resized_image_PIL = pil_image.resize(resolution)
25 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
26 | if len(resized_image.shape) == 3:
27 | return resized_image.permute(2, 0, 1)
28 | else:
29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
30 |
31 |
32 | def ArrayToTorch(array, resolution):
33 | # resized_image = np.resize(array, resolution)
34 | resized_image_torch = torch.from_numpy(array)
35 |
36 | if len(resized_image_torch.shape) == 3:
37 | return resized_image_torch.permute(2, 0, 1)
38 | else:
39 | return resized_image_torch.unsqueeze(dim=-1).permute(2, 0, 1)
40 |
41 |
42 | def get_expon_lr_func(
43 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
44 | ):
45 | """
46 | Copied from Plenoxels
47 |
48 | Continuous learning rate decay function. Adapted from JaxNeRF
49 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
50 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
51 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
52 | function of lr_delay_mult, such that the initial learning rate is
53 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
54 | to the normal learning rate when steps>lr_delay_steps.
55 | :param conf: config subtree 'lr' or similar
56 | :param max_steps: int, the number of steps during optimization.
57 | :return HoF which takes step as input
58 | """
59 |
60 | def helper(step):
61 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
62 | # Disable this parameter
63 | return 0.0
64 | if lr_delay_steps > 0:
65 | # A kind of reverse cosine decay.
66 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
67 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
68 | )
69 | else:
70 | delay_rate = 1.0
71 | t = np.clip(step / max_steps, 0, 1)
72 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
73 | return delay_rate * log_lerp
74 |
75 | return helper
76 |
77 |
78 | def get_linear_noise_func(
79 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
80 | ):
81 | """
82 | Copied from Plenoxels
83 |
84 | Continuous learning rate decay function. Adapted from JaxNeRF
85 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
86 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
87 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
88 | function of lr_delay_mult, such that the initial learning rate is
89 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
90 | to the normal learning rate when steps>lr_delay_steps.
91 | :param conf: config subtree 'lr' or similar
92 | :param max_steps: int, the number of steps during optimization.
93 | :return HoF which takes step as input
94 | """
95 |
96 | def helper(step):
97 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
98 | # Disable this parameter
99 | return 0.0
100 | if lr_delay_steps > 0:
101 | # A kind of reverse cosine decay.
102 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
103 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
104 | )
105 | else:
106 | delay_rate = 1.0
107 | t = np.clip(step / max_steps, 0, 1)
108 | log_lerp = lr_init * (1 - t) + lr_final * t
109 | return delay_rate * log_lerp
110 |
111 | return helper
112 |
113 |
114 | def strip_lowerdiag(L):
115 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
116 |
117 | uncertainty[:, 0] = L[:, 0, 0]
118 | uncertainty[:, 1] = L[:, 0, 1]
119 | uncertainty[:, 2] = L[:, 0, 2]
120 | uncertainty[:, 3] = L[:, 1, 1]
121 | uncertainty[:, 4] = L[:, 1, 2]
122 | uncertainty[:, 5] = L[:, 2, 2]
123 | return uncertainty
124 |
125 |
126 | def strip_symmetric(sym):
127 | return strip_lowerdiag(sym)
128 |
129 |
130 | def build_rotation(r):
131 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3])
132 |
133 | q = r / norm[:, None]
134 |
135 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
136 |
137 | r = q[:, 0]
138 | x = q[:, 1]
139 | y = q[:, 2]
140 | z = q[:, 3]
141 |
142 | R[:, 0, 0] = 1 - 2 * (y * y + z * z)
143 | R[:, 0, 1] = 2 * (x * y - r * z)
144 | R[:, 0, 2] = 2 * (x * z + r * y)
145 | R[:, 1, 0] = 2 * (x * y + r * z)
146 | R[:, 1, 1] = 1 - 2 * (x * x + z * z)
147 | R[:, 1, 2] = 2 * (y * z - r * x)
148 | R[:, 2, 0] = 2 * (x * z - r * y)
149 | R[:, 2, 1] = 2 * (y * z + r * x)
150 | R[:, 2, 2] = 1 - 2 * (x * x + y * y)
151 | return R
152 |
153 |
154 | def build_scaling_rotation(s, r):
155 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
156 | R = build_rotation(r)
157 |
158 | L[:, 0, 0] = s[:, 0]
159 | L[:, 1, 1] = s[:, 1]
160 | L[:, 2, 2] = s[:, 2]
161 |
162 | L = R @ L
163 | return L
164 |
165 |
166 | def safe_state(silent):
167 | old_f = sys.stdout
168 |
169 | class F:
170 | def __init__(self, silent):
171 | self.silent = silent
172 |
173 | def write(self, x):
174 | if not self.silent:
175 | if x.endswith("\n"):
176 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
177 | else:
178 | old_f.write(x)
179 |
180 | def flush(self):
181 | old_f.flush()
182 |
183 | sys.stdout = F(silent)
184 |
185 | random.seed(0)
186 | np.random.seed(0)
187 | torch.manual_seed(0)
188 | torch.cuda.set_device(torch.device("cuda:0"))
189 |
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 |
18 | class BasicPointCloud(NamedTuple):
19 | points: np.array
20 | colors: np.array
21 | normals: np.array
22 |
23 |
24 | def geom_transform_points(points, transf_matrix):
25 | P, _ = points.shape
26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
27 | points_hom = torch.cat([points, ones], dim=1)
28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
29 |
30 | denom = points_out[..., 3:] + 0.0000001
31 | return (points_out[..., :3] / denom).squeeze(dim=0)
32 |
33 |
34 | def getWorld2View(R, t):
35 | Rt = np.zeros((4, 4))
36 | Rt[:3, :3] = R.transpose()
37 | Rt[:3, 3] = t
38 | Rt[3, 3] = 1.0
39 | return np.float32(Rt)
40 |
41 |
42 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
43 | Rt = np.zeros((4, 4))
44 | Rt[:3, :3] = R.transpose()
45 | Rt[:3, 3] = t
46 | Rt[3, 3] = 1.0
47 |
48 | C2W = np.linalg.inv(Rt)
49 | cam_center = C2W[:3, 3]
50 | cam_center = (cam_center + translate) * scale
51 | C2W[:3, 3] = cam_center
52 | Rt = np.linalg.inv(C2W)
53 | return np.float32(Rt)
54 |
55 |
56 | def getProjectionMatrix(znear, zfar, fovX, fovY):
57 | tanHalfFovY = math.tan((fovY / 2))
58 | tanHalfFovX = math.tan((fovX / 2))
59 |
60 | top = tanHalfFovY * znear
61 | bottom = -top
62 | right = tanHalfFovX * znear
63 | left = -right
64 |
65 | P = torch.zeros(4, 4)
66 |
67 | z_sign = 1.0
68 |
69 | P[0, 0] = 2.0 * znear / (right - left)
70 | P[1, 1] = 2.0 * znear / (top - bottom)
71 | P[0, 2] = (right + left) / (right - left)
72 | P[1, 2] = (top + bottom) / (top - bottom)
73 | P[3, 2] = z_sign
74 | P[2, 2] = z_sign * zfar / (zfar - znear)
75 | P[2, 3] = -(zfar * znear) / (zfar - znear)
76 | return P
77 |
78 |
79 | def fov2focal(fov, pixels):
80 | return pixels / (2 * math.tan(fov / 2))
81 |
82 |
83 | def focal2fov(focal, pixels):
84 | return 2 * math.atan(pixels / (2 * focal))
85 |
--------------------------------------------------------------------------------
/utils/gui_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.transform import Rotation as R
3 |
4 | import torch
5 |
6 | def dot(x, y):
7 | if isinstance(x, np.ndarray):
8 | return np.sum(x * y, -1, keepdims=True)
9 | else:
10 | return torch.sum(x * y, -1, keepdim=True)
11 |
12 |
13 | def length(x, eps=1e-20):
14 | if isinstance(x, np.ndarray):
15 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
16 | else:
17 | return torch.sqrt(torch.clamp(dot(x, x), min=eps))
18 |
19 |
20 | def safe_normalize(x, eps=1e-20):
21 | return x / length(x, eps)
22 |
23 |
24 | def look_at(campos, target, opengl=True):
25 | # campos: [N, 3], camera/eye position
26 | # target: [N, 3], object to look at
27 | # return: [N, 3, 3], rotation matrix
28 | if not opengl:
29 | # camera forward aligns with -z
30 | forward_vector = safe_normalize(target - campos)
31 | up_vector = np.array([0, 1, 0], dtype=np.float32)
32 | right_vector = safe_normalize(np.cross(forward_vector, up_vector))
33 | up_vector = safe_normalize(np.cross(right_vector, forward_vector))
34 | else:
35 | # camera forward aligns with +z
36 | forward_vector = safe_normalize(campos - target)
37 | up_vector = np.array([0, 1, 0], dtype=np.float32)
38 | right_vector = safe_normalize(np.cross(up_vector, forward_vector))
39 | up_vector = safe_normalize(np.cross(forward_vector, right_vector))
40 | R = np.stack([right_vector, up_vector, forward_vector], axis=1)
41 | return R
42 |
43 |
44 | # elevation & azimuth to pose (cam2world) matrix
45 | def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
46 | # radius: scalar
47 | # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)
48 | # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)
49 | # return: [4, 4], camera pose matrix
50 | if is_degree:
51 | elevation = np.deg2rad(elevation)
52 | azimuth = np.deg2rad(azimuth)
53 | x = radius * np.cos(elevation) * np.sin(azimuth)
54 | y = - radius * np.sin(elevation)
55 | z = radius * np.cos(elevation) * np.cos(azimuth)
56 | if target is None:
57 | target = np.zeros([3], dtype=np.float32)
58 | campos = np.array([x, y, z]) + target # [3]
59 | T = np.eye(4, dtype=np.float32)
60 | T[:3, :3] = look_at(campos, target, opengl)
61 | T[:3, 3] = campos
62 | return T
63 |
64 |
65 | class OrbitCamera:
66 | def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
67 | self.W = W
68 | self.H = H
69 | self.radius = r # camera distance from center
70 | self.fovy = np.deg2rad(fovy) # deg 2 rad
71 | self.near = near
72 | self.far = far
73 | self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
74 | # self.rot = R.from_matrix(np.eye(3))
75 | self.rot = R.from_matrix(np.array([[1., 0., 0.,],
76 | [0., 0., -1.],
77 | [0., 1., 0.]]))
78 | self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
79 | self.side = np.array([1, 0, 0], dtype=np.float32)
80 |
81 | @property
82 | def fovx(self):
83 | return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)
84 |
85 | @property
86 | def campos(self):
87 | return self.pose[:3, 3]
88 |
89 | # pose (c2w)
90 | @property
91 | def pose(self):
92 | # first move camera to radius
93 | res = np.eye(4, dtype=np.float32)
94 | res[2, 3] = self.radius # opengl convention...
95 | # rotate
96 | rot = np.eye(4, dtype=np.float32)
97 | rot[:3, :3] = self.rot.as_matrix()
98 | res = rot @ res
99 | # translate
100 | res[:3, 3] -= self.center
101 | return res
102 |
103 | # view (w2c)
104 | @property
105 | def view(self):
106 | return np.linalg.inv(self.pose)
107 |
108 | # projection (perspective)
109 | @property
110 | def perspective(self):
111 | y = np.tan(self.fovy / 2)
112 | aspect = self.W / self.H
113 | return np.array(
114 | [
115 | [1 / (y * aspect), 0, 0, 0],
116 | [0, -1 / y, 0, 0],
117 | [
118 | 0,
119 | 0,
120 | -(self.far + self.near) / (self.far - self.near),
121 | -(2 * self.far * self.near) / (self.far - self.near),
122 | ],
123 | [0, 0, -1, 0],
124 | ],
125 | dtype=np.float32,
126 | )
127 |
128 | # intrinsics
129 | @property
130 | def intrinsics(self):
131 | focal = self.H / (2 * np.tan(self.fovy / 2))
132 | return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)
133 |
134 | @property
135 | def mvp(self):
136 | return self.perspective @ np.linalg.inv(self.pose) # [4, 4]
137 |
138 | def orbit(self, dx, dy):
139 | # rotate along camera up/side axis!
140 | side = self.rot.as_matrix()[:3, 0]
141 | up = self.rot.as_matrix()[:3, 1]
142 | rotvec_x = up * np.radians(-0.05 * dx)
143 | rotvec_y = side * np.radians(-0.05 * dy)
144 | self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
145 |
146 | def scale(self, delta):
147 | self.radius *= 1.1 ** (-delta)
148 |
149 | def pan(self, dx, dy, dz=0):
150 | # pan in camera coordinate system (careful on the sensitivity!)
151 | self.center += 0.0001 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])
152 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 |
14 |
15 | def mse(img1, img2):
16 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
17 |
18 |
19 | def psnr(img1, img2):
20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
21 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
22 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 |
18 | def l1_loss(network_output, gt):
19 | return torch.abs((network_output - gt)).mean()
20 |
21 |
22 | def kl_divergence(rho, rho_hat):
23 | rho_hat = torch.mean(torch.sigmoid(rho_hat), 0)
24 | rho = torch.tensor([rho] * len(rho_hat)).cuda()
25 | return torch.mean(
26 | rho * torch.log(rho / (rho_hat + 1e-5)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + 1e-5)))
27 |
28 |
29 | def l2_loss(network_output, gt):
30 | return ((network_output - gt) ** 2).mean()
31 |
32 |
33 | def gaussian(window_size, sigma):
34 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
35 | return gauss / gauss.sum()
36 |
37 |
38 | def create_window(window_size, channel):
39 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
40 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
41 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
42 | return window
43 |
44 |
45 | def ssim(img1, img2, window_size=11, size_average=True):
46 | channel = img1.size(-3)
47 | window = create_window(window_size, channel)
48 |
49 | if img1.is_cuda:
50 | window = window.cuda(img1.get_device())
51 | window = window.type_as(img1)
52 |
53 | return _ssim(img1, img2, window, window_size, channel, size_average)
54 |
55 |
56 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
57 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
58 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
59 |
60 | mu1_sq = mu1.pow(2)
61 | mu2_sq = mu2.pow(2)
62 | mu1_mu2 = mu1 * mu2
63 |
64 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
65 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
66 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
67 |
68 | C1 = 0.01 ** 2
69 | C2 = 0.03 ** 2
70 |
71 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
72 |
73 | if size_average:
74 | return ssim_map.mean()
75 | else:
76 | return ssim_map.mean(1).mean(1).mean(1)
77 |
--------------------------------------------------------------------------------
/utils/pose_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from utils.graphics_utils import fov2focal
4 |
5 | trans_t = lambda t: torch.Tensor([
6 | [1, 0, 0, 0],
7 | [0, 1, 0, 0],
8 | [0, 0, 1, t],
9 | [0, 0, 0, 1]]).float()
10 |
11 | rot_phi = lambda phi: torch.Tensor([
12 | [1, 0, 0, 0],
13 | [0, np.cos(phi), -np.sin(phi), 0],
14 | [0, np.sin(phi), np.cos(phi), 0],
15 | [0, 0, 0, 1]]).float()
16 |
17 | rot_theta = lambda th: torch.Tensor([
18 | [np.cos(th), 0, -np.sin(th), 0],
19 | [0, 1, 0, 0],
20 | [np.sin(th), 0, np.cos(th), 0],
21 | [0, 0, 0, 1]]).float()
22 |
23 |
24 | def rodrigues_mat_to_rot(R):
25 | eps = 1e-16
26 | trc = np.trace(R)
27 | trc2 = (trc - 1.) / 2.
28 | # sinacostrc2 = np.sqrt(1 - trc2 * trc2)
29 | s = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]])
30 | if (1 - trc2 * trc2) >= eps:
31 | tHeta = np.arccos(trc2)
32 | tHetaf = tHeta / (2 * (np.sin(tHeta)))
33 | else:
34 | tHeta = np.real(np.arccos(trc2))
35 | tHetaf = 0.5 / (1 - tHeta / 6)
36 | omega = tHetaf * s
37 | return omega
38 |
39 |
40 | def rodrigues_rot_to_mat(r):
41 | wx, wy, wz = r
42 | theta = np.sqrt(wx * wx + wy * wy + wz * wz)
43 | a = np.cos(theta)
44 | b = (1 - np.cos(theta)) / (theta * theta)
45 | c = np.sin(theta) / theta
46 | R = np.zeros([3, 3])
47 | R[0, 0] = a + b * (wx * wx)
48 | R[0, 1] = b * wx * wy - c * wz
49 | R[0, 2] = b * wx * wz + c * wy
50 | R[1, 0] = b * wx * wy + c * wz
51 | R[1, 1] = a + b * (wy * wy)
52 | R[1, 2] = b * wy * wz - c * wx
53 | R[2, 0] = b * wx * wz - c * wy
54 | R[2, 1] = b * wz * wy + c * wx
55 | R[2, 2] = a + b * (wz * wz)
56 | return R
57 |
58 |
59 | def pose_spherical(theta, phi, radius):
60 | c2w = trans_t(radius)
61 | c2w = rot_phi(phi / 180. * np.pi) @ c2w
62 | c2w = rot_theta(theta / 180. * np.pi) @ c2w
63 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w
64 | return c2w
65 |
66 |
67 | def render_wander_path(view):
68 | focal_length = fov2focal(view.FoVy, view.image_height)
69 | R = view.R
70 | R[:, 1] = -R[:, 1]
71 | R[:, 2] = -R[:, 2]
72 | T = -view.T.reshape(-1, 1)
73 | pose = np.concatenate([R, T], -1)
74 |
75 | num_frames = 60
76 | max_disp = 5000.0 # 64 , 48
77 |
78 | max_trans = max_disp / focal_length # Maximum camera translation to satisfy max_disp parameter
79 | output_poses = []
80 |
81 | for i in range(num_frames):
82 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames))
83 | y_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 # * 3.0 / 4.0
84 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0
85 |
86 | i_pose = np.concatenate([
87 | np.concatenate(
88 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
89 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
90 | ], axis=0) # [np.newaxis, :, :]
91 |
92 | i_pose = np.linalg.inv(i_pose) # torch.tensor(np.linalg.inv(i_pose)).float()
93 |
94 | ref_pose = np.concatenate([pose, np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)
95 |
96 | render_pose = np.dot(ref_pose, i_pose)
97 | output_poses.append(torch.Tensor(render_pose))
98 |
99 | return output_poses
100 |
--------------------------------------------------------------------------------
/utils/rigid_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def skew(w: torch.Tensor) -> torch.Tensor:
5 | """Build a skew matrix ("cross product matrix") for vector w.
6 |
7 | Modern Robotics Eqn 3.30.
8 |
9 | Args:
10 | w: (N, 3) A 3-vector
11 |
12 | Returns:
13 | W: (N, 3, 3) A skew matrix such that W @ v == w x v
14 | """
15 | zeros = torch.zeros(w.shape[0], device=w.device)
16 | w_skew_list = [zeros, -w[:, 2], w[:, 1],
17 | w[:, 2], zeros, -w[:, 0],
18 | -w[:, 1], w[:, 0], zeros]
19 | w_skew = torch.stack(w_skew_list, dim=-1).reshape(-1, 3, 3)
20 | return w_skew
21 |
22 |
23 | def rp_to_se3(R: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
24 | """Rotation and translation to homogeneous transform.
25 |
26 | Args:
27 | R: (3, 3) An orthonormal rotation matrix.
28 | p: (3,) A 3-vector representing an offset.
29 |
30 | Returns:
31 | X: (4, 4) The homogeneous transformation matrix described by rotating by R
32 | and translating by p.
33 | """
34 | bottom_row = torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=R.device).repeat(R.shape[0], 1, 1)
35 | transform = torch.cat([torch.cat([R, p], dim=-1), bottom_row], dim=1)
36 |
37 | return transform
38 |
39 |
40 | def exp_so3(w: torch.Tensor, theta: float) -> torch.Tensor:
41 | """Exponential map from Lie algebra so3 to Lie group SO3.
42 |
43 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula.
44 |
45 | Args:
46 | w: (3,) An axis of rotation.
47 | theta: An angle of rotation.
48 |
49 | Returns:
50 | R: (3, 3) An orthonormal rotation matrix representing a rotation of
51 | magnitude theta about axis w.
52 | """
53 | W = skew(w)
54 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device)
55 | W_sqr = torch.bmm(W, W) # batch matrix multiplication
56 | R = identity + torch.sin(theta.unsqueeze(-1)) * W + (1.0 - torch.cos(theta.unsqueeze(-1))) * W_sqr
57 | return R
58 |
59 |
60 | def exp_se3(S: torch.Tensor, theta: float) -> torch.Tensor:
61 | """Exponential map from Lie algebra so3 to Lie group SO3.
62 |
63 | Modern Robotics Eqn 3.88.
64 |
65 | Args:
66 | S: (6,) A screw axis of motion.
67 | theta: Magnitude of motion.
68 |
69 | Returns:
70 | a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating
71 | motion of magnitude theta about S for one second.
72 | """
73 | w, v = torch.split(S, 3, dim=-1)
74 | W = skew(w)
75 | R = exp_so3(w, theta)
76 |
77 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device)
78 | W_sqr = torch.bmm(W, W)
79 | theta = theta.view(-1, 1, 1)
80 |
81 | p = torch.bmm((theta * identity + (1.0 - torch.cos(theta)) * W + (theta - torch.sin(theta)) * W_sqr),
82 | v.unsqueeze(-1))
83 | return rp_to_se3(R, p)
84 |
85 |
86 | def to_homogenous(v: torch.Tensor) -> torch.Tensor:
87 | """Converts a vector to a homogeneous coordinate vector by appending a 1.
88 |
89 | Args:
90 | v: A tensor representing a vector or batch of vectors.
91 |
92 | Returns:
93 | A tensor with an additional dimension set to 1.
94 | """
95 | return torch.cat([v, torch.ones_like(v[..., :1])], dim=-1)
96 |
97 |
98 | def from_homogenous(v: torch.Tensor) -> torch.Tensor:
99 | """Converts a homogeneous coordinate vector to a standard vector by dividing by the last element.
100 |
101 | Args:
102 | v: A tensor representing a homogeneous coordinate vector or batch of homogeneous coordinate vectors.
103 |
104 | Returns:
105 | A tensor with the last dimension removed.
106 | """
107 | return v[..., :3] / v[..., -1:]
108 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 |
115 | def RGB2SH(rgb):
116 | return (rgb - 0.5) / C0
117 |
118 |
119 | def SH2RGB(sh):
120 | return sh * C0 + 0.5
121 |
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from errno import EEXIST
13 | from os import makedirs, path
14 | import os
15 |
16 |
17 | def mkdir_p(folder_path):
18 | # Creates a directory. equivalent to using mkdir -p on the command line
19 | try:
20 | makedirs(folder_path)
21 | except OSError as exc: # Python >2.5
22 | if exc.errno == EEXIST and path.isdir(folder_path):
23 | pass
24 | else:
25 | raise
26 |
27 |
28 | def searchForMaxIteration(folder):
29 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
30 | return max(saved_iters)
31 |
--------------------------------------------------------------------------------
/utils/time_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils.rigid_utils import exp_se3
5 |
6 |
7 | def get_embedder(multires, i=1):
8 | if i == -1:
9 | return nn.Identity(), 3
10 |
11 | embed_kwargs = {
12 | 'include_input': True,
13 | 'input_dims': i,
14 | 'max_freq_log2': multires - 1,
15 | 'num_freqs': multires,
16 | 'log_sampling': True,
17 | 'periodic_fns': [torch.sin, torch.cos],
18 | }
19 |
20 | embedder_obj = Embedder(**embed_kwargs)
21 | embed = lambda x, eo=embedder_obj: eo.embed(x)
22 | return embed, embedder_obj.out_dim
23 |
24 |
25 | class Embedder:
26 | def __init__(self, **kwargs):
27 | self.kwargs = kwargs
28 | self.create_embedding_fn()
29 |
30 | def create_embedding_fn(self):
31 | embed_fns = []
32 | d = self.kwargs['input_dims']
33 | out_dim = 0
34 | if self.kwargs['include_input']:
35 | embed_fns.append(lambda x: x)
36 | out_dim += d
37 |
38 | max_freq = self.kwargs['max_freq_log2']
39 | N_freqs = self.kwargs['num_freqs']
40 |
41 | if self.kwargs['log_sampling']:
42 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
43 | else:
44 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)
45 |
46 | for freq in freq_bands:
47 | for p_fn in self.kwargs['periodic_fns']:
48 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
49 | out_dim += d
50 |
51 | self.embed_fns = embed_fns
52 | self.out_dim = out_dim
53 |
54 | def embed(self, inputs):
55 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
56 |
57 |
58 | class DeformNetwork(nn.Module):
59 | def __init__(self, D=8, W=256, input_ch=3, output_ch=59, multires=10, is_blender=False, is_6dof=False):
60 | super(DeformNetwork, self).__init__()
61 | self.D = D
62 | self.W = W
63 | self.input_ch = input_ch
64 | self.output_ch = output_ch
65 | self.t_multires = 6 if is_blender else 10
66 | self.skips = [D // 2]
67 |
68 | self.embed_time_fn, time_input_ch = get_embedder(self.t_multires, 1)
69 | self.embed_fn, xyz_input_ch = get_embedder(multires, 3)
70 | self.input_ch = xyz_input_ch + time_input_ch
71 |
72 | if is_blender:
73 | # Better for D-NeRF Dataset
74 | self.time_out = 30
75 |
76 | self.timenet = nn.Sequential(
77 | nn.Linear(time_input_ch, 256), nn.ReLU(inplace=True),
78 | nn.Linear(256, self.time_out))
79 |
80 | self.linear = nn.ModuleList(
81 | [nn.Linear(xyz_input_ch + self.time_out, W)] + [
82 | nn.Linear(W, W) if i not in self.skips else nn.Linear(W + xyz_input_ch + self.time_out, W)
83 | for i in range(D - 1)]
84 | )
85 |
86 | else:
87 | self.linear = nn.ModuleList(
88 | [nn.Linear(self.input_ch, W)] + [
89 | nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W)
90 | for i in range(D - 1)]
91 | )
92 |
93 | self.is_blender = is_blender
94 | self.is_6dof = is_6dof
95 |
96 | if is_6dof:
97 | self.branch_w = nn.Linear(W, 3)
98 | self.branch_v = nn.Linear(W, 3)
99 | else:
100 | self.gaussian_warp = nn.Linear(W, 3)
101 | self.gaussian_rotation = nn.Linear(W, 4)
102 | self.gaussian_scaling = nn.Linear(W, 3)
103 |
104 | def forward(self, x, t):
105 | t_emb = self.embed_time_fn(t)
106 | if self.is_blender:
107 | t_emb = self.timenet(t_emb) # better for D-NeRF Dataset
108 | x_emb = self.embed_fn(x)
109 | h = torch.cat([x_emb, t_emb], dim=-1)
110 | for i, l in enumerate(self.linear):
111 | h = self.linear[i](h)
112 | h = F.relu(h)
113 | if i in self.skips:
114 | h = torch.cat([x_emb, t_emb, h], -1)
115 |
116 | if self.is_6dof:
117 | w = self.branch_w(h)
118 | v = self.branch_v(h)
119 | theta = torch.norm(w, dim=-1, keepdim=True)
120 | w = w / theta + 1e-5
121 | v = v / theta + 1e-5
122 | screw_axis = torch.cat([w, v], dim=-1)
123 | d_xyz = exp_se3(screw_axis, theta)
124 | else:
125 | d_xyz = self.gaussian_warp(h)
126 | scaling = self.gaussian_scaling(h)
127 | rotation = self.gaussian_rotation(h)
128 |
129 | return d_xyz, rotation, scaling
130 |
--------------------------------------------------------------------------------