├── .gitignore ├── .gitmodules ├── LICENSE.md ├── README.md ├── arguments └── __init__.py ├── assets ├── hierarchy_viewer_0.gif ├── logo_graphdeco.png ├── logo_inria.png ├── logo_mpi.png ├── logo_mpi.svg ├── logo_tuwien.svg └── logo_uca.png ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── preprocess ├── auto_reorient.py ├── black_mask.py ├── concat_chunks_info.py ├── copy_file_to_chunks.py ├── database.py ├── fill_database.py ├── generate_chunks.py ├── generate_colmap.py ├── generate_depth.py ├── jz_test_gen_chunk.py ├── make_chunk.py ├── make_chunks_depth_scale.py ├── make_colmap_custom_matcher.py ├── make_colmap_custom_matcher_distance.py ├── make_depth_scale.py ├── make_mask_uint8.py ├── prepare_chunk.py ├── prepare_chunk.slurm ├── read_write_model.py ├── reorient.py ├── simplify_images.py └── transform_colmap.py ├── render_hierarchy.py ├── requirements.txt ├── scene ├── OurAdam.py ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── scripts ├── coarse_train.slurm ├── consolidate.slurm ├── full_train.py └── train_chunk.slurm ├── train_coarse.py ├── train_post.py ├── train_single.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | tensorboard_3d 6 | screenshots 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "SIBR_viewers"] 2 | path = SIBR_viewers 3 | url = https://gitlab.inria.fr/sibr/sibr_core.git 4 | branch = gaussian_code_release_hierarchy 5 | [submodule "submodules/hierarchy-rasterizer"] 6 | path = submodules/hierarchy-rasterizer 7 | url = https://github.com/graphdeco-inria/hierarchy-rasterizer.git 8 | [submodule "submodules/simple-knn"] 9 | path = submodules/simple-knn 10 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 11 | [submodule "submodules/gaussianhierarchy"] 12 | path = submodules/gaussianhierarchy 13 | url = https://github.com/graphdeco-inria/gaussian-hierarchy.git 14 | [submodule "submodules/Depth-Anything-V2"] 15 | path = submodules/Depth-Anything-V2 16 | url = https://github.com/DepthAnything/Depth-Anything-V2.git 17 | [submodule "submodules/DPT"] 18 | path = submodules/DPT 19 | url = https://github.com/isl-org/DPT.git 20 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The code for Hierarchical 3D Gaussians is an enhancement of the original 3DGS codebase; the license of the original 3DGS code [License3DGS](https://github.com/graphdeco-inria/gaussian-splatting/blob/main/LICENSE.md) also applies to the parts of the codebase in this repository that build on it. 2 | 3 | 4 | Hierarchical-3d-Gaussians License 5 | =========================== 6 | 7 | **Inria** holds all the ownership rights on the *Software* named **hierarchical-3d-gaussians**. 8 | The *Software* is in the process of being registered with the Agence pour la Protection des 9 | Programmes (APP). 10 | 11 | The *Software* is still being developed by the *Licensor*. 12 | 13 | *Licensor*'s goal is to allow the research community to use, test and evaluate 14 | the *Software*. 15 | 16 | ## 1. Definitions 17 | 18 | *Licensee* means any person or entity that uses the *Software* and distributes 19 | its *Work*. 20 | 21 | *Licensor* means the owners of the *Software*, i.e Inria and TUW 22 | 23 | *Software* means the original work of authorship made available under this 24 | License ie hierarchical-3d-gaussians. 25 | 26 | *Work* means the *Software* and any additions to or derivative works of the 27 | *Software* that are made available under this License. 28 | 29 | 30 | ## 2. Purpose 31 | This license is intended to define the rights granted to the *Licensee* by 32 | Licensors under the *Software*. 33 | 34 | ## 3. Rights granted 35 | 36 | For the above reasons Licensors have decided to distribute the *Software*. 37 | Licensors grant non-exclusive rights to use the *Software* for research purposes 38 | to research users (both academic and industrial), free of charge, without right 39 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 40 | and/or evaluation purposes only. 41 | 42 | Subject to the terms and conditions of this License, you are granted a 43 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 44 | publicly display, publicly perform and distribute its *Work* and any resulting 45 | derivative works in any form. 46 | 47 | ## 4. Limitations 48 | 49 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 50 | so under this License, (b) you include a complete copy of this License with 51 | your distribution, and (c) you retain without modification any copyright, 52 | patent, trademark, or attribution notices that are present in the *Work*. 53 | 54 | **4.2 Derivative Works.** You may specify that additional or different terms apply 55 | to the use, reproduction, and distribution of your derivative works of the *Work* 56 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 57 | Section 2 applies to your derivative works, and (b) you identify the specific 58 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 59 | this License (including the redistribution requirements in Section 3.1) will 60 | continue to apply to the *Work* itself. 61 | 62 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 63 | users explicitly acknowledge having received from Licensors all information 64 | allowing to appreciate the adequacy between of the *Software* and their needs and 65 | to undertake all necessary precautions for its execution and use. 66 | 67 | **4.4** The *Software* is provided both as a compiled library file and as source 68 | code. In case of using the *Software* for a publication or other results obtained 69 | through the use of the *Software*, users are strongly encouraged to cite the 70 | corresponding publications as explained in the documentation of the *Software*. 71 | 72 | ## 5. Disclaimer 73 | 74 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 75 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 76 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 77 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 78 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 79 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 80 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 81 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 82 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 83 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 84 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 85 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 86 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 87 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._exp_name = "" 53 | self._images = "images" 54 | self._alpha_masks = "" 55 | self._depths = "" 56 | self._resolution = -1 57 | self._white_background = False 58 | self.train_test_exp = False # Include the left half of the test images in the train set to optimize exposures 59 | self.data_device = "cuda" 60 | self.eval = False 61 | self.skip_scale_big_gauss = False 62 | self.hierarchy = "" 63 | self.pretrained = "" 64 | self.skybox_num = 0 65 | self.scaffold_file = "" 66 | self.bounds_file = "" 67 | self.skybox_locked = False 68 | super().__init__(parser, "Loading Parameters", sentinel) 69 | 70 | def extract(self, args): 71 | g = super().extract(args) 72 | g.source_path = os.path.abspath(g.source_path) 73 | return g 74 | 75 | class PipelineParams(ParamGroup): 76 | def __init__(self, parser): 77 | self.convert_SHs_python = False 78 | self.compute_cov3D_python = False 79 | self.debug = False 80 | super().__init__(parser, "Pipeline Parameters") 81 | 82 | class OptimizationParams(ParamGroup): 83 | def __init__(self, parser): 84 | self.iterations = 30_000 85 | self.position_lr_init = 0.00002 86 | self.position_lr_final = 0.0000002 87 | self.position_lr_delay_mult = 0.01 88 | self.position_lr_max_steps = 30_000 89 | self.feature_lr = 0.0025 90 | self.opacity_lr = 0.05 91 | self.scaling_lr = 0.005 92 | self.rotation_lr = 0.001 93 | self.exposure_lr_init = 0.001 94 | self.exposure_lr_final = 0.0001 95 | self.exposure_lr_delay_steps = 5000 96 | self.exposure_lr_delay_mult = 0.001 97 | self.percent_dense = 0.0001 98 | self.lambda_dssim = 0.2 99 | self.densification_interval = 300 100 | self.opacity_reset_interval = 3000 101 | self.densify_from_iter = 500 102 | self.densify_until_iter = 15_000 103 | self.densify_grad_threshold = 0.015 104 | self.depth_l1_weight_init = 1.0 105 | self.depth_l1_weight_final = 0.01 106 | super().__init__(parser, "Optimization Parameters") 107 | 108 | def get_combined_args(parser : ArgumentParser): 109 | cmdlne_string = sys.argv[1:] 110 | cfgfile_string = "Namespace()" 111 | args_cmdline = parser.parse_args(cmdlne_string) 112 | 113 | try: 114 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 115 | print("Looking for config file in", cfgfilepath) 116 | with open(cfgfilepath) as cfg_file: 117 | print("Config file found: {}".format(cfgfilepath)) 118 | cfgfile_string = cfg_file.read() 119 | except TypeError: 120 | print("Config file not found at") 121 | pass 122 | args_cfgfile = eval(cfgfile_string) 123 | 124 | merged_dict = vars(args_cfgfile).copy() 125 | for k,v in vars(args_cmdline).items(): 126 | if v != None: 127 | merged_dict[k] = v 128 | return Namespace(**merged_dict) 129 | -------------------------------------------------------------------------------- /assets/hierarchy_viewer_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphdeco-inria/hierarchical-3d-gaussians/f2e849ebccaa5cd963b86cea180a05262867786a/assets/hierarchy_viewer_0.gif -------------------------------------------------------------------------------- /assets/logo_graphdeco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphdeco-inria/hierarchical-3d-gaussians/f2e849ebccaa5cd963b86cea180a05262867786a/assets/logo_graphdeco.png -------------------------------------------------------------------------------- /assets/logo_inria.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphdeco-inria/hierarchical-3d-gaussians/f2e849ebccaa5cd963b86cea180a05262867786a/assets/logo_inria.png -------------------------------------------------------------------------------- /assets/logo_mpi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphdeco-inria/hierarchical-3d-gaussians/f2e849ebccaa5cd963b86cea180a05262867786a/assets/logo_mpi.png -------------------------------------------------------------------------------- /assets/logo_tuwien.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/logo_uca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphdeco-inria/hierarchical-3d-gaussians/f2e849ebccaa5cd963b86cea180a05262867786a/assets/logo_uca.png -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /preprocess/auto_reorient.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import argparse 14 | from read_write_model import * 15 | import torch 16 | import argparse 17 | import os, time 18 | from scipy import spatial 19 | 20 | def fit_plane_least_squares(points): 21 | # Augment the point cloud with a column of ones 22 | A = np.c_[points[:, 0], points[:, 1], np.ones(points.shape[0])] 23 | B = points[:, 2] 24 | 25 | # Solve the least squares problem A * [a, b, c].T = B to get the plane equation z = a*x + b*y + c 26 | coefficients, _, _, _ = np.linalg.lstsq(A, B, rcond=None) 27 | 28 | # Plane coefficients: z = a*x + b*y + c 29 | a, b, c = coefficients 30 | 31 | # The normal vector is [a, b, -1] 32 | normal_vector = np.array([a, b, -1]) 33 | normal_vector /= np.linalg.norm(normal_vector) # Normalize the normal vector 34 | 35 | # An in-plane vector can be any vector orthogonal to the normal. One simple choice is: 36 | in_plane_vector = np.cross(normal_vector, np.array([0, 0, 1])) 37 | if np.linalg.norm(in_plane_vector) == 0: 38 | in_plane_vector = np.cross(normal_vector, np.array([0, 1, 0])) 39 | in_plane_vector /= np.linalg.norm(in_plane_vector) # Normalize the in-plane vector 40 | 41 | return normal_vector, in_plane_vector, np.mean(points, axis=0) 42 | 43 | def rotate_camera(qvec, tvec, rot_matrix, upscale): 44 | # Assuming cameras have 'T' (translation) field 45 | 46 | R = qvec2rotmat(qvec) 47 | T = np.array(tvec) 48 | 49 | Rt = np.zeros((4, 4)) 50 | Rt[:3, :3] = R 51 | Rt[:3, 3] = T 52 | Rt[3, 3] = 1.0 53 | 54 | C2W = np.linalg.inv(Rt) 55 | cam_center = np.copy(C2W[:3, 3]) 56 | cam_rot_orig = np.copy(C2W[:3, :3]) 57 | cam_center = np.matmul(cam_center, rot_matrix) 58 | cam_rot = np.linalg.inv(rot_matrix) @ cam_rot_orig 59 | C2W[:3, 3] = upscale * cam_center 60 | C2W[:3, :3] = cam_rot 61 | Rt = np.linalg.inv(C2W) 62 | new_pos = Rt[:3, 3] 63 | new_rot = rotmat2qvec(Rt[:3, :3]) 64 | 65 | # R_test = qvec2rotmat(new_rots[-1]) 66 | # T_test = np.array(new_poss[-1]) 67 | # Rttest = np.zeros((4, 4)) 68 | # Rttest[:3, :3] = R_test 69 | # Rttest[:3, 3] = T_test 70 | # Rttest[3, 3] = 1.0 71 | # C2Wtest = np.linalg.inv(Rttest) 72 | 73 | return new_pos, new_rot 74 | 75 | if __name__ == '__main__': 76 | 77 | parser = argparse.ArgumentParser(description='Automatically reorient colmap') 78 | 79 | # Add command-line argument(s) 80 | parser.add_argument('--input_path', type=str, help='Path to input colmap dir', required=True) 81 | parser.add_argument('--output_path', type=str, help='Path to output colmap dir', required=True) 82 | parser.add_argument('--upscale', type=float, help='Upscaling factor', default=0) 83 | parser.add_argument('--target_med_dist', default=20) 84 | parser.add_argument('--model_type', type=str, help='Specify which file format to use when processing colmap files (txt or bin)', choices=['bin','txt'], default='bin') 85 | 86 | args = parser.parse_args() 87 | 88 | 89 | # Read colmap cameras, images and points 90 | start_time = time.time() 91 | cameras, images_metas_in, points3d_in = read_model(args.input_path, ext=f".{args.model_type}") 92 | end_time = time.time() 93 | print(f"{len(images_metas_in)} images read in {end_time - start_time} seconds.") 94 | 95 | if args.upscale != 0: 96 | upscale = args.upscale 97 | print("manual upscale") 98 | else: 99 | # compute upscale factor 100 | median_distances = [] 101 | for key in images_metas_in: 102 | image_meta = images_metas_in[key] 103 | cam_center = -qvec2rotmat(image_meta.qvec).astype(np.float32).T @ image_meta.tvec.astype(np.float32) 104 | 105 | median_distances.extend([ 106 | np.linalg.norm(points3d_in[pt_idx].xyz - cam_center) for pt_idx in image_meta.point3D_ids if pt_idx != -1 107 | ]) 108 | 109 | median_distance = np.median(np.array(median_distances)) 110 | upscale = (args.target_med_dist / median_distance) 111 | 112 | 113 | cam_centers = np.array([ 114 | -qvec2rotmat(images_metas_in[key].qvec).T @ images_metas_in[key].tvec 115 | for key in images_metas_in 116 | ]) 117 | 118 | up, _, _ = fit_plane_least_squares(cam_centers) 119 | 120 | # two cameras which are fruthest apart will occur as vertices of the convex hull 121 | candidates = cam_centers[spatial.ConvexHull(cam_centers).vertices] 122 | 123 | # get distances between each pair of cameras 124 | dist_mat = spatial.distance_matrix(candidates, candidates) 125 | 126 | # get indices of cameras that are furthest apart 127 | i, j = np.unravel_index(dist_mat.argmax(), dist_mat.shape) 128 | right = candidates[i] - candidates[j] 129 | right /= np.linalg.norm(right) 130 | 131 | up = torch.from_numpy(up).double() 132 | right = torch.from_numpy(right).double() 133 | 134 | forward = torch.cross(up, right) 135 | forward /= torch.norm(forward, p=2) 136 | 137 | right = torch.cross(forward, up) 138 | right /= torch.norm(right, p=2) 139 | 140 | # Stack the target axes as columns to form the rotation matrix 141 | rotation_matrix = torch.stack([right, forward, up], dim=1) 142 | 143 | 144 | positions = [] 145 | print("Doing points") 146 | for key in points3d_in: 147 | positions.append(points3d_in[key].xyz) 148 | 149 | positions = torch.from_numpy(np.array(positions)) 150 | 151 | # Perform the rotation by matrix multiplication 152 | rotated_points = upscale * torch.matmul(positions, rotation_matrix) 153 | 154 | 155 | 156 | points3d_out = {} 157 | for key, rotated in zip(points3d_in, rotated_points): 158 | point3d_in = points3d_in[key] 159 | points3d_out[key] = Point3D( 160 | id=point3d_in.id, 161 | xyz=rotated, 162 | rgb=point3d_in.rgb, 163 | error=point3d_in.error, 164 | image_ids=point3d_in.image_ids, 165 | point2D_idxs=point3d_in.point2D_idxs, 166 | ) 167 | 168 | print("Doing images") 169 | images_metas_out = {} 170 | for key in images_metas_in: 171 | image_meta_in = images_metas_in[key] 172 | new_pos, new_rot = rotate_camera(image_meta_in.qvec, image_meta_in.tvec, rotation_matrix.double().numpy(), upscale) 173 | 174 | images_metas_out[key] = Image( 175 | id=image_meta_in.id, 176 | qvec=new_rot, 177 | tvec=new_pos, 178 | camera_id=image_meta_in.camera_id, 179 | name=image_meta_in.name, 180 | xys=image_meta_in.xys, 181 | point3D_ids=image_meta_in.point3D_ids, 182 | ) 183 | 184 | if not os.path.isdir(args.output_path): 185 | os.makedirs(args.output_path) 186 | write_model(cameras, images_metas_out, points3d_out, args.output_path, f".{args.model_type}") 187 | 188 | global_end = time.time() 189 | 190 | print(f"reorient script took {global_end - start_time} seconds ({args.model_type} file processed).") 191 | -------------------------------------------------------------------------------- /preprocess/black_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import cv2 4 | import numpy as np 5 | from joblib import delayed, Parallel 6 | import argparse 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--project_dir') 11 | args = parser.parse_args() 12 | 13 | images_dir = os.path.join(args.project_dir, "camera_calibration/rectified/images") 14 | masks_dir = os.path.join(args.project_dir, "camera_calibration/rectified/masks") 15 | 16 | folders = os.listdir(images_dir) 17 | if "jpg" in folders[0]: 18 | all_img_names = folders 19 | else: 20 | all_img_names = [] 21 | for folder in folders: 22 | img_names = os.listdir(f"{images_dir}/{folder}") 23 | img_names = [f"{folder}/{img_name}" for img_name in img_names] 24 | all_img_names += img_names 25 | 26 | def split_mask(name): 27 | img = cv2.imread(f"{images_dir}/{name}", cv2.IMREAD_UNCHANGED) 28 | mask = cv2.imread(f"{masks_dir}/{name[:-4]}.png", cv2.IMREAD_UNCHANGED) 29 | mask = cv2.dilate(mask, np.ones([5, 5])) 30 | img[mask == 0] = 0 31 | cv2.imwrite(f"{images_dir}/{name}", img, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) 32 | 33 | Parallel(n_jobs=-1, backend="threading")( 34 | delayed(split_mask)(name) for name in all_img_names 35 | ) 36 | -------------------------------------------------------------------------------- /preprocess/concat_chunks_info.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import argparse 13 | import os 14 | 15 | # if __name__ == 'main': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--base_dir', required=True, help="Chunks folder") 18 | parser.add_argument('--dest_dir', required=True, help="Folder to which chunks.txt file will be written") 19 | args = parser.parse_args() 20 | 21 | chunks = os.listdir(args.base_dir) 22 | 23 | chunks_data = [] 24 | for chunk in chunks: 25 | center_file_path = os.path.join(args.base_dir, chunk + "/center.txt") 26 | extents_file_path = os.path.join(args.base_dir, chunk + "/extent.txt") 27 | 28 | chunk = { 29 | "name":chunk, 30 | "center": [0,0,0], 31 | "extent": [0,0,0] 32 | } 33 | 34 | try: 35 | with open(center_file_path, 'r') as file: 36 | content = file.read() 37 | chunk["center"] = content.split(" ") 38 | except FileNotFoundError: 39 | print(f"File not found: {center_file_path}") 40 | 41 | try: 42 | with open(extents_file_path, 'r') as file: 43 | content = file.read() 44 | chunk["extent"] = content.split(" ") 45 | except FileNotFoundError: 46 | print(f"File not found: {extents_file_path}") 47 | 48 | chunks_data.append(chunk) 49 | 50 | def write_chunks(data, output_directory): 51 | file_path = os.path.join(output_directory, "chunks.txt") 52 | try: 53 | with open(file_path, 'w') as file: 54 | ind = 0 55 | for chunk in data: 56 | line = chunk['name'] + " " + ' '.join(map(str, chunk['center'])) + " " +' '.join(map(str, chunk['extent'])) + "\n" 57 | 58 | if ind == len(data)-1: 59 | line = line[:-1] 60 | 61 | # Write content to the file 62 | file.write(line) 63 | ind += 1 64 | print(f"Content written to {file_path}") 65 | 66 | except IOError: 67 | print(f"Error writing to {file_path}") 68 | 69 | write_chunks(chunks_data, args.dest_dir) -------------------------------------------------------------------------------- /preprocess/copy_file_to_chunks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--file_path', required=True, help="File to copy") 7 | parser.add_argument('--chunks_path', required=True, help="Path containing folders 0_0/, ...") 8 | parser.add_argument('--out_subdir', default="sparse/0", help="Copy file_path to chunks_path/x_y/out_subdir/") 9 | args = parser.parse_args() 10 | 11 | chunks = os.listdir(args.chunks_path) 12 | 13 | for chunk in chunks: 14 | shutil.copy(args.file_path, os.path.join(args.chunks_path, chunk, args.out_subdir)) 15 | -------------------------------------------------------------------------------- /preprocess/fill_database.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import database 3 | from read_write_model import read_model, CAMERA_MODEL_NAMES 4 | import os 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--in_dir', required=True) 9 | parser.add_argument('--database_path', required=True) 10 | args = parser.parse_args() 11 | 12 | if os.path.exists(args.database_path): 13 | os.remove(args.database_path) 14 | 15 | cam_intrinsics, images_metas, _, = read_model(args.in_dir, ".bin") 16 | db = database.COLMAPDatabase.connect(args.database_path) 17 | db.create_tables() 18 | 19 | for key in cam_intrinsics: 20 | cam = cam_intrinsics[key] 21 | db.add_camera(CAMERA_MODEL_NAMES[cam.model].model_id, cam.width, cam.height, cam.params, camera_id=key) 22 | 23 | for key in images_metas: 24 | image_meta = images_metas[key] 25 | db.add_image(image_meta.name, image_meta.camera_id, image_id=key) 26 | 27 | db.commit() 28 | # shutil.copy(f"{args.in_dir}/cameras.txt", f"{args.out_dir}/cameras.txt") 29 | 30 | print(0) -------------------------------------------------------------------------------- /preprocess/generate_chunks.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os, sys 13 | import subprocess 14 | import argparse 15 | import time, platform 16 | 17 | def submit_job(slurm_args): 18 | """Submit a job using sbatch and return the job ID.""" 19 | try: 20 | result = subprocess.run(slurm_args, capture_output=True) 21 | except subprocess.CalledProcessError as e: 22 | print(f"Error when submitting a job: {e}") 23 | sys.exit(1) 24 | 25 | job = result.stdout.strip().split()[-1] 26 | print(f"submitted job {job}") 27 | return job 28 | 29 | def is_job_finished(job): 30 | """Check if the job has finished using sacct.""" 31 | result = subprocess.run(['sacct', '-j', job, '--format=State', '--noheader', '--parsable2'], capture_output=True, text=True) 32 | 33 | job_state = result.stdout.split('\n')[0] 34 | return job_state if job_state in {'COMPLETED', 'FAILED', 'CANCELLED'} else "" 35 | 36 | def setup_dirs(images, colmap, chunks, project): 37 | images_dir = os.path.join(project, "camera_calibration", "rectified", "images") if images == "" else images 38 | colmap_dir = os.path.join(project, "camera_calibration", "aligned") if colmap == "" else colmap 39 | chunks_dir = os.path.join(project, "camera_calibration") if chunks == "" else chunks 40 | 41 | return images_dir, colmap_dir, chunks_dir 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--project_dir', required=True, help="images, colmap and chunks paths doesnt have to be set if you generated the colmap using generate_colmap script.") 46 | parser.add_argument('--images_dir', default="") 47 | parser.add_argument('--global_colmap_dir', default="") 48 | parser.add_argument('--chunks_dir', default="") 49 | parser.add_argument('--use_slurm', action="store_true", default=False) 50 | parser.add_argument('--skip_bundle_adjustment', action="store_true", default=False) 51 | parser.add_argument('--n_jobs', type=int, default=8, help="Run per chunk COLMAP in parallel on the same machine. Does not handle multi GPU systems. --use_slurm overrides this.") 52 | args = parser.parse_args() 53 | 54 | images_dir, colmap_dir, chunks_dir = setup_dirs( 55 | args.images_dir, 56 | args.global_colmap_dir, args.chunks_dir, 57 | args.project_dir 58 | ) 59 | 60 | if args.use_slurm: 61 | slurm_args = [ 62 | "sbatch" 63 | ] 64 | submitted_jobs = [] 65 | 66 | colmap_exe = "colmap.bat" if platform.system() == "Windows" else "colmap" 67 | start_time = time.time() 68 | 69 | ## First create raw_chunks, each chunk has its own colmap. 70 | print(f"chunking colmap from {colmap_dir} to {args.chunks_dir}/raw_chunks") 71 | make_chunk_args = [ 72 | "python", f"preprocess/make_chunk.py", 73 | "--base_dir", os.path.join(colmap_dir, "sparse", "0"), 74 | "--images_dir", f"{images_dir}", 75 | "--output_path", f"{chunks_dir}/raw_chunks", 76 | ] 77 | try: 78 | subprocess.run(make_chunk_args, check=True) 79 | except subprocess.CalledProcessError as e: 80 | print(f"Error executing image_undistorter: {e}") 81 | sys.exit(1) 82 | 83 | ## Then we refine chunks with 2 rounds of bundle adjustment/triangulation 84 | print("Starting per chunk triangulation and bundle adjustment (if required)") 85 | n_processed = 0 86 | chunk_names = os.listdir(os.path.join(chunks_dir, "raw_chunks")) 87 | for chunk_name in chunk_names: 88 | in_dir = os.path.join(chunks_dir, "raw_chunks", chunk_name) 89 | out_dir = os.path.join(chunks_dir, "chunks", chunk_name) 90 | 91 | if args.use_slurm: 92 | # Process chunks in parallel 93 | job = submit_job(slurm_args + [ 94 | f"--error={in_dir}/log.err", f"--output={in_dir}/log.out", 95 | "preprocess/prepare_chunk.slurm", in_dir, out_dir,images_dir, 96 | os.path.dirname(os.path.realpath(__file__)) 97 | ]) 98 | submitted_jobs.append(job) 99 | else: 100 | try: 101 | if len(submitted_jobs) >= args.n_jobs: 102 | submitted_jobs.pop(0).communicate() 103 | intermediate_dir = os.path.join(in_dir, "bundle_adjustment") 104 | if os.path.exists(intermediate_dir): 105 | print(f"{intermediate_dir} exists! Per chunk triangulation might crash!") 106 | prepare_chunk_args = [ 107 | "python", f"preprocess/prepare_chunk.py", 108 | "--raw_chunk", in_dir, "--out_chunk", out_dir, 109 | "--images_dir", images_dir 110 | ] 111 | if args.skip_bundle_adjustment: 112 | prepare_chunk_args.append("--skip_bundle_adjustment") 113 | job = subprocess.Popen( 114 | prepare_chunk_args, 115 | stderr=open(f"{in_dir}/log.err", 'w'), 116 | stdout=open(f"{in_dir}/log.out", 'w'), 117 | ) 118 | submitted_jobs.append(job) 119 | n_processed += 1 120 | print(f"Launched triangulation for [{n_processed} / {len(chunk_names)} chunks].") 121 | print(f"Logs in {in_dir}/log.err (or .out)") 122 | except subprocess.CalledProcessError as e: 123 | print(f"Error executing prepare_chunk.py: {e}") 124 | sys.exit(1) 125 | 126 | 127 | if args.use_slurm: 128 | # Check every 10 sec all the jobs status 129 | all_finished = False 130 | all_status = [] 131 | last_count = 0 132 | print(f"Waiting for chunks processed in parallel to be done ...") 133 | 134 | while not all_finished: 135 | # print("Checking status of all jobs...") 136 | all_status = [is_job_finished(id) for id in submitted_jobs if is_job_finished(id) != ""] 137 | if last_count != all_status.count("COMPLETED"): 138 | last_count = all_status.count("COMPLETED") 139 | print(f"processed [{last_count} / {len(chunk_names)} chunks].") 140 | 141 | all_finished = len(all_status) == len(submitted_jobs) 142 | 143 | if not all_finished: 144 | time.sleep(10) # Wait before checking again 145 | 146 | if not all(status == "COMPLETED" for status in all_status): 147 | print("At least one job failed or was cancelled, check at error logs.") 148 | else: 149 | for job in submitted_jobs: 150 | job.communicate() 151 | 152 | # create chunks.txt file that concatenates all chunks center.txt and extent.txt files 153 | try: 154 | subprocess.run([ 155 | "python", "preprocess/concat_chunks_info.py", 156 | "--base_dir", os.path.join(chunks_dir, "chunks"), 157 | "--dest_dir", colmap_dir 158 | ], check=True) 159 | n_processed += 1 160 | except subprocess.CalledProcessError as e: 161 | print(f"Error executing concat_chunks_info.sh: {e}") 162 | sys.exit(1) 163 | 164 | end_time = time.time() 165 | print(f"chunks successfully prepared in {(end_time - start_time)/60.0} minutes.") 166 | 167 | -------------------------------------------------------------------------------- /preprocess/generate_colmap.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os, sys, shutil 13 | import subprocess 14 | import argparse 15 | from read_write_model import read_images_binary,write_images_binary, Image 16 | import time, platform 17 | 18 | def replace_images_by_masks(images_file, out_file): 19 | """Replace images.jpg to images.png in the colmap images.bin to process masks the same way as images.""" 20 | images_metas = read_images_binary(images_file) 21 | out_images_metas = {} 22 | for key in images_metas: 23 | in_image_meta = images_metas[key] 24 | out_images_metas[key] = Image( 25 | id=key, 26 | qvec=in_image_meta.qvec, 27 | tvec=in_image_meta.tvec, 28 | camera_id=in_image_meta.camera_id, 29 | name=in_image_meta.name[:-3]+"png", 30 | xys=in_image_meta.xys, 31 | point3D_ids=in_image_meta.point3D_ids, 32 | ) 33 | 34 | write_images_binary(out_images_metas, out_file) 35 | 36 | def setup_dirs(project_dir): 37 | """Create the directories that will be required.""" 38 | if not os.path.exists(project_dir): 39 | print("creating project dir.") 40 | os.makedirs(project_dir) 41 | 42 | if not os.path.exists(os.path.join(project_dir, "camera_calibration/aligned")): 43 | os.makedirs(os.path.join(project_dir, "camera_calibration/aligned/sparse/0")) 44 | 45 | if not os.path.exists(os.path.join(project_dir, "camera_calibration/rectified")): 46 | os.makedirs(os.path.join(project_dir, "camera_calibration/rectified")) 47 | 48 | if not os.path.exists(os.path.join(project_dir, "camera_calibration/unrectified")): 49 | os.makedirs(os.path.join(project_dir, "camera_calibration/unrectified")) 50 | os.makedirs(os.path.join(project_dir, "camera_calibration/unrectified", "sparse")) 51 | 52 | if not os.path.exists(os.path.join(project_dir, "camera_calibration/unrectified", "sparse")): 53 | os.makedirs(os.path.join(project_dir, "camera_calibration/unrectified", "sparse")) 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--project_dir', type=str, required=True) 58 | parser.add_argument('--images_dir', default="", help="Will be set to project_dir/inputs/images if not set") 59 | parser.add_argument('--masks_dir', default="", help="Will be set to project_dir/inputs/masks if exists and not set") 60 | args = parser.parse_args() 61 | 62 | if args.images_dir == "": 63 | args.images_dir = os.path.join(args.project_dir, "inputs/images") 64 | if args.masks_dir == "": 65 | args.masks_dir = os.path.join(args.project_dir, "inputs/masks") 66 | args.masks_dir = args.masks_dir if os.path.exists(args.masks_dir) else "" 67 | 68 | colmap_exe = "colmap.bat" if platform.system() == "Windows" else "colmap" 69 | start_time = time.time() 70 | 71 | print(f"Project will be built here ${args.project_dir} base images are available there ${args.images_dir}.") 72 | 73 | setup_dirs(args.project_dir) 74 | 75 | ## Feature extraction, matching then mapper to generate the colmap. 76 | print("extracting features ...") 77 | colmap_feature_extractor_args = [ 78 | colmap_exe, "feature_extractor", 79 | "--database_path", f"{args.project_dir}/camera_calibration/unrectified/database.db", 80 | "--image_path", f"{args.images_dir}", 81 | "--ImageReader.single_camera", "1", 82 | "--ImageReader.default_focal_length_factor", "0.5", 83 | "--ImageReader.camera_model", "OPENCV", 84 | ] 85 | 86 | try: 87 | subprocess.run(colmap_feature_extractor_args, check=True) 88 | except subprocess.CalledProcessError as e: 89 | print(f"Error executing colmap feature_extractor: {e}") 90 | sys.exit(1) 91 | 92 | print("making custom matches...") 93 | make_colmap_custom_matcher_args = [ 94 | "python", f"preprocess/make_colmap_custom_matcher.py", 95 | "--image_path", f"{args.images_dir}", 96 | "--output_path", f"{args.project_dir}/camera_calibration/unrectified/matching.txt" 97 | ] 98 | try: 99 | subprocess.run(make_colmap_custom_matcher_args, check=True) 100 | except subprocess.CalledProcessError as e: 101 | print(f"Error executing make_colmap_custom_matcher: {e}") 102 | sys.exit(1) 103 | 104 | ## Feature matching 105 | print("matching features...") 106 | colmap_matches_importer_args = [ 107 | colmap_exe, "matches_importer", 108 | "--database_path", f"{args.project_dir}/camera_calibration/unrectified/database.db", 109 | "--match_list_path", f"{args.project_dir}/camera_calibration/unrectified/matching.txt" 110 | ] 111 | try: 112 | subprocess.run(colmap_matches_importer_args, check=True) 113 | except subprocess.CalledProcessError as e: 114 | print(f"Error executing colmap matches_importer: {e}") 115 | sys.exit(1) 116 | 117 | ## Generate sfm pointcloud 118 | print("generating sfm point cloud...") 119 | colmap_hierarchical_mapper_args = [ 120 | colmap_exe, "hierarchical_mapper", 121 | "--database_path", f"{args.project_dir}/camera_calibration/unrectified/database.db", 122 | "--image_path", f"{args.images_dir}", 123 | "--output_path", f"{args.project_dir}/camera_calibration/unrectified/sparse", 124 | "--Mapper.ba_global_function_tolerance", "0.000001" 125 | ] 126 | try: 127 | subprocess.run(colmap_hierarchical_mapper_args, check=True) 128 | except subprocess.CalledProcessError as e: 129 | print(f"Error executing colmap hierarchical_mapper: {e}") 130 | sys.exit(1) 131 | 132 | ## Simplify images so that everything takes less time (reading colmap usually takes forever) 133 | simplify_images_args = [ 134 | "python", f"preprocess/simplify_images.py", 135 | "--base_dir", f"{args.project_dir}/camera_calibration/unrectified/sparse/0" 136 | ] 137 | try: 138 | subprocess.run(simplify_images_args, check=True) 139 | except subprocess.CalledProcessError as e: 140 | print(f"Error executing simplify_images: {e}") 141 | sys.exit(1) 142 | 143 | ## Undistort images 144 | print(f"undistorting images from {args.images_dir} to {args.project_dir}/camera_calibration/rectified images...") 145 | colmap_image_undistorter_args = [ 146 | colmap_exe, "image_undistorter", 147 | "--image_path", f"{args.images_dir}", 148 | "--input_path", f"{args.project_dir}/camera_calibration/unrectified/sparse/0", 149 | "--output_path", f"{args.project_dir}/camera_calibration/rectified/", 150 | "--output_type", "COLMAP", 151 | "--max_image_size", "2048", 152 | ] 153 | try: 154 | subprocess.run(colmap_image_undistorter_args, check=True) 155 | except subprocess.CalledProcessError as e: 156 | print(f"Error executing image_undistorter: {e}") 157 | sys.exit(1) 158 | 159 | if not args.masks_dir == "": 160 | # create a copy of colmap as txt and replace jpgs with pngs to undistort masks the same way images were distorted 161 | if not os.path.exists(f"{args.project_dir}/camera_calibration/unrectified/sparse/0/masks"): 162 | os.makedirs(f"{args.project_dir}/camera_calibration/unrectified/sparse/0/masks") 163 | 164 | shutil.copy(f"{args.project_dir}/camera_calibration/unrectified/sparse/0/cameras.bin", f"{args.project_dir}/camera_calibration/unrectified/sparse/0/masks/cameras.bin") 165 | shutil.copy(f"{args.project_dir}/camera_calibration/unrectified/sparse/0/points3D.bin", f"{args.project_dir}/camera_calibration/unrectified/sparse/0/masks/points3D.bin") 166 | replace_images_by_masks(f"{args.project_dir}/camera_calibration/unrectified/sparse/0/images.bin", f"{args.project_dir}/camera_calibration/unrectified/sparse/0/masks/images.bin") 167 | 168 | print("undistorting masks aswell...") 169 | colmap_image_undistorter_args = [ 170 | colmap_exe, "image_undistorter", 171 | "--image_path", f"{args.masks_dir}", 172 | "--input_path", f"{args.project_dir}/camera_calibration/unrectified/sparse/0/masks", 173 | "--output_path", f"{args.project_dir}/camera_calibration/tmp/", 174 | "--output_type", "COLMAP", 175 | "--max_image_size", "2048", 176 | ] 177 | try: 178 | subprocess.run(colmap_image_undistorter_args, check=True) 179 | except subprocess.CalledProcessError as e: 180 | print(f"Error executing image_undistorter: {e}") 181 | sys.exit(1) 182 | 183 | make_mask_uint8_args = [ 184 | "python", f"preprocess/make_mask_uint8.py", 185 | "--in_dir", f"{args.project_dir}/camera_calibration/tmp/images", 186 | "--out_dir", f"{args.project_dir}/camera_calibration/rectified/masks" 187 | ] 188 | try: 189 | subprocess.run(make_mask_uint8_args, check=True) 190 | except subprocess.CalledProcessError as e: 191 | print(f"Error executing make_colmap_custom_matcher: {e}") 192 | sys.exit(1) 193 | 194 | # remove temporary dir containing undistorted masks 195 | shutil.rmtree(f"{args.project_dir}/camera_calibration/tmp") 196 | 197 | # re-orient + scale colmap 198 | print(f"re-orient and scaling scene to {args.project_dir}/camera_calibration/aligned/sparse/0") 199 | reorient_args = [ 200 | "python", f"preprocess/auto_reorient.py", 201 | "--input_path", f"{args.project_dir}/camera_calibration/rectified/sparse", 202 | "--output_path", f"{args.project_dir}/camera_calibration/aligned/sparse/0" 203 | ] 204 | try: 205 | subprocess.run(reorient_args, check=True) 206 | except subprocess.CalledProcessError as e: 207 | print(f"Error executing auto_orient: {e}") 208 | sys.exit(1) 209 | 210 | end_time = time.time() 211 | print(f"Preprocessing done in {(end_time - start_time)/60.0} minutes.") 212 | -------------------------------------------------------------------------------- /preprocess/generate_depth.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import subprocess 3 | import argparse 4 | import time 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--project_dir', type=str, required=True) 9 | parser.add_argument('--images_dir', default="", help="Will be set to project_dir/camera_calibration/rectified/images if not set") 10 | parser.add_argument('--chunks_dir', default="", help="Will be set to project_dir/camera_calibration/chunks if not set") 11 | parser.add_argument('--depth_generator', default="Depth-Anything-V2", choices=["DPT", "Depth-Anything-V2"], help="depth generator can be DPT or Depth-Anything-V2, we suggest using Depth-Anything-V2.") 12 | args = parser.parse_args() 13 | 14 | if args.images_dir == "": 15 | args.images_dir = os.path.join(args.project_dir, "camera_calibration/rectified/images") 16 | 17 | if args.chunks_dir == "": 18 | args.chunks_dir = os.path.join(args.project_dir, "camera_calibration/chunks") 19 | 20 | print(f"generating depth maps using {args.depth_generator}.") 21 | start_time = time.time() 22 | 23 | # Generate depth maps 24 | generator_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "submodules", args.depth_generator) 25 | 26 | if args.depth_generator == "DPT": 27 | base_generator_args = [ 28 | "python", f"{generator_dir}/run_monodepth.py", 29 | "-t", "dpt_large" 30 | ] 31 | else: 32 | base_generator_args = [ 33 | "python", f"{generator_dir}/run.py", 34 | "--encoder", "vitl", "--pred-only", "--grayscale" 35 | ] 36 | 37 | images_dir = os.path.join(args.project_dir, "camera_calibration/rectified", "images") 38 | cam_dirs = os.listdir(images_dir) 39 | if all(os.path.isfile(os.path.join(images_dir, cam_dir)) for cam_dir in cam_dirs): 40 | cam_dirs = [""] 41 | for cam_dir in cam_dirs: 42 | full_cam_path = os.path.join(images_dir, cam_dir) 43 | print(f"Estimating depth for {full_cam_path}") 44 | full_depth_path = os.path.join(args.project_dir, "camera_calibration/rectified", "depths", cam_dir) 45 | if not os.path.isabs(full_cam_path): 46 | full_cam_path = os.path.join("../../", full_cam_path) 47 | if not os.path.isabs(full_depth_path): 48 | full_depth_path = os.path.join("../../", full_depth_path) 49 | os.makedirs(full_depth_path, exist_ok=True) 50 | if args.depth_generator == "DPT": 51 | generator_args = base_generator_args + [ 52 | "-i", full_cam_path, 53 | "-o", full_depth_path 54 | ] 55 | else: 56 | generator_args = base_generator_args + [ 57 | "--img-path", full_cam_path, 58 | "--outdir", full_depth_path 59 | ] 60 | try: 61 | subprocess.run(generator_args, check=True, cwd=generator_dir) 62 | except subprocess.CalledProcessError as e: 63 | print(f"Error executing run_monodepth: {e}") 64 | sys.exit(1) 65 | 66 | # generate depth_params.json for each chunks 67 | print(f"generating depth_params.json for chunks {os.listdir(args.chunks_dir)}.") 68 | try: 69 | subprocess.run([ 70 | "python", "preprocess/make_chunks_depth_scale.py", "--chunks_dir", f"{args.chunks_dir}", "--depths_dir", f"{os.path.join(args.project_dir, "camera_calibration/rectified", "depths")}"], 71 | check=True 72 | ) 73 | except subprocess.CalledProcessError as e: 74 | print(f"Error executing run_monodepth: {e}") 75 | sys.exit(1) 76 | 77 | end_time = time.time() 78 | print(f"Monocular depth estimation done in {(end_time - start_time)/60.0} minutes.") -------------------------------------------------------------------------------- /preprocess/jz_test_gen_chunk.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil 2 | import subprocess 3 | import argparse 4 | import time, platform 5 | 6 | def submit_job(slurm_args): 7 | """Submit a job using sbatch and return the job ID.""" 8 | try: 9 | result = subprocess.run(slurm_args, capture_output=True) 10 | except subprocess.CalledProcessError as e: 11 | print(f"Error when submitting a job: {e}") 12 | sys.exit(1) 13 | print(f"RESULT {result}") 14 | 15 | # Extract job ID from sbatch output 16 | job_id = result.stdout.strip().split()[-1] 17 | return job_id 18 | 19 | def is_job_finished(job_id): 20 | """Check if the job has finished using sacct.""" 21 | result = subprocess.run(['sacct', '-j', job_id, '--format=State', '--noheader', '--parsable2'], capture_output=True, text=True) 22 | #test = subprocess.run(['scontrol', 'show', 'jobid', job_id], capture_output=True, text=True) 23 | 24 | # Get job state 25 | job_state = result.stdout.split('\n')[0] 26 | # print(f"res {job_state}") 27 | return job_state if job_state in {'COMPLETED', 'FAILED', 'CANCELLED'} else "" 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--images_dir', required=True) 33 | parser.add_argument('--depths_dir', default="") # can be set if depths were not generated using automatic generate_colmap script 34 | parser.add_argument('--global_colmap_dir', required=True) 35 | parser.add_argument('--chunks_dir', required=True) 36 | parser.add_argument('--use_slurm', action="store_true", default=False) 37 | # parser.add_argument('--colmap_exe', default="colmap.bat") 38 | args = parser.parse_args() 39 | preprocess_dir = os.path.dirname(os.path.realpath(__file__)) 40 | 41 | # if args.use_slurm: 42 | # gpu='-C a100 -A hzb@a100' 43 | # slurm_args = [ 44 | # "sbatch", 45 | # #gpu, 46 | ## "--ntasks=1", "--nodes=1", 47 | ## "--gres=gpu:1", "--cpus-per-task=20", 48 | ## "--time=3:00:00" 49 | # ] 50 | submitted_jobs_ids = [] 51 | os_name = platform.system() 52 | 53 | colmap_exe = "colmap.bat" if os_name == "Windows" else "colmap" 54 | start_time = time.time() 55 | depths_dir = args.depths_dir if args.depths_dir != "" else os.path.join(args.images_dir, "..", "depths") 56 | 57 | print(f"chunking colmap from {args.global_colmap_dir} to {args.chunks_dir}/raw_chunks") 58 | make_chunk_args = [ 59 | "python", f"{preprocess_dir}/make_chunk.py", 60 | "--base_dir", f"{args.global_colmap_dir}", 61 | "--images_dir", f"{args.images_dir}", 62 | "--output_path", f"{args.chunks_dir}/raw_chunks", 63 | ] 64 | try: 65 | subprocess.run(make_chunk_args, check=True) 66 | except subprocess.CalledProcessError as e: 67 | print(f"Error executing image_undistorter: {e}") 68 | sys.exit(1) 69 | 70 | print("TEST WITH ONLY 1 CHUNK") 71 | chunk_name = os.listdir(os.path.join(args.chunks_dir, "raw_chunks"))[0] 72 | in_dir = os.path.join(args.chunks_dir, "raw_chunks", chunk_name) 73 | bundle_adj_dir = os.path.join(args.chunks_dir, "raw_chunks", chunk_name, "bundle_adjustment") 74 | out_dir = os.path.join(args.chunks_dir, "chunks", chunk_name) 75 | 76 | if args.use_slurm: 77 | gpu='-C a100 -A hzb@a100' 78 | slurm_args = [ 79 | "sbatch", f"--error={in_dir}/log.err", 80 | f"--output={in_dir}/log.out" 81 | #gpu, 82 | # "--ntasks=1", "--nodes=1", 83 | # "--gres=gpu:1", "--cpus-per-task=20", 84 | # "--time=3:00:00" 85 | ] 86 | 87 | # Process chunks in parallel 88 | # str_args = " ".join(slurm_args + ["preprocess/prepare_chunk.slurm", in_dir, bundle_adj_dir, out_dir,args.images_dir, depths_dir, preprocess_dir]) 89 | #print(f"STR ARGS {str_args}") 90 | job_id = submit_job(slurm_args + ["preprocess/prepare_chunk.slurm", in_dir, bundle_adj_dir, out_dir,args.images_dir, depths_dir, preprocess_dir]) 91 | # job_id = submit_job(slurm_args, [ 92 | # "preprocess/prepare_chunk.slurm", 93 | # in_dir, bundle_adj_dir, out_dir, 94 | # args.images_dir, depths_dir, preprocess_dir 95 | # ]) 96 | # job_id = submit_job(slurm_args, [ 97 | # f"{preprocess_dir}/prepare_chunk.py", 98 | # "--in_dir", in_dir, "--bundle_adj_dir", bundle_adj_dir,"--out_dir", out_dir, 99 | # "--images_dir", args.images_dir, "--depths_dir", args.depths_dir,"--preprocess_dir", preprocess_dir, "--is_job" 100 | # ]) 101 | submitted_jobs_ids.append(job_id) 102 | 103 | # Check every 10 sec all the jobs status 104 | all_finished = False 105 | all_status = [] 106 | time_limit = 180 107 | while not all_finished and time_limit: 108 | # print("Checking status of all jobs...") 109 | all_status = [is_job_finished(id) for id in submitted_jobs_ids if is_job_finished(id) != ""] 110 | 111 | all_finished = len(all_status) == len(submitted_jobs_ids) 112 | if not all_finished: 113 | time.sleep(10) # Wait before checking again 114 | time_limit -= 10 115 | 116 | if not all(status == "COMPLETED" for status in all_status): 117 | print("At least one job failed or was cancelled, check at error logs.") 118 | print(f"STATUS WHEN DONE: {all_status}") 119 | end_time = time.time() 120 | print(f"chunks successfully prepared in {(end_time - start_time)/60.0} minutes.") 121 | 122 | -------------------------------------------------------------------------------- /preprocess/make_chunk.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import argparse 14 | import cv2 15 | from joblib import delayed, Parallel 16 | import os 17 | import random 18 | from read_write_model import * 19 | import json 20 | 21 | def get_nb_pts(image_metas): 22 | n_pts = 0 23 | for key in image_metas: 24 | pts_idx = image_metas[key].point3D_ids 25 | if(len(pts_idx) > 5): 26 | n_pts = max(n_pts, np.max(pts_idx)) 27 | 28 | return n_pts + 1 29 | 30 | if __name__ == '__main__': 31 | random.seed(0) 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--base_dir', required=True) 34 | parser.add_argument('--images_dir', required=True) 35 | parser.add_argument('--chunk_size', default=100, type=float) 36 | parser.add_argument('--min_padd', default=0.2, type=float) 37 | parser.add_argument('--lapla_thresh', default=1, type=float, help="Discard images if their laplacians are < mean - lapla_thresh * std") # 1 38 | parser.add_argument('--min_n_cams', default=100, type=int) # 100 39 | parser.add_argument('--max_n_cams', default=1500, type=int) # 1500 40 | parser.add_argument('--output_path', required=True) 41 | parser.add_argument('--add_far_cams', default=True) 42 | parser.add_argument('--model_type', default="bin") 43 | 44 | args = parser.parse_args() 45 | 46 | # eval 47 | test_file = f"{args.base_dir}/test.txt" 48 | if os.path.exists(test_file): 49 | with open(test_file, 'r') as file: 50 | test_cam_names_list = file.readlines() 51 | blending_dict = {name[:-1] if name[-1] == '\n' else name: {} for name in test_cam_names_list} 52 | 53 | cam_intrinsics, images_metas, points3d = read_model(args.base_dir, ext=f".{args.model_type}") 54 | 55 | cam_centers = np.array([ 56 | -qvec2rotmat(images_metas[key].qvec).astype(np.float32).T @ images_metas[key].tvec.astype(np.float32) 57 | for key in images_metas 58 | ]) 59 | 60 | n_pts = get_nb_pts(images_metas) 61 | 62 | xyzs = np.zeros([n_pts, 3], np.float32) 63 | errors = np.zeros([n_pts], np.float32) + 9e9 64 | indices = np.zeros([n_pts], np.int64) 65 | n_images = np.zeros([n_pts], np.int64) 66 | colors = np.zeros([n_pts, 3], np.float32) 67 | 68 | idx = 0 69 | for key in points3d: 70 | xyzs[idx] = points3d[key].xyz 71 | indices[idx] = points3d[key].id 72 | errors[idx] = points3d[key].error 73 | colors[idx] = points3d[key].rgb 74 | n_images[idx] = len(points3d[key].image_ids) 75 | idx +=1 76 | 77 | mask = errors < 1e1 78 | # mask *= n_images > 3 79 | xyzsC, colorsC, errorsC, indicesC, n_imagesC = xyzs[mask], colors[mask], errors[mask], indices[mask], n_images[mask] 80 | 81 | points3d_ordered = np.zeros([indicesC.max()+1, 3]) 82 | points3d_ordered[indicesC] = xyzsC 83 | images_points3d = {} 84 | 85 | for key in images_metas: 86 | pts_idx = images_metas[key].point3D_ids 87 | mask = pts_idx >= 0 88 | mask *= pts_idx < len(points3d_ordered) 89 | pts_idx = pts_idx[mask] 90 | if len(pts_idx) > 0: 91 | image_points3d = points3d_ordered[pts_idx] 92 | mask = (image_points3d != 0).sum(axis=-1) 93 | # images_metas[key]["points3d"] = image_points3d[mask>0] 94 | images_points3d[key] = image_points3d[mask>0] 95 | else: 96 | # images_metas[key]["points3d"] = np.array([]) 97 | images_points3d[key] = np.array([]) 98 | 99 | 100 | global_bbox = np.stack([cam_centers.min(axis=0), cam_centers.max(axis=0)]) 101 | global_bbox[0, :2] -= args.min_padd * args.chunk_size 102 | global_bbox[1, :2] += args.min_padd * args.chunk_size 103 | extent = global_bbox[1] - global_bbox[0] 104 | padd = np.array([args.chunk_size - extent[0] % args.chunk_size, args.chunk_size - extent[1] % args.chunk_size]) 105 | global_bbox[0, :2] -= padd / 2 106 | global_bbox[1, :2] += padd / 2 107 | 108 | global_bbox[0, 2] = -1e12 109 | global_bbox[1, 2] = 1e12 110 | 111 | def get_var_of_laplacian(key): 112 | image = cv2.imread(os.path.join(args.images_dir, images_metas[key].name)) 113 | if image is not None: 114 | gray = cv2.cvtColor(image[..., :3], cv2.COLOR_BGR2GRAY) 115 | return cv2.Laplacian(gray, cv2.CV_32F).var() 116 | else: 117 | return 0 118 | 119 | if args.lapla_thresh > 0: 120 | laplacians = Parallel(n_jobs=-1, backend="threading")( 121 | delayed(get_var_of_laplacian)(key) for key in images_metas 122 | ) 123 | laplacians_dict = {key: laplacian for key, laplacian in zip(images_metas, laplacians)} 124 | 125 | excluded_chunks = [] 126 | chunks_pcd = {} 127 | 128 | def make_chunk(i, j, n_width, n_height): 129 | # in_path = f"{args.base_dir}/chunk_{i}_{j}" 130 | # if os.path.exists(in_path): 131 | print(f"chunk {i}_{j}") 132 | # corner_min, corner_max = bboxes[i, j, :, 0], bboxes[i, j, :, 1] 133 | corner_min = global_bbox[0] + np.array([i * args.chunk_size, j * args.chunk_size, 0]) 134 | corner_max = global_bbox[0] + np.array([(i + 1) * args.chunk_size, (j + 1) * args.chunk_size, 1e12]) 135 | corner_min[2] = -1e12 136 | corner_max[2] = 1e12 137 | 138 | corner_min_for_pts = corner_min.copy() 139 | corner_max_for_pts = corner_max.copy() 140 | if i == 0: 141 | corner_min_for_pts[0] = -1e12 142 | if j == 0: 143 | corner_min_for_pts[1] = -1e12 144 | if i == n_width - 1: 145 | corner_max_for_pts[0] = 1e12 146 | if j == n_height - 1: 147 | corner_max_for_pts[1] = 1e12 148 | 149 | mask = np.all(xyzsC < corner_max_for_pts, axis=-1) * np.all(xyzsC > corner_min_for_pts, axis=-1) 150 | new_xyzs = xyzsC[mask] 151 | new_colors = colorsC[mask] 152 | new_indices = indicesC[mask] 153 | new_errors = errorsC[mask] 154 | 155 | new_colors = np.clip(new_colors, 0, 255).astype(np.uint8) 156 | 157 | valid_cam = np.all(cam_centers < corner_max, axis=-1) * np.all(cam_centers > corner_min, axis=-1) 158 | 159 | box_center = (corner_max + corner_min) / 2 160 | extent = (corner_max - corner_min) / 2 161 | acceptable_radius = 2 162 | extended_corner_min = box_center - acceptable_radius * extent 163 | extended_corner_max = box_center + acceptable_radius * extent 164 | 165 | for cam_idx, key in enumerate(images_metas): 166 | # if not valid_cam[cam_idx]: 167 | image_points3d = images_points3d[key] 168 | n_pts = (np.all(image_points3d < corner_max_for_pts, axis=-1) * np.all(image_points3d > corner_min_for_pts, axis=-1)).sum() if len(image_points3d) > 0 else 0 169 | 170 | # If within chunk 171 | if np.all(cam_centers[cam_idx] < corner_max) and np.all(cam_centers[cam_idx] > corner_min): 172 | valid_cam[cam_idx] = n_pts > 50 173 | # If within 2x of the chunk 174 | elif np.all(cam_centers[cam_idx] < extended_corner_max) and np.all(cam_centers[cam_idx] > extended_corner_min): 175 | valid_cam[cam_idx] = n_pts > 50 and random.uniform(0, 1) > 0.5 176 | # All distances 177 | if (not valid_cam[cam_idx]) and n_pts > 10 and args.add_far_cams: 178 | valid_cam[cam_idx] = random.uniform(0, 0.5) < (float(n_pts) / len(image_points3d)) 179 | 180 | print(f"{valid_cam.sum()} valid cameras after visibility-base selection") 181 | if args.lapla_thresh > 0: 182 | chunk_laplacians = np.array([laplacians_dict[key] for cam_idx, key in enumerate(images_metas) if valid_cam[cam_idx]]) 183 | laplacian_mean = chunk_laplacians.mean() 184 | laplacian_std_dev = chunk_laplacians.std() 185 | for cam_idx, key in enumerate(images_metas): 186 | if valid_cam[cam_idx] and laplacians_dict[key] < (laplacian_mean - args.lapla_thresh * laplacian_std_dev): 187 | # image = cv2.imread(f"{args.base_dir}/images_masked/{images_metas[key]['name']}") 188 | # cv2.imshow("blurry", image) 189 | # cv2.waitKey(0) 190 | valid_cam[cam_idx] = False 191 | 192 | print(f"{valid_cam.sum()} after Laplacian") 193 | 194 | if valid_cam.sum() > args.max_n_cams: 195 | for _ in range(valid_cam.sum() - args.max_n_cams): 196 | remove_idx = random.randint(0, valid_cam.sum() - 1) 197 | remove_idx_glob = np.arange(len(valid_cam))[valid_cam][remove_idx] 198 | valid_cam[remove_idx_glob] = False 199 | 200 | print(f"{valid_cam.sum()} after random removal") 201 | 202 | valid_keys = [key for idx, key in enumerate(images_metas) if valid_cam[idx]] 203 | 204 | if valid_cam.sum() > args.min_n_cams:# or init_valid_cam.sum() > 0: 205 | out_path = os.path.join(args.output_path, f"{i}_{j}") 206 | out_colmap = os.path.join(out_path, "sparse", "0") 207 | os.makedirs(out_colmap, exist_ok=True) 208 | 209 | # must remove sfm points to use colmap triangulator in following steps 210 | images_out = {} 211 | for key in valid_keys: 212 | image_meta = images_metas[key] 213 | images_out[key] = Image( 214 | id = key, 215 | qvec = image_meta.qvec, 216 | tvec = image_meta.tvec, 217 | camera_id = image_meta.camera_id, 218 | name = image_meta.name, 219 | xys = [], 220 | point3D_ids = [] 221 | ) 222 | 223 | if os.path.exists(test_file) and image_meta.name in blending_dict: 224 | n_pts = np.isin(image_meta.point3D_ids, new_indices).sum() 225 | blending_dict[image_meta.name][f"{i}_{j}"] = str(n_pts) 226 | 227 | 228 | points_out = { 229 | new_indices[idx] : Point3D( 230 | id=new_indices[idx], 231 | xyz= new_xyzs[idx], 232 | rgb=new_colors[idx], 233 | error=new_errors[idx], 234 | image_ids=np.array([]), 235 | point2D_idxs=np.array([]) 236 | ) 237 | for idx in range(len(new_xyzs)) 238 | } 239 | 240 | write_model(cam_intrinsics, images_out, points_out, out_colmap, f".{args.model_type}") 241 | 242 | with open(os.path.join(out_path, "center.txt"), 'w') as f: 243 | f.write(' '.join(map(str, (corner_min + corner_max) / 2))) 244 | with open(os.path.join(out_path, "extent.txt"), 'w') as f: 245 | f.write(' '.join(map(str, corner_max - corner_min))) 246 | else: 247 | excluded_chunks.append([i, j]) 248 | print("Chunk excluded") 249 | 250 | extent = global_bbox[1] - global_bbox[0] 251 | n_width = round(extent[0] / args.chunk_size) 252 | n_height = round(extent[1] / args.chunk_size) 253 | 254 | for i in range(n_width): 255 | for j in range(n_height): 256 | make_chunk(i, j, n_width, n_height) 257 | 258 | if os.path.exists(test_file): 259 | with open(f"{args.base_dir}/blending_dict.json", "w") as f: 260 | json.dump(blending_dict, f, indent=2) -------------------------------------------------------------------------------- /preprocess/make_chunks_depth_scale.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import sys, os 13 | import subprocess 14 | import argparse 15 | import time 16 | 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--chunks_dir', required=True) 21 | parser.add_argument('--depths_dir', required=True) 22 | args = parser.parse_args() 23 | 24 | chunk_names = os.listdir(args.chunks_dir) 25 | for chunk_name in chunk_names: 26 | 27 | ## Generate depth_params.json file for each chunks as each chunk now has its own colmap 28 | make_depth_scale_args = [ 29 | "python", "preprocess/make_depth_scale.py", 30 | "--base_dir", os.path.join(args.chunks_dir, chunk_name), 31 | "--depths_dir", args.depths_dir, 32 | ] 33 | try: 34 | subprocess.run(make_depth_scale_args, check=True) 35 | except subprocess.CalledProcessError as e: 36 | print(f"Error executing make_depth_scale: {e}") 37 | sys.exit(1) -------------------------------------------------------------------------------- /preprocess/make_colmap_custom_matcher.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import numpy as np 14 | from joblib import delayed, Parallel 15 | import argparse 16 | from exif import Image 17 | from sklearn.neighbors import NearestNeighbors 18 | 19 | #TODO: clean it 20 | def decimal_coords(coords, ref): 21 | decimal_degrees = coords[0] + coords[1] / 60 + coords[2] / 3600 22 | if ref == "S" or ref =='W' : 23 | decimal_degrees = -decimal_degrees 24 | return decimal_degrees 25 | 26 | def image_coordinates(image_name): 27 | with open(os.path.join(args.image_path, image_name), 'rb') as src: 28 | img = Image(src) 29 | if img.has_exif: 30 | try: 31 | img.gps_longitude 32 | coords = [ 33 | decimal_coords(img.gps_latitude, img.gps_latitude_ref), 34 | decimal_coords(img.gps_longitude, img.gps_longitude_ref) 35 | ] 36 | return coords 37 | except AttributeError: 38 | return None 39 | else: 40 | return None 41 | 42 | def get_matches(img_name, cam_center, cam_nbrs, img_names_gps): 43 | _, indices = cam_nbrs.kneighbors(cam_center[None]) 44 | matches = "" 45 | for idx in indices[0, 1:]: 46 | matches += f"{img_name} {img_names_gps[idx]}\n" 47 | return matches 48 | 49 | def find_images_names(root_dir): 50 | image_files_by_subdir = [] 51 | 52 | # Walk through the directory structure 53 | for dirpath, dirnames, filenames in os.walk(root_dir): 54 | 55 | # Filter for image files (you can add more extensions if needed), sort images 56 | image_files = sorted([f for f in filenames if f.lower().endswith(('.png', '.jpg', '.JPG', '.PNG'))]) 57 | 58 | # If there are image files in the current directory, add them to the list 59 | if image_files: 60 | image_files_by_subdir.append({ 61 | 'dir': os.path.basename(dirpath) if dirpath != root_dir else "", 62 | 'images': image_files 63 | }) 64 | 65 | return image_files_by_subdir 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--image_path', required=True) 70 | parser.add_argument('--output_path', required=True) 71 | parser.add_argument('--n_seq_matches_per_view', default=0, type=int) 72 | parser.add_argument('--n_quad_matches_per_view', default=10, type=int) 73 | parser.add_argument('--n_loop_closure_match_per_view', default=5, type=int) 74 | parser.add_argument('--loop_matches', default=[], type=int) 75 | parser.add_argument('--n_gps_neighbours', default=25, type=int) 76 | args = parser.parse_args() 77 | 78 | 79 | loop_matches = np.array(args.loop_matches, dtype=np.int64).reshape(-1, 2) 80 | 81 | loop_rel_matches = np.arange(0, args.n_loop_closure_match_per_view) 82 | loop_rel_matches = 2**loop_rel_matches 83 | loop_rel_matches = np.concatenate([-loop_rel_matches[::-1], np.array([0]), loop_rel_matches]) 84 | 85 | image_files_organised = find_images_names(args.image_path) 86 | 87 | cam_folder_list = [] 88 | cam_folder_list = os.listdir(f"{args.image_path}") 89 | 90 | 91 | matches_str = [] 92 | def add_match(cam_id, matched_cam_id, current_image_file, matched_frame_id): 93 | # REMOVE AFTER 94 | # if (cam_folder_list[cam_id + matched_cam_id] == "backleft") and (not cam_folder_list[cam_id] == "backleft") and (matched_frame_id >= 785): 95 | # matched_frame_id -= 647 96 | # if (not cam_folder_list[cam_id + matched_cam_id] == "backleft") and (cam_folder_list[cam_id] == "backleft") and (matched_frame_id >= 785): 97 | # matched_frame_id += 647 98 | 99 | if matched_frame_id < len(matched_cam['images']): 100 | matched_image_file = matched_cam['images'][matched_frame_id] 101 | matches_str.append(f"{cam_folder_list[cam_id]}/{current_image_file} {cam_folder_list[cam_id + matched_cam_id]}/{matched_image_file}\n") 102 | 103 | 104 | for cam_id, current_cam in enumerate(image_files_organised): 105 | for matched_cam_id, matched_cam in enumerate(image_files_organised[cam_id:]): 106 | for current_image_id, current_image_file in enumerate(current_cam['images']): 107 | for frame_step in range(args.n_seq_matches_per_view): 108 | matched_frame_id = current_image_id + frame_step 109 | add_match(cam_id, matched_cam_id, current_image_file, matched_frame_id) 110 | 111 | for match_id in range(args.n_quad_matches_per_view): 112 | frame_step = args.n_seq_matches_per_view + int(2**match_id) - 1 113 | matched_frame_id = current_image_id + frame_step 114 | add_match(cam_id, matched_cam_id, current_image_file, matched_frame_id) 115 | 116 | ## Loop closure 117 | for loop_match in loop_matches: 118 | for current_loop_rel_match in loop_rel_matches: 119 | current_image_id = (loop_match[0] + current_loop_rel_match) 120 | if current_image_id < len(current_cam['images']): 121 | current_image_file = current_cam['images'][current_image_id] 122 | for matched_loop_rel_match in loop_rel_matches: 123 | matched_frame_id = (loop_match[1] + matched_loop_rel_match) 124 | add_match(cam_id, matched_cam_id, current_image_file, matched_frame_id) 125 | 126 | 127 | ## Add GPS matches 128 | if args.n_gps_neighbours > 0: 129 | all_img_names = [] 130 | for ind, cam in enumerate(image_files_organised): 131 | all_img_names += [os.path.join(cam['dir'], img_name) for img_name in cam['images']] 132 | 133 | all_cam_centers = [image_coordinates(img_name) for img_name in all_img_names] 134 | # all_cam_centers = Parallel(n_jobs=-1, backend="threading")( 135 | # delayed(image_coordinates)(img_name) for img_name in all_img_names 136 | # ) 137 | img_names_gps = [img_name for img_name, cam_center in zip(all_img_names, all_cam_centers) if cam_center is not None] 138 | cam_centers_gps = [cam_center for cam_center in all_cam_centers if cam_center is not None] 139 | cam_centers = np.array(cam_centers_gps) 140 | cam_nbrs = NearestNeighbors(n_neighbors=args.n_gps_neighbours).fit(cam_centers) if cam_centers.size else [] 141 | 142 | matches_str += [get_matches(img_name, cam_center, cam_nbrs, img_names_gps) for img_name, cam_center in zip(img_names_gps, cam_centers)] 143 | 144 | 145 | ## Remove duplicate matches 146 | intermediate_out_matches = list(dict.fromkeys(matches_str)) 147 | reciproc_matches = [f"{match.split(' ')[1][:-1]} {match.split(' ')[0]}\n" for match in intermediate_out_matches] 148 | reciproc_matches_dict = dict.fromkeys(reciproc_matches) 149 | out_matches = [ 150 | match for match in intermediate_out_matches 151 | if not match in reciproc_matches_dict 152 | ] 153 | 154 | # with open(f"{args.image_path}/TEST_new_{args.n_seq_matches_per_view}_{args.n_quad_matches_per_view}_{args.n_loop_closure_match_per_view}_{args.n_gps_neighbours}.txt", "w") as f: 155 | with open(args.output_path, "w") as f: 156 | f.write(''.join(out_matches)) 157 | 158 | print(0) -------------------------------------------------------------------------------- /preprocess/make_colmap_custom_matcher_distance.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import argparse 14 | from sklearn.neighbors import NearestNeighbors 15 | from read_write_model import read_images_binary 16 | 17 | def qvec2rotmat(qvec): 18 | return np.array([ 19 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 20 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 21 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 22 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 23 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 24 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 25 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 26 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 27 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 28 | 29 | def read_images_metas(path): 30 | """ 31 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 32 | """ 33 | images_metas = {} 34 | with open(path, "r") as fid: 35 | while True: 36 | line = fid.readline() 37 | if not line: 38 | break 39 | line = line.strip() 40 | if len(line) > 0 and line[0] != "#": 41 | elems = line.split() 42 | idx = int(elems[0]) 43 | images_metas[idx] = { 44 | "camera_id": int(elems[8]), 45 | "name":elems[9], 46 | "qvec": np.array(tuple(map(float, elems[1:5]))), 47 | "tvec": np.array(tuple(map(float, elems[5:8]))), 48 | } 49 | elems = fid.readline().split() 50 | 51 | return images_metas 52 | 53 | 54 | if __name__ == '__main__': 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--base_dir', required=True) 57 | parser.add_argument('--n_neighbours', default=100, type=int) 58 | args = parser.parse_args() 59 | 60 | images_metas = read_images_binary(f"{args.base_dir}/images.bin") 61 | cam_centers = np.array([ 62 | -qvec2rotmat(images_metas[key].qvec).astype(np.float32).T @ images_metas[key].tvec.astype(np.float32) 63 | for key in images_metas 64 | ]) 65 | n_neighbours = min(args.n_neighbours, len(cam_centers)) 66 | cam_nbrs = NearestNeighbors(n_neighbors=n_neighbours).fit(cam_centers) 67 | 68 | def get_matches(key, cam_center): 69 | _, indices = cam_nbrs.kneighbors(cam_center[None]) 70 | matches = "" 71 | keys = list(images_metas.keys()) 72 | for idx in indices[0, 1:]: 73 | matches += f"{images_metas[key].name} {images_metas[keys[idx]].name}\n" 74 | return matches 75 | 76 | matches = [get_matches(key, cam_center) for key, cam_center in zip(images_metas, cam_centers)] 77 | # matches = Parallel(n_jobs=-1, backend="threading")( 78 | # delayed(get_matches)(key, cam_center) for key, cam_center in zip(images_metas, cam_centers) 79 | # ) 80 | 81 | matches_str = [] 82 | for match in matches: 83 | matches_str.append(match) 84 | 85 | with open(f"{args.base_dir}/matching_{args.n_neighbours}.txt", "w") as f: 86 | f.write(''.join(matches_str)) -------------------------------------------------------------------------------- /preprocess/make_depth_scale.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import argparse 14 | import cv2 15 | from joblib import delayed, Parallel 16 | import json 17 | from read_write_model import * 18 | 19 | def get_scales(key, cameras, images, points3d_ordered, args): 20 | image_meta = images[key] 21 | cam_intrinsic = cameras[image_meta.camera_id] 22 | 23 | pts_idx = images_metas[key].point3D_ids 24 | 25 | mask = pts_idx >= 0 26 | mask *= pts_idx < len(points3d_ordered) 27 | 28 | pts_idx = pts_idx[mask] 29 | valid_xys = image_meta.xys[mask] 30 | 31 | if len(pts_idx) > 0: 32 | pts = points3d_ordered[pts_idx] 33 | else: 34 | pts = np.array([0, 0, 0]) 35 | 36 | R = qvec2rotmat(image_meta.qvec) 37 | pts = np.dot(pts, R.T) + image_meta.tvec 38 | 39 | invcolmapdepth = 1. / pts[..., 2] 40 | n_remove = len(image_meta.name.split('.')[-1]) + 1 41 | invmonodepthmap = cv2.imread(f"{args.depths_dir}/{image_meta.name[:-n_remove]}.png", cv2.IMREAD_UNCHANGED) 42 | 43 | if invmonodepthmap is None: 44 | return None 45 | 46 | if invmonodepthmap.ndim != 2: 47 | invmonodepthmap = invmonodepthmap[..., 0] 48 | 49 | invmonodepthmap = invmonodepthmap.astype(np.float32) / (2**16) 50 | s = invmonodepthmap.shape[0] / cam_intrinsic.height 51 | 52 | maps = (valid_xys * s).astype(np.float32) 53 | valid = ( 54 | (maps[..., 0] >= 0) * 55 | (maps[..., 1] >= 0) * 56 | (maps[..., 0] < cam_intrinsic.width * s) * 57 | (maps[..., 1] < cam_intrinsic.height * s) * (invcolmapdepth > 0)) 58 | 59 | if valid.sum() > 10 and (invcolmapdepth.max() - invcolmapdepth.min()) > 1e-3: 60 | maps = maps[valid, :] 61 | invcolmapdepth = invcolmapdepth[valid] 62 | invmonodepth = cv2.remap(invmonodepthmap, maps[..., 0], maps[..., 1], interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)[..., 0] 63 | 64 | ## Median / dev 65 | t_colmap = np.median(invcolmapdepth) 66 | s_colmap = np.mean(np.abs(invcolmapdepth - t_colmap)) 67 | 68 | t_mono = np.median(invmonodepth) 69 | s_mono = np.mean(np.abs(invmonodepth - t_mono)) 70 | scale = s_colmap / s_mono 71 | offset = t_colmap - t_mono * scale 72 | else: 73 | scale = 0 74 | offset = 0 75 | return {"image_name": image_meta.name[:-n_remove], "scale": scale, "offset": offset} 76 | 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('--base_dir', required=True) 80 | parser.add_argument('--depths_dir', required=True) 81 | parser.add_argument('--model_type', default="bin") 82 | args = parser.parse_args() 83 | 84 | 85 | cam_intrinsics, images_metas, points3d = read_model(os.path.join(args.base_dir, "sparse", "0"), ext=f".{args.model_type}") 86 | 87 | pts_indices = np.array([points3d[key].id for key in points3d]) 88 | pts_xyzs = np.array([points3d[key].xyz for key in points3d]) 89 | points3d_ordered = np.zeros([pts_indices.max()+1, 3]) 90 | points3d_ordered[pts_indices] = pts_xyzs 91 | 92 | # depth_param_list = [get_scales(key, cam_intrinsics, images_metas, points3d_ordered, args) for key in images_metas] 93 | depth_param_list = Parallel(n_jobs=-1, backend="threading")( 94 | delayed(get_scales)(key, cam_intrinsics, images_metas, points3d_ordered, args) for key in images_metas 95 | ) 96 | 97 | depth_params = { 98 | depth_param["image_name"]: {"scale": depth_param["scale"], "offset": depth_param["offset"]} 99 | for depth_param in depth_param_list if depth_param != None 100 | } 101 | 102 | with open(f"{args.base_dir}/sparse/0/depth_params.json", "w") as f: 103 | json.dump(depth_params, f, indent=2) 104 | 105 | print(0) 106 | -------------------------------------------------------------------------------- /preprocess/make_mask_uint8.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import cv2 4 | import numpy as np 5 | from joblib import delayed, Parallel 6 | import argparse 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--in_dir') 11 | parser.add_argument('--out_dir') 12 | args = parser.parse_args() 13 | 14 | in_dir = args.in_dir 15 | masks_dir = args.out_dir 16 | 17 | folders = os.listdir(in_dir) 18 | if "png" in folders[0]: 19 | all_img_names = folders 20 | else: 21 | all_img_names = [] 22 | for folder in folders: 23 | img_names = os.listdir(f"{in_dir}/{folder}") 24 | img_names = [f"{folder}/{img_name}" for img_name in img_names] 25 | all_img_names += img_names 26 | 27 | def split_mask(name): 28 | img = cv2.imread(f"{in_dir}/{name}", cv2.IMREAD_UNCHANGED) 29 | if img is not None: 30 | os.makedirs(os.path.dirname(f"{masks_dir}/{name}"), exist_ok=True) 31 | mask = (img[..., -1] > 250).astype(np.uint8) * 255 32 | cv2.imwrite(f"{masks_dir}/{name}", (cv2.erode(mask, np.ones([3, 3])) > 250).astype(np.uint8) * 255) 33 | 34 | Parallel(n_jobs=-1, backend="threading")( 35 | delayed(split_mask)(name) for name in all_img_names 36 | ) 37 | -------------------------------------------------------------------------------- /preprocess/prepare_chunk.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os, sys, shutil 13 | import subprocess 14 | import argparse 15 | import time, platform 16 | from read_write_model import write_points3D_binary 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--raw_chunk', type=str, help='Input raw chunk', required=True) 21 | parser.add_argument('--out_chunk', type=str, help='Output chunk', required=True) 22 | parser.add_argument('--images_dir', type=str, help='Images directory', required=True) 23 | parser.add_argument('--skip_bundle_adjustment', action="store_true", default=False) 24 | args = parser.parse_args() 25 | 26 | matching_nb = 50 if args.skip_bundle_adjustment else 200 27 | colmap_exe = "colmap.bat" if platform.system() == "Windows" else "colmap" 28 | bundle_adj_chunk = os.path.join(args.raw_chunk, "bundle_adjustment") 29 | 30 | if not os.path.exists(bundle_adj_chunk): 31 | os.makedirs(os.path.join(bundle_adj_chunk, "sparse")) 32 | 33 | # First, create a new colmap database for each chunk, it is filled with the raw chunk colmap 34 | gen_db_attr = [ 35 | "python", "preprocess/fill_database.py", 36 | "--in_dir", os.path.join(args.raw_chunk, "sparse", "0"), 37 | "--database_path", os.path.join(bundle_adj_chunk, "database.db") 38 | ] 39 | try: 40 | subprocess.run(gen_db_attr, check=True) 41 | except subprocess.CalledProcessError as e: 42 | print(f"Error executing gen_database: {e}") 43 | sys.exit(1) 44 | 45 | ## A custom matching file is generated for the chunk, this one is based on distance 46 | make_colmap_custom_matcher_args = [ 47 | "python", "preprocess/make_colmap_custom_matcher_distance.py", 48 | "--base_dir", os.path.join(args.raw_chunk, "sparse", "0"), 49 | "--n_neighbours", f"{matching_nb}" 50 | ] 51 | try: 52 | subprocess.run(make_colmap_custom_matcher_args, check=True) 53 | except subprocess.CalledProcessError as e: 54 | print(f"Error executing custom matcher distance: {e}") 55 | sys.exit(1) 56 | 57 | shutil.copy(os.path.join(args.raw_chunk, "sparse", "0", f"matching_{matching_nb}.txt"), os.path.join(bundle_adj_chunk, f"matching_{matching_nb}.txt")) 58 | 59 | ## Extracting the subset of images corresponding to that chunk 60 | print(f"undistorting to chunk {bundle_adj_chunk}...") 61 | colmap_image_undistorter_args = [ 62 | colmap_exe, "image_undistorter", 63 | "--image_path", f"{args.images_dir}", 64 | "--input_path", f"{args.raw_chunk}/sparse/0", 65 | "--output_path", f"{bundle_adj_chunk}", 66 | "--output_type", "COLMAP" 67 | ] 68 | try: 69 | subprocess.run(colmap_image_undistorter_args, check=True) 70 | except subprocess.CalledProcessError as e: 71 | print(f"Error executing image_undistorter: {e}") 72 | sys.exit(1) 73 | 74 | print("extracting features...") 75 | colmap_feature_extractor_args = [ 76 | colmap_exe, "feature_extractor", 77 | "--database_path", f"{bundle_adj_chunk}/database.db", 78 | "--image_path", f"{bundle_adj_chunk}/images", 79 | "--ImageReader.existing_camera_id", "1", 80 | ] 81 | 82 | try: 83 | subprocess.run(colmap_feature_extractor_args, check=True) 84 | except subprocess.CalledProcessError as e: 85 | print(f"Error executing colmap feature_extractor: {e}") 86 | sys.exit(1) 87 | 88 | print("feature matching...") 89 | colmap_matches_importer_args = [ 90 | colmap_exe, "matches_importer", 91 | "--database_path", f"{bundle_adj_chunk}/database.db", 92 | "--match_list_path", f"{bundle_adj_chunk}/matching_{matching_nb}.txt" 93 | ] 94 | try: 95 | subprocess.run(colmap_matches_importer_args, check=True) 96 | except subprocess.CalledProcessError as e: 97 | print(f"Error executing colmap matches_importer: {e}") 98 | sys.exit(1) 99 | 100 | os.makedirs(os.path.join(bundle_adj_chunk, "sparse", "o")) 101 | os.makedirs(os.path.join(bundle_adj_chunk, "sparse", "t")) 102 | os.makedirs(os.path.join(bundle_adj_chunk, "sparse", "b")) 103 | os.makedirs(os.path.join(bundle_adj_chunk, "sparse", "t2")) 104 | os.makedirs(os.path.join(bundle_adj_chunk, "sparse", "0")) 105 | 106 | shutil.copy(os.path.join(args.raw_chunk, "sparse", "0", "images.bin"), os.path.join(bundle_adj_chunk, "sparse", "o", "images.bin")) 107 | shutil.copy(os.path.join(args.raw_chunk, "sparse", "0", "cameras.bin"), os.path.join(bundle_adj_chunk, "sparse", "o", "cameras.bin")) 108 | 109 | # points3D.bin shouldnt be completely empty (must have 1 BYTE) 110 | write_points3D_binary({}, os.path.join(bundle_adj_chunk, "sparse", "o", "points3D.bin")) 111 | 112 | if args.skip_bundle_adjustment: 113 | subprocess.run([colmap_exe, "point_triangulator", 114 | "--Mapper.ba_global_max_num_iterations", "5", 115 | "--Mapper.ba_global_max_refinements", "1", 116 | "--database_path", f"{bundle_adj_chunk}/database.db", 117 | "--image_path", f"{bundle_adj_chunk}/images", 118 | "--input_path", f"{bundle_adj_chunk}/sparse/o", 119 | "--output_path", f"{bundle_adj_chunk}/sparse/0", 120 | ], check=True) 121 | else: 122 | colmap_point_triangulator_args = [ 123 | colmap_exe, "point_triangulator", 124 | "--Mapper.ba_global_function_tolerance", "0.000001", 125 | "--Mapper.ba_global_max_num_iterations", "30", 126 | "--Mapper.ba_global_max_refinements", "3", 127 | ] 128 | 129 | colmap_bundle_adjuster_args = [ 130 | colmap_exe, "bundle_adjuster", 131 | "--BundleAdjustment.refine_extra_params", "0", 132 | "--BundleAdjustment.function_tolerance", "0.000001", 133 | "--BundleAdjustment.max_linear_solver_iterations", "100", 134 | "--BundleAdjustment.max_num_iterations", "50", 135 | "--BundleAdjustment.refine_focal_length", "0" 136 | ] 137 | # 2 rounds of triangulation + bundle adjustment 138 | try: 139 | subprocess.run(colmap_point_triangulator_args + [ 140 | "--database_path", f"{bundle_adj_chunk}/database.db", 141 | "--image_path", f"{bundle_adj_chunk}/images", 142 | "--input_path", f"{bundle_adj_chunk}/sparse/o", 143 | "--output_path", f"{bundle_adj_chunk}/sparse/t", 144 | ], check=True) 145 | except subprocess.CalledProcessError as e: 146 | print(f"Error executing colmap_point_triangulator_args: {e}") 147 | sys.exit(1) 148 | 149 | try: 150 | subprocess.run(colmap_bundle_adjuster_args + [ 151 | "--input_path", f"{bundle_adj_chunk}/sparse/t", 152 | "--output_path", f"{bundle_adj_chunk}/sparse/b", 153 | ], check=True) 154 | except subprocess.CalledProcessError as e: 155 | print(f"Error executing colmap_bundle_adjuster_args: {e}") 156 | sys.exit(1) 157 | 158 | try: 159 | subprocess.run(colmap_point_triangulator_args + [ 160 | "--database_path", f"{bundle_adj_chunk}/database.db", 161 | "--image_path", f"{bundle_adj_chunk}/images", 162 | "--input_path", f"{bundle_adj_chunk}/sparse/b", 163 | "--output_path", f"{bundle_adj_chunk}/sparse/t2", 164 | ], check=True) 165 | except subprocess.CalledProcessError as e: 166 | print(f"Error executing colmap_point_triangulator_args: {e}") 167 | sys.exit(1) 168 | 169 | try: 170 | subprocess.run(colmap_bundle_adjuster_args + [ 171 | "--input_path", f"{bundle_adj_chunk}/sparse/t2", 172 | "--output_path", f"{bundle_adj_chunk}/sparse/0", 173 | ], check=True) 174 | except subprocess.CalledProcessError as e: 175 | print(f"Error executing colmap_bundle_adjuster_args: {e}") 176 | sys.exit(1) 177 | 178 | transform_colmap_args = [ 179 | "python", "preprocess/transform_colmap.py", 180 | "--in_dir", args.raw_chunk, 181 | "--new_colmap_dir", bundle_adj_chunk, 182 | "--out_dir", args.out_chunk 183 | ] 184 | 185 | ## Correct slight shifts that might have happened during bundle adjustments 186 | try: 187 | subprocess.run(transform_colmap_args, check=True) 188 | except subprocess.CalledProcessError as e: 189 | print(f"Error executing transform_colmap_args: {e}") 190 | sys.exit(1) 191 | -------------------------------------------------------------------------------- /preprocess/prepare_chunk.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## YOU SHOULD PUT YOUR SLURM PARAMETERS HERE (GPU, ACCOUNT, ETC ...) 4 | 5 | #SBATCH -A hzb@a100 6 | #SBATCH -C a100 7 | #SBATCH --ntasks=1 8 | #SBATCH --nodes=1 9 | #SBATCH --gres=gpu:1 10 | #SBATCH --cpus-per-task=20 11 | #SBATCH --time=3:00:00 12 | 13 | RAW_CHUNK=$1 14 | OUT_CHUNK=$2 15 | IMAGES_DIR=$3 16 | PREPROCESS_DIR=$5 17 | N_NEIGHB=200 18 | 19 | 20 | BUNDLE_ADJ_CHUNK=${RAW_CHUNK}/bundle_adjustment 21 | echo "JOB FOR CHUNK" ${RAW_CHUNK} 22 | source $WORK/miniconda3/etc/profile.d/conda.sh 23 | 24 | module load cpuarch/amd 25 | module load colmap 26 | conda activate 3dgs_single 27 | 28 | ## Generate the chunk's colmap in an intermediate folder 29 | mkdir ${BUNDLE_ADJ_CHUNK} 30 | mkdir ${BUNDLE_ADJ_CHUNK}/sparse 31 | 32 | python ${PREPROCESS_DIR}/fill_database.py --in_dir ${RAW_CHUNK}/sparse/0 --database_path ${BUNDLE_ADJ_CHUNK}/database.db 33 | 34 | python ${PREPROCESS_DIR}/make_colmap_custom_matcher_distance.py --base_dir ${RAW_CHUNK}/sparse/0 --n_neighbours ${N_NEIGHB} 35 | cp ${RAW_CHUNK}/sparse/0/matching_${N_NEIGHB}.txt ${BUNDLE_ADJ_CHUNK}/matching_${N_NEIGHB}.txt 36 | 37 | colmap image_undistorter --image_path ${IMAGES_DIR} --input_path ${RAW_CHUNK}/sparse/0 --output_path ${BUNDLE_ADJ_CHUNK} --output_type COLMAP 38 | colmap feature_extractor --database_path ${BUNDLE_ADJ_CHUNK}/database.db --image_path ${BUNDLE_ADJ_CHUNK}/images --ImageReader.existing_camera_id 1 39 | colmap matches_importer --database_path ${BUNDLE_ADJ_CHUNK}/database.db --match_list_path ${BUNDLE_ADJ_CHUNK}/matching_${N_NEIGHB}.txt 40 | # colmap exhaustive_matcher --database_path ${BUNDLE_ADJ_CHUNK}/database.db 41 | 42 | mkdir ${BUNDLE_ADJ_CHUNK}/sparse/o ${BUNDLE_ADJ_CHUNK}/sparse/t ${BUNDLE_ADJ_CHUNK}/sparse/b ${BUNDLE_ADJ_CHUNK}/sparse/t2 ${BUNDLE_ADJ_CHUNK}/sparse/0 43 | cp ${RAW_CHUNK}/sparse/0/images.bin ${RAW_CHUNK}/sparse/0/cameras.bin ${BUNDLE_ADJ_CHUNK}/sparse/o/ 44 | 45 | touch ${BUNDLE_ADJ_CHUNK}/sparse/o/points3D.bin 46 | 47 | ## 2 Rounds of triangulation + bundle adjustment 48 | colmap point_triangulator --Mapper.ba_global_function_tolerance 0.000001 --Mapper.ba_global_max_num_iterations 30 --Mapper.ba_global_max_refinements 3 --database_path ${BUNDLE_ADJ_CHUNK}/database.db --image_path ${BUNDLE_ADJ_CHUNK}/images --input_path ${BUNDLE_ADJ_CHUNK}/sparse/o --output_path ${BUNDLE_ADJ_CHUNK}/sparse/t 49 | colmap bundle_adjuster --BundleAdjustment.refine_extra_params 0 --BundleAdjustment.function_tolerance 0.000001 --BundleAdjustment.max_linear_solver_iterations 100 --BundleAdjustment.max_num_iterations 50 --BundleAdjustment.refine_focal_length 0 --input_path ${BUNDLE_ADJ_CHUNK}/sparse/t --output_path ${BUNDLE_ADJ_CHUNK}/sparse/b 50 | 51 | colmap point_triangulator --Mapper.ba_global_function_tolerance 0.000001 --Mapper.ba_global_max_num_iterations 30 --Mapper.ba_global_max_refinements 3 --database_path ${BUNDLE_ADJ_CHUNK}/database.db --image_path ${BUNDLE_ADJ_CHUNK}/images --input_path ${BUNDLE_ADJ_CHUNK}/sparse/b --output_path ${BUNDLE_ADJ_CHUNK}/sparse/t2 52 | colmap bundle_adjuster --BundleAdjustment.refine_extra_params 0 --BundleAdjustment.function_tolerance 0.000001 --BundleAdjustment.max_linear_solver_iterations 100 --BundleAdjustment.max_num_iterations 50 --BundleAdjustment.refine_focal_length 0 --input_path ${BUNDLE_ADJ_CHUNK}/sparse/t2 --output_path ${BUNDLE_ADJ_CHUNK}/sparse/0 53 | 54 | ## Correct shifts that might have happened when bundle adjusting 55 | python ${PREPROCESS_DIR}/transform_colmap.py --in_dir ${RAW_CHUNK} --new_colmap_dir ${BUNDLE_ADJ_CHUNK} --out_dir ${OUT_CHUNK} 56 | 57 | echo ${OUT_CHUNK} " DONE." 58 | -------------------------------------------------------------------------------- /preprocess/reorient.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import torch 14 | import argparse 15 | import ast 16 | import os, time 17 | from read_write_model import * 18 | 19 | def rotate_camera(qvec, tvec, rot_matrix, upscale): 20 | # Assuming cameras have 'T' (translation) field 21 | 22 | R = qvec2rotmat(qvec) 23 | T = np.array(tvec) 24 | 25 | Rt = np.zeros((4, 4)) 26 | Rt[:3, :3] = R 27 | Rt[:3, 3] = T 28 | Rt[3, 3] = 1.0 29 | 30 | C2W = np.linalg.inv(Rt) 31 | cam_center = np.copy(C2W[:3, 3]) 32 | cam_rot_orig = np.copy(C2W[:3, :3]) 33 | cam_center = np.matmul(cam_center, rot_matrix) 34 | cam_rot = np.linalg.inv(rot_matrix) @ cam_rot_orig 35 | C2W[:3, 3] = upscale * cam_center 36 | C2W[:3, :3] = cam_rot 37 | Rt = np.linalg.inv(C2W) 38 | new_pos = Rt[:3, 3] 39 | new_rot = rotmat2qvec(Rt[:3, :3]) 40 | 41 | # R_test = qvec2rotmat(new_rots[-1]) 42 | # T_test = np.array(new_poss[-1]) 43 | # Rttest = np.zeros((4, 4)) 44 | # Rttest[:3, :3] = R_test 45 | # Rttest[:3, 3] = T_test 46 | # Rttest[3, 3] = 1.0 47 | # C2Wtest = np.linalg.inv(Rttest) 48 | 49 | return new_pos, new_rot 50 | 51 | # Function to compute the cross product of two 3D vectors 52 | def cross_product(v1, v2): 53 | return torch.cross(v1, v2) 54 | 55 | # Function to normalize a 3D vector 56 | def normalize_vector(v): 57 | norm = torch.norm(v, p=2) # Calculate the Euclidean norm (L2 norm) 58 | return v / norm 59 | 60 | def parse_vector(s): 61 | try: 62 | result = ast.literal_eval(s) 63 | if isinstance(result, tuple) and len(result) == 3: 64 | return result 65 | else: 66 | raise ValueError("Invalid vector format. Must be a 3-element tuple.") 67 | except (ValueError, SyntaxError): 68 | raise argparse.ArgumentTypeError("Invalid vector format. Example: (1.0, 2.0, 3.0)") 69 | 70 | def main(): 71 | 72 | parser = argparse.ArgumentParser(description='Example script with command-line arguments.') 73 | 74 | # Add command-line argument(s) 75 | parser.add_argument('--input_path', type=str, help='Path to input colmap dir', required=True) 76 | parser.add_argument('--output_path', type=str, help='Path to output colmap dir', required=True) 77 | parser.add_argument('--upscale', type=float, help='Upscaling factor', required=True) 78 | 79 | # Add command-line arguments for two 3D vectors 80 | parser.add_argument('--up', type=parse_vector, help='Up 3D vector in the format (x, y, z)', required=True) 81 | parser.add_argument('--right', type=parse_vector, help='Right 3D vector in the format (x, y, z)', required = True) 82 | parser.add_argument('--input_format', type=str, help='specify which file format to use when processing colmap files (txt or bin)', choices=['bin','txt'], default='bin') 83 | 84 | # Parse the command-line arguments 85 | args = parser.parse_args() 86 | 87 | global_start = time.time() 88 | 89 | # Access the parsed vectors 90 | vector1 = args.up 91 | vector2 = args.right 92 | 93 | # Your main logic goes here 94 | print("Up Vector:", vector1) 95 | print("Right Vector:", vector2) 96 | 97 | ext = args.input_format 98 | 99 | #print(float(vector1[0]),float(vector1[1]),float(vector1[2])) 100 | 101 | up = torch.Tensor(vector1).double() 102 | right = torch.Tensor(vector2).double() 103 | 104 | # Parse the command-line arguments 105 | args = parser.parse_args() 106 | 107 | # Access the parsed arguments 108 | os.makedirs(args.output_path, exist_ok=True) 109 | 110 | # Your main logic goes here 111 | print("Input path:", args.input_path) 112 | print("Output path:", args.output_path) 113 | print(f"processing format: .{ext}") 114 | 115 | #up = torch.Tensor([0, -1, -0.2]).double() 116 | #right = torch.Tensor([-0.5, 0, -0.5]).double() 117 | 118 | up = normalize_vector(up) 119 | right = normalize_vector(right) 120 | 121 | forward = cross_product(up, right) 122 | forward = normalize_vector(forward) 123 | right = cross_product(forward, up) 124 | right = normalize_vector(right) 125 | 126 | # Stack the target axes as columns to form the rotation matrix 127 | rotation_matrix = torch.stack([right, forward, up], dim=1) 128 | 129 | # Read colmap cameras, images and points 130 | start_time = time.time() 131 | cameras, images_metas_in, points3d_in = read_model(args.input_path, ext=f".{ext}") 132 | end_time = time.time() 133 | print(f"{len(images_metas_in)} images read in {end_time - start_time} seconds.") 134 | 135 | positions = [] 136 | print("Doing points") 137 | for key in points3d_in: 138 | positions.append(points3d_in[key].xyz) 139 | 140 | positions = torch.from_numpy(np.array(positions)) 141 | 142 | # Perform the rotation by matrix multiplication 143 | rotated_points = args.upscale * torch.matmul(positions, rotation_matrix) 144 | 145 | points3d_out = {} 146 | for key, rotated in zip(points3d_in, rotated_points): 147 | point3d_in = points3d_in[key] 148 | points3d_out[key] = Point3D( 149 | id=point3d_in.id, 150 | xyz=rotated, 151 | rgb=point3d_in.rgb, 152 | error=point3d_in.error, 153 | image_ids=point3d_in.image_ids, 154 | point2D_idxs=point3d_in.point2D_idxs, 155 | ) 156 | 157 | print("Doing images") 158 | images_metas_out = {} 159 | for key in images_metas_in: 160 | image_meta_in = images_metas_in[key] 161 | new_pos, new_rot = rotate_camera(image_meta_in.qvec, image_meta_in.tvec, rotation_matrix.double().numpy(), args.upscale) 162 | 163 | images_metas_out[key] = Image( 164 | id=image_meta_in.id, 165 | qvec=new_rot, 166 | tvec=new_pos, 167 | camera_id=image_meta_in.camera_id, 168 | name=image_meta_in.name, 169 | xys=image_meta_in.xys, 170 | point3D_ids=image_meta_in.point3D_ids, 171 | ) 172 | 173 | write_model(cameras, images_metas_out, points3d_out, args.output_path, f".{ext}") 174 | 175 | global_end = time.time() 176 | 177 | print(f"reorient script took {global_end - global_start} seconds ({ext} file processed).") 178 | 179 | if __name__ == "__main__": 180 | main() -------------------------------------------------------------------------------- /preprocess/simplify_images.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os, argparse 13 | import numpy as np 14 | from read_write_model import Image, read_images_binary, read_images_text, write_images_binary, write_images_text, qvec2rotmat 15 | from sklearn.neighbors import NearestNeighbors 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--base_dir', help="path to colmap images file (.txt or .bin)", required="true") 20 | parser.add_argument('--mult_min_dist', type=float, help="points at distance > [mult_min_dist]*median_dist_neighbors are removed. (default 10)", default=10) 21 | parser.add_argument('--model_type', default="bin") 22 | 23 | print("cleaning colmap images.bin file: removing useless data.") 24 | 25 | args = parser.parse_args() 26 | 27 | ext = args.model_type 28 | images_file = os.path.join(args.base_dir, f"images.{ext}") 29 | 30 | images_metas = {} 31 | if ext == "txt": 32 | images_metas = read_images_text(images_file) 33 | elif ext == "bin": 34 | images_metas = read_images_binary(images_file) 35 | 36 | 37 | cam_centers = np.array([ 38 | -qvec2rotmat(images_metas[key].qvec).T @ images_metas[key].tvec 39 | for key in images_metas 40 | ]) 41 | cam_nbrs = NearestNeighbors(n_neighbors=2).fit(cam_centers) 42 | centers_std = cam_centers.std(axis=0).mean() 43 | 44 | second_min_distances = [] 45 | for key, cam_center in zip(images_metas, cam_centers): 46 | distances, indices = cam_nbrs.kneighbors(cam_center[None]) 47 | second_min_distances.append(distances[0, -1]) 48 | 49 | 50 | med_dist = np.median(second_min_distances) 51 | filtered_images = {} 52 | 53 | for key, second_min_distance in zip(images_metas, second_min_distances): 54 | image_meta = images_metas[key] 55 | 56 | if len(image_meta.point3D_ids) > 0 and second_min_distance <= args.mult_min_dist * med_dist: 57 | valid_pts_mask = image_meta.point3D_ids >= 0 58 | if valid_pts_mask.sum() > 0: 59 | filtered_images[key] = Image( 60 | id=image_meta.id, 61 | qvec=image_meta.qvec, 62 | tvec=image_meta.tvec, 63 | camera_id=image_meta.camera_id, 64 | name=image_meta.name, 65 | xys=image_meta.xys[valid_pts_mask], 66 | point3D_ids=image_meta.point3D_ids[valid_pts_mask], 67 | ) 68 | 69 | 70 | # filtered_images[key].valid_point3D_ids = filtered_images[key].point3D_ids[filtered_images[key].point3D_ids >= 0] 71 | # filtered_images[key].valid_xys = filtered_images[key].xys[filtered_images[key].point3D_ids >= 0] 72 | 73 | # if(len(filtered_images[key].valid_point3D_ids) != len(filtered_images[key].point3D_ids)): 74 | # print("reducing size ...") 75 | 76 | print(f"{len(images_metas)} images before; {len(filtered_images)} images after") 77 | 78 | # rename old images.bin/txt as images_heavy 79 | if os.path.exists(f"images_heavy.{ext}"): 80 | os.remove(f"images_heavy.{ext}") 81 | os.rename(images_file, os.path.join(args.base_dir, f"images_heavy.{ext}")) 82 | 83 | if ext == "txt": 84 | write_images_text(filtered_images, images_file) 85 | elif ext == "bin": 86 | write_images_binary(filtered_images, images_file) 87 | 88 | print(0) -------------------------------------------------------------------------------- /preprocess/transform_colmap.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import argparse 14 | import os 15 | import shutil 16 | import random 17 | import torch 18 | from read_write_model import * 19 | 20 | Transform = collections.namedtuple( 21 | "Transform", ["t0", "t1", "s0", "s1", "R"] 22 | ) 23 | 24 | def procrustes_analysis(X0,X1): # [N,3] 25 | """ 26 | From https://github.com/chenhsuanlin/bundle-adjusting-NeRF/blob/803291bd0ee91c7c13fb5cc42195383c5ade7d15/camera.py#L278 27 | """ 28 | # translation 29 | t0 = X0.mean(dim=0,keepdim=True) 30 | t1 = X1.mean(dim=0,keepdim=True) 31 | X0c = X0-t0 32 | X1c = X1-t1 33 | # scale 34 | s0 = (X0c**2).sum(dim=-1).mean().sqrt() 35 | s1 = (X1c**2).sum(dim=-1).mean().sqrt() 36 | X0cs = X0c/s0 37 | X1cs = X1c/s1 38 | # rotation (use double for SVD, float loses precision) 39 | U,S,V = (X0cs.t()@X1cs).double().svd(some=True) 40 | R = (U@V.t()).float() 41 | if R.det()<0: R[2] *= -1 42 | # align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0 43 | sim3 = Transform(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R) 44 | return sim3 45 | 46 | def get_nb_pts(image_metas): 47 | n_pts = 0 48 | for key in image_metas: 49 | pts_idx = image_metas[key].point3D_ids 50 | if(len(pts_idx) > 5): 51 | n_pts = max(n_pts, np.max(pts_idx)) 52 | 53 | return n_pts + 1 54 | 55 | if __name__ == '__main__': 56 | random.seed(0) 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--in_dir', required=True) 59 | parser.add_argument('--new_colmap_dir', required=True) 60 | parser.add_argument('--out_dir', required=True) 61 | args = parser.parse_args() 62 | 63 | old_images_metas = read_images_binary(f"{args.in_dir}/sparse/0/images.bin") 64 | new_images_metas = read_images_binary(f"{args.new_colmap_dir}/sparse/0/images.bin") 65 | n_pts = get_nb_pts(new_images_metas) 66 | 67 | old_keys = old_images_metas.keys() 68 | old_keys_dict = {old_images_metas[key].name: key for key in old_keys} 69 | new_old_key_mapping = {key: old_keys_dict[new_images_metas[key].name] for key in new_images_metas} 70 | 71 | old_cam_centers = np.array([ 72 | -qvec2rotmat(old_images_metas[new_old_key_mapping[key]].qvec).astype(np.float32).T @ old_images_metas[new_old_key_mapping[key]].tvec.astype(np.float32) 73 | for key in new_images_metas 74 | ]) 75 | new_cam_centers = np.array([ 76 | -qvec2rotmat(new_images_metas[key].qvec).astype(np.float32).T @ new_images_metas[key].tvec.astype(np.float32) 77 | for key in new_images_metas 78 | ]) 79 | 80 | dists = np.linalg.norm(old_cam_centers - new_cam_centers, axis=-1) 81 | valid_cams = dists <= (np.median(dists) * 5) + 1e-8 82 | 83 | old_cam_centers_torch = torch.from_numpy(old_cam_centers) 84 | new_cam_centers_torch = torch.from_numpy(new_cam_centers) 85 | 86 | old_cam_centers_trimmed = old_cam_centers[valid_cams] 87 | new_cam_centers_trimmed = new_cam_centers[valid_cams] 88 | old_cam_centers_torch_trimmed = torch.from_numpy(old_cam_centers_trimmed) 89 | new_cam_centers_torch_trimmed = torch.from_numpy(new_cam_centers_trimmed) 90 | 91 | sim3 = procrustes_analysis(old_cam_centers_torch_trimmed, new_cam_centers_torch_trimmed) 92 | center_aligned = (new_cam_centers_torch-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0 93 | points3d = read_points3D_binary(f"{args.new_colmap_dir}/sparse/0/points3D.bin") 94 | 95 | xyzs = np.zeros([n_pts, 3], np.float32) 96 | errors = np.zeros([n_pts], np.float32) + 9e9 97 | indices = np.zeros([n_pts], np.int64) 98 | n_images = np.zeros([n_pts], np.int64) 99 | colors = np.zeros([n_pts, 3], np.float32) 100 | 101 | idx = 0 102 | for key in points3d: 103 | xyzs[idx] = points3d[key].xyz 104 | indices[idx] = points3d[key].id 105 | errors[idx] = points3d[key].error 106 | colors[idx] = points3d[key].rgb 107 | n_images[idx] = len(points3d[key].image_ids) 108 | idx +=1 109 | 110 | mask = errors < 1.5 111 | mask *= n_images > 3 112 | 113 | xyzsC, colorsC, errorsC, indicesC, n_imagesC = xyzs[mask], colors[mask], errors[mask], indices[mask], n_images[mask] 114 | 115 | points3dC_aligned = ((torch.from_numpy(xyzsC)-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0).numpy() 116 | 117 | R = torch.from_numpy(np.array([ 118 | qvec2rotmat(new_images_metas[key].qvec).astype(np.float32) 119 | for key in new_images_metas 120 | ])) 121 | R_aligned = R@sim3.R.t() 122 | t_aligned = (-R_aligned@center_aligned[...,None])[...,0] 123 | 124 | with open(f"{args.in_dir}/center.txt", 'r') as f: 125 | center = np.array(tuple(map(float, f.readline().strip().split()))) 126 | with open(f"{args.in_dir}/extent.txt", 'r') as f: 127 | extent = np.array(tuple(map(float, f.readline().strip().split()))) 128 | 129 | corner_min = center - 1.1 * extent / 2 130 | corner_max = center + 1.1 * extent / 2 131 | 132 | out_colmap = f"{args.out_dir}/sparse/0" 133 | os.makedirs(out_colmap, exist_ok=True) 134 | 135 | mask = np.all(points3dC_aligned < corner_max, axis=-1) * np.all(points3dC_aligned > corner_min, axis=-1) 136 | new_points3d = points3dC_aligned#[mask] 137 | new_colors = np.clip(colorsC, 0, 255).astype(np.uint8) 138 | new_indices = indicesC#[mask] 139 | new_errors = errorsC#[mask] 140 | 141 | images_metas_out = {} 142 | for key, R, t, valid_cam in zip(new_images_metas, R_aligned.numpy(), t_aligned.numpy(), valid_cams): 143 | if valid_cam: 144 | image_meta = new_images_metas[key] 145 | 146 | images_metas_out[key] = Image( 147 | id = key, 148 | qvec = rotmat2qvec(R), 149 | tvec = t, 150 | camera_id = image_meta.camera_id, 151 | name = image_meta.name, 152 | xys = image_meta.xys, 153 | point3D_ids = image_meta.point3D_ids, 154 | ) 155 | 156 | write_images_binary(images_metas_out, f"{out_colmap}/images.bin") 157 | 158 | points_out = { 159 | new_indices[idx] : Point3D( 160 | id=indicesC[idx], 161 | xyz= points3dC_aligned[idx], 162 | rgb=new_colors[idx], 163 | error=errorsC[idx], 164 | image_ids=np.array([]), 165 | point2D_idxs=np.array([]) 166 | ) 167 | for idx in range(len(points3dC_aligned)) 168 | } 169 | 170 | write_points3D_binary(points_out, f"{out_colmap}/points3D.bin") 171 | 172 | shutil.copy(f"{args.new_colmap_dir}/sparse/0/cameras.bin", f"{out_colmap}/cameras.bin") 173 | shutil.copy(f"{args.in_dir}/center.txt", f"{args.out_dir}/center.txt") 174 | shutil.copy(f"{args.in_dir}/extent.txt", f"{args.out_dir}/extent.txt") 175 | 176 | print(0) -------------------------------------------------------------------------------- /render_hierarchy.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import math 13 | import os 14 | import torch 15 | from random import randint 16 | from utils.loss_utils import ssim 17 | from gaussian_renderer import render_post 18 | import sys 19 | from scene import Scene, GaussianModel 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | from arguments import ModelParams, PipelineParams, OptimizationParams 24 | import torchvision 25 | from lpipsPyTorch import lpips 26 | 27 | from gaussian_hierarchy._C import expand_to_size, get_interpolation_weights 28 | 29 | def direct_collate(x): 30 | return x 31 | 32 | @torch.no_grad() 33 | def render_set(args, scene, pipe, out_dir, tau, eval): 34 | render_path = out_dir 35 | 36 | render_indices = torch.zeros(scene.gaussians._xyz.size(0)).int().cuda() 37 | parent_indices = torch.zeros(scene.gaussians._xyz.size(0)).int().cuda() 38 | nodes_for_render_indices = torch.zeros(scene.gaussians._xyz.size(0)).int().cuda() 39 | interpolation_weights = torch.zeros(scene.gaussians._xyz.size(0)).float().cuda() 40 | num_siblings = torch.zeros(scene.gaussians._xyz.size(0)).int().cuda() 41 | 42 | psnr_test = 0.0 43 | ssims = 0.0 44 | lpipss = 0.0 45 | 46 | cameras = scene.getTestCameras() if eval else scene.getTrainCameras() 47 | 48 | for viewpoint in tqdm(cameras): 49 | viewpoint=viewpoint 50 | viewpoint.world_view_transform = viewpoint.world_view_transform.cuda() 51 | viewpoint.projection_matrix = viewpoint.projection_matrix.cuda() 52 | viewpoint.full_proj_transform = viewpoint.full_proj_transform.cuda() 53 | viewpoint.camera_center = viewpoint.camera_center.cuda() 54 | 55 | tanfovx = math.tan(viewpoint.FoVx * 0.5) 56 | threshold = (2 * (tau + 0.5)) * tanfovx / (0.5 * viewpoint.image_width) 57 | 58 | to_render = expand_to_size( 59 | scene.gaussians.nodes, 60 | scene.gaussians.boxes, 61 | threshold, 62 | viewpoint.camera_center, 63 | torch.zeros((3)), 64 | render_indices, 65 | parent_indices, 66 | nodes_for_render_indices) 67 | 68 | indices = render_indices[:to_render].int().contiguous() 69 | node_indices = nodes_for_render_indices[:to_render].contiguous() 70 | 71 | get_interpolation_weights( 72 | node_indices, 73 | threshold, 74 | scene.gaussians.nodes, 75 | scene.gaussians.boxes, 76 | viewpoint.camera_center.cpu(), 77 | torch.zeros((3)), 78 | interpolation_weights, 79 | num_siblings 80 | ) 81 | 82 | image = torch.clamp(render_post( 83 | viewpoint, 84 | scene.gaussians, 85 | pipe, 86 | torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device="cuda"), 87 | render_indices=indices, 88 | parent_indices = parent_indices, 89 | interpolation_weights = interpolation_weights, 90 | num_node_kids = num_siblings, 91 | use_trained_exp=args.train_test_exp 92 | )["render"], 0.0, 1.0) 93 | 94 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 95 | 96 | alpha_mask = viewpoint.alpha_mask.cuda() 97 | 98 | if args.train_test_exp: 99 | image = image[..., image.shape[-1] // 2:] 100 | gt_image = gt_image[..., gt_image.shape[-1] // 2:] 101 | alpha_mask = alpha_mask[..., alpha_mask.shape[-1] // 2:] 102 | 103 | try: 104 | torchvision.utils.save_image(image, os.path.join(render_path, viewpoint.image_name.split(".")[0] + ".png")) 105 | except: 106 | os.makedirs(os.path.dirname(os.path.join(render_path, viewpoint.image_name.split(".")[0] + ".png")), exist_ok=True) 107 | torchvision.utils.save_image(image, os.path.join(render_path, viewpoint.image_name.split(".")[0] + ".png")) 108 | if eval: 109 | image *= alpha_mask 110 | gt_image *= alpha_mask 111 | psnr_test += psnr(image, gt_image).mean().double() 112 | ssims += ssim(image, gt_image).mean().double() 113 | lpipss += lpips(image, gt_image, net_type='vgg').mean().double() 114 | 115 | torch.cuda.empty_cache() 116 | if eval and len(scene.getTestCameras()) > 0: 117 | psnr_test /= len(scene.getTestCameras()) 118 | ssims /= len(scene.getTestCameras()) 119 | lpipss /= len(scene.getTestCameras()) 120 | print(f"tau: {tau}, PSNR: {psnr_test:.5f} SSIM: {ssims:.5f} LPIPS: {lpipss:.5f}") 121 | 122 | if __name__ == "__main__": 123 | # Set up command line argument parser 124 | parser = ArgumentParser(description="Rendering script parameters") 125 | lp = ModelParams(parser) 126 | op = OptimizationParams(parser) 127 | pp = PipelineParams(parser) 128 | parser.add_argument('--out_dir', type=str, default="") 129 | parser.add_argument("--taus", nargs="+", type=float, default=[0.0, 3.0, 6.0, 15.0]) 130 | args = parser.parse_args(sys.argv[1:]) 131 | 132 | print("Rendering " + args.model_path) 133 | 134 | dataset, pipe = lp.extract(args), pp.extract(args) 135 | gaussians = GaussianModel(dataset.sh_degree) 136 | gaussians.active_sh_degree = dataset.sh_degree 137 | scene = Scene(dataset, gaussians, resolution_scales = [1], create_from_hier=True) 138 | 139 | for tau in args.taus: 140 | render_set(args, scene, pipe, os.path.join(args.out_dir, f"render_{tau}"), tau, args.eval) 141 | 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | plyfile 2 | tqdm 3 | joblib 4 | exif 5 | scikit-learn 6 | timm==0.4.5 7 | opencv-python==4.9.0.80 8 | gradio_imageslider 9 | gradio==4.29.0 10 | matplotlib 11 | submodules/hierarchy-rasterizer 12 | submodules/simple-knn 13 | submodules/gaussianhierarchy -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import camera_to_JSON, CameraDataset 20 | from utils.system_utils import mkdir_p 21 | 22 | class Scene: 23 | 24 | gaussians : GaussianModel 25 | 26 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0], create_from_hier=False): 27 | """b 28 | :param path: Path to colmap scene main folder. 29 | """ 30 | self.model_path = args.model_path 31 | self.loaded_iter = None 32 | self.gaussians = gaussians 33 | 34 | if load_iteration: 35 | if load_iteration == -1: 36 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 37 | else: 38 | self.loaded_iter = load_iteration 39 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 40 | 41 | self.train_cameras = {} 42 | self.test_cameras = {} 43 | 44 | if os.path.exists(os.path.join(args.source_path, "sparse")): 45 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.alpha_masks, args.depths, args.eval, args.train_test_exp) 46 | else: 47 | assert False, "Could not recognize scene type!" 48 | 49 | if not self.loaded_iter: 50 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 51 | dest_file.write(src_file.read()) 52 | json_cams = [] 53 | camlist = [] 54 | if scene_info.test_cameras: 55 | camlist.extend(scene_info.test_cameras) 56 | if scene_info.train_cameras: 57 | camlist.extend(scene_info.train_cameras) 58 | for id, cam in enumerate(camlist): 59 | json_cams.append(camera_to_JSON(id, cam)) 60 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 61 | json.dump(json_cams, file) 62 | 63 | if shuffle: 64 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 65 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 66 | 67 | self.cameras_extent = scene_info.nerf_normalization["radius"] 68 | 69 | for resolution_scale in resolution_scales: 70 | print("Making Training Dataset") 71 | self.train_cameras[resolution_scale] = CameraDataset(scene_info.train_cameras, args, resolution_scale, False) 72 | 73 | print("Making Test Dataset") 74 | self.test_cameras[resolution_scale] = CameraDataset(scene_info.test_cameras, args, resolution_scale, True) 75 | 76 | if self.loaded_iter: 77 | self.gaussians.load_ply(os.path.join(self.model_path, 78 | "point_cloud", 79 | "iteration_" + str(self.loaded_iter), 80 | "point_cloud.ply")) 81 | elif args.pretrained: 82 | self.gaussians.create_from_pt(args.pretrained, self.cameras_extent) 83 | elif create_from_hier: 84 | self.gaussians.create_from_hier(args.hierarchy, self.cameras_extent, args.scaffold_file) 85 | else: 86 | self.gaussians.create_from_pcd(scene_info.point_cloud, 87 | scene_info.train_cameras, 88 | self.cameras_extent, 89 | args.skybox_num, 90 | args.scaffold_file, 91 | args.bounds_file, 92 | args.skybox_locked) 93 | 94 | 95 | def save(self, iteration): 96 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 97 | mkdir_p(point_cloud_path) 98 | if self.gaussians.nodes is not None: 99 | self.gaussians.save_hier() 100 | else: 101 | with open(os.path.join(point_cloud_path, "pc_info.txt"), "w") as f: 102 | f.write(str(self.gaussians.skybox_points)) 103 | if self.gaussians._xyz.size(0) > 8_000_000: 104 | self.gaussians.save_pt(point_cloud_path) 105 | else: 106 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 107 | 108 | exposure_dict = { 109 | image_name: self.gaussians.get_exposure_from_name(image_name).detach().cpu().numpy().tolist() 110 | for image_name in self.gaussians.exposure_mapping 111 | } 112 | 113 | with open(os.path.join(self.model_path, "exposure.json"), "w") as f: 114 | json.dump(exposure_dict, f, indent=2) 115 | 116 | def getTrainCameras(self, scale=1.0): 117 | return self.train_cameras[scale] 118 | 119 | def getTestCameras(self, scale=1.0): 120 | return self.test_cameras[scale] 121 | 122 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | import cv2 17 | 18 | from utils.general_utils import PILtoTorch 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | class Camera(nn.Module): 24 | def __init__(self, resolution, colmap_id, R, T, FoVx, FoVy, depth_params, primx, primy, image, alpha_mask, 25 | invdepthmap, 26 | image_name, uid, 27 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", 28 | train_test_exp=False, is_test_dataset=False, is_test_view=False, 29 | ): 30 | super(Camera, self).__init__() 31 | 32 | self.uid = uid 33 | self.colmap_id = colmap_id 34 | self.R = R 35 | self.T = T 36 | self.FoVx = FoVx 37 | self.FoVy = FoVy 38 | self.image_name = image_name 39 | 40 | try: 41 | self.data_device = torch.device(data_device) 42 | except Exception as e: 43 | print(e) 44 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 45 | self.data_device = torch.device("cuda") 46 | 47 | resized_image_rgb = PILtoTorch(image, resolution) 48 | gt_image = resized_image_rgb[:3, ...] 49 | if alpha_mask is not None: 50 | self.alpha_mask = PILtoTorch(alpha_mask, resolution) 51 | elif resized_image_rgb.shape[0] == 4: 52 | self.alpha_mask = resized_image_rgb[3:4, ...].to(self.data_device) 53 | else: 54 | self.alpha_mask = torch.ones_like(resized_image_rgb[0:1, ...].to(self.data_device)) 55 | 56 | if train_test_exp and is_test_view: 57 | if is_test_dataset: 58 | self.alpha_mask[..., :self.alpha_mask.shape[-1] // 2] = 0 59 | else: 60 | self.alpha_mask[..., self.alpha_mask.shape[-1] // 2:] = 0 61 | 62 | self.original_image = gt_image.clamp(0.0, 1.0).to(self.data_device) 63 | self.image_width = self.original_image.shape[2] 64 | self.image_height = self.original_image.shape[1] 65 | 66 | if self.alpha_mask is not None: 67 | self.original_image *= self.alpha_mask 68 | 69 | self.invdepthmap = None 70 | self.depth_reliable = False 71 | if invdepthmap is not None and depth_params is not None and depth_params["scale"] > 0: 72 | invdepthmapScaled = invdepthmap * depth_params["scale"] + depth_params["offset"] 73 | invdepthmapScaled = cv2.resize(invdepthmapScaled, resolution) 74 | invdepthmapScaled[invdepthmapScaled < 0] = 0 75 | if invdepthmapScaled.ndim != 2: 76 | invdepthmapScaled = invdepthmapScaled[..., 0] 77 | self.invdepthmap = torch.from_numpy(invdepthmapScaled[None]).to(self.data_device) 78 | 79 | if self.alpha_mask is not None: 80 | self.depth_mask = self.alpha_mask.clone() 81 | else: 82 | self.depth_mask = torch.ones_like(self.invdepthmap > 0) 83 | 84 | if depth_params["scale"] < 0.2 * depth_params["med_scale"] or depth_params["scale"] > 5 * depth_params["med_scale"]: 85 | self.depth_mask *= 0 86 | else: 87 | self.depth_reliable = True 88 | 89 | self.zfar = 100.0 90 | self.znear = 0.01 91 | 92 | self.trans = trans 93 | self.scale = scale 94 | 95 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).to(self.data_device) 96 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy, primx = primx, primy=primy).transpose(0,1).to(self.data_device) 97 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0).to(self.data_device) 98 | self.camera_center = self.world_view_transform.inverse()[3, :3].to(self.data_device) 99 | 100 | class MiniCam: 101 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 102 | self.image_width = width 103 | self.image_height = height 104 | self.FoVy = fovy 105 | self.FoVx = fovx 106 | self.znear = znear 107 | self.zfar = zfar 108 | self.world_view_transform = world_view_transform 109 | self.full_proj_transform = full_proj_transform 110 | view_inv = torch.inverse(self.world_view_transform) 111 | self.camera_center = view_inv[3][:3] 112 | 113 | 114 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | count = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | count += 1 101 | 102 | xyzs = np.zeros((count, 3)) 103 | rgbs = np.zeros((count, 3)) 104 | errors = np.zeros((count, 1)) 105 | 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | 122 | count += 1 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | with open(path_to_model_file, "rb") as fid: 132 | num_points = read_next_bytes(fid, 8, "Q")[0] 133 | 134 | xyzs = np.empty((num_points, 3)) 135 | rgbs = np.empty((num_points, 3)) 136 | errors = np.empty((num_points, 1)) 137 | 138 | for p_id in range(num_points): 139 | binary_point_line_properties = read_next_bytes( 140 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 141 | xyz = np.array(binary_point_line_properties[1:4]) 142 | rgb = np.array(binary_point_line_properties[4:7]) 143 | error = np.array(binary_point_line_properties[7]) 144 | track_length = read_next_bytes( 145 | fid, num_bytes=8, format_char_sequence="Q")[0] 146 | track_elems = read_next_bytes( 147 | fid, num_bytes=8*track_length, 148 | format_char_sequence="ii"*track_length) 149 | xyzs[p_id] = xyz 150 | rgbs[p_id] = rgb 151 | errors[p_id] = error 152 | return xyzs, rgbs, errors 153 | 154 | def read_intrinsics_text(path): 155 | """ 156 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 157 | """ 158 | cameras = {} 159 | with open(path, "r") as fid: 160 | while True: 161 | line = fid.readline() 162 | if not line: 163 | break 164 | line = line.strip() 165 | if len(line) > 0 and line[0] != "#": 166 | elems = line.split() 167 | camera_id = int(elems[0]) 168 | model = elems[1] 169 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 170 | width = int(elems[2]) 171 | height = int(elems[3]) 172 | params = np.array(tuple(map(float, elems[4:]))) 173 | cameras[camera_id] = Camera(id=camera_id, model=model, 174 | width=width, height=height, 175 | params=params) 176 | return cameras 177 | 178 | def read_extrinsics_binary(path_to_model_file): 179 | """ 180 | see: src/base/reconstruction.cc 181 | void Reconstruction::ReadImagesBinary(const std::string& path) 182 | void Reconstruction::WriteImagesBinary(const std::string& path) 183 | """ 184 | images = {} 185 | with open(path_to_model_file, "rb") as fid: 186 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 187 | for _ in range(num_reg_images): 188 | binary_image_properties = read_next_bytes( 189 | fid, num_bytes=64, format_char_sequence="idddddddi") 190 | image_id = binary_image_properties[0] 191 | qvec = np.array(binary_image_properties[1:5]) 192 | tvec = np.array(binary_image_properties[5:8]) 193 | camera_id = binary_image_properties[8] 194 | image_name = "" 195 | current_char = read_next_bytes(fid, 1, "c")[0] 196 | while current_char != b"\x00": # look for the ASCII 0 entry 197 | image_name += current_char.decode("utf-8") 198 | current_char = read_next_bytes(fid, 1, "c")[0] 199 | num_points2D = read_next_bytes(fid, num_bytes=8, 200 | format_char_sequence="Q")[0] 201 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 202 | format_char_sequence="ddq"*num_points2D) 203 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 204 | tuple(map(float, x_y_id_s[1::3]))]) 205 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 206 | images[image_id] = Image( 207 | id=image_id, qvec=qvec, tvec=tvec, 208 | camera_id=camera_id, name=image_name, 209 | xys=xys, point3D_ids=point3D_ids) 210 | return images 211 | 212 | 213 | def read_intrinsics_binary(path_to_model_file): 214 | """ 215 | see: src/base/reconstruction.cc 216 | void Reconstruction::WriteCamerasBinary(const std::string& path) 217 | void Reconstruction::ReadCamerasBinary(const std::string& path) 218 | """ 219 | cameras = {} 220 | with open(path_to_model_file, "rb") as fid: 221 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 222 | for _ in range(num_cameras): 223 | camera_properties = read_next_bytes( 224 | fid, num_bytes=24, format_char_sequence="iiQQ") 225 | camera_id = camera_properties[0] 226 | model_id = camera_properties[1] 227 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 228 | width = camera_properties[2] 229 | height = camera_properties[3] 230 | num_params = CAMERA_MODEL_IDS[model_id].num_params 231 | params = read_next_bytes(fid, num_bytes=8*num_params, 232 | format_char_sequence="d"*num_params) 233 | cameras[camera_id] = Camera(id=camera_id, 234 | model=model_name, 235 | width=width, 236 | height=height, 237 | params=np.array(params)) 238 | assert len(cameras) == num_cameras 239 | return cameras 240 | 241 | 242 | def read_extrinsics_text(path): 243 | """ 244 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 245 | """ 246 | images = {} 247 | with open(path, "r") as fid: 248 | while True: 249 | line = fid.readline() 250 | if not line: 251 | break 252 | line = line.strip() 253 | if len(line) > 0 and line[0] != "#": 254 | elems = line.split() 255 | image_id = int(elems[0]) 256 | qvec = np.array(tuple(map(float, elems[1:5]))) 257 | tvec = np.array(tuple(map(float, elems[5:8]))) 258 | camera_id = int(elems[8]) 259 | image_name = elems[9] 260 | elems = fid.readline().split() 261 | xys = np.column_stack([tuple(map(float, elems[0::3])), 262 | tuple(map(float, elems[1::3]))]) 263 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 264 | images[image_id] = Image( 265 | id=image_id, qvec=qvec, tvec=tvec, 266 | camera_id=camera_id, name=image_name, 267 | xys=xys, point3D_ids=point3D_ids) 268 | return images 269 | 270 | 271 | def read_colmap_bin_array(path): 272 | """ 273 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 274 | 275 | :param path: path to the colmap binary file. 276 | :return: nd array with the floating point values in the value 277 | """ 278 | with open(path, "rb") as fid: 279 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 280 | usecols=(0, 1, 2), dtype=int) 281 | fid.seek(0) 282 | num_delimiter = 0 283 | byte = fid.read(1) 284 | while True: 285 | if byte == b"&": 286 | num_delimiter += 1 287 | if num_delimiter >= 3: 288 | break 289 | byte = fid.read(1) 290 | array = np.fromfile(fid, np.float32) 291 | array = array.reshape((width, height, channels), order="F") 292 | return np.transpose(array, (1, 0, 2)).squeeze() 293 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from utils.sh_utils import SH2RGB 24 | from scene.gaussian_model import BasicPointCloud 25 | import torch 26 | 27 | 28 | class CameraInfo(NamedTuple): 29 | uid: int 30 | R: np.array 31 | T: np.array 32 | FovY: np.array 33 | FovX: np.array 34 | primx:float 35 | primy:float 36 | depth_params: dict 37 | image_path: str 38 | mask_path: str 39 | depth_path: str 40 | image_name: str 41 | width: int 42 | height: int 43 | is_test: bool 44 | 45 | class SceneInfo(NamedTuple): 46 | point_cloud: BasicPointCloud 47 | train_cameras: list 48 | test_cameras: list 49 | nerf_normalization: dict 50 | ply_path: str 51 | 52 | def getNerfppNorm(cam_info): 53 | def get_center_and_diag(cam_centers): 54 | cam_centers = np.hstack(cam_centers) 55 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 56 | center = avg_cam_center 57 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 58 | diagonal = np.quantile(dist, 0.9) 59 | return center.flatten(), diagonal 60 | 61 | cam_centers = [] 62 | 63 | for cam in cam_info: 64 | W2C = getWorld2View2(cam.R, cam.T) 65 | C2W = np.linalg.inv(W2C) 66 | cam_centers.append(C2W[:3, 3:4]) 67 | 68 | center, diagonal = get_center_and_diag(cam_centers) 69 | radius = diagonal * 1.1 70 | 71 | translate = -center 72 | 73 | return {"translate": translate, "radius": radius} 74 | 75 | def readColmapCameras(cam_extrinsics, cam_intrinsics, depths_params, images_folder, masks_folder, depths_folder, test_cam_names_list): 76 | cam_infos = [] 77 | for idx, key in enumerate(cam_extrinsics): 78 | sys.stdout.write('\r') 79 | # the exact output you're looking for: 80 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 81 | sys.stdout.flush() 82 | 83 | extr = cam_extrinsics[key] 84 | intr = cam_intrinsics[extr.camera_id] 85 | height = intr.height 86 | width = intr.width 87 | 88 | uid = intr.id 89 | R = np.transpose(qvec2rotmat(extr.qvec)) 90 | T = np.array(extr.tvec) 91 | 92 | if intr.model=="SIMPLE_PINHOLE": 93 | focal_length_x = intr.params[0] 94 | primx = float(intr.params[1]) / width 95 | primy = float(intr.params[2]) / height 96 | FovY = focal2fov(focal_length_x, height) 97 | FovX = focal2fov(focal_length_x, width) 98 | elif intr.model=="PINHOLE": 99 | primx = float(intr.params[2]) / width 100 | primy = float(intr.params[3]) / height 101 | focal_length_x = intr.params[0] 102 | focal_length_y = intr.params[1] 103 | FovY = focal2fov(focal_length_y, height) 104 | FovX = focal2fov(focal_length_x, width) 105 | else: 106 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 107 | 108 | n_remove = len(extr.name.split('.')[-1]) + 1 109 | depth_params = None 110 | if depths_params is not None: 111 | try: 112 | depth_params = depths_params[extr.name[:-n_remove]] 113 | except: 114 | print("\n", key, "not found in depths_params") 115 | 116 | image_path = os.path.join(images_folder, extr.name) 117 | image_name = extr.name 118 | if not os.path.exists(image_path): 119 | image_path = os.path.join(images_folder, f"{extr.name[:-n_remove]}.jpg") 120 | image_name = f"{extr.name[:-n_remove]}.jpg" 121 | if not os.path.exists(image_path): 122 | image_path = os.path.join(images_folder, f"{extr.name[:-n_remove]}.png") 123 | image_name = f"{extr.name[:-n_remove]}.png" 124 | 125 | mask_path = os.path.join(masks_folder, f"{extr.name[:-n_remove]}.png") if masks_folder != "" else "" 126 | depth_path = os.path.join(depths_folder, f"{extr.name[:-n_remove]}.png") if depths_folder != "" else "" 127 | 128 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, primx=primx, primy=primy, depth_params=depth_params, 129 | image_path=image_path, mask_path=mask_path, depth_path=depth_path, image_name=image_name, 130 | width=width, height=height, is_test=image_name in test_cam_names_list) 131 | cam_infos.append(cam_info) 132 | sys.stdout.write('\n') 133 | return cam_infos 134 | 135 | def fetchPly(path): 136 | plydata = PlyData.read(path) 137 | vertices = plydata['vertex'] 138 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 139 | 140 | if('red' in vertices): 141 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 142 | else: 143 | colors = np.ones_like(positions) * 0.5 144 | if('nx' in vertices): 145 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 146 | else: 147 | normals = np.zeros_like(positions) 148 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 149 | 150 | def fetchPt(xyz_path, rgb_path): 151 | positions_tensor = torch.jit.load(xyz_path).state_dict()['0'] 152 | 153 | positions = positions_tensor.numpy() 154 | 155 | colors_tensor = torch.jit.load(rgb_path).state_dict()['0'] 156 | if colors_tensor.size(0) == 0: 157 | colors_tensor = 255 * (torch.ones_like(positions_tensor) * 0.5) 158 | colors = (colors_tensor.float().numpy()) / 255.0 159 | normals = torch.Tensor([]).numpy() 160 | 161 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 162 | 163 | def storePly(path, xyz, rgb): 164 | # Define the dtype for the structured array 165 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 166 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 167 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 168 | 169 | normals = np.zeros_like(xyz) 170 | 171 | elements = np.empty(xyz.shape[0], dtype=dtype) 172 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 173 | elements[:] = list(map(tuple, attributes)) 174 | 175 | # Create the PlyData object and write to file 176 | vertex_element = PlyElement.describe(elements, 'vertex') 177 | ply_data = PlyData([vertex_element]) 178 | ply_data.write(path) 179 | 180 | def readColmapSceneInfo(path, images, masks, depths, eval, train_test_exp, llffhold=None): 181 | try: 182 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 183 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 184 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 185 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 186 | except: 187 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 188 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 189 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 190 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 191 | 192 | depth_params_file = os.path.join(path, "sparse/0", "depth_params.json") 193 | ## if depth_params_file isnt there AND depths file is here -> throw error 194 | depths_params = None 195 | if depths != "": 196 | try: 197 | with open(depth_params_file, "r") as f: 198 | depths_params = json.load(f) 199 | all_scales = np.array([depths_params[key]["scale"] for key in depths_params]) 200 | if (all_scales > 0).sum(): 201 | med_scale = np.median(all_scales[all_scales > 0]) 202 | else: 203 | med_scale = 0 204 | for key in depths_params: 205 | depths_params[key]["med_scale"] = med_scale 206 | 207 | except FileNotFoundError: 208 | print(f"Error: depth_params.json file not found at path '{depth_params_file}'.") 209 | sys.exit(1) 210 | except Exception as e: 211 | print(f"An unexpected error occurred when trying to open depth_params.json file: {e}") 212 | sys.exit(1) 213 | 214 | 215 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 216 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 217 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 218 | 219 | try: 220 | xyz_path = os.path.join(path, "sparse/0/xyz.pt") 221 | rgb_path = os.path.join(path, "sparse/0/rgb.pt") 222 | pcd = fetchPt(xyz_path, rgb_path) 223 | except: 224 | if not os.path.exists(ply_path): 225 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 226 | try: 227 | xyz, rgb, _ = read_points3D_binary(bin_path) 228 | except: 229 | xyz, rgb, _ = read_points3D_text(txt_path) 230 | storePly(ply_path, xyz, rgb) 231 | pcd = fetchPly(ply_path) 232 | 233 | if eval: 234 | if "360" in path: 235 | llffhold = 8 236 | if llffhold: 237 | print("------------LLFF HOLD-------------") 238 | cam_names = [cam_extrinsics[cam_id].name for cam_id in cam_extrinsics] 239 | cam_names = sorted(cam_names) 240 | test_cam_names_list = [name for idx, name in enumerate(cam_names) if idx % llffhold == 0] 241 | else: 242 | with open(os.path.join(path, "sparse/0", "test.txt"), 'r') as file: 243 | test_cam_names_list = [line.strip() for line in file] 244 | else: 245 | test_cam_names_list = [] 246 | 247 | reading_dir = "images" if images == None else images 248 | masks_reading_dir = masks if masks == "" else os.path.join(path, masks) 249 | 250 | cam_infos_unsorted = readColmapCameras( 251 | cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, depths_params=depths_params, 252 | images_folder=os.path.join(path, reading_dir), masks_folder=masks_reading_dir, 253 | depths_folder=os.path.join(path, depths) if depths != "" else "", test_cam_names_list=test_cam_names_list) 254 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 255 | 256 | train_cam_infos = [c for c in cam_infos if train_test_exp or not c.is_test] 257 | test_cam_infos = [c for c in cam_infos if c.is_test] 258 | print(len(test_cam_infos), "test images") 259 | print(len(train_cam_infos), "train images") 260 | 261 | nerf_normalization = getNerfppNorm(train_cam_infos) 262 | 263 | scene_info = SceneInfo(point_cloud=pcd, 264 | train_cameras=train_cam_infos, 265 | test_cameras=test_cam_infos, 266 | nerf_normalization=nerf_normalization, 267 | ply_path=ply_path) 268 | return scene_info 269 | 270 | 271 | sceneLoadTypeCallbacks = { 272 | "Colmap": readColmapSceneInfo 273 | } -------------------------------------------------------------------------------- /scripts/coarse_train.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -A hzb@v100 4 | #SBATCH -C v100-32g 5 | #SBATCH --ntasks=1 6 | #SBATCH --nodes=1 7 | #SBATCH --gres=gpu:1 8 | #SBATCH --cpus-per-task=20 9 | #SBATCH --time=3:00:00 10 | 11 | COARSE_ENV=$1 12 | COLMAP_DIR=$2 13 | IMAGES_DIR=$3 14 | OUTPUT_DIR=$4 15 | 16 | if [ -z "${5}" ]; then 17 | echo "No masks provided." 18 | MASKS_ARG="" 19 | else 20 | MASKS_ARG="${5}" 21 | echo "masks provided: $MASKS_ARGS" 22 | fi 23 | 24 | source $WORK/miniconda3/etc/profile.d/conda.sh 25 | conda activate ${COARSE_ENV} 26 | 27 | python ..//train_coarse.py -s ${COLMAP_DIR} --save_iterations -1 -i ${IMAGES_DIR} --skybox_num 100000 --model_path ${OUTPUT_DIR}/scaffold ${MASKS_ARGS} 28 | -------------------------------------------------------------------------------- /scripts/consolidate.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -A hzb@v100 4 | #SBATCH -C v100-32g 5 | #SBATCH --ntasks=1 6 | #SBATCH --nodes=1 7 | #SBATCH --gres=gpu:1 8 | #SBATCH --cpus-per-task=20 9 | #SBATCH --time=3:00:00 10 | 11 | HIERARCHY_MERGER_EXE=$1 12 | TRAINED_CHUNKS=$2 13 | NB=$3 14 | CHUNKS_COLMAP=$4 15 | OUTPUT=$5 16 | 17 | # Remove the first two arguments, leaving only the array elements 18 | shift 5 19 | 20 | # The remaining arguments are the array elements 21 | array_elements=("$@") 22 | 23 | echo "CHUNKS TO BE MERGED " ${CHUNK_LIST} 24 | ${HIERARCHY_MERGER_EXE} ${TRAINED_CHUNKS} ${NB} ${CHUNKS_COLMAP} ${OUTPUT} "${array_elements[@]}" 25 | -------------------------------------------------------------------------------- /scripts/train_chunk.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -A hzb@v100 4 | #SBATCH -C v100-32g 5 | #SBATCH --ntasks=1 6 | #SBATCH --nodes=1 7 | #SBATCH --gres=gpu:1 8 | #SBATCH --cpus-per-task=20 9 | #SBATCH --time=3:00:00 10 | 11 | 12 | SOURCE_CHUNK=$1 13 | OUTPUT_DIR=$2 14 | ENV=$3 15 | CHUNK_NAME=$4 16 | HIERARCHY_CREATOR_EXE=$5 17 | IMAGES_DIR=$6 18 | DEPTHS_DIR=$7 19 | 20 | if [ -z "$8" ]; then 21 | echo "No masks provided." 22 | MASKS_ARG="" 23 | else 24 | MASKS_ARG="$8" 25 | echo "masks provided: $MASKS_ARGS" 26 | fi 27 | 28 | TRAINED_CHUNK=${OUTPUT_DIR}"/trained_chunks/"${CHUNK_NAME} 29 | 30 | source $WORK/miniconda3/etc/profile.d/conda.sh 31 | conda activate ${ENV} 32 | 33 | # Train the chunk 34 | python -u train_single.py --save_iterations -1 -i ${IMAGES_DIR} -d ${DEPTHS_DIR} --scaffold_file ${OUTPUT_DIR}/scaffold/point_cloud/iteration_30000 --disable_viewer --skybox_locked ${MASKS_ARG} -s ${SOURCE_CHUNK} --model_path ${TRAINED_CHUNK} --bounds_file ${SOURCE_CHUNK} 35 | 36 | # Generate a hierarchy within the chunk 37 | ${HIERARCHY_CREATOR_EXE} ${TRAINED_CHUNK}/point_cloud/iteration_30000/point_cloud.ply ${SOURCE_CHUNK} ${TRAINED_CHUNK} ${OUTPUT_DIR}/scaffold/point_cloud/iteration_30000 38 | 39 | python -u train_post.py --iterations 15000 --feature_lr 0.0005 --opacity_lr 0.01 --scaling_lr 0.001 --save_iterations -1 -i ${IMAGES_DIR} --scaffold_file ${OUTPUT_DIR}/scaffold/point_cloud/iteration_30000 ${MASKS_ARG} -s ${SOURCE_CHUNK} --disable_viewer --model_path ${TRAINED_CHUNK} --hierarchy ${TRAINED_CHUNK}/hierarchy.hier 40 | 41 | echo "CHUNK " ${CHUNK_NAME} " FULLY TRAINED." 42 | -------------------------------------------------------------------------------- /train_coarse.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from utils.loss_utils import l1_loss, ssim 15 | from gaussian_renderer import render_coarse, network_gui 16 | import sys 17 | from scene import Scene, GaussianModel 18 | from utils.general_utils import safe_state 19 | import uuid 20 | from tqdm import tqdm 21 | from torch.utils.data import DataLoader 22 | from argparse import ArgumentParser, Namespace 23 | from arguments import ModelParams, PipelineParams, OptimizationParams 24 | 25 | def direct_collate(x): 26 | return x 27 | 28 | def training(dataset, opt, pipe, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 29 | first_iter = 0 30 | prepare_output_and_logger(dataset) 31 | gaussians = GaussianModel(1) 32 | scene = Scene(dataset, gaussians) 33 | gaussians.training_setup(opt) 34 | if checkpoint: 35 | (model_params, first_iter) = torch.load(checkpoint) 36 | gaussians.restore(model_params, opt) 37 | 38 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 39 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 40 | 41 | iter_start = torch.cuda.Event(enable_timing = True) 42 | iter_end = torch.cuda.Event(enable_timing = True) 43 | 44 | #viewpoint_stack = None 45 | ema_loss_for_log = 0.0 46 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 47 | first_iter += 1 48 | 49 | target = 0 50 | indices = None 51 | 52 | iteration = first_iter 53 | training_generator = DataLoader(scene.getTrainCameras(), num_workers = 8, prefetch_factor = 1, persistent_workers = True, collate_fn=direct_collate) 54 | 55 | for param_group in gaussians.optimizer.param_groups: 56 | if param_group["name"] == "xyz": 57 | param_group['lr'] = 0.0 58 | 59 | while iteration < opt.iterations + 1: 60 | for viewpoint_batch in training_generator: 61 | for viewpoint_cam in viewpoint_batch: 62 | background = torch.rand((3), dtype=torch.float32, device="cuda") 63 | 64 | viewpoint_cam.world_view_transform = viewpoint_cam.world_view_transform.cuda() 65 | viewpoint_cam.projection_matrix = viewpoint_cam.projection_matrix.cuda() 66 | viewpoint_cam.full_proj_transform = viewpoint_cam.full_proj_transform.cuda() 67 | viewpoint_cam.camera_center = viewpoint_cam.camera_center.cuda() 68 | 69 | if network_gui.conn == None: 70 | network_gui.try_connect() 71 | while network_gui.conn != None: 72 | try: 73 | net_image_bytes = None 74 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 75 | if custom_cam != None: 76 | net_image = render_coarse(custom_cam, gaussians, pipe, background, scaling_modifer, indices = indices)["render"] 77 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 78 | network_gui.send(net_image_bytes, dataset.source_path) 79 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 80 | break 81 | except Exception as e: 82 | network_gui.conn = None 83 | 84 | iter_start.record() 85 | 86 | # Every 1000 its we increase the levels of SH up to a maximum degree 87 | if iteration % 1000 == 0: 88 | gaussians.oneupSHdegree() 89 | 90 | # Render 91 | if (iteration - 1) == debug_from: 92 | pipe.debug = True 93 | 94 | render_pkg = render_coarse(viewpoint_cam, gaussians, pipe, background, indices = indices) 95 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 96 | 97 | # Loss 98 | gt_image = viewpoint_cam.original_image.cuda().float() 99 | if viewpoint_cam.alpha_mask is not None: 100 | alpha_mask = viewpoint_cam.alpha_mask.cuda().float() 101 | Ll1 = l1_loss(image * alpha_mask, gt_image) 102 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image * alpha_mask, gt_image)) 103 | else: 104 | Ll1 = l1_loss(image, gt_image) 105 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 106 | loss.backward() 107 | 108 | iter_end.record() 109 | 110 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii) 111 | 112 | with torch.no_grad(): 113 | # Progress bar 114 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 115 | if iteration % 10 == 0: 116 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Size": f"{gaussians._xyz.size(0)}", "Peak memory": f"{torch.cuda.max_memory_allocated(device='cuda')}"}) 117 | progress_bar.update(10) 118 | 119 | # Log and save 120 | if (iteration in saving_iterations): 121 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 122 | scene.save(iteration) 123 | 124 | if iteration == opt.iterations: 125 | progress_bar.close() 126 | training_generator._get_iterator()._shutdown_workers() 127 | return 128 | 129 | # Optimizer step 130 | 131 | if iteration < opt.iterations: 132 | gaussians._scaling.grad[:gaussians.skybox_points,:] = 0 133 | relevant = (gaussians._opacity.grad != 0).nonzero() 134 | gaussians.optimizer.step(relevant) 135 | gaussians.optimizer.zero_grad(set_to_none = True) 136 | 137 | if (iteration in checkpoint_iterations): 138 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 139 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 140 | 141 | with torch.no_grad(): 142 | vals, _ = gaussians.get_scaling.max(dim=1) 143 | violators = vals > scene.cameras_extent * 0.1 144 | violators[:gaussians.skybox_points] = False 145 | gaussians._scaling[violators] = gaussians.scaling_inverse_activation(gaussians.get_scaling[violators] * 0.8) 146 | 147 | 148 | iteration += 1 149 | 150 | 151 | def prepare_output_and_logger(args): 152 | if not args.model_path: 153 | if os.getenv('OAR_JOB_ID'): 154 | unique_str=os.getenv('OAR_JOB_ID') 155 | else: 156 | unique_str = str(uuid.uuid4()) 157 | args.model_path = os.path.join("./output/", unique_str[0:10]) 158 | 159 | # Set up output folder 160 | print("Output folder: {}".format(args.model_path)) 161 | os.makedirs(args.model_path, exist_ok = True) 162 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 163 | cfg_log_f.write(str(Namespace(**vars(args)))) 164 | 165 | if __name__ == "__main__": 166 | # Set up command line argument parser 167 | parser = ArgumentParser(description="Training script parameters") 168 | lp = ModelParams(parser) 169 | op = OptimizationParams(parser) 170 | pp = PipelineParams(parser) 171 | parser.add_argument('--ip', type=str, default="127.0.0.1") 172 | parser.add_argument('--port', type=int, default=6009) 173 | parser.add_argument('--debug_from', type=int, default=-1) 174 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 175 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[30_000]) 176 | parser.add_argument("--quiet", action="store_true") 177 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 178 | parser.add_argument("--start_checkpoint", type=str, default = None) 179 | args = parser.parse_args(sys.argv[1:]) 180 | args.save_iterations.append(args.iterations) 181 | 182 | print("Optimizing " + args.model_path) 183 | 184 | # Initialize system state (RNG) 185 | safe_state(args.quiet) 186 | 187 | # Start GUI server, configure and run training 188 | network_gui.init(args.ip, args.port) 189 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 190 | training(lp.extract(args), op.extract(args), pp.extract(args), args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 191 | 192 | # All done 193 | print("\nTraining complete.") 194 | -------------------------------------------------------------------------------- /train_post.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from utils.loss_utils import l1_loss, ssim 15 | from gaussian_renderer import render_post 16 | import sys 17 | from scene import Scene, GaussianModel 18 | from utils.general_utils import safe_state 19 | import uuid 20 | from tqdm import tqdm 21 | from torch.utils.data import DataLoader 22 | from argparse import ArgumentParser, Namespace 23 | from arguments import ModelParams, PipelineParams, OptimizationParams 24 | import math 25 | 26 | from gaussian_hierarchy._C import expand_to_size, get_interpolation_weights 27 | 28 | def direct_collate(x): 29 | return x 30 | 31 | def training(dataset, opt, pipe, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 32 | first_iter = 0 33 | prepare_output_and_logger(dataset) 34 | gaussians = GaussianModel(dataset.sh_degree) 35 | gaussians.active_sh_degree = dataset.sh_degree 36 | scene = Scene(dataset, gaussians, resolution_scales = [1], create_from_hier=True) 37 | gaussians.training_setup(opt, our_adam=False) 38 | if checkpoint: 39 | (model_params, first_iter) = torch.load(checkpoint) 40 | gaussians.restore(model_params, opt) 41 | 42 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 43 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 44 | 45 | iter_start = torch.cuda.Event(enable_timing = True) 46 | iter_end = torch.cuda.Event(enable_timing = True) 47 | 48 | ema_loss_for_log = 0.0 49 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 50 | first_iter += 1 51 | 52 | indices = None 53 | 54 | iteration = first_iter 55 | training_generator = DataLoader(scene.getTrainCameras(), num_workers = 8, prefetch_factor = 1, persistent_workers = True, collate_fn=direct_collate) 56 | 57 | limit = 0.001 58 | 59 | render_indices = torch.zeros(gaussians._xyz.size(0)).int().cuda() 60 | parent_indices = torch.zeros(gaussians._xyz.size(0)).int().cuda() 61 | nodes_for_render_indices = torch.zeros(gaussians._xyz.size(0)).int().cuda() 62 | interpolation_weights = torch.zeros(gaussians._xyz.size(0)).float().cuda() 63 | num_siblings = torch.zeros(gaussians._xyz.size(0)).int().cuda() 64 | to_render = 0 65 | 66 | limmax = 0.1 67 | limmin = 0.005 68 | 69 | while iteration < opt.iterations + 1: 70 | for viewpoint_batch in training_generator: 71 | for viewpoint_cam in viewpoint_batch: 72 | 73 | sample = torch.rand(1).item() 74 | limit = math.pow(2, sample * (math.log2(limmax) - math.log2(limmin)) + math.log2(limmin)) 75 | scale = 1 76 | 77 | viewpoint_cam.world_view_transform = viewpoint_cam.world_view_transform.cuda() 78 | viewpoint_cam.projection_matrix = viewpoint_cam.projection_matrix.cuda() 79 | viewpoint_cam.full_proj_transform = viewpoint_cam.full_proj_transform.cuda() 80 | viewpoint_cam.camera_center = viewpoint_cam.camera_center.cuda() 81 | 82 | #Then with blending training 83 | iter_start.record() 84 | 85 | gaussians.update_learning_rate(iteration) 86 | 87 | # Every 1000 its we increase the levels of SH up to a maximum degree 88 | if iteration % 1000 == 0: 89 | gaussians.oneupSHdegree() 90 | 91 | to_render = expand_to_size( 92 | gaussians.nodes, 93 | gaussians.boxes, 94 | limit * scale, 95 | viewpoint_cam.camera_center, 96 | torch.zeros((3)), 97 | render_indices, 98 | parent_indices, 99 | nodes_for_render_indices) 100 | 101 | indices = render_indices[:to_render].int() 102 | node_indices = nodes_for_render_indices[:to_render] 103 | 104 | get_interpolation_weights( 105 | node_indices, 106 | limit * scale, 107 | gaussians.nodes, 108 | gaussians.boxes, 109 | viewpoint_cam.camera_center.cpu(), 110 | torch.zeros((3)), 111 | interpolation_weights, 112 | num_siblings 113 | ) 114 | 115 | # Render 116 | if (iteration - 1) == debug_from: 117 | pipe.debug = True 118 | 119 | render_pkg = render_post( 120 | viewpoint_cam, 121 | gaussians, 122 | pipe, 123 | background, 124 | render_indices=indices, 125 | parent_indices = parent_indices, 126 | interpolation_weights = interpolation_weights, 127 | num_node_kids = num_siblings, 128 | use_trained_exp=True, 129 | ) 130 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 131 | 132 | 133 | # Loss 134 | gt_image = viewpoint_cam.original_image.cuda() 135 | if viewpoint_cam.alpha_mask is not None: 136 | Ll1 = l1_loss(image * viewpoint_cam.alpha_mask.cuda(), gt_image) 137 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image * viewpoint_cam.alpha_mask.cuda(), gt_image)) 138 | else: 139 | Ll1 = l1_loss(image, gt_image) 140 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 141 | 142 | loss.backward() 143 | 144 | iter_end.record() 145 | 146 | with torch.no_grad(): 147 | # Progress bar 148 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 149 | if iteration % 10 == 0: 150 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Size": f"{gaussians._xyz.size(0)}", "Peak memory": f"{torch.cuda.max_memory_allocated(device='cuda')}"}) 151 | progress_bar.update(10) 152 | 153 | # Log and save 154 | if (iteration in saving_iterations): 155 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 156 | scene.save(iteration) 157 | print("peak memory: ", torch.cuda.max_memory_allocated(device='cuda')) 158 | 159 | if iteration == opt.iterations: 160 | 161 | progress_bar.close() 162 | return 163 | 164 | # Optimizer step 165 | if iteration < opt.iterations: 166 | 167 | if gaussians._xyz.grad != None: 168 | if gaussians.skybox_points != 0 and gaussians.skybox_locked: #No post-opt for skybox 169 | gaussians._xyz.grad[-gaussians.skybox_points:, :] = 0 170 | gaussians._rotation.grad[-gaussians.skybox_points:, :] = 0 171 | gaussians._features_dc.grad[-gaussians.skybox_points:, :, :] = 0 172 | gaussians._features_rest.grad[-gaussians.skybox_points:, :, :] = 0 173 | gaussians._opacity.grad[-gaussians.skybox_points:, :] = 0 174 | gaussians._scaling.grad[-gaussians.skybox_points:, :] = 0 175 | 176 | gaussians._xyz.grad[gaussians.anchors, :] = 0 177 | gaussians._rotation.grad[gaussians.anchors, :] = 0 178 | gaussians._features_dc.grad[gaussians.anchors, :, :] = 0 179 | gaussians._features_rest.grad[gaussians.anchors, :, :] = 0 180 | gaussians._opacity.grad[gaussians.anchors, :] = 0 181 | gaussians._scaling.grad[gaussians.anchors, :] = 0 182 | 183 | ## OurAdam version 184 | # if gaussians._opacity.grad != None: 185 | # relevant = (gaussians._opacity.grad.flatten() != 0).nonzero() 186 | # relevant = relevant.flatten().long() 187 | # if(relevant.size(0) > 0): 188 | # gaussians.optimizer.step(relevant) 189 | # gaussians.optimizer.zero_grad(set_to_none = True) 190 | 191 | gaussians.optimizer.step() 192 | gaussians.optimizer.zero_grad(set_to_none = True) 193 | 194 | if (iteration in checkpoint_iterations): 195 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 196 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 197 | 198 | iteration += 1 199 | 200 | 201 | 202 | def prepare_output_and_logger(args): 203 | if not args.model_path: 204 | if os.getenv('OAR_JOB_ID'): 205 | unique_str=os.getenv('OAR_JOB_ID') 206 | else: 207 | unique_str = str(uuid.uuid4()) 208 | args.model_path = os.path.join("./output/", unique_str[0:10]) 209 | 210 | # Set up output folder 211 | print("Output folder: {}".format(args.model_path)) 212 | os.makedirs(args.model_path, exist_ok = True) 213 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 214 | cfg_log_f.write(str(Namespace(**vars(args)))) 215 | 216 | if __name__ == "__main__": 217 | # Set up command line argument parser 218 | parser = ArgumentParser(description="Training script parameters") 219 | lp = ModelParams(parser) 220 | op = OptimizationParams(parser) 221 | pp = PipelineParams(parser) 222 | parser.add_argument('--ip', type=str, default="127.0.0.1") 223 | parser.add_argument('--port', type=int, default=6009) 224 | parser.add_argument('--disable_viewer', action='store_true', default=False) 225 | parser.add_argument('--debug_from', type=int, default=-1) 226 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 227 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[30_000]) 228 | parser.add_argument("--quiet", action="store_true") 229 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 230 | parser.add_argument("--start_checkpoint", type=str, default = None) 231 | args = parser.parse_args(sys.argv[1:]) 232 | args.save_iterations.append(args.iterations) 233 | 234 | print("Optimizing " + args.model_path) 235 | 236 | # Initialize system state (RNG) 237 | safe_state(args.quiet) 238 | 239 | # Start GUI server, configure and run training 240 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 241 | training(lp.extract(args), op.extract(args), pp.extract(args), args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 242 | 243 | print("\nTraining complete.") 244 | -------------------------------------------------------------------------------- /train_single.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from utils.loss_utils import l1_loss, ssim 15 | from gaussian_renderer import render, network_gui 16 | import sys 17 | from scene import Scene, GaussianModel 18 | from utils.general_utils import safe_state, get_expon_lr_func 19 | import uuid 20 | from tqdm import tqdm 21 | from torch.utils.data import DataLoader 22 | from argparse import ArgumentParser, Namespace 23 | from arguments import ModelParams, PipelineParams, OptimizationParams 24 | 25 | def direct_collate(x): 26 | return x 27 | 28 | def training(dataset, opt, pipe, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 29 | first_iter = 0 30 | prepare_output_and_logger(dataset) 31 | gaussians = GaussianModel(dataset.sh_degree) 32 | scene = Scene(dataset, gaussians) 33 | gaussians.training_setup(opt) 34 | if checkpoint: 35 | (model_params, first_iter) = torch.load(checkpoint) 36 | gaussians.restore(model_params, opt) 37 | 38 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 39 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 40 | 41 | iter_start = torch.cuda.Event(enable_timing = True) 42 | iter_end = torch.cuda.Event(enable_timing = True) 43 | 44 | depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations) 45 | 46 | ema_loss_for_log = 0.0 47 | ema_Ll1depth_for_log = 0.0 48 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 49 | first_iter += 1 50 | 51 | indices = None 52 | 53 | training_generator = DataLoader(scene.getTrainCameras(), num_workers = 8, prefetch_factor = 1, persistent_workers = True, collate_fn=direct_collate) 54 | 55 | iteration = first_iter 56 | 57 | while iteration < opt.iterations + 1: 58 | for viewpoint_batch in training_generator: 59 | for viewpoint_cam in viewpoint_batch: 60 | background = torch.rand((3), dtype=torch.float32, device="cuda") 61 | 62 | viewpoint_cam.world_view_transform = viewpoint_cam.world_view_transform.cuda() 63 | viewpoint_cam.projection_matrix = viewpoint_cam.projection_matrix.cuda() 64 | viewpoint_cam.full_proj_transform = viewpoint_cam.full_proj_transform.cuda() 65 | viewpoint_cam.camera_center = viewpoint_cam.camera_center.cuda() 66 | 67 | if not args.disable_viewer: 68 | if network_gui.conn == None: 69 | network_gui.try_connect() 70 | while network_gui.conn != None: 71 | try: 72 | net_image_bytes = None 73 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 74 | if custom_cam != None: 75 | if keep_alive: 76 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer, indices = indices)["render"] 77 | else: 78 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer, indices = indices)["depth"].repeat(3, 1, 1) 79 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 80 | network_gui.send(net_image_bytes, dataset.source_path) 81 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 82 | break 83 | except Exception as e: 84 | network_gui.conn = None 85 | 86 | iter_start.record() 87 | 88 | gaussians.update_learning_rate(iteration) 89 | 90 | # Every 1000 its we increase the levels of SH up to a maximum degree 91 | if iteration % 1000 == 0: 92 | gaussians.oneupSHdegree() 93 | 94 | # Render 95 | if (iteration - 1) == debug_from: 96 | pipe.debug = True 97 | render_pkg = render(viewpoint_cam, gaussians, pipe, background, indices = indices, use_trained_exp=True) 98 | image, invDepth, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["depth"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 99 | 100 | # Loss 101 | gt_image = viewpoint_cam.original_image.cuda() 102 | if viewpoint_cam.alpha_mask is not None: 103 | alpha_mask = viewpoint_cam.alpha_mask.cuda() 104 | image *= alpha_mask 105 | 106 | Ll1 = l1_loss(image, gt_image) 107 | Lssim = (1.0 - ssim(image, gt_image)) 108 | photo_loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * Lssim 109 | loss = photo_loss.clone() 110 | Ll1depth_pure = 0.0 111 | if depth_l1_weight(iteration) > 0 and viewpoint_cam.depth_reliable: 112 | mono_invdepth = viewpoint_cam.invdepthmap.cuda() 113 | depth_mask = viewpoint_cam.depth_mask.cuda() 114 | 115 | Ll1depth_pure = torch.abs((invDepth - mono_invdepth) * depth_mask).mean() 116 | Ll1depth = depth_l1_weight(iteration) * Ll1depth_pure 117 | loss += Ll1depth 118 | Ll1depth = Ll1depth.item() 119 | else: 120 | Ll1depth = 0 121 | 122 | 123 | loss.backward() 124 | iter_end.record() 125 | 126 | with torch.no_grad(): 127 | # Progress bar 128 | ema_loss_for_log = 0.4 * photo_loss.item() + 0.6 * ema_loss_for_log 129 | ema_Ll1depth_for_log = 0.4 * Ll1depth + 0.6 * ema_Ll1depth_for_log 130 | if iteration % 10 == 0: 131 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Depth Loss": f"{ema_Ll1depth_for_log:.{7}f}", "Size": f"{gaussians._xyz.size(0)}"}) 132 | progress_bar.update(10) 133 | 134 | # Log and save 135 | if (iteration in saving_iterations): 136 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 137 | scene.save(iteration) 138 | print("peak memory: ", torch.cuda.max_memory_allocated(device='cuda')) 139 | 140 | if iteration == opt.iterations: 141 | progress_bar.close() 142 | return 143 | 144 | # Densification 145 | if iteration < opt.densify_until_iter: 146 | # Keep track of max radii in image-space for pruning 147 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii) 148 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 149 | 150 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 151 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent) 152 | 153 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 154 | print("-----------------RESET OPACITY!-------------") 155 | gaussians.reset_opacity() 156 | 157 | # Optimizer step 158 | if iteration < opt.iterations: 159 | gaussians.exposure_optimizer.step() 160 | gaussians.exposure_optimizer.zero_grad(set_to_none = True) 161 | 162 | if gaussians._xyz.grad != None and gaussians.skybox_locked: 163 | gaussians._xyz.grad[:gaussians.skybox_points, :] = 0 164 | gaussians._rotation.grad[:gaussians.skybox_points, :] = 0 165 | gaussians._features_dc.grad[:gaussians.skybox_points, :, :] = 0 166 | gaussians._features_rest.grad[:gaussians.skybox_points, :, :] = 0 167 | gaussians._opacity.grad[:gaussians.skybox_points, :] = 0 168 | gaussians._scaling.grad[:gaussians.skybox_points, :] = 0 169 | 170 | if gaussians._opacity.grad != None: 171 | relevant = (gaussians._opacity.grad.flatten() != 0).nonzero() 172 | relevant = relevant.flatten().long() 173 | if(relevant.size(0) > 0): 174 | gaussians.optimizer.step(relevant) 175 | else: 176 | gaussians.optimizer.step(relevant) 177 | print("No grads!") 178 | gaussians.optimizer.zero_grad(set_to_none = True) 179 | 180 | if not args.skip_scale_big_gauss: 181 | with torch.no_grad(): 182 | vals, _ = gaussians.get_scaling.max(dim=1) 183 | violators = vals > scene.cameras_extent * 0.02 184 | if gaussians.scaffold_points is not None: 185 | violators[:gaussians.scaffold_points] = False 186 | gaussians._scaling[violators] = gaussians.scaling_inverse_activation(gaussians.get_scaling[violators] * 0.8) 187 | 188 | if (iteration in checkpoint_iterations): 189 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 190 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 191 | 192 | iteration += 1 193 | 194 | def prepare_output_and_logger(args): 195 | if not args.model_path: 196 | if os.getenv('OAR_JOB_ID'): 197 | unique_str=os.getenv('OAR_JOB_ID') 198 | else: 199 | unique_str = str(uuid.uuid4()) 200 | args.model_path = os.path.join("./output/", unique_str[0:10]) 201 | 202 | # Set up output folder 203 | print("Output folder: {}".format(args.model_path)) 204 | os.makedirs(args.model_path, exist_ok = True) 205 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 206 | cfg_log_f.write(str(Namespace(**vars(args)))) 207 | 208 | if __name__ == "__main__": 209 | # Set up command line argument parser 210 | parser = ArgumentParser(description="Training script parameters") 211 | lp = ModelParams(parser) 212 | op = OptimizationParams(parser) 213 | pp = PipelineParams(parser) 214 | parser.add_argument('--ip', type=str, default="127.0.0.1") 215 | parser.add_argument('--port', type=int, default=6009) 216 | parser.add_argument('--disable_viewer', action='store_true', default=False) 217 | parser.add_argument('--debug_from', type=int, default=-1) 218 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 219 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[30_000]) 220 | parser.add_argument("--quiet", action="store_true") 221 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 222 | parser.add_argument("--start_checkpoint", type=str, default = None) 223 | args = parser.parse_args(sys.argv[1:]) 224 | args.save_iterations.append(args.iterations) 225 | 226 | print("Optimizing " + args.model_path) 227 | 228 | if args.eval and args.exposure_lr_init > 0 and not args.train_test_exp: 229 | print("Reconstructing for evaluation (--eval) with exposure optimization on the train set but not for the test set.") 230 | print("This will lead to high error when computing metrics. To optimize exposure on the left half of the test images, use --train_test_exp") 231 | 232 | # Initialize system state (RNG) 233 | safe_state(args.quiet) 234 | 235 | # Start GUI server, configure and run training 236 | if not args.disable_viewer: 237 | network_gui.init(args.ip, args.port) 238 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 239 | training(lp.extract(args), op.extract(args), pp.extract(args), args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 240 | 241 | # All done 242 | print("\nTraining complete.") 243 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from scene.cameras import Camera 14 | import numpy as np 15 | from utils.graphics_utils import fov2focal 16 | from PIL import Image 17 | import os, sys 18 | import cv2 19 | 20 | WARNED = False 21 | 22 | def loadCam(args, id, cam_info, resolution_scale, is_test_dataset): 23 | image = Image.open(cam_info.image_path) 24 | 25 | if cam_info.mask_path != "": 26 | try: 27 | alpha_mask = Image.open(cam_info.mask_path) 28 | except FileNotFoundError: 29 | print(f"Error: The mask file at path '{cam_info.mask_path}' was not found.") 30 | raise 31 | except IOError: 32 | print(f"Error: Unable to open the image file '{cam_info.mask_path}'. It may be corrupted or an unsupported format.") 33 | raise 34 | except Exception as e: 35 | print(f"An unexpected error occurred: {e}") 36 | raise 37 | else: 38 | alpha_mask = None 39 | 40 | if cam_info.depth_path != "": 41 | try: 42 | invdepthmap = cv2.imread(cam_info.depth_path, -1).astype(np.float32) / float(2**16) 43 | except FileNotFoundError: 44 | print(f"Error: The depth file at path '{cam_info.depth_path}' was not found.") 45 | raise 46 | except IOError: 47 | print(f"Error: Unable to open the image file '{cam_info.depth_path}'. It may be corrupted or an unsupported format.") 48 | raise 49 | except Exception as e: 50 | print(f"An unexpected error occurred when trying to read depth at {cam_info.depth_path}: {e}") 51 | raise 52 | else: 53 | invdepthmap = None 54 | 55 | orig_w, orig_h = image.size 56 | 57 | if args.resolution in [1, 2, 4, 8]: 58 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 59 | else: # should be a type that converts to float 60 | if args.resolution == -1: 61 | if orig_w > 1600: 62 | global WARNED 63 | if not WARNED: 64 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 65 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 66 | WARNED = True 67 | global_down = orig_w / 1600 68 | else: 69 | global_down = 1 70 | else: 71 | global_down = orig_w / args.resolution 72 | 73 | scale = float(global_down) * float(resolution_scale) 74 | resolution = (int(orig_w / scale), int(orig_h / scale)) 75 | 76 | return Camera(resolution, colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 77 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, depth_params=cam_info.depth_params, 78 | primx=cam_info.primx, primy=cam_info.primy, 79 | image=image, alpha_mask=alpha_mask, invdepthmap=invdepthmap, 80 | image_name=cam_info.image_name, uid=id, data_device=args.data_device, 81 | train_test_exp=args.train_test_exp, is_test_dataset=is_test_dataset, is_test_view=cam_info.is_test) 82 | 83 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 84 | camera_list = [] 85 | 86 | for id, c in enumerate(cam_infos): 87 | camera_list.append(loadCam(args, id, c, resolution_scale)) 88 | 89 | return camera_list 90 | 91 | def camera_to_JSON(id, camera : Camera): 92 | Rt = np.zeros((4, 4)) 93 | Rt[:3, :3] = camera.R.transpose() 94 | Rt[:3, 3] = camera.T 95 | Rt[3, 3] = 1.0 96 | 97 | W2C = np.linalg.inv(Rt) 98 | pos = W2C[:3, 3] 99 | rot = W2C[:3, :3] 100 | serializable_array_2d = [x.tolist() for x in rot] 101 | camera_entry = { 102 | 'id' : id, 103 | 'img_name' : camera.image_name, 104 | 'width' : camera.width, 105 | 'height' : camera.height, 106 | 'position': pos.tolist(), 107 | 'rotation': serializable_array_2d, 108 | 'fy' : fov2focal(camera.FovY, camera.height), 109 | 'fx' : fov2focal(camera.FovX, camera.width) 110 | } 111 | return camera_entry 112 | 113 | import torch 114 | 115 | class CameraDataset(torch.utils.data.Dataset): 116 | 'Characterizes a dataset for PyTorch' 117 | def __init__(self, list_cam_infos, args, resolution_scales, is_test): 118 | 'Initialization' 119 | self.resolution_scales = resolution_scales 120 | self.list_cam_infos = list_cam_infos 121 | self.args = args 122 | self.args.data_device = 'cpu' 123 | self.is_test = is_test 124 | 125 | def __len__(self): 126 | 'Denotes the total number of samples' 127 | return len(self.list_cam_infos) 128 | 129 | def __getitem__(self, index): 130 | 'Generates one sample of data' 131 | 132 | # Select sample 133 | info = self.list_cam_infos[index] 134 | X = loadCam(self.args, index, info, self.resolution_scales, self.is_test) 135 | 136 | return X 137 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | from PIL import Image 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x/(1-x)) 21 | 22 | def PILtoTorch(pil_image, resolution): 23 | if resolution[0] != pil_image.size[0] or resolution[1] != pil_image.size[1]: 24 | pil_image = pil_image.resize(resolution, Image.LANCZOS) 25 | resized_image = torch.from_numpy(np.array(pil_image)) / 255.0 26 | if len(resized_image.shape) == 3: 27 | return resized_image.permute(2, 0, 1) 28 | else: 29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 30 | 31 | def get_expon_lr_func( 32 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 33 | ): 34 | """ 35 | Copied from Plenoxels 36 | 37 | Continuous learning rate decay function. Adapted from JaxNeRF 38 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 39 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 40 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 41 | function of lr_delay_mult, such that the initial learning rate is 42 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 43 | to the normal learning rate when steps>lr_delay_steps. 44 | :param conf: config subtree 'lr' or similar 45 | :param max_steps: int, the number of steps during optimization. 46 | :return HoF which takes step as input 47 | """ 48 | 49 | def helper(step): 50 | if lr_init == 0: 51 | return 0 52 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 53 | # Disable this parameter 54 | return 0.0 55 | if lr_delay_steps > 0: 56 | # A kind of reverse cosine decay. 57 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 58 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 59 | ) 60 | else: 61 | delay_rate = 1.0 62 | t = np.clip(step / max_steps, 0, 1) 63 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 64 | return delay_rate * log_lerp 65 | 66 | return helper 67 | 68 | def strip_lowerdiag(L): 69 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 70 | 71 | uncertainty[:, 0] = L[:, 0, 0] 72 | uncertainty[:, 1] = L[:, 0, 1] 73 | uncertainty[:, 2] = L[:, 0, 2] 74 | uncertainty[:, 3] = L[:, 1, 1] 75 | uncertainty[:, 4] = L[:, 1, 2] 76 | uncertainty[:, 5] = L[:, 2, 2] 77 | return uncertainty 78 | 79 | def strip_symmetric(sym): 80 | return strip_lowerdiag(sym) 81 | 82 | def build_rotation(r): 83 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 84 | 85 | q = r / norm[:, None] 86 | 87 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 88 | 89 | r = q[:, 0] 90 | x = q[:, 1] 91 | y = q[:, 2] 92 | z = q[:, 3] 93 | 94 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 95 | R[:, 0, 1] = 2 * (x*y - r*z) 96 | R[:, 0, 2] = 2 * (x*z + r*y) 97 | R[:, 1, 0] = 2 * (x*y + r*z) 98 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 99 | R[:, 1, 2] = 2 * (y*z - r*x) 100 | R[:, 2, 0] = 2 * (x*z - r*y) 101 | R[:, 2, 1] = 2 * (y*z + r*x) 102 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 103 | return R 104 | 105 | def build_scaling_rotation(s, r): 106 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 107 | R = build_rotation(r) 108 | 109 | L[:,0,0] = s[:,0] 110 | L[:,1,1] = s[:,1] 111 | L[:,2,2] = s[:,2] 112 | 113 | L = R @ L 114 | return L 115 | 116 | def safe_state(silent): 117 | old_f = sys.stdout 118 | class F: 119 | def __init__(self, silent): 120 | self.silent = silent 121 | 122 | def write(self, x): 123 | if not self.silent: 124 | if x.endswith("\n"): 125 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 126 | else: 127 | old_f.write(x) 128 | 129 | def flush(self): 130 | old_f.flush() 131 | 132 | sys.stdout = F(silent) 133 | 134 | random.seed(0) 135 | np.random.seed(0) 136 | torch.manual_seed(0) 137 | torch.cuda.set_device(torch.device("cuda:0")) 138 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY, primx, primy): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | #primy = 0.5 56 | #primx = 0.5 57 | 58 | top = tanHalfFovY * znear 59 | bottom = (1 - primy) * 2 * -top 60 | top = primy * 2 * top 61 | 62 | right = tanHalfFovX * znear 63 | left = (1-primx) * 2 * -right 64 | right = primx * 2 * right 65 | 66 | P = torch.zeros(4, 4) 67 | 68 | z_sign = 1.0 69 | 70 | P[0, 0] = 2.0 * znear / (right - left) 71 | P[1, 1] = 2.0 * znear / (top - bottom) 72 | P[0, 2] = (right + left) / (right - left) 73 | P[1, 2] = (top + bottom) / (top - bottom) 74 | P[3, 2] = z_sign 75 | P[2, 2] = z_sign * zfar / (zfar - znear) 76 | P[2, 3] = -(zfar * znear) / (zfar - znear) 77 | return P 78 | 79 | def fov2focal(fov, pixels): 80 | return pixels / (2 * math.tan(fov / 2)) 81 | 82 | def focal2fov(focal, pixels): 83 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 - 2024, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------