├── .gitignore ├── LICENSE ├── README.md ├── assets └── imgs │ └── zero_gs_logo.png ├── conerf ├── __init__.py ├── base │ ├── checkpoint_manager.py │ └── model_base.py ├── datasets │ ├── __init__.py │ ├── ace_camera_loc_dataset.py │ ├── dataset_base.py │ ├── load_colmap.py │ └── utils.py ├── evaluators │ ├── ace_zero_evaluator.py │ └── evaluator.py ├── geometry │ ├── align_poses.py │ ├── camera.py │ ├── pose_util.py │ └── utils.py ├── loss │ └── ssim_torch.py ├── model │ ├── __init__.py │ ├── backbone │ │ ├── activations.py │ │ ├── encodings.py │ │ ├── feature_pyramid_net.py │ │ ├── mlp.py │ │ └── resnet3d.py │ ├── misc.py │ └── scene_regressor │ │ ├── ace_encoder_pretrained.pt │ │ ├── ace_loss.py │ │ ├── ace_network.py │ │ ├── ace_util.py │ │ ├── calibr.py │ │ ├── depth_network.py │ │ └── pose_refine_network.py ├── pycolmap │ ├── pycolmap │ │ ├── __init__.py │ │ ├── camera.py │ │ ├── database.py │ │ ├── image.py │ │ ├── rotation.py │ │ └── scene_manager.py │ └── tools │ │ ├── colmap_to_nvm.py │ │ ├── delete_images.py │ │ ├── impute_missing_cameras.py │ │ ├── save_cameras_as_ply.py │ │ ├── transform_model.py │ │ ├── write_camera_track_to_bundler.py │ │ └── write_depthmap_to_ply.py ├── trainers │ ├── ace_zero_trainer.py │ └── trainer.py ├── utils │ ├── config.py │ └── utils.py └── visualization │ ├── feature_visualizer.py │ ├── pose_visualizer.py │ └── scene_visualizer.py ├── config └── ace │ ├── llff.yaml │ ├── mipnerf360.yaml │ └── tanks_and_temples.yaml ├── eval.py ├── scripts ├── env │ └── install.sh ├── eval │ ├── eval_ace_zero.sh │ └── vis_recon.py ├── preprocess │ ├── colmap_mapping.sh │ ├── database.py │ ├── hloc_mapping │ │ ├── extract_features.py │ │ ├── extract_relative_poses.py │ │ ├── filter_matches.py │ │ ├── match_features.py │ │ ├── pairs_from_retrieval.py │ │ ├── reconstruction.py │ │ ├── sfm_pipeline.py │ │ ├── triangulate_from_existing_model.py │ │ └── utils.py │ ├── mapping.py │ ├── read_write_model.py │ ├── triangulate.sh │ └── utils.py └── train │ └── train_ace_zero.sh ├── submodules └── dsacstar │ ├── dsacstar.cpp │ ├── dsacstar_derivative.h │ ├── dsacstar_loss.h │ ├── dsacstar_types.h │ ├── dsacstar_util.h │ ├── dsacstar_util_rgbd.h │ ├── setup.py │ ├── stop_watch.h │ ├── thread_rand.cpp │ └── thread_rand.h ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | submodules/dsacstar/build 2 | submodules/dsacstar/dist 3 | submodules/dsacstar/dsacstar.egg-info -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The following files are under the license of ACE. Copyright © Niantic, Inc. 2022. Patent Pending: 2 | 3 | - Main ACE Files: 4 | - conerf/model/scene_regressor/ace_encoder_pretrained.pt 5 | - conerf/model/scene_regressor/ace_loss.py 6 | - conerf/model/scene_regressor/ace_network.py 7 | - conerf/model/scene_regressor/ace_util.py 8 | - conerf/trainers/ace_zero_trainer.py 9 | 10 | ------------------------------------------------------------------------------ 11 | 12 | The rest of the files are under the MIT license: 13 | 14 | Copyright (c) 2024, Chen Yu 15 | All rights reserved. 16 | 17 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 18 | associated documentation files (the “Software”), to deal in the Software without restriction, 19 | including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 20 | and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 21 | subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all copies or substantial 24 | portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT 27 | LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 28 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 29 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 30 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 31 | 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ZeroGS: Training 3D Gaussian Splatting from Unposed Images 2 | 3 | [[Project Page](https://aibluefisher.github.io/ZeroGS/) | [arXiv](https://arxiv.org/pdf/2411.15779)] 4 | 5 | --------------------------- 6 | 7 | ## 🛠️ Installation 8 | 9 | Install the conda environment of ZeroGS. 10 | 11 | ```sh 12 | conda create -n zero_gs python=3.9 13 | conda activate zero_gs 14 | cd ZeroGS/scripts 15 | ./scripts/env/install.sh 16 | ``` 17 | 18 | **git hook for code style checking**: 19 | ```sh 20 | pre-commit install --hook-type pre-commit 21 | ``` 22 | 23 | 24 | ## 🚀 Features 25 | 26 | - [x] Release [ACE0](https://nianticlabs.github.io/acezero) implementation 27 | - [ ] Incorporate [GLACE](https://github.com/cvg/glace) into ACE0 28 | - [ ] Release our customized 3D Gaussian Splatting module 29 | - [ ] Incorporate [Scaffold-GS](https://city-super.github.io/scaffold-gs) 30 | - [ ] Incorporate [DOGS](https://github.com/aibluefisher/dogs) 31 | - [ ] Release ZeroGS implementation 32 | 33 | 34 | ## 📋 Train & Eval ACE0 35 | 36 | We aim at providing a framework which makes it easy to implement your own neural implicit module with this codebase and since this project starts before the code releasing of ACE0, we re-implement ACE0 based on our codebase. 37 | 38 | ### ⌛Train ACE0 39 | 40 | Before training ACE0, please download the [pretrained feature encoder](https://github.com/nianticlabs/ace/blob/main/ace_encoder_pretrained.pt) from ACE, and put it under the folder `ZeroGS/conerf/model/scene_regressor`. 41 | 42 | ```bash 43 | conda activate zero_gs 44 | visdom -port=9000 # Keep the port the same as the `visdom_port` provided in the configuration file 45 | cd ZeroGS/scripts/train 46 | ./train_ace_zero.sh 0 ace_early_stop_resize_2k_anneal mipnerf360 ace 47 | ``` 48 | We use `visdom` to visualize the camera pose predictions during training. You can access `https://localhost:9000` to view it. 49 | 50 | ### 📊 Evaluate ACE0 51 | 52 | ```bash 53 | conda activate zero_gs 54 | cd ZeroGS/scripts/eval 55 | ./eval_ace_zero.sh 0 ace_early_stop_resize_2k_anneal mipnerf360 ace 56 | ``` 57 | Metrics file and camera poses will be recorded in `eval/val/` folder. Point clouds are recorded in the `eval/val/ACE0_COLMAP` (This folder also contains the model files in COLMAP formats) in `.ply` format. 58 | 59 | ### 🔢 Hyper Parameters for training ACE0 60 | 61 | All the parameters related to train ACE0 are provided the configuration file in `config/ace/mipnerf360.yaml`. Most of the parameters can be kept the same as in this configuration file. However, the parameters listed below need to be adjusted accordingly to obtain better performance: 62 | ```yaml 63 | trainer: 64 | # We can use less iterations for the `garden` scene (i.e. 2000). 65 | min_iterations_per_epoch: 5000 66 | 67 | pose_estimator: 68 | # Change this to a larger threshold (3000) for the 'garden` scene of the mipnerf360 dataset. 69 | min_inlier_count: 2000 # minimum number of inlier correspondences when registering an image 70 | ``` 71 | 72 | A larger value in `min_iterations_per_epoch` can make the mapping more accurate, but also lead to longer training time. 73 | 74 | 75 | ## ✏️ Cite 76 | 77 | If you find this project useful for your research, please consider citing our paper: 78 | ```bibtex 79 | @inproceedings{yuchen2024zerogs, 80 | title={ZeroGS: Training 3D Gaussian Splatting from Unposed Images}, 81 | author={Yu Chen, Rolandos Alexandros Potamias, Evangelos Ververas, Jifei Song, Jiankang Deng, Gim Hee Lee}, 82 | booktitle={arXiv}, 83 | year={2024}, 84 | } 85 | ``` 86 | 87 | ## 🙌 Acknowledgements 88 | 89 | This work is built upon [ACE](https://nianticlabs.github.io/ace/), [DUSt3R](https://github.com/naver/dust3r), and [Spann3R](https://hengyiwang.github.io/projects/spanner). We sincerely thank all the authors for releasing their code. 90 | 91 | ## 🪪 License 92 | 93 | Copyright © 2024, Chen Yu. 94 | All rights reserved. 95 | Please see the [license file](LICENSE) for terms. 96 | -------------------------------------------------------------------------------- /assets/imgs/zero_gs_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/assets/imgs/zero_gs_logo.png -------------------------------------------------------------------------------- /conerf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/conerf/__init__.py -------------------------------------------------------------------------------- /conerf/base/model_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ModelBase(torch.nn.Module): 5 | """ 6 | An abstract class which defines some basic operations for a torch model. 7 | """ 8 | def __init__(self, **kwargs) -> None: 9 | super().__init__() 10 | 11 | def to_distributed(self): 12 | """Change model to distributed mode.""" 13 | raise NotImplementedError 14 | 15 | def switch_to_eval(self): 16 | """Change model to evaluation mode.""" 17 | raise NotImplementedError 18 | 19 | def switch_to_train(self): 20 | """Change model to training mode.""" 21 | raise NotImplementedError 22 | 23 | def forward(self, data, **kwargs): 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /conerf/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/conerf/datasets/__init__.py -------------------------------------------------------------------------------- /conerf/datasets/ace_camera_loc_dataset.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101 2 | 3 | import logging 4 | import random 5 | import math 6 | 7 | import imageio 8 | import torch 9 | import torchvision.transforms.functional as TF 10 | from skimage import color 11 | from skimage import io 12 | from skimage.transform import rotate 13 | from torch.utils.data import Dataset 14 | from torch.utils.data.dataloader import default_collate 15 | from torchvision import transforms 16 | 17 | from conerf.datasets.load_colmap import load_colmap 18 | 19 | _logger = logging.getLogger(__name__) 20 | 21 | 22 | class CamLocDataset(Dataset): 23 | """Camera localization dataset. 24 | 25 | Access to image, calibration and ground truth data given a dataset directory. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | # root_dir: str, 31 | root_fp: str, 32 | subject_id: str, 33 | val_interval: int = 0, 34 | scale: bool = True, 35 | rotate: bool = True, 36 | augment: bool = False, 37 | aug_rotation: int = 15, 38 | aug_scale_min: float = 2 / 3, 39 | aug_scale_max: float = 3 / 2, 40 | aug_black_white: float = 0.1, 41 | aug_color: float = 0.3, 42 | factor: int = 8, 43 | use_half: bool = True, 44 | ): 45 | """ 46 | Params: 47 | root_dir: Folder of the data (training or test). 48 | augment: Use random data augmentation, note: note supported for mode=2 (RGB-D) since 49 | pre-generated eye coordinates cannot be augmented 50 | aug_rotation: Max 2D image rotation angle, sampled uniformly around 0, both 51 | directions, degrees 52 | aug_scale_min: Lower limit of image scale factor for uniform sampling 53 | aug_scale_max: Upper limit of image scale factor for uniform sampling 54 | aug_black_white: Max relative scale factor for image brightness/contrast sampling, 55 | e.g. 0.1 -> [0.9,1.1] 56 | aug_color: Max relative scale factor for image saturation/hue sampling, e.g. 57 | 0.1 -> [0.9,1.1] 58 | image_height: RGB images are rescaled to this maximum height (if augmentation is 59 | disabled, and in the range [aug_scale_min * image_height, aug_scale_max * 60 | image_height] otherwise). 61 | use_half: Enabled if training with half-precision floats. 62 | """ 63 | 64 | self.use_half = use_half 65 | self.factor = factor 66 | self.augment = augment 67 | self.aug_rotation = aug_rotation 68 | self.aug_scale_min = aug_scale_min 69 | self.aug_scale_max = aug_scale_max 70 | self.aug_black_white = aug_black_white 71 | self.aug_color = aug_color 72 | 73 | data = load_colmap( 74 | root_fp, subject_id, split='train', factor=factor, 75 | val_interval=val_interval, scale=scale, rotate=rotate, 76 | ) 77 | self.rgb_files = data['image_paths'] 78 | self.gt_camtoworlds = data['poses'] 79 | 80 | # We use this to iterate over all frames. 81 | self.valid_file_indices = {i: i for i in range(len(self.rgb_files))} 82 | 83 | # Try to read an image and get its width and height. 84 | image = imageio.imread(self.rgb_files[0]) # [H,W,3] 85 | # Use a fixed 480px image height since the convolutional feature backbone 86 | # is pretrained to ingest images scaled to 480px. 87 | self.origin_image_height, self.origin_image_width = image.shape[:2] 88 | # self.image_height = image.shape[0] 89 | # self.image_width = image.shape[1] 90 | self.image_height = 480 91 | 92 | # Image transformations. Excluding scale since that can vary batch-by-batch. 93 | if self.augment: 94 | self.image_transform = transforms.Compose([ 95 | transforms.Grayscale(), 96 | transforms.ColorJitter( 97 | brightness=self.aug_black_white, contrast=self.aug_black_white), 98 | transforms.ToTensor(), 99 | transforms.Normalize(mean=[0.4], std=[0.25]), 100 | ]) 101 | else: 102 | self.image_transform = transforms.Compose([ 103 | transforms.Grayscale(), 104 | transforms.ToTensor(), 105 | transforms.Normalize(mean=[0.4], std=[0.25]), 106 | ]) 107 | 108 | def image(self, idx): 109 | idx = self.valid_file_indices[idx] 110 | return self._load_image(idx) 111 | 112 | def image_tensor(self, idx): 113 | return torch.from_numpy(self.image(idx)) 114 | 115 | def resized_image(self, idx, image_height: int, image_width: int = None): 116 | image = self.image(idx) 117 | return self._resize_image(image, image_height, image_width) 118 | 119 | def resized_grayscale_image(self, idx, image_height: int): 120 | color_image_pil = self.resized_image(idx, image_height) 121 | return color_image_pil, self.image_transform(color_image_pil) 122 | 123 | @staticmethod 124 | def _resize_image(image, image_height: int, image_width: int = None): 125 | # Resize a numpy image as PIL. Works slightly better than resizing the tensor 126 | # using torch's internal function. 127 | image = TF.to_pil_image(image) 128 | image = TF.resize(image, image_height) if image_width is None else \ 129 | TF.resize(image, [image_height, image_width]) 130 | return image 131 | 132 | @staticmethod 133 | def _rotate_image(image, angle, order, mode='constant'): 134 | # Image is a torch tensor (CxHxW), convert it to numpy as HxWxC. 135 | image = image.permute(1, 2, 0).numpy() 136 | # Apply rotation. 137 | image = rotate(image, angle, order=order, mode=mode) 138 | # Back to torch tensor. 139 | image = torch.from_numpy(image).permute(2, 0, 1).float() 140 | return image 141 | 142 | def _load_image(self, idx): 143 | image = io.imread(self.rgb_files[idx]) 144 | 145 | if len(image.shape) < 3: 146 | # Convert to RGB if needed. 147 | image = color.gray2rgb(image) 148 | 149 | return image 150 | 151 | def _get_single_item(self, idx, image_height): 152 | # Apply index indirection. 153 | idx = self.valid_file_indices[idx] 154 | 155 | # Load image. 156 | image = self._load_image(idx) 157 | 158 | # Rescale image. 159 | image = self._resize_image(image, image_height) 160 | 161 | # Create mask of the same size as the resized image (it's a PIL image at this point). 162 | image_mask = torch.ones((1, image.size[1], image.size[0])) 163 | 164 | # Apply remaining transforms. 165 | image = self.image_transform(image) 166 | 167 | pose_rot = torch.eye(4) 168 | 169 | # Apply data augmentation if necessary. 170 | if self.augment: 171 | # Generate a random rotation angle. 172 | angle = random.uniform(-self.aug_rotation, self.aug_rotation) 173 | 174 | # Rotate input image and mask. 175 | image = self._rotate_image(image, angle, 1, 'reflect') 176 | image_mask = self._rotate_image(image_mask, angle, order=1, mode='constant') 177 | 178 | # Provide the rotation as well. 179 | # - pose = pose @ pose_rot 180 | angle = angle * math.pi / 180. 181 | pose_rot[0, 0] = math.cos(angle) 182 | pose_rot[0, 1] = -math.sin(angle) 183 | pose_rot[1, 0] = math.sin(angle) 184 | pose_rot[1, 1] = math.cos(angle) 185 | 186 | # Convert to half precision if needed. 187 | if self.use_half and torch.cuda.is_available(): 188 | image = image.half() 189 | 190 | # Binarize the mask. 191 | image_mask = image_mask > 0 192 | 193 | # TODO(chenyu): shall we return the augmented status for latter 3D Gaussian Splatting? 194 | 195 | return image, image_mask, pose_rot, idx, str(self.rgb_files[idx]) 196 | 197 | def __len__(self): 198 | return len(self.valid_file_indices) 199 | 200 | def __getitem__(self, idx): 201 | if self.augment: 202 | scale_factor = random.uniform(self.aug_scale_min, self.aug_scale_max) 203 | else: 204 | scale_factor = 1 205 | 206 | # Target image height. We compute it here in case we are asked for a full batch of tensors 207 | # because we need to apply the same scale factor to all of them. 208 | image_height = int(self.image_height * scale_factor) 209 | 210 | if type(idx) == list: 211 | # Whole batch. 212 | tensors = [self._get_single_item(i, image_height) for i in idx] 213 | return default_collate(tensors) 214 | else: 215 | # Single element. 216 | return self._get_single_item(idx, image_height) 217 | -------------------------------------------------------------------------------- /conerf/geometry/align_poses.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101 2 | 3 | import easydict as edict 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | def convert3x4_4x4(input): 10 | """ 11 | Make into homogeneous coordinates by adding [0, 0, 0, 1] to the bottom. 12 | :param input: (N, 3, 4) or (3, 4) torch or np 13 | :return: (N, 4, 4) or (4, 4) torch or np 14 | """ 15 | if torch.is_tensor(input): 16 | if len(input.shape) == 3: 17 | output = torch.cat([input, torch.zeros_like( 18 | input[:, 0:1])], dim=1) # (N, 4, 4) 19 | output[:, 3, 3] = 1.0 20 | else: 21 | output = torch.cat([input, torch.tensor( 22 | [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4) 23 | else: 24 | if len(input.shape) == 3: 25 | output = np.concatenate( 26 | [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4) 27 | output[:, 3, 3] = 1.0 28 | else: 29 | output = np.concatenate( 30 | [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4) 31 | output[3, 3] = 1.0 32 | return output 33 | 34 | 35 | def procrustes_analysis(X0, X1): # [N,3] 36 | # translation 37 | t0 = X0.mean(dim=0, keepdim=True) 38 | t1 = X1.mean(dim=0, keepdim=True) 39 | X0c = X0 - t0 40 | X1c = X1 - t1 41 | 42 | # scale 43 | s0 = (X0c ** 2).sum(dim=-1).mean().sqrt() 44 | s1 = (X1c ** 2).sum(dim=-1).mean().sqrt() 45 | X0cs = X0c / s0 46 | X1cs = X1c / s1 47 | 48 | # rotation (use double for SVD, float loses precision) 49 | U, S, V = (X0cs.t() @ X1cs).double().svd(some=True) 50 | R = (U @ V.t()).float() 51 | if R.det() < 0: 52 | R[2] *= -1 53 | 54 | # Align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0 55 | sim3 = edict(t0=t0[0], t1=t1[0], s0=s0, s1=s1, R=R) 56 | return sim3 57 | 58 | 59 | def get_best_yaw(C): 60 | ''' 61 | maximize trace(Rz(theta) * C) 62 | ''' 63 | assert C.shape == (3, 3) 64 | 65 | A = C[0, 1] - C[1, 0] 66 | B = C[0, 0] + C[1, 1] 67 | theta = np.pi / 2 - np.arctan2(B, A) 68 | 69 | return theta 70 | 71 | 72 | def align_umeyama(model, data, known_scale=False): 73 | """Implementation of the paper: S. Umeyama, Least-Squares Estimation 74 | of Transformation Parameters Between Two Point Patterns, 75 | IEEE Trans. Pattern Anal. Mach. Intell., vol. 13, no. 4, 1991. 76 | 77 | model = s * R * data + t 78 | 79 | Input: 80 | model -- first trajectory (nx3), numpy array type 81 | data -- second trajectory (nx3), numpy array type 82 | 83 | Output: 84 | s -- scale factor (scalar) 85 | R -- rotation matrix (3x3) 86 | t -- translation vector (3x1) 87 | t_error -- translational error per point (1xn) 88 | 89 | """ 90 | 91 | # substract mean 92 | mu_M = model.mean(0) 93 | mu_D = data.mean(0) 94 | model_zerocentered = model - mu_M 95 | data_zerocentered = data - mu_D 96 | n = np.shape(model)[0] 97 | 98 | # correlation 99 | C = 1.0/n*np.dot(model_zerocentered.transpose(), data_zerocentered) 100 | sigma2 = 1.0/n*np.multiply(data_zerocentered, data_zerocentered).sum() 101 | U_svd, D_svd, V_svd = np.linalg.linalg.svd(C) 102 | 103 | D_svd = np.diag(D_svd) 104 | V_svd = np.transpose(V_svd) 105 | 106 | S = np.eye(3) 107 | if (np.linalg.det(U_svd)*np.linalg.det(V_svd) < 0): 108 | S[2, 2] = -1 109 | 110 | R = np.dot(U_svd, np.dot(S, np.transpose(V_svd))) 111 | 112 | if known_scale: 113 | s = 1 114 | else: 115 | s = 1.0/sigma2*np.trace(np.dot(D_svd, S)) 116 | 117 | t = mu_M-s*np.dot(R, mu_D) 118 | 119 | return s, R, t 120 | 121 | 122 | def _getIndices(n_aligned, total_n): 123 | if n_aligned == -1: 124 | idxs = np.arange(0, total_n) 125 | else: 126 | assert n_aligned <= total_n and n_aligned >= 1 127 | idxs = np.arange(0, n_aligned) 128 | return idxs 129 | 130 | 131 | # align by similarity transformation 132 | def align_sim3(p_es, p_gt, n_aligned=-1): 133 | ''' 134 | calculate s, R, t so that: 135 | gt = R * s * est + t 136 | ''' 137 | idxs = _getIndices(n_aligned, p_es.shape[0]) 138 | est_pos = p_es[idxs, 0:3] 139 | gt_pos = p_gt[idxs, 0:3] 140 | try: 141 | s, R, t = align_umeyama(gt_pos, est_pos) # note the order 142 | except: # pylint: disable=W0702 143 | print('[WARNING] align_poses.py: SVD did not converge!') 144 | s, R, t = 1.0, np.eye(3), np.zeros(3) 145 | return s, R, t 146 | 147 | 148 | def align_ate_c2b_use_a2b(traj_a, traj_b, traj_c=None): 149 | """Align c to b using the sim3 from a to b. 150 | :param traj_a: (N0, 3/4, 4) torch tensor 151 | :param traj_b: (N0, 3/4, 4) torch tensor 152 | :param traj_c: None or (N1, 3/4, 4) torch tensor 153 | :return: (N1, 4, 4) torch tensor 154 | """ 155 | device = traj_a.device 156 | if traj_c is None: 157 | traj_c = traj_a.clone() 158 | 159 | traj_a = traj_a.float().cpu().numpy() 160 | traj_b = traj_b.float().cpu().numpy() 161 | traj_c = traj_c.float().cpu().numpy() 162 | 163 | # R_a = traj_a[:, :3, :3] # (N0, 3, 3) 164 | t_a = traj_a[:, :3, 3] # (N0, 3) 165 | 166 | # R_b = traj_b[:, :3, :3] # (N0, 3, 3) 167 | t_b = traj_b[:, :3, 3] # (N0, 3) 168 | 169 | # This function works in quaternion. 170 | # scalar, (3, 3), (3, ) gt = R * s * est + t. 171 | s, R, t = align_sim3(t_a, t_b) 172 | 173 | # reshape tensors 174 | R = R[None, :, :].astype(np.float32) # (1, 3, 3) 175 | t = t[None, :, None].astype(np.float32) # (1, 3, 1) 176 | s = float(s) 177 | 178 | R_c = traj_c[:, :3, :3] # (N1, 3, 3) 179 | t_c = traj_c[:, :3, 3:4] # (N1, 3, 1) 180 | 181 | R_c_aligned = R @ R_c # (N1, 3, 3) 182 | t_c_aligned = s * (R @ t_c) + t # (N1, 3, 1) 183 | traj_c_aligned = np.concatenate( 184 | [R_c_aligned, t_c_aligned], axis=2) # (N1, 3, 4) 185 | 186 | # append the last row 187 | traj_c_aligned = convert3x4_4x4(traj_c_aligned) # (N1, 4, 4) 188 | 189 | traj_c_aligned = torch.from_numpy(traj_c_aligned).to(device) 190 | 191 | return traj_c_aligned, s, R, t # (N1, 4, 4) 192 | -------------------------------------------------------------------------------- /conerf/geometry/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def todevice(batch, device, callback=None, non_blocking=False): 6 | ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). 7 | 8 | batch: list, tuple, dict of tensors or other things 9 | device: pytorch device or 'numpy' 10 | callback: function that would be called on every sub-elements. 11 | ''' 12 | if callback: 13 | batch = callback(batch) 14 | 15 | if isinstance(batch, dict): 16 | return {k: todevice(v, device) for k, v in batch.items()} 17 | 18 | if isinstance(batch, (tuple, list)): 19 | return type(batch)(todevice(x, device) for x in batch) 20 | 21 | x = batch 22 | if device == 'numpy': 23 | if isinstance(x, torch.Tensor): 24 | x = x.detach().cpu().numpy() 25 | elif x is not None: 26 | if isinstance(x, np.ndarray): 27 | x = torch.from_numpy(x) 28 | if torch.is_tensor(x): 29 | x = x.to(device, non_blocking=non_blocking) 30 | return x 31 | 32 | 33 | def to_numpy(x): 34 | return todevice(x, 'numpy') 35 | 36 | 37 | def to_cpu(x): 38 | return todevice(x, 'cpu') 39 | 40 | 41 | def to_cuda(x): 42 | return todevice(x, 'cuda') 43 | -------------------------------------------------------------------------------- /conerf/loss/ssim_torch.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=[E1102,C0103,R0903] 2 | 3 | from math import exp 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from torch.autograd import Variable 9 | 10 | 11 | def gaussian(window_size, sigma): 12 | gauss = torch.Tensor([ 13 | exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) \ 14 | for x in range(window_size) 15 | ]) 16 | return gauss / gauss.sum() 17 | 18 | 19 | def create_window(window_size, channel): 20 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 21 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 22 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 23 | return window 24 | 25 | 26 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 27 | mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel) 28 | mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel) 29 | 30 | mu1_sq = mu1.pow(2) 31 | mu2_sq = mu2.pow(2) 32 | mu1_mu2 = mu1*mu2 33 | 34 | sigma1_sq = F.conv2d( 35 | img1 * img1, window, padding=window_size//2, groups=channel 36 | ) - mu1_sq 37 | sigma2_sq = F.conv2d( 38 | img2 * img2, window, padding=window_size//2, groups=channel 39 | ) - mu2_sq 40 | sigma12 = F.conv2d( 41 | img1 * img2, window, padding=window_size//2, groups=channel 42 | ) - mu1_mu2 43 | 44 | C1 = 0.01 ** 2 45 | C2 = 0.03 ** 2 46 | 47 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (( 48 | mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 49 | 50 | if size_average: 51 | return ssim_map.mean() 52 | return ssim_map.mean(1).mean(1).mean(1) 53 | 54 | 55 | class SSIM(torch.nn.Module): 56 | def __init__(self, window_size=11, size_average = True): 57 | super().__init__() 58 | self.window_size = window_size 59 | self.size_average = size_average 60 | self.channel = 1 61 | self.window = create_window(window_size, self.channel) 62 | 63 | def forward(self, img1, img2): 64 | channel = img1.size(-3) 65 | 66 | if channel == self.channel and self.window.data.type() == img1.data.type(): 67 | window = self.window 68 | else: 69 | window = create_window(self.window_size, channel) 70 | 71 | if img1.is_cuda: 72 | window = window.cuda(img1.get_device()) 73 | window = window.type_as(img1) 74 | 75 | self.window = window 76 | self.channel = channel 77 | 78 | 79 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 80 | 81 | 82 | def ssim(img1, img2, window_size = 11, size_average = True): 83 | channel = img1.size(-3) 84 | window = create_window(window_size, channel) 85 | 86 | if img1.is_cuda: 87 | window = window.cuda(img1.get_device()) 88 | window = window.type_as(img1) 89 | 90 | return _ssim(img1, img2, window, window_size, channel, size_average) 91 | -------------------------------------------------------------------------------- /conerf/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/conerf/model/__init__.py -------------------------------------------------------------------------------- /conerf/model/backbone/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Gaussian(nn.Module): 6 | """ 7 | Gaussian activation function. 8 | """ 9 | def __init__( 10 | self, 11 | mean: float = 0.0, 12 | sigma: float = 0.1, 13 | ) -> None: 14 | super().__init__() 15 | 16 | self.mean = mean 17 | self.sigma = sigma 18 | self.sigma_square = self.sigma ** 2 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | return ( # torch.exp( 22 | -0.5 * (x ** 2) / self.sigma_square 23 | ).exp() 24 | -------------------------------------------------------------------------------- /conerf/model/backbone/encodings.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | from typing import List, Callable, Dict 9 | 10 | 11 | class SinusoidalEncoder(nn.Module): 12 | """Sinusoidal Positional Encoder used in Nerf.""" 13 | 14 | def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True): 15 | super().__init__() 16 | self.x_dim = x_dim 17 | self.min_deg = min_deg 18 | self.max_deg = max_deg 19 | self.use_identity = use_identity 20 | self.c2f = None 21 | self.register_buffer( 22 | "scales", torch.tensor([2**i for i in range(min_deg, max_deg)]) 23 | ) 24 | 25 | @property 26 | def latent_dim(self) -> int: 27 | return ( 28 | int(self.use_identity) + (self.max_deg - self.min_deg) * 2 29 | ) * self.x_dim 30 | 31 | def forward(self, x: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Args: 34 | x: [..., x_dim] 35 | Returns: 36 | latent: [..., latent_dim] 37 | """ 38 | if self.max_deg == self.min_deg: 39 | return x 40 | xb = torch.reshape( 41 | (x[Ellipsis, None, :] * self.scales[:, None]), 42 | list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim], 43 | ) 44 | latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1)) 45 | if self.use_identity: 46 | latent = torch.cat([x] + [latent], dim=-1) 47 | return latent 48 | 49 | 50 | class ProgressiveSinusoidalEncoder(SinusoidalEncoder): 51 | """ 52 | Coarse-to-fine positional encodings. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | x_dim: int, 58 | min_deg: int, 59 | max_deg: int, 60 | use_identity: bool = True, 61 | c2f: List = [0.1, 0.5], 62 | half_dim: bool = False, 63 | ): 64 | super().__init__(x_dim, min_deg, max_deg, use_identity) 65 | 66 | # Use nn.Parameter so it could be checkpointed. 67 | self.progress = torch.nn.Parameter(torch.tensor(0.), requires_grad=False) 68 | 69 | self.c2f = c2f 70 | self.half_dim = half_dim 71 | 72 | @property 73 | def latent_dim(self) -> int: 74 | latent_dim = super().latent_dim 75 | if self.half_dim: 76 | latent_dim = (latent_dim - self.x_dim) // 2 + self.x_dim 77 | return latent_dim 78 | 79 | def anneal( 80 | self, 81 | iteration: int, 82 | max_iteration: int, 83 | factor: float = 1.0, 84 | reduction: float = 0.0, 85 | bias: float = 0.0, 86 | anneal_surface: bool = False, 87 | ): 88 | """ 89 | Gradually increase the controllable parameter during training. 90 | """ 91 | if anneal_surface: 92 | if iteration > max_iteration // 2: 93 | progress_data = 1.0 94 | else: 95 | progress_data = 0.5 + float(iteration) / float(max_iteration) 96 | else: 97 | # For camera pose annealing. 98 | progress_data = float(iteration) / float(max_iteration) 99 | 100 | progress_data = factor * (progress_data - reduction) + bias 101 | 102 | self.progress.data.fill_(progress_data) 103 | 104 | def forward(self, x: torch.Tensor) -> torch.Tensor: 105 | latent = super().forward(x) 106 | latent_dim = super().latent_dim 107 | 108 | # Computing weights. 109 | start, end = self.c2f 110 | alpha = (self.progress.data - start) / (end - start) * self.max_deg 111 | ks = torch.arange(self.min_deg, self.max_deg, 112 | dtype=torch.float32, device=x.device) 113 | weight = ( 114 | 1.0 - (alpha - ks).clamp_(min=0, max=1).mul_(np.pi).cos_() 115 | ) / 2.0 116 | 117 | # Apply weight to positional encodings. 118 | shape = latent.shape 119 | L = self.max_deg - self.min_deg 120 | 121 | if self.use_identity: 122 | latent_freq = latent[:, self.x_dim:].reshape(-1, L) 123 | latent_freq = ( 124 | latent_freq * weight).reshape(shape[0], shape[-1] - self.x_dim) 125 | latent[:, self.x_dim:] = latent_freq 126 | else: 127 | latent = (latent.reshape(-1, L) * weight).reshape(*shape) 128 | 129 | if self.half_dim: 130 | half_freq = L // 2 131 | # input coordinates are excluded. 132 | half_latent_dim = (latent_dim - self.x_dim) // 2 133 | num_feat_each_band = (latent_dim - self.x_dim) // L 134 | half_latent = latent[:, self.x_dim:].view(-1, L, num_feat_each_band)[ 135 | :, :half_freq, :].view(-1, half_latent_dim) 136 | 137 | half_latent_contg = latent[:, self.x_dim:].view(-1, L, num_feat_each_band)[ 138 | :, :half_freq, :].view(-1, half_latent_dim).contiguous() 139 | half_latent_contg = ( 140 | half_latent_contg.view(-1, half_freq) * weight[:half_freq] 141 | ).view(-1, half_latent_dim) 142 | flag = weight[:half_freq].tile(shape[0], num_feat_each_band, 1).transpose( 143 | 1, 2).contiguous().view(-1, half_latent_dim) 144 | half_latent = torch.where( 145 | flag > 0.01, half_latent, half_latent_contg) 146 | latent = torch.cat([latent[:, :self.x_dim], half_latent], dim=-1) 147 | 148 | return latent 149 | 150 | 151 | class GaussianEncoder(nn.Module): 152 | """ 153 | Gaussian encodings. 154 | """ 155 | 156 | def __init__( 157 | self, 158 | x_dim: int, 159 | feature_dim: int, 160 | init_func: Callable = nn.init.uniform_, 161 | init_range: float = 0.1, 162 | sigma: float = 0.1, 163 | ) -> None: 164 | super().__init__() 165 | 166 | self.init_func = init_func 167 | self.init_range = init_range 168 | self.sigma = sigma 169 | self.sigma_square = sigma ** 2 170 | self.latent_dim = feature_dim 171 | 172 | gaussian_linear = torch.nn.Linear(x_dim, feature_dim) 173 | self.init_func(gaussian_linear.weight, - 174 | self.init_range, self.init_range) 175 | self.gaussian_linear = nn.utils.weight_norm(gaussian_linear) 176 | 177 | def forward(self, x: torch.Tensor) -> torch.Tensor: 178 | x = self.gaussian_linear(x) 179 | mu = torch.mean(x, axis=-1).unsqueeze(-1) 180 | x = torch.exp( 181 | -0.5 * ((x - mu) ** 2) / self.sigma_square 182 | ) 183 | return x 184 | 185 | 186 | def create_encoder(x_dim: int, config: Dict): 187 | """ 188 | Factory function for creating encodings that applied to coordinate input. 189 | """ 190 | encoder_type = config["type"] 191 | if encoder_type == "sinusoidal": 192 | return SinusoidalEncoder( 193 | x_dim=x_dim, 194 | min_deg=config["min_deg"], 195 | max_deg=config["max_deg"], 196 | use_identity=config["use_identity"], 197 | ) 198 | elif encoder_type == "progressive": 199 | return ProgressiveSinusoidalEncoder( 200 | x_dim=x_dim, 201 | min_deg=config["min_deg"], 202 | max_deg=config["max_deg"], 203 | use_identity=config["use_identity"], 204 | c2f=config["c2f"], 205 | half_dim=config["half_dim"], 206 | ) 207 | elif encoder_type == "gaussian": 208 | return GaussianEncoder( 209 | x_dim=x_dim, 210 | feature_dim=config["feature_dim"] // 2, 211 | init_range=config["init_range"], 212 | sigma=config["sigma"], 213 | ) 214 | else: 215 | raise NotImplementedError 216 | -------------------------------------------------------------------------------- /conerf/model/backbone/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim: int, # The number of input tensor channels. 13 | output_dim: int = None, # The number of output tensor channels. 14 | net_depth: int = 8, # The depth of the MLP. 15 | net_width: int = 256, # The width of the MLP. 16 | skip_layer: int = 4, # The layer to add skip layers to. 17 | hidden_init: Callable = nn.init.xavier_uniform_, 18 | hidden_activation: Callable = nn.ReLU(), 19 | output_enabled: bool = True, 20 | output_init: Optional[Callable] = nn.init.xavier_uniform_, 21 | output_activation: Optional[Callable] = nn.Identity(), 22 | bias_enabled: bool = True, 23 | bias_init: Callable = nn.init.zeros_, 24 | ): 25 | super().__init__() 26 | self.input_dim = input_dim 27 | self.output_dim = output_dim 28 | self.net_depth = net_depth 29 | self.net_width = net_width 30 | self.skip_layer = skip_layer 31 | self.hidden_init = hidden_init 32 | self.hidden_activation = hidden_activation 33 | self.output_enabled = output_enabled 34 | self.output_init = output_init 35 | self.output_activation = output_activation 36 | self.bias_enabled = bias_enabled 37 | self.bias_init = bias_init 38 | 39 | self.hidden_layers = nn.ModuleList() 40 | in_features = self.input_dim 41 | for i in range(self.net_depth): 42 | self.hidden_layers.append( 43 | nn.Linear(in_features, self.net_width, bias=bias_enabled) 44 | ) 45 | if ( 46 | (self.skip_layer is not None) 47 | and (i % self.skip_layer == 0) 48 | and (i > 0) 49 | ): 50 | in_features = self.net_width + self.input_dim 51 | else: 52 | in_features = self.net_width 53 | if self.output_enabled: 54 | self.output_layer = nn.Linear( 55 | in_features, self.output_dim, bias=bias_enabled 56 | ) 57 | else: 58 | self.output_dim = in_features 59 | 60 | self.initialize() 61 | 62 | def initialize(self): 63 | def init_func_hidden(m): 64 | if isinstance(m, nn.Linear): 65 | if self.hidden_init is not None: 66 | self.hidden_init(m.weight) 67 | if self.bias_enabled and self.bias_init is not None: 68 | self.bias_init(m.bias) 69 | 70 | self.hidden_layers.apply(init_func_hidden) 71 | if self.output_enabled: 72 | 73 | def init_func_output(m): 74 | if isinstance(m, nn.Linear): 75 | if self.output_init is not None: 76 | self.output_init(m.weight) 77 | if self.bias_enabled and self.bias_init is not None: 78 | self.bias_init(m.bias) 79 | 80 | self.output_layer.apply(init_func_output) 81 | 82 | def forward(self, x): 83 | inputs = x 84 | for i in range(self.net_depth): 85 | x = self.hidden_layers[i](x) 86 | x = self.hidden_activation(x) 87 | if ( 88 | (self.skip_layer is not None) 89 | and (i % self.skip_layer == 0) 90 | and (i > 0) 91 | ): 92 | x = torch.cat([x, inputs], dim=-1) 93 | if self.output_enabled: 94 | x = self.output_layer(x) 95 | x = self.output_activation(x) 96 | return x 97 | 98 | 99 | class DenseLayer(MLP): 100 | def __init__(self, input_dim, output_dim, **kwargs): 101 | super().__init__( 102 | input_dim=input_dim, 103 | output_dim=output_dim, 104 | net_depth=0, # no hidden layers 105 | **kwargs, 106 | ) 107 | 108 | 109 | class NormalizedMLP(nn.Module): 110 | def __init__( 111 | self, 112 | input_dim: int, 113 | output_dim: int = None, 114 | net_depth: int = 8, 115 | net_width: int = 256, 116 | skip_layer: List = [4], 117 | hidden_activation: Callable = nn.ReLU(), 118 | bias: float = 0.5, 119 | weight_norm: bool = True, 120 | geometric_init: bool = True, 121 | ) -> None: 122 | super().__init__() 123 | 124 | self.input_dim = input_dim 125 | self.output_dim = output_dim 126 | self.net_depth = net_depth 127 | self.net_width = net_width 128 | self.skip_layer = skip_layer 129 | self.hidden_activation = hidden_activation 130 | self.bias = bias 131 | dims = [input_dim] + [net_width for _ in range(net_depth)] + [output_dim] 132 | self.num_layers = len(dims) 133 | 134 | for i in range(0, self.num_layers - 1): 135 | if (self.skip_layer is not None) and (i + 1) in self.skip_layer: 136 | out_dim = dims[i + 1] - dims[0] 137 | else: 138 | out_dim = dims[i + 1] 139 | 140 | lin = nn.Linear(dims[i], out_dim) 141 | 142 | if geometric_init: 143 | if i == self.num_layers - 2: 144 | torch.nn.init.normal_( 145 | lin.weight, 146 | mean=np.sqrt(np.pi) / np.sqrt(dims[i]), 147 | std=0.0001 148 | ) 149 | torch.nn.init.constant_(lin.bias, -bias) 150 | elif i == 0: 151 | torch.nn.init.constant_(lin.bias, 0.0) 152 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 153 | torch.nn.init.normal_( 154 | lin.weight[:, :3], 155 | 0.0, 156 | np.sqrt(2) / np.sqrt(out_dim) 157 | ) 158 | elif (self.skip_layer is not None) and (i in self.skip_layer): 159 | torch.nn.init.constant_(lin.bias, 0.0) 160 | torch.nn.init.normal_( 161 | lin.weight, 162 | 0.0, 163 | np.sqrt(2) / np.sqrt(out_dim) 164 | ) 165 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 166 | else: 167 | torch.nn.init.constant_(lin.bias, 0.0) 168 | torch.nn.init.normal_( 169 | lin.weight, 170 | 0.0, 171 | np.sqrt(2) / np.sqrt(out_dim) 172 | ) 173 | 174 | if weight_norm: 175 | lin = nn.utils.weight_norm(lin) 176 | 177 | setattr(self, "lin" + str(i), lin) 178 | 179 | def forward(self, x): 180 | inputs = x 181 | for i in range(0, self.num_layers - 1): 182 | lin = getattr(self, "lin" + str(i)) 183 | 184 | if (self.skip_layer is not None) and (i in self.skip_layer): 185 | x = torch.cat([x, inputs], dim=-1) / np.sqrt(2) 186 | 187 | x = lin(x) 188 | 189 | if i < self.num_layers - 2: 190 | x = self.hidden_activation(x) 191 | 192 | return x 193 | -------------------------------------------------------------------------------- /conerf/model/backbone/resnet3d.py: -------------------------------------------------------------------------------- 1 | # Code is adapted from: https://github.com/DonGovi/pyramid-detection-3D 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | 11 | __all__ = ['ResNet3D', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152'] 13 | 14 | ''' 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | ''' 23 | 24 | 25 | def conv3x3x3(in_planes, out_planes, stride=1): 26 | """3x3x3 convolution with padding""" 27 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=1, bias=False) 29 | ''' 30 | def downsample_basic_block(x, planes, stride): 31 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 32 | zero_pads = torch.Tensor( 33 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 34 | out.size(4)).zero_() 35 | if isinstance(out.data, torch.cuda.FloatTensor): 36 | zero_pads = zero_pads.cuda() 37 | 38 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 39 | 40 | return out 41 | ''' 42 | 43 | class BasicBlock(nn.Module): 44 | expansion = 1 45 | 46 | def __init__(self, in_planes, planes, stride=1, downsample=None): 47 | super(BasicBlock, self).__init__() 48 | self.conv1 = conv3x3x3(in_planes, planes, stride) 49 | self.bn1 = nn.BatchNorm3d(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3x3(planes, planes) 52 | self.bn2 = nn.BatchNorm3d(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | residual = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | #conv2_rep = out 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | residual = self.downsample(x) 69 | 70 | out += residual 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | 79 | def __init__(self, in_planes, planes, stride=1, downsample=None): 80 | super(Bottleneck, self).__init__() 81 | self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False) 82 | self.bn1 = nn.BatchNorm3d(planes) 83 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, 84 | padding=1, bias=False) 85 | self.bn2 = nn.BatchNorm3d(planes) 86 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 87 | self.bn3 = nn.BatchNorm3d(planes * 4) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.downsample = downsample 90 | self.stride = stride 91 | 92 | def forward(self, x): 93 | residual = x 94 | 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv3(out) 104 | #conv3_rep = out 105 | out = self.bn3(out) 106 | 107 | if self.downsample is not None: 108 | residual = self.downsample(x) 109 | 110 | out += residual 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | 116 | class ResNet3D(nn.Module): 117 | def __init__(self, in_channels, block, layers): 118 | self.in_planes = 64 119 | super(ResNet3D, self).__init__() 120 | self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=5, stride=2, padding=2, bias=False) # 128 -> 64 121 | self.bn1 = nn.BatchNorm3d(64) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) # 64 -> 32 124 | self.layer1 = self._make_layer(block, 64, layers[0]) 125 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # 32 -> 16 126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 16 -> 8 127 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 8 -> 4 128 | ''' 129 | self.avgpool = nn.AvgPool3d(7, stride=1) 130 | self.fc = nn.Linear(512 * block.expansion, num_classes) 131 | ''' 132 | 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv3d): 135 | m.weight = nn.init.xavier_normal_(m.weight) 136 | elif isinstance(m, nn.BatchNorm3d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | 140 | def _make_layer(self, block, planes, blocks, stride=1): 141 | downsample = None 142 | if stride != 1 or self.in_planes != planes * block.expansion: 143 | downsample = nn.Sequential( 144 | nn.Conv3d(self.in_planes, planes * block.expansion, 145 | kernel_size=1, stride=stride, bias=False), 146 | nn.BatchNorm3d(planes * block.expansion) 147 | ) 148 | 149 | layers = [] 150 | layers.append(block(self.in_planes, planes, stride, downsample)) 151 | self.in_planes = planes * block.expansion 152 | for i in range(1, blocks): 153 | layers.append(block(self.in_planes, planes)) 154 | 155 | return nn.Sequential(*layers) 156 | 157 | def forward(self, x): # 128 158 | c1 = self.conv1(x) # 64 --> 8 anchor_area 159 | c1 = self.bn1(c1) 160 | c1 = self.relu(c1) 161 | c2 = self.maxpool(c1) # 32 162 | 163 | c2 = self.layer1(c2) # 32 --> 16 anchor_area 164 | c3 = self.layer2(c2) # 16 --> 32 anchor_area 165 | c4 = self.layer3(c3) # 8 166 | c5 = self.layer4(c4) # 4 167 | ''' 168 | x = self.avgpool(x) 169 | x = x.view(x.size(0), -1) 170 | x = self.fc(x) 171 | ''' 172 | return c1, c2, c3, c4, c5 173 | 174 | 175 | def resnet18(in_channels=3, pretrained=False, **kwargs): 176 | """Constructs a ResNet-18 model. 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | """ 180 | model = ResNet3D(in_channels, BasicBlock, [2, 2, 2, 2], **kwargs) 181 | # if pretrained: 182 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 183 | return model 184 | 185 | 186 | def resnet34(in_channels=3, pretrained=False, **kwargs): 187 | """Constructs a ResNet-34 model. 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = ResNet3D(in_channels, BasicBlock, [3, 4, 6, 3], **kwargs) 192 | # if pretrained: 193 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 194 | return model 195 | 196 | 197 | def resnet50(in_channels=3, pretrained=False, **kwargs): 198 | """Constructs a ResNet-50 model. 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | """ 202 | model = ResNet3D(in_channels, Bottleneck, [3, 4, 6, 3], **kwargs) 203 | # if pretrained: 204 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 205 | return model 206 | 207 | 208 | def resnet101(in_channels=3, pretrained=False, **kwargs): 209 | """Constructs a ResNet-101 model. 210 | Args: 211 | pretrained (bool): If True, returns a model pre-trained on ImageNet 212 | """ 213 | model = ResNet3D(in_channels, Bottleneck, [3, 4, 23, 3], **kwargs) 214 | # if pretrained: 215 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 216 | return model 217 | 218 | 219 | def resnet152(in_channels=3, pretrained=False, **kwargs): 220 | """Constructs a ResNet-152 model. 221 | Args: 222 | pretrained (bool): If True, returns a model pre-trained on ImageNet 223 | """ 224 | model = ResNet3D(in_channels, Bottleneck, [3, 8, 36, 3], **kwargs) 225 | # if pretrained: 226 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 227 | return model 228 | -------------------------------------------------------------------------------- /conerf/model/misc.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=[E1101,W0108] 2 | 3 | import gc 4 | from collections import defaultdict 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.autograd import Function 9 | from torch.cuda.amp import custom_bwd, custom_fwd 10 | 11 | import tinycudann as tcnn 12 | 13 | 14 | def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs): 15 | B = None 16 | for arg in args: 17 | if isinstance(arg, torch.Tensor): 18 | B = arg.shape[0] 19 | break 20 | out = defaultdict(list) 21 | out_type = None 22 | chunk_length = 0 23 | for i in range(0, B, chunk_size): 24 | out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) 25 | if out_chunk is None: 26 | continue 27 | out_type = type(out_chunk) 28 | if isinstance(out_chunk, torch.Tensor): 29 | out_chunk = {0: out_chunk} 30 | elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): 31 | chunk_length = len(out_chunk) 32 | out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} 33 | elif isinstance(out_chunk, dict): 34 | pass 35 | else: 36 | print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.') 37 | exit(1) 38 | for k, v in out_chunk.items(): 39 | if v is None: 40 | chunk_length -= 1 41 | continue 42 | v = v if torch.is_grad_enabled() else v.detach() 43 | v = v.cpu() if move_to_cpu else v 44 | out[k].append(v) 45 | 46 | if out_type is None: 47 | return 48 | 49 | out = {k: torch.cat(v, dim=0) for k, v in out.items()} 50 | if out_type is torch.Tensor: 51 | return out[0] 52 | elif out_type in [tuple, list]: 53 | # return out_type([out[i] for i in range(chunk_length)]) 54 | return out_type([out[i] for i in out.keys()]) 55 | elif out_type is dict: 56 | return out 57 | 58 | 59 | class _TruncExp(Function): # pylint: disable=abstract-method 60 | # Implementation from torch-ngp: 61 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py 62 | @staticmethod 63 | @custom_fwd(cast_inputs=torch.float32) 64 | def forward(ctx, x): # pylint: disable=arguments-differ 65 | ctx.save_for_backward(x) 66 | return torch.exp(x) 67 | 68 | @staticmethod 69 | @custom_bwd 70 | def backward(ctx, g): # pylint: disable=arguments-differ 71 | x = ctx.saved_tensors[0] 72 | return g * torch.exp(torch.clamp(x, max=15)) 73 | 74 | trunc_exp = _TruncExp.apply 75 | 76 | 77 | def get_activation(name): 78 | if name is None: 79 | return lambda x: x 80 | name = name.lower() 81 | if name == 'none': 82 | return lambda x: x 83 | elif name.startswith('scale'): 84 | scale_factor = float(name[5:]) 85 | return lambda x: x.clamp(0., scale_factor) / scale_factor 86 | elif name.startswith('clamp'): 87 | clamp_max = float(name[5:]) 88 | return lambda x: x.clamp(0., clamp_max) 89 | elif name.startswith('mul'): 90 | mul_factor = float(name[3:]) 91 | return lambda x: x * mul_factor 92 | elif name == 'lin2srgb': 93 | return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) 94 | elif name == 'trunc_exp': 95 | return trunc_exp 96 | elif name.startswith('+') or name.startswith('-'): 97 | return lambda x: x + float(name) 98 | elif name == 'sigmoid': 99 | return lambda x: torch.sigmoid(x) 100 | elif name == 'tanh': 101 | return lambda x: torch.tanh(x) 102 | else: 103 | return getattr(F, name) 104 | 105 | 106 | def dot(x, y): 107 | return torch.sum(x*y, -1, keepdim=True) 108 | 109 | 110 | def reflect(x, n): 111 | return 2 * dot(x, n) * n - x 112 | 113 | 114 | def cleanup(): 115 | gc.collect() 116 | torch.cuda.empty_cache() 117 | tcnn.free_temporary_memory() -------------------------------------------------------------------------------- /conerf/model/scene_regressor/ace_encoder_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/conerf/model/scene_regressor/ace_encoder_pretrained.pt -------------------------------------------------------------------------------- /conerf/model/scene_regressor/ace_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright © Niantic, Inc. 2022. 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def weighted_tanh(repro_errs, weight): 9 | return weight * torch.tanh(repro_errs / weight).sum() 10 | 11 | 12 | class ReproLoss: 13 | """ 14 | Compute per-pixel reprojection loss using different configurable approaches. 15 | 16 | - tanh: tanh loss with a constant scale factor given by the `soft_clamp` parameter 17 | (when a pixel's reprojection error is equal to `soft_clamp`, its loss is equal 18 | to `soft_clamp * tanh(1)`). 19 | - dyntanh: Used in the paper, similar to the tanh loss above, but the scaling factor 20 | decreases during the course of the training from `soft_clamp` to `soft_clamp_min`. 21 | The decrease is linear, unless `circle_schedule` is True (default), in which 22 | case it applies a circular scheduling. See paper for details. 23 | - l1: Standard L1 loss, computed only on those pixels having an error lower than `soft_clamp` 24 | - l1+sqrt: L1 loss for pixels with reprojection error smaller than `soft_clamp` and 25 | `sqrt(soft_clamp * reprojection_error)` for pixels with a higher error. 26 | - l1+logl1: Similar to the above, but using log L1 for pixels with high reprojection error. 27 | """ 28 | 29 | def __init__(self, 30 | total_iterations, 31 | soft_clamp, 32 | soft_clamp_min, 33 | type='dyntanh', 34 | circle_schedule=True): 35 | 36 | self.total_iterations = total_iterations 37 | self.soft_clamp = soft_clamp 38 | self.soft_clamp_min = soft_clamp_min 39 | self.type = type 40 | self.circle_schedule = circle_schedule 41 | 42 | def compute(self, repro_errs_b1N, iteration): 43 | if repro_errs_b1N.nelement() == 0: 44 | return 0 45 | 46 | if self.type == "tanh": 47 | return weighted_tanh(repro_errs_b1N, self.soft_clamp) 48 | 49 | elif self.type == "dyntanh": 50 | # Compute the progress over the training process. 51 | schedule_weight = iteration / self.total_iterations 52 | 53 | if self.circle_schedule: 54 | # Optionally scale it using the circular schedule. 55 | schedule_weight = 1 - math.sqrt(1 - schedule_weight ** 2) 56 | 57 | # Compute the weight to use in the tanh loss. 58 | loss_weight = (1 - schedule_weight) * self.soft_clamp + self.soft_clamp_min 59 | 60 | # Compute actual loss. 61 | return weighted_tanh(repro_errs_b1N, loss_weight) 62 | 63 | elif self.type == "l1": 64 | # L1 loss on all pixels with small-enough error. 65 | softclamp_mask_b1 = repro_errs_b1N > self.soft_clamp 66 | return repro_errs_b1N[~softclamp_mask_b1].sum() 67 | 68 | elif self.type == "l1+sqrt": 69 | # L1 loss on pixels with small errors and sqrt for the others. 70 | softclamp_mask_b1 = repro_errs_b1N > self.soft_clamp 71 | loss_l1 = repro_errs_b1N[~softclamp_mask_b1].sum() 72 | loss_sqrt = torch.sqrt(self.soft_clamp * repro_errs_b1N[softclamp_mask_b1]).sum() 73 | 74 | return loss_l1 + loss_sqrt 75 | 76 | else: 77 | # l1+logl1: same as above, but use log(L1) for pixels with a larger error. 78 | softclamp_mask_b1 = repro_errs_b1N > self.soft_clamp 79 | loss_l1 = repro_errs_b1N[~softclamp_mask_b1].sum() 80 | loss_logl1 = torch.log(1 + (self.soft_clamp * repro_errs_b1N[softclamp_mask_b1])).sum() 81 | 82 | return loss_l1 + loss_logl1 83 | -------------------------------------------------------------------------------- /conerf/model/scene_regressor/ace_util.py: -------------------------------------------------------------------------------- 1 | # Copyright © Niantic, Inc. 2022. 2 | # pylint: disable=[E1101] 3 | 4 | import torch 5 | 6 | from conerf.datasets.utils import store_ply 7 | 8 | 9 | def get_pixel_grid(image_height: int, image_width: int): 10 | """ 11 | Generate target pixel positions according to image height and width, assuming 12 | prediction at center pixel. 13 | """ 14 | ys = torch.arange(image_height, dtype=torch.float32) 15 | xs = torch.arange(image_width, dtype=torch.float32) 16 | yy, xx = torch.meshgrid(ys, xs, indexing='ij') 17 | 18 | return torch.stack([xx, yy]) + 0.5 19 | 20 | 21 | # def get_pixel_grid(subsampling_factor): 22 | # """ 23 | # Generate target pixel positions according to a subsampling factor, assuming prediction 24 | # at center pixel. 25 | # """ 26 | # pix_range = torch.arange(np.ceil(5000 / subsampling_factor), dtype=torch.float32) 27 | # yy, xx = torch.meshgrid(pix_range, pix_range, indexing='ij') 28 | 29 | # return subsampling_factor * (torch.stack([xx, yy]) + 0.5) 30 | 31 | 32 | def to_homogeneous(input_tensor, dim=1): 33 | """ 34 | Converts tensor to homogeneous coordinates by adding ones to the specified dimension 35 | """ 36 | ones = torch.ones_like(input_tensor.select(dim, 0).unsqueeze(dim)) 37 | output = torch.cat([input_tensor, ones], dim=dim) 38 | 39 | return output 40 | 41 | 42 | def save_point_cloud(points3d: torch.Tensor, colors: torch.Tensor = None, path: str = ""): 43 | """Save point cloud to '.ply' file. 44 | """ 45 | if isinstance(points3d, torch.Tensor): 46 | points3d = points3d.detach().cpu().numpy() 47 | 48 | if colors is not None: 49 | if isinstance(colors, torch.Tensor): 50 | colors = colors.detach().cpu().numpy() 51 | 52 | store_ply(path, points3d, colors) 53 | -------------------------------------------------------------------------------- /conerf/model/scene_regressor/calibr.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=[E1101] 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Calibr(nn.Module): 8 | """ 9 | A modular class for calibration refinement. 10 | NOTE: 11 | The class assumes that: 12 | (1) the principle point is in the center; 13 | (2) pixels are unskewed and square; 14 | (3) image distortion is not modelled. 15 | """ 16 | 17 | def __init__(self, device: str = "cuda"): 18 | super(Calibr, self).__init__() 19 | 20 | self.device = device 21 | self.scaler = torch.nn.Parameter(torch.tensor(0.), requires_grad=True) 22 | 23 | def forward(self, heights: torch.Tensor, widths: torch.Tensor) -> torch.Tensor: 24 | batch_size = heights.shape[0] 25 | # The initial focal length is set to 70% of the image diagonal. 26 | focal_lengths_init = 0.7 * torch.sqrt(heights ** 2 + widths ** 2) 27 | 28 | # assume principle point is in the center. 29 | cxs = widths / 2 30 | cys = heights / 2 31 | 32 | focal_lengths = focal_lengths_init * (1 + self.scaler) 33 | 34 | Ks = torch.eye(3, device=self.device)[None, ...].repeat(batch_size, 1, 1) 35 | Ks[:, 0, 0] = Ks[:, 1, 1] = focal_lengths 36 | Ks[:, 0, 2] = cxs 37 | Ks[:, 1, 2] = cys 38 | 39 | return Ks 40 | -------------------------------------------------------------------------------- /conerf/model/scene_regressor/depth_network.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=[E1101,W0212] 2 | 3 | import torch 4 | import ssl 5 | ssl._create_default_https_context = ssl._create_unverified_context 6 | 7 | 8 | class DepthNetwork: 9 | """ 10 | A wrapper of different depth network (ZoeDepth and Metric3D) 11 | """ 12 | def __init__( 13 | self, 14 | method: str = "ZoeDepth", 15 | depth_type: str = "ZoeD_NK", 16 | pretrain: bool = True, 17 | depth_min: float = 0.1, 18 | depth_max: float = 1000, 19 | device: str = "cuda" 20 | ) -> None: 21 | self.method = method 22 | self.depth_min = depth_min 23 | self.depth_max = depth_max 24 | self.device = device 25 | 26 | if method == "ZoeDepth": 27 | self.depth_network = torch.hub.load( 28 | 'isl-org/ZoeDepth', 29 | depth_type, 30 | pretrained=pretrain, 31 | ).to(self.device) 32 | elif method == "metric3d": 33 | self.depth_network = torch.hub.load( 34 | 'yvanyin/metric3d', 35 | depth_type, 36 | pretrain=pretrain, 37 | ).to(self.device) 38 | else: 39 | raise NotImplementedError 40 | 41 | def infer(self, image: torch.Tensor): 42 | """ 43 | Param: 44 | @param image: [B,3,H,W] 45 | Return: 46 | depth: depth map for image [B,1,H,W] 47 | confidence: confidence score corresponds to the depth map 48 | output_dict: other outputs from metric3d 49 | """ 50 | confidence = None 51 | output_dict = None 52 | if self.method == "ZoeDepth": 53 | depth = self.depth_network.infer(image) 54 | elif self.method == "metric3d": 55 | depth, confidence, output_dict = self.depth_network.inference({'input': image}) 56 | else: 57 | raise NotImplementedError 58 | 59 | depth = torch.clamp(depth, self.depth_min, self.depth_max) 60 | 61 | return depth, confidence, output_dict 62 | -------------------------------------------------------------------------------- /conerf/model/scene_regressor/pose_refine_network.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=[E1101,E1102] 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from conerf.geometry.pose_util import se3_exp_map 7 | from conerf.model.backbone.mlp import MLP 8 | 9 | 10 | class PoseRefineNetwork(nn.Module): 11 | """ 12 | Optimize the 6DoF camera poses. 13 | """ 14 | 15 | def __init__(self, input_dim: int = 12, output_dim: int = 6, hidden_dim: int = 128): 16 | super(PoseRefineNetwork, self).__init__() 17 | 18 | self.hidden_dim = hidden_dim 19 | 20 | self.mlp = MLP( 21 | input_dim=input_dim, 22 | output_dim=output_dim, 23 | net_depth=6, # hard-coded. 24 | net_width=hidden_dim, 25 | skip_layer=3, # hard-coded. 26 | # TODO(chenyu): check with the hidden activation since it is not mentioned in the paper. 27 | hidden_activation=nn.ReLU(), 28 | ) 29 | 30 | def forward(self, poses: torch.Tensor): 31 | """ 32 | Parameters: 33 | @param poses: [N,3/4,4] 34 | Returns: 35 | optimized poses [N,4,4] 36 | """ 37 | batch_size = poses.shape[0] 38 | 39 | poses_3x4 = poses[:, :3, :].reshape(batch_size, -1) # [B,12] 40 | delta_se3 = self.mlp(poses_3x4) # [B,6] 41 | delta_pose_4x4 = se3_exp_map(delta_se3) # [B,4,4] 42 | 43 | updated_poses = poses @ delta_pose_4x4 44 | 45 | return updated_poses, delta_se3 46 | 47 | # poses = poses[:, :3, :].reshape(batch_size, -1) # [B,12] 48 | # poses = self.mlp(poses) # [B,12] 49 | # poses = poses.reshape(batch_size, 3, 4) # [B,3,4] 50 | 51 | # # Retraction to recover the rotational part. 52 | # Us, _, Vhs = torch.linalg.svd(poses[:, :3, :3]) # pylint: disable=C0103 53 | 54 | # updated_poses = torch.eye(4, device=poses.device).reshape(-1, 4).repeat(batch_size, 1, 1) 55 | 56 | # # R = U @ V^T. 57 | # # Construct Z to fix the orientation of R to get det(R) = 1. 58 | # Z = torch.eye(3, device=poses.device).reshape(-1, 3).repeat(batch_size, 1, 1) 59 | # Z[:, -1, -1] = Z [:, -1, -1] * torch.sign(torch.linalg.det(Us @ Vhs)) 60 | # updated_poses[:, :3, :3] = Us @ Z @ Vhs 61 | 62 | # # Copy translational part. 63 | # updated_poses[:, :3, 3:] = poses[:, :3, 3:] 64 | 65 | # return updated_poses 66 | -------------------------------------------------------------------------------- /conerf/pycolmap/pycolmap/__init__.py: -------------------------------------------------------------------------------- 1 | from conerf.pycolmap.pycolmap.camera import Camera 2 | from conerf.pycolmap.pycolmap.database import COLMAPDatabase 3 | from conerf.pycolmap.pycolmap.image import Image 4 | from conerf.pycolmap.pycolmap.scene_manager import SceneManager 5 | from conerf.pycolmap.pycolmap.rotation import Quaternion, DualQuaternion 6 | -------------------------------------------------------------------------------- /conerf/pycolmap/pycolmap/image.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | #------------------------------------------------------------------------------- 6 | # 7 | # Image 8 | # 9 | #------------------------------------------------------------------------------- 10 | 11 | class Image: 12 | def __init__(self, name_, camera_id_, q_, tvec_): 13 | self.name = name_ 14 | self.camera_id = camera_id_ 15 | self.q = q_ 16 | self.tvec = tvec_ 17 | 18 | self.points2D = np.empty((0, 2), dtype=np.float64) 19 | self.point3D_ids = np.empty((0,), dtype=np.uint64) 20 | 21 | #--------------------------------------------------------------------------- 22 | 23 | def R(self): 24 | return self.q.ToR() 25 | 26 | #--------------------------------------------------------------------------- 27 | 28 | def C(self): 29 | return -self.R().T.dot(self.tvec) 30 | 31 | #--------------------------------------------------------------------------- 32 | 33 | @property 34 | def t(self): 35 | return self.tvec 36 | -------------------------------------------------------------------------------- /conerf/pycolmap/tools/colmap_to_nvm.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import sys 3 | sys.path.append("..") 4 | 5 | import numpy as np 6 | 7 | from pycolmap import Quaternion, SceneManager 8 | 9 | 10 | #------------------------------------------------------------------------------- 11 | 12 | def main(args): 13 | scene_manager = SceneManager(args.input_folder) 14 | scene_manager.load() 15 | 16 | with open(args.output_file, "w") as fid: 17 | fid.write("NVM_V3\n \n{:d}\n".format(len(scene_manager.images))) 18 | 19 | image_fmt_str = " {:.3f} " + 7 * "{:.7f} " 20 | for image_id, image in scene_manager.images.iteritems(): 21 | camera = scene_manager.cameras[image.camera_id] 22 | f = 0.5 * (camera.fx + camera.fy) 23 | fid.write(args.image_name_prefix + image.name) 24 | fid.write(image_fmt_str.format( 25 | *((f,) + tuple(image.q.q) + tuple(image.C())))) 26 | if camera.distortion_func is None: 27 | fid.write("0 0\n") 28 | else: 29 | fid.write("{:.7f} 0\n".format(-camera.k1)) 30 | 31 | image_id_to_idx = dict( 32 | (image_id, i) for i, image_id in enumerate(scene_manager.images)) 33 | 34 | fid.write("{:d}\n".format(len(scene_manager.points3D))) 35 | for i, point3D_id in enumerate(scene_manager.point3D_ids): 36 | fid.write( 37 | "{:.7f} {:.7f} {:.7f} ".format(*scene_manager.points3D[i])) 38 | fid.write( 39 | "{:d} {:d} {:d} ".format(*scene_manager.point3D_colors[i])) 40 | keypoints = [ 41 | (image_id_to_idx[image_id], kp_idx) + 42 | tuple(scene_manager.images[image_id].points2D[kp_idx]) 43 | for image_id, kp_idx in 44 | scene_manager.point3D_id_to_images[point3D_id]] 45 | fid.write("{:d}".format(len(keypoints))) 46 | fid.write( 47 | (len(keypoints) * " {:d} {:d} {:.3f} {:.3f}" + "\n").format( 48 | *itertools.chain(*keypoints))) 49 | 50 | 51 | #------------------------------------------------------------------------------- 52 | 53 | if __name__ == "__main__": 54 | import argparse 55 | 56 | parser = argparse.ArgumentParser( 57 | description="Save a COLMAP reconstruction in the NVM format " 58 | "(http://ccwu.me/vsfm/doc.html#nvm).", 59 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 60 | 61 | parser.add_argument("input_folder") 62 | parser.add_argument("output_file") 63 | 64 | parser.add_argument("--image_name_prefix", type=str, default="", 65 | help="prefix image names with this string (e.g., 'images/')") 66 | 67 | args = parser.parse_args() 68 | 69 | main(args) 70 | -------------------------------------------------------------------------------- /conerf/pycolmap/tools/delete_images.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | 6 | from pycolmap import DualQuaternion, Image, SceneManager 7 | 8 | 9 | #------------------------------------------------------------------------------- 10 | 11 | def main(args): 12 | scene_manager = SceneManager(args.input_folder) 13 | scene_manager.load() 14 | 15 | image_ids = map(scene_manager.get_image_from_name, 16 | iter(lambda: sys.stdin.readline().strip(), "")) 17 | scene_manager.delete_images(image_ids) 18 | 19 | scene_manager.save(args.output_folder) 20 | 21 | 22 | #------------------------------------------------------------------------------- 23 | 24 | if __name__ == "__main__": 25 | import argparse 26 | 27 | parser = argparse.ArgumentParser( 28 | description="Deletes images (filenames read from stdin) from a model.", 29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | 31 | parser.add_argument("input_folder") 32 | parser.add_argument("output_folder") 33 | 34 | args = parser.parse_args() 35 | 36 | main(args) 37 | -------------------------------------------------------------------------------- /conerf/pycolmap/tools/impute_missing_cameras.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | 6 | from pycolmap import DualQuaternion, Image, SceneManager 7 | 8 | 9 | #------------------------------------------------------------------------------- 10 | 11 | image_to_idx = lambda im: int(im.name[:im.name.rfind(".")]) 12 | 13 | 14 | #------------------------------------------------------------------------------- 15 | 16 | def interpolate_linear(images, camera_id, file_format): 17 | if len(images) < 2: 18 | raise ValueError("Need at least two images for linear interpolation!") 19 | 20 | prev_image = images[0] 21 | prev_idx = image_to_idx(prev_image) 22 | prev_dq = DualQuaternion.FromQT(prev_image.q, prev_image.t) 23 | start = prev_idx 24 | 25 | new_images = [] 26 | 27 | for image in images[1:]: 28 | curr_idx = image_to_idx(image) 29 | curr_dq = DualQuaternion.FromQT(image.q, image.t) 30 | T = curr_idx - prev_idx 31 | Tinv = 1. / T 32 | 33 | # like quaternions, dq(x) = -dq(x), so we'll need to pick the one more 34 | # appropriate for interpolation by taking -dq if the dot product of the 35 | # two q-vectors is negative 36 | if prev_dq.q0.dot(curr_dq.q0) < 0: 37 | curr_dq = -curr_dq 38 | 39 | for i in xrange(1, T): 40 | t = i * Tinv 41 | dq = t * prev_dq + (1. - t) * curr_dq 42 | q, t = dq.ToQT() 43 | new_images.append( 44 | Image(file_format.format(prev_idx + i), args.camera_id, q, t)) 45 | 46 | prev_idx = curr_idx 47 | prev_dq = curr_dq 48 | 49 | return new_images 50 | 51 | 52 | #------------------------------------------------------------------------------- 53 | 54 | def interpolate_hermite(images, camera_id, file_format): 55 | if len(images) < 4: 56 | raise ValueError( 57 | "Need at least four images for Hermite spline interpolation!") 58 | 59 | new_images = [] 60 | 61 | # linear blending for the first frames 62 | T0 = image_to_idx(images[0]) 63 | dq0 = DualQuaternion.FromQT(images[0].q, images[0].t) 64 | T1 = image_to_idx(images[1]) 65 | dq1 = DualQuaternion.FromQT(images[1].q, images[1].t) 66 | 67 | if dq0.q0.dot(dq1.q0) < 0: 68 | dq1 = -dq1 69 | dT = 1. / float(T1 - T0) 70 | for j in xrange(1, T1 - T0): 71 | t = j * dT 72 | dq = ((1. - t) * dq0 + t * dq1).normalize() 73 | new_images.append( 74 | Image(file_format.format(T0 + j), camera_id, *dq.ToQT())) 75 | 76 | T2 = image_to_idx(images[2]) 77 | dq2 = DualQuaternion.FromQT(images[2].q, images[2].t) 78 | if dq1.q0.dot(dq2.q0) < 0: 79 | dq2 = -dq2 80 | 81 | # Hermite spline interpolation of dual quaternions 82 | # pdfs.semanticscholar.org/05b1/8ede7f46c29c2722fed3376d277a1d286c55.pdf 83 | for i in xrange(1, len(images) - 2): 84 | T3 = image_to_idx(images[i + 2]) 85 | dq3 = DualQuaternion.FromQT(images[i + 2].q, images[i + 2].t) 86 | if dq2.q0.dot(dq3.q0) < 0: 87 | dq3 = -dq3 88 | 89 | prev_duration = T1 - T0 90 | current_duration = T2 - T1 91 | next_duration = T3 - T2 92 | 93 | # approximate the derivatives at dq1 and dq2 using weighted central 94 | # differences 95 | dt1 = 1. / float(T2 - T0) 96 | dt2 = 1. / float(T3 - T1) 97 | 98 | m1 = (current_duration * dt1) * (dq2 - dq1) + \ 99 | (prev_duration * dt1) * (dq1 - dq0) 100 | m2 = (next_duration * dt2) * (dq3 - dq2) + \ 101 | (current_duration * dt2) * (dq2 - dq1) 102 | 103 | dT = 1. / float(current_duration) 104 | 105 | for j in xrange(1, current_duration): 106 | t = j * dT # 0 to 1 107 | t2 = t * t # t squared 108 | t3 = t2 * t # t cubed 109 | 110 | # coefficients of the Hermite spline (a=>dq and b=>m) 111 | a1 = 2. * t3 - 3. * t2 + 1. 112 | b1 = t3 - 2. * t2 + t 113 | a2 = -2. * t3 + 3. * t2 114 | b2 = t3 - t2 115 | 116 | dq = (a1 * dq1 + b1 * m1 + a2 * dq2 + b2 * m2).normalize() 117 | 118 | new_images.append( 119 | Image(file_format.format(T1 + j), camera_id, *dq.ToQT())) 120 | 121 | T0, T1, T2 = T1, T2, T3 122 | dq0, dq1, dq2 = dq1, dq2, dq3 123 | 124 | # linear blending for the last frames 125 | dT = 1. / float(T2 - T1) 126 | for j in xrange(1, T2 - T1): 127 | t = j * dT # 0 to 1 128 | dq = ((1. - t) * dq1 + t * dq2).normalize() 129 | new_images.append( 130 | Image(file_format.format(T1 + j), camera_id, *dq.ToQT())) 131 | 132 | return new_images 133 | 134 | 135 | #------------------------------------------------------------------------------- 136 | 137 | def main(args): 138 | scene_manager = SceneManager(args.input_folder) 139 | scene_manager.load() 140 | 141 | images = sorted(scene_manager.images.itervalues(), key=image_to_idx) 142 | 143 | if args.method.lower() == "linear": 144 | new_images = interpolate_linear(images, args.camera_id, args.format) 145 | else: 146 | new_images = interpolate_hermite(images, args.camera_id, args.format) 147 | 148 | map(scene_manager.add_image, new_images) 149 | 150 | scene_manager.save(args.output_folder) 151 | 152 | 153 | #------------------------------------------------------------------------------- 154 | 155 | if __name__ == "__main__": 156 | import argparse 157 | 158 | parser = argparse.ArgumentParser( 159 | description="Given a reconstruction with ordered images *with integer " 160 | "filenames* like '000100.png', fill in missing camera positions for " 161 | "intermediate frames.", 162 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 163 | 164 | parser.add_argument("input_folder") 165 | parser.add_argument("output_folder") 166 | 167 | parser.add_argument("--camera_id", type=int, default=1, 168 | help="camera id to use for the missing images") 169 | 170 | parser.add_argument("--format", type=str, default="{:06d}.png", 171 | help="filename format to use for added images") 172 | 173 | parser.add_argument( 174 | "--method", type=str.lower, choices=("linear", "hermite"), 175 | default="hermite", 176 | help="Pose imputation method") 177 | 178 | args = parser.parse_args() 179 | 180 | main(args) 181 | -------------------------------------------------------------------------------- /conerf/pycolmap/tools/save_cameras_as_ply.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | import os 6 | 7 | from pycolmap import SceneManager 8 | 9 | 10 | #------------------------------------------------------------------------------- 11 | 12 | # Saves the cameras as a mesh 13 | # 14 | # inputs: 15 | # - ply_file: output file 16 | # - images: ordered array of pycolmap Image objects 17 | # - color: color string for the camera 18 | # - scale: amount to shrink/grow the camera model 19 | def save_camera_ply(ply_file, images, scale): 20 | points3D = scale * np.array(( 21 | (0., 0., 0.), 22 | (-1., -1., 1.), 23 | (-1., 1., 1.), 24 | (1., -1., 1.), 25 | (1., 1., 1.))) 26 | 27 | faces = np.array(((0, 2, 1), 28 | (0, 4, 2), 29 | (0, 3, 4), 30 | (0, 1, 3), 31 | (1, 2, 4), 32 | (1, 4, 3))) 33 | 34 | r = np.linspace(0, 255, len(images), dtype=np.uint8) 35 | g = 255 - r 36 | b = r - np.linspace(0, 128, len(images), dtype=np.uint8) 37 | color = np.column_stack((r, g, b)) 38 | 39 | with open(ply_file, "w") as fid: 40 | print>>fid, "ply" 41 | print>>fid, "format ascii 1.0" 42 | print>>fid, "element vertex", len(points3D) * len(images) 43 | print>>fid, "property float x" 44 | print>>fid, "property float y" 45 | print>>fid, "property float z" 46 | print>>fid, "property uchar red" 47 | print>>fid, "property uchar green" 48 | print>>fid, "property uchar blue" 49 | print>>fid, "element face", len(faces) * len(images) 50 | print>>fid, "property list uchar int vertex_index" 51 | print>>fid, "end_header" 52 | 53 | for image, c in zip(images, color): 54 | for p3D in (points3D.dot(image.R()) + image.C()): 55 | print>>fid, p3D[0], p3D[1], p3D[2], c[0], c[1], c[2] 56 | 57 | for i in xrange(len(images)): 58 | for f in (faces + len(points3D) * i): 59 | print>>fid, "3 {} {} {}".format(*f) 60 | 61 | 62 | #------------------------------------------------------------------------------- 63 | 64 | def main(args): 65 | scene_manager = SceneManager(args.input_folder) 66 | scene_manager.load_images() 67 | 68 | images = sorted(scene_manager.images.itervalues(), 69 | key=lambda image: image.name) 70 | 71 | save_camera_ply(args.output_file, images, args.scale) 72 | 73 | 74 | #------------------------------------------------------------------------------- 75 | 76 | if __name__ == "__main__": 77 | import argparse 78 | 79 | parser = argparse.ArgumentParser( 80 | description="Saves camera positions to a PLY for easy viewing outside " 81 | "of COLMAP. Currently, camera FoV is not reflected in the output.", 82 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 83 | 84 | parser.add_argument("input_folder") 85 | parser.add_argument("output_file") 86 | 87 | parser.add_argument("--scale", type=float, default=1., 88 | help="Scaling factor for the camera mesh.") 89 | 90 | args = parser.parse_args() 91 | 92 | main(args) 93 | -------------------------------------------------------------------------------- /conerf/pycolmap/tools/transform_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | 6 | from pycolmap import Quaternion, SceneManager 7 | 8 | 9 | #------------------------------------------------------------------------------- 10 | 11 | def main(args): 12 | scene_manager = SceneManager(args.input_folder) 13 | scene_manager.load() 14 | 15 | # expect each line of input corresponds to one row 16 | P = np.array([ 17 | map(float, sys.stdin.readline().strip().split()) for _ in xrange(3)]) 18 | 19 | scene_manager.points3D[:] = scene_manager.points3D.dot(P[:,:3].T) + P[:,3] 20 | 21 | # get rotation without any global scaling (assuming isotropic scaling) 22 | scale = np.cbrt(np.linalg.det(P[:,:3])) 23 | q_old_from_new = ~Quaternion.FromR(P[:,:3] / scale) 24 | 25 | for image in scene_manager.images.itervalues(): 26 | image.q *= q_old_from_new 27 | image.tvec = scale * image.tvec - image.R().dot(P[:,3]) 28 | 29 | scene_manager.save(args.output_folder) 30 | 31 | 32 | #------------------------------------------------------------------------------- 33 | 34 | if __name__ == "__main__": 35 | import argparse 36 | 37 | parser = argparse.ArgumentParser( 38 | description="Apply a 3x4 transformation matrix to a COLMAP model and " 39 | "save the result as a new model. Row-major input can be piped in from " 40 | "a file or entered via the command line.", 41 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 42 | 43 | parser.add_argument("input_folder") 44 | parser.add_argument("output_folder") 45 | 46 | args = parser.parse_args() 47 | 48 | main(args) 49 | -------------------------------------------------------------------------------- /conerf/pycolmap/tools/write_camera_track_to_bundler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import numpy as np 5 | 6 | from pycolmap import SceneManager 7 | 8 | 9 | #------------------------------------------------------------------------------- 10 | 11 | def main(args): 12 | scene_manager = SceneManager(args.input_folder) 13 | scene_manager.load_cameras() 14 | scene_manager.load_images() 15 | 16 | if args.sort: 17 | images = sorted( 18 | scene_manager.images.itervalues(), key=lambda im: im.name) 19 | else: 20 | images = scene_manager.images.values() 21 | 22 | fid = open(args.output_file, "w") 23 | fid_filenames = open(args.output_file + ".list.txt", "w") 24 | 25 | print>>fid, "# Bundle file v0.3" 26 | print>>fid, len(images), 0 27 | 28 | for image in images: 29 | print>>fid_filenames, image.name 30 | camera = scene_manager.cameras[image.camera_id] 31 | print>>fid, 0.5 * (camera.fx + camera.fy), 0, 0 32 | R, t = image.R(), image.t 33 | print>>fid, R[0, 0], R[0, 1], R[0, 2] 34 | print>>fid, -R[1, 0], -R[1, 1], -R[1, 2] 35 | print>>fid, -R[2, 0], -R[2, 1], -R[2, 2] 36 | print>>fid, t[0], -t[1], -t[2] 37 | 38 | fid.close() 39 | fid_filenames.close() 40 | 41 | 42 | #------------------------------------------------------------------------------- 43 | 44 | if __name__ == "__main__": 45 | import argparse 46 | 47 | parser = argparse.ArgumentParser( 48 | description="Saves the camera positions in the Bundler format. Note " 49 | "that 3D points are not saved.", 50 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | 52 | parser.add_argument("input_folder") 53 | parser.add_argument("output_file") 54 | 55 | parser.add_argument("--sort", default=False, action="store_true", 56 | help="sort the images by their filename") 57 | 58 | args = parser.parse_args() 59 | 60 | main(args) 61 | -------------------------------------------------------------------------------- /conerf/pycolmap/tools/write_depthmap_to_ply.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import imageio 5 | import numpy as np 6 | import os 7 | 8 | from plyfile import PlyData, PlyElement 9 | from pycolmap import SceneManager 10 | from scipy.ndimage.interpolation import zoom 11 | 12 | 13 | #------------------------------------------------------------------------------- 14 | 15 | def main(args): 16 | suffix = ".photometric.bin" if args.photometric else ".geometric.bin" 17 | 18 | image_file = os.path.join(args.dense_folder, "images", args.image_filename) 19 | depth_file = os.path.join( 20 | args.dense_folder, args.stereo_folder, "depth_maps", 21 | args.image_filename + suffix) 22 | if args.save_normals: 23 | normals_file = os.path.join( 24 | args.dense_folder, args.stereo_folder, "normal_maps", 25 | args.image_filename + suffix) 26 | 27 | # load camera intrinsics from the COLMAP reconstruction 28 | scene_manager = SceneManager(os.path.join(args.dense_folder, "sparse")) 29 | scene_manager.load_cameras() 30 | scene_manager.load_images() 31 | 32 | image_id, image = scene_manager.get_image_from_name(args.image_filename) 33 | camera = scene_manager.cameras[image.camera_id] 34 | rotation_camera_from_world = image.R() 35 | camera_center = image.C() 36 | 37 | # load image, depth map, and normal map 38 | image = imageio.imread(image_file) 39 | 40 | with open(depth_file, "rb") as fid: 41 | w = int("".join(iter(lambda: fid.read(1), "&"))) 42 | h = int("".join(iter(lambda: fid.read(1), "&"))) 43 | c = int("".join(iter(lambda: fid.read(1), "&"))) 44 | depth_map = np.fromfile(fid, np.float32).reshape(h, w) 45 | if (h, w) != image.shape[:2]: 46 | depth_map = zoom( 47 | depth_map, 48 | (float(image.shape[0]) / h, float(image.shape[1]) / w), 49 | order=0) 50 | 51 | if args.save_normals: 52 | with open(normals_file, "rb") as fid: 53 | w = int("".join(iter(lambda: fid.read(1), "&"))) 54 | h = int("".join(iter(lambda: fid.read(1), "&"))) 55 | c = int("".join(iter(lambda: fid.read(1), "&"))) 56 | normals = np.fromfile( 57 | fid, np.float32).reshape(c, h, w).transpose([1, 2, 0]) 58 | if (h, w) != image.shape[:2]: 59 | normals = zoom( 60 | normals, 61 | (float(image.shape[0]) / h, float(image.shape[1]) / w, 1.), 62 | order=0) 63 | 64 | if args.min_depth is not None: 65 | depth_map[depth_map < args.min_depth] = 0. 66 | if args.max_depth is not None: 67 | depth_map[depth_map > args.max_depth] = 0. 68 | 69 | # create 3D points 70 | #depth_map = np.minimum(depth_map, 100.) 71 | points3D = np.dstack(camera.get_image_grid() + [depth_map]) 72 | points3D[:,:,:2] *= depth_map[:,:,np.newaxis] 73 | 74 | # save 75 | points3D = points3D.astype(np.float32).reshape(-1, 3) 76 | if args.save_normals: 77 | normals = normals.astype(np.float32).reshape(-1, 3) 78 | image = image.reshape(-1, 3) 79 | if image.dtype != np.uint8: 80 | if image.max() <= 1: 81 | image = (image * 255.).astype(np.uint8) 82 | else: 83 | image = image.astype(np.uint8) 84 | 85 | if args.world_space: 86 | points3D = points3D.dot(rotation_camera_from_world) + camera_center 87 | if args.save_normals: 88 | normals = normals.dot(rotation_camera_from_world) 89 | 90 | if args.save_normals: 91 | vertices = np.rec.fromarrays( 92 | tuple(points3D.T) + tuple(normals.T) + tuple(image.T), 93 | names="x,y,z,nx,ny,nz,red,green,blue") 94 | else: 95 | vertices = np.rec.fromarrays( 96 | tuple(points3D.T) + tuple(image.T), names="x,y,z,red,green,blue") 97 | vertices = PlyElement.describe(vertices, "vertex") 98 | PlyData([vertices]).write(args.output_filename) 99 | 100 | 101 | #------------------------------------------------------------------------------- 102 | 103 | if __name__ == "__main__": 104 | import argparse 105 | 106 | parser = argparse.ArgumentParser( 107 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 108 | 109 | parser.add_argument("dense_folder", type=str) 110 | parser.add_argument("image_filename", type=str) 111 | parser.add_argument("output_filename", type=str) 112 | 113 | parser.add_argument( 114 | "--photometric", default=False, action="store_true", 115 | help="use photometric depthmap instead of geometric") 116 | 117 | parser.add_argument( 118 | "--world_space", default=False, action="store_true", 119 | help="apply the camera->world extrinsic transformation to the result") 120 | 121 | parser.add_argument( 122 | "--save_normals", default=False, action="store_true", 123 | help="load the estimated normal map and save as part of the PLY") 124 | 125 | parser.add_argument( 126 | "--stereo_folder", type=str, default="stereo", 127 | help="folder in the dense workspace containing depth and normal maps") 128 | 129 | parser.add_argument( 130 | "--min_depth", type=float, default=None, 131 | help="set pixels with depth less than this value to zero depth") 132 | 133 | parser.add_argument( 134 | "--max_depth", type=float, default=None, 135 | help="set pixels with depth greater than this value to zero depth") 136 | 137 | args = parser.parse_args() 138 | 139 | main(args) 140 | -------------------------------------------------------------------------------- /conerf/utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from omegaconf import OmegaConf 3 | 4 | 5 | def strs2ints(strs): 6 | strs = strs.split(',') 7 | ints = [] 8 | for num in strs: 9 | ints.append(int(num)) 10 | print(f'ints: {ints}') 11 | return ints 12 | 13 | 14 | def calc_milestones(max_step, muls, divs): 15 | # muls, divs = strs2ints(muls), strs2ints(divs) 16 | milestones = "[" 17 | for mul, div in zip(muls, divs): 18 | milestones += str(max_step * mul // div) 19 | milestones += "," 20 | real_milestones = milestones[:-1] 21 | real_milestones += "]" 22 | return real_milestones 23 | 24 | 25 | OmegaConf.register_new_resolver( 26 | 'calc_exp_lr_decay_rate', 27 | lambda factor, n: factor**(1./n) 28 | ) 29 | OmegaConf.register_new_resolver('add', lambda a, b: a + b) 30 | OmegaConf.register_new_resolver('sub', lambda a, b: a - b) 31 | OmegaConf.register_new_resolver('mul', lambda a, b: a * b) 32 | OmegaConf.register_new_resolver('divi', lambda a, b: a // b) 33 | OmegaConf.register_new_resolver( 34 | 'calc_milestones', 35 | lambda max_step, muls, divs: calc_milestones(max_step, muls, divs) # pylint: disable=W0108 36 | ) 37 | 38 | 39 | def config_parser(): 40 | parser = argparse.ArgumentParser() 41 | 42 | ##################################### Base configs ######################################## 43 | parser.add_argument("--config", 44 | type=str, 45 | default="", 46 | help="absolute path of config file") 47 | parser.add_argument("--suffix", 48 | type=str, 49 | default="", 50 | help="suffix for training folder") 51 | parser.add_argument("--scene", 52 | type=str, 53 | default="", 54 | help="name for the trained scene") 55 | parser.add_argument("--expname", 56 | type=str, 57 | default="", 58 | help="experiment name") 59 | parser.add_argument("--model_folder", 60 | type=str, 61 | default="sparse", # ['sparse', 'zero_gs'] 62 | help="folder that contain colmap model output") 63 | parser.add_argument("--init_ply_type", 64 | type=str, 65 | default="sparse", # ['sparse', 'dense'] 66 | help="use dense or sparse point cloud to initialize 3DGS") 67 | parser.add_argument("--load_specified_images", 68 | action="store_true", 69 | help="Only load the specified images to train.") 70 | 71 | ##################################### Block Training ######################################## 72 | parser.add_argument("--block_id", 73 | type=int, 74 | default=0, 75 | help="block id") 76 | parser.add_argument("--block_data_path", 77 | type=str, 78 | default="", 79 | help="directory that stores the block data") 80 | parser.add_argument("--train_local", 81 | action="store_true", 82 | help="train local blocks") 83 | 84 | ##################################### registration ######################################## 85 | parser.add_argument("--position_embedding_type", 86 | type=str, 87 | default="sine", 88 | help="which kind of positional embedding to use in transformer") 89 | parser.add_argument("--position_embedding_dim", 90 | type=int, 91 | default=256, 92 | help="dimensionality of position embeddings") 93 | parser.add_argument("--position_embedding_scaling", 94 | type=float, 95 | default=1.0, 96 | help="position embedding scale factor") 97 | parser.add_argument("--num_downsample", 98 | type=int, 99 | default=6, 100 | help="how many layers used to downsample points") 101 | parser.add_argument("--robust_loss", 102 | action="store_true", 103 | help="whether to use robust loss function") 104 | 105 | #################################### composite inr blocks ################################# 106 | parser.add_argument("--enable_composite", 107 | action="store_true", 108 | help="whether to composite implicit neural representation blocks.") 109 | 110 | args = parser.parse_args() 111 | 112 | return args 113 | 114 | 115 | def load_config(*yaml_files, cli_args=[]): 116 | yaml_confs = [OmegaConf.load(f) for f in yaml_files] 117 | cli_conf = OmegaConf.from_cli(cli_args) 118 | conf = OmegaConf.merge(*yaml_confs, cli_conf) 119 | OmegaConf.resolve(conf) 120 | 121 | return conf 122 | -------------------------------------------------------------------------------- /conerf/visualization/feature_visualizer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101 2 | 3 | import math 4 | 5 | import torch 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | 9 | import numpy as np 10 | import cv2 11 | 12 | 13 | def plot_feature_map(writer, global_step, ray_sampler, feat_maps, prefix=''): 14 | coarse_feat_map = ray_sampler.target_feat_map[0].transpose(0, 1) 15 | feat_map_grid = torchvision.utils.make_grid( 16 | coarse_feat_map, normalize=True, scale_each=True, nrow=8) 17 | writer.add_image(prefix + 'target_feat_map', feat_map_grid, global_step) 18 | 19 | num_nearby_views = feat_maps[0].shape[0] 20 | for i in range(num_nearby_views): 21 | feat_map = feat_maps[0][i].unsqueeze(0).transpose(0, 1) 22 | # print(f'[DEBUG] feat_map shape: {feat_map}') 23 | feat_map_grid = torchvision.utils.make_grid( 24 | feat_map, normalize=True, scale_each=True, nrow=8) 25 | writer.add_image( 26 | prefix + f'nearby_feat_map-{i}', feat_map_grid, global_step) 27 | 28 | 29 | def feature_map_to_heatmap(feat_maps): 30 | ''' 31 | feat_maps: [C, H, W] 32 | ''' 33 | # Define a transform to convert the image to tensor 34 | transform = transforms.ToTensor() 35 | 36 | num_channels = feat_maps.shape[0] 37 | heat_maps = [] 38 | 39 | for i in range(num_channels): 40 | feat_map = np.asarray(feat_maps)[i] 41 | # print('feat_map.shape:', feat_map.shape) # [H, W] 42 | # print('feat_map type:', feat_map.dtype) # float32 43 | 44 | feat_map = np.asarray(feat_map * 255, dtype=np.uint8) # [0,255] 45 | # print('feat_map type:', feat_map.dtype) # uint8 46 | 47 | # https://www.sohu.com/a/343215045_120197868 48 | feat_map = cv2.applyColorMap(feat_map, cv2.COLORMAP_RAINBOW) 49 | feat_map = transform(feat_map) 50 | 51 | heat_maps.append(feat_map) 52 | 53 | heat_maps = torch.stack(heat_maps, dim=0) # [C, 3, 25, 25] 54 | return heat_maps 55 | 56 | 57 | def feature_maps_to_heatmap(feat_maps): 58 | ''' 59 | Args: 60 | feat_maps: [C, H, W] 61 | Return: 62 | A composed heat map with shape [H, W] 63 | ''' 64 | # Define a transform to convert the image to tensor 65 | transform = transforms.ToTensor() 66 | 67 | # print(f'[DEBUG] feat_maps shape: {feat_maps.shape}') 68 | [c, h, w] = feat_maps.shape 69 | 70 | heatmap = torch.zeros((h, w)) 71 | weight = [] 72 | feat_maps = np.asarray(feat_maps) 73 | 74 | for i in range(c): 75 | feat_map = feat_maps[i] 76 | weight = np.mean(feat_map) 77 | heatmap[:, :] += weight * feat_map 78 | 79 | heatmap = (heatmap - heatmap.min()) / heatmap.max() # normalization 80 | 81 | heatmap = np.asarray(heatmap * 255, dtype=np.uint8) 82 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_RAINBOW) 83 | heatmap = transform(heatmap) 84 | 85 | return heatmap 86 | 87 | 88 | def plot_sampled_feature_map(writer, global_step, target_rgb_feat, rgb_feats, N_rand, prefix='train/'): 89 | width = int(math.sqrt(N_rand)) 90 | target_rgb_feat = target_rgb_feat.detach().cpu() 91 | rgb_feats = rgb_feats.detach().cpu() 92 | 93 | target_rgb_feat = target_rgb_feat.permute( 94 | 3, 2, 0, 1).reshape(35, -1, width, width) 95 | rgb_feats = rgb_feats.permute(3, 2, 0, 1).reshape(35, -1, width, width) 96 | res_rgb_feats = torch.abs(target_rgb_feat - rgb_feats) 97 | target_rgb_feat = target_rgb_feat[:, 0, ...] 98 | 99 | # target_feat_map = feature_map_to_heatmap(target_rgb_feat[3:]) 100 | # feat_map_grid = torchvision.utils.make_grid(target_feat_map, normalize=True, scale_each=True, nrow=8) 101 | # writer.add_image(prefix + f'target_feat_map', feat_map_grid, global_step) # feature map 102 | target_feat_map = feature_maps_to_heatmap(target_rgb_feat[3:]) 103 | writer.add_image(prefix + f'target_feat_map', 104 | target_feat_map, global_step) # feature map 105 | writer.add_image(prefix + f'target_rgb_map', 106 | target_rgb_feat[0:3], global_step) 107 | 108 | num_nearby_views = rgb_feats.shape[1] 109 | nearby_feat_maps, nearby_rgb_maps = [], [] 110 | nearby_res_feat_maps, nearby_res_rgb_maps = [], [] 111 | for i in range(num_nearby_views): 112 | rgb_feat_map = rgb_feats[:, i, ...] 113 | feat_map = feature_maps_to_heatmap(rgb_feat_map[3:]) 114 | nearby_feat_maps.append(feat_map) 115 | nearby_rgb_maps.append(rgb_feat_map[0:3]) 116 | 117 | res_rgb_feat_map = res_rgb_feats[:, i, ...] 118 | res_feat_map = feature_maps_to_heatmap(res_rgb_feat_map[3:]) 119 | nearby_res_feat_maps.append(res_feat_map) 120 | nearby_res_rgb_maps.append(res_rgb_feat_map[0:3]) 121 | 122 | nearby_feat_maps = torch.stack(nearby_feat_maps, dim=0) 123 | nearby_feat_grid = torchvision.utils.make_grid( 124 | nearby_feat_maps, normalize=True, scale_each=True, nrow=5) 125 | writer.add_image(prefix + f'nearby_feat_maps', 126 | nearby_feat_grid, global_step) 127 | 128 | nearby_rgb_maps = torch.stack(nearby_rgb_maps, dim=0) # [n_views, 3, h, w] 129 | nearby_rgb_grid = torchvision.utils.make_grid( 130 | nearby_rgb_maps, normalize=True, scale_each=True, nrow=5) 131 | writer.add_image(prefix + f'nearby_rgb_maps', nearby_rgb_grid, global_step) 132 | 133 | nearby_res_feat_maps = torch.stack(nearby_res_feat_maps, dim=0) 134 | nearby_res_feat_grid = torchvision.utils.make_grid( 135 | nearby_res_feat_maps, normalize=True, scale_each=True, nrow=5) 136 | writer.add_image(prefix + f'nearby_res_feat_maps', 137 | nearby_res_feat_grid, global_step) 138 | 139 | nearby_res_rgb_maps = torch.stack( 140 | nearby_res_rgb_maps, dim=0) # [n_views, 3, h, w] 141 | nearby_res_rgb_grid = torchvision.utils.make_grid( 142 | nearby_res_rgb_maps, normalize=True, scale_each=True, nrow=5) 143 | writer.add_image(prefix + f'nearby_res_rgb_maps', 144 | nearby_res_rgb_grid, global_step) 145 | -------------------------------------------------------------------------------- /conerf/visualization/pose_visualizer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101 2 | 3 | from typing import List 4 | 5 | import torch 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from easydict import EasyDict as edict 9 | 10 | 11 | def to_hom(X): 12 | # get homogeneous coordinates of the input 13 | X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1) 14 | 15 | return X_hom 16 | 17 | 18 | def get_camera_mesh(pose, depth=1): 19 | vertices = torch.tensor([[-0.5, -0.5, 1], 20 | [0.5, -0.5, 1], 21 | [0.5, 0.5, 1], 22 | [-0.5, 0.5, 1], 23 | [0, 0, 0]]) * depth 24 | 25 | faces = torch.tensor([[0, 1, 2], 26 | [0, 2, 3], 27 | [0, 1, 4], 28 | [1, 2, 4], 29 | [2, 3, 4], 30 | [3, 0, 4]]) 31 | 32 | # vertices = camera.cam2world(vertices[None], pose) 33 | vertices = to_hom(vertices[None]) @ pose.transpose(-1, -2) 34 | 35 | wire_frame = vertices[:, [0, 1, 2, 3, 0, 4, 1, 2, 4, 3]] 36 | 37 | return vertices, faces, wire_frame 38 | 39 | 40 | def merge_wire_frames(wire_frame): 41 | wire_frame_merged = [[], [], []] 42 | for w in wire_frame: 43 | wire_frame_merged[0] += [float(n) for n in w[:, 0]] + [None] 44 | wire_frame_merged[1] += [float(n) for n in w[:, 1]] + [None] 45 | wire_frame_merged[2] += [float(n) for n in w[:, 2]] + [None] 46 | 47 | return wire_frame_merged 48 | 49 | 50 | def merge_meshes(vertices, faces): 51 | mesh_N, vertex_N = vertices.shape[:2] 52 | faces_merged = torch.cat([faces+i*vertex_N for i in range(mesh_N)], dim=0) 53 | vertices_merged = vertices.view(-1, vertices.shape[-1]) 54 | 55 | return vertices_merged, faces_merged 56 | 57 | 58 | def merge_centers(centers): 59 | center_merged = [[], [], []] 60 | 61 | for c1, c2 in zip(*centers): 62 | center_merged[0] += [float(c1[0]), float(c2[0]), None] 63 | center_merged[1] += [float(c1[1]), float(c2[1]), None] 64 | center_merged[2] += [float(c1[2]), float(c2[2]), None] 65 | 66 | return center_merged 67 | 68 | 69 | @torch.no_grad() 70 | def visualize_cameras( 71 | vis, 72 | step: int = 0, 73 | poses: List = [], 74 | cam_depth: float = 0.5, 75 | colors: List = ["blue", "magenta"], 76 | plot_dist: bool = True 77 | ): 78 | win_name = "gt_pred" 79 | data = [] 80 | 81 | # set up plots 82 | centers = [] 83 | for pose, color in zip(poses, colors): 84 | pose = pose.detach().cpu() 85 | vertices, faces, wire_frame = get_camera_mesh(pose, depth=cam_depth) 86 | center = vertices[:, -1] 87 | centers.append(center) 88 | 89 | # camera centers 90 | data.append(dict( 91 | type="scatter3d", 92 | x=[float(n) for n in center[:, 0]], 93 | y=[float(n) for n in center[:, 1]], 94 | z=[float(n) for n in center[:, 2]], 95 | mode="markers", 96 | marker=dict(color=color, size=3), 97 | )) 98 | 99 | # colored camera mesh 100 | vertices_merged, faces_merged = merge_meshes(vertices, faces) 101 | 102 | data.append(dict( 103 | type="mesh3d", 104 | x=[float(n) for n in vertices_merged[:, 0]], 105 | y=[float(n) for n in vertices_merged[:, 1]], 106 | z=[float(n) for n in vertices_merged[:, 2]], 107 | i=[int(n) for n in faces_merged[:, 0]], 108 | j=[int(n) for n in faces_merged[:, 1]], 109 | k=[int(n) for n in faces_merged[:, 2]], 110 | flatshading=True, 111 | color=color, 112 | opacity=0.05, 113 | )) 114 | 115 | # camera wire_frame 116 | wire_frame_merged = merge_wire_frames(wire_frame) 117 | data.append(dict( 118 | type="scatter3d", 119 | x=wire_frame_merged[0], 120 | y=wire_frame_merged[1], 121 | z=wire_frame_merged[2], 122 | mode="lines", 123 | line=dict(color=color,), 124 | opacity=0.3, 125 | )) 126 | 127 | if plot_dist: 128 | # distance between two poses (camera centers) 129 | center_merged = merge_centers(centers[:2]) 130 | data.append(dict( 131 | type="scatter3d", 132 | x=center_merged[0], 133 | y=center_merged[1], 134 | z=center_merged[2], 135 | mode="lines", 136 | line=dict(color="red", width=4,), 137 | )) 138 | 139 | if len(centers) == 4: 140 | center_merged = merge_centers(centers[2:4]) 141 | data.append(dict( 142 | type="scatter3d", 143 | x=center_merged[0], 144 | y=center_merged[1], 145 | z=center_merged[2], 146 | mode="lines", 147 | line=dict(color="red", width=4,), 148 | )) 149 | 150 | # send data to visdom 151 | vis._send(dict( 152 | data=data, 153 | win="poses", 154 | eid=win_name, 155 | layout=dict( 156 | title=f"({step})", 157 | autosize=True, 158 | margin=dict(l=30, r=30, b=30, t=30,), 159 | showlegend=False, 160 | yaxis=dict( 161 | scaleanchor="x", 162 | scaleratio=1, 163 | ) 164 | ), 165 | opts=dict(title=f"{win_name} poses ({step})",), 166 | )) 167 | 168 | 169 | def plot_save_poses( 170 | cam_depth: float, 171 | fig, 172 | pose: torch.Tensor, 173 | pose_ref: torch.Tensor = None, 174 | path: str = None, 175 | ep=None, 176 | axis_len: float = 1.0, 177 | ): 178 | # get the camera meshes 179 | _, _, cam = get_camera_mesh(pose, depth=cam_depth) 180 | cam = cam.numpy() 181 | 182 | if pose_ref is not None: 183 | _, _, cam_ref = get_camera_mesh(pose_ref, depth=cam_depth) 184 | cam_ref = cam_ref.numpy() 185 | 186 | # set up plot window(s) 187 | plt.title(f"epoch {ep}") 188 | ax1 = fig.add_subplot(121, projection="3d") 189 | ax2 = fig.add_subplot(122, projection="3d") 190 | setup_3D_plot( 191 | ax1, elev=-90, azim=-90, 192 | lim=edict(x=(-axis_len, axis_len), y=(-axis_len, 193 | axis_len), z=(-axis_len, axis_len)) 194 | ) 195 | setup_3D_plot( 196 | ax2, elev=0, azim=-90, 197 | lim=edict(x=(-axis_len, axis_len), y=(-axis_len, 198 | axis_len), z=(-axis_len, axis_len)) 199 | ) 200 | ax1.set_title("forward-facing view", pad=0) 201 | ax2.set_title("top-down view", pad=0) 202 | plt.subplots_adjust(left=0, right=1, bottom=0, 203 | top=0.95, wspace=0, hspace=0) 204 | plt.margins(tight=True, x=0, y=0) 205 | 206 | # plot the cameras 207 | N = len(cam) 208 | color = plt.get_cmap("gist_rainbow") 209 | for i in range(N): 210 | if pose_ref is not None: 211 | ax1.plot(cam_ref[i, :, 0], cam_ref[i, :, 1], 212 | cam_ref[i, :, 2], color=(0.1, 0.1, 0.1), linewidth=1) 213 | ax2.plot(cam_ref[i, :, 0], cam_ref[i, :, 1], 214 | cam_ref[i, :, 2], color=(0.1, 0.1, 0.1), linewidth=1) 215 | ax1.scatter(cam_ref[i, 5, 0], cam_ref[i, 5, 1], 216 | cam_ref[i, 5, 2], color=(0.1, 0.1, 0.1), s=40) 217 | ax2.scatter(cam_ref[i, 5, 0], cam_ref[i, 5, 1], 218 | cam_ref[i, 5, 2], color=(0.1, 0.1, 0.1), s=40) 219 | c = np.array(color(float(i) / N)) * 0.8 220 | ax1.plot(cam[i, :, 0], cam[i, :, 1], cam[i, :, 2], color=c) 221 | ax2.plot(cam[i, :, 0], cam[i, :, 1], cam[i, :, 2], color=c) 222 | ax1.scatter(cam[i, 5, 0], cam[i, 5, 1], cam[i, 5, 2], color=c, s=40) 223 | ax2.scatter(cam[i, 5, 0], cam[i, 5, 1], cam[i, 5, 2], color=c, s=40) 224 | 225 | png_fname = f"{path}/{ep}.png" 226 | plt.savefig(png_fname, dpi=75) 227 | # clean up 228 | plt.clf() 229 | 230 | 231 | def setup_3D_plot(ax, elev, azim, lim=None): 232 | ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 233 | ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 234 | ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 235 | ax.xaxis._axinfo["grid"]["color"] = (0.9, 0.9, 0.9, 1) 236 | ax.yaxis._axinfo["grid"]["color"] = (0.9, 0.9, 0.9, 1) 237 | ax.zaxis._axinfo["grid"]["color"] = (0.9, 0.9, 0.9, 1) 238 | ax.xaxis.set_tick_params(labelsize=8) 239 | ax.yaxis.set_tick_params(labelsize=8) 240 | ax.zaxis.set_tick_params(labelsize=8) 241 | ax.set_xlabel("X", fontsize=16) 242 | ax.set_ylabel("Y", fontsize=16) 243 | ax.set_zlabel("Z", fontsize=16) 244 | ax.set_xlim(lim.x[0], lim.x[1]) 245 | ax.set_ylim(lim.y[0], lim.y[1]) 246 | ax.set_zlim(lim.z[0], lim.z[1]) 247 | ax.view_init(elev=elev, azim=azim) 248 | -------------------------------------------------------------------------------- /config/ace/llff.yaml: -------------------------------------------------------------------------------- 1 | neural_field_type: mlp 2 | expname: ${neural_field_type}_${task}_${dataset.name}_${dataset.scene} 3 | task: pose 4 | seed: 42 5 | 6 | dataset: 7 | name: llff 8 | root_dir: # eg.: /home/user/datasets/${dataset.name} 9 | encoder_path: # eg: /home/user/Projects/ZeroGS/conerf/model/scene_regressor/ace_encoder_pretrained.pt 10 | scene: ['fern', 'flower', 'fortress', 'horns', 'leaves', 'orchids', 'room', 'trex'] 11 | image_resolution: 12 | scale: true 13 | rotate: false 14 | use_aug: true 15 | aug_rotation: 15 16 | aug_scale: 1.5 17 | factor: 4 18 | val_interval: -1 19 | apply_mask: false 20 | cam_depth: 0.2 21 | axis_len: 1.7 22 | 23 | trainer: 24 | epochs: 16 25 | max_patch_loops_per_epoch: 10 26 | samples_per_image: 1024 27 | training_buffer_size: 8000000 28 | batch_size: 5120 29 | min_iterations_per_epoch: 5000 30 | max_iterations_per_epoch: 10000 31 | early_stop_thresh: 6 32 | use_half: true 33 | ckpt_path: "" 34 | no_load_opt: false 35 | no_load_scheduler: false 36 | enable_tensorboard: true 37 | enable_visdom: false 38 | visdom_server: localhost 39 | visdom_port: 9002 40 | n_tensorboard: 100 41 | n_validation: 5000 42 | n_checkpoint: 1000 43 | distributed: false 44 | excluded_gpus: [] 45 | num_workers: 4 46 | local_rank: 0 47 | 48 | optimizer: 49 | lr_sc_min: 0.0005 # lowest learning rate of 1 cycle scheduler 50 | lr_sc_max: 0.003 # highest learning rate of 1 cycle scheduler 51 | lr_pr: 1e-3 # learning rate for the pose refiner 52 | lr_cr: 1e-3 # learning rate for the calibration refiner 53 | 54 | regressor: 55 | # ZoeD_N is fine-tuned for metric depth on NYU Depth v2 for relative depth estimation, 56 | # ZoeD_K is fine-tuned for metric depth on KITTI for relative depth estimation. 57 | # ZoeD_NK has two separate heads fine-tuned on both NYU Depth v2 and KITTI. 58 | 59 | # [ZoeDepth, metric3d] 60 | depth_net_method: ZoeDepth 61 | # ZoeDepth: [ZoeD_N, ZoeD_K, ZoeD_NK]; metric3d: [metric3d_vit_small, metric3d_vit_large, metric3d_vit_giant2] 62 | depth_net_type: ZoeD_NK 63 | num_seed_image_trials: 5 64 | num_reloc_images_max: 1000 # the number of relocalization test during seed reconstruction. 65 | num_head_blocks: 1 # The depth of the head network. 66 | use_homogeneous: true 67 | depth_min: 0.1 68 | depth_max: 1000 # [ZoeDepth: 1000; metric3d: 200] 69 | depth_target: 10 70 | 71 | pose_estimator: 72 | reproj_thresh: 10 # inlier threshold in pixels (RGB) or centimeters (RGB-D) 73 | hypotheses: 64 # number of hypotheses, i.e. number of RANSAC iterations. 74 | inlier_alpha: 100 # alpha parameter of the soft inlier count. 75 | max_pixel_error: 100 # maximum reprojection (RGB, in px) or 3D distance (RGB-D, in cm) error when checking pose consistency. 76 | min_inlier_count: 2000 # minimum number of inlier correspondences when registering an image 77 | 78 | loss: 79 | repro_loss_hard_clamp: 1000 80 | repro_loss_soft_clamp: 50 81 | repro_loss_soft_clamp_min: 1 82 | repro_loss_type: tanh # dyntanh 83 | repro_loss_scheduler: circle 84 | -------------------------------------------------------------------------------- /config/ace/mipnerf360.yaml: -------------------------------------------------------------------------------- 1 | neural_field_type: mlp 2 | expname: ${neural_field_type}_${task}_${dataset.name}_${dataset.scene} 3 | task: pose 4 | seed: 42 5 | 6 | dataset: 7 | name: mipnerf360 8 | root_dir: # eg.: /home/user/datasets/${dataset.name} 9 | encoder_path: # eg: /home/user/Projects/ZeroGS/conerf/model/scene_regressor/ace_encoder_pretrained.pt 10 | scene: ["bicycle", "bonsai", "counter", "garden", "kitchen", "room", "stump", "flowers", "treehill"] 11 | image_resolution: 12 | scale: true 13 | rotate: false 14 | use_aug: true 15 | aug_rotation: 15 16 | aug_scale: 1.5 17 | factor: 4 18 | val_interval: -1 19 | apply_mask: false 20 | cam_depth: 0.1 21 | axis_len: 1.0 22 | 23 | trainer: 24 | epochs: 16 25 | max_patch_loops_per_epoch: 10 26 | samples_per_image: 1024 27 | training_buffer_size: 8000000 28 | batch_size: 5120 29 | min_iterations_per_epoch: 5000 30 | max_iterations_per_epoch: 10000 31 | early_stop_thresh: 6 32 | use_half: true 33 | ckpt_path: "" 34 | no_load_opt: false 35 | no_load_scheduler: false 36 | enable_tensorboard: true 37 | enable_visdom: false 38 | visdom_server: localhost 39 | visdom_port: 9002 40 | n_tensorboard: 100 41 | n_validation: 5000 42 | n_checkpoint: 1000 43 | distributed: false 44 | excluded_gpus: [] 45 | num_workers: 4 46 | local_rank: 0 47 | 48 | optimizer: 49 | lr_sc_min: 0.0005 # lowest learning rate of 1 cycle scheduler 50 | lr_sc_max: 0.003 # highest learning rate of 1 cycle scheduler 51 | lr_pr: 1e-3 # learning rate for the pose refiner 52 | lr_cr: 1e-3 # learning rate for the calibration refiner 53 | 54 | regressor: 55 | # ZoeD_N is fine-tuned for metric depth on NYU Depth v2 for relative depth estimation, 56 | # ZoeD_K is fine-tuned for metric depth on KITTI for relative depth estimation. 57 | # ZoeD_NK has two separate heads fine-tuned on both NYU Depth v2 and KITTI. 58 | 59 | # [ZoeDepth, metric3d] 60 | depth_net_method: ZoeDepth 61 | # ZoeDepth: [ZoeD_N, ZoeD_K, ZoeD_NK]; metric3d: [metric3d_vit_small, metric3d_vit_large, metric3d_vit_giant2] 62 | depth_net_type: ZoeD_NK 63 | num_seed_image_trials: 5 64 | num_reloc_images_max: 1000 # the number of relocalization test during seed reconstruction. 65 | num_head_blocks: 1 # The depth of the head network. 66 | use_homogeneous: true 67 | depth_min: 0.1 68 | depth_max: 1000 # [ZoeDepth: 1000; metric3d: 200] 69 | depth_target: 10 70 | 71 | pose_estimator: 72 | reproj_thresh: 10 # inlier threshold in pixels (RGB) or centimeters (RGB-D) 73 | hypotheses: 64 # number of hypotheses, i.e. number of RANSAC iterations. 74 | inlier_alpha: 100 # alpha parameter of the soft inlier count. 75 | max_pixel_error: 100 # maximum reprojection (RGB, in px) or 3D distance (RGB-D, in cm) error when checking pose consistency. 76 | min_inlier_count: 2000 # minimum number of inlier correspondences when registering an image 77 | 78 | loss: 79 | repro_loss_hard_clamp: 1000 80 | repro_loss_soft_clamp: 50 81 | repro_loss_soft_clamp_min: 1 82 | repro_loss_type: tanh # dyntanh 83 | repro_loss_scheduler: circle 84 | -------------------------------------------------------------------------------- /config/ace/tanks_and_temples.yaml: -------------------------------------------------------------------------------- 1 | neural_field_type: mlp 2 | expname: ${neural_field_type}_${task}_${dataset.name}_${dataset.scene} 3 | task: pose 4 | seed: 42 5 | 6 | dataset: 7 | name: tanks_and_temples 8 | root_dir: # eg.: /home/user/datasets/${dataset.name} 9 | encoder_path: # eg: /home/user/Projects/ZeroGS/conerf/model/scene_regressor/ace_encoder_pretrained.pt 10 | scene: ["Family", "Francis", "Ignatius", "Train", "Truck", "Playground"] 11 | image_resolution: 12 | scale: true 13 | rotate: false 14 | use_aug: true 15 | aug_rotation: 15 16 | aug_scale: 1.5 17 | factor: 2 18 | val_interval: -1 19 | apply_mask: false 20 | cam_depth: 0.1 21 | axis_len: 1.0 22 | 23 | trainer: 24 | epochs: 16 25 | max_patch_loops_per_epoch: 10 26 | samples_per_image: 1024 27 | training_buffer_size: 8000000 28 | batch_size: 5120 29 | min_iterations_per_epoch: 5000 30 | max_iterations_per_epoch: 10000 31 | early_stop_thresh: 6 32 | use_half: true 33 | ckpt_path: "" 34 | no_load_opt: false 35 | no_load_scheduler: false 36 | enable_tensorboard: true 37 | enable_visdom: false 38 | visdom_server: localhost 39 | visdom_port: 9002 40 | n_tensorboard: 100 41 | n_validation: 5000 42 | n_checkpoint: 1000 43 | distributed: false 44 | excluded_gpus: [] 45 | num_workers: 4 46 | local_rank: 0 47 | 48 | optimizer: 49 | lr_sc_min: 0.0005 # lowest learning rate of 1 cycle scheduler 50 | lr_sc_max: 0.003 # highest learning rate of 1 cycle scheduler 51 | lr_pr: 1e-3 # learning rate for the pose refiner 52 | lr_cr: 1e-3 # learning rate for the calibration refiner 53 | 54 | regressor: 55 | # ZoeD_N is fine-tuned for metric depth on NYU Depth v2 for relative depth estimation, 56 | # ZoeD_K is fine-tuned for metric depth on KITTI for relative depth estimation. 57 | # ZoeD_NK has two separate heads fine-tuned on both NYU Depth v2 and KITTI. 58 | 59 | # [ZoeDepth, metric3d] 60 | depth_net_method: ZoeDepth 61 | # ZoeDepth: [ZoeD_N, ZoeD_K, ZoeD_NK]; metric3d: [metric3d_vit_small, metric3d_vit_large, metric3d_vit_giant2] 62 | depth_net_type: ZoeD_NK 63 | num_seed_image_trials: 5 64 | num_reloc_images_max: 1000 # the number of relocalization test during seed reconstruction. 65 | num_head_blocks: 1 # The depth of the head network. 66 | use_homogeneous: true 67 | depth_min: 0.1 68 | depth_max: 1000 # [ZoeDepth: 1000; metric3d: 200] 69 | depth_target: 10 70 | 71 | pose_estimator: 72 | reproj_thresh: 10 # inlier threshold in pixels (RGB) or centimeters (RGB-D) 73 | hypotheses: 64 # number of hypotheses, i.e. number of RANSAC iterations. 74 | inlier_alpha: 100 # alpha parameter of the soft inlier count. 75 | max_pixel_error: 100 # maximum reprojection (RGB, in px) or 3D distance (RGB-D, in cm) error when checking pose consistency. 76 | min_inlier_count: 2200 # minimum number of inlier correspondences when registering an image 77 | 78 | loss: 79 | repro_loss_hard_clamp: 1000 80 | repro_loss_soft_clamp: 50 81 | repro_loss_soft_clamp_min: 1 82 | repro_loss_type: tanh # dyntanh 83 | repro_loss_scheduler: circle 84 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=[E1101,W0621] 2 | 3 | import os 4 | import copy 5 | import json 6 | import warnings 7 | from typing import List 8 | 9 | import omegaconf 10 | from omegaconf import OmegaConf 11 | 12 | from conerf.evaluators.ace_zero_evaluator import AceZeroEvaluator 13 | from conerf.utils.utils import setup_seed 14 | 15 | warnings.filterwarnings("ignore", category=UserWarning) 16 | 17 | 18 | def create_evaluator( 19 | config: OmegaConf, 20 | load_train_data: bool = False, 21 | trainset=None, 22 | load_val_data: bool = True, 23 | valset=None, 24 | load_test_data: bool = False, 25 | testset = None, 26 | models: List = None, 27 | meta_data: List = None, 28 | verbose: bool = False, 29 | device: str = "cuda", 30 | ): 31 | """Factory function for training neural network trainers.""" 32 | if config.task == "pose": 33 | evaluator = AceZeroEvaluator( 34 | config, load_train_data, trainset, 35 | load_val_data, valset, load_test_data, 36 | testset, models, meta_data, verbose, device 37 | ) 38 | else: 39 | raise NotImplementedError 40 | 41 | return evaluator 42 | 43 | 44 | if __name__ == "__main__": 45 | from conerf.utils.config import config_parser, load_config 46 | args = config_parser() 47 | 48 | # parse YAML config to OmegaConf 49 | config = load_config(args.config) 50 | 51 | assert config.dataset.get("data_split_json", "") != "" or config.dataset.scene != "" 52 | 53 | setup_seed(config.seed) 54 | 55 | scenes = [] 56 | if config.dataset.get("data_split_json", "") != "" and config.dataset.scene == "": 57 | # For objaverse only. 58 | with open(config.dataset.data_split_json, "r", encoding="utf-8") as fp: 59 | obj_id_to_name = json.load(fp) 60 | 61 | for idx, name in obj_id_to_name.items(): 62 | scenes.append(name) 63 | elif ( 64 | type(config.dataset.scene) == omegaconf.listconfig.ListConfig # pylint: disable=C0123 65 | ): # pylint: disable=C0123 66 | scene_list = list(config.dataset.scene) 67 | for sc in config.dataset.scene: 68 | scenes.append(sc) 69 | else: 70 | scenes.append(config.dataset.scene) 71 | 72 | for scene in scenes: 73 | data_dir = os.path.join(config.dataset.root_dir, scene) 74 | if not os.path.exists(data_dir): 75 | continue 76 | 77 | local_config = copy.deepcopy(config) 78 | local_config.expname = ( 79 | f"{config.neural_field_type}_{config.task}_{config.dataset.name}_{scene}" 80 | ) 81 | local_config.expname = local_config.expname + "_" + args.suffix 82 | local_config.dataset.scene = scene 83 | 84 | evaluator = create_evaluator( 85 | local_config, 86 | load_train_data=False, 87 | trainset=None, 88 | load_val_data=True, 89 | valset=None, 90 | load_test_data=True, 91 | testset=None, 92 | verbose=True, 93 | ) 94 | evaluator.eval(split="val") 95 | # evaluator.eval(split="test") 96 | evaluator.export_mesh() 97 | -------------------------------------------------------------------------------- /scripts/env/install.sh: -------------------------------------------------------------------------------- 1 | # conda create -n zero_gs python=3.9 2 | # conda activate zero_gs 3 | 4 | # install pytorch 5 | # Ref: https://pytorch.org/get-started/previous-versions/ 6 | 7 | # CUDA 11.7 8 | # conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.7 -c pytorch -c nvidia 9 | conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.8 -c pytorch -c nvidia 10 | conda install -c "nvidia/label/cuda-11.8.0" cuda 11 | 12 | # Basic packages. 13 | pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg easydict \ 14 | kornia lpips tensorboard visdom tensorboardX matplotlib plyfile trimesh h5py pandas \ 15 | omegaconf PyMCubes Ninja pyransac3d einops pyglet pre-commit pylint GPUtil \ 16 | open3d pyrender 17 | pip install timm==0.6.7 18 | pip install -U scikit-learn 19 | pip install git+https://github.com/jonbarron/robust_loss_pytorch 20 | pip install torch-geometric==2.4.0 21 | 22 | conda install pytorch3d -c pytorch3d 23 | conda install conda-forge::opencv 24 | conda install pytorch-scatter -c pyg 25 | conda remove ffmpeg --force 26 | 27 | # Third-parties. 28 | 29 | cd submodules/dsacstar 30 | python setup.py install 31 | 32 | cd ../../ 33 | pip install submodules/simple-knn 34 | pip install submodules/diff-gaussian-rasterization 35 | 36 | mkdir 3rd_party && cd 3rd_party 37 | 38 | git clone https://github.com/cvg/sfm-disambiguation-colmap.git 39 | cd sfm-disambiguation-colmap 40 | python -m pip install -e . 41 | cd .. 42 | 43 | # HLoc is used for extracting keypoints and matching features. 44 | git clone --recursive https://github.com/cvg/Hierarchical-Localization/ 45 | cd Hierarchical-Localization/ 46 | python -m pip install -e . 47 | cd .. 48 | 49 | # Tiny-cuda-cnn & nerfacc 50 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 51 | 52 | # nerfacc 53 | # pip install nerfacc -f https://nerfacc-bucket.s3.us-west-2.amazonaws.com/whl/torch-1.13.1_cu117.html 54 | # or install the latest version 55 | # pip install git+https://github.com/KAIR-BAIR/nerfacc.git 56 | # To install a specified version: 57 | # pip install nerfacc==0.3.5 -f https://nerfacc-bucket.s3.us-west-2.amazonaws.com/whl/torch-1.13.1_cu117.html 58 | pip install nerfacc==0.3.5 -f https://nerfacc-bucket.s3.us-west-2.amazonaws.com/whl/torch-2.0.0_cu118.html 59 | 60 | # Install CURope 61 | cd croco/models/curope/ 62 | python setup.py build_ext --inplace 63 | -------------------------------------------------------------------------------- /scripts/eval/eval_ace_zero.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_IDS=$1 # {'0,1,2,...'} 4 | 5 | export PYTHONDONTWRITEBYTECODE=1 6 | export CUDA_VISIBLE_DEVICES=${CUDA_IDS} 7 | 8 | # Default parameters. 9 | DATASET='blender' # [blender, mipnerf360, tanks_and_temples] 10 | ENCODING='ace' 11 | SUFFIX='' 12 | 13 | NUM_CMD_PARAMS=$# 14 | if [ $NUM_CMD_PARAMS -eq 2 ] 15 | then 16 | SUFFIX=$2 17 | elif [ $NUM_CMD_PARAMS -eq 3 ] 18 | then 19 | SUFFIX=$2 20 | DATASET=$3 21 | elif [ $NUM_CMD_PARAMS -eq 4 ] 22 | then 23 | SUFFIX=$2 24 | DATASET=$3 25 | ENCODING=$4 26 | fi 27 | 28 | YAML=${ENCODING}/${DATASET}'.yaml' 29 | echo "Using yaml file: ${YAML}" 30 | 31 | HOME_DIR=$HOME 32 | CODE_ROOT_DIR=$HOME/'Projects/ZeroGS' 33 | 34 | cd $CODE_ROOT_DIR 35 | 36 | python eval.py --config 'config/'${YAML} \ 37 | --suffix $SUFFIX 38 | -------------------------------------------------------------------------------- /scripts/eval/vis_recon.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import numpy as np 6 | import open3d as o3d 7 | 8 | from conerf.datasets.realworld import similarity_from_cameras, normalize_poses 9 | from conerf.datasets.utils import compute_bounding_box3D, points_in_bbox3D 10 | from conerf.pycolmap.pycolmap.scene_manager import SceneManager 11 | from conerf.visualization.scene_visualizer import visualize_single_scene 12 | 13 | 14 | def config_parser(): 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument("--colmap_dir", 18 | type=str, 19 | default="", 20 | help="absolute path of config file") 21 | parser.add_argument("--output_dir", 22 | type=str, 23 | default="", 24 | help="absolute path of config file") 25 | 26 | args = parser.parse_args() 27 | 28 | return args 29 | 30 | 31 | if __name__ == '__main__': 32 | args = config_parser() 33 | rotate = False 34 | 35 | if not os.path.exists(args.output_dir): 36 | os.makedirs(args.output_dir) 37 | 38 | # (1) Loading camera poses and 3D points. 39 | manager = SceneManager(args.colmap_dir, load_points=False) 40 | manager.load() 41 | 42 | ply_path = os.path.join(args.colmap_dir, "points3D.ply") 43 | pcd = o3d.io.read_point_cloud(ply_path) 44 | points = np.asarray(pcd.points) 45 | colors = np.asarray(pcd.colors) 46 | num_points = np.asarray(pcd.points).shape[0] 47 | print(f'num points: {num_points}') 48 | 49 | colmap_image_data = manager.images 50 | colmap_camera_data = manager.cameras 51 | 52 | w2c_mats = [] 53 | bottom = np.array([0, 0, 0, 1]).reshape(1, 4) 54 | for k in colmap_image_data: 55 | im_data = colmap_image_data[k] 56 | w2c = np.concatenate([ 57 | np.concatenate( 58 | [im_data.R(), im_data.tvec.reshape(3, 1)], 1), bottom 59 | ], axis=0) 60 | w2c_mats.append(w2c) 61 | w2c_mats = np.stack(w2c_mats, axis=0) 62 | cam_to_world = np.linalg.inv(w2c_mats) 63 | 64 | # (2) Normalize the scene. 65 | T, scale = similarity_from_cameras( 66 | cam_to_world, strict_scaling=False 67 | ) 68 | cam_to_world = np.einsum("nij, ki -> nkj", cam_to_world, T) 69 | cam_to_world[:, :3, 3:4] *= scale 70 | 71 | points = scale * (T[:3, :3] @ points.T + T[:3, 3][..., None]).T # [Np, 3] 72 | 73 | # (3) Rotate the scene to align with ground plane. 74 | if rotate: 75 | down_pcd = pcd.voxel_down_sample(voxel_size=0.1) 76 | points_for_est_normal = np.asarray(down_pcd.points) 77 | print( 78 | f'num points for estimating normal: {points_for_est_normal.shape}') 79 | cam_to_world, _, R, t = normalize_poses( 80 | torch.from_numpy(cam_to_world).float(), # pylint: disable=E1101 81 | torch.from_numpy(points_for_est_normal).float( 82 | ), # pylint: disable=E1101 83 | up_est_method="ground", 84 | center_est_method="lookat", 85 | ) 86 | cam_to_world = cam_to_world.numpy() 87 | points[:, :] = (R @ points.T + t).T 88 | 89 | # (4) Compute bounding box to exclude points outside the bounding box. 90 | aabb = compute_bounding_box3D( 91 | torch.from_numpy(cam_to_world[..., :, -1]), # pylint: disable=E1101 92 | scale_factor=[7, 7, 7], # [4.0,4.0,4.0] 93 | ).numpy() 94 | valid_point_indices = points_in_bbox3D(points, aabb).reshape(-1) 95 | points = points[valid_point_indices] 96 | colors = colors[valid_point_indices] 97 | colors = np.clip(colors, 0, 1) 98 | print(f'num points: {points.shape[0]}') 99 | 100 | pcd.points = o3d.utility.Vector3dVector(points) 101 | pcd.colors = o3d.utility.Vector3dVector(colors) 102 | 103 | # (5) Downsample points if there are too many. 104 | if num_points > 2000000: 105 | down_pcd = pcd.voxel_down_sample(voxel_size=0.005) 106 | points = np.asarray(down_pcd.points) 107 | colors = np.asarray(down_pcd.colors) 108 | print(f'points shape: {points.shape}') 109 | 110 | pcd.points = o3d.utility.Vector3dVector(points) 111 | pcd.colors = o3d.utility.Vector3dVector(colors) 112 | else: 113 | colors = np.asarray(pcd.colors) 114 | pcd.colors = o3d.utility.Vector3dVector(colors) 115 | 116 | visualize_single_scene( 117 | pcd, 118 | cam_to_world, 119 | size=0.05, 120 | rainbow_color=True, 121 | output_directory=args.output_dir 122 | ) 123 | 124 | video_filename = os.path.join(args.output_dir, "zero_gs_scene.mp4") 125 | os.system(f"ffmpeg -framerate 10 -i {args.output_dir}/screenshot_%05d.png -c:v libx264 " + 126 | f"-pix_fmt yuv420p {video_filename}" 127 | ) 128 | -------------------------------------------------------------------------------- /scripts/preprocess/colmap_mapping.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | DATASET_PATH=$1 5 | OUTPUT_PATH=$2 6 | VOC_TREE_PATH=$3 7 | MOST_SIMILAR_IMAGES_NUM=$4 8 | CUDA_IDS=$5 9 | 10 | NUM_THREADS=24 11 | # export PYTHONDONTWRITEBYTECODE=1 12 | # export CUDA_VISIBLE_DEVICES=${CUDA_IDS} 13 | 14 | COLMAP_DIR=/usr/local/bin 15 | COLMAP_EXE=$COLMAP_DIR/colmap 16 | 17 | mkdir $OUTPUT_PATH/sparse 18 | 19 | $COLMAP_EXE feature_extractor \ 20 | --database_path=$OUTPUT_PATH/database.db \ 21 | --image_path=$DATASET_PATH/images \ 22 | --SiftExtraction.num_threads=$NUM_THREADS \ 23 | --SiftExtraction.use_gpu=1 \ 24 | --SiftExtraction.gpu_index=$CUDA_IDS \ 25 | --SiftExtraction.estimate_affine_shape=true \ 26 | --SiftExtraction.domain_size_pooling=true \ 27 | --ImageReader.camera_model PINHOLE \ 28 | --ImageReader.single_camera 1 \ 29 | --SiftExtraction.max_num_features 8192 \ 30 | > $DATASET_PATH/log_extract_feature.txt 2>&1 31 | 32 | $COLMAP_EXE vocab_tree_matcher \ 33 | --database_path=$OUTPUT_PATH/database.db \ 34 | --SiftMatching.num_threads=$NUM_THREADS \ 35 | --SiftMatching.use_gpu=1 \ 36 | --SiftMatching.gpu_index=$CUDA_IDS \ 37 | --SiftMatching.guided_matching=false \ 38 | --VocabTreeMatching.num_images=$MOST_SIMILAR_IMAGES_NUM \ 39 | --VocabTreeMatching.num_nearest_neighbors=5 \ 40 | --VocabTreeMatching.vocab_tree_path=$VOC_TREE_PATH \ 41 | > $DATASET_PATH/log_match.txt 2>&1 42 | 43 | $COLMAP_EXE mapper $OUTPUT_PATH \ 44 | --database_path=$OUTPUT_PATH/database.db \ 45 | --image_path=$DATASET_PATH/images \ 46 | --output_path=$OUTPUT_PATH/sparse \ 47 | --Mapper.num_threads=$NUM_THREADS \ 48 | > $DATASET_PATH/log_sfm.txt 2>&1 49 | -------------------------------------------------------------------------------- /scripts/preprocess/database.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import sqlite3 3 | 4 | from typing import Dict 5 | 6 | import numpy as np 7 | 8 | IS_PYTHON3 = sys.version_info[0] >= 3 9 | 10 | #------------------------------------------------------------------------------- 11 | # create table commands 12 | 13 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( 14 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 15 | model INTEGER NOT NULL, 16 | width INTEGER NOT NULL, 17 | height INTEGER NOT NULL, 18 | params BLOB, 19 | prior_focal_length INTEGER NOT NULL)""" 20 | 21 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( 22 | image_id INTEGER PRIMARY KEY NOT NULL, 23 | rows INTEGER NOT NULL, 24 | cols INTEGER NOT NULL, 25 | data BLOB, 26 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 27 | 28 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( 29 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 30 | name TEXT NOT NULL UNIQUE, 31 | camera_id INTEGER NOT NULL, 32 | prior_qw REAL, 33 | prior_qx REAL, 34 | prior_qy REAL, 35 | prior_qz REAL, 36 | prior_tx REAL, 37 | prior_ty REAL, 38 | prior_tz REAL, 39 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647), 40 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))""" 41 | 42 | CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries ( 43 | pair_id INTEGER PRIMARY KEY NOT NULL, 44 | rows INTEGER NOT NULL, 45 | cols INTEGER NOT NULL, 46 | data BLOB, 47 | config INTEGER NOT NULL, 48 | F BLOB, 49 | E BLOB, 50 | H BLOB)""" 51 | 52 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( 53 | image_id INTEGER PRIMARY KEY NOT NULL, 54 | rows INTEGER NOT NULL, 55 | cols INTEGER NOT NULL, 56 | data BLOB, 57 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 58 | 59 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( 60 | pair_id INTEGER PRIMARY KEY NOT NULL, 61 | rows INTEGER NOT NULL, 62 | cols INTEGER NOT NULL, 63 | data BLOB)""" 64 | 65 | CREATE_NAME_INDEX = \ 66 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" 67 | 68 | CREATE_ALL = "; ".join([CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE, 69 | CREATE_IMAGES_TABLE, CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE, 70 | CREATE_MATCHES_TABLE, CREATE_NAME_INDEX]) 71 | 72 | 73 | def array_to_blob(array): 74 | if IS_PYTHON3: 75 | return array.tostring() 76 | 77 | return np.getbuffer(array) 78 | 79 | 80 | class COLMAPDatabase(sqlite3.Connection): 81 | @staticmethod 82 | def connect(database_path): 83 | return sqlite3.connect(database_path, factory=COLMAPDatabase) 84 | 85 | 86 | def __init__(self, *args, **kwargs): 87 | super().__init__(*args, **kwargs) 88 | 89 | self.initialize_tables = lambda: self.executescript(CREATE_ALL) 90 | 91 | self.initialize_cameras = \ 92 | lambda: self.executescript(CREATE_CAMERAS_TABLE) 93 | self.initialize_descriptors = \ 94 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) 95 | self.initialize_images = \ 96 | lambda: self.executescript(CREATE_IMAGES_TABLE) 97 | self.initialize_inlier_matches = \ 98 | lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE) 99 | self.initialize_keypoints = \ 100 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE) 101 | self.initialize_matches = \ 102 | lambda: self.executescript(CREATE_MATCHES_TABLE) 103 | 104 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) 105 | 106 | def add_camera(self, model, width, height, params, 107 | prior_focal_length=False, camera_id=None): 108 | params = np.asarray(params, np.float64) 109 | cursor = self.execute( 110 | "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", 111 | (camera_id, model, width, height, array_to_blob(params), 112 | prior_focal_length)) 113 | return cursor.lastrowid 114 | 115 | 116 | def fetch_images_from_database(database_path: str) -> Dict: 117 | db = COLMAPDatabase.connect(database_path) # pylint: disable=[C0103] 118 | rows = db.execute("SELECT * FROM images") 119 | name_to_image_id = {} 120 | for row in rows: 121 | image_id, name = row[0], row[1] 122 | # print(f'image_id: {image_id}, name: {name}') 123 | name_to_image_id[name] = image_id 124 | 125 | return name_to_image_id 126 | -------------------------------------------------------------------------------- /scripts/preprocess/hloc_mapping/filter_matches.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import numpy as np 4 | import networkx as nx 5 | 6 | from matplotlib import pyplot as plt 7 | from scipy.sparse.csgraph import minimum_spanning_tree 8 | 9 | from disambiguation.utils.read_write_database import remove_matches_from_db 10 | from disambiguation.utils.run_colmap import run_matches_importer 11 | 12 | 13 | def draw_graph(scores, plot_path, display=False): 14 | graph = nx.from_numpy_array(scores) 15 | # print(scores) 16 | pos = nx.nx_agraph.graphviz_layout(graph) 17 | edge_vmin = np.percentile(scores[scores.nonzero()], 10) 18 | edge_vmax = np.percentile(scores[scores.nonzero()], 90) 19 | # print(edge_vmin, edge_vmax) 20 | nx.draw( 21 | graph, 22 | pos, 23 | with_labels=True, 24 | edge_color=[graph[u][v]['weight'] for u, v in graph.edges], 25 | # edge_cmap=plt.cm.plasma, 26 | edge_cmap=plt.cm.YlOrRd, 27 | edge_vmin=edge_vmin, 28 | edge_vmax=edge_vmax) 29 | plt.savefig(plot_path) 30 | if display: 31 | plt.show() 32 | plt.close() 33 | return 34 | 35 | 36 | def filter_with_fixed_threshold(scores, thres, plot_path=None): 37 | valid = scores >= thres 38 | invalid = np.logical_not(valid) 39 | scores[invalid] = 0. 40 | if plot_path is not None: 41 | draw_graph(scores, plot_path, display=False) 42 | return valid 43 | 44 | 45 | def filter_with_knn(scores, k, plot_path): 46 | valid = np.zeros_like(scores, dtype=np.bool) 47 | valid_indices = scores.argsort()[:, -k:] 48 | for i in range(scores.shape[0]): 49 | for j in valid_indices[i]: 50 | valid[i, j] = True 51 | invalid = np.logical_not(valid) 52 | scores[invalid] = 0. 53 | if plot_path is not None: 54 | draw_graph(scores, plot_path, display=False) 55 | return valid 56 | 57 | 58 | def filter_with_mst_min(scores, plot_path=None): 59 | min_scores = np.minimum(scores, scores.T) 60 | assert np.allclose(min_scores, min_scores.T) 61 | mst = minimum_spanning_tree(-min_scores) 62 | valid = (-mst).toarray() > 0 63 | invalid = np.logical_not(valid) 64 | scores[invalid] = 0. 65 | if plot_path is not None: 66 | draw_graph(scores, plot_path, display=False) 67 | return valid 68 | 69 | 70 | def filter_with_mst_mean(scores, plot_path=None): 71 | mean_scores = (scores + scores.T) / 2 72 | assert np.allclose(mean_scores, mean_scores.T) 73 | mst = minimum_spanning_tree(-mean_scores) 74 | valid = (-mst).toarray() > 0 75 | invalid = np.logical_not(valid) 76 | scores[invalid] = 0. 77 | if plot_path is not None: 78 | draw_graph(scores, plot_path, display=False) 79 | return valid 80 | 81 | 82 | def filter_with_percentile(scores, percentile, plot_path=None): 83 | num_images = scores.shape[0] 84 | thres = np.zeros((num_images, 1)) 85 | for i in range(num_images): 86 | thres[i] = np.percentile(scores[i, scores[i].nonzero()], percentile) 87 | valid = scores >= thres 88 | invalid = np.logical_not(valid) 89 | scores[invalid] = 0. 90 | if plot_path is not None: 91 | draw_graph(scores, plot_path, display=False) 92 | return valid 93 | 94 | 95 | def main(colmap_path: str, 96 | results_path: str, 97 | filter_type: str, 98 | threshold: float, 99 | scores_dir: Path, 100 | scores_name: str, 101 | topk: int, 102 | percentile: float, 103 | old_db_path: str, 104 | new_db_path: str, 105 | geometric_verification_type: str): 106 | scores_path = scores_dir / scores_name 107 | scores = np.load(scores_path) 108 | 109 | # valid = scores >= args.threshold 110 | if filter_type == 'threshold': 111 | assert threshold is not None 112 | output_path = results_path / ('sparse' + scores_name[6:-4] + 113 | f'_t{threshold:.2f}') 114 | output_path.mkdir(exist_ok=True) 115 | plot_path = output_path / 'match_graph.png' 116 | match_list_path = results_path / ( 117 | 'match_list' + scores_name[6:-4] + f'_t{threshold}.txt') 118 | valid = filter_with_fixed_threshold(scores, threshold, plot_path) 119 | elif filter_type == 'knn': 120 | assert topk is not None 121 | output_path = results_path / ('sparse' + scores_name[6:-4] + 122 | f'_k{topk}') 123 | output_path.mkdir(exist_ok=True) 124 | plot_path = output_path / 'match_graph.png' 125 | match_list_path = results_path / ( 126 | 'match_list' + scores_name[6:-4] + f'_k{topk}.txt') 127 | valid = filter_with_knn(scores, topk, plot_path) 128 | elif filter_type == 'percentile': 129 | assert percentile is not None 130 | output_path = results_path / ('sparse' + scores_name[6:-4] + 131 | f'_p{percentile}') 132 | output_path.mkdir(exist_ok=True) 133 | plot_path = output_path / 'match_graph.png' 134 | match_list_path = results_path / ( 135 | 'match_list' + scores_name[6:-4] + f'_p{percentile}.txt') 136 | valid = filter_with_percentile(scores, percentile, plot_path) 137 | elif filter_type == 'mst_min': 138 | output_path = results_path / ('sparse' + scores_name[6:-4] + 139 | '_mst_min') 140 | output_path.mkdir(exist_ok=True) 141 | plot_path = output_path / 'match_graph.png' 142 | match_list_path = results_path / ( 143 | 'match_list' + scores_name[6:-4] + '_mst_min.txt') 144 | valid = filter_with_mst_min(scores, plot_path) 145 | # we don't do reconstruction based with mst graph as it is too sparse. 146 | # use it for visualization only 147 | exit(0) 148 | elif filter_type == 'mst_mean': 149 | output_path = results_path / ('sparse' + scores_name[6:-4] + 150 | '_mst_mean') 151 | output_path.mkdir(exist_ok=True) 152 | plot_path = output_path / 'match_graph.png' 153 | match_list_path = results_path / ( 154 | 'match_list' + scores_name[6:-4] + '_mst_mean.txt') 155 | valid = filter_with_mst_mean(scores, plot_path) 156 | # we don't do reconstruction based with mst graph as it is too sparse. 157 | # use it for visualization only 158 | exit(0) 159 | else: 160 | raise NotImplementedError 161 | 162 | remove_matches_from_db(old_db_path, new_db_path, match_list_path, valid) 163 | run_matches_importer(colmap_path, 164 | new_db_path, 165 | match_list_path, 166 | use_gpu=False, 167 | colmap_matching_type=geometric_verification_type) 168 | 169 | 170 | if __name__ == '__main__': 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument('--dataset_dir', type=Path, default='datasets', 173 | help='Path to the dataset, default: %(default)s') 174 | parser.add_argument('--results_path', type=Path, default='outputs', 175 | help='Path to the output directory, default: %(default)s') 176 | parser.add_argument('--scores_name', type=str, required=True, 177 | default='yan', choices=['yan', 'cui']) 178 | parser.add_argument('--filter_type', 179 | type=str, 180 | choices=['threshold', 'knn', 'mst_min', 'mst_mean', 'percentile']) 181 | parser.add_argument('--threshold', type=float) 182 | parser.add_argument('--topk', type=int) 183 | parser.add_argument('--percentile', type=float) 184 | parser.add_argument('--colmap_path', type=Path, default='colmap') 185 | parser.add_argument('--old_db_path', type=str, Required=True) 186 | parser.add_argument('--new_db_path', type=str, Required=True) 187 | parser.add_argument('--geometric_verification_type', 188 | type=str, 189 | required=True, 190 | choices=['default', 'strict']) 191 | 192 | 193 | args = parser.parse_args() 194 | 195 | main(args.colmap_path, args.results_path, args.filter_type, args.threshold, 196 | args.scores_name, args.topk, args.percentile, args.old_db_path, 197 | args.new_db_path, args.geometric_verification_type) 198 | -------------------------------------------------------------------------------- /scripts/preprocess/hloc_mapping/match_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Union, Optional, Dict, List, Tuple 3 | from pathlib import Path 4 | import pprint 5 | import collections.abc as collections 6 | from tqdm import tqdm 7 | import h5py 8 | import torch 9 | 10 | from hloc import matchers, logger 11 | from hloc.utils.base_model import dynamic_load 12 | from hloc.utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval 13 | from hloc.utils.io import list_h5_names 14 | 15 | 16 | ''' 17 | A set of standard configurations that can be directly selected from the command 18 | line using their name. Each is a dictionary with the following entries: 19 | - output: the name of the match file that will be generated. 20 | - model: the model configuration, as passed to a feature matcher. 21 | ''' 22 | confs = { 23 | 'superglue': { 24 | 'output': 'matches-superglue', 25 | 'model': { 26 | 'name': 'superglue', 27 | 'weights': 'outdoor', 28 | 'sinkhorn_iterations': 50, 29 | }, 30 | }, 31 | 'superglue-fast': { 32 | 'output': 'matches-superglue-it5', 33 | 'model': { 34 | 'name': 'superglue', 35 | 'weights': 'outdoor', 36 | 'sinkhorn_iterations': 5, 37 | }, 38 | }, 39 | 'NN-superpoint': { 40 | 'output': 'matches-NN-mutual-dist.7', 41 | 'model': { 42 | 'name': 'nearest_neighbor', 43 | 'do_mutual_check': True, 44 | 'distance_threshold': 0.7, 45 | }, 46 | }, 47 | 'NN-ratio': { 48 | 'output': 'matches-NN-mutual-ratio.8', 49 | 'model': { 50 | 'name': 'nearest_neighbor', 51 | 'do_mutual_check': True, 52 | 'ratio_threshold': 0.8, 53 | } 54 | }, 55 | 'NN-mutual': { 56 | 'output': 'matches-NN-mutual', 57 | 'model': { 58 | 'name': 'nearest_neighbor', 59 | 'do_mutual_check': True, 60 | }, 61 | } 62 | } 63 | 64 | 65 | def main(conf: Dict, 66 | pairs: Path, features: Union[Path, str], 67 | export_dir: Optional[Path] = None, 68 | matches: Optional[Path] = None, 69 | features_ref: Optional[Path] = None, 70 | overwrite: bool = False, 71 | device='cuda') -> Path: 72 | 73 | if isinstance(features, Path) or Path(features).exists(): 74 | features_q = features 75 | if matches is None: 76 | raise ValueError('Either provide both features and matches as Path' 77 | ' or both as names.') 78 | else: 79 | if export_dir is None: 80 | raise ValueError('Provide an export_dir if features is not' 81 | f' a file path: {features}.') 82 | features_q = Path(export_dir, features+'.h5') 83 | if matches is None: 84 | matches = Path( 85 | export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5') 86 | 87 | if features_ref is None: 88 | features_ref = features_q 89 | if isinstance(features_ref, collections.Iterable): 90 | features_ref = list(features_ref) 91 | else: 92 | features_ref = [features_ref] 93 | 94 | match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite, device) 95 | 96 | return matches 97 | 98 | 99 | def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None): 100 | '''Avoid to recompute duplicates to save time.''' 101 | pairs = set() 102 | for i, j in pairs_all: 103 | if (j, i) not in pairs: 104 | pairs.add((i, j)) 105 | pairs = list(pairs) 106 | if match_path is not None and match_path.exists(): 107 | with h5py.File(str(match_path), 'r') as fd: 108 | pairs_filtered = [] 109 | for i, j in pairs: 110 | if (names_to_pair(i, j) in fd or 111 | names_to_pair(j, i) in fd or 112 | names_to_pair_old(i, j) in fd or 113 | names_to_pair_old(j, i) in fd): 114 | continue 115 | pairs_filtered.append((i, j)) 116 | return pairs_filtered 117 | return pairs 118 | 119 | 120 | @torch.no_grad() 121 | def match_from_paths(conf: Dict, 122 | pairs_path: Path, 123 | match_path: Path, 124 | feature_path_q: Path, 125 | feature_paths_refs: Path, 126 | overwrite: bool = False, 127 | device='cuda') -> Path: 128 | logger.info('Matching local features with configuration:' 129 | f'\n{pprint.pformat(conf)}') 130 | 131 | if not feature_path_q.exists(): 132 | raise FileNotFoundError(f'Query feature file {feature_path_q}.') 133 | for path in feature_paths_refs: 134 | if not path.exists(): 135 | raise FileNotFoundError(f'Reference feature file {path}.') 136 | name2ref = {n: i for i, p in enumerate(feature_paths_refs) 137 | for n in list_h5_names(p)} 138 | match_path.parent.mkdir(exist_ok=True, parents=True) 139 | 140 | assert pairs_path.exists(), pairs_path 141 | pairs = parse_retrieval(pairs_path) 142 | pairs = [(q, r) for q, rs in pairs.items() for r in rs] 143 | pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) 144 | if len(pairs) == 0: 145 | logger.info('Skipping the matching.') 146 | return 147 | 148 | # device = 'cuda' if torch.cuda.is_available() else 'cpu' 149 | Model = dynamic_load(matchers, conf['model']['name']) 150 | model = Model(conf['model']).eval().to(device) 151 | 152 | for (name0, name1) in tqdm(pairs, smoothing=.1): 153 | data = {} 154 | with h5py.File(str(feature_path_q), 'r') as fd: 155 | grp = fd[name0] 156 | for k, v in grp.items(): 157 | data[k+'0'] = torch.from_numpy(v.__array__()).float().to(device) 158 | # some matchers might expect an image but only use its size 159 | data['image0'] = torch.empty((1,)+tuple(grp['image_size'])[::-1]) 160 | with h5py.File(str(feature_paths_refs[name2ref[name1]]), 'r') as fd: 161 | grp = fd[name1] 162 | for k, v in grp.items(): 163 | data[k+'1'] = torch.from_numpy(v.__array__()).float().to(device) 164 | data['image1'] = torch.empty((1,)+tuple(grp['image_size'])[::-1]) 165 | data = {k: v[None] for k, v in data.items()} 166 | 167 | pred = model(data) 168 | pair = names_to_pair(name0, name1) 169 | with h5py.File(str(match_path), 'a') as fd: 170 | if pair in fd: 171 | del fd[pair] 172 | grp = fd.create_group(pair) 173 | matches = pred['matches0'][0].cpu().short().numpy() 174 | grp.create_dataset('matches0', data=matches) 175 | 176 | if 'matching_scores0' in pred: 177 | scores = pred['matching_scores0'][0].cpu().half().numpy() 178 | grp.create_dataset('matching_scores0', data=scores) 179 | 180 | logger.info('Finished exporting matches.') 181 | 182 | 183 | if __name__ == '__main__': 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument('--pairs', type=Path, required=True) 186 | parser.add_argument('--export_dir', type=Path) 187 | parser.add_argument('--features', type=str, 188 | default='feats-superpoint-n4096-r1024') 189 | parser.add_argument('--matches', type=Path) 190 | parser.add_argument('--conf', type=str, default='superglue', 191 | choices=list(confs.keys())) 192 | args = parser.parse_args() 193 | main(confs[args.conf], args.pairs, args.features, args.export_dir) 194 | -------------------------------------------------------------------------------- /scripts/preprocess/hloc_mapping/pairs_from_retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | from typing import Optional 5 | import h5py 6 | import numpy as np 7 | import torch 8 | import collections.abc as collections 9 | 10 | from hloc import logger 11 | from hloc.utils.parsers import parse_image_lists 12 | from hloc.utils.read_write_model import read_images_binary 13 | from hloc.utils.io import list_h5_names 14 | 15 | 16 | def parse_names(prefix, names, names_all): 17 | if prefix is not None: 18 | if not isinstance(prefix, str): 19 | prefix = tuple(prefix) 20 | names = [n for n in names_all if n.startswith(prefix)] 21 | elif names is not None: 22 | if isinstance(names, (str, Path)): 23 | names = parse_image_lists(names) 24 | elif isinstance(names, collections.Iterable): 25 | names = list(names) 26 | else: 27 | raise ValueError(f'Unknown type of image list: {names}.' 28 | 'Provide either a list or a path to a list file.') 29 | else: 30 | names = names_all 31 | return names 32 | 33 | 34 | def get_descriptors(names, path, name2idx=None, key='global_descriptor'): 35 | if name2idx is None: 36 | with h5py.File(str(path), 'r') as fd: 37 | desc = [fd[n][key].__array__() for n in names] 38 | else: 39 | desc = [] 40 | for n in names: 41 | with h5py.File(str(path[name2idx[n]]), 'r') as fd: 42 | desc.append(fd[n][key].__array__()) 43 | return torch.from_numpy(np.stack(desc, 0)).float() 44 | 45 | 46 | def pairs_from_score_matrix(scores: torch.Tensor, 47 | invalid: np.array, 48 | num_select: int, 49 | min_score: Optional[float] = None): 50 | assert scores.shape == invalid.shape 51 | if isinstance(scores, np.ndarray): 52 | scores = torch.from_numpy(scores) 53 | invalid = torch.from_numpy(invalid).to(scores.device) 54 | if min_score is not None: 55 | invalid |= scores < min_score 56 | scores.masked_fill_(invalid, float('-inf')) 57 | 58 | topk = torch.topk(scores, num_select, dim=1) 59 | indices = topk.indices.cpu().numpy() 60 | valid = topk.values.isfinite().cpu().numpy() 61 | 62 | pairs = [] 63 | for i, j in zip(*np.where(valid)): 64 | pairs.append((i, indices[i, j])) 65 | return pairs 66 | 67 | 68 | def get_query_names( 69 | descriptors, 70 | query_prefix=None, query_list=None, 71 | db_prefix=None, db_list=None, db_model=None, db_descriptors=None, 72 | ): 73 | # We handle multiple reference feature files. 74 | # We only assume that names are unique among them and map names to files. 75 | if db_descriptors is None: 76 | db_descriptors = descriptors 77 | if isinstance(db_descriptors, (Path, str)): 78 | db_descriptors = [db_descriptors] 79 | name2db = {n: i for i, p in enumerate(db_descriptors) 80 | for n in list_h5_names(p)} 81 | 82 | db_names_h5 = list(name2db.keys()) 83 | db_names_h5 = sorted(db_names_h5) 84 | 85 | query_names_h5 = list_h5_names(descriptors) 86 | query_names_h5 = sorted(query_names_h5) 87 | 88 | if db_model: 89 | images = read_images_binary(os.path.join(db_model, 'images.bin')) 90 | db_names = [i.name for i in images.values()] 91 | else: 92 | db_names = parse_names(db_prefix, db_list, db_names_h5) 93 | 94 | num_images = len(db_names) 95 | if num_images == 0: 96 | raise ValueError('Could not find any database image.') 97 | query_names = parse_names(query_prefix, query_list, query_names_h5) 98 | 99 | return db_names, db_descriptors, query_names, name2db 100 | 101 | 102 | def compute_similarity_score( 103 | descriptors, output, 104 | db_names, db_descriptors, query_names, name2db, 105 | device='cuda'): 106 | logger.info('Extracting image pairs from a retrieval database.') 107 | 108 | db_desc = get_descriptors(db_names, db_descriptors, name2db) 109 | query_desc = get_descriptors(query_names, descriptors) 110 | sim = torch.einsum('id,jd->ij', query_desc.to(device), db_desc.to(device)) 111 | 112 | torch.save(sim, output) 113 | 114 | return sim 115 | 116 | 117 | def main(descriptors, output, num_matched, 118 | query_prefix=None, query_list=None, 119 | db_prefix=None, db_list=None, db_model=None, db_descriptors=None, 120 | device='cuda'): 121 | logger.info('Extracting image pairs from a retrieval database.') 122 | 123 | # We handle multiple reference feature files. 124 | # We only assume that names are unique among them and map names to files. 125 | if db_descriptors is None: 126 | db_descriptors = descriptors 127 | if isinstance(db_descriptors, (Path, str)): 128 | db_descriptors = [db_descriptors] 129 | name2db = {n: i for i, p in enumerate(db_descriptors) 130 | for n in list_h5_names(p)} 131 | db_names_h5 = list(name2db.keys()) 132 | query_names_h5 = list_h5_names(descriptors) 133 | 134 | if db_model: 135 | images = read_images_binary(os.path.join(db_model, 'images.bin')) 136 | db_names = [i.name for i in images.values()] 137 | else: 138 | db_names = parse_names(db_prefix, db_list, db_names_h5) 139 | 140 | num_images = len(db_names) 141 | if num_images == 0: 142 | raise ValueError('Could not find any database image.') 143 | query_names = parse_names(query_prefix, query_list, query_names_h5) 144 | 145 | # device = 'cuda' if torch.cuda.is_available() else 'cpu' 146 | db_desc = get_descriptors(db_names, db_descriptors, name2db) 147 | query_desc = get_descriptors(query_names, descriptors) 148 | sim = torch.einsum('id,jd->ij', query_desc.to(device), db_desc.to(device)) 149 | 150 | # Avoid self-matching 151 | self = np.array(query_names)[:, None] == np.array(db_names)[None] 152 | num_matched = min(num_images, num_matched) 153 | pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) 154 | pairs = [(query_names[i], db_names[j]) for i, j in pairs] 155 | 156 | logger.info(f'Found {len(pairs)} pairs.') 157 | with open(output, 'w') as f: 158 | f.write('\n'.join(' '.join([i, j]) for i, j in pairs)) 159 | 160 | 161 | if __name__ == "__main__": 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument('--descriptors', type=Path, required=True) 164 | parser.add_argument('--output', type=Path, required=True) 165 | parser.add_argument('--num_matched', type=int, required=True) 166 | parser.add_argument('--query_prefix', type=str, nargs='+') 167 | parser.add_argument('--query_list', type=Path) 168 | parser.add_argument('--db_prefix', type=str, nargs='+') 169 | parser.add_argument('--db_list', type=Path) 170 | parser.add_argument('--db_model', type=Path) 171 | parser.add_argument('--db_descriptors', type=Path) 172 | args = parser.parse_args() 173 | main(**args.__dict__) 174 | -------------------------------------------------------------------------------- /scripts/preprocess/hloc_mapping/reconstruction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from typing import Optional, List 4 | import multiprocessing 5 | from pathlib import Path 6 | import pycolmap 7 | 8 | from hloc import logger 9 | from hloc.utils.database import COLMAPDatabase 10 | from hloc.triangulation import ( 11 | import_features, import_matches, geometric_verification, OutputCapture) 12 | 13 | 14 | def create_empty_db(database_path): 15 | if database_path.exists(): 16 | logger.warning('The database already exists, deleting it.') 17 | database_path.unlink() 18 | logger.info('Creating an empty database...') 19 | db = COLMAPDatabase.connect(database_path) 20 | db.create_tables() 21 | db.commit() 22 | db.close() 23 | 24 | 25 | def import_images(image_dir, database_path, camera_mode, image_list=None): 26 | logger.info('Importing images into the database...') 27 | images = list(image_dir.iterdir()) 28 | if len(images) == 0: 29 | raise IOError(f'No images found in {image_dir}.') 30 | with pycolmap.ostream(): 31 | pycolmap.import_images(database_path, image_dir, camera_mode, 32 | image_list=image_list or []) 33 | 34 | 35 | def get_image_ids(database_path): 36 | db = COLMAPDatabase.connect(database_path) 37 | images = {} 38 | for name, image_id in db.execute("SELECT name, image_id FROM images;"): 39 | images[name] = image_id 40 | db.close() 41 | return images 42 | 43 | 44 | def run_reconstruction(sfm_dir, database_path, image_dir, verbose=False): 45 | models_path = sfm_dir / 'models' 46 | models_path.mkdir(exist_ok=True, parents=True) 47 | logger.info('Running 3D reconstruction...') 48 | with OutputCapture(verbose): 49 | with pycolmap.ostream(): 50 | reconstructions = pycolmap.incremental_mapping( 51 | database_path, image_dir, models_path, 52 | num_threads=min(multiprocessing.cpu_count(), 16)) 53 | 54 | if len(reconstructions) == 0: 55 | logger.error('Could not reconstruct any model!') 56 | return None 57 | logger.info(f'Reconstructed {len(reconstructions)} model(s).') 58 | 59 | largest_index = None 60 | largest_num_images = 0 61 | for index, rec in reconstructions.items(): 62 | num_images = rec.num_reg_images() 63 | if num_images > largest_num_images: 64 | largest_index = index 65 | largest_num_images = num_images 66 | assert largest_index is not None 67 | logger.info(f'Largest model is #{largest_index} ' 68 | f'with {largest_num_images} images.') 69 | 70 | for filename in ['images.bin', 'cameras.bin', 'points3D.bin']: 71 | if (sfm_dir / filename).exists(): 72 | (sfm_dir / filename).unlink() 73 | shutil.move( 74 | str(models_path / str(largest_index) / filename), str(models_path)) 75 | return reconstructions[largest_index] 76 | 77 | 78 | def main(database, output_dir, image_dir, pairs, features, matches, 79 | camera_mode=pycolmap.CameraMode.AUTO, verbose=False, 80 | skip_geometric_verification=False, min_match_score=None, 81 | image_list: Optional[List[str]] = None): 82 | 83 | assert features.exists(), features 84 | assert pairs.exists(), pairs 85 | assert matches.exists(), matches 86 | 87 | output_dir.mkdir(parents=True, exist_ok=True) 88 | 89 | # create_empty_db(database) 90 | # import_images(image_dir, database, camera_mode, image_list) 91 | image_ids = get_image_ids(database) 92 | # import_features(image_ids, database, features) 93 | # import_matches(image_ids, database, pairs, matches, 94 | # min_match_score, skip_geometric_verification) 95 | # if not skip_geometric_verification: 96 | # geometric_verification(database, pairs, verbose) 97 | reconstruction = run_reconstruction(output_dir, database, image_dir, verbose) 98 | if reconstruction is not None: 99 | logger.info(f'Reconstruction statistics:\n{reconstruction.summary()}' 100 | + f'\n\tnum_input_images = {len(image_ids)}') 101 | return reconstruction 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--output_dir', type=Path, required=True) 107 | parser.add_argument('--image_dir', type=Path, required=True) 108 | 109 | parser.add_argument('--pairs', type=Path, required=True) 110 | parser.add_argument('--features', type=Path, required=True) 111 | parser.add_argument('--matches', type=Path, required=True) 112 | 113 | parser.add_argument('--camera_mode', type=str, default="AUTO", 114 | choices=list(pycolmap.CameraMode.__members__.keys())) 115 | parser.add_argument('--skip_geometric_verification', action='store_true') 116 | parser.add_argument('--min_match_score', type=float) 117 | parser.add_argument('--verbose', action='store_true') 118 | args = parser.parse_args() 119 | 120 | main(**args.__dict__) 121 | -------------------------------------------------------------------------------- /scripts/preprocess/hloc_mapping/sfm_pipeline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from scripts.preprocess.hloc_mapping import extract_relative_poses 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--dataset_dir', type=Path, default='datasets', 10 | help='Path to the dataset, default: %(default)s') 11 | parser.add_argument('--outputs', type=Path, default='outputs', 12 | help='Path to the output directory, default: %(default)s') 13 | parser.add_argument('--num_matches', type=int, default=30, 14 | help='Number of image pairs for loc, default: %(default)s') 15 | parser.add_argument('--disambiguate', action="store_true", 16 | help='Enable/Disable disambiguating wrong matches.') 17 | parser.add_argument('--min_track_length', type=int, default=3) 18 | parser.add_argument('--max_track_length', type=int, default=40) 19 | parser.add_argument('--track_degree', type=int, default=3) 20 | parser.add_argument('--coverage_thres', type=float, default=0.9) 21 | parser.add_argument('--alpha', type=float, default=0.1) 22 | parser.add_argument('--minimal_views', type=int, default=5) 23 | parser.add_argument('--ds', type=str, 24 | choices=['dict', 'smallarray', 'largearray'], 25 | default='largearray') 26 | parser.add_argument('--filter_type', type=str, choices=[ 27 | 'threshold', 'knn', 'mst_min', 'mst_mean', 'percentile'], 28 | default='threshold') 29 | parser.add_argument('--threshold', type=float, default=0.15) 30 | parser.add_argument('--topk', type=int, default=3) 31 | parser.add_argument('--percentile', type=float) 32 | parser.add_argument('--colmap_path', type=Path, default='colmap') 33 | parser.add_argument('--geometric_verification_type', 34 | type=str, 35 | choices=['default', 'strict'], 36 | default='default') 37 | parser.add_argument('--recon', action="store_true", 38 | help='Indicates whether to reconstruct the scene.') 39 | parser.add_argument('--visualize', action="store_true", 40 | help='Whether to visualize the reconstruction.') 41 | parser.add_argument('--gpu_idx', type=str, default='0') 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def main(): 47 | args = parse_args() 48 | # Extracting relative poses and store as g2o file. 49 | view_graph_path, database_path, num_view_pairs = extract_relative_poses.main(args=args) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /scripts/preprocess/hloc_mapping/triangulate_from_existing_model.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | import argparse 4 | import contextlib 5 | 6 | from typing import Optional, List, Dict, Any 7 | from pathlib import Path 8 | 9 | import pycolmap 10 | 11 | 12 | class OutputCapture: 13 | def __init__(self, verbose: bool): 14 | self.verbose = verbose 15 | 16 | def __enter__(self): 17 | if not self.verbose: 18 | self.capture = contextlib.redirect_stdout(io.StringIO()) # pylint: disable=W0201 19 | self.out = self.capture.__enter__() # pylint: disable=W0201 20 | 21 | def __exit__(self, exc_type, *args): 22 | if not self.verbose: 23 | self.capture.__exit__(exc_type, *args) 24 | if exc_type is not None: 25 | print('Failed with output:\n%s', self.out.getvalue()) 26 | sys.stdout.flush() 27 | 28 | 29 | def run_triangulation( 30 | output_path: Path, 31 | database_path: Path, 32 | image_dir: Path, 33 | reference_model: pycolmap.Reconstruction, 34 | verbose: bool = False, 35 | options: Optional[Dict[str, Any]] = None, 36 | ) -> pycolmap.Reconstruction: 37 | output_path.mkdir(parents=True, exist_ok=True) 38 | print('Running 3D triangulation...') 39 | if options is None: 40 | options = {} 41 | with OutputCapture(verbose): 42 | with pycolmap.ostream(): 43 | reconstruction = pycolmap.triangulate_points( 44 | reference_model, database_path, image_dir, output_path) 45 | return reconstruction 46 | 47 | 48 | def main( 49 | sfm_dir: Path, 50 | reference_model: Path, 51 | image_dir: Path, 52 | output_dir: Path, 53 | verbose: bool = False, 54 | mapper_options: Optional[Dict[str, Any]] = None, 55 | ) -> pycolmap.Reconstruction: 56 | 57 | assert reference_model.exists(), reference_model 58 | 59 | sfm_dir.mkdir(parents=True, exist_ok=True) 60 | database_path = sfm_dir / 'database.db' 61 | reference_model = pycolmap.Reconstruction(reference_model) 62 | 63 | reconstruction = run_triangulation(output_dir, database_path, image_dir, reference_model, 64 | verbose, mapper_options) 65 | print('Finished the triangulation with statistics:\n%s', 66 | reconstruction.summary()) 67 | return reconstruction 68 | 69 | 70 | def parse_option_args(args: List[str], default_options) -> Dict[str, Any]: 71 | options = {} 72 | for arg in args: 73 | idx = arg.find('=') 74 | if idx == -1: 75 | raise ValueError('Options format: key1=value1 key2=value2 etc.') 76 | key, value = arg[:idx], arg[idx+1:] 77 | if not hasattr(default_options, key): 78 | raise ValueError( 79 | f'Unknown option "{key}", allowed options and default values' 80 | f' for {default_options.summary()}') 81 | value = eval(value) # pylint: disable=W0123 82 | target_type = type(getattr(default_options, key)) 83 | if not isinstance(value, target_type): 84 | raise ValueError(f'Incorrect type for option "{key}":' 85 | f' {type(value)} vs {target_type}') 86 | options[key] = value 87 | return options 88 | 89 | 90 | if __name__ == '__main__': 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--sfm_dir', type=Path, required=True) 93 | parser.add_argument('--reference_model', type=Path, required=True) 94 | parser.add_argument('--image_dir', type=Path, required=True) 95 | parser.add_argument('--output_dir', type=Path, required=True) 96 | parser.add_argument('--verbose', action='store_true') 97 | args = parser.parse_args().__dict__ 98 | 99 | # mapper_options = parse_option_args( 100 | # args.pop("mapper_options"), pycolmap.IncrementalMapperOptions()) 101 | mapper_options = pycolmap.IncrementalMapperOptions() 102 | 103 | main(**args, mapper_options=mapper_options) 104 | -------------------------------------------------------------------------------- /scripts/preprocess/hloc_mapping/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import cv2 4 | from tqdm import tqdm 5 | 6 | from hloc.utils.database import COLMAPDatabase, blob_to_array 7 | from hloc import logger 8 | from hloc.utils.io import get_matches 9 | 10 | 11 | def import_matches(image_ids, database_path, pairs_path, matches_path, 12 | min_match_score=None, skip_geometric_verification=False 13 | ) -> int : 14 | logger.info('Importing matches into the database...') 15 | 16 | with open(str(pairs_path), 'r') as f: 17 | pairs = [p.split() for p in f.readlines()] 18 | 19 | db = COLMAPDatabase.connect(database_path) 20 | 21 | matched = set() 22 | for name0, name1 in tqdm(pairs): 23 | id0, id1 = image_ids[name0], image_ids[name1] 24 | if len({(id0, id1), (id1, id0)} & matched) > 0: 25 | continue 26 | matches, scores = get_matches(matches_path, name0, name1) 27 | if min_match_score: 28 | matches = matches[scores > min_match_score] 29 | db.add_matches(id0, id1, matches) 30 | matched |= {(id0, id1), (id1, id0)} 31 | 32 | if skip_geometric_verification: 33 | db.add_two_view_geometry(id0, id1, matches) 34 | 35 | db.commit() 36 | db.close() 37 | return len(pairs) 38 | 39 | 40 | def read_camera_intrinsics_by_image_id(image_id: int, db: COLMAPDatabase): 41 | rows = db.execute(f'SELECT camera_id FROM images WHERE image_id={image_id}') 42 | camera_id = next(rows)[0] 43 | rows = db.execute(f'SELECT params FROM cameras WHERE camera_id={camera_id}') 44 | params = blob_to_array(next(rows)[0], dtype=np.float64) 45 | 46 | # FIXME(chenyu): when camera model is not a simple pinhole. 47 | intrinsics = np.zeros((3, 3), dtype=np.float64) 48 | intrinsics[0, 0] = intrinsics[1, 1] = params[0] 49 | intrinsics[0, 2], intrinsics[1, 2] = params[1], params[2] 50 | intrinsics[2, 2] = 1. 51 | return intrinsics 52 | 53 | 54 | def read_all_keypoints(db: COLMAPDatabase): 55 | keypoints_dict = dict( 56 | (image_id, blob_to_array(data, np.float32, (-1, 2))) 57 | for image_id, data in db.execute( 58 | "SELECT image_id, data FROM keypoints")) 59 | return keypoints_dict 60 | 61 | 62 | def extract_inlier_keypoints_pair(inlier_matches, keypoints1, keypoints2): 63 | inlier_keypoints1, inlier_keypoints2 = [], [] 64 | num_inliers = inlier_matches.shape[0] 65 | for i in range(num_inliers): 66 | idx = inlier_matches[i] 67 | inlier_keypoints1.append(keypoints1[idx[0]]) 68 | inlier_keypoints2.append(keypoints2[idx[1]]) 69 | 70 | inlier_keypoints1 = np.stack(inlier_keypoints1, axis=0) 71 | inlier_keypoints2 = np.stack(inlier_keypoints2, axis=0) 72 | return inlier_keypoints1, inlier_keypoints2 73 | 74 | 75 | def triangulate(inlier_keypoints1, inlier_keypoints2, 76 | extrinsics1: np.ndarray, extrinsics2: np.ndarray, 77 | intrinsics1: np.ndarray, intrinsics2: np.ndarray): 78 | proj_mtx1 = np.matmul(intrinsics1, extrinsics1) 79 | proj_mtx2 = np.matmul(intrinsics2, extrinsics2) 80 | 81 | points3d = cv2.triangulatePoints(projMatr1=proj_mtx1, projMatr2=proj_mtx2, 82 | projPoints1=inlier_keypoints1.transpose(1, 0), 83 | projPoints2=inlier_keypoints2.transpose(1, 0)) 84 | points3d = points3d.transpose(1, 0) 85 | points3d = points3d[:, :3] / points3d[:, 3].reshape(-1, 1) 86 | return points3d 87 | 88 | 89 | def compute_depth(proj_matrix, point3d): 90 | homo_point3d = np.ones(4) 91 | homo_point3d[0:3] = point3d 92 | proj_z = np.dot(proj_matrix[2, :].T, homo_point3d) 93 | return proj_z * np.linalg.norm(proj_matrix[:, 2], ord=2) 94 | 95 | 96 | def check_cheirality(inlier_keypoints1, inlier_keypoints2, 97 | extrinsic1: np.ndarray, extrinsic2: np.ndarray, 98 | intrinsics1: np.ndarray, intrinsics2: np.ndarray): 99 | min_depth = 1e-16 100 | max_depth = 1000 * np.linalg.norm( 101 | np.dot(extrinsic2[:3, :3].T, extrinsic2[:, 3]), ord=2) 102 | points3d = [] 103 | 104 | tmp_points3d = triangulate(inlier_keypoints1, inlier_keypoints2, 105 | extrinsic1, extrinsic2, intrinsics1, intrinsics2) 106 | for point3d in tmp_points3d: 107 | # Checking for positive depth in front of both cameras. 108 | depth1 = compute_depth(extrinsic1, point3d) 109 | if depth1 < max_depth and depth1 > min_depth: 110 | depth2 = compute_depth(extrinsic2, point3d) 111 | if depth2 < max_depth and depth2 > min_depth: 112 | points3d.append(point3d) 113 | 114 | return points3d 115 | 116 | 117 | def decompose_essential_matrix( 118 | keypoints1, keypoints2, 119 | essential_matrix, inlier_matches, 120 | intrinsics1: np.ndarray, intrinsics2: np.ndarray 121 | ) -> (np.ndarray, np.ndarray): 122 | """ 123 | Assume that the image_id1 is at [I|0] and second image_id2 is at [R|t] 124 | where R, t are derived from the essential matrix. 125 | 126 | Args: 127 | keypoints1: keypoints locations of image1 128 | keypoints1: keypoints locations of image2 129 | essential_matrix: 3 x 3 numpy array, 130 | inlier_matches: matched keypoints indices between image1 and image2 131 | intrinsics1: 3 x 3 numpy array for image1 132 | intrinsics2: 3 x 3 numpy array for image2 133 | 134 | Returns: 135 | extrinsic matrix of shape (1, 12) from image 1 to image 2 136 | """ 137 | 138 | inlier_keypoints1, inlier_keypoints2 = extract_inlier_keypoints_pair( 139 | inlier_matches, keypoints1, keypoints2) 140 | # print(f'{inlier_keypoints1.shape}') 141 | # print(f'{inlier_keypoints2.shape}') 142 | 143 | extrinsic1 = np.zeros(shape=[3, 4], dtype=np.float64) 144 | extrinsic1[:3, :3] = np.eye(3) 145 | # relative motion from camera1 to camera2. 146 | extrinsics2 = np.zeros(shape=[3, 4], dtype=np.float64) 147 | 148 | W = np.zeros((3, 3)) 149 | W[0, 1], W[1, 0], W[2, 2] = -1, 1, 1 150 | U, _, Vh = np.linalg.svd(essential_matrix) 151 | 152 | if np.linalg.det(U) < 0: 153 | U *= -1 154 | if np.linalg.det(Vh) < 0: 155 | Vh *= -1 156 | 157 | R1, R2 = np.dot(np.dot(U, W), Vh), np.dot(np.dot(U, np.transpose(W)), Vh) 158 | t = U[:, 2] 159 | t /= np.linalg.norm(t, ord=2) 160 | 161 | def compose_projection_matrix(R, t): 162 | P = np.zeros(shape=[3, 4], dtype=float) 163 | P[:3, :3], P[:, 3] = R, t 164 | return P 165 | 166 | # Generate candidate projection matrices. 167 | P2_list = [] 168 | P2_list.append(compose_projection_matrix(R1, t)) 169 | P2_list.append(compose_projection_matrix(R2, t)) 170 | P2_list.append(compose_projection_matrix(R1, -t)) 171 | P2_list.append(compose_projection_matrix(R2, -t)) 172 | 173 | candidate_points3d, points3d = [], [] 174 | # Then, we need to iterate over each projection matrix and 175 | # make the cheirality validation. 176 | for extrinsic2 in P2_list: 177 | candidate_points3d = check_cheirality( 178 | inlier_keypoints1, inlier_keypoints2, 179 | extrinsic1, extrinsic2, intrinsics1, intrinsics2) 180 | # print(f'len points3d: {len(points3d)}') 181 | if len(points3d) < len(candidate_points3d): 182 | points3d[:] = candidate_points3d 183 | extrinsics2[:] = extrinsic2 184 | 185 | # print(f'final len points3d: {len(points3d)}') 186 | if len(points3d) == 0: 187 | return None, None 188 | 189 | points3d = np.stack(points3d, axis=0) 190 | 191 | return extrinsics2.reshape(1, -1), points3d 192 | -------------------------------------------------------------------------------- /scripts/preprocess/mapping.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tqdm 4 | import argparse 5 | 6 | from conerf.datasets.hypersim import _collect_camera_names, _get_all_image_names 7 | 8 | 9 | SFM_SCRIPT_PATH = os.path.join(os.getcwd(), 'scripts/preprocess/colmap_mapping.sh') 10 | VOC_TREE_PATH = '/home/chenyu/HD_Datasets/datasets/vocab_tree_flickr100K_words256K.bin' 11 | TOPK_IMAGES = 100 12 | GPU_IDS = 1 13 | 14 | # DATASETS = ['Hypersim'] #, 'nerf_synthetic'] # ['nerf_llff_data', 'ibrnet_collected_more', 'BlendedMVS'] 15 | DATASETS = ['DTU'] #, 'nerf_synthetic'] # ['nerf_llff_data', 'ibrnet_collected_more', 'BlendedMVS'] 16 | ROOT_DIR = '/home/chenyu/HD_Datasets/datasets' 17 | # DATASETS = ['BlendedMVS'] 18 | # ROOT_DIR = '/media/chenyu/SSD_Data/datasets' 19 | 20 | 21 | def config_parser(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--preprocess", 24 | action="store_true", 25 | help="whether to preprocess data") 26 | parser.add_argument("--run_colmap", 27 | action="store_true", 28 | help="whether to preprocess data") 29 | parser.add_argument("--start_index", type=int, default=0) 30 | parser.add_argument("--end_index", type=int, default=0) 31 | 32 | return parser.parse_args() 33 | 34 | 35 | def get_filename_from_abs_path(abs_path): 36 | return abs_path.split('/')[-1] 37 | 38 | 39 | def get_filename_no_ext(filename): 40 | return os.path.splitext(filename)[0] 41 | 42 | 43 | def get_file_extension(filename): 44 | return os.path.splitext(filename)[-1] 45 | 46 | 47 | def preprocess_nerf_synthetic_dataset(dataset_dir): 48 | # The DTU dataset follows pixel-nerf: https://github.com/sxyu/pixel-nerf , 49 | # Url: https://drive.google.com/drive/folders/1PsT3uKwqHHD2bEEHkIXB99AlIjtmrEiR 50 | scenes = sorted(os.listdir(dataset_dir)) 51 | for scene in scenes: 52 | scene_dir = os.path.join(dataset_dir, scene) 53 | image_dir = os.path.join(scene_dir, 'train') 54 | new_image_dir = os.path.join(scene_dir, 'images') 55 | 56 | os.system(f'cp -r {image_dir} {new_image_dir}') 57 | 58 | 59 | def preprocess_dtu_dataset(dataset_dir): 60 | # The DTU dataset follows pixel-nerf: https://github.com/sxyu/pixel-nerf , 61 | # Url: https://drive.google.com/drive/folders/1PsT3uKwqHHD2bEEHkIXB99AlIjtmrEiR 62 | scenes = sorted(os.listdir(dataset_dir)) 63 | for scene in scenes: 64 | scene_dir = os.path.join(dataset_dir, scene) 65 | if not os.path.isdir(scene_dir): 66 | continue 67 | image_dir = os.path.join(scene_dir, 'image') 68 | new_image_dir = os.path.join(scene_dir, 'images') 69 | 70 | os.system(f'mv {image_dir} {new_image_dir}') 71 | 72 | 73 | def preprocess_blended_mvs_dataset(dataset_dir): 74 | scenes = sorted(os.listdir(dataset_dir)) 75 | if args.start_index < args.end_index: 76 | scenes = scenes[args.start_index:args.end_index] 77 | 78 | for scene in scenes: 79 | scene_dir = os.path.join(dataset_dir, scene) 80 | 81 | blended_image_dir = os.path.join(scene_dir, 'blended_images') 82 | image_dir = os.path.join(scene_dir, 'images') 83 | ori_image_dir = os.path.join(scene_dir, 'ori_images') 84 | masked_image_dir = os.path.join(scene_dir, 'masked_images') 85 | 86 | # os.system(f'rm -r {scene_dir}/output') 87 | # os.system(f'rm {scene_dir}/database.db {scene_dir}/poses_bounds.npy {scene_dir}/track.txt {scene_dir}/*.g2o {scene_dir}/*.json') 88 | # os.system(f'mv {image_dir} {ori_image_dir}') 89 | # os.system(f'mv {masked_image_dir} {image_dir}') 90 | 91 | # if not os.path.exists(image_dir): 92 | # os.mkdir(image_dir) 93 | 94 | # if not os.path.exists(masked_image_dir): 95 | # os.mkdir(masked_image_dir) 96 | 97 | # for root, dirs, files in os.walk(blended_image_dir): 98 | # for file in files: 99 | # image_path = os.path.join(blended_image_dir, root, file) 100 | # if file.find('masked') >= 0: 101 | # shutil.move(image_path, os.path.join(masked_image_dir, file)) 102 | # else: 103 | # shutil.move(image_path, os.path.join(image_dir, file)) 104 | 105 | # os.system(f'rm -r {blended_image_dir}') 106 | 107 | 108 | def preprocess_hypersim_dataset(dataset_dir): 109 | scenes = sorted(os.listdir(dataset_dir)) 110 | if args.start_index < args.end_index: 111 | scenes = scenes[args.start_index:args.end_index] 112 | 113 | pbar = tqdm.trange(len(scenes), desc="Preprocessing", leave=False) 114 | for scene in scenes: 115 | scene_dir = os.path.join(dataset_dir, scene) 116 | if not os.path.isdir(scene_dir): 117 | continue 118 | 119 | camera_names = _collect_camera_names(os.path.join(scene_dir, '_detail')) 120 | 121 | new_image_dir = os.path.join(scene_dir, 'images') 122 | origin_image_dir = os.path.join(scene_dir, 'ori_images') 123 | 124 | if not os.path.exists(origin_image_dir): 125 | os.mkdir(origin_image_dir) 126 | 127 | # backup 128 | os.system(f'mv {new_image_dir}/* {origin_image_dir}') 129 | # os.system(f'rm -r {origin_image_dir}/images') 130 | 131 | for i, camera_name in enumerate(camera_names): 132 | image_dir = os.path.join(origin_image_dir, 'scene_' + camera_name + '_final_preview') 133 | image_files, _ = _get_all_image_names(image_dir, image_type='tonemap') 134 | 135 | for image_file in image_files: 136 | image_name = get_filename_from_abs_path(image_file) 137 | 138 | sub_image_dir = os.path.join(new_image_dir, str(i)) 139 | os.makedirs(sub_image_dir, exist_ok=True) 140 | new_image_file = os.path.join(sub_image_dir, image_name) 141 | shutil.copy(image_file, new_image_file) 142 | 143 | pbar.update(1) 144 | 145 | 146 | def run_sfm(root_dir, dataset_list, args): 147 | for dataset in dataset_list: 148 | dataset_dir = os.path.join(root_dir, dataset) 149 | 150 | if args.preprocess: 151 | if dataset == 'BlendedMVS': 152 | preprocess_blended_mvs_dataset(dataset_dir) 153 | 154 | if dataset == 'DTU': 155 | preprocess_dtu_dataset(dataset_dir) 156 | 157 | if dataset == 'Hypersim': 158 | preprocess_hypersim_dataset(dataset_dir) 159 | 160 | scenes = sorted(os.listdir(dataset_dir)) 161 | if args.start_index < args.end_index: 162 | scenes = scenes[args.start_index:args.end_index] 163 | 164 | pbar = tqdm.trange(len(scenes), desc="Running SfM", leave=False) 165 | for scene in scenes: 166 | data_dir = os.path.join(dataset_dir, scene) 167 | if not os.path.isdir(data_dir): 168 | continue 169 | output_dir = os.path.join(data_dir, 'sparse') 170 | if not os.path.exists(output_dir): 171 | os.makedirs(output_dir) 172 | 173 | if not args.run_colmap: 174 | # Compute bounding box. 175 | os.system(f'python -m scripts.preprocess.compute_bbox --colmap_dir {output_dir}/0') 176 | continue 177 | 178 | # print(f'output dir: {output_dir}') 179 | os.system(f'{SFM_SCRIPT_PATH} {data_dir} {output_dir} {VOC_TREE_PATH} {TOPK_IMAGES} {GPU_IDS}') 180 | 181 | shutil.move(os.path.join(output_dir, 'database.db'), os.path.join(data_dir, 'database.db')) 182 | 183 | # Compute bounding box. 184 | os.system(f'python -m scripts.preprocess.compute_bbox --colmap_dir {output_dir}/0') 185 | 186 | pbar.update(1) 187 | 188 | 189 | if __name__ == '__main__': 190 | args = config_parser() 191 | 192 | run_sfm(ROOT_DIR, DATASETS, args) 193 | -------------------------------------------------------------------------------- /scripts/preprocess/triangulate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Run triangulator with known camera poses. 4 | 5 | COLMAP_DIR=/usr/local/bin 6 | COLMAP_EXE=$COLMAP_DIR/colmap 7 | 8 | export PYTHONDONTWRITEBYTECODE=1 9 | 10 | PROJECT_PATH=$1 11 | colmap_method=$2 # ['colmap', 'pycolmap'] 12 | 13 | if [ `echo $colmap_method | grep -c "py" ` -gt 0 ] 14 | then 15 | HOME_DIR=$HOME 16 | CODE_ROOT_DIR=$HOME/'Projects/ZeroGS' 17 | cd $CODE_ROOT_DIR 18 | 19 | python -m scripts.preprocess.hloc_mapping.triangulate_from_existing_model \ 20 | --sfm_dir $PROJECT_PATH \ 21 | --reference_model $PROJECT_PATH/sparse/triangulator_input \ 22 | --output_dir $PROJECT_PATH/sparse/0 \ 23 | --image_dir $PROJECT_PATH \ 24 | --verbose \ 25 | > $PROJECT_PATH/log_triangulate.txt 2>&1 26 | else 27 | $COLMAP_EXE point_triangulator \ 28 | --database_path $PROJECT_PATH/database.db \ 29 | --image_path $PROJECT_PATH \ 30 | --input_path $PROJECT_PATH/sparse/triangulator_input \ 31 | --output_path $PROJECT_PATH/sparse/0 \ 32 | > $PROJECT_PATH/log_triangulate.txt 2>&1 33 | fi 34 | -------------------------------------------------------------------------------- /scripts/preprocess/utils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=[E0402, C)103] 2 | 3 | from pathlib import Path 4 | from typing import List, Dict, Tuple 5 | 6 | from .read_write_model import Camera 7 | 8 | 9 | def list_images(data: str) -> List[str]: 10 | """Lists all supported images in a directory 11 | Modified from: 12 | https://github.com/hturki/nerfstudio/nerfstudio/process_data/process_data_utils.py#L60 13 | 14 | Args: 15 | data: Path to the directory of images. 16 | Returns: 17 | Paths to images contained in the directory 18 | """ 19 | data = Path(data) 20 | allowed_exts = [".jpg", ".jpeg", ".png", ".tif", ".tiff"] 21 | image_paths = sorted([p for p in data.glob("[!.]*") if p.suffix.lower() in allowed_exts]) 22 | return image_paths 23 | 24 | 25 | def list_metadata(data: str) -> List[str]: 26 | """Lists all supported images in a directory 27 | Modified from: 28 | https://github.com/hturki/nerfstudio/nerfstudio/process_data/process_data_utils.py#L60 29 | 30 | Args: 31 | data: Path to the directory of images. 32 | Returns: 33 | Paths to images contained in the directory 34 | """ 35 | data = Path(data) 36 | allowed_exts = [".pt"] 37 | metadata_paths = sorted([p for p in data.glob("[!.]*") if p.suffix.lower() in allowed_exts]) 38 | return metadata_paths 39 | 40 | 41 | def list_jsons(data: str) -> List[str]: 42 | """Lists all supported images in a directory 43 | Modified from: 44 | https://github.com/hturki/nerfstudio/nerfstudio/process_data/process_data_utils.py#L60 45 | 46 | Args: 47 | data: Path to the directory of images. 48 | Returns: 49 | Paths to images contained in the directory 50 | """ 51 | data = Path(data) 52 | allowed_exts = [".json"] 53 | metadata_paths = sorted([p for p in data.glob("[!.]*") if p.suffix.lower() in allowed_exts]) 54 | return metadata_paths 55 | 56 | 57 | def read_meganerf_mappings(mappings_path: str) -> Tuple[Dict, Dict]: 58 | image_name_to_metadata, metadata_to_image_name = {}, {} 59 | with open(mappings_path, "r", encoding="utf-8") as file: 60 | line = file.readline() 61 | while line: 62 | image_name, pt_name = line.split(',') 63 | pt_name = pt_name.strip() 64 | image_name_to_metadata[image_name] = pt_name 65 | metadata_to_image_name[pt_name] = image_name 66 | line = file.readline() 67 | 68 | return image_name_to_metadata, metadata_to_image_name 69 | 70 | 71 | def get_filename_from_path(path: str) -> str: 72 | last_slash_index = path.rfind('/') 73 | return path[last_slash_index+1:] 74 | 75 | 76 | def is_same_camera(camera1: Camera, camera2: Camera) -> bool: 77 | if camera1.width != camera2.width: 78 | return False 79 | 80 | if camera1.height != camera2.height: 81 | return False 82 | 83 | if len(camera1.params) != len(camera2.params): 84 | return False 85 | 86 | for i in range(len(camera1.params)): 87 | if camera1.params[i] != camera2.params[i]: 88 | return False 89 | 90 | return True 91 | 92 | 93 | def get_camera_id(cameras: Dict, query_camera: Camera) -> int: 94 | for idx, camera in cameras.items(): 95 | if is_same_camera(camera, query_camera): 96 | return idx 97 | 98 | return len(cameras) + 1 99 | -------------------------------------------------------------------------------- /scripts/train/train_ace_zero.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_IDS=$1 # {'0,1,2,...'} 4 | 5 | export PYTHONDONTWRITEBYTECODE=1 6 | export CUDA_VISIBLE_DEVICES=${CUDA_IDS} 7 | 8 | # Default parameters. 9 | DATASET='blender' # [llff, mipnerf360, tanks_and_temples] 10 | ENCODING='ace' # [ace, zero_gs] 11 | SUFFIX='' 12 | 13 | NUM_CMD_PARAMS=$# 14 | if [ $NUM_CMD_PARAMS -eq 2 ] 15 | then 16 | SUFFIX=$2 17 | elif [ $NUM_CMD_PARAMS -eq 3 ] 18 | then 19 | SUFFIX=$2 20 | DATASET=$3 21 | elif [ $NUM_CMD_PARAMS -eq 4 ] 22 | then 23 | SUFFIX=$2 24 | DATASET=$3 25 | ENCODING=$4 26 | fi 27 | 28 | YAML=${ENCODING}/${DATASET}'.yaml' 29 | echo "Using yaml file: ${YAML}" 30 | 31 | HOME_DIR=$HOME 32 | CODE_ROOT_DIR=$HOME/'Projects/ZeroGS' 33 | 34 | cd $CODE_ROOT_DIR 35 | 36 | python train.py --config 'config/'${YAML} \ 37 | --suffix $SUFFIX 38 | -------------------------------------------------------------------------------- /submodules/dsacstar/dsacstar.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Based on the DSAC++ and ESAC code. 3 | https://github.com/vislearn/LessMore 4 | https://github.com/vislearn/esac 5 | 6 | Copyright (c) 2016, TU Dresden 7 | Copyright (c) 2020, Heidelberg University 8 | All rights reserved. 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are met: 12 | * Redistributions of source code must retain the above copyright 13 | notice, this list of conditions and the following disclaimer. 14 | * Redistributions in binary form must reproduce the above copyright 15 | notice, this list of conditions and the following disclaimer in the 16 | documentation and/or other materials provided with the distribution. 17 | * Neither the name of the TU Dresden, Heidelberg University nor the 18 | names of its contributors may be used to endorse or promote products 19 | derived from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY 25 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 28 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | */ 32 | 33 | #include 34 | #include 35 | 36 | #include 37 | 38 | #include "thread_rand.h" 39 | #include "stop_watch.h" 40 | 41 | #include "dsacstar_types.h" 42 | #include "dsacstar_util.h" 43 | //#include "dsacstar_util_rgbd.h" 44 | #include "dsacstar_loss.h" 45 | #include "dsacstar_derivative.h" 46 | 47 | #define MAX_REF_STEPS 100 // max pose refienment iterations 48 | #define MAX_HYPOTHESES_TRIES 1000000 // repeat sampling x times hypothesis if hypothesis is invalid 49 | 50 | /** 51 | * @brief Estimate a camera pose based on a scene coordinate prediction 52 | * @param sceneCoordinatesSrc Scene coordinate prediction, (1x3xHxW) with 1=batch dimension (only batch_size=1 supported atm), 3=scene coordainte dimensions, H=height and W=width. 53 | * @param outPoseSrc Camera pose (output parameter), (4x4) tensor containing the homogeneous camera tranformation matrix. 54 | * @param ransacHypotheses Number of RANSAC iterations. 55 | * @param inlierThreshold Inlier threshold for RANSAC in px. 56 | * @param focalLength Focal length of the camera in px. 57 | * @param ppointX Coordinate (X) of the prinicpal points. 58 | * @param ppointY Coordinate (Y) of the prinicpal points. 59 | * @param inlierAlpha Alpha parameter for soft inlier counting. 60 | * @param maxReproj Reprojection errors are clamped above this value (px). 61 | * @param subSampling Sub-sampling of the scene coordinate prediction wrt the input image. 62 | * @return The number of inliers for the output pose. 63 | */ 64 | int dsacstar_rgb_forward( 65 | at::Tensor sceneCoordinatesSrc, 66 | at::Tensor outPoseSrc, 67 | int ransacHypotheses, 68 | float inlierThreshold, 69 | float fx, 70 | float fy, 71 | float ppointX, 72 | float ppointY, 73 | float inlierAlpha, 74 | float maxReproj, 75 | int subSampling, 76 | int maxHypothesesTries = 10000000, 77 | bool verbose = false) 78 | { 79 | ThreadRand::init(); 80 | 81 | // access to tensor objects 82 | dsacstar::coord_t sceneCoordinates = 83 | sceneCoordinatesSrc.accessor(); 84 | 85 | // dimensions of scene coordinate predictions 86 | int imH = sceneCoordinates.size(2); 87 | int imW = sceneCoordinates.size(3); 88 | 89 | // internal camera calibration matrix 90 | cv::Mat_ camMat = cv::Mat_::eye(3, 3); 91 | camMat(0, 0) = fx; 92 | camMat(1, 1) = fy; 93 | camMat(0, 2) = ppointX; 94 | camMat(1, 2) = ppointY; 95 | 96 | // calculate original image position for each scene coordinate prediction 97 | cv::Mat_ sampling = 98 | dsacstar::createSampling(imW, imH, subSampling, 0, 0); 99 | 100 | if (verbose) { 101 | std::cout << BLUETEXT("Sampling " << ransacHypotheses << " hypotheses.") << std::endl; 102 | } 103 | 104 | StopWatch stopW; 105 | 106 | // sample RANSAC hypotheses 107 | std::vector hypotheses; 108 | std::vector> sampledPoints; 109 | std::vector> imgPts; 110 | std::vector> objPts; 111 | 112 | dsacstar::sampleHypotheses( 113 | sceneCoordinates, 114 | sampling, 115 | camMat, 116 | ransacHypotheses, 117 | maxHypothesesTries, 118 | inlierThreshold, 119 | hypotheses, 120 | sampledPoints, 121 | imgPts, 122 | objPts); 123 | 124 | if (verbose) { 125 | std::cout << "Done in " << stopW.stop() / 1000 << "s." << std::endl; 126 | std::cout << BLUETEXT("Calculating scores.") << std::endl; 127 | } 128 | 129 | // compute reprojection error images 130 | std::vector> reproErrs(ransacHypotheses); 131 | cv::Mat_ jacobeanDummy; 132 | 133 | #pragma omp parallel for 134 | for(unsigned h = 0; h < hypotheses.size(); h++) 135 | reproErrs[h] = dsacstar::getReproErrs( 136 | sceneCoordinates, 137 | hypotheses[h], 138 | sampling, 139 | camMat, 140 | maxReproj, 141 | jacobeanDummy); 142 | 143 | // soft inlier counting 144 | std::vector scores = dsacstar::getHypScores( 145 | reproErrs, 146 | inlierThreshold, 147 | inlierAlpha); 148 | 149 | if (verbose) { 150 | std::cout << "Done in " << stopW.stop() / 1000 << "s." << std::endl; 151 | std::cout << BLUETEXT("Drawing final hypothesis.") << std::endl; 152 | } 153 | 154 | // apply soft max to scores to get a distribution 155 | std::vector hypProbs = dsacstar::softMax(scores); 156 | double hypEntropy = dsacstar::entropy(hypProbs); // measure distribution entropy 157 | int hypIdx = dsacstar::draw(hypProbs, false); // select winning hypothesis 158 | 159 | if (verbose) { 160 | std::cout << "Soft inlier count: " << scores[hypIdx] << " (Selection Probability: " << (int) (hypProbs[hypIdx]*100) << "%)" << std::endl; 161 | std::cout << "Entropy of hypothesis distribution: " << hypEntropy << std::endl; 162 | 163 | std::cout << "Done in " << stopW.stop() / 1000 << "s." << std::endl; 164 | std::cout << BLUETEXT("Refining winning pose:") << std::endl; 165 | } 166 | 167 | // refine selected hypothesis 168 | cv::Mat_ inlierMap; 169 | 170 | dsacstar::refineHyp( 171 | sceneCoordinates, 172 | reproErrs[hypIdx], 173 | sampling, 174 | camMat, 175 | inlierThreshold, 176 | MAX_REF_STEPS, 177 | maxReproj, 178 | hypotheses[hypIdx], 179 | inlierMap); 180 | 181 | if (verbose) { 182 | std::cout << "Done in " << stopW.stop() / 1000 << "s." << std::endl; 183 | } 184 | 185 | // write result back to PyTorch 186 | dsacstar::trans_t estTrans = dsacstar::pose2trans(hypotheses[hypIdx]); 187 | 188 | auto outPose = outPoseSrc.accessor(); 189 | for(unsigned x = 0; x < 4; x++) 190 | for(unsigned y = 0; y < 4; y++) 191 | outPose[y][x] = estTrans(y, x); 192 | 193 | // Return the inlier count. cv::sum returns a scalar, so we return its first element. 194 | return cv::sum(inlierMap)[0]; 195 | } 196 | 197 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 198 | m.def("forward_rgb", &dsacstar_rgb_forward, "DSAC* forward (RGB)"); 199 | // m.def("backward_rgb", &dsacstar_rgb_backward, "DSAC* backward (RGB)"); 200 | // m.def("forward_rgbd", &dsacstar_rgbd_forward, "DSAC* forward (RGB-D)"); 201 | // m.def("backward_rgbd", &dsacstar_rgbd_backward, "DSAC* backward (RGB-D)"); 202 | } 203 | -------------------------------------------------------------------------------- /submodules/dsacstar/dsacstar_loss.h: -------------------------------------------------------------------------------- 1 | /* 2 | Based on the DSAC++ and ESAC code. 3 | https://github.com/vislearn/LessMore 4 | https://github.com/vislearn/esac 5 | 6 | Copyright (c) 2016, TU Dresden 7 | Copyright (c) 2020, Heidelberg University 8 | All rights reserved. 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are met: 12 | * Redistributions of source code must retain the above copyright 13 | notice, this list of conditions and the following disclaimer. 14 | * Redistributions in binary form must reproduce the above copyright 15 | notice, this list of conditions and the following disclaimer in the 16 | documentation and/or other materials provided with the distribution. 17 | * Neither the name of the TU Dresden, Heidelberg University nor the 18 | names of its contributors may be used to endorse or promote products 19 | derived from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY 25 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 28 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | */ 32 | 33 | #pragma once 34 | 35 | #define MAXLOSS 10000000.0 // clamp for stability 36 | 37 | namespace dsacstar 38 | { 39 | /** 40 | * @brief Calculates the rotational distance in degree between two transformations. 41 | * Translation will be ignored. 42 | * 43 | * @param trans1 Transformation 1. 44 | * @param trans2 Transformation 2. 45 | * @return Angle in degree. 46 | */ 47 | double calcAngularDistance(const dsacstar::trans_t& trans1, const dsacstar::trans_t& trans2) 48 | { 49 | cv::Mat rot1 = trans1.colRange(0, 3).rowRange(0, 3); 50 | cv::Mat rot2 = trans2.colRange(0, 3).rowRange(0, 3); 51 | 52 | cv::Mat rotDiff= rot2 * rot1.t(); 53 | double trace = cv::trace(rotDiff)[0]; 54 | 55 | trace = std::min(3.0, std::max(-1.0, trace)); 56 | return 180*acos((trace-1.0)/2.0)/PI; 57 | } 58 | 59 | /** 60 | * @brief Weighted average of translational error and rotational error between two pose hypothesis. 61 | * @param h1 Pose 1. 62 | * @param h2 Pose 2. 63 | * @param wRot Weight of rotation error. 64 | * @param wTrans Weight of translation error. 65 | * @param cut Apply soft clamping after this value. 66 | * @return Loss. 67 | */ 68 | double loss( 69 | const dsacstar::trans_t& trans1, 70 | const dsacstar::trans_t& trans2, 71 | double wRot = 1.0, 72 | double wTrans = 1.0, 73 | double cut = 100) 74 | { 75 | double rotErr = dsacstar::calcAngularDistance(trans1, trans2); 76 | double tErr = cv::norm( 77 | trans1.col(3).rowRange(0, 3) - trans2.col(3).rowRange(0, 3)); 78 | 79 | double loss = wRot * rotErr + wTrans * tErr; 80 | 81 | if(loss > cut) 82 | loss = std::sqrt(cut * loss); 83 | 84 | return std::min(loss, MAXLOSS); 85 | } 86 | 87 | /** 88 | * @brief Calculate the derivative of the loss w.r.t. the estimated pose. 89 | * @param est Estimated pose (6 DoF). 90 | * @param gt Ground truth pose (6 DoF). 91 | * @param wRot Weight of rotation error. 92 | * @param wTrans Weight of translation error. 93 | * @param cut Apply soft clamping after this value. 94 | * @return 1x6 Jacobean. 95 | */ 96 | cv::Mat_ dLoss( 97 | const dsacstar::pose_t& est, 98 | const dsacstar::pose_t& gt, 99 | double wRot = 1.0, 100 | double wTrans = 1.0, 101 | double cut = 100) 102 | { 103 | cv::Mat rot1, rot2, dRod; 104 | cv::Rodrigues(est.first, rot1, dRod); 105 | cv::Rodrigues(gt.first, rot2); 106 | 107 | // measure loss of inverted poses (camera pose instead of scene pose) 108 | cv::Mat_ invRot1 = rot1.t(); 109 | cv::Mat_ invRot2 = rot2.t(); 110 | 111 | // get the difference rotation 112 | cv::Mat diffRot = rot1 * invRot2; 113 | 114 | // calculate rotational and translational error 115 | double trace = cv::trace(diffRot)[0]; 116 | trace = std::min(3.0, std::max(-1.0, trace)); 117 | double rotErr = 180*acos((trace-1.0)/2.0)/CV_PI; 118 | 119 | cv::Mat_ invT1 = est.second.clone(); 120 | invT1 = invRot1 * invT1; 121 | 122 | cv::Mat_ invT2 = gt.second.clone(); 123 | invT2 = invRot2 * invT2; 124 | 125 | // zero error, abort 126 | double tErr = cv::norm(invT1 - invT2); 127 | 128 | cv::Mat_ jacobean = cv::Mat_::zeros(1, 6); 129 | 130 | // clamped loss, return zero gradient if loss is bigger than threshold 131 | double loss = wRot * rotErr + wTrans * tErr; 132 | bool cutLoss = false; 133 | 134 | 135 | if(loss > cut) 136 | { 137 | loss = std::sqrt(loss); 138 | cutLoss = true; 139 | } 140 | 141 | if(loss > MAXLOSS) 142 | return jacobean; 143 | 144 | if((tErr + rotErr) < EPS) 145 | return jacobean; 146 | 147 | 148 | // return gradient of translational error 149 | cv::Mat_ dDist_dInvT1(1, 3); 150 | for(unsigned i = 0; i < 3; i++) 151 | dDist_dInvT1(0, i) = (invT1(i, 0) - invT2(i, 0)) / tErr; 152 | 153 | cv::Mat_ dInvT1_dEstT(3, 3); 154 | dInvT1_dEstT = invRot1; 155 | 156 | cv::Mat_ dDist_dEstT = dDist_dInvT1 * dInvT1_dEstT; 157 | jacobean.colRange(3, 6) += dDist_dEstT * wTrans; 158 | 159 | cv::Mat_ dInvT1_dInvRot1 = cv::Mat_::zeros(3, 9); 160 | 161 | dInvT1_dInvRot1(0, 0) = est.second.at(0, 0); 162 | dInvT1_dInvRot1(0, 3) = est.second.at(1, 0); 163 | dInvT1_dInvRot1(0, 6) = est.second.at(2, 0); 164 | 165 | dInvT1_dInvRot1(1, 1) = est.second.at(0, 0); 166 | dInvT1_dInvRot1(1, 4) = est.second.at(1, 0); 167 | dInvT1_dInvRot1(1, 7) = est.second.at(2, 0); 168 | 169 | dInvT1_dInvRot1(2, 2) = est.second.at(0, 0); 170 | dInvT1_dInvRot1(2, 5) = est.second.at(1, 0); 171 | dInvT1_dInvRot1(2, 8) = est.second.at(2, 0); 172 | 173 | dRod = dRod.t(); 174 | 175 | cv::Mat_ dDist_dRod = dDist_dInvT1 * dInvT1_dInvRot1 * dRod; 176 | jacobean.colRange(0, 3) += dDist_dRod * wTrans; 177 | 178 | 179 | // return gradient of rotational error 180 | cv::Mat_ dRotDiff = cv::Mat_::zeros(9, 9); 181 | invRot2.row(0).copyTo(dRotDiff.row(0).colRange(0, 3)); 182 | invRot2.row(1).copyTo(dRotDiff.row(1).colRange(0, 3)); 183 | invRot2.row(2).copyTo(dRotDiff.row(2).colRange(0, 3)); 184 | 185 | invRot2.row(0).copyTo(dRotDiff.row(3).colRange(3, 6)); 186 | invRot2.row(1).copyTo(dRotDiff.row(4).colRange(3, 6)); 187 | invRot2.row(2).copyTo(dRotDiff.row(5).colRange(3, 6)); 188 | 189 | invRot2.row(0).copyTo(dRotDiff.row(6).colRange(6, 9)); 190 | invRot2.row(1).copyTo(dRotDiff.row(7).colRange(6, 9)); 191 | invRot2.row(2).copyTo(dRotDiff.row(8).colRange(6, 9)); 192 | 193 | dRotDiff = dRotDiff.t(); 194 | 195 | cv::Mat_ dTrace = cv::Mat_::zeros(1, 9); 196 | dTrace(0, 0) = 1; 197 | dTrace(0, 4) = 1; 198 | dTrace(0, 8) = 1; 199 | 200 | cv::Mat_ dAngle = (180 / CV_PI * -1 / sqrt(3 - trace * trace + 2 * trace)) * dTrace * dRotDiff * dRod; 201 | 202 | jacobean.colRange(0, 3) += dAngle * wRot; 203 | 204 | if(cutLoss) 205 | jacobean *= 0.5 / loss; 206 | 207 | 208 | if(cv::sum(cv::Mat(jacobean != jacobean))[0] > 0) //check for NaNs 209 | return cv::Mat_::zeros(1, 6); 210 | 211 | return jacobean; 212 | } 213 | 214 | 215 | } 216 | -------------------------------------------------------------------------------- /submodules/dsacstar/dsacstar_types.h: -------------------------------------------------------------------------------- 1 | /* 2 | Based on the DSAC++ and ESAC code. 3 | https://github.com/vislearn/LessMore 4 | https://github.com/vislearn/esac 5 | 6 | Copyright (c) 2016, TU Dresden 7 | Copyright (c) 2020, Heidelberg University 8 | All rights reserved. 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are met: 12 | * Redistributions of source code must retain the above copyright 13 | notice, this list of conditions and the following disclaimer. 14 | * Redistributions in binary form must reproduce the above copyright 15 | notice, this list of conditions and the following disclaimer in the 16 | documentation and/or other materials provided with the distribution. 17 | * Neither the name of the TU Dresden, Heidelberg University nor the 18 | names of its contributors may be used to endorse or promote products 19 | derived from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY 25 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 28 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | */ 32 | 33 | #pragma once 34 | 35 | #include "opencv2/opencv.hpp" 36 | 37 | /** Several important types used troughout all this code. If types have to be changed, it can be done here, conveniently. */ 38 | 39 | namespace dsacstar 40 | { 41 | // scene pose type (OpenCV convention: axis-angle + translation) 42 | typedef std::pair pose_t; 43 | // camera transformation type (inverted scene pose as 4x4 matrix) 44 | typedef cv::Mat_ trans_t; 45 | // ATen accessor type 46 | typedef at::TensorAccessor coord_t; 47 | } 48 | -------------------------------------------------------------------------------- /submodules/dsacstar/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CppExtension, BuildExtension 3 | import os 4 | 5 | opencv_inc_dir = '' # directory containing OpenCV header files 6 | opencv_lib_dir = '' # directory containing OpenCV library files 7 | 8 | #if not explicitly provided, we try to locate OpenCV in the current Conda environment 9 | conda_env = os.environ['CONDA_PREFIX'] 10 | 11 | if len(conda_env) > 0 and len(opencv_inc_dir) == 0 and len(opencv_lib_dir) == 0: 12 | print("Detected active conda environment:", conda_env) 13 | 14 | opencv_inc_dir = conda_env + '/include/opencv4' 15 | opencv_lib_dir = conda_env + '/lib/opencv4' 16 | 17 | print("Assuming OpenCV dependencies in:") 18 | print(opencv_inc_dir) 19 | print(opencv_lib_dir) 20 | 21 | if len(opencv_inc_dir) == 0: 22 | print("Error: You have to provide an OpenCV include directory. Edit this file.") 23 | exit() 24 | if len(opencv_lib_dir) == 0: 25 | print("Error: You have to provide an OpenCV library directory. Edit this file.") 26 | exit() 27 | 28 | setup( 29 | name='dsacstar', 30 | ext_modules=[CppExtension( 31 | name='dsacstar', 32 | sources=['dsacstar.cpp','thread_rand.cpp'], 33 | include_dirs=[opencv_inc_dir], 34 | library_dirs=[opencv_lib_dir], 35 | libraries=['opencv_core','opencv_calib3d'], 36 | extra_compile_args=['-fopenmp'] 37 | )], 38 | cmdclass={'build_ext': BuildExtension}) 39 | -------------------------------------------------------------------------------- /submodules/dsacstar/stop_watch.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2016, TU Dresden 3 | Copyright (c) 2017, Heidelberg University 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | * Neither the name of the TU Dresden, Heidelberg University nor the 14 | names of its contributors may be used to endorse or promote products 15 | derived from this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY 21 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | 29 | 30 | #pragma once 31 | 32 | #include 33 | 34 | /** 35 | * @brief Class for time measurements. 36 | */ 37 | class StopWatch 38 | { 39 | public: 40 | /** 41 | * @brief Construction. Initializes the stop watch. 42 | */ 43 | StopWatch(){ init(); } 44 | 45 | /** 46 | * @brief Initialization. Starts the time measurement. 47 | * 48 | * @return void 49 | */ 50 | void init() 51 | { 52 | start = std::chrono::high_resolution_clock::now(); 53 | } 54 | 55 | /** 56 | * @brief Stops and restarts the time measurement. 57 | * 58 | * @return float The time in ms since the last init or stop call. 59 | */ 60 | float stop() 61 | { 62 | std::chrono::high_resolution_clock::time_point now; 63 | now = std::chrono::high_resolution_clock::now(); 64 | 65 | std::chrono::high_resolution_clock::duration duration = now - start; 66 | 67 | start = now; 68 | 69 | return static_cast( 70 | 1000.0 * std::chrono::duration_cast>( 71 | duration).count()); 72 | } 73 | 74 | private: 75 | std::chrono::high_resolution_clock::time_point start; // start time of the current measurement. 76 | }; 77 | -------------------------------------------------------------------------------- /submodules/dsacstar/thread_rand.cpp: -------------------------------------------------------------------------------- 1 | #include "thread_rand.h" 2 | #include 3 | 4 | std::vector ThreadRand::generators; 5 | bool ThreadRand::initialised = false; 6 | 7 | void ThreadRand::forceInit(unsigned seed) 8 | { 9 | initialised = false; 10 | init(seed); 11 | } 12 | 13 | void ThreadRand::init(unsigned seed) 14 | { 15 | #pragma omp critical 16 | { 17 | if(!initialised) 18 | { 19 | unsigned nThreads = omp_get_max_threads(); 20 | 21 | for(unsigned i = 0; i < nThreads; i++) 22 | { 23 | generators.push_back(std::mt19937()); 24 | generators[i].seed(i+seed); 25 | } 26 | 27 | initialised = true; 28 | } 29 | } 30 | } 31 | 32 | int ThreadRand::irand(int min, int max, int tid) 33 | { 34 | std::uniform_int_distribution dist(min, max); 35 | 36 | unsigned threadID = omp_get_thread_num(); 37 | if(tid >= 0) threadID = tid; 38 | 39 | if(!initialised) init(); 40 | 41 | return dist(ThreadRand::generators[threadID]); 42 | } 43 | 44 | double ThreadRand::drand(double min, double max, int tid) 45 | { 46 | std::uniform_real_distribution dist(min, max); 47 | 48 | unsigned threadID = omp_get_thread_num(); 49 | if(tid >= 0) threadID = tid; 50 | 51 | if(!initialised) init(); 52 | 53 | return dist(ThreadRand::generators[threadID]); 54 | } 55 | 56 | double ThreadRand::dgauss(double mean, double stdDev, int tid) 57 | { 58 | std::normal_distribution dist(mean, stdDev); 59 | 60 | unsigned threadID = omp_get_thread_num(); 61 | if(tid >= 0) threadID = tid; 62 | 63 | if(!initialised) init(); 64 | 65 | return dist(ThreadRand::generators[threadID]); 66 | } 67 | 68 | int irand(int incMin, int excMax, int tid) 69 | { 70 | return ThreadRand::irand(incMin, excMax - 1, tid); 71 | } 72 | 73 | double drand(double incMin, double incMax,int tid) 74 | { 75 | return ThreadRand::drand(incMin, incMax, tid); 76 | } 77 | 78 | int igauss(int mean, int stdDev, int tid) 79 | { 80 | return (int) ThreadRand::dgauss(mean, stdDev, tid); 81 | } 82 | 83 | double dgauss(double mean, double stdDev, int tid) 84 | { 85 | return ThreadRand::dgauss(mean, stdDev, tid); 86 | } -------------------------------------------------------------------------------- /submodules/dsacstar/thread_rand.h: -------------------------------------------------------------------------------- 1 | /* 2 | Based on the DSAC++ and ESAC code. 3 | https://github.com/vislearn/LessMore 4 | https://github.com/vislearn/esac 5 | 6 | Copyright (c) 2016, TU Dresden 7 | Copyright (c) 2019, Heidelberg University 8 | All rights reserved. 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are met: 12 | * Redistributions of source code must retain the above copyright 13 | notice, this list of conditions and the following disclaimer. 14 | * Redistributions in binary form must reproduce the above copyright 15 | notice, this list of conditions and the following disclaimer in the 16 | documentation and/or other materials provided with the distribution. 17 | * Neither the name of the TU Dresden, Heidelberg University nor the 18 | names of its contributors may be used to endorse or promote products 19 | derived from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY 25 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 28 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | */ 32 | 33 | #pragma once 34 | 35 | #include 36 | 37 | /** Classes and methods for generating random numbers in multi-threaded programs. */ 38 | 39 | /** 40 | * @brief Provides random numbers for multiple threads. 41 | * 42 | * Singelton class. Holds a random number generator for each thread and gives random numbers for the current thread. 43 | */ 44 | class ThreadRand 45 | { 46 | public: 47 | /** 48 | * @brief Returns a random integer (uniform distribution). 49 | * 50 | * @param min Minimum value of the random integer (inclusive). 51 | * @param max Maximum value of the random integer (exclusive). 52 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 53 | * @return int Random integer value. 54 | */ 55 | static int irand(int min, int max, int tid = -1); 56 | 57 | /** 58 | * @brief Returns a random double value (uniform distribution). 59 | * 60 | * @param min Minimum value of the random double (inclusive). 61 | * @param max Maximum value of the random double (inclusive). 62 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 63 | * @return double Random double value. 64 | */ 65 | static double drand(double min, double max, int tid = -1); 66 | 67 | /** 68 | * @brief Returns a random double value (Gauss distribution). 69 | * 70 | * @param mean Mean of the Gauss distribution to sample from. 71 | * @param stdDev Standard deviation of the Gauss distribution to sample from. 72 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 73 | * @return double Random double value. 74 | */ 75 | static double dgauss(double mean, double stdDev, int tid = -1); 76 | 77 | /** 78 | * @brief Re-Initialize the object with the given seed. 79 | * 80 | * @param seed Seed to initialize the random number generators (seed is incremented by one for each generator). 81 | * @return void 82 | */ 83 | static void forceInit(unsigned seed); 84 | 85 | /** 86 | * @brief List of random number generators. One for each thread. 87 | * 88 | */ 89 | static std::vector generators; 90 | 91 | /** 92 | * @brief Initialize class with the given seed. 93 | * 94 | * Method will create a random number generator for each thread. The given seed 95 | * will be incremented by one for each generator. This methods is automatically 96 | * called when this calss is used the first time. 97 | * 98 | * @param seed Optional parameter. Seed to be used when initializing the generators. Will be incremented by one for each generator. 99 | * @return void 100 | */ 101 | static void init(unsigned seed = 1305); 102 | 103 | private: 104 | /** 105 | * @brief True if the class has been initialized already 106 | */ 107 | static bool initialised; 108 | }; 109 | 110 | /** 111 | * @brief Returns a random integer (uniform distribution). 112 | * 113 | * This method used the ThreadRand class. 114 | * 115 | * @param min Minimum value of the random integer (inclusive). 116 | * @param max Maximum value of the random integer (exclusive). 117 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 118 | * @return int Random integer value. 119 | */ 120 | int irand(int incMin, int excMax, int tid = -1); 121 | /** 122 | * @brief Returns a random double value (uniform distribution). 123 | * 124 | * This method used the ThreadRand class. 125 | * 126 | * @param min Minimum value of the random double (inclusive). 127 | * @param max Maximum value of the random double (inclusive). 128 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 129 | * @return double Random double value. 130 | */ 131 | double drand(double incMin, double incMax, int tid = -1); 132 | 133 | /** 134 | * @brief Returns a random integer value (Gauss distribution). 135 | * 136 | * This method used the ThreadRand class. 137 | * 138 | * @param mean Mean of the Gauss distribution to sample from. 139 | * @param stdDev Standard deviation of the Gauss distribution to sample from. 140 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 141 | * @return double Random integer value. 142 | */ 143 | int igauss(int mean, int stdDev, int tid = -1); 144 | 145 | /** 146 | * @brief Returns a random double value (Gauss distribution). 147 | * 148 | * This method used the ThreadRand class. 149 | * 150 | * @param mean Mean of the Gauss distribution to sample from. 151 | * @param stdDev Standard deviation of the Gauss distribution to sample from. 152 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 153 | * @return double Random double value. 154 | */ 155 | double dgauss(double mean, double stdDev, int tid = -1); 156 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=[E1101,W0621,E0401] 2 | 3 | import copy 4 | import os 5 | import warnings 6 | import logging 7 | 8 | import omegaconf 9 | from omegaconf import OmegaConf 10 | 11 | from conerf.utils.config import config_parser, load_config 12 | from conerf.utils.utils import setup_seed 13 | from utils import create_trainer # pylint: disable=E0611 14 | 15 | warnings.filterwarnings("ignore", category=UserWarning) 16 | 17 | 18 | def run_cmd(cmd: str): 19 | os.system(cmd) 20 | 21 | return True 22 | 23 | 24 | def train(config: OmegaConf): 25 | trainer = create_trainer(config) 26 | trainer.update_meta_data() 27 | trainer.train() 28 | print(f"total iteration: {trainer.iteration}") 29 | 30 | 31 | if __name__ == "__main__": 32 | args = config_parser() 33 | 34 | logging.basicConfig( 35 | format='%(asctime)s %(levelname)-6s [%(filename)s:%(lineno)d] %(message)s', 36 | datefmt='%Y-%m-%d:%H:%M:%S', 37 | level=logging.INFO 38 | ) 39 | 40 | # parse YAML config to OmegaConf 41 | config = load_config(args.config) 42 | config["config_file_path"] = args.config 43 | 44 | assert config.dataset.scene != "" 45 | 46 | setup_seed(config.seed) 47 | 48 | scenes = [] 49 | if ( 50 | type(config.dataset.scene) == omegaconf.listconfig.ListConfig # pylint: disable=C0123 51 | ): 52 | scene_list = list(config.dataset.scene) 53 | for sc in config.dataset.scene: 54 | scenes.append(sc) 55 | else: 56 | scenes.append(config.dataset.scene) 57 | 58 | for scene in scenes: 59 | data_dir = os.path.join(config.dataset.root_dir, scene) 60 | assert os.path.exists(data_dir), f"Dataset does not exist: {data_dir}!" 61 | 62 | local_config = copy.deepcopy(config) 63 | local_config.expname = ( 64 | f"{config.neural_field_type}_{config.task}_{config.dataset.name}_{scene}" 65 | ) 66 | local_config.expname = local_config.expname + "_" + args.suffix 67 | local_config.dataset.scene = scene 68 | local_config.dataset.model_folder = args.model_folder 69 | local_config.dataset.init_ply_type = args.init_ply_type 70 | local_config.dataset.load_specified_images = args.load_specified_images 71 | 72 | train(local_config) 73 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | from conerf.base.model_base import ModelBase 4 | from conerf.trainers.ace_zero_trainer import AceZeroTrainer 5 | 6 | 7 | def create_trainer( 8 | config: OmegaConf, 9 | prefetch_dataset=True, 10 | trainset=None, 11 | valset=None, 12 | model: ModelBase = None 13 | ): 14 | """Factory function for training neural network trainers.""" 15 | if config.task == "pose": 16 | trainer = AceZeroTrainer(config, prefetch_dataset, trainset, valset) 17 | else: 18 | raise NotImplementedError 19 | 20 | return trainer 21 | --------------------------------------------------------------------------------