├── .gitmodules
├── LICENSE.md
├── LICENSE_BBSplat.txt
├── README.md
├── arguments
└── __init__.py
├── assets
├── alpha_init_gaussian.png
├── alpha_init_gaussian_small.png
├── control_panel.png
└── readme_images
│ ├── blender_preset.jpg
│ ├── scull.gif
│ ├── teaser.png
│ ├── train.gif
│ └── visualizer.png
├── bbsplat_install.sh
├── convert.py
├── docker
├── Dockerfile
├── build.sh
├── environment.yml
├── push.sh
├── run.sh
└── source.sh
├── docker_colmap
├── run.sh
└── source.sh
├── gaussian_renderer
├── __init__.py
└── network_gui.py
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── metrics.py
├── render.py
├── scene
├── __init__.py
├── cameras.py
├── colmap_loader.py
├── dataset_readers.py
└── gaussian_model.py
├── scripts
├── average_error.py
├── colmap_all.sh
├── dtu_eval.py
├── eval_dtu
│ ├── eval.py
│ ├── evaluate_single_scene.py
│ └── render_utils.py
├── metrics_all.sh
├── render_all.sh
└── train_all.sh
├── train.py
├── utils
├── camera_utils.py
├── general_utils.py
├── graphics_utils.py
├── image_utils.py
├── loss_utils.py
├── mcube_utils.py
├── mesh_utils.py
├── point_utils.py
├── reconstruction_utils.py
├── render_utils.py
├── sh_utils.py
└── system_utils.py
└── visualize.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "submodules/simple-knn"]
2 | path = submodules/simple-knn
3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git
4 | [submodule "submodules/diff-bbsplat-rasterization"]
5 | path = submodules/diff-bbsplat-rasterization
6 | url = https://github.com/david-svitov/diff-bbsplat-rasterization.git
7 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Gaussian-Splatting License
2 | ===========================
3 |
4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5 | The *Software* is in the process of being registered with the Agence pour la Protection des
6 | Programmes (APP).
7 |
8 | The *Software* is still being developed by the *Licensor*.
9 |
10 | *Licensor*'s goal is to allow the research community to use, test and evaluate
11 | the *Software*.
12 |
13 | ## 1. Definitions
14 |
15 | *Licensee* means any person or entity that uses the *Software* and distributes
16 | its *Work*.
17 |
18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII
19 |
20 | *Software* means the original work of authorship made available under this
21 | License ie gaussian-splatting.
22 |
23 | *Work* means the *Software* and any additions to or derivative works of the
24 | *Software* that are made available under this License.
25 |
26 |
27 | ## 2. Purpose
28 | This license is intended to define the rights granted to the *Licensee* by
29 | Licensors under the *Software*.
30 |
31 | ## 3. Rights granted
32 |
33 | For the above reasons Licensors have decided to distribute the *Software*.
34 | Licensors grant non-exclusive rights to use the *Software* for research purposes
35 | to research users (both academic and industrial), free of charge, without right
36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37 | and/or evaluation purposes only.
38 |
39 | Subject to the terms and conditions of this License, you are granted a
40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41 | publicly display, publicly perform and distribute its *Work* and any resulting
42 | derivative works in any form.
43 |
44 | ## 4. Limitations
45 |
46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47 | so under this License, (b) you include a complete copy of this License with
48 | your distribution, and (c) you retain without modification any copyright,
49 | patent, trademark, or attribution notices that are present in the *Work*.
50 |
51 | **4.2 Derivative Works.** You may specify that additional or different terms apply
52 | to the use, reproduction, and distribution of your derivative works of the *Work*
53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in
54 | Section 2 applies to your derivative works, and (b) you identify the specific
55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56 | this License (including the redistribution requirements in Section 3.1) will
57 | continue to apply to the *Work* itself.
58 |
59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60 | users explicitly acknowledge having received from Licensors all information
61 | allowing to appreciate the adequacy between of the *Software* and their needs and
62 | to undertake all necessary precautions for its execution and use.
63 |
64 | **4.4** The *Software* is provided both as a compiled library file and as source
65 | code. In case of using the *Software* for a publication or other results obtained
66 | through the use of the *Software*, users are strongly encouraged to cite the
67 | corresponding publications as explained in the documentation of the *Software*.
68 |
69 | ## 5. Disclaimer
70 |
71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
84 |
--------------------------------------------------------------------------------
/LICENSE_BBSplat.txt:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, David Svitov
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BillBoard Splatting (BBSplat): Learnable Textured Primitives for Novel View Synthesis
2 |
3 | [Project page](https://david-svitov.github.io/BBSplat_project_page/) | [Paper](https://arxiv.org/pdf/2411.08508) | [Video](https://youtu.be/ZnIOZHBJ4wM) | [BBSplat Rasterizer (CUDA)](https://github.com/david-svitov/diff-bbsplat-rasterization/) | [Scenes example (1.5GB)](https://drive.google.com/file/d/1gu_bDFXx38KJtwIrXo8lMVtuY-P2PFXX/view?usp=sharing) |
4 |
5 | 
6 |
7 | ## Abstract
8 | We present billboard Splatting (BBSplat) - a novel approach for novel view synthesis based on textured geometric primitives.
9 | BBSplat represents the scene as a set of optimizable textured planar primitives with learnable RGB textures and alpha-maps to
10 | control their shape. BBSplat primitives can be used in any Gaussian Splatting pipeline as drop-in replacements for Gaussians.
11 | The proposed primitives close the rendering quality gap between 2D and 3D Gaussian Splatting (GS), enabling the accurate extraction
12 | of 3D mesh as in the 2DGS framework. Additionally, the explicit nature of planar primitives enables the use of the ray-tracing effects in rasterization.
13 | Our novel regularization term encourages textures to have a sparser structure, enabling an efficient compression that leads to a reduction in the storage
14 | space of the model up to $\times17$ times compared to 3DGS. Our experiments show the efficiency of BBSplat on standard datasets of real indoor and outdoor
15 | scenes such as Tanks\&Temples, DTU, and Mip-NeRF-360. Namely, we achieve a state-of-the-art PSNR of 29.72 for DTU at Full HD resolution.
16 |
17 | ## Updates
18 |
19 | * 10/02/2025 - We fixed a bug in the FPS measurement function and updated the preprint accordingly.
20 | * 13/03/2025 - We released the mesh extraction code
21 |
22 | ## Repository structure
23 |
24 | Here, we briefly describe the key elements of the project. All main python scripts are in the ```./``` directory,
25 | bash scripts to reproduce the experiments are in the ```scripts``` folder, for a quick start please use
26 | Docker images provided in the ```docker``` folder.
27 |
28 | ```bash
29 | .
30 | ├── scripts # Bash scripts to process datasets
31 | │ ├── colmap_all.sh # > Extract point clouds with COLMAP
32 | │ ├── dtu_eval.py # Script to run DTU Chamfer distance evaluation
33 | │ ├── train_all.sh # > Fit all scenes
34 | │ ├── render_all.sh # > Render all scenes
35 | │ └── metrics_all.sh # > Calculate metrics for all scenes
36 | ├── submodules
37 | │ ├── diff-bbsplat-rasterization # CUDA implementation of BBSplat rasterized
38 | │ └── simple-knn # CUDA implementation of KNN
39 | ├── docker # Scripts to build and run Docker image
40 | ├── docker_colmap # Scripts to download and run Docker image for COLMAP
41 | ├── bbsplat_install.sh # Build and install submodules
42 | ├── convert.py # Extract point cloud with COLMAP
43 | ├── train.py # Train BBSplat scene representation
44 | ├── render.py # Novel view synthesis
45 | ├── metrics.py # Metrics calculation
46 | └── visualize.py # Interactive scene visualizer
47 | ```
48 |
49 |
50 | ## Installation
51 |
52 | We prepared the Docker image for quick and easy installation. Please follow the next steps:
53 |
54 | ```bash
55 | # Download
56 | git clone https://github.com/david-svitov/BBSplat.git --recursive
57 | # Go to the "docker" subfolder
58 | cd BBSplat/docker
59 |
60 | # Build Docker image
61 | bash build.sh
62 | # Optionally adjust mounting folder paths in source.sh
63 | # Run Docker container
64 | bash run.sh
65 |
66 | # In the container please install submodules
67 | bash bbsplat_install.sh
68 | ```
69 |
70 |
71 | Docker container for COLMAP
72 |
73 | To use COLMAP you can also use provided Docker image in the ```docker_colmap``` as follows:
74 |
75 | ```bash
76 | cd BBSplat/docker_colmap
77 | # Optionally adjust mounting folder paths in source.sh
78 | # Run Docker container
79 | bash run.sh
80 |
81 | # The trick is that you have to install OpenCV in this container because we use "jsantisi/colmap-gpu" one
82 | add-apt-repository universe
83 | apt-get update
84 | apt install python3-pip
85 | python3 -m pip install opencv-python
86 | ```
87 |
88 |
89 | ## Data preprocessing
90 |
91 | The example of using ```convert.py``` can be found in ```scripts\colmap.all```.
92 | Please note that for different datasets in the paper we used different ```images_N``` folders from the COLMAP output folder.
93 | The instructions on how to install COLMAP can be found above.
94 |
95 | We use the same COLMAP loader as 3DGS and 2DGS, you can find detailed description of it [here](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file#processing-your-own-scenes).
96 |
97 |
98 | ## Training
99 | To train a scene, please use following command:
100 | ```bash
101 | python train.py -s --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval
102 | ```
103 | Commandline arguments description:
104 | ```bash
105 | --cap_max # maximum number of Billboards
106 | --max_read_points # maximum number of SfM points for initialization
107 | --add_sky_box # flag to create additional points for far objects
108 | --eval # to hold each N-th image for evaluation
109 |
110 | # 2DGS normal-depth regularization can be beneficial for some datasets
111 | --lambda_normal # hyperparameter for normal consistency
112 | --lambda_distortion # hyperparameter for depth distortion
113 | ```
114 |
115 | The examples of training commands for different datasets can be found in ```scripts\train_all.sh```.
116 |
117 | ## Testing
118 | ### Novel view synthesis evaluation
119 | For novel view synthesis use:
120 | ```bash
121 | python render.py -m -s
122 | ```
123 |
124 | Commandline arguments description:
125 | ```bash
126 | --skip_mesh # flag to disable mesh extraction to accelerate NVS evaluation
127 | --save_planes # flag to save BBSplat as a set of textured planes
128 | ```
129 |
130 | To calculate metrics values use:
131 | ```bash
132 | python metrics.py -m
133 | ```
134 | The examples for the datasets used in the paper can be found in ```scripts\render_all.sh``` and ```scripts\metrics_all.sh```.
135 |
136 | ---
137 | ❗ **Faster inference**
138 |
139 | There is an option to accelerate inference speed by using more tight bounding boxes in the rasterisation. To do this follow next steps:
140 | * Open ```submodules/diff-bbsplat-rasterization/cuda_rasterizer/auxiliary.h```
141 | * Modify ```#define FAST_INFERENCE 0``` to be ```#define FAST_INFERENCE 1```
142 | * Rebuild the code with ```.\bbsplat_install.sh```
143 |
144 | This will give you up to $\times 2$ acceleration by the cost of slight metrics degradation.
145 |
146 | ---
147 |
148 | ### DTU Chamfer distance evaluation
149 |
150 | To calculate Chamfer distance metrics for the DTU dataset simple run ```scripts\dtu_eval.py``` as fallows:
151 | ```bash
152 | python scripts/dtu_eval.py --dtu= --output_path= --DTU_Official=
153 | ```
154 |
155 | ## Exporting to Blender
156 |
157 | The newest feature of the code is convertion of BBSplat into set of textured planes for rasterization in Blender:
158 |
159 |
160 |
161 |
162 |
163 |
164 | To do this follow these instructions. First you have to enable StopThePop sorting of billboards:
165 | * Open ```submodules/diff-bbsplat-rasterization/cuda_rasterizer/auxiliary.h```
166 | * Switch ```#define TILE_SORTING 0``` to ```#define TILE_SORTING 1```
167 | * Switch ```#define PIXEL_RESORTING 0``` to ```#define TILE_SORTING 1```
168 | * Make sure that ```#define FAST_INFERENCE 0``` is seit to 0
169 | * Rebuild the code with ```.\bbsplat_install.sh```
170 |
171 | Next simple run ```render.py``` with ```--save_planes``` flag. In the folder you will find ```planes_mesh.obj```. Import it to the blender.
172 |
173 | As the final step use alpha textures for alpha channel in Blender shader settings. Enable Raytracing and adjust sampling number for EEVEE renderer:
174 |
175 |
176 | ## Interactive visualization
177 |
178 | 
179 |
180 | To dynamically control camera position use ```visualize.py``` with the same ```-m -s``` parameters as ```render.py```
181 |
182 | ## Citation
183 | If you find our code or paper helps, please consider citing:
184 | ```bibtex
185 | @article{svitov2024billboard,
186 | title={BillBoard Splatting (BBSplat): Learnable Textured Primitives for Novel View Synthesis},
187 | author={Svitov, David and Morerio, Pietro and Agapito, Lourdes and Del Bue, Alessio},
188 | journal={arXiv preprint arXiv:2411.08508},
189 | year={2024}
190 | }
191 | ```
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from argparse import ArgumentParser, Namespace
13 | import sys
14 | import os
15 |
16 | class GroupParams:
17 | pass
18 |
19 | class ParamGroup:
20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21 | group = parser.add_argument_group(name)
22 | for key, value in vars(self).items():
23 | shorthand = False
24 | if key.startswith("_"):
25 | shorthand = True
26 | key = key[1:]
27 | t = type(value)
28 | value = value if not fill_none else None
29 | if shorthand:
30 | if t == bool:
31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32 | else:
33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34 | else:
35 | if t == bool:
36 | group.add_argument("--" + key, default=value, action="store_true")
37 | else:
38 | group.add_argument("--" + key, default=value, type=t)
39 |
40 | def extract(self, args):
41 | group = GroupParams()
42 | for arg in vars(args).items():
43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44 | setattr(group, arg[0], arg[1])
45 | return group
46 |
47 | class ModelParams(ParamGroup):
48 | def __init__(self, parser, sentinel=False):
49 | self.sh_degree = 3
50 | self._source_path = ""
51 | self._model_path = ""
52 | self._images = "images"
53 | self._resolution = -1
54 | self._white_background = False
55 | self.data_device = "cuda"
56 | self.eval = False
57 | super().__init__(parser, "Loading Parameters", sentinel)
58 |
59 | def extract(self, args):
60 | g = super().extract(args)
61 | g.source_path = os.path.abspath(g.source_path)
62 | return g
63 |
64 | class PipelineParams(ParamGroup):
65 | def __init__(self, parser):
66 | self.convert_SHs_python = False
67 | self.compute_cov3D_python = False
68 | self.depth_ratio = 0.0
69 | self.debug = False
70 | super().__init__(parser, "Pipeline Parameters")
71 |
72 | class OptimizationParams(ParamGroup):
73 | def __init__(self, parser):
74 | self.iterations = 32_000
75 | self.position_lr_init = 0.00016
76 | self.position_lr_final = 0.0000016
77 | self.position_lr_delay_mult = 0.01
78 | self.position_lr_max_steps = 30_000
79 | self.feature_lr = 0.005
80 | self.scaling_lr = 0.005
81 | self.rotation_lr = 0.001
82 | self.texture_opacity_lr = 0.001
83 | self.texture_color_lr = 0.0025
84 | self.percent_dense = 0.1
85 | self.lambda_dssim = 0.2
86 | self.lambda_dist = 0.0
87 | self.lambda_normal = 0.0
88 | self.lambda_texture_value = 0.0001
89 | self.lambda_alpha_value = 0.0001
90 | self.max_impact_threshold = 100
91 | self.sphere_point = 10000
92 |
93 | # Densification policy
94 | self.texture_from_iter = 500
95 | self.texture_to_iter = 30000
96 | self.densification_interval = 100
97 | self.densify_from_iter = 500
98 | self.densify_until_iter = 25000
99 | self.dead_opacity = 0.005
100 |
101 | # MCMC
102 | self.noise_lr = 5e5
103 | self.opacity_reg = 0.01
104 | self.cap_max = 160_000
105 |
106 | # Data
107 | self.max_read_points = self.cap_max - 20_000
108 | self.add_sky_box = False
109 |
110 | super().__init__(parser, "Optimization Parameters")
111 |
112 | def get_combined_args(parser : ArgumentParser):
113 | cmdlne_string = sys.argv[1:]
114 | cfgfile_string = "Namespace()"
115 | args_cmdline = parser.parse_args(cmdlne_string)
116 |
117 | try:
118 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
119 | print("Looking for config file in", cfgfilepath)
120 | with open(cfgfilepath) as cfg_file:
121 | print("Config file found: {}".format(cfgfilepath))
122 | cfgfile_string = cfg_file.read()
123 | except TypeError:
124 | print("Config file not found at")
125 | pass
126 | args_cfgfile = eval(cfgfile_string)
127 |
128 | merged_dict = vars(args_cfgfile).copy()
129 | for k,v in vars(args_cmdline).items():
130 | if v != None:
131 | merged_dict[k] = v
132 | return Namespace(**merged_dict)
133 |
--------------------------------------------------------------------------------
/assets/alpha_init_gaussian.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/alpha_init_gaussian.png
--------------------------------------------------------------------------------
/assets/alpha_init_gaussian_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/alpha_init_gaussian_small.png
--------------------------------------------------------------------------------
/assets/control_panel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/control_panel.png
--------------------------------------------------------------------------------
/assets/readme_images/blender_preset.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/blender_preset.jpg
--------------------------------------------------------------------------------
/assets/readme_images/scull.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/scull.gif
--------------------------------------------------------------------------------
/assets/readme_images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/teaser.png
--------------------------------------------------------------------------------
/assets/readme_images/train.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/train.gif
--------------------------------------------------------------------------------
/assets/readme_images/visualizer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-svitov/BBSplat/cff8d6a7bce27c56b6938482dbdc72b882cd53bb/assets/readme_images/visualizer.png
--------------------------------------------------------------------------------
/bbsplat_install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | pip install ./submodules/diff-bbsplat-rasterization
4 | pip install ./submodules/simple-knn
5 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import logging
14 | from argparse import ArgumentParser
15 | import shutil
16 |
17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository.
18 | parser = ArgumentParser("Colmap converter")
19 | parser.add_argument("--no_gpu", action='store_true')
20 | parser.add_argument("--skip_matching", action='store_true')
21 | parser.add_argument("--source_path", "-s", required=True, type=str)
22 | parser.add_argument("--camera", default="OPENCV", type=str)
23 | parser.add_argument("--colmap_executable", default="", type=str)
24 | parser.add_argument("--resize", action="store_true")
25 | parser.add_argument("--magick_executable", default="", type=str)
26 | args = parser.parse_args()
27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
29 | use_gpu = 1 if not args.no_gpu else 0
30 |
31 | input_folder = "/input"
32 | if not args.skip_matching:
33 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
34 |
35 | ## Feature extraction
36 | feat_extracton_cmd = colmap_command + " feature_extractor "\
37 | "--database_path " + args.source_path + "/distorted/database.db \
38 | --image_path " + args.source_path + input_folder + " \
39 | --ImageReader.single_camera 1 \
40 | --ImageReader.camera_model " + args.camera + " \
41 | --SiftExtraction.use_gpu " + str(use_gpu)
42 | exit_code = os.system(feat_extracton_cmd)
43 | if exit_code != 0:
44 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
45 | exit(exit_code)
46 |
47 | ## Feature matching
48 | feat_matching_cmd = colmap_command + " exhaustive_matcher \
49 | --database_path " + args.source_path + "/distorted/database.db \
50 | --SiftMatching.use_gpu " + str(use_gpu)
51 | exit_code = os.system(feat_matching_cmd)
52 | if exit_code != 0:
53 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
54 | exit(exit_code)
55 |
56 | ### Bundle adjustment
57 | # The default Mapper tolerance is unnecessarily large,
58 | # decreasing it speeds up bundle adjustment steps.
59 | mapper_cmd = (colmap_command + " mapper \
60 | --database_path " + args.source_path + "/distorted/database.db \
61 | --image_path " + args.source_path + input_folder + " \
62 | --output_path " + args.source_path + "/distorted/sparse \
63 | --Mapper.ba_global_function_tolerance=0.000001")
64 | exit_code = os.system(mapper_cmd)
65 | if exit_code != 0:
66 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
67 | exit(exit_code)
68 |
69 | ### Image undistortion
70 | ## We need to undistort our images into ideal pinhole intrinsics.
71 | img_undist_cmd = (colmap_command + " image_undistorter \
72 | --image_path " + args.source_path + input_folder + " \
73 | --input_path " + args.source_path + "/distorted/sparse/0 \
74 | --output_path " + args.source_path + "\
75 | --output_type COLMAP")
76 | exit_code = os.system(img_undist_cmd)
77 | if exit_code != 0:
78 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
79 | exit(exit_code)
80 |
81 | files = os.listdir(args.source_path + "/sparse")
82 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
83 | # Copy each file from the source directory to the destination directory
84 | for file in files:
85 | if file == '0':
86 | continue
87 | source_file = os.path.join(args.source_path, "sparse", file)
88 | destination_file = os.path.join(args.source_path, "sparse", "0", file)
89 | shutil.move(source_file, destination_file)
90 |
91 | if(args.resize):
92 | print("Copying and resizing...")
93 |
94 | # Resize images.
95 | os.makedirs(args.source_path + "/images_2", exist_ok=True)
96 | os.makedirs(args.source_path + "/images_4", exist_ok=True)
97 | os.makedirs(args.source_path + "/images_8", exist_ok=True)
98 | # Get the list of files in the source directory
99 | files = os.listdir(args.source_path + "/images")
100 | # Copy each file from the source directory to the destination directory
101 | for file in files:
102 | source_file = os.path.join(args.source_path, "images", file)
103 |
104 | destination_file = os.path.join(args.source_path, "images_2", file)
105 | shutil.copy2(source_file, destination_file)
106 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
107 | if exit_code != 0:
108 | logging.error(f"50% resize failed with code {exit_code}. Exiting.")
109 | exit(exit_code)
110 |
111 | destination_file = os.path.join(args.source_path, "images_4", file)
112 | shutil.copy2(source_file, destination_file)
113 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
114 | if exit_code != 0:
115 | logging.error(f"25% resize failed with code {exit_code}. Exiting.")
116 | exit(exit_code)
117 |
118 | destination_file = os.path.join(args.source_path, "images_8", file)
119 | shutil.copy2(source_file, destination_file)
120 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
121 | if exit_code != 0:
122 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
123 | exit(exit_code)
124 |
125 | print("Done.")
126 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
2 |
3 | ENV TZ=Europe/Rome
4 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
5 |
6 | SHELL ["/bin/bash", "--login", "-c"]
7 |
8 | RUN apt-get update && apt-get install -y \
9 | wget \
10 | htop \
11 | git \
12 | nano \
13 | cmake \
14 | unzip \
15 | zip \
16 | vim \
17 | libglu1-mesa-dev freeglut3-dev mesa-common-dev \
18 | libopencv-dev \
19 | libglew-dev \
20 | assimp-utils libassimp-dev \
21 | libboost-all-dev \
22 | libglfw3-dev \
23 | libgtk-3-dev \
24 | ffmpeg libavcodec-dev libavdevice-dev libavfilter-dev libavformat-dev libavutil-dev \
25 | libeigen3-dev \
26 | libgl1-mesa-dev xorg-dev \
27 | libembree-dev
28 |
29 | RUN ln -s /lib/x86_64-linux-gnu/libembree3.so /lib/x86_64-linux-gnu/libembree.so
30 |
31 | ENV PYTHONDONTWRITEBYTECODE=1
32 | ENV PYTHONUNBUFFERED=1
33 |
34 | ENV LD_LIBRARY_PATH /usr/lib64:$LD_LIBRARY_PATH
35 |
36 | ENV NVIDIA_VISIBLE_DEVICES all
37 | ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,graphics
38 |
39 | ENV PYOPENGL_PLATFORM egl
40 |
41 | #RUN ls /usr/share/glvnd/egl_vendor.d/
42 | #COPY docker/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json
43 |
44 | # fixuid
45 | ARG USERNAME=user
46 | RUN apt-get update && apt-get install -y sudo curl && \
47 | addgroup --gid 1000 $USERNAME && \
48 | adduser --uid 1000 --gid 1000 --disabled-password --gecos '' $USERNAME && \
49 | adduser $USERNAME sudo && \
50 | echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers && \
51 | USER=$USERNAME && \
52 | GROUP=$USERNAME && \
53 | curl -SsL https://github.com/boxboat/fixuid/releases/download/v0.4/fixuid-0.4-linux-amd64.tar.gz | tar -C /usr/local/bin -xzf - && \
54 | chown root:root /usr/local/bin/fixuid && \
55 | chmod 4755 /usr/local/bin/fixuid && \
56 | mkdir -p /etc/fixuid && \
57 | printf "user: $USER\ngroup: $GROUP\n" > /etc/fixuid/config.yml
58 | USER $USERNAME:$USERNAME
59 |
60 | # miniforge
61 | WORKDIR /home/$USERNAME
62 | ENV CONDA_AUTO_UPDATE_CONDA=false
63 | ENV PATH=/home/$USERNAME/miniforge/bin:$PATH
64 |
65 | RUN wget --quiet https://github.com/conda-forge/miniforge/releases/download/24.11.3-2/Miniforge3-24.11.3-2-Linux-x86_64.sh -O ~/miniforge.sh && \
66 | chmod +x ~/miniforge.sh && \
67 | ~/miniforge.sh -b -p ~/miniforge
68 |
69 | #RUN echo 112
70 | COPY docker/environment.yml /home/$USERNAME/environment.yml
71 | RUN conda env create -f /home/$USERNAME/environment.yml
72 | ENV PATH=/home/$USERNAME/miniforge/envs/bbsplat/bin:$PATH
73 |
74 | RUN echo "source activate bbsplat" > ~/.bashrc
75 | ENV PATH /opt/conda/envs/bbsplat/bin:$PATH
76 |
77 | # python libs
78 | RUN pip install --upgrade pip
79 |
80 |
81 | # docker setup
82 | WORKDIR /
83 | ENTRYPOINT ["fixuid", "-q"]
84 | CMD ["fixuid", "-q", "bash"]
85 |
--------------------------------------------------------------------------------
/docker/build.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
4 | source ${CURRENT_DIR}/source.sh
5 |
6 | DOCKER_BUILDKIT=0 docker build -t $NAME --build-arg ssh_prv_key="$(cat ~/.ssh/id_rsa)" --build-arg ssh_pub_key="$(cat ~/.ssh/id_rsa.pub)" -f ${CURRENT_DIR}/Dockerfile ${CURRENT_DIR}/..
7 |
--------------------------------------------------------------------------------
/docker/environment.yml:
--------------------------------------------------------------------------------
1 | name: bbsplat
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - nvidia
6 | - defaults
7 | - open3d-admin
8 | - anaconda
9 | dependencies:
10 | - pip=23.3.1
11 | - ffmpeg=6.1.1=h4c62175_0
12 | - jpeg=9e=h5eee18b_3
13 | - ncurses=6.4=h6a678d5_0
14 | - networkx=3.1=py38h06a4308_0
15 | - numpy=1.24.3=py38h14f4228_0
16 | - numpy-base=1.24.3=py38h31eccc5_0
17 | - openh264=2.1.1=h4ff587b_0
18 | - openjpeg=2.5.2=he7f1fd0_0
19 | - pillow=10.2.0=py38h5eee18b_0
20 | - plyfile=1.0.3=pyhd8ed1ab_0
21 | - python=3.8.18=h955ad1f_0
22 | - pytorch=2.0.0=py3.8_cuda11.8_cudnn8.7.0_0
23 | - pytorch-cuda=11.8=h7e8668a_5
24 | - torchaudio=2.0.0=py38_cu118
25 | - torchtriton=2.0.0=py38
26 | - torchvision=0.15.0=py38_cu118
27 | - typing_extensions=4.9.0=py38h06a4308_1
28 | - open3d=0.11.2
29 | - scikit-learn=1.3.0
30 | - addict=2.4.0
31 | - pandas=2.0.3
32 | - ninja=1.12.1
33 | - pip:
34 | - mediapy==1.1.2
35 | - opencv-python==4.9.0.80
36 | - scikit-image==0.21.0
37 | - tqdm==4.66.2
38 | - trimesh
39 | - xatlas
40 | - git+https://github.com/facebookresearch/pytorch3d.git
41 | - git+https://github.com/NVlabs/nvdiffrast.git
42 |
--------------------------------------------------------------------------------
/docker/push.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
4 | source ${CURRENT_DIR}/source.sh
5 |
6 | docker tag $NAME $HEAD_NAME
7 | docker push $HEAD_NAME
--------------------------------------------------------------------------------
/docker/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
4 | source ${CURRENT_DIR}/source.sh
5 |
6 | docker run -e DISPLAY=unix$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix -ti --gpus all $VOLUMES $NAME $@
7 |
8 |
--------------------------------------------------------------------------------
/docker/source.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | NAME="bbsplat"
4 | VOLUMES="-v ./..:/home/bbsplat -v /media/dsvitov/DATA:/media/dsvitov/DATA -w /home/bbsplat"
5 |
--------------------------------------------------------------------------------
/docker_colmap/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
4 | source ${CURRENT_DIR}/source.sh
5 |
6 | docker run -e DISPLAY=unix$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix -ti --gpus all $VOLUMES $NAME $@
7 |
8 |
--------------------------------------------------------------------------------
/docker_colmap/source.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | NAME="jsantisi/colmap-gpu"
4 | VOLUMES="-v /home/dsvitov/Code/textured-splatting:/home/dsvitov/Code/textured-splatting -v /home/dsvitov/Datasets:/home/dsvitov/Datasets -v /media/dsvitov/DATA1:/media/dsvitov/DATA -w /home/dsvitov/Code/textured-splatting"
5 |
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_bbsplat_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15 | from scene.gaussian_model import GaussianModel
16 | from utils.sh_utils import eval_sh
17 | from utils.point_utils import depth_to_normal
18 |
19 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, additional_return=True):
20 | """
21 | Render the scene.
22 |
23 | Background tensor (bg_color) must be on GPU!
24 | """
25 |
26 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
27 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
28 | try:
29 | screenspace_points.retain_grad()
30 | except:
31 | pass
32 |
33 | # Set up rasterization configuration
34 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
35 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
36 |
37 | raster_settings = GaussianRasterizationSettings(
38 | image_height=int(viewpoint_camera.image_height),
39 | image_width=int(viewpoint_camera.image_width),
40 | tanfovx=tanfovx,
41 | tanfovy=tanfovy,
42 | bg=bg_color,
43 | scale_modifier=scaling_modifier,
44 | viewmatrix=viewpoint_camera.world_view_transform,
45 | projmatrix=viewpoint_camera.full_proj_transform,
46 | sh_degree=pc.active_sh_degree,
47 | campos=viewpoint_camera.camera_center,
48 | prefiltered=False,
49 | debug=False,
50 | # pipe.debug
51 | )
52 |
53 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
54 |
55 | means3D = pc.get_xyz
56 | means2D = screenspace_points
57 |
58 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
59 | # scaling / rotation by the rasterizer.
60 | scales = None
61 | rotations = None
62 | cov3D_precomp = None
63 | if pipe.compute_cov3D_python:
64 | cov3D_precomp = pc.get_covariance(scaling_modifier)
65 | else:
66 | scales = pc.get_scaling
67 | rotations = pc.get_rotation
68 |
69 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
70 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
71 | pipe.convert_SHs_python = False
72 | shs = None
73 | colors_precomp = None
74 | if override_color is None:
75 | if pipe.convert_SHs_python:
76 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
77 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
78 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
79 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
80 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
81 | else:
82 | shs = pc.get_features
83 | else:
84 | colors_precomp = override_color
85 |
86 | try:
87 | means3D.retain_grad()
88 | except:
89 | pass
90 |
91 | texture_alpha = pc.get_texture_alpha
92 | texture_color = pc.get_texture_color
93 |
94 | start_timer = torch.cuda.Event(enable_timing=True)
95 | end_timer = torch.cuda.Event(enable_timing=True)
96 | start_timer.record()
97 |
98 | rendered_image, radii, impact, allmap = rasterizer(
99 | means3D = means3D,
100 | means2D = means2D,
101 | shs = shs,
102 | colors_precomp = colors_precomp,
103 | texture_alpha = texture_alpha,
104 | texture_color = texture_color,
105 | scales = scales,
106 | rotations = rotations,
107 | cov3D_precomp = cov3D_precomp,
108 | )
109 |
110 | end_timer.record()
111 | torch.cuda.synchronize()
112 | start_timer.elapsed_time(end_timer)
113 | fps = 1000 / start_timer.elapsed_time(end_timer)
114 |
115 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
116 | # They will be excluded from value updates used in the splitting criteria.
117 | rets = {"render": rendered_image,
118 | "viewspace_points": means2D,
119 | "visibility_filter" : impact > 0,
120 | "radii": radii,
121 | "impact": impact,
122 | "fps": fps,
123 | }
124 |
125 | if additional_return:
126 | # additional regularizations
127 | render_alpha = allmap[1:2]
128 |
129 | # get normal map
130 | render_normal = allmap[2:5]
131 | render_normal = (render_normal.permute(1,2,0) @ (viewpoint_camera.world_view_transform[:3,:3].T)).permute(2,0,1)
132 |
133 | # get median depth map
134 | render_depth_median = allmap[5:6]
135 | render_depth_median = torch.nan_to_num(render_depth_median, 0, 0)
136 |
137 | # get expected depth map
138 | render_depth_expected = allmap[0:1]
139 | render_depth_expected = (render_depth_expected / render_alpha)
140 | render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0)
141 |
142 | # get depth distortion map
143 | render_dist = allmap[6:7]
144 |
145 | # psedo surface attributes
146 | # surf depth is either median or expected by setting depth_ratio to 1 or 0
147 | # for bounded scene, use median depth, i.e., depth_ratio = 1;
148 | # for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing.
149 | surf_depth = render_depth_expected * (1-pipe.depth_ratio) + (pipe.depth_ratio) * render_depth_median
150 |
151 | # assume the depth points form the 'surface' and generate psudo surface normal for regularizations.
152 | surf_normal = depth_to_normal(viewpoint_camera, surf_depth)
153 | surf_normal = surf_normal.permute(2,0,1)
154 | # remember to multiply with accum_alpha since render_normal is unnormalized.
155 | surf_normal = surf_normal * (render_alpha).detach()
156 |
157 |
158 | rets.update({
159 | 'rend_alpha': render_alpha,
160 | 'rend_normal': render_normal,
161 | 'rend_dist': render_dist,
162 | 'surf_depth': surf_depth,
163 | 'surf_normal': surf_normal,
164 | })
165 |
166 | return rets
167 |
--------------------------------------------------------------------------------
/gaussian_renderer/network_gui.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import traceback
14 | import socket
15 | import json
16 | from scene.cameras import MiniCam
17 |
18 | host = "127.0.0.1"
19 | port = 6009
20 |
21 | conn = None
22 | addr = None
23 |
24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25 |
26 | def init(wish_host, wish_port):
27 | global host, port, listener
28 | host = wish_host
29 | port = wish_port
30 | listener.bind((host, port))
31 | listener.listen()
32 | listener.settimeout(0)
33 |
34 | def try_connect():
35 | global conn, addr, listener
36 | try:
37 | conn, addr = listener.accept()
38 | print(f"\nConnected by {addr}")
39 | conn.settimeout(None)
40 | except Exception as inst:
41 | pass
42 |
43 | def read():
44 | global conn
45 | messageLength = conn.recv(4)
46 | messageLength = int.from_bytes(messageLength, 'little')
47 | message = conn.recv(messageLength)
48 | return json.loads(message.decode("utf-8"))
49 |
50 | def send(message_bytes, verify):
51 | global conn
52 | if message_bytes != None:
53 | conn.sendall(message_bytes)
54 | conn.sendall(len(verify).to_bytes(4, 'little'))
55 | conn.sendall(bytes(verify, 'ascii'))
56 |
57 | def receive():
58 | message = read()
59 |
60 | width = message["resolution_x"]
61 | height = message["resolution_y"]
62 |
63 | if width != 0 and height != 0:
64 | try:
65 | do_training = bool(message["train"])
66 | fovy = message["fov_y"]
67 | fovx = message["fov_x"]
68 | znear = message["z_near"]
69 | zfar = message["z_far"]
70 | do_shs_python = bool(message["shs_python"])
71 | do_rot_scale_python = bool(message["rot_scale_python"])
72 | keep_alive = bool(message["keep_alive"])
73 | scaling_modifier = message["scaling_modifier"]
74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
75 | world_view_transform[:,1] = -world_view_transform[:,1]
76 | world_view_transform[:,2] = -world_view_transform[:,2]
77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
78 | full_proj_transform[:,1] = -full_proj_transform[:,1]
79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
80 | except Exception as e:
81 | print("")
82 | traceback.print_exc()
83 | raise e
84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
85 | else:
86 | return None, None, None, None, None, None
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1',
10 | size_average: bool = True):
11 | r"""Function that measures
12 | Learned Perceptual Image Patch Similarity (LPIPS).
13 |
14 | Arguments:
15 | x, y (torch.Tensor): the input tensors to compare.
16 | net_type (str): the network type to compare the features:
17 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
18 | version (str): the version of LPIPS. Default: 0.1.
19 | """
20 | device = x.device
21 | criterion = LPIPS(net_type, version).to(device)
22 | return criterion(x, y, size_average)
23 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as nnf
4 |
5 | from .networks import get_network, LinLayers
6 | from .utils import get_state_dict
7 |
8 |
9 | class LPIPS(nn.Module):
10 | r"""Creates a criterion that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | net_type (str): the network type to compare the features:
15 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
16 | version (str): the version of LPIPS. Default: 0.1.
17 | """
18 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
19 |
20 | assert version in ['0.1'], 'v0.1 is only supported now'
21 |
22 | super(LPIPS, self).__init__()
23 |
24 | # pretrained network
25 | self.net = get_network(net_type)
26 |
27 | # linear layers
28 | self.lin = LinLayers(self.net.n_channels_list)
29 | self.lin.load_state_dict(get_state_dict(net_type, version))
30 |
31 | def forward(self, x: torch.Tensor, y: torch.Tensor, size_average: bool = True):
32 | _, _, H, W = x.shape
33 | feat_x, feat_y = self.net(x), self.net(y)
34 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
35 |
36 | if size_average:
37 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
38 | return torch.sum(torch.cat(res, 0), 0, True)
39 | else:
40 | res = [l(d) for d, l in zip(diff, self.lin)]
41 | res = [nnf.interpolate(f, size=(H, W), mode='bicubic', align_corners=False) for f in res]
42 | return torch.sum(torch.cat(res, 0), 0)
43 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from pathlib import Path
13 | import os
14 | from PIL import Image
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | from utils.loss_utils import ssim
18 | from lpipsPyTorch import lpips
19 | import json
20 | from tqdm import tqdm
21 | from utils.image_utils import psnr
22 | from argparse import ArgumentParser
23 |
24 | def readImages(renders_dir, gt_dir):
25 | renders = []
26 | gts = []
27 | image_names = []
28 | for fname in os.listdir(renders_dir):
29 | render = Image.open(renders_dir / fname)
30 | gt = Image.open(gt_dir / fname)
31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
33 | image_names.append(fname)
34 | return renders, gts, image_names
35 |
36 | def evaluate(model_paths):
37 |
38 | full_dict = {}
39 | per_view_dict = {}
40 | full_dict_polytopeonly = {}
41 | per_view_dict_polytopeonly = {}
42 |
43 | for scene_dir in model_paths:
44 | try:
45 | print("Scene:", scene_dir)
46 | full_dict[scene_dir] = {}
47 | per_view_dict[scene_dir] = {}
48 | full_dict_polytopeonly[scene_dir] = {}
49 | per_view_dict_polytopeonly[scene_dir] = {}
50 |
51 | test_dir = Path(scene_dir) / "test"
52 |
53 | for method in os.listdir(test_dir):
54 | print("Method:", method)
55 |
56 | full_dict[scene_dir][method] = {}
57 | per_view_dict[scene_dir][method] = {}
58 | full_dict_polytopeonly[scene_dir][method] = {}
59 | per_view_dict_polytopeonly[scene_dir][method] = {}
60 |
61 | method_dir = test_dir / method
62 | gt_dir = method_dir/ "gt"
63 | renders_dir = method_dir / "renders"
64 | renders, gts, image_names = readImages(renders_dir, gt_dir)
65 |
66 | ssims = []
67 | psnrs = []
68 | lpipss = []
69 |
70 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
71 | ssims.append(ssim(renders[idx], gts[idx]))
72 | psnrs.append(psnr(renders[idx], gts[idx]))
73 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
74 |
75 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
76 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
77 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
78 | print("")
79 |
80 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
81 | "PSNR": torch.tensor(psnrs).mean().item(),
82 | "LPIPS": torch.tensor(lpipss).mean().item()})
83 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
84 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
85 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
86 |
87 | with open(scene_dir + "/results.json", 'w') as fp:
88 | json.dump(full_dict[scene_dir], fp, indent=True)
89 | with open(scene_dir + "/per_view.json", 'w') as fp:
90 | json.dump(per_view_dict[scene_dir], fp, indent=True)
91 | except:
92 | print("Unable to compute metrics for model", scene_dir)
93 |
94 | if __name__ == "__main__":
95 | device = torch.device("cuda:0")
96 | torch.cuda.set_device(device)
97 |
98 | # Set up command line argument parser
99 | parser = ArgumentParser(description="Training script parameters")
100 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
101 | args = parser.parse_args()
102 | evaluate(args.model_paths)
103 |
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import json
12 | import math
13 | import os
14 | from argparse import ArgumentParser
15 |
16 | import cv2
17 | import numpy as np
18 | import nvdiffrast.torch as dr
19 | import open3d as o3d
20 | import torch
21 | import torch.nn.functional as F
22 | import xatlas
23 | from pytorch3d.io import save_obj
24 | from tqdm import tqdm
25 |
26 | from arguments import ModelParams, PipelineParams, get_combined_args
27 | from gaussian_renderer import GaussianModel
28 | from gaussian_renderer import render
29 | from scene import Scene
30 | from utils.general_utils import build_scaling_rotation
31 | from utils.loss_utils import l1_loss
32 | from utils.mesh_utils import GaussianExtractor, post_process_mesh
33 | from utils.render_utils import generate_path, create_videos, save_img_u8
34 | from utils.sh_utils import SH2RGB
35 |
36 |
37 | def unwrap_uvmap(mesh, device="cuda"):
38 | v_np = np.asarray(mesh.vertices) # [N, 3]
39 | f_np = np.asarray(mesh.triangles) # [M, 3]
40 |
41 | print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
42 |
43 | # unwrap uv in contracted space
44 | atlas = xatlas.Atlas()
45 | atlas.add_mesh(v_np, f_np)
46 | chart_options = xatlas.ChartOptions()
47 | chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
48 | pack_options = xatlas.PackOptions()
49 | # pack_options.blockAlign = True
50 | # pack_options.bruteForce = False
51 | atlas.generate(chart_options=chart_options, pack_options=pack_options)
52 | vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
53 |
54 | vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
55 | ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
56 |
57 | print("UV shape:", vt.shape)
58 |
59 | v_torch = torch.from_numpy(v_np.astype(np.float32)).to(device)
60 | f_torch = torch.from_numpy(f_np).to(device)
61 |
62 | return v_torch, f_torch, vt, ft
63 |
64 | def render_mesh(v_torch, f_torch, uv, uv_idx, cudactx, texture, cam):
65 | mvp = cam.full_proj_transform
66 | vertices_clip = torch.matmul(F.pad(v_torch, pad=(0, 1), mode='constant', value=1.0), mvp).float().unsqueeze(0)
67 | rast, _ = dr.rasterize(cudactx, vertices_clip, f_torch, resolution=[cam.image_height, cam.image_width])
68 | texc, _ = dr.interpolate(uv[None, ...], rast, uv_idx)
69 | color = dr.texture(texture[None, ...], texc, filter_mode='linear')[0]
70 | return color
71 |
72 | def train_texture(v_torch, f_torch, uv, uv_idx, cudactx, texture, scene):
73 | optimizer = torch.optim.Adam([texture], lr=0.01)
74 | for epoch in tqdm(range(300)):
75 | for cam in scene.getTrainCameras():
76 | optimizer.zero_grad()
77 | color = render_mesh(v_torch, f_torch, uv, uv_idx, cudactx, F.sigmoid(texture), cam)
78 |
79 | gt = torch.permute(cam.original_image.cuda(), (1, 2, 0))
80 | Ll1 = l1_loss(color, gt)
81 | #ssim_map = ssim(color, gt, size_average=False).mean()
82 | loss = Ll1 # * 0.8 + ssim_map * 0.2
83 | loss.backward()
84 | optimizer.step()
85 |
86 | def billboard_to_plane(xyz, transform, rgb, alpha, texture_size, num_textures_x, vertices, faces, stitched_texture, uv, uv_idx):
87 | vertices_local = torch.tensor([[-1, -1, 0], [1, 1, 0], [1, -1, 0], [-1, 1, 0]], dtype=torch.float32).cuda()
88 | faces_local = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int32).cuda()
89 |
90 | # Scaling + Rotation
91 | vertices_local = vertices_local @ transform.T
92 | # Offset
93 | vertices_local += xyz
94 |
95 | # Add to the "mesh"
96 | faces_local += 4 * len(faces)
97 | faces.append(faces_local)
98 | vertices.append(vertices_local)
99 |
100 | # Add tile to the texture
101 | num = len(vertices) - 1
102 | y = num // num_textures_x
103 | x = num % num_textures_x
104 | h, w = alpha.shape
105 | stitched_texture[:3, y*texture_size: y*texture_size + h, x*texture_size: x*texture_size + w] = rgb
106 | stitched_texture[3:, y*texture_size: y*texture_size + h, x*texture_size: x*texture_size + w] = alpha[None]
107 |
108 | u = x*texture_size / stitched_texture.shape[2]
109 | v = y*texture_size / stitched_texture.shape[1]
110 | offset_u = h / stitched_texture.shape[2]
111 | offset_v = w / stitched_texture.shape[1]
112 | uv_local = torch.tensor([[u, v], [u + offset_u, v + offset_v], [u + offset_u, v], [u, v + offset_v]], dtype=torch.float32).cuda()
113 | uv.append(uv_local)
114 | uv_idx.append(faces_local)
115 |
116 | def billboards_to_mesh(gaussians, save_folder):
117 | num_points = len(gaussians.get_xyz)
118 | gaps = 2
119 | texture_size = gaussians.get_texture_alpha.shape[-1] + gaps
120 | num_textures_x = int(math.sqrt(num_points))
121 | globa_texture_size = num_textures_x * texture_size
122 | global_rgba = torch.zeros([4, globa_texture_size + texture_size*2, globa_texture_size]).cuda()
123 |
124 | transform = build_scaling_rotation(gaussians.get_scaling, gaussians.get_rotation)
125 |
126 | vertices = []
127 | faces = []
128 | uv = []
129 | uv_idx = []
130 | for i in tqdm(range(num_points)):
131 | #if gaussians.get_scaling[i].min() > 1:
132 | # continue
133 | billboard_to_plane(
134 | gaussians.get_xyz[i], transform[i], gaussians.get_texture_color[i] + SH2RGB(gaussians.get_features_first[i])[0, :, None, None],
135 | gaussians.get_texture_alpha[i], texture_size, num_textures_x,
136 | vertices, faces, global_rgba, uv, uv_idx,
137 | )
138 | vertices = torch.concat(vertices)
139 | faces = torch.concat(faces)
140 | uv = torch.concat(uv)
141 | uv_idx = torch.concat(uv_idx)
142 |
143 | print(vertices.shape, faces.shape)
144 |
145 | global_rgba = torch.permute(global_rgba, (1, 2, 0))
146 | global_rgba = torch.flip(global_rgba, [0])
147 | save_obj(
148 | os.path.join(save_folder, "planes_mesh.obj"),
149 | verts=vertices,
150 | faces=faces,
151 | verts_uvs=uv,
152 | faces_uvs=uv_idx,
153 | texture_map=global_rgba[..., :3],
154 | )
155 | print(global_rgba.shape)
156 | global_rgba = global_rgba.detach().cpu().numpy()
157 | global_rgba[..., :3] = cv2.cvtColor(global_rgba[..., :3], cv2.COLOR_BGR2RGB)
158 | cv2.imwrite(os.path.join(save_folder, "planes_mesh.png"), global_rgba * 255)
159 |
160 | def prune_based_on_visibility(scene, gaussians, pipe, background):
161 | with torch.no_grad():
162 | # Calculate impact
163 | acc_impact = None
164 | for camera in scene.getTrainCameras():
165 | render_pkg = render(camera, gaussians, pipe, background)
166 | impact = render_pkg["impact"]
167 | if acc_impact is None:
168 | acc_impact = impact
169 | else:
170 | acc_impact += impact
171 |
172 | prob = acc_impact / acc_impact.sum()
173 | mask = prob > 1e-6
174 |
175 | mask = mask & (torch.amax(gaussians.get_texture_alpha, dim=(1, 2)) > 0.2)
176 | gaussians.prune_postproc(mask)
177 |
178 | if __name__ == "__main__":
179 | # Set up command line argument parser
180 | parser = ArgumentParser(description="Testing script parameters")
181 | model = ModelParams(parser, sentinel=True)
182 | pipeline = PipelineParams(parser)
183 | parser.add_argument("--iteration", default=-1, type=int)
184 | parser.add_argument("--skip_train", action="store_true")
185 | parser.add_argument("--skip_test", action="store_true")
186 | parser.add_argument("--skip_mesh", action="store_true")
187 | parser.add_argument("--save_planes", action="store_true")
188 | parser.add_argument("--quiet", action="store_true")
189 | parser.add_argument("--render_path", action="store_true")
190 | parser.add_argument("--voxel_size", default=0.004, type=float, help='Mesh: voxel size for TSDF')
191 | parser.add_argument("--depth_trunc", default=3.0, type=float, help='Mesh: Max depth range for TSDF')
192 | parser.add_argument("--sdf_trunc", default=-1.0, type=float, help='Mesh: truncation value for TSDF')
193 | parser.add_argument("--num_cluster", default=1000, type=int, help='Mesh: number of connected clusters to export')
194 | parser.add_argument("--unbounded", action="store_true", help='Mesh: using unbounded mode for meshing')
195 | parser.add_argument("--mesh_res", default=1024, type=int, help='Mesh: resolution for unbounded mesh extraction')
196 | args = get_combined_args(parser)
197 | print("Rendering " + args.model_path)
198 |
199 |
200 | dataset, iteration, pipe = model.extract(args), args.iteration, pipeline.extract(args)
201 | gaussians = GaussianModel(dataset.sh_degree, texture_preproc=True)
202 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
203 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
204 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
205 |
206 | train_dir = os.path.join(args.model_path, 'train', "ours_{}".format(scene.loaded_iter))
207 | test_dir = os.path.join(args.model_path, 'test', "ours_{}".format(scene.loaded_iter))
208 | gaussExtractor = GaussianExtractor(gaussians, render, pipe, bg_color=bg_color, additional_return=True)
209 |
210 | speed_data = {"points": len(gaussians.get_xyz)}
211 |
212 | if not args.skip_train:
213 | print("export training images ...")
214 | os.makedirs(train_dir, exist_ok=True)
215 | mean_time, std_time = gaussExtractor.reconstruction(scene.getTrainCameras())
216 | speed_data["train_time"] = mean_time
217 | speed_data["train_time_std"] = std_time
218 | gaussExtractor.export_image(train_dir)
219 |
220 |
221 | if (not args.skip_test) and (len(scene.getTestCameras()) > 0):
222 | print("export rendered testing images ...")
223 | os.makedirs(test_dir, exist_ok=True)
224 | mean_time, std_time = gaussExtractor.reconstruction(scene.getTestCameras())
225 | speed_data["test_time"] = mean_time
226 | speed_data["test_time_std"] = std_time
227 | gaussExtractor.export_image(test_dir)
228 |
229 | with open(os.path.join(args.model_path, "speed.json"), "w") as f:
230 | json.dump(speed_data, f)
231 |
232 | if args.render_path:
233 | print("render videos ...")
234 | traj_dir = os.path.join(args.model_path, 'traj', "ours_{}".format(scene.loaded_iter))
235 | os.makedirs(traj_dir, exist_ok=True)
236 | n_fames = 480
237 | cam_traj = generate_path(scene.getTrainCameras(), n_frames=n_fames)
238 | gaussExtractor.reconstruction(cam_traj)
239 | gaussExtractor.export_image(traj_dir, export_gt=False) #, print_fps=True
240 | create_videos(base_dir=traj_dir,
241 | input_dir=traj_dir,
242 | out_name='render_traj',
243 | num_frames=n_fames)
244 |
245 | if args.save_planes:
246 | # CONVERT TO SET OF PLANES
247 | prune_based_on_visibility(scene, gaussians, pipe, background)
248 | billboards_to_mesh(gaussians, args.model_path)
249 |
250 | if not args.skip_mesh:
251 | print("export mesh ...")
252 | os.makedirs(train_dir, exist_ok=True)
253 | # set the active_sh to 0 to export only diffuse texture
254 | gaussExtractor.gaussians.active_sh_degree = 0
255 | gaussExtractor.reconstruction(scene.getTrainCameras())
256 | print("ckpt 1 ...")
257 | # extract the mesh and save
258 | if args.unbounded:
259 | name = 'fuse_unbounded.ply'
260 | mesh = gaussExtractor.extract_mesh_unbounded(resolution=args.mesh_res)
261 | else:
262 | name = 'fuse.ply'
263 | #mesh = gaussExtractor.extract_mesh_bounded(voxel_size=args.voxel_size, sdf_trunc=5*args.voxel_size, depth_trunc=args.depth_trunc)
264 | depth_trunc = (gaussExtractor.radius * 2.0) if args.depth_trunc < 0 else args.depth_trunc
265 | voxel_size = (depth_trunc / args.mesh_res) if args.voxel_size < 0 else args.voxel_size
266 | sdf_trunc = 5.0 * voxel_size if args.sdf_trunc < 0 else args.sdf_trunc
267 | mesh = gaussExtractor.extract_mesh_bounded(voxel_size=voxel_size, sdf_trunc=sdf_trunc, depth_trunc=depth_trunc)
268 |
269 | print("ckpt 2 ...")
270 | o3d.io.write_triangle_mesh(os.path.join(train_dir, name), mesh)
271 | print("mesh saved at {}".format(os.path.join(train_dir, name)))
272 | # post-process the mesh and save, saving the largest N clusters
273 | mesh_post = post_process_mesh(mesh, cluster_to_keep=args.num_cluster)
274 | o3d.io.write_triangle_mesh(os.path.join(train_dir, name.replace('.ply', '_post.ply')), mesh_post)
275 | print("mesh post processed saved at {}".format(os.path.join(train_dir, name.replace('.ply', '_post.ply'))))
276 |
277 | # TEXTURE EXTRACTION
278 | device = "cuda"
279 | # Unwrap the uv-map for the mesh
280 | v_cuda, f_cuda, uv, uv_idx = unwrap_uvmap(mesh, device)
281 |
282 | texture = 0.5 + torch.randn((1024, 1024, 3), dtype=torch.float32, device=device) * 0.001
283 | texture = torch.nn.Parameter(texture, requires_grad=True)
284 |
285 | cudactx = dr.RasterizeCudaContext()
286 |
287 | # Train texture from input images
288 | train_texture(v_cuda, f_cuda, uv, uv_idx, cudactx, texture, scene)
289 | texture = F.sigmoid(texture)
290 |
291 | # Render textured mesh to the folder
292 | mesh_path = os.path.join(train_dir, "mesh")
293 | os.makedirs(mesh_path, exist_ok=True)
294 | for idx, cam in enumerate(scene.getTrainCameras()):
295 | mvp = cam.full_proj_transform
296 | color = render_mesh(v_cuda, f_cuda, uv, uv_idx, cudactx, texture, cam)
297 | color = torch.permute(color, (2, 0, 1))
298 | save_img_u8(color, os.path.join(mesh_path, '{0:05d}'.format(idx) + ".png"))
299 |
300 | save_obj(
301 | os.path.join(args.model_path, "textured_mesh.obj"),
302 | verts=v_cuda,
303 | faces=f_cuda,
304 | verts_uvs=uv,
305 | faces_uvs=uv_idx,
306 | texture_map=torch.flip(texture, [0]),
307 | )
308 |
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.gaussian_model import GaussianModel
18 | from arguments import ModelParams
19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
20 | import numpy as np
21 | import cv2
22 |
23 | class Scene:
24 |
25 | gaussians : GaussianModel
26 |
27 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0],
28 | add_sky_box=False, max_read_points=60_000, sphere_point=10_000):
29 | """b
30 | :param path: Path to colmap scene main folder.
31 | """
32 | self.model_path = args.model_path
33 | self.loaded_iter = None
34 | self.gaussians = gaussians
35 |
36 | if load_iteration:
37 | if load_iteration == -1:
38 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
39 | else:
40 | self.loaded_iter = load_iteration
41 | print("Loading trained model at iteration {}".format(self.loaded_iter))
42 |
43 | self.train_cameras = {}
44 | self.test_cameras = {}
45 |
46 | if os.path.exists(os.path.join(args.source_path, "sparse")):
47 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, max_points=max_read_points)
48 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
49 | print("Found transforms_train.json file, assuming Blender data set!")
50 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, max_points=max_read_points)
51 | elif os.path.exists(os.path.join(args.source_path, "inputs/sfm_scene.json")):
52 | print("Found sfm_scene.json file, assuming NeILF data set!")
53 | scene_info = sceneLoadTypeCallbacks["NeILF"](args.source_path, args.white_background, args.eval)
54 | else:
55 | assert False, "Could not recognize scene type!"
56 |
57 | if not self.loaded_iter:
58 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
59 | dest_file.write(src_file.read())
60 | json_cams = []
61 | camlist = []
62 | if scene_info.test_cameras:
63 | camlist.extend(scene_info.test_cameras)
64 | if scene_info.train_cameras:
65 | camlist.extend(scene_info.train_cameras)
66 | for id, cam in enumerate(camlist):
67 | json_cams.append(camera_to_JSON(id, cam))
68 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
69 | json.dump(json_cams, file)
70 |
71 | if shuffle:
72 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
73 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
74 |
75 | self.cameras_extent = scene_info.nerf_normalization["radius"]
76 |
77 | for resolution_scale in resolution_scales:
78 | print("Loading Training Cameras")
79 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
80 | print("Loading Test Cameras")
81 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
82 |
83 | if self.loaded_iter:
84 | folder_path = os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter))
85 | self.gaussians.load_ply(os.path.join(folder_path, "point_cloud.ply"))
86 | self.gaussians.load_texture(folder_path)
87 | else:
88 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, add_sky_box=add_sky_box, sphere_point=sphere_point)
89 |
90 | def save(self, iteration):
91 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
92 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
93 | self.gaussians.save_texture(point_cloud_path)
94 |
95 | def getTrainCameras(self, scale=1.0):
96 | return self.train_cameras[scale]
97 |
98 | def getTestCameras(self, scale=1.0):
99 | return self.test_cameras[scale]
100 |
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix
16 |
17 | class Camera(nn.Module):
18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
19 | image_name, uid,
20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
21 | ):
22 | super(Camera, self).__init__()
23 |
24 | self.uid = uid
25 | self.colmap_id = colmap_id
26 | self.R = R
27 | self.T = T
28 | self.FoVx = FoVx
29 | self.FoVy = FoVy
30 | self.image_name = image_name
31 |
32 | try:
33 | self.data_device = torch.device(data_device)
34 | except Exception as e:
35 | print(e)
36 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
37 | self.data_device = torch.device("cuda")
38 |
39 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
40 | self.image_width = self.original_image.shape[2]
41 | self.image_height = self.original_image.shape[1]
42 |
43 | if gt_alpha_mask is not None:
44 | # self.original_image *= gt_alpha_mask.to(self.data_device)
45 | self.gt_alpha_mask = gt_alpha_mask.to(self.data_device)
46 | else:
47 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
48 | self.gt_alpha_mask = None
49 |
50 | self.zfar = 100.0
51 | self.znear = 0.01
52 |
53 | self.trans = trans
54 | self.scale = scale
55 |
56 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
57 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
58 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
59 | self.camera_center = self.world_view_transform.inverse()[3, :3]
60 |
61 | def update_proj_matrix(self):
62 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
63 | self.camera_center = self.world_view_transform.inverse()[3, :3]
64 |
65 | class MiniCam:
66 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
67 | self.image_width = width
68 | self.image_height = height
69 | self.FoVy = fovy
70 | self.FoVx = fovx
71 | self.znear = znear
72 | self.zfar = zfar
73 | self.world_view_transform = world_view_transform
74 | self.full_proj_transform = full_proj_transform
75 | view_inv = torch.inverse(self.world_view_transform)
76 | self.camera_center = view_inv[3][:3]
77 |
78 |
--------------------------------------------------------------------------------
/scene/colmap_loader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import collections
14 | import struct
15 |
16 | CameraModel = collections.namedtuple(
17 | "CameraModel", ["model_id", "model_name", "num_params"])
18 | Camera = collections.namedtuple(
19 | "Camera", ["id", "model", "width", "height", "params"])
20 | BaseImage = collections.namedtuple(
21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22 | Point3D = collections.namedtuple(
23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24 | CAMERA_MODELS = {
25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32 | CameraModel(model_id=7, model_name="FOV", num_params=5),
33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36 | }
37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38 | for camera_model in CAMERA_MODELS])
39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40 | for camera_model in CAMERA_MODELS])
41 |
42 |
43 | def qvec2rotmat(qvec):
44 | return np.array([
45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
54 |
55 | def rotmat2qvec(R):
56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
57 | K = np.array([
58 | [Rxx - Ryy - Rzz, 0, 0, 0],
59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
62 | eigvals, eigvecs = np.linalg.eigh(K)
63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
64 | if qvec[0] < 0:
65 | qvec *= -1
66 | return qvec
67 |
68 | class Image(BaseImage):
69 | def qvec2rotmat(self):
70 | return qvec2rotmat(self.qvec)
71 |
72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
73 | """Read and unpack the next bytes from a binary file.
74 | :param fid:
75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
77 | :param endian_character: Any of {@, =, <, >, !}
78 | :return: Tuple of read and unpacked values.
79 | """
80 | data = fid.read(num_bytes)
81 | return struct.unpack(endian_character + format_char_sequence, data)
82 |
83 | def read_points3D_text(path):
84 | """
85 | see: src/base/reconstruction.cc
86 | void Reconstruction::ReadPoints3DText(const std::string& path)
87 | void Reconstruction::WritePoints3DText(const std::string& path)
88 | """
89 | xyzs = None
90 | rgbs = None
91 | errors = None
92 | num_points = 0
93 | with open(path, "r") as fid:
94 | while True:
95 | line = fid.readline()
96 | if not line:
97 | break
98 | line = line.strip()
99 | if len(line) > 0 and line[0] != "#":
100 | num_points += 1
101 |
102 |
103 | xyzs = np.empty((num_points, 3))
104 | rgbs = np.empty((num_points, 3))
105 | errors = np.empty((num_points, 1))
106 | count = 0
107 | with open(path, "r") as fid:
108 | while True:
109 | line = fid.readline()
110 | if not line:
111 | break
112 | line = line.strip()
113 | if len(line) > 0 and line[0] != "#":
114 | elems = line.split()
115 | xyz = np.array(tuple(map(float, elems[1:4])))
116 | rgb = np.array(tuple(map(int, elems[4:7])))
117 | error = np.array(float(elems[7]))
118 | xyzs[count] = xyz
119 | rgbs[count] = rgb
120 | errors[count] = error
121 | count += 1
122 |
123 | return xyzs, rgbs, errors
124 |
125 | def read_points3D_binary(path_to_model_file):
126 | """
127 | see: src/base/reconstruction.cc
128 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
129 | void Reconstruction::WritePoints3DBinary(const std::string& path)
130 | """
131 |
132 |
133 | with open(path_to_model_file, "rb") as fid:
134 | num_points = read_next_bytes(fid, 8, "Q")[0]
135 |
136 | xyzs = np.empty((num_points, 3))
137 | rgbs = np.empty((num_points, 3))
138 | errors = np.empty((num_points, 1))
139 |
140 | for p_id in range(num_points):
141 | binary_point_line_properties = read_next_bytes(
142 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
143 | xyz = np.array(binary_point_line_properties[1:4])
144 | rgb = np.array(binary_point_line_properties[4:7])
145 | error = np.array(binary_point_line_properties[7])
146 | track_length = read_next_bytes(
147 | fid, num_bytes=8, format_char_sequence="Q")[0]
148 | track_elems = read_next_bytes(
149 | fid, num_bytes=8*track_length,
150 | format_char_sequence="ii"*track_length)
151 | xyzs[p_id] = xyz
152 | rgbs[p_id] = rgb
153 | errors[p_id] = error
154 | return xyzs, rgbs, errors
155 |
156 | def read_intrinsics_text(path):
157 | """
158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
159 | """
160 | cameras = {}
161 | with open(path, "r") as fid:
162 | while True:
163 | line = fid.readline()
164 | if not line:
165 | break
166 | line = line.strip()
167 | if len(line) > 0 and line[0] != "#":
168 | elems = line.split()
169 | camera_id = int(elems[0])
170 | model = elems[1]
171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
172 | width = int(elems[2])
173 | height = int(elems[3])
174 | params = np.array(tuple(map(float, elems[4:])))
175 | cameras[camera_id] = Camera(id=camera_id, model=model,
176 | width=width, height=height,
177 | params=params)
178 | return cameras
179 |
180 | def read_extrinsics_binary(path_to_model_file):
181 | """
182 | see: src/base/reconstruction.cc
183 | void Reconstruction::ReadImagesBinary(const std::string& path)
184 | void Reconstruction::WriteImagesBinary(const std::string& path)
185 | """
186 | images = {}
187 | with open(path_to_model_file, "rb") as fid:
188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
189 | for _ in range(num_reg_images):
190 | binary_image_properties = read_next_bytes(
191 | fid, num_bytes=64, format_char_sequence="idddddddi")
192 | image_id = binary_image_properties[0]
193 | qvec = np.array(binary_image_properties[1:5])
194 | tvec = np.array(binary_image_properties[5:8])
195 | camera_id = binary_image_properties[8]
196 | image_name = ""
197 | current_char = read_next_bytes(fid, 1, "c")[0]
198 | while current_char != b"\x00": # look for the ASCII 0 entry
199 | image_name += current_char.decode("utf-8")
200 | current_char = read_next_bytes(fid, 1, "c")[0]
201 | num_points2D = read_next_bytes(fid, num_bytes=8,
202 | format_char_sequence="Q")[0]
203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
204 | format_char_sequence="ddq"*num_points2D)
205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
206 | tuple(map(float, x_y_id_s[1::3]))])
207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
208 | images[image_id] = Image(
209 | id=image_id, qvec=qvec, tvec=tvec,
210 | camera_id=camera_id, name=image_name,
211 | xys=xys, point3D_ids=point3D_ids)
212 | return images
213 |
214 |
215 | def read_intrinsics_binary(path_to_model_file):
216 | """
217 | see: src/base/reconstruction.cc
218 | void Reconstruction::WriteCamerasBinary(const std::string& path)
219 | void Reconstruction::ReadCamerasBinary(const std::string& path)
220 | """
221 | cameras = {}
222 | with open(path_to_model_file, "rb") as fid:
223 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
224 | for _ in range(num_cameras):
225 | camera_properties = read_next_bytes(
226 | fid, num_bytes=24, format_char_sequence="iiQQ")
227 | camera_id = camera_properties[0]
228 | model_id = camera_properties[1]
229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
230 | width = camera_properties[2]
231 | height = camera_properties[3]
232 | num_params = CAMERA_MODEL_IDS[model_id].num_params
233 | params = read_next_bytes(fid, num_bytes=8*num_params,
234 | format_char_sequence="d"*num_params)
235 | cameras[camera_id] = Camera(id=camera_id,
236 | model=model_name,
237 | width=width,
238 | height=height,
239 | params=np.array(params))
240 | assert len(cameras) == num_cameras
241 | return cameras
242 |
243 |
244 | def read_extrinsics_text(path):
245 | """
246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
247 | """
248 | images = {}
249 | with open(path, "r") as fid:
250 | while True:
251 | line = fid.readline()
252 | if not line:
253 | break
254 | line = line.strip()
255 | if len(line) > 0 and line[0] != "#":
256 | elems = line.split()
257 | image_id = int(elems[0])
258 | qvec = np.array(tuple(map(float, elems[1:5])))
259 | tvec = np.array(tuple(map(float, elems[5:8])))
260 | camera_id = int(elems[8])
261 | image_name = elems[9]
262 | elems = fid.readline().split()
263 | xys = np.column_stack([tuple(map(float, elems[0::3])),
264 | tuple(map(float, elems[1::3]))])
265 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
266 | images[image_id] = Image(
267 | id=image_id, qvec=qvec, tvec=tvec,
268 | camera_id=camera_id, name=image_name,
269 | xys=xys, point3D_ids=point3D_ids)
270 | return images
271 |
272 |
273 | def read_colmap_bin_array(path):
274 | """
275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
276 |
277 | :param path: path to the colmap binary file.
278 | :return: nd array with the floating point values in the value
279 | """
280 | with open(path, "rb") as fid:
281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
282 | usecols=(0, 1, 2), dtype=int)
283 | fid.seek(0)
284 | num_delimiter = 0
285 | byte = fid.read(1)
286 | while True:
287 | if byte == b"&":
288 | num_delimiter += 1
289 | if num_delimiter >= 3:
290 | break
291 | byte = fid.read(1)
292 | array = np.fromfile(fid, np.float32)
293 | array = array.reshape((width, height, channels), order="F")
294 | return np.transpose(array, (1, 0, 2)).squeeze()
295 |
--------------------------------------------------------------------------------
/scene/dataset_readers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import json
13 | import os
14 | import sys
15 | from pathlib import Path
16 | from typing import NamedTuple
17 |
18 | import numpy as np
19 | import torch
20 | from PIL import Image
21 | from plyfile import PlyData, PlyElement
22 | from pytorch3d.ops import sample_farthest_points
23 | import imageio.v2 as imageio
24 | import glob
25 | import re
26 |
27 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
28 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
29 | from scene.gaussian_model import BasicPointCloud
30 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
31 | from utils.sh_utils import SH2RGB
32 |
33 |
34 | class CameraInfo(NamedTuple):
35 | uid: int
36 | R: np.array
37 | T: np.array
38 | FovY: np.array
39 | FovX: np.array
40 | image: np.array
41 | image_path: str
42 | image_name: str
43 | width: int
44 | height: int
45 | image_id: int = None
46 | normal: Image.Image = None
47 | alpha: Image.Image = None
48 | depth: np.array = None
49 |
50 | class SceneInfo(NamedTuple):
51 | point_cloud: BasicPointCloud
52 | train_cameras: list
53 | test_cameras: list
54 | nerf_normalization: dict
55 | ply_path: str
56 |
57 | def getNerfppNorm(cam_info):
58 | def get_center_and_diag(cam_centers):
59 | cam_centers = np.hstack(cam_centers)
60 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
61 | center = avg_cam_center
62 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
63 | diagonal = np.max(dist)
64 | return center.flatten(), diagonal
65 |
66 | cam_centers = []
67 |
68 | for cam in cam_info:
69 | W2C = getWorld2View2(cam.R, cam.T)
70 | C2W = np.linalg.inv(W2C)
71 | cam_centers.append(C2W[:3, 3:4])
72 |
73 | center, diagonal = get_center_and_diag(cam_centers)
74 | radius = diagonal * 1.1
75 |
76 | translate = -center
77 |
78 | return {"translate": translate, "radius": radius}
79 |
80 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
81 | cam_infos = []
82 | for idx, key in enumerate(cam_extrinsics):
83 | sys.stdout.write('\r')
84 | # the exact output you're looking for:
85 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
86 | sys.stdout.flush()
87 |
88 | extr = cam_extrinsics[key]
89 | intr = cam_intrinsics[extr.camera_id]
90 | height = intr.height
91 | width = intr.width
92 |
93 | uid = intr.id
94 | R = np.transpose(qvec2rotmat(extr.qvec))
95 | T = np.array(extr.tvec)
96 |
97 | if intr.model=="SIMPLE_PINHOLE":
98 | focal_length_x = intr.params[0]
99 | FovY = focal2fov(focal_length_x, height)
100 | FovX = focal2fov(focal_length_x, width)
101 | elif intr.model=="PINHOLE":
102 | focal_length_x = intr.params[0]
103 | focal_length_y = intr.params[1]
104 | FovY = focal2fov(focal_length_y, height)
105 | FovX = focal2fov(focal_length_x, width)
106 | else:
107 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
108 |
109 | image_path = os.path.join(images_folder, os.path.basename(extr.name))
110 | image_name = os.path.basename(image_path).split(".")[0]
111 | image = Image.open(image_path)
112 |
113 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
114 | image_path=image_path, image_name=image_name, width=width, height=height)
115 | cam_infos.append(cam_info)
116 | sys.stdout.write('\n')
117 | return cam_infos
118 |
119 | def fetchPly(path, max_points=60_000):
120 | plydata = PlyData.read(path)
121 | vertices = plydata['vertex']
122 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
123 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
124 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
125 |
126 | if len(positions) >= max_points:
127 | #indices = np.random.randint(0, len(positions), size=max_points
128 | _, indices = sample_farthest_points(torch.tensor(positions[None]), K=max_points)
129 | indices = indices[0]
130 | positions = positions[indices]
131 | colors = colors[indices]
132 | normals = normals[indices]
133 |
134 | return BasicPointCloud(points=positions, colors=colors, normals=normals)
135 |
136 | def storePly(path, xyz, rgb, normals=None):
137 | # Define the dtype for the structured array
138 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
139 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
140 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
141 |
142 | if normals is None:
143 | normals = np.zeros_like(xyz)
144 |
145 | elements = np.empty(xyz.shape[0], dtype=dtype)
146 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
147 | elements[:] = list(map(tuple, attributes))
148 |
149 | # Create the PlyData object and write to file
150 | vertex_element = PlyElement.describe(elements, 'vertex')
151 | ply_data = PlyData([vertex_element])
152 | ply_data.write(path)
153 |
154 | def readColmapSceneInfo(path, images, eval, llffhold=8, max_points=60_000):
155 | try:
156 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
157 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
158 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
159 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
160 | except:
161 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
162 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
163 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
164 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
165 |
166 | reading_dir = "images" if images == None else images
167 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
168 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
169 |
170 | if eval:
171 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
172 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
173 | else:
174 | train_cam_infos = cam_infos
175 | test_cam_infos = []
176 |
177 | nerf_normalization = getNerfppNorm(train_cam_infos)
178 |
179 | ply_path = os.path.join(path, "sparse/0/points3D.ply")
180 | bin_path = os.path.join(path, "sparse/0/points3D.bin")
181 | txt_path = os.path.join(path, "sparse/0/points3D.txt")
182 | if not os.path.exists(ply_path):
183 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
184 | try:
185 | xyz, rgb, _ = read_points3D_binary(bin_path)
186 | except:
187 | xyz, rgb, _ = read_points3D_text(txt_path)
188 | storePly(ply_path, xyz, rgb)
189 | try:
190 | pcd = fetchPly(ply_path, max_points)
191 | except:
192 | pcd = None
193 |
194 | scene_info = SceneInfo(point_cloud=pcd,
195 | train_cameras=train_cam_infos,
196 | test_cameras=test_cam_infos,
197 | nerf_normalization=nerf_normalization,
198 | ply_path=ply_path)
199 | return scene_info
200 |
201 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
202 | cam_infos = []
203 |
204 | with open(os.path.join(path, transformsfile)) as json_file:
205 | contents = json.load(json_file)
206 | fovx = contents["camera_angle_x"]
207 |
208 | frames = contents["frames"]
209 | for idx, frame in enumerate(frames):
210 | cam_name = os.path.join(path, frame["file_path"] + extension)
211 |
212 | # NeRF 'transform_matrix' is a camera-to-world transform
213 | c2w = np.array(frame["transform_matrix"])
214 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
215 | c2w[:3, 1:3] *= -1
216 |
217 | # get the world-to-camera transform and set R, T
218 | w2c = np.linalg.inv(c2w)
219 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
220 | T = w2c[:3, 3]
221 |
222 | image_path = os.path.join(path, cam_name)
223 | image_name = Path(cam_name).stem
224 | image = Image.open(image_path)
225 |
226 | im_data = np.array(image.convert("RGBA"))
227 |
228 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
229 |
230 | norm_data = im_data / 255.0
231 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
232 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
233 |
234 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
235 | FovY = fovy
236 | FovX = fovx
237 |
238 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
239 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
240 |
241 | return cam_infos
242 |
243 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png", max_points=60_000):
244 | print("Reading Training Transforms")
245 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
246 | print("Reading Test Transforms")
247 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
248 |
249 | if not eval:
250 | train_cam_infos.extend(test_cam_infos)
251 | test_cam_infos = []
252 |
253 | nerf_normalization = getNerfppNorm(train_cam_infos)
254 |
255 | ply_path = os.path.join(path, "points3d.ply")
256 | if not os.path.exists(ply_path):
257 | # Since this data set has no colmap data, we start with random points
258 | num_pts = 100_000
259 | print(f"Generating random point cloud ({num_pts})...")
260 |
261 | # We create random points inside the bounds of the synthetic Blender scenes
262 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
263 | shs = np.random.random((num_pts, 3)) / 255.0
264 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
265 |
266 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
267 | try:
268 | pcd = fetchPly(ply_path, max_points)
269 | except:
270 | pcd = None
271 |
272 | scene_info = SceneInfo(point_cloud=pcd,
273 | train_cameras=train_cam_infos,
274 | test_cameras=test_cam_infos,
275 | nerf_normalization=nerf_normalization,
276 | ply_path=ply_path)
277 | return scene_info
278 |
279 |
280 | def load_img(path):
281 | if not "." in os.path.basename(path):
282 | files = glob.glob(path + '.*')
283 | assert len(files) > 0, "Tried to find image file for: %s, but found 0 files" % (path)
284 | path = files[0]
285 | if path.endswith(".exr"):
286 | assert False
287 | if pyexr is not None:
288 | exr_file = pyexr.open(path)
289 | # print(exr_file.channels)
290 | all_data = exr_file.get()
291 | img = all_data[..., 0:3]
292 | if "A" in exr_file.channels:
293 | mask = np.clip(all_data[..., 3:4], 0, 1)
294 | img = img * mask
295 | else:
296 | img = imageio.imread(path)
297 | import pdb;
298 | pdb.set_trace()
299 | img = np.nan_to_num(img)
300 | hdr = True
301 | else: # LDR image
302 | img = imageio.imread(path)
303 | img = img / 255
304 | # img[..., 0:3] = srgb_to_rgb_np(img[..., 0:3])
305 | hdr = False
306 | return img, hdr
307 |
308 |
309 | def load_pfm(file: str):
310 | color = None
311 | width = None
312 | height = None
313 | scale = None
314 | endian = None
315 | with open(file, 'rb') as f:
316 | header = f.readline().rstrip()
317 | if header == b'PF':
318 | color = True
319 | elif header == b'Pf':
320 | color = False
321 | else:
322 | raise Exception('Not a PFM file.')
323 | dim_match = re.match(br'^(\d+)\s(\d+)\s$', f.readline())
324 | if dim_match:
325 | width, height = map(int, dim_match.groups())
326 | else:
327 | raise Exception('Malformed PFM header.')
328 | scale = float(f.readline().rstrip())
329 | if scale < 0: # little-endian
330 | endian = '<'
331 | scale = -scale
332 | else:
333 | endian = '>' # big-endian
334 | data = np.fromfile(f, endian + 'f')
335 | shape = (height, width, 3) if color else (height, width)
336 | data = np.reshape(data, shape)
337 | data = data[::-1, ...] # cv2.flip(data, 0)
338 |
339 | return np.ascontiguousarray(data)
340 |
341 |
342 | def load_depth(tiff_path):
343 | return imageio.imread(tiff_path)
344 |
345 |
346 | def load_mask(mask_file):
347 | mask = imageio.imread(mask_file, mode='L')
348 | mask = mask.astype(np.float32)
349 | mask[mask > 0.5] = 1.0
350 |
351 | return mask
352 |
353 |
354 | def loadCamsFromScene(path, valid_list, background, debug):
355 | with open(f'{path}/sfm_scene.json') as f:
356 | sfm_scene = json.load(f)
357 |
358 | # load bbox transform
359 | bbox_transform = np.array(sfm_scene['bbox']['transform']).reshape(4, 4)
360 | bbox_transform = bbox_transform.copy()
361 | bbox_transform[[0, 1, 2], [0, 1, 2]] = bbox_transform[[0, 1, 2], [0, 1, 2]].max() / 2
362 | bbox_inv = np.linalg.inv(bbox_transform)
363 |
364 | # meta info
365 | image_list = sfm_scene['image_path']['file_paths']
366 |
367 | # camera parameters
368 | train_cam_infos = []
369 | test_cam_infos = []
370 | camera_info_list = sfm_scene['camera_track_map']['images']
371 | for i, (index, camera_info) in enumerate(camera_info_list.items()):
372 | if debug and i >= 5: break
373 | if camera_info['flg'] == 2:
374 | intrinsic = np.zeros((4, 4))
375 | intrinsic[0, 0] = camera_info['camera']['intrinsic']['focal'][0]
376 | intrinsic[1, 1] = camera_info['camera']['intrinsic']['focal'][1]
377 | intrinsic[0, 2] = camera_info['camera']['intrinsic']['ppt'][0]
378 | intrinsic[1, 2] = camera_info['camera']['intrinsic']['ppt'][1]
379 | intrinsic[2, 2] = intrinsic[3, 3] = 1
380 |
381 | extrinsic = np.array(camera_info['camera']['extrinsic']).reshape(4, 4)
382 | c2w = np.linalg.inv(extrinsic)
383 | c2w[:3, 3] = (c2w[:4, 3] @ bbox_inv.T)[:3]
384 | extrinsic = np.linalg.inv(c2w)
385 |
386 | R = np.transpose(extrinsic[:3, :3])
387 | T = extrinsic[:3, 3]
388 |
389 | focal_length_x = camera_info['camera']['intrinsic']['focal'][0]
390 | focal_length_y = camera_info['camera']['intrinsic']['focal'][1]
391 | ppx = camera_info['camera']['intrinsic']['ppt'][0]
392 | ppy = camera_info['camera']['intrinsic']['ppt'][1]
393 |
394 | image_path = os.path.join(path, image_list[index])
395 | image_name = Path(image_path).stem
396 |
397 | image, is_hdr = load_img(image_path)
398 |
399 | depth_path = os.path.join(path + "/depths/", os.path.basename(
400 | image_list[index]).replace(os.path.splitext(image_list[index])[-1], ".tiff"))
401 |
402 | if os.path.exists(depth_path):
403 | depth = load_depth(depth_path)
404 | depth *= bbox_inv[0, 0]
405 | else:
406 | print("No depth map for test view.")
407 | depth = None
408 |
409 | normal_path = os.path.join(path + "/normals/", os.path.basename(
410 | image_list[index]).replace(os.path.splitext(image_list[index])[-1], ".pfm"))
411 | if os.path.exists(normal_path):
412 | normal = load_pfm(normal_path)
413 | else:
414 | print("No normal map for test view.")
415 | normal = None
416 |
417 | mask_path = os.path.join(path + "/pmasks/", os.path.basename(
418 | image_list[index]).replace(os.path.splitext(image_list[index])[-1], ".png"))
419 | if os.path.exists(mask_path):
420 | img_mask = (imageio.imread(mask_path, pilmode='L') > 0.1).astype(np.float32)
421 | # if pmask is available, mask the image for PSNR
422 | image *= img_mask[..., np.newaxis]
423 | else:
424 | img_mask = np.ones_like(image[:, :, 0])
425 |
426 | fovx = focal2fov(focal_length_x, image.shape[1])
427 | fovy = focal2fov(focal_length_y, image.shape[0])
428 | if int(index) in valid_list:
429 | image *= img_mask[..., np.newaxis]
430 | image = Image.fromarray(np.array(image * 255.0, dtype=np.byte), "RGB")
431 | alpha = Image.fromarray(np.array(np.tile(img_mask[..., np.newaxis], (1, 1, 3)) * 255.0, dtype=np.byte),
432 | "RGB")
433 | if normal is not None:
434 | normal = Image.fromarray(np.array((normal + 1) / 2 * 255.0, dtype=np.byte), "RGB")
435 | test_cam_infos.append(CameraInfo(
436 | uid=index, R=R, T=T, FovY=fovy, FovX=fovx, image=image,
437 | image_path=image_path, image_name=image_name,
438 | alpha=alpha, normal=normal, depth=depth,
439 | width=image.size[0], height=image.size[1]))
440 | else:
441 | image *= img_mask[..., np.newaxis]
442 | depth *= img_mask
443 | normal *= img_mask[..., np.newaxis]
444 | image = Image.fromarray(np.array(image * 255.0, dtype=np.byte), "RGB")
445 | alpha = Image.fromarray(np.array(np.tile(img_mask[..., np.newaxis], (1, 1, 3)) * 255.0, dtype=np.byte),
446 | "RGB")
447 | if normal is not None:
448 | normal = Image.fromarray(np.array((normal + 1) / 2 * 255.0, dtype=np.byte), "RGB")
449 | train_cam_infos.append(CameraInfo(
450 | uid=index, R=R, T=T, FovY=fovy, FovX=fovx, image=image,
451 | image_path=image_path, image_name=image_name,
452 | alpha=alpha, normal=normal, depth=depth,
453 | width=image.size[0], height=image.size[1]))
454 |
455 | return train_cam_infos, test_cam_infos, bbox_transform
456 |
457 |
458 | def readNeILFInfo(path, background, eval, log=None, debug=False):
459 | validation_indexes = []
460 | if eval:
461 | if "dtu" in path.lower():
462 | validation_indexes = [6, 13, 30, 35] # same as neuTex
463 | else:
464 | raise NotImplementedError
465 |
466 | train_cam_infos, test_cam_infos, bbx_trans = loadCamsFromScene(
467 | f'{path}/inputs', validation_indexes, background, debug)
468 |
469 | nerf_normalization = getNerfppNorm(train_cam_infos)
470 |
471 | ply_path = f'{path}/inputs/model/sparse_bbx_scale.ply'
472 | if not os.path.exists(ply_path):
473 | org_ply_path = f'{path}/inputs/model/sparse.ply'
474 |
475 | # scale sparse.ply
476 | pcd = fetchPly(org_ply_path)
477 | inv_scale_mat = np.linalg.inv(bbx_trans) # [4, 4]
478 | points = pcd.points
479 | xyz = (np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) @ inv_scale_mat.T)[:, :3]
480 | normals = pcd.normals
481 | colors = pcd.colors
482 |
483 | storePly(ply_path, xyz, colors * 255, normals)
484 |
485 | try:
486 | pcd = fetchPly(ply_path)
487 | except:
488 | pcd = None
489 |
490 | scene_info = SceneInfo(point_cloud=pcd,
491 | train_cameras=train_cam_infos,
492 | test_cameras=test_cam_infos,
493 | nerf_normalization=nerf_normalization,
494 | ply_path=ply_path)
495 | return scene_info
496 |
497 | sceneLoadTypeCallbacks = {
498 | "Colmap": readColmapSceneInfo,
499 | "Blender" : readNerfSyntheticInfo,
500 | "NeILF": readNeILFInfo,
501 | }
--------------------------------------------------------------------------------
/scripts/average_error.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from collections import defaultdict
4 | from glob import glob
5 | from statistics import mean
6 | from argparse import ArgumentParser
7 |
8 | if __name__ == '__main__':
9 | parser = ArgumentParser(description='Script to averaging metrics values for the dataset')
10 | parser.add_argument('-f', '--folder', help='Path to the target folder', type=str, required=True)
11 | args = parser.parse_args()
12 |
13 | #======================================================
14 | results = glob(os.path.join(args.folder, "*/results.json"))
15 |
16 | metrics_statistic = defaultdict(list)
17 | for path in results:
18 | with open(path, "r") as f:
19 | metrics = json.load(f)
20 | metrics_statistic["PSNR"].append(metrics["ours_32000"]["PSNR"])
21 | metrics_statistic["SSIM"].append(metrics["ours_32000"]["SSIM"])
22 | metrics_statistic["LPIPS"].append(metrics["ours_32000"]["LPIPS"])
23 |
24 | print("PSNR:", mean(metrics_statistic["PSNR"]))
25 | print("SSIM:", mean(metrics_statistic["SSIM"]))
26 | print("LPIPS:", mean(metrics_statistic["LPIPS"]))
27 |
28 | #======================================================
29 | results = glob(os.path.join(args.folder, "*/speed.json"))
30 |
31 | metrics_statistic = defaultdict(list)
32 | for path in results:
33 | with open(path, "r") as f:
34 | metrics = json.load(f)
35 | metrics_statistic["points"].append(metrics["points"])
36 | metrics_statistic["train_time"].append(metrics["train_time"])
37 | metrics_statistic["train_time_std"].append(metrics["train_time_std"])
38 |
39 | print("Points:", mean(metrics_statistic["points"]))
40 | print("FPS:", mean(metrics_statistic["train_time"]), "±", mean(metrics_statistic["train_time_std"]))
41 |
42 | #======================================================
43 | files_1 = glob(os.path.join(args.folder, "*/point_cloud/iteration_32000/texture_color.npz"))
44 | files_2 = glob(os.path.join(args.folder, "*/point_cloud/iteration_32000/texture_alpha.npz"))
45 | files_3 = glob(os.path.join(args.folder, "*/point_cloud/iteration_32000/point_cloud.ply"))
46 |
47 | size_statistic = []
48 | for f1, f2, f3 in zip(files_1, files_2, files_3):
49 | total_size = 0
50 | file_stats = os.stat(f1)
51 | total_size += file_stats.st_size / (1024 * 1024)
52 | file_stats = os.stat(f2)
53 | total_size += file_stats.st_size / (1024 * 1024)
54 | file_stats = os.stat(f3)
55 | total_size += file_stats.st_size / (1024 * 1024)
56 | size_statistic.append(total_size)
57 |
58 | print("Size:", mean(size_statistic), " MB")
59 |
--------------------------------------------------------------------------------
/scripts/colmap_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ..
3 | DATA_FOLDER=/media/dsvitov/DATA/
4 |
5 | # Process Tanks&Temples
6 | # First put images in /*_COLMAP/input folder
7 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Train_COLMAP/
8 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Training/Truck_COLMAP/
9 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Francis_COLMAP/
10 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Horse_COLMAP/
11 | python3 convert.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Lighthouse_COLMAP/
12 |
13 | # Mip-NeRF-360
14 | # First put images in /*/COLMAP/input folder
15 | python3 convert.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bonsai/COLMAP
16 | python3 convert.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/counter/COLMAP
17 | python3 convert.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/kitchen/COLMAP
18 | python3 convert.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/room/COLMAP
19 |
20 | # Process DTU
21 | # First put images in /*_COLMAP/input folder
22 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan24_COLMAP
23 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan37_COLMAP
24 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan40_COLMAP
25 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan55_COLMAP
26 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan63_COLMAP
27 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan65_COLMAP
28 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan69_COLMAP
29 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan83_COLMAP
30 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan97_COLMAP
31 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan105_COLMAP
32 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan106_COLMAP
33 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan110_COLMAP
34 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan114_COLMAP
35 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan118_COLMAP
36 | python3 convert.py -s ${DATA_FOLDER}/DTU/selected/scan122_COLMAP
37 |
38 |
--------------------------------------------------------------------------------
/scripts/dtu_eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from argparse import ArgumentParser
4 | from glob import glob
5 | from statistics import mean
6 |
7 | dtu_scenes = ['scan24', 'scan37', 'scan40', 'scan55', 'scan63', 'scan65', 'scan69', 'scan83', 'scan97', 'scan105', 'scan106', 'scan110', 'scan114', 'scan118', 'scan122']
8 |
9 | points = {
10 | 'scan24': 30_000,
11 | 'scan37': 30_000,
12 | 'scan40': 30_000,
13 | 'scan55': 60_000,
14 | 'scan63': 60_000,
15 | 'scan65': 60_000,
16 | 'scan69': 60_000,
17 | 'scan83': 60_000,
18 | 'scan97': 60_000,
19 | 'scan105': 30_000,
20 | 'scan106': 60_000,
21 | 'scan110': 60_000,
22 | 'scan114': 60_000,
23 | 'scan118': 60_000,
24 | 'scan122': 60_000,
25 | }
26 |
27 | parser = ArgumentParser(description="Full evaluation script parameters")
28 | parser.add_argument("--skip_training", action="store_true")
29 | parser.add_argument("--skip_rendering", action="store_true")
30 | parser.add_argument("--skip_metrics", action="store_true")
31 | parser.add_argument("--output_path", default="./eval/dtu")
32 | parser.add_argument('--dtu', "-dtu", required=True, type=str)
33 | args, _ = parser.parse_known_args()
34 |
35 | all_scenes = []
36 | all_scenes.extend(dtu_scenes)
37 |
38 | if not args.skip_metrics:
39 | parser.add_argument('--DTU_Official', "-DTU", required=True, type=str)
40 | args = parser.parse_args()
41 |
42 |
43 | if not args.skip_training:
44 | for scene in dtu_scenes:
45 | common_args = " --quiet --test_iterations -1 --depth_ratio 1.0 -r 2 --lambda_dist 1000 --lambda_normal=0.05 --cap_max=" + str(points[scene]) + " --max_read_points=" + str(points[scene])
46 | source = args.dtu + "/" + scene
47 | print("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
48 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
49 |
50 |
51 | if not args.skip_rendering:
52 | all_sources = []
53 | common_args = " --quiet --depth_ratio 1.0 --num_cluster 1 --voxel_size 0.004 --sdf_trunc 0.016 --depth_trunc 3.0"
54 | for scene in dtu_scenes:
55 | source = args.dtu + "/" + scene
56 | print("python render.py --iteration 32000 -s " + source + " -m" + args.output_path + "/" + scene + common_args)
57 | os.system("python render.py --iteration 32000 -s " + source + " -m" + args.output_path + "/" + scene + common_args)
58 |
59 |
60 | if not args.skip_metrics:
61 | script_dir = os.path.dirname(os.path.abspath(__file__))
62 | for scene in dtu_scenes:
63 | scan_id = scene[4:]
64 | ply_file = f"{args.output_path}/{scene}/train/ours_32000/"
65 | iteration = 32000
66 | string = f"python {script_dir}/eval_dtu/evaluate_single_scene.py " + \
67 | f"--input_mesh {args.output_path}/{scene}/train/ours_32000/fuse_post.ply " + \
68 | f"--scan_id {scan_id} --output_dir {script_dir}/tmp/scan{scan_id} " + \
69 | f"--mask_dir {args.dtu} " + \
70 | f"--DTU {args.DTU_Official}"
71 |
72 | os.system(string)
73 |
74 | results = glob(f"{script_dir}/tmp/*/results.json")
75 | overall = []
76 | for path in results:
77 | with open(path, "r") as f:
78 | metrics = json.load(f)
79 | overall.append(metrics["overall"])
80 |
81 | print("Mean CD:", mean(overall))
--------------------------------------------------------------------------------
/scripts/eval_dtu/eval.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/jzhangbs/DTUeval-python
2 | import numpy as np
3 | import open3d as o3d
4 | import sklearn.neighbors as skln
5 | from tqdm import tqdm
6 | from scipy.io import loadmat
7 | import multiprocessing as mp
8 | import argparse
9 |
10 | def sample_single_tri(input_):
11 | n1, n2, v1, v2, tri_vert = input_
12 | c = np.mgrid[:n1+1, :n2+1]
13 | c += 0.5
14 | c[0] /= max(n1, 1e-7)
15 | c[1] /= max(n2, 1e-7)
16 | c = np.transpose(c, (1,2,0))
17 | k = c[c.sum(axis=-1) < 1] # m2
18 | q = v1 * k[:,:1] + v2 * k[:,1:] + tri_vert
19 | return q
20 |
21 | def write_vis_pcd(file, points, colors):
22 | pcd = o3d.geometry.PointCloud()
23 | pcd.points = o3d.utility.Vector3dVector(points)
24 | pcd.colors = o3d.utility.Vector3dVector(colors)
25 | o3d.io.write_point_cloud(file, pcd)
26 |
27 | if __name__ == '__main__':
28 | mp.freeze_support()
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--data', type=str, default='data_in.ply')
32 | parser.add_argument('--scan', type=int, default=1)
33 | parser.add_argument('--mode', type=str, default='mesh', choices=['mesh', 'pcd'])
34 | parser.add_argument('--dataset_dir', type=str, default='.')
35 | parser.add_argument('--vis_out_dir', type=str, default='.')
36 | parser.add_argument('--downsample_density', type=float, default=0.2)
37 | parser.add_argument('--patch_size', type=float, default=60)
38 | parser.add_argument('--max_dist', type=float, default=20)
39 | parser.add_argument('--visualize_threshold', type=float, default=10)
40 | args = parser.parse_args()
41 |
42 | thresh = args.downsample_density
43 | if args.mode == 'mesh':
44 | pbar = tqdm(total=9)
45 | pbar.set_description('read data mesh')
46 | data_mesh = o3d.io.read_triangle_mesh(args.data)
47 |
48 | vertices = np.asarray(data_mesh.vertices)
49 | triangles = np.asarray(data_mesh.triangles)
50 | tri_vert = vertices[triangles]
51 |
52 | pbar.update(1)
53 | pbar.set_description('sample pcd from mesh')
54 | v1 = tri_vert[:,1] - tri_vert[:,0]
55 | v2 = tri_vert[:,2] - tri_vert[:,0]
56 | l1 = np.linalg.norm(v1, axis=-1, keepdims=True)
57 | l2 = np.linalg.norm(v2, axis=-1, keepdims=True)
58 | area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True)
59 | non_zero_area = (area2 > 0)[:,0]
60 | l1, l2, area2, v1, v2, tri_vert = [
61 | arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert]
62 | ]
63 | thr = thresh * np.sqrt(l1 * l2 / area2)
64 | n1 = np.floor(l1 / thr)
65 | n2 = np.floor(l2 / thr)
66 |
67 | with mp.Pool() as mp_pool:
68 | new_pts = mp_pool.map(sample_single_tri, ((n1[i,0], n2[i,0], v1[i:i+1], v2[i:i+1], tri_vert[i:i+1,0]) for i in range(len(n1))), chunksize=1024)
69 |
70 | new_pts = np.concatenate(new_pts, axis=0)
71 | data_pcd = np.concatenate([vertices, new_pts], axis=0)
72 |
73 | elif args.mode == 'pcd':
74 | pbar = tqdm(total=8)
75 | pbar.set_description('read data pcd')
76 | data_pcd_o3d = o3d.io.read_point_cloud(args.data)
77 | data_pcd = np.asarray(data_pcd_o3d.points)
78 |
79 | pbar.update(1)
80 | pbar.set_description('random shuffle pcd index')
81 | shuffle_rng = np.random.default_rng()
82 | shuffle_rng.shuffle(data_pcd, axis=0)
83 |
84 | pbar.update(1)
85 | pbar.set_description('downsample pcd')
86 | nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=-1)
87 | nn_engine.fit(data_pcd)
88 | rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False)
89 | mask = np.ones(data_pcd.shape[0], dtype=np.bool_)
90 | for curr, idxs in enumerate(rnn_idxs):
91 | if mask[curr]:
92 | mask[idxs] = 0
93 | mask[curr] = 1
94 | data_down = data_pcd[mask]
95 |
96 | pbar.update(1)
97 | pbar.set_description('masking data pcd')
98 | obs_mask_file = loadmat(f'{args.dataset_dir}/ObsMask/ObsMask{args.scan}_10.mat')
99 | ObsMask, BB, Res = [obs_mask_file[attr] for attr in ['ObsMask', 'BB', 'Res']]
100 | BB = BB.astype(np.float32)
101 |
102 | patch = args.patch_size
103 | inbound = ((data_down >= BB[:1]-patch) & (data_down < BB[1:]+patch*2)).sum(axis=-1) ==3
104 | data_in = data_down[inbound]
105 |
106 | data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32)
107 | grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(axis=-1) ==3
108 | data_grid_in = data_grid[grid_inbound]
109 | in_obs = ObsMask[data_grid_in[:,0], data_grid_in[:,1], data_grid_in[:,2]].astype(np.bool_)
110 | data_in_obs = data_in[grid_inbound][in_obs]
111 |
112 | pbar.update(1)
113 | pbar.set_description('read STL pcd')
114 | stl_pcd = o3d.io.read_point_cloud(f'{args.dataset_dir}/Points/stl/stl{args.scan:03}_total.ply')
115 | stl = np.asarray(stl_pcd.points)
116 |
117 | pbar.update(1)
118 | pbar.set_description('compute data2stl')
119 | nn_engine.fit(stl)
120 | dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True)
121 | max_dist = args.max_dist
122 | mean_d2s = dist_d2s[dist_d2s < max_dist].mean()
123 |
124 | pbar.update(1)
125 | pbar.set_description('compute stl2data')
126 | ground_plane = loadmat(f'{args.dataset_dir}/ObsMask/Plane{args.scan}.mat')['P']
127 |
128 | stl_hom = np.concatenate([stl, np.ones_like(stl[:,:1])], -1)
129 | above = (ground_plane.reshape((1,4)) * stl_hom).sum(-1) > 0
130 | stl_above = stl[above]
131 |
132 | nn_engine.fit(data_in)
133 | dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True)
134 | mean_s2d = dist_s2d[dist_s2d < max_dist].mean()
135 |
136 | pbar.update(1)
137 | pbar.set_description('visualize error')
138 | vis_dist = args.visualize_threshold
139 | R = np.array([[1,0,0]], dtype=np.float64)
140 | G = np.array([[0,1,0]], dtype=np.float64)
141 | B = np.array([[0,0,1]], dtype=np.float64)
142 | W = np.array([[1,1,1]], dtype=np.float64)
143 | data_color = np.tile(B, (data_down.shape[0], 1))
144 | data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist
145 | data_color[ np.where(inbound)[0][grid_inbound][in_obs] ] = R * data_alpha + W * (1-data_alpha)
146 | data_color[ np.where(inbound)[0][grid_inbound][in_obs][dist_d2s[:,0] >= max_dist] ] = G
147 | write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_d2s.ply', data_down, data_color)
148 | stl_color = np.tile(B, (stl.shape[0], 1))
149 | stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist
150 | stl_color[ np.where(above)[0] ] = R * stl_alpha + W * (1-stl_alpha)
151 | stl_color[ np.where(above)[0][dist_s2d[:,0] >= max_dist] ] = G
152 | write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_s2d.ply', stl, stl_color)
153 |
154 | pbar.update(1)
155 | pbar.set_description('done')
156 | pbar.close()
157 | over_all = (mean_d2s + mean_s2d) / 2
158 | print(mean_d2s, mean_s2d, over_all)
159 |
160 | import json
161 | with open(f'{args.vis_out_dir}/results.json', 'w') as fp:
162 | json.dump({
163 | 'mean_d2s': mean_d2s,
164 | 'mean_s2d': mean_s2d,
165 | 'overall': over_all,
166 | }, fp, indent=True)
--------------------------------------------------------------------------------
/scripts/eval_dtu/evaluate_single_scene.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import cv2
5 | import numpy as np
6 | import os
7 | import glob
8 | from skimage.morphology import binary_dilation, disk
9 | import argparse
10 |
11 | import trimesh
12 | from pathlib import Path
13 | import subprocess
14 |
15 | import sys
16 | import render_utils as rend_util
17 | from tqdm import tqdm
18 |
19 | def cull_scan(scan, mesh_path, result_mesh_file, instance_dir):
20 |
21 | # load poses
22 | image_dir = '{0}/images'.format(instance_dir)
23 | image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))
24 | n_images = len(image_paths)
25 | cam_file = '{0}/cameras.npz'.format(instance_dir)
26 | camera_dict = np.load(cam_file)
27 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
28 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
29 |
30 | intrinsics_all = []
31 | pose_all = []
32 | for scale_mat, world_mat in zip(scale_mats, world_mats):
33 | P = world_mat @ scale_mat
34 | P = P[:3, :4]
35 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
36 | intrinsics_all.append(torch.from_numpy(intrinsics).float())
37 | pose_all.append(torch.from_numpy(pose).float())
38 |
39 | # load mask
40 | mask_dir = '{0}/mask'.format(instance_dir)
41 | mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*.png")))
42 | masks = []
43 | for p in mask_paths:
44 | mask = cv2.imread(p)
45 | masks.append(mask)
46 |
47 | # hard-coded image shape
48 | W, H = 1600, 1200
49 |
50 | # load mesh
51 | mesh = trimesh.load(mesh_path)
52 |
53 | # load transformation matrix
54 |
55 | vertices = mesh.vertices
56 |
57 | # project and filter
58 | vertices = torch.from_numpy(vertices).cuda()
59 | vertices = torch.cat((vertices, torch.ones_like(vertices[:, :1])), dim=-1)
60 | vertices = vertices.permute(1, 0)
61 | vertices = vertices.float()
62 |
63 | sampled_masks = []
64 | for i in tqdm(range(n_images), desc="Culling mesh given masks"):
65 | pose = pose_all[i]
66 | w2c = torch.inverse(pose).cuda()
67 | intrinsic = intrinsics_all[i].cuda()
68 |
69 | with torch.no_grad():
70 | # transform and project
71 | cam_points = intrinsic @ w2c @ vertices
72 | pix_coords = cam_points[:2, :] / (cam_points[2, :].unsqueeze(0) + 1e-6)
73 | pix_coords = pix_coords.permute(1, 0)
74 | pix_coords[..., 0] /= W - 1
75 | pix_coords[..., 1] /= H - 1
76 | pix_coords = (pix_coords - 0.5) * 2
77 | valid = ((pix_coords > -1. ) & (pix_coords < 1.)).all(dim=-1).float()
78 |
79 | # dialate mask similar to unisurf
80 | maski = masks[i][:, :, 0].astype(np.float32) / 256.
81 | maski = torch.from_numpy(binary_dilation(maski, disk(24))).float()[None, None].cuda()
82 |
83 | sampled_mask = F.grid_sample(maski, pix_coords[None, None], mode='nearest', padding_mode='zeros', align_corners=True)[0, -1, 0]
84 |
85 | sampled_mask = sampled_mask + (1. - valid)
86 | sampled_masks.append(sampled_mask)
87 |
88 | sampled_masks = torch.stack(sampled_masks, -1)
89 | # filter
90 |
91 | mask = (sampled_masks > 0.).all(dim=-1).cpu().numpy()
92 | face_mask = mask[mesh.faces].all(axis=1)
93 |
94 | mesh.update_vertices(mask)
95 | mesh.update_faces(face_mask)
96 |
97 | # transform vertices to world
98 | scale_mat = scale_mats[0]
99 | mesh.vertices = mesh.vertices * scale_mat[0, 0] + scale_mat[:3, 3][None]
100 | mesh.export(result_mesh_file)
101 | del mesh
102 |
103 |
104 | if __name__ == "__main__":
105 |
106 | parser = argparse.ArgumentParser(
107 | description='Arguments to evaluate the mesh.'
108 | )
109 |
110 | parser.add_argument('--input_mesh', type=str, help='path to the mesh to be evaluated')
111 | parser.add_argument('--scan_id', type=str, help='scan id of the input mesh')
112 | parser.add_argument('--output_dir', type=str, default='evaluation_results_single', help='path to the output folder')
113 | parser.add_argument('--mask_dir', type=str, default='mask', help='path to uncropped mask')
114 | parser.add_argument('--DTU', type=str, default='Offical_DTU_Dataset', help='path to the GT DTU point clouds')
115 | args = parser.parse_args()
116 |
117 | Offical_DTU_Dataset = args.DTU
118 | out_dir = args.output_dir
119 | Path(out_dir).mkdir(parents=True, exist_ok=True)
120 |
121 | scan = args.scan_id
122 | ply_file = args.input_mesh
123 | print("cull mesh ....")
124 | result_mesh_file = os.path.join(out_dir, "culled_mesh.ply")
125 | cull_scan(scan, ply_file, result_mesh_file, instance_dir=os.path.join(args.mask_dir, f'scan{args.scan_id}'))
126 |
127 | script_dir = os.path.dirname(os.path.abspath(__file__))
128 | cmd = f"python {script_dir}/eval.py --data {result_mesh_file} --scan {scan} --mode mesh --dataset_dir {Offical_DTU_Dataset} --vis_out_dir {out_dir}"
129 | os.system(cmd)
--------------------------------------------------------------------------------
/scripts/eval_dtu/render_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | import skimage
4 | import cv2
5 | import torch
6 | from torch.nn import functional as F
7 |
8 |
9 | def get_psnr(img1, img2, normalize_rgb=False):
10 | if normalize_rgb: # [-1,1] --> [0,1]
11 | img1 = (img1 + 1.) / 2.
12 | img2 = (img2 + 1. ) / 2.
13 |
14 | mse = torch.mean((img1 - img2) ** 2)
15 | psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda())
16 |
17 | return psnr
18 |
19 |
20 | def load_rgb(path, normalize_rgb = False):
21 | img = imageio.imread(path)
22 | img = skimage.img_as_float32(img)
23 |
24 | if normalize_rgb: # [-1,1] --> [0,1]
25 | img -= 0.5
26 | img *= 2.
27 | img = img.transpose(2, 0, 1)
28 | return img
29 |
30 |
31 | def load_K_Rt_from_P(filename, P=None):
32 | if P is None:
33 | lines = open(filename).read().splitlines()
34 | if len(lines) == 4:
35 | lines = lines[1:]
36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
37 | P = np.asarray(lines).astype(np.float32).squeeze()
38 |
39 | out = cv2.decomposeProjectionMatrix(P)
40 | K = out[0]
41 | R = out[1]
42 | t = out[2]
43 |
44 | K = K/K[2,2]
45 | intrinsics = np.eye(4)
46 | intrinsics[:3, :3] = K
47 |
48 | pose = np.eye(4, dtype=np.float32)
49 | pose[:3, :3] = R.transpose()
50 | pose[:3,3] = (t[:3] / t[3])[:,0]
51 |
52 | return intrinsics, pose
53 |
54 |
55 | def get_camera_params(uv, pose, intrinsics):
56 | if pose.shape[1] == 7: #In case of quaternion vector representation
57 | cam_loc = pose[:, 4:]
58 | R = quat_to_rot(pose[:,:4])
59 | p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
60 | p[:, :3, :3] = R
61 | p[:, :3, 3] = cam_loc
62 | else: # In case of pose matrix representation
63 | cam_loc = pose[:, :3, 3]
64 | p = pose
65 |
66 | batch_size, num_samples, _ = uv.shape
67 |
68 | depth = torch.ones((batch_size, num_samples)).cuda()
69 | x_cam = uv[:, :, 0].view(batch_size, -1)
70 | y_cam = uv[:, :, 1].view(batch_size, -1)
71 | z_cam = depth.view(batch_size, -1)
72 |
73 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
74 |
75 | # permute for batch matrix product
76 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
77 |
78 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
79 | ray_dirs = world_coords - cam_loc[:, None, :]
80 | ray_dirs = F.normalize(ray_dirs, dim=2)
81 |
82 | return ray_dirs, cam_loc
83 |
84 |
85 | def get_camera_for_plot(pose):
86 | if pose.shape[1] == 7: #In case of quaternion vector representation
87 | cam_loc = pose[:, 4:].detach()
88 | R = quat_to_rot(pose[:,:4].detach())
89 | else: # In case of pose matrix representation
90 | cam_loc = pose[:, :3, 3]
91 | R = pose[:, :3, :3]
92 | cam_dir = R[:, :3, 2]
93 | return cam_loc, cam_dir
94 |
95 |
96 | def lift(x, y, z, intrinsics):
97 | # parse intrinsics
98 | intrinsics = intrinsics.cuda()
99 | fx = intrinsics[:, 0, 0]
100 | fy = intrinsics[:, 1, 1]
101 | cx = intrinsics[:, 0, 2]
102 | cy = intrinsics[:, 1, 2]
103 | sk = intrinsics[:, 0, 1]
104 |
105 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
106 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
107 |
108 | # homogeneous
109 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1)
110 |
111 |
112 | def quat_to_rot(q):
113 | batch_size, _ = q.shape
114 | q = F.normalize(q, dim=1)
115 | R = torch.ones((batch_size, 3,3)).cuda()
116 | qr=q[:,0]
117 | qi = q[:, 1]
118 | qj = q[:, 2]
119 | qk = q[:, 3]
120 | R[:, 0, 0]=1-2 * (qj**2 + qk**2)
121 | R[:, 0, 1] = 2 * (qj *qi -qk*qr)
122 | R[:, 0, 2] = 2 * (qi * qk + qr * qj)
123 | R[:, 1, 0] = 2 * (qj * qi + qk * qr)
124 | R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
125 | R[:, 1, 2] = 2*(qj*qk - qi*qr)
126 | R[:, 2, 0] = 2 * (qk * qi-qj * qr)
127 | R[:, 2, 1] = 2 * (qj*qk + qi*qr)
128 | R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
129 | return R
130 |
131 |
132 | def rot_to_quat(R):
133 | batch_size, _,_ = R.shape
134 | q = torch.ones((batch_size, 4)).cuda()
135 |
136 | R00 = R[:, 0,0]
137 | R01 = R[:, 0, 1]
138 | R02 = R[:, 0, 2]
139 | R10 = R[:, 1, 0]
140 | R11 = R[:, 1, 1]
141 | R12 = R[:, 1, 2]
142 | R20 = R[:, 2, 0]
143 | R21 = R[:, 2, 1]
144 | R22 = R[:, 2, 2]
145 |
146 | q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2
147 | q[:, 1]=(R21-R12)/(4*q[:,0])
148 | q[:, 2] = (R02 - R20) / (4 * q[:, 0])
149 | q[:, 3] = (R10 - R01) / (4 * q[:, 0])
150 | return q
151 |
152 |
153 | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0):
154 | # Input: n_rays x 3 ; n_rays x 3
155 | # Output: n_rays x 1, n_rays x 1 (close and far)
156 |
157 | ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3),
158 | cam_loc.view(-1, 3, 1)).squeeze(-1)
159 | under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2)
160 |
161 | # sanity check
162 | if (under_sqrt <= 0).sum() > 0:
163 | print('BOUNDING SPHERE PROBLEM!')
164 | exit()
165 |
166 | sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot
167 | sphere_intersections = sphere_intersections.clamp_min(0.0)
168 |
169 | return sphere_intersections
--------------------------------------------------------------------------------
/scripts/metrics_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ..
3 | OUTPUT_FOLDER=/media/dsvitov/DATA/output/Ours
4 |
5 | # Process Tanks&Temples
6 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Train
7 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Truck
8 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Francis
9 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Horse
10 | python metrics.py --model_path=${OUTPUT_FOLDER}/TnT/Lighthouse
11 |
12 | # Mip-NeRF-360
13 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Bonsai
14 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Counter
15 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Kitchen
16 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Room
17 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Bicycle
18 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Stump
19 | python metrics.py --model_path=${OUTPUT_FOLDER}/MipNerf/Garden
20 |
21 | # Process DTU
22 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan24
23 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan37
24 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan40
25 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan55
26 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan63
27 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan65
28 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan69
29 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan83
30 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan97
31 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan105
32 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan106
33 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan110
34 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan114
35 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan118
36 | python metrics.py --model_path=${OUTPUT_FOLDER}/DTU/scan122
37 |
38 | # Average metrics for each dataset
39 | python scripts/average_error.py --folder ${OUTPUT_FOLDER}/TnT
40 | python scripts/average_error.py --folder ${OUTPUT_FOLDER}/MipNerf
41 | python scripts/average_error.py --folder ${OUTPUT_FOLDER}/DTU
42 |
--------------------------------------------------------------------------------
/scripts/render_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ..
3 | DATA_FOLDER=/media/dsvitov/DATA/
4 | OUTPUT_FOLDER=/media/dsvitov/DATA/output/Ours
5 |
6 | # Process Tanks&Temples
7 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Train_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Train --skip_mesh
8 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Training/Truck_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Truck --skip_mesh
9 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Francis_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Francis --skip_mesh
10 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Horse_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Horse --skip_mesh
11 | python render.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Lighthouse_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Lighthouse --skip_mesh
12 |
13 | # Mip-NeRF-360
14 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bonsai --model_path=${OUTPUT_FOLDER}/MipNerf/Bonsai --skip_mesh
15 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/counter --model_path=${OUTPUT_FOLDER}/MipNerf/Counter --skip_mesh
16 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/kitchen --model_path=${OUTPUT_FOLDER}/MipNerf/Kitchen --skip_mesh
17 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/room --model_path=${OUTPUT_FOLDER}/MipNerf/Room --skip_mesh
18 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bicycle --model_path=${OUTPUT_FOLDER}/MipNerf/Bicycle --skip_mesh
19 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/stump --model_path=${OUTPUT_FOLDER}/MipNerf/Stump --skip_mesh
20 | python render.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/garden --model_path=${OUTPUT_FOLDER}/MipNerf/Garden --skip_mesh
21 |
22 | # Process DTU
23 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan24 --model_path=${OUTPUT_FOLDER}/DTU/scan24 --skip_mesh
24 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan37 --model_path=${OUTPUT_FOLDER}/DTU/scan37 --skip_mesh
25 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan40 --model_path=${OUTPUT_FOLDER}/DTU/scan40 --skip_mesh
26 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan55 --model_path=${OUTPUT_FOLDER}/DTU/scan55 --skip_mesh
27 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan63 --model_path=${OUTPUT_FOLDER}/DTU/scan63 --skip_mesh
28 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan65 --model_path=${OUTPUT_FOLDER}/DTU/scan65 --skip_mesh
29 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan69 --model_path=${OUTPUT_FOLDER}/DTU/scan69 --skip_mesh
30 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan83 --model_path=${OUTPUT_FOLDER}/DTU/scan83 --skip_mesh
31 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan97 --model_path=${OUTPUT_FOLDER}/DTU/scan97 --skip_mesh
32 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan105 --model_path=${OUTPUT_FOLDER}/DTU/scan105 --skip_mesh
33 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan106 --model_path=${OUTPUT_FOLDER}/DTU/scan106 --skip_mesh
34 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan110 --model_path=${OUTPUT_FOLDER}/DTU/scan110 --skip_mesh
35 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan114 --model_path=${OUTPUT_FOLDER}/DTU/scan114 --skip_mesh
36 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan118 --model_path=${OUTPUT_FOLDER}/DTU/scan118 --skip_mesh
37 | python render.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan122 --model_path=${OUTPUT_FOLDER}/DTU/scan122 --skip_mesh
38 |
--------------------------------------------------------------------------------
/scripts/train_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ..
3 | DATA_FOLDER=/media/dsvitov/DATA/
4 | OUTPUT_FOLDER=/media/dsvitov/DATA/output/Ours
5 |
6 | # Process Tanks&Temples
7 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Train_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Train --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval
8 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Training/Truck_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Truck --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval
9 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Francis_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Francis --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval
10 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Horse_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Horse --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval
11 | python train.py -s ${DATA_FOLDER}/Tanks_and_Temples/Intermediate/Lighthouse_COLMAP_big --model_path=${OUTPUT_FOLDER}/TnT/Lighthouse --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval
12 |
13 | # Mip-NeRF-360
14 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bonsai --model_path=${OUTPUT_FOLDER}/MipNerf/Bonsai --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval
15 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/counter --model_path=${OUTPUT_FOLDER}/MipNerf/Counter --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval
16 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/kitchen --model_path=${OUTPUT_FOLDER}/MipNerf/Kitchen --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval
17 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/room --model_path=${OUTPUT_FOLDER}/MipNerf/Room --cap_max=160_000 --max_read_points=150_000 --add_sky_box --eval
18 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/bicycle --model_path=${OUTPUT_FOLDER}/MipNerf/Bicycle --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval
19 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/stump --model_path=${OUTPUT_FOLDER}/MipNerf/Stump --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval
20 | python train.py -s ${DATA_FOLDER}/Mip-NeRF-360/360_v2/garden --model_path=${OUTPUT_FOLDER}/MipNerf/Garden --cap_max=300_000 --max_read_points=290_000 --add_sky_box --eval
21 |
22 | # Process DTU
23 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan24 --model_path=${OUTPUT_FOLDER}/DTU/scan24 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
24 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan37 --model_path=${OUTPUT_FOLDER}/DTU/scan37 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
25 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan40 --model_path=${OUTPUT_FOLDER}/DTU/scan40 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
26 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan55 --model_path=${OUTPUT_FOLDER}/DTU/scan55 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
27 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan63 --model_path=${OUTPUT_FOLDER}/DTU/scan63 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
28 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan65 --model_path=${OUTPUT_FOLDER}/DTU/scan65 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
29 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan69 --model_path=${OUTPUT_FOLDER}/DTU/scan69 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
30 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan83 --model_path=${OUTPUT_FOLDER}/DTU/scan83 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
31 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan97 --model_path=${OUTPUT_FOLDER}/DTU/scan97 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
32 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan105 --model_path=${OUTPUT_FOLDER}/DTU/scan105 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
33 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan106 --model_path=${OUTPUT_FOLDER}/DTU/scan106 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
34 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan110 --model_path=${OUTPUT_FOLDER}/DTU/scan110 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
35 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan114 --model_path=${OUTPUT_FOLDER}/DTU/scan114 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
36 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan118 --model_path=${OUTPUT_FOLDER}/DTU/scan118 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
37 | python train.py -s ${DATA_FOLDER}DTU/dtu/DTU/scan122 --model_path=${OUTPUT_FOLDER}/DTU/scan122 --cap_max=60_000 --max_read_points=60_000 --lambda_normal=0.05 --lambda_dist 100 --eval
38 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
14 |
15 | import torch
16 | from random import randint
17 | from utils.loss_utils import l1_loss, ssim
18 | from gaussian_renderer import render
19 | import sys
20 | from scene import Scene, GaussianModel
21 | from utils.general_utils import safe_state, build_scaling_rotation
22 | import uuid
23 | from tqdm import tqdm
24 | from utils.image_utils import psnr
25 | from argparse import ArgumentParser, Namespace
26 | from arguments import ModelParams, PipelineParams, OptimizationParams
27 |
28 | try:
29 | from torch.utils.tensorboard import SummaryWriter
30 | TENSORBOARD_FOUND = True
31 | except ImportError:
32 | TENSORBOARD_FOUND = False
33 |
34 |
35 | def total_variation_loss(img):
36 | bs_img, c_img, h_img, w_img = img.size()
37 | tv_h = torch.pow(img[:, :, 1:, :] - img[:, :, :-1, :], 2).sum()
38 | tv_w = torch.pow(img[:, :, :, 1:] - img[:, :, :, :-1], 2).sum()
39 | return (tv_h + tv_w) / (bs_img * c_img * h_img * w_img)
40 |
41 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint):
42 | first_iter = 0
43 | tb_writer = prepare_output_and_logger(dataset)
44 | gaussians = GaussianModel(dataset.sh_degree)
45 | scene = Scene(dataset, gaussians, add_sky_box=opt.add_sky_box, max_read_points=opt.max_read_points, sphere_point=opt.sphere_point)
46 | gaussians.training_setup(opt)
47 | if checkpoint:
48 | (model_params, first_iter) = torch.load(checkpoint)
49 | gaussians.restore(model_params, opt)
50 |
51 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
52 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
53 |
54 | iter_start = torch.cuda.Event(enable_timing = True)
55 | iter_end = torch.cuda.Event(enable_timing = True)
56 |
57 | viewpoint_stack = None
58 | ema_loss_for_log = 0.0
59 | ema_dist_for_log = 0.0
60 | ema_normal_for_log = 0.0
61 | ema_texture_for_log = 0.0
62 |
63 | initial_texture_alpha = gaussians.get_texture_alpha[0:1].detach().clone()
64 |
65 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
66 | first_iter += 1
67 | for iteration in range(first_iter, opt.iterations + 1):
68 |
69 | iter_start.record()
70 |
71 | xyz_lr = gaussians.update_learning_rate(iteration)
72 |
73 | # Every 1000 its we increase the levels of SH up to a maximum degree
74 | if iteration % 1000 == 0:
75 | gaussians.oneupSHdegree()
76 |
77 | # Pick a random Camera
78 | if not viewpoint_stack:
79 | viewpoint_stack = scene.getTrainCameras().copy()
80 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
81 |
82 | render_pkg = render(viewpoint_cam, gaussians, pipe, background)
83 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
84 | impact = render_pkg["impact"]
85 |
86 | gt_image = viewpoint_cam.original_image.cuda()
87 |
88 | Ll1 = l1_loss(image, gt_image)
89 | ssim_map = ssim(image, gt_image, size_average=False)
90 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_map.mean())
91 |
92 | # regularization
93 | lambda_normal = opt.lambda_normal if iteration > 7000 else 0.0
94 | lambda_dist = opt.lambda_dist if iteration > 3000 else 0.0
95 |
96 | rend_dist = render_pkg["rend_dist"]
97 | rend_normal = render_pkg['rend_normal']
98 | surf_normal = render_pkg['surf_normal']
99 | normal_error = (1 - (rend_normal * surf_normal).sum(dim=0))[None]
100 | normal_loss = lambda_normal * (normal_error).mean()
101 | dist_loss = lambda_dist * (rend_dist).mean()
102 |
103 | weights = opt.max_impact_threshold - torch.clamp(impact[visibility_filter], 0, opt.max_impact_threshold)
104 | textures_reg = (gaussians.get_texture_color[visibility_filter].mean(dim=[1, 2, 3]) * weights).mean() * opt.lambda_texture_value
105 | textures_reg += torch.abs((gaussians.get_texture_alpha[visibility_filter] - initial_texture_alpha).mean(dim=[1, 2]) * weights).mean() * opt.lambda_alpha_value
106 |
107 | # loss
108 | total_loss = loss + dist_loss + normal_loss + textures_reg
109 | # For MCMC sampler
110 | total_loss += opt.opacity_reg * gaussians.get_texture_alpha.mean()
111 | total_loss.backward()
112 |
113 | iter_end.record()
114 |
115 | with torch.no_grad():
116 | # Progress bar
117 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
118 | ema_dist_for_log = 0.4 * dist_loss.item() + 0.6 * ema_dist_for_log
119 | ema_normal_for_log = 0.4 * normal_loss.item() + 0.6 * ema_normal_for_log
120 | ema_texture_for_log = 0.4 * textures_reg.item() + 0.6 * ema_texture_for_log
121 |
122 |
123 | if iteration % 10 == 0:
124 | loss_dict = {
125 | "Loss": f"{ema_loss_for_log:.{5}f}",
126 | "distort": f"{ema_dist_for_log:.{5}f}",
127 | "normal": f"{ema_normal_for_log:.{5}f}",
128 | "texture": f"{ema_texture_for_log:.{5}f}",
129 | "Points": f"{len(gaussians.get_xyz)}"
130 | }
131 | progress_bar.set_postfix(loss_dict)
132 |
133 | progress_bar.update(10)
134 | if iteration == opt.iterations:
135 | progress_bar.close()
136 |
137 | # Log and save
138 | if tb_writer is not None:
139 | tb_writer.add_scalar('train_loss_patches/dist_loss', ema_dist_for_log, iteration)
140 | tb_writer.add_scalar('train_loss_patches/normal_loss', ema_normal_for_log, iteration)
141 |
142 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
143 | if (iteration in saving_iterations):
144 | print("\n[ITER {}] Saving Gaussians".format(iteration))
145 | scene.save(iteration)
146 |
147 | if opt.texture_from_iter <= iteration < opt.texture_to_iter:
148 | gaussians.activate_texture_training()
149 |
150 | if iteration >= opt.texture_to_iter:
151 | gaussians.deactivate_texture_training()
152 |
153 | if iteration > opt.position_lr_max_steps:
154 | gaussians.deactivate_gaussians_training()
155 |
156 | # Densification
157 | if iteration < opt.densify_until_iter and iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
158 | size = len(gaussians.get_texture_alpha)
159 | dead_mask = (gaussians.get_texture_alpha.view(size, -1).mean(1) <= opt.dead_opacity).squeeze(-1)
160 |
161 | gaussians.relocate_gs(dead_mask=dead_mask)
162 | gaussians.add_new_gs(cap_max=opt.cap_max)
163 |
164 | # Optimizer step
165 | if iteration < opt.iterations:
166 | gaussians.optimizer.step()
167 | gaussians.optimizer.zero_grad(set_to_none = True)
168 |
169 | L = build_scaling_rotation(gaussians.get_scaling, gaussians.get_rotation)
170 | actual_covariance = L @ L.transpose(1, 2)
171 |
172 | def op_sigmoid(x, k=100, x0=0.995):
173 | return 1 / (1 + torch.exp(-k * (x - x0)))
174 |
175 | #size = len(gaussians.get_texture_alpha)
176 | #opacity = gaussians.get_texture_alpha.view(size, -1).mean(1, keepdim=True) * 10 # Rescale to get maximum = 1
177 | opacity = torch.ones([gaussians.get_texture_alpha.shape[0], 1], dtype=torch.float32, device="cuda") # Fix opacity to 1 (results in the paper obtained this way)
178 | noise = torch.randn_like(gaussians._xyz) * (op_sigmoid(1 - opacity)) * opt.noise_lr * xyz_lr
179 | noise = torch.bmm(actual_covariance, noise.unsqueeze(-1)).squeeze(-1)
180 | gaussians._xyz.add_(noise)
181 |
182 | if (iteration in checkpoint_iterations):
183 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
184 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
185 |
186 | def prepare_output_and_logger(args):
187 | if not args.model_path:
188 | if os.getenv('OAR_JOB_ID'):
189 | unique_str=os.getenv('OAR_JOB_ID')
190 | else:
191 | unique_str = str(uuid.uuid4())
192 | args.model_path = os.path.join("./output/", unique_str[0:10])
193 |
194 | # Set up output folder
195 | print("Output folder: {}".format(args.model_path))
196 | os.makedirs(args.model_path, exist_ok = True)
197 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
198 | cfg_log_f.write(str(Namespace(**vars(args))))
199 |
200 | # Create Tensorboard writer
201 | tb_writer = None
202 | if TENSORBOARD_FOUND:
203 | tb_writer = SummaryWriter(args.model_path)
204 | else:
205 | print("Tensorboard not available: not logging progress")
206 | return tb_writer
207 |
208 | @torch.no_grad()
209 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
210 | if tb_writer:
211 | tb_writer.add_scalar('train_loss_patches/reg_loss', Ll1.item(), iteration)
212 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
213 | tb_writer.add_scalar('iter_time', elapsed, iteration)
214 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
215 |
216 | # Report test and samples of training set
217 | if iteration in testing_iterations:
218 | torch.cuda.empty_cache()
219 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
220 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
221 |
222 | for config in validation_configs:
223 | if config['cameras'] and len(config['cameras']) > 0:
224 | l1_test = 0.0
225 | psnr_test = 0.0
226 | for idx, viewpoint in enumerate(config['cameras']):
227 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs)
228 | image = torch.clamp(render_pkg["render"], 0.0, 1.0)
229 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
230 | if tb_writer and (idx < 5):
231 | from utils.general_utils import colormap
232 | depth = render_pkg["surf_depth"]
233 | norm = depth.max()
234 | depth = depth / norm
235 | depth = colormap(depth.cpu().numpy()[0], cmap='turbo')
236 | tb_writer.add_images(config['name'] + "_view_{}/depth".format(viewpoint.image_name), depth[None], global_step=iteration)
237 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
238 |
239 | try:
240 | rend_alpha = render_pkg['rend_alpha']
241 | rend_normal = render_pkg["rend_normal"] * 0.5 + 0.5
242 | surf_normal = render_pkg["surf_normal"] * 0.5 + 0.5
243 | tb_writer.add_images(config['name'] + "_view_{}/rend_normal".format(viewpoint.image_name), rend_normal[None], global_step=iteration)
244 | tb_writer.add_images(config['name'] + "_view_{}/surf_normal".format(viewpoint.image_name), surf_normal[None], global_step=iteration)
245 | tb_writer.add_images(config['name'] + "_view_{}/rend_alpha".format(viewpoint.image_name), rend_alpha[None], global_step=iteration)
246 |
247 | rend_dist = render_pkg["rend_dist"]
248 | rend_dist = colormap(rend_dist.cpu().numpy()[0])
249 | tb_writer.add_images(config['name'] + "_view_{}/rend_dist".format(viewpoint.image_name), rend_dist[None], global_step=iteration)
250 | except:
251 | pass
252 |
253 | if iteration == testing_iterations[0]:
254 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
255 |
256 | l1_test += l1_loss(image, gt_image).mean().double()
257 | psnr_test += psnr(image, gt_image).mean().double()
258 |
259 | psnr_test /= len(config['cameras'])
260 | l1_test /= len(config['cameras'])
261 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
262 | if tb_writer:
263 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
264 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
265 |
266 | torch.cuda.empty_cache()
267 |
268 | if __name__ == "__main__":
269 | # Set up command line argument parser
270 | parser = ArgumentParser(description="Training script parameters")
271 | lp = ModelParams(parser)
272 | op = OptimizationParams(parser)
273 | pp = PipelineParams(parser)
274 | parser.add_argument('--ip', type=str, default="127.0.0.1")
275 | parser.add_argument('--port', type=int, default=6009)
276 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
277 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[1_000, 7_000, 10_000, 15_000, 20_000, 25_000, 30_000, 32_000])
278 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1_000, 7_000, 30_000, 32_000])
279 | parser.add_argument("--quiet", action="store_true")
280 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
281 | parser.add_argument("--start_checkpoint", type=str, default = None)
282 | args = parser.parse_args(sys.argv[1:])
283 | args.save_iterations.append(args.iterations)
284 |
285 | print("Optimizing " + args.model_path)
286 |
287 | # Initialize system state (RNG)
288 | safe_state(args.quiet)
289 |
290 | # Start GUI server, configure and run training
291 | # network_gui.init(args.ip, args.port)
292 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
293 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint)
294 |
295 | # All done
296 | print("\nTraining complete.")
297 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import torch
14 |
15 | from scene.cameras import Camera
16 | from utils.general_utils import PILtoTorch
17 | from utils.graphics_utils import fov2focal
18 |
19 | WARNED = False
20 |
21 | def loadCam(args, id, cam_info, resolution_scale):
22 | orig_w, orig_h = cam_info.image.size
23 |
24 | if args.resolution in [1, 2, 4, 8]:
25 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
26 | else: # should be a type that converts to float
27 | if args.resolution == -1:
28 | if orig_w > 1600:
29 | global WARNED
30 | if not WARNED:
31 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
32 | "If this is not desired, please explicitly specify '--resolution/-r' as 1")
33 | WARNED = True
34 | global_down = orig_w / 1600
35 | else:
36 | global_down = 1
37 | else:
38 | global_down = orig_w / args.resolution
39 |
40 | scale = float(global_down) * float(resolution_scale)
41 | resolution = (int(orig_w / scale), int(orig_h / scale))
42 |
43 | if len(cam_info.image.split()) > 3:
44 | resized_image_rgb = torch.cat([PILtoTorch(im, resolution) for im in cam_info.image.split()[:3]], dim=0)
45 | loaded_mask = PILtoTorch(cam_info.image.split()[3], resolution)
46 | gt_image = resized_image_rgb
47 | else:
48 | resized_image_rgb = PILtoTorch(cam_info.image, resolution)
49 | loaded_mask = None
50 | gt_image = resized_image_rgb
51 |
52 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
53 | FoVx=cam_info.FovX, FoVy=cam_info.FovY,
54 | image=gt_image, gt_alpha_mask=loaded_mask,
55 | image_name=cam_info.image_name, uid=id, data_device=args.data_device)
56 |
57 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
58 | camera_list = []
59 |
60 | for id, c in enumerate(cam_infos):
61 | camera_list.append(loadCam(args, id, c, resolution_scale))
62 |
63 | return camera_list
64 |
65 | def camera_to_JSON(id, camera : Camera):
66 | Rt = np.zeros((4, 4))
67 | Rt[:3, :3] = camera.R.transpose()
68 | Rt[:3, 3] = camera.T
69 | Rt[3, 3] = 1.0
70 |
71 | W2C = np.linalg.inv(Rt)
72 | pos = W2C[:3, 3]
73 | rot = W2C[:3, :3]
74 | serializable_array_2d = [x.tolist() for x in rot]
75 | camera_entry = {
76 | 'id' : id,
77 | 'img_name' : camera.image_name,
78 | 'width' : camera.width,
79 | 'height' : camera.height,
80 | 'position': pos.tolist(),
81 | 'rotation': serializable_array_2d,
82 | 'fy' : fov2focal(camera.FovY, camera.height),
83 | 'fx' : fov2focal(camera.FovX, camera.width)
84 | }
85 | return camera_entry
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import sys
14 | from datetime import datetime
15 | import numpy as np
16 | import random
17 |
18 | def inverse_sigmoid(x):
19 | return torch.log(x/(1-x))
20 |
21 | def PILtoTorch(pil_image, resolution):
22 | resized_image_PIL = pil_image.resize(resolution)
23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
24 | if len(resized_image.shape) == 3:
25 | return resized_image.permute(2, 0, 1)
26 | else:
27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
28 |
29 | def get_expon_lr_func(
30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
31 | ):
32 | """
33 | Copied from Plenoxels
34 |
35 | Continuous learning rate decay function. Adapted from JaxNeRF
36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
37 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
39 | function of lr_delay_mult, such that the initial learning rate is
40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
41 | to the normal learning rate when steps>lr_delay_steps.
42 | :param conf: config subtree 'lr' or similar
43 | :param max_steps: int, the number of steps during optimization.
44 | :return HoF which takes step as input
45 | """
46 |
47 | def helper(step):
48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
49 | # Disable this parameter
50 | return 0.0
51 | if lr_delay_steps > 0:
52 | # A kind of reverse cosine decay.
53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
55 | )
56 | else:
57 | delay_rate = 1.0
58 | t = np.clip(step / max_steps, 0, 1)
59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
60 | return delay_rate * log_lerp
61 |
62 | return helper
63 |
64 | def strip_lowerdiag(L):
65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
66 |
67 | uncertainty[:, 0] = L[:, 0, 0]
68 | uncertainty[:, 1] = L[:, 0, 1]
69 | uncertainty[:, 2] = L[:, 0, 2]
70 | uncertainty[:, 3] = L[:, 1, 1]
71 | uncertainty[:, 4] = L[:, 1, 2]
72 | uncertainty[:, 5] = L[:, 2, 2]
73 | return uncertainty
74 |
75 | def strip_symmetric(sym):
76 | return strip_lowerdiag(sym)
77 |
78 | def build_rotation(r):
79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
80 |
81 | q = r / norm[:, None]
82 |
83 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
84 |
85 | r = q[:, 0]
86 | x = q[:, 1]
87 | y = q[:, 2]
88 | z = q[:, 3]
89 |
90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z)
91 | R[:, 0, 1] = 2 * (x*y - r*z)
92 | R[:, 0, 2] = 2 * (x*z + r*y)
93 | R[:, 1, 0] = 2 * (x*y + r*z)
94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z)
95 | R[:, 1, 2] = 2 * (y*z - r*x)
96 | R[:, 2, 0] = 2 * (x*z - r*y)
97 | R[:, 2, 1] = 2 * (y*z + r*x)
98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y)
99 | return R
100 |
101 | def build_scaling_rotation(s, r):
102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
103 | R = build_rotation(r)
104 |
105 | L[:,0,0] = s[:,0]
106 | L[:,1,1] = s[:,1]
107 | L[:,2,2] = 0 #s[:,2]
108 |
109 | L = R @ L
110 | return L
111 |
112 | def safe_state(silent):
113 | old_f = sys.stdout
114 | class F:
115 | def __init__(self, silent):
116 | self.silent = silent
117 |
118 | def write(self, x):
119 | if not self.silent:
120 | if x.endswith("\n"):
121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
122 | else:
123 | old_f.write(x)
124 |
125 | def flush(self):
126 | old_f.flush()
127 |
128 | sys.stdout = F(silent)
129 |
130 | random.seed(0)
131 | np.random.seed(0)
132 | torch.manual_seed(0)
133 | torch.cuda.set_device(torch.device("cuda:0"))
134 |
135 |
136 |
137 |
138 | def create_rotation_matrix_from_direction_vector_batch(direction_vectors):
139 | # Normalize the batch of direction vectors
140 | direction_vectors = direction_vectors / torch.norm(direction_vectors, dim=-1, keepdim=True)
141 | # Create a batch of arbitrary vectors that are not collinear with the direction vectors
142 | v1 = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32).to(direction_vectors.device).expand(direction_vectors.shape[0], -1).clone()
143 | is_collinear = torch.all(torch.abs(direction_vectors - v1) < 1e-5, dim=-1)
144 | v1[is_collinear] = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32).to(direction_vectors.device)
145 |
146 | # Calculate the first orthogonal vectors
147 | v1 = torch.cross(direction_vectors, v1)
148 | v1 = v1 / (torch.norm(v1, dim=-1, keepdim=True))
149 | # Calculate the second orthogonal vectors by taking the cross product
150 | v2 = torch.cross(direction_vectors, v1)
151 | v2 = v2 / (torch.norm(v2, dim=-1, keepdim=True))
152 | # Create the batch of rotation matrices with the direction vectors as the last columns
153 | rotation_matrices = torch.stack((v1, v2, direction_vectors), dim=-1)
154 | return rotation_matrices
155 |
156 | # from kornia.geometry import conversions
157 | # def normal_to_rotation(normals):
158 | # rotations = create_rotation_matrix_from_direction_vector_batch(normals)
159 | # rotations = conversions.rotation_matrix_to_quaternion(rotations,eps=1e-5, order=conversions.QuaternionCoeffOrder.WXYZ)
160 | # return rotations
161 |
162 |
163 | def colormap(img, cmap='jet'):
164 | import matplotlib.pyplot as plt
165 | W, H = img.shape[:2]
166 | dpi = 300
167 | fig, ax = plt.subplots(1, figsize=(H/dpi, W/dpi), dpi=dpi)
168 | im = ax.imshow(img, cmap=cmap)
169 | ax.set_axis_off()
170 | fig.colorbar(im, ax=ax)
171 | fig.tight_layout()
172 | fig.canvas.draw()
173 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
174 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
175 | img = torch.from_numpy(data / 255.).float().permute(2,0,1)
176 | plt.close()
177 | return img
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 | class BasicPointCloud(NamedTuple):
18 | points : np.array
19 | colors : np.array
20 | normals : np.array
21 |
22 | def geom_transform_points(points, transf_matrix):
23 | P, _ = points.shape
24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
25 | points_hom = torch.cat([points, ones], dim=1)
26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
27 |
28 | denom = points_out[..., 3:] + 0.0000001
29 | return (points_out[..., :3] / denom).squeeze(dim=0)
30 |
31 | def getWorld2View(R, t):
32 | Rt = np.zeros((4, 4))
33 | Rt[:3, :3] = R.transpose()
34 | Rt[:3, 3] = t
35 | Rt[3, 3] = 1.0
36 | return np.float32(Rt)
37 |
38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
39 | Rt = np.zeros((4, 4))
40 | Rt[:3, :3] = R.transpose()
41 | Rt[:3, 3] = t
42 | Rt[3, 3] = 1.0
43 |
44 | C2W = np.linalg.inv(Rt)
45 | cam_center = C2W[:3, 3]
46 | cam_center = (cam_center + translate) * scale
47 | C2W[:3, 3] = cam_center
48 | Rt = np.linalg.inv(C2W)
49 | return np.float32(Rt)
50 |
51 | def getProjectionMatrix(znear, zfar, fovX, fovY):
52 | tanHalfFovY = math.tan((fovY / 2))
53 | tanHalfFovX = math.tan((fovX / 2))
54 |
55 | top = tanHalfFovY * znear
56 | bottom = -top
57 | right = tanHalfFovX * znear
58 | left = -right
59 |
60 | P = torch.zeros(4, 4)
61 |
62 | z_sign = 1.0
63 |
64 | P[0, 0] = 2.0 * znear / (right - left)
65 | P[1, 1] = 2.0 * znear / (top - bottom)
66 | P[0, 2] = (right + left) / (right - left)
67 | P[1, 2] = (top + bottom) / (top - bottom)
68 | P[3, 2] = z_sign
69 | P[2, 2] = z_sign * zfar / (zfar - znear)
70 | P[2, 3] = -(zfar * znear) / (zfar - znear)
71 | return P
72 |
73 | def fov2focal(fov, pixels):
74 | return pixels / (2 * math.tan(fov / 2))
75 |
76 | def focal2fov(focal, pixels):
77 | return 2*math.atan(pixels/(2*focal))
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 |
14 | def mse(img1, img2):
15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
16 |
17 | def psnr(img1, img2):
18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
19 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
20 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 | def l1_loss(network_output, gt):
18 | return torch.abs((network_output - gt)).mean()
19 |
20 | def l2_loss(network_output, gt):
21 | return ((network_output - gt) ** 2).mean()
22 |
23 | def gaussian(window_size, sigma):
24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
25 | return gauss / gauss.sum()
26 |
27 |
28 | def smooth_loss(disp, img):
29 | grad_disp_x = torch.abs(disp[:,1:-1, :-2] + disp[:,1:-1,2:] - 2 * disp[:,1:-1,1:-1])
30 | grad_disp_y = torch.abs(disp[:,:-2, 1:-1] + disp[:,2:,1:-1] - 2 * disp[:,1:-1,1:-1])
31 | grad_img_x = torch.mean(torch.abs(img[:, 1:-1, :-2] - img[:, 1:-1, 2:]), 0, keepdim=True) * 0.5
32 | grad_img_y = torch.mean(torch.abs(img[:, :-2, 1:-1] - img[:, 2:, 1:-1]), 0, keepdim=True) * 0.5
33 | grad_disp_x *= torch.exp(-grad_img_x)
34 | grad_disp_y *= torch.exp(-grad_img_y)
35 | return grad_disp_x.mean() + grad_disp_y.mean()
36 |
37 | def create_window(window_size, channel):
38 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
39 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
40 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
41 | return window
42 |
43 | def ssim(img1, img2, window_size=11, size_average=True):
44 | channel = img1.size(-3)
45 | window = create_window(window_size, channel)
46 |
47 | if img1.is_cuda:
48 | window = window.cuda(img1.get_device())
49 | window = window.type_as(img1)
50 |
51 | return _ssim(img1, img2, window, window_size, channel, size_average)
52 |
53 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
54 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
55 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
56 |
57 | mu1_sq = mu1.pow(2)
58 | mu2_sq = mu2.pow(2)
59 | mu1_mu2 = mu1 * mu2
60 |
61 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
62 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
63 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
64 |
65 | C1 = 0.01 ** 2
66 | C2 = 0.03 ** 2
67 |
68 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
69 |
70 | if size_average:
71 | return ssim_map.mean()
72 | else:
73 | return ssim_map.mean(0)
74 |
75 |
--------------------------------------------------------------------------------
/utils/mcube_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2024, ShanghaiTech
3 | # SVIP research group, https://github.com/svip-lab
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact huangbb@shanghaitech.edu.cn
10 | #
11 |
12 | import numpy as np
13 | import torch
14 | import trimesh
15 | from skimage import measure
16 | # modified from here https://github.com/autonomousvision/sdfstudio/blob/370902a10dbef08cb3fe4391bd3ed1e227b5c165/nerfstudio/utils/marching_cubes.py#L201
17 | def marching_cubes_with_contraction(
18 | sdf,
19 | resolution=512,
20 | bounding_box_min=(-1.0, -1.0, -1.0),
21 | bounding_box_max=(1.0, 1.0, 1.0),
22 | return_mesh=False,
23 | level=0,
24 | simplify_mesh=True,
25 | inv_contraction=None,
26 | max_range=32.0,
27 | ):
28 | assert resolution % 512 == 0
29 |
30 | resN = resolution
31 | cropN = 512
32 | level = 0
33 | N = resN // cropN
34 |
35 | grid_min = bounding_box_min
36 | grid_max = bounding_box_max
37 | xs = np.linspace(grid_min[0], grid_max[0], N + 1)
38 | ys = np.linspace(grid_min[1], grid_max[1], N + 1)
39 | zs = np.linspace(grid_min[2], grid_max[2], N + 1)
40 |
41 | meshes = []
42 | for i in range(N):
43 | for j in range(N):
44 | for k in range(N):
45 | print(i, j, k)
46 | x_min, x_max = xs[i], xs[i + 1]
47 | y_min, y_max = ys[j], ys[j + 1]
48 | z_min, z_max = zs[k], zs[k + 1]
49 |
50 | x = np.linspace(x_min, x_max, cropN)
51 | y = np.linspace(y_min, y_max, cropN)
52 | z = np.linspace(z_min, z_max, cropN)
53 |
54 | xx, yy, zz = np.meshgrid(x, y, z, indexing="ij")
55 | points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()
56 |
57 | @torch.no_grad()
58 | def evaluate(points):
59 | z = []
60 | for _, pnts in enumerate(torch.split(points, 256**3, dim=0)):
61 | z.append(sdf(pnts))
62 | z = torch.cat(z, axis=0)
63 | return z
64 |
65 | # construct point pyramids
66 | points = points.reshape(cropN, cropN, cropN, 3)
67 | points = points.reshape(-1, 3)
68 | pts_sdf = evaluate(points.contiguous())
69 | z = pts_sdf.detach().cpu().numpy()
70 | if not (np.min(z) > level or np.max(z) < level):
71 | z = z.astype(np.float32)
72 | verts, faces, normals, _ = measure.marching_cubes(
73 | volume=z.reshape(cropN, cropN, cropN),
74 | level=level,
75 | spacing=(
76 | (x_max - x_min) / (cropN - 1),
77 | (y_max - y_min) / (cropN - 1),
78 | (z_max - z_min) / (cropN - 1),
79 | ),
80 | )
81 | verts = verts + np.array([x_min, y_min, z_min])
82 | meshcrop = trimesh.Trimesh(verts, faces, normals)
83 | meshes.append(meshcrop)
84 |
85 | print("finished one block")
86 |
87 | combined = trimesh.util.concatenate(meshes)
88 | combined.merge_vertices(digits_vertex=6)
89 |
90 | # inverse contraction and clipping the points range
91 | if inv_contraction is not None:
92 | combined.vertices = inv_contraction(torch.from_numpy(combined.vertices).float().cuda()).cpu().numpy()
93 | combined.vertices = np.clip(combined.vertices, -max_range, max_range)
94 |
95 | return combined
--------------------------------------------------------------------------------
/utils/mesh_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2024, ShanghaiTech
3 | # SVIP research group, https://github.com/svip-lab
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact huangbb@shanghaitech.edu.cn
10 | #
11 | import os
12 | from functools import partial
13 | from statistics import mean, stdev
14 |
15 | import cv2
16 | import numpy as np
17 | import open3d as o3d
18 | import torch
19 | from tqdm import tqdm
20 |
21 | from utils.render_utils import save_img_u8
22 |
23 |
24 | def post_process_mesh(mesh, cluster_to_keep=1000):
25 | """
26 | Post-process a mesh to filter out floaters and disconnected parts
27 | """
28 | import copy
29 | print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep))
30 | mesh_0 = copy.deepcopy(mesh)
31 | with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
32 | triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles())
33 |
34 | triangle_clusters = np.asarray(triangle_clusters)
35 | cluster_n_triangles = np.asarray(cluster_n_triangles)
36 | cluster_area = np.asarray(cluster_area)
37 | n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep]
38 | n_cluster = max(n_cluster, 50) # filter meshes smaller than 50
39 | triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster
40 | mesh_0.remove_triangles_by_mask(triangles_to_remove)
41 | mesh_0.remove_unreferenced_vertices()
42 | mesh_0.remove_degenerate_triangles()
43 | print("num vertices raw {}".format(len(mesh.vertices)))
44 | print("num vertices post {}".format(len(mesh_0.vertices)))
45 | return mesh_0
46 |
47 |
48 | def to_cam_open3d(viewpoint_stack):
49 | camera_traj = []
50 | for i, viewpoint_cam in enumerate(viewpoint_stack):
51 | W = viewpoint_cam.image_width
52 | H = viewpoint_cam.image_height
53 | ndc2pix = torch.tensor([
54 | [W / 2, 0, 0, (W - 1) / 2],
55 | [0, H / 2, 0, (H - 1) / 2],
56 | [0, 0, 0, 1]]).float().cuda().T
57 | intrins = (viewpoint_cam.projection_matrix @ ndc2pix)[:3, :3].T
58 | intrinsic = o3d.camera.PinholeCameraIntrinsic(
59 | width=viewpoint_cam.image_width,
60 | height=viewpoint_cam.image_height,
61 | cx=intrins[0, 2].item(),
62 | cy=intrins[1, 2].item(),
63 | fx=intrins[0, 0].item(),
64 | fy=intrins[1, 1].item()
65 | )
66 |
67 | extrinsic = np.asarray((viewpoint_cam.world_view_transform.T).cpu().numpy())
68 | camera = o3d.camera.PinholeCameraParameters()
69 | camera.extrinsic = extrinsic
70 | camera.intrinsic = intrinsic
71 | camera_traj.append(camera)
72 |
73 | return camera_traj
74 |
75 |
76 | class GaussianExtractor(object):
77 | def __init__(self, gaussians, render, pipe, bg_color=None, additional_return=True):
78 | """
79 | a class that extracts attributes a scene presented by 2DGS
80 |
81 | Usage example:
82 | >>> gaussExtrator = GaussianExtractor(gaussians, render, pipe)
83 | >>> gaussExtrator.reconstruction(view_points)
84 | >>> mesh = gaussExtractor.export_mesh_bounded(...)
85 | """
86 | if bg_color is None:
87 | bg_color = [0, 0, 0]
88 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
89 | self.gaussians = gaussians
90 | self.render = partial(render, pipe=pipe, bg_color=background, additional_return=additional_return)
91 | self._additional_return = additional_return
92 | self.clean()
93 |
94 | @torch.no_grad()
95 | def clean(self):
96 | self.depthmaps = []
97 | # self.alphamaps = []
98 | self.rgbmaps = []
99 | # self.normals = []
100 | # self.depth_normals = []
101 | self.viewpoint_stack = []
102 |
103 | @torch.no_grad()
104 | def reconstruction(self, viewpoint_stack):
105 | """
106 | reconstruct radiance field given cameras
107 | """
108 | self.clean()
109 | self.viewpoint_stack = viewpoint_stack
110 | times = []
111 | if len(self.viewpoint_stack) > 1:
112 | iterator = tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields")
113 | else:
114 | iterator = enumerate(self.viewpoint_stack)
115 |
116 | for i, viewpoint_cam in iterator:
117 | render_pkg = self.render(viewpoint_cam, self.gaussians)
118 | times.append(render_pkg['fps'])
119 | rgb = render_pkg['render']
120 | self.rgbmaps.append(rgb.cpu())
121 | if self._additional_return:
122 | alpha = render_pkg['rend_alpha']
123 | normal = torch.nn.functional.normalize(render_pkg['rend_normal'], dim=0)
124 | depth = render_pkg['surf_depth']
125 | depth_normal = render_pkg['surf_normal']
126 | self.depthmaps.append(depth.cpu())
127 | # self.alphamaps.append(alpha.cpu())
128 | # self.normals.append(normal.cpu())
129 | # self.depth_normals.append(depth_normal.cpu())
130 |
131 | self.times = times
132 | mean_time = mean(times)
133 | std_time = 0
134 | if len(times) > 1:
135 | std_time = stdev(times)
136 | print("FPS:", mean_time, " std:", std_time)
137 | # self.rgbmaps = torch.stack(self.rgbmaps, dim=0)
138 | # self.depthmaps = torch.stack(self.depthmaps, dim=0)
139 | # self.alphamaps = torch.stack(self.alphamaps, dim=0)
140 | # self.depth_normals = torch.stack(self.depth_normals, dim=0)
141 | self.estimate_bounding_sphere()
142 |
143 | return mean_time, std_time
144 |
145 | def estimate_bounding_sphere(self):
146 | """
147 | Estimate the bounding sphere given camera pose
148 | """
149 | from utils.render_utils import focus_point_fn
150 | torch.cuda.empty_cache()
151 | c2ws = np.array(
152 | [np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in self.viewpoint_stack])
153 | poses = c2ws[:, :3, :] @ np.diag([1, -1, -1, 1])
154 | center = (focus_point_fn(poses))
155 | self.radius = np.linalg.norm(c2ws[:, :3, 3] - center, axis=-1).min()
156 | self.center = torch.from_numpy(center).float().cuda()
157 | print(f"The estimated bounding radius is {self.radius:.2f}")
158 | print(f"Use at least {2.0 * self.radius:.2f} for depth_trunc")
159 |
160 | @torch.no_grad()
161 | def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, mask_backgrond=True):
162 | """
163 | Perform TSDF fusion given a fixed depth range, used in the paper.
164 |
165 | voxel_size: the voxel size of the volume
166 | sdf_trunc: truncation value
167 | depth_trunc: maximum depth range, should depended on the scene's scales
168 | mask_backgrond: whether to mask backgroud, only works when the dataset have masks
169 |
170 | return o3d.mesh
171 | """
172 | print("Running tsdf volume integration ...")
173 | print(f'voxel_size: {voxel_size}')
174 | print(f'sdf_trunc: {sdf_trunc}')
175 | print(f'depth_truc: {depth_trunc}')
176 |
177 | volume = o3d.pipelines.integration.ScalableTSDFVolume(
178 | voxel_length=voxel_size,
179 | sdf_trunc=sdf_trunc,
180 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
181 | )
182 |
183 | for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack)), desc="TSDF integration progress"):
184 | rgb = self.rgbmaps[i]
185 | depth = self.depthmaps[i]
186 |
187 | # if we have mask provided, use it
188 | if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None):
189 | depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0
190 |
191 | # make open3d rgbd
192 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
193 | o3d.geometry.Image(
194 | np.asarray(np.clip(rgb.permute(1, 2, 0).cpu().numpy(), 0.0, 1.0) * 255, order="C", dtype=np.uint8)),
195 | o3d.geometry.Image(np.asarray(depth.permute(1, 2, 0).cpu().numpy(), order="C")),
196 | depth_trunc=depth_trunc, convert_rgb_to_intensity=False,
197 | depth_scale=1.0
198 | )
199 |
200 | volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic)
201 |
202 | mesh = volume.extract_triangle_mesh()
203 | return mesh
204 |
205 | @torch.no_grad()
206 | def extract_mesh_unbounded(self, resolution=1024):
207 | """
208 | Experimental features, extracting meshes from unbounded scenes, not fully test across datasets.
209 | return o3d.mesh
210 | """
211 |
212 | def contract(x):
213 | mag = torch.linalg.norm(x, ord=2, dim=-1)[..., None]
214 | return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag))
215 |
216 | def uncontract(y):
217 | mag = torch.linalg.norm(y, ord=2, dim=-1)[..., None]
218 | return torch.where(mag < 1, y, (1 / (2 - mag) * (y / mag)))
219 |
220 | def compute_sdf_perframe(i, points, depthmap, rgbmap, viewpoint_cam):
221 | """
222 | compute per frame sdf
223 | """
224 | new_points = torch.cat([points, torch.ones_like(points[..., :1])],
225 | dim=-1) @ viewpoint_cam.full_proj_transform
226 | z = new_points[..., -1:]
227 | pix_coords = (new_points[..., :2] / new_points[..., -1:])
228 | mask_proj = ((pix_coords > -1.) & (pix_coords < 1.) & (z > 0)).all(dim=-1)
229 | sampled_depth = torch.nn.functional.grid_sample(depthmap.cuda()[None], pix_coords[None, None],
230 | mode='bilinear', padding_mode='border',
231 | align_corners=True).reshape(-1, 1)
232 | sampled_rgb = torch.nn.functional.grid_sample(rgbmap.cuda()[None], pix_coords[None, None], mode='bilinear',
233 | padding_mode='border', align_corners=True).reshape(3, -1).T
234 | sdf = (sampled_depth - z)
235 | return sdf, sampled_rgb, mask_proj
236 |
237 | def compute_unbounded_tsdf(samples, inv_contraction, voxel_size, return_rgb=False):
238 | """
239 | Fusion all frames, perform adaptive sdf_funcation on the contract spaces.
240 | """
241 | if inv_contraction is not None:
242 | mask = torch.linalg.norm(samples, dim=-1) > 1
243 | # adaptive sdf_truncation
244 | sdf_trunc = 5 * voxel_size * torch.ones_like(samples[:, 0])
245 | sdf_trunc[mask] *= 1 / (2 - torch.linalg.norm(samples, dim=-1)[mask].clamp(max=1.9))
246 | samples = inv_contraction(samples)
247 | else:
248 | sdf_trunc = 5 * voxel_size
249 |
250 | tsdfs = torch.ones_like(samples[:, 0]) * 1
251 | rgbs = torch.zeros((samples.shape[0], 3)).cuda()
252 |
253 | weights = torch.ones_like(samples[:, 0])
254 | for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="TSDF integration progress"):
255 | sdf, rgb, mask_proj = compute_sdf_perframe(i, samples,
256 | depthmap=self.depthmaps[i],
257 | rgbmap=self.rgbmaps[i],
258 | viewpoint_cam=self.viewpoint_stack[i],
259 | )
260 |
261 | # volume integration
262 | sdf = sdf.flatten()
263 | mask_proj = mask_proj & (sdf > -sdf_trunc)
264 | sdf = torch.clamp(sdf / sdf_trunc, min=-1.0, max=1.0)[mask_proj]
265 | w = weights[mask_proj]
266 | wp = w + 1
267 | tsdfs[mask_proj] = (tsdfs[mask_proj] * w + sdf) / wp
268 | rgbs[mask_proj] = (rgbs[mask_proj] * w[:, None] + rgb[mask_proj]) / wp[:, None]
269 | # update weight
270 | weights[mask_proj] = wp
271 |
272 | if return_rgb:
273 | return tsdfs, rgbs
274 |
275 | return tsdfs
276 |
277 | normalize = lambda x: (x - self.center) / self.radius
278 | unnormalize = lambda x: (x * self.radius) + self.center
279 | inv_contraction = lambda x: unnormalize(uncontract(x))
280 |
281 | N = resolution
282 | voxel_size = (self.radius * 2 / N)
283 | print(f"Computing sdf gird resolution {N} x {N} x {N}")
284 | print(f"Define the voxel_size as {voxel_size}")
285 | sdf_function = lambda x: compute_unbounded_tsdf(x, inv_contraction, voxel_size)
286 | from utils.mcube_utils import marching_cubes_with_contraction
287 | R = contract(normalize(self.gaussians.get_xyz)).norm(dim=-1).cpu().numpy()
288 | R = np.quantile(R, q=0.95)
289 | R = min(R + 0.01, 1.9)
290 |
291 | mesh = marching_cubes_with_contraction(
292 | sdf=sdf_function,
293 | bounding_box_min=(-R, -R, -R),
294 | bounding_box_max=(R, R, R),
295 | level=0,
296 | resolution=N,
297 | inv_contraction=inv_contraction,
298 | )
299 |
300 | # coloring the mesh
301 | torch.cuda.empty_cache()
302 | mesh = mesh.as_open3d
303 | print("texturing mesh ... ")
304 | _, rgbs = compute_unbounded_tsdf(torch.tensor(np.asarray(mesh.vertices)).float().cuda(), inv_contraction=None,
305 | voxel_size=voxel_size, return_rgb=True)
306 | mesh.vertex_colors = o3d.utility.Vector3dVector(rgbs.cpu().numpy())
307 | return mesh
308 |
309 | @torch.no_grad()
310 | def export_image(self, path, export_gt=True, print_fps=False):
311 | render_path = os.path.join(path, "renders")
312 | vis_path = os.path.join(path, "vis")
313 | os.makedirs(render_path, exist_ok=True)
314 | os.makedirs(vis_path, exist_ok=True)
315 | if export_gt:
316 | gts_path = os.path.join(path, "gt")
317 | os.makedirs(gts_path, exist_ok=True)
318 |
319 | for idx, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="export images"):
320 | if export_gt:
321 | gt = viewpoint_cam.original_image[0:3, :, :]
322 | save_img_u8(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
323 |
324 | image = self.rgbmaps[idx]
325 | if print_fps:
326 | fps = '{:4d}'.format(int(self.times[idx]))
327 | image = image.numpy()
328 | image = np.transpose(image, (1, 2, 0)).copy()
329 | cv2.putText(image, 'FPS: ' + str(fps), (10, 50), cv2.FONT_HERSHEY_SIMPLEX,
330 | 1, (0, 0, 0), 3, 2)
331 | cv2.putText(image, 'FPS: ' + str(fps), (10, 50), cv2.FONT_HERSHEY_SIMPLEX,
332 | 1, (1, 1, 1), 1, 2)
333 | image = np.transpose(image, (2, 0, 1))
334 | image = torch.tensor(image)
335 | save_img_u8(image, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
336 | if self._additional_return:
337 | depth = self.depthmaps[idx][0]
338 | #save_img_f32(depth, os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".tiff"))
339 | depth[depth > 30] = 30
340 | depth = depth**0.1
341 | depth = 1 - (depth - depth.min()) / (depth.max() - depth.min())
342 | save_img_u8(depth, os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".png"))
343 | #save_img_u8(self.normals[idx] * 0.5 + 0.5, os.path.join(vis_path, 'normal_{0:05d}'.format(idx) + ".png"))
344 | #save_img_u8(self.depth_normals[idx] * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + ".png"))
345 |
--------------------------------------------------------------------------------
/utils/point_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import os, cv2
6 | import matplotlib.pyplot as plt
7 | import math
8 |
9 | def depths_to_points(view, depthmap):
10 | c2w = (view.world_view_transform.T).inverse()
11 | W, H = view.image_width, view.image_height
12 | fx = W / (2 * math.tan(view.FoVx / 2.))
13 | fy = H / (2 * math.tan(view.FoVy / 2.))
14 | intrins = torch.tensor(
15 | [[fx, 0., W/2.],
16 | [0., fy, H/2.],
17 | [0., 0., 1.0]]
18 | ).float().cuda()
19 | grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy')
20 | points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3)
21 | rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T
22 | rays_o = c2w[:3,3]
23 | points = depthmap.reshape(-1, 1) * rays_d + rays_o
24 | return points
25 |
26 | def depth_to_normal(view, depth):
27 | """
28 | view: view camera
29 | depth: depthmap
30 | """
31 | points = depths_to_points(view, depth).reshape(*depth.shape[1:], 3)
32 | output = torch.zeros_like(points)
33 | dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0)
34 | dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1)
35 | normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
36 | output[1:-1, 1:-1, :] = normal_map
37 | return output
--------------------------------------------------------------------------------
/utils/reconstruction_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2024, ShanghaiTech
3 | # SVIP research group, https://github.com/svip-lab
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact huangbb@shanghaitech.edu.cn
10 | #
11 | import os
12 | from functools import partial
13 | from statistics import mean, stdev
14 |
15 | import cv2
16 | import numpy as np
17 | import torch
18 | from tqdm import tqdm
19 |
20 | from utils.render_utils import save_img_u8
21 |
22 |
23 | class GaussianExtractor(object):
24 | def __init__(self, gaussians, render, pipe, bg_color=None, additional_return=True):
25 | """
26 | a class that extracts attributes a scene presented by 2DGS
27 |
28 | Usage example:
29 | >>> gaussExtrator = GaussianExtractor(gaussians, render, pipe)
30 | >>> gaussExtrator.reconstruction(view_points)
31 | >>> mesh = gaussExtractor.export_mesh_bounded(...)
32 | """
33 | if bg_color is None:
34 | bg_color = [0, 0, 0]
35 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
36 | self.gaussians = gaussians
37 | self.render = partial(render, pipe=pipe, bg_color=background, additional_return=additional_return)
38 | self._additional_return = additional_return
39 | self.clean()
40 |
41 | @torch.no_grad()
42 | def clean(self):
43 | self.depthmaps = []
44 | self.alphamaps = []
45 | self.rgbmaps = []
46 | self.normals = []
47 | self.depth_normals = []
48 | self.viewpoint_stack = []
49 | self.times = []
50 |
51 | @torch.no_grad()
52 | def reconstruction(self, viewpoint_stack):
53 | """
54 | reconstruct radiance field given cameras
55 | """
56 | self.clean()
57 | self.viewpoint_stack = viewpoint_stack
58 | times = []
59 | if len(self.viewpoint_stack) > 1:
60 | iterator = tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields")
61 | else:
62 | iterator = enumerate(self.viewpoint_stack)
63 |
64 | for i, viewpoint_cam in iterator:
65 | render_pkg = self.render(viewpoint_cam, self.gaussians)
66 | times.append(render_pkg['fps'])
67 | rgb = render_pkg['render']
68 | self.rgbmaps.append(rgb.cpu())
69 | if self._additional_return:
70 | alpha = render_pkg['rend_alpha']
71 | normal = torch.nn.functional.normalize(render_pkg['rend_normal'], dim=0)
72 | depth = render_pkg['surf_depth']
73 | depth_normal = render_pkg['surf_normal']
74 | self.depthmaps.append(depth.cpu())
75 | self.alphamaps.append(alpha.cpu())
76 | self.normals.append(normal.cpu())
77 | self.depth_normals.append(depth_normal.cpu())
78 |
79 | self.times = times
80 | mean_time = mean(times)
81 | std_time = 0
82 | if len(times) > 1:
83 | std_time = stdev(times)
84 | print("FPS:", mean_time, " std:", std_time)
85 | #self.rgbmaps = torch.stack(self.rgbmaps, dim=0)
86 | if self._additional_return:
87 | self.depthmaps = torch.stack(self.depthmaps, dim=0)
88 | self.alphamaps = torch.stack(self.alphamaps, dim=0)
89 | self.depth_normals = torch.stack(self.depth_normals, dim=0)
90 |
91 | return mean_time, std_time
92 |
93 | @torch.no_grad()
94 | def export_image(self, path, export_gt=True, print_fps=False):
95 | render_path = os.path.join(path, "renders")
96 | os.makedirs(render_path, exist_ok=True)
97 | if export_gt:
98 | gts_path = os.path.join(path, "gt")
99 | os.makedirs(gts_path, exist_ok=True)
100 |
101 | for idx, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="export images"):
102 | if export_gt:
103 | gt = viewpoint_cam.original_image[0:3, :, :]
104 | save_img_u8(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
105 |
106 | image = self.rgbmaps[idx]
107 | if print_fps:
108 | fps = '{:4d}'.format(int(self.times[idx]))
109 | image = image.numpy()
110 | image = np.transpose(image, (1, 2, 0)).copy()
111 | cv2.putText(image, 'FPS: ' + str(fps), (10, 50), cv2.FONT_HERSHEY_SIMPLEX,
112 | 1, (0, 0, 0), 3, 2)
113 | cv2.putText(image, 'FPS: ' + str(fps), (10, 50), cv2.FONT_HERSHEY_SIMPLEX,
114 | 1, (1, 1, 1), 1, 2)
115 | image = np.transpose(image, (2, 0, 1))
116 | image = torch.tensor(image)
117 | save_img_u8(image, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
118 |
--------------------------------------------------------------------------------
/utils/render_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 | import os
17 | from typing import Tuple
18 |
19 | import mediapy as media
20 | import numpy as np
21 | import torch
22 | import torchvision
23 | from PIL import Image
24 | from tqdm import tqdm
25 |
26 |
27 | def normalize(x: np.ndarray) -> np.ndarray:
28 | """Normalization helper function."""
29 | return x / np.linalg.norm(x)
30 |
31 | def pad_poses(p: np.ndarray) -> np.ndarray:
32 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
33 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
34 | return np.concatenate([p[..., :3, :4], bottom], axis=-2)
35 |
36 |
37 | def unpad_poses(p: np.ndarray) -> np.ndarray:
38 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
39 | return p[..., :3, :4]
40 |
41 |
42 | def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
43 | """Recenter poses around the origin."""
44 | cam2world = average_pose(poses)
45 | transform = np.linalg.inv(pad_poses(cam2world))
46 | poses = transform @ pad_poses(poses)
47 | return unpad_poses(poses), transform
48 |
49 |
50 | def average_pose(poses: np.ndarray) -> np.ndarray:
51 | """New pose using average position, z-axis, and up vector of input poses."""
52 | position = poses[:, :3, 3].mean(0)
53 | z_axis = poses[:, :3, 2].mean(0)
54 | up = poses[:, :3, 1].mean(0)
55 | cam2world = viewmatrix(z_axis, up, position)
56 | return cam2world
57 |
58 | def viewmatrix(lookdir: np.ndarray, up: np.ndarray,
59 | position: np.ndarray) -> np.ndarray:
60 | """Construct lookat view matrix."""
61 | vec2 = normalize(lookdir)
62 | vec0 = normalize(np.cross(up, vec2))
63 | vec1 = normalize(np.cross(vec2, vec0))
64 | m = np.stack([vec0, vec1, vec2, position], axis=1)
65 | return m
66 |
67 | def focus_point_fn(poses: np.ndarray) -> np.ndarray:
68 | """Calculate nearest point to all focal axes in poses."""
69 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
70 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
71 | mt_m = np.transpose(m, [0, 2, 1]) @ m
72 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
73 | return focus_pt
74 |
75 | def transform_poses_pca(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
76 | """Transforms poses so principal components lie on XYZ axes.
77 |
78 | Args:
79 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
80 |
81 | Returns:
82 | A tuple (poses, transform), with the transformed poses and the applied
83 | camera_to_world transforms.
84 | """
85 | t = poses[:, :3, 3]
86 | t_mean = t.mean(axis=0)
87 | t = t - t_mean
88 |
89 | eigval, eigvec = np.linalg.eig(t.T @ t)
90 | # Sort eigenvectors in order of largest to smallest eigenvalue.
91 | inds = np.argsort(eigval)[::-1]
92 | eigvec = eigvec[:, inds]
93 | rot = eigvec.T
94 | if np.linalg.det(rot) < 0:
95 | rot = np.diag(np.array([1, 1, -1])) @ rot
96 |
97 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
98 | poses_recentered = unpad_poses(transform @ pad_poses(poses))
99 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
100 |
101 | # Flip coordinate system if z component of y-axis is negative
102 | if poses_recentered.mean(axis=0)[2, 1] < 0:
103 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
104 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform
105 |
106 | return poses_recentered, transform
107 | # points = np.random.rand(3,100)
108 | # points_h = np.concatenate((points,np.ones_like(points[:1])), axis=0)
109 | # (poses_recentered @ points_h)[0]
110 | # (transform @ pad_poses(poses) @ points_h)[0,:3]
111 | # import pdb; pdb.set_trace()
112 |
113 | # # Just make sure it's it in the [-1, 1]^3 cube
114 | # scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
115 | # poses_recentered[:, :3, 3] *= scale_factor
116 | # transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
117 |
118 | # return poses_recentered, transform
119 |
120 | def generate_ellipse_path(poses: np.ndarray,
121 | n_frames: int = 120,
122 | const_speed: bool = True,
123 | z_variation: float = 0.,
124 | z_phase: float = 0.) -> np.ndarray:
125 | """Generate an elliptical render path based on the given poses."""
126 | # Calculate the focal point for the path (cameras point toward this).
127 | center = focus_point_fn(poses)
128 | # Path height sits at z=0 (in middle of zero-mean capture pattern).
129 | offset = np.array([center[0], center[1], 0])
130 |
131 | # Calculate scaling for ellipse axes based on input camera positions.
132 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
133 | # Use ellipse that is symmetric about the focal point in xy.
134 | low = -sc + offset
135 | high = sc + offset
136 | # Optional height variation need not be symmetric
137 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
138 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
139 |
140 | def get_positions(theta):
141 | # Interpolate between bounds with trig functions to get ellipse in x-y.
142 | # Optionally also interpolate in z to change camera height along path.
143 | return np.stack([
144 | low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5),
145 | low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5),
146 | z_variation * (z_low[2] + (z_high - z_low)[2] *
147 | (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
148 | ], -1)
149 |
150 | theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
151 | positions = get_positions(theta)
152 |
153 | #if const_speed:
154 |
155 | # # Resample theta angles so that the velocity is closer to constant.
156 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
157 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
158 | # positions = get_positions(theta)
159 |
160 | # Throw away duplicated last position.
161 | positions = positions[:-1]
162 |
163 | # Set path's up vector to axis closest to average of input pose up vectors.
164 | avg_up = poses[:, :3, 1].mean(0)
165 | avg_up = avg_up / np.linalg.norm(avg_up)
166 | ind_up = np.argmax(np.abs(avg_up))
167 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
168 |
169 | return np.stack([viewmatrix(p - center, up, p) for p in positions])
170 |
171 |
172 | def generate_path(viewpoint_cameras, n_frames=480):
173 | c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in viewpoint_cameras])
174 | pose = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1])
175 | pose_recenter, colmap_to_world_transform = transform_poses_pca(pose)
176 |
177 | # generate new poses
178 | new_poses = generate_ellipse_path(poses=pose_recenter, n_frames=n_frames)
179 | # warp back to orignal scale
180 | new_poses = np.linalg.inv(colmap_to_world_transform) @ pad_poses(new_poses)
181 |
182 | traj = []
183 | for c2w in new_poses:
184 | c2w = c2w @ np.diag([1, -1, -1, 1])
185 | cam = copy.deepcopy(viewpoint_cameras[0])
186 | cam.image_height = int(cam.image_height / 2) * 2
187 | cam.image_width = int(cam.image_width / 2) * 2
188 | cam.world_view_transform = torch.from_numpy(np.linalg.inv(c2w).T).float().cuda()
189 | cam.full_proj_transform = (cam.world_view_transform.unsqueeze(0).bmm(cam.projection_matrix.unsqueeze(0))).squeeze(0)
190 | cam.camera_center = cam.world_view_transform.inverse()[3, :3]
191 | traj.append(cam)
192 |
193 | return traj
194 |
195 | def load_img(pth: str) -> np.ndarray:
196 | """Load an image and cast to float32."""
197 | with open(pth, 'rb') as f:
198 | image = np.array(Image.open(f), dtype=np.float32)
199 | return image
200 |
201 |
202 | def create_videos(base_dir, input_dir, out_name, num_frames=480):
203 | """Creates videos out of the images saved to disk."""
204 | # Last two parts of checkpoint path are experiment name and scene name.
205 | video_prefix = f'{out_name}'
206 |
207 | zpad = max(5, len(str(num_frames - 1)))
208 | idx_to_str = lambda idx: str(idx).zfill(zpad)
209 |
210 | os.makedirs(base_dir, exist_ok=True)
211 |
212 | # Load one example frame to get image shape and depth range.
213 | rgb_file = os.path.join(input_dir, 'renders', f'{idx_to_str(0)}.png')
214 | rgb_frame = load_img(rgb_file)
215 | shape = rgb_frame.shape
216 | print(f'Video shape is {shape[:2]}')
217 |
218 | video_kwargs = {
219 | 'shape': shape[:2],
220 | 'codec': 'h264',
221 | 'fps': 30,
222 | 'crf': 1,
223 | }
224 |
225 | for k in ['color']:
226 | video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4')
227 | input_format = 'rgb'
228 | file_ext = 'png'
229 |
230 | if k == 'color':
231 | file0 = os.path.join(input_dir, 'renders', f'{idx_to_str(0)}.{file_ext}')
232 |
233 | if not os.path.exists(file0):
234 | print(f'Images missing for tag {k}')
235 | continue
236 | print(f'Making video {video_file}...')
237 | with media.VideoWriter(
238 | video_file, **video_kwargs, input_format=input_format) as writer:
239 | for idx in tqdm(range(num_frames)):
240 | img_file = os.path.join(input_dir, 'renders', f'{idx_to_str(idx)}.{file_ext}')
241 |
242 | if not os.path.exists(img_file):
243 | ValueError(f'Image file {img_file} does not exist.')
244 | img = load_img(img_file)
245 | img = img / 255.
246 |
247 | frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)
248 | writer.add_image(frame)
249 | idx += 1
250 |
251 | def save_img_u8(img, pth):
252 | """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG."""
253 | torchvision.utils.save_image(img, pth)
254 |
255 | def save_img_f32(depthmap, pth):
256 | """Save an image (probably a depthmap) to disk as a float32 TIFF."""
257 | with open(pth, 'wb') as f:
258 | Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF')
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from errno import EEXIST
13 | from os import makedirs, path
14 | import os
15 |
16 | def mkdir_p(folder_path):
17 | # Creates a directory. equivalent to using mkdir -p on the command line
18 | try:
19 | makedirs(folder_path)
20 | except OSError as exc: # Python >2.5
21 | if exc.errno == EEXIST and path.isdir(folder_path):
22 | pass
23 | else:
24 | raise
25 |
26 | def searchForMaxIteration(folder):
27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
28 | return max(saved_iters)
29 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 |
8 | from arguments import ModelParams, PipelineParams, get_combined_args
9 | from gaussian_renderer import GaussianModel
10 | from gaussian_renderer import render
11 | from scene import Scene
12 | from utils.general_utils import build_rotation
13 | from utils.reconstruction_utils import GaussianExtractor
14 |
15 | if __name__ == "__main__":
16 | # Set up command line argument parser
17 | parser = ArgumentParser(description="Testing script parameters")
18 | model = ModelParams(parser, sentinel=True)
19 | pipeline = PipelineParams(parser)
20 | parser.add_argument("--iteration", default=-1, type=int)
21 | args = get_combined_args(parser)
22 | print("Rendering " + args.model_path)
23 |
24 | control_panel = cv2.imread("assets/control_panel.png")[..., ::-1].astype(np.float32) / 255.
25 |
26 | dataset, iteration, pipe = model.extract(args), args.iteration, pipeline.extract(args)
27 | gaussians = GaussianModel(dataset.sh_degree, texture_preproc=True)
28 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
29 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
30 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
31 |
32 | train_dir = os.path.join(args.model_path, 'train', "ours_{}".format(scene.loaded_iter))
33 | test_dir = os.path.join(args.model_path, 'test', "ours_{}".format(scene.loaded_iter))
34 | gaussExtractor = GaussianExtractor(gaussians, render, pipe, bg_color=bg_color, additional_return=False)
35 |
36 | speed_data = {"points": len(gaussians.get_xyz)}
37 |
38 | idx = 0
39 | cameras = scene.getTestCameras()[idx: idx+1].copy()
40 | frame_num = 0
41 | while True:
42 | mean_time, std_time = gaussExtractor.reconstruction(cameras)
43 | render = gaussExtractor.rgbmaps[0].detach().cpu().numpy()
44 | render = np.transpose(render, (1, 2, 0)).copy()
45 | if frame_num == 0:
46 | scale = render.shape[1] / control_panel.shape[1]
47 | control_panel = cv2.resize(control_panel, None, fx=scale, fy=scale)
48 |
49 | if frame_num > 5:
50 | mean_time = int(mean_time)
51 | cv2.putText(render, 'FPS: ' + str(mean_time), (10, 50), cv2.FONT_HERSHEY_SIMPLEX,
52 | 1, (0, 0, 0), 3, 2)
53 | cv2.putText(render, 'FPS: ' + str(mean_time),(10, 50), cv2.FONT_HERSHEY_SIMPLEX,
54 | 1,(255, 255, 255),1,2)
55 |
56 | render = cv2.vconcat([render, control_panel])
57 | cv2.imshow("Render", render[..., ::-1])
58 | key = cv2.waitKey(-1) & 0b11111111
59 |
60 | speed_t = 0.03
61 | speed_r = speed_t / 2.0
62 | if key == ord("q"):
63 | break
64 | if key == ord("a"):
65 | cameras[0].world_view_transform[3, 0] += speed_t
66 | if key == ord("d"):
67 | cameras[0].world_view_transform[3, 0] -= speed_t
68 | if key == ord("w"):
69 | cameras[0].world_view_transform[3, 2] -= speed_t
70 | if key == ord("s"):
71 | cameras[0].world_view_transform[3, 2] += speed_t
72 | if key == ord("e"):
73 | cameras[0].world_view_transform[3, 1] += speed_t
74 | if key == ord("f"):
75 | cameras[0].world_view_transform[3, 1] -= speed_t
76 |
77 | if key == ord("j"):
78 | R = build_rotation(torch.tensor([[1-speed_r, -speed_r, 0, 0]]).cuda())[0]
79 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R)
80 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R)
81 | if key == ord("u"):
82 | R = build_rotation(torch.tensor([[1-speed_r, speed_r, 0, 0]]).cuda())[0]
83 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R)
84 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R)
85 | if key == ord("k"):
86 | R = build_rotation(torch.tensor([[1-speed_r, 0, speed_r, 0]]).cuda())[0]
87 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R)
88 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R)
89 | if key == ord("h"):
90 | R = build_rotation(torch.tensor([[1-speed_r, 0, -speed_r, 0]]).cuda())[0]
91 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R)
92 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R)
93 | if key == ord("l"):
94 | R = build_rotation(torch.tensor([[1-speed_r, 0, 0, speed_r]]).cuda())[0]
95 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R)
96 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R)
97 | if key == ord("i"):
98 | R = build_rotation(torch.tensor([[1-speed_r, 0, 0, -speed_r]]).cuda())[0]
99 | cameras[0].world_view_transform[:3, :3] = torch.mm(cameras[0].world_view_transform[:3, :3], R)
100 | cameras[0].world_view_transform[3:, :3] = torch.matmul(cameras[0].world_view_transform[3:, :3], R)
101 |
102 | if key == 32:
103 | idx += 1
104 | if idx >= len(scene.getTestCameras()):
105 | idx = 0
106 | cameras = scene.getTestCameras()[idx: idx+1].copy()
107 |
108 | cameras[0].update_proj_matrix()
109 | frame_num += 1
110 |
111 |
--------------------------------------------------------------------------------