├── .gitignore
├── LICENSE
├── README.md
├── assets
└── imgs
│ └── zero_gs_logo.png
├── conerf
├── __init__.py
├── base
│ ├── checkpoint_manager.py
│ └── model_base.py
├── datasets
│ ├── __init__.py
│ ├── ace_camera_loc_dataset.py
│ ├── dataset_base.py
│ ├── load_colmap.py
│ └── utils.py
├── evaluators
│ ├── ace_zero_evaluator.py
│ └── evaluator.py
├── geometry
│ ├── align_poses.py
│ ├── camera.py
│ ├── pose_util.py
│ └── utils.py
├── loss
│ └── ssim_torch.py
├── model
│ ├── __init__.py
│ ├── backbone
│ │ ├── activations.py
│ │ ├── encodings.py
│ │ ├── feature_pyramid_net.py
│ │ ├── mlp.py
│ │ └── resnet3d.py
│ ├── misc.py
│ └── scene_regressor
│ │ ├── ace_encoder_pretrained.pt
│ │ ├── ace_loss.py
│ │ ├── ace_network.py
│ │ ├── ace_util.py
│ │ ├── calibr.py
│ │ ├── depth_network.py
│ │ └── pose_refine_network.py
├── pycolmap
│ ├── pycolmap
│ │ ├── __init__.py
│ │ ├── camera.py
│ │ ├── database.py
│ │ ├── image.py
│ │ ├── rotation.py
│ │ └── scene_manager.py
│ └── tools
│ │ ├── colmap_to_nvm.py
│ │ ├── delete_images.py
│ │ ├── impute_missing_cameras.py
│ │ ├── save_cameras_as_ply.py
│ │ ├── transform_model.py
│ │ ├── write_camera_track_to_bundler.py
│ │ └── write_depthmap_to_ply.py
├── trainers
│ ├── ace_zero_trainer.py
│ └── trainer.py
├── utils
│ ├── config.py
│ └── utils.py
└── visualization
│ ├── feature_visualizer.py
│ ├── pose_visualizer.py
│ └── scene_visualizer.py
├── config
└── ace
│ ├── llff.yaml
│ ├── mipnerf360.yaml
│ └── tanks_and_temples.yaml
├── eval.py
├── scripts
├── env
│ └── install.sh
├── eval
│ ├── eval_ace_zero.sh
│ └── vis_recon.py
├── preprocess
│ ├── colmap_mapping.sh
│ ├── database.py
│ ├── hloc_mapping
│ │ ├── extract_features.py
│ │ ├── extract_relative_poses.py
│ │ ├── filter_matches.py
│ │ ├── match_features.py
│ │ ├── pairs_from_retrieval.py
│ │ ├── reconstruction.py
│ │ ├── sfm_pipeline.py
│ │ ├── triangulate_from_existing_model.py
│ │ └── utils.py
│ ├── mapping.py
│ ├── read_write_model.py
│ ├── triangulate.sh
│ └── utils.py
└── train
│ └── train_ace_zero.sh
├── submodules
└── dsacstar
│ ├── dsacstar.cpp
│ ├── dsacstar_derivative.h
│ ├── dsacstar_loss.h
│ ├── dsacstar_types.h
│ ├── dsacstar_util.h
│ ├── dsacstar_util_rgbd.h
│ ├── setup.py
│ ├── stop_watch.h
│ ├── thread_rand.cpp
│ └── thread_rand.h
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | submodules/dsacstar/build
2 | submodules/dsacstar/dist
3 | submodules/dsacstar/dsacstar.egg-info
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The following files are under the license of ACE. Copyright © Niantic, Inc. 2022. Patent Pending:
2 |
3 | - Main ACE Files:
4 | - conerf/model/scene_regressor/ace_encoder_pretrained.pt
5 | - conerf/model/scene_regressor/ace_loss.py
6 | - conerf/model/scene_regressor/ace_network.py
7 | - conerf/model/scene_regressor/ace_util.py
8 | - conerf/trainers/ace_zero_trainer.py
9 |
10 | ------------------------------------------------------------------------------
11 |
12 | The rest of the files are under the MIT license:
13 |
14 | Copyright (c) 2024, Chen Yu
15 | All rights reserved.
16 |
17 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
18 | associated documentation files (the “Software”), to deal in the Software without restriction,
19 | including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
20 | and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
21 | subject to the following conditions:
22 |
23 | The above copyright notice and this permission notice shall be included in all copies or substantial
24 | portions of the Software.
25 |
26 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
27 | LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
28 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
30 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31 |
32 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
ZeroGS: Training 3D Gaussian Splatting from Unposed Images
2 |
3 | [[Project Page](https://aibluefisher.github.io/ZeroGS/) | [arXiv](https://arxiv.org/pdf/2411.15779)]
4 |
5 | ---------------------------
6 |
7 | ## 🛠️ Installation
8 |
9 | Install the conda environment of ZeroGS.
10 |
11 | ```sh
12 | conda create -n zero_gs python=3.9
13 | conda activate zero_gs
14 | cd ZeroGS/scripts
15 | ./scripts/env/install.sh
16 | ```
17 |
18 | **git hook for code style checking**:
19 | ```sh
20 | pre-commit install --hook-type pre-commit
21 | ```
22 |
23 |
24 | ## 🚀 Features
25 |
26 | - [x] Release [ACE0](https://nianticlabs.github.io/acezero) implementation
27 | - [ ] Incorporate [GLACE](https://github.com/cvg/glace) into ACE0
28 | - [ ] Release our customized 3D Gaussian Splatting module
29 | - [ ] Incorporate [Scaffold-GS](https://city-super.github.io/scaffold-gs)
30 | - [ ] Incorporate [DOGS](https://github.com/aibluefisher/dogs)
31 | - [ ] Release ZeroGS implementation
32 |
33 |
34 | ## 📋 Train & Eval ACE0
35 |
36 | We aim at providing a framework which makes it easy to implement your own neural implicit module with this codebase and since this project starts before the code releasing of ACE0, we re-implement ACE0 based on our codebase.
37 |
38 | ### ⌛Train ACE0
39 |
40 | Before training ACE0, please download the [pretrained feature encoder](https://github.com/nianticlabs/ace/blob/main/ace_encoder_pretrained.pt) from ACE, and put it under the folder `ZeroGS/conerf/model/scene_regressor`.
41 |
42 | ```bash
43 | conda activate zero_gs
44 | visdom -port=9000 # Keep the port the same as the `visdom_port` provided in the configuration file
45 | cd ZeroGS/scripts/train
46 | ./train_ace_zero.sh 0 ace_early_stop_resize_2k_anneal mipnerf360 ace
47 | ```
48 | We use `visdom` to visualize the camera pose predictions during training. You can access `https://localhost:9000` to view it.
49 |
50 | ### 📊 Evaluate ACE0
51 |
52 | ```bash
53 | conda activate zero_gs
54 | cd ZeroGS/scripts/eval
55 | ./eval_ace_zero.sh 0 ace_early_stop_resize_2k_anneal mipnerf360 ace
56 | ```
57 | Metrics file and camera poses will be recorded in `eval/val/` folder. Point clouds are recorded in the `eval/val/ACE0_COLMAP` (This folder also contains the model files in COLMAP formats) in `.ply` format.
58 |
59 | ### 🔢 Hyper Parameters for training ACE0
60 |
61 | All the parameters related to train ACE0 are provided the configuration file in `config/ace/mipnerf360.yaml`. Most of the parameters can be kept the same as in this configuration file. However, the parameters listed below need to be adjusted accordingly to obtain better performance:
62 | ```yaml
63 | trainer:
64 | # We can use less iterations for the `garden` scene (i.e. 2000).
65 | min_iterations_per_epoch: 5000
66 |
67 | pose_estimator:
68 | # Change this to a larger threshold (3000) for the 'garden` scene of the mipnerf360 dataset.
69 | min_inlier_count: 2000 # minimum number of inlier correspondences when registering an image
70 | ```
71 |
72 | A larger value in `min_iterations_per_epoch` can make the mapping more accurate, but also lead to longer training time.
73 |
74 |
75 | ## ✏️ Cite
76 |
77 | If you find this project useful for your research, please consider citing our paper:
78 | ```bibtex
79 | @inproceedings{yuchen2024zerogs,
80 | title={ZeroGS: Training 3D Gaussian Splatting from Unposed Images},
81 | author={Yu Chen, Rolandos Alexandros Potamias, Evangelos Ververas, Jifei Song, Jiankang Deng, Gim Hee Lee},
82 | booktitle={arXiv},
83 | year={2024},
84 | }
85 | ```
86 |
87 | ## 🙌 Acknowledgements
88 |
89 | This work is built upon [ACE](https://nianticlabs.github.io/ace/), [DUSt3R](https://github.com/naver/dust3r), and [Spann3R](https://hengyiwang.github.io/projects/spanner). We sincerely thank all the authors for releasing their code.
90 |
91 | ## 🪪 License
92 |
93 | Copyright © 2024, Chen Yu.
94 | All rights reserved.
95 | Please see the [license file](LICENSE) for terms.
96 |
--------------------------------------------------------------------------------
/assets/imgs/zero_gs_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/assets/imgs/zero_gs_logo.png
--------------------------------------------------------------------------------
/conerf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/conerf/__init__.py
--------------------------------------------------------------------------------
/conerf/base/model_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class ModelBase(torch.nn.Module):
5 | """
6 | An abstract class which defines some basic operations for a torch model.
7 | """
8 | def __init__(self, **kwargs) -> None:
9 | super().__init__()
10 |
11 | def to_distributed(self):
12 | """Change model to distributed mode."""
13 | raise NotImplementedError
14 |
15 | def switch_to_eval(self):
16 | """Change model to evaluation mode."""
17 | raise NotImplementedError
18 |
19 | def switch_to_train(self):
20 | """Change model to training mode."""
21 | raise NotImplementedError
22 |
23 | def forward(self, data, **kwargs):
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------
/conerf/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/conerf/datasets/__init__.py
--------------------------------------------------------------------------------
/conerf/datasets/ace_camera_loc_dataset.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1101
2 |
3 | import logging
4 | import random
5 | import math
6 |
7 | import imageio
8 | import torch
9 | import torchvision.transforms.functional as TF
10 | from skimage import color
11 | from skimage import io
12 | from skimage.transform import rotate
13 | from torch.utils.data import Dataset
14 | from torch.utils.data.dataloader import default_collate
15 | from torchvision import transforms
16 |
17 | from conerf.datasets.load_colmap import load_colmap
18 |
19 | _logger = logging.getLogger(__name__)
20 |
21 |
22 | class CamLocDataset(Dataset):
23 | """Camera localization dataset.
24 |
25 | Access to image, calibration and ground truth data given a dataset directory.
26 | """
27 |
28 | def __init__(
29 | self,
30 | # root_dir: str,
31 | root_fp: str,
32 | subject_id: str,
33 | val_interval: int = 0,
34 | scale: bool = True,
35 | rotate: bool = True,
36 | augment: bool = False,
37 | aug_rotation: int = 15,
38 | aug_scale_min: float = 2 / 3,
39 | aug_scale_max: float = 3 / 2,
40 | aug_black_white: float = 0.1,
41 | aug_color: float = 0.3,
42 | factor: int = 8,
43 | use_half: bool = True,
44 | ):
45 | """
46 | Params:
47 | root_dir: Folder of the data (training or test).
48 | augment: Use random data augmentation, note: note supported for mode=2 (RGB-D) since
49 | pre-generated eye coordinates cannot be augmented
50 | aug_rotation: Max 2D image rotation angle, sampled uniformly around 0, both
51 | directions, degrees
52 | aug_scale_min: Lower limit of image scale factor for uniform sampling
53 | aug_scale_max: Upper limit of image scale factor for uniform sampling
54 | aug_black_white: Max relative scale factor for image brightness/contrast sampling,
55 | e.g. 0.1 -> [0.9,1.1]
56 | aug_color: Max relative scale factor for image saturation/hue sampling, e.g.
57 | 0.1 -> [0.9,1.1]
58 | image_height: RGB images are rescaled to this maximum height (if augmentation is
59 | disabled, and in the range [aug_scale_min * image_height, aug_scale_max *
60 | image_height] otherwise).
61 | use_half: Enabled if training with half-precision floats.
62 | """
63 |
64 | self.use_half = use_half
65 | self.factor = factor
66 | self.augment = augment
67 | self.aug_rotation = aug_rotation
68 | self.aug_scale_min = aug_scale_min
69 | self.aug_scale_max = aug_scale_max
70 | self.aug_black_white = aug_black_white
71 | self.aug_color = aug_color
72 |
73 | data = load_colmap(
74 | root_fp, subject_id, split='train', factor=factor,
75 | val_interval=val_interval, scale=scale, rotate=rotate,
76 | )
77 | self.rgb_files = data['image_paths']
78 | self.gt_camtoworlds = data['poses']
79 |
80 | # We use this to iterate over all frames.
81 | self.valid_file_indices = {i: i for i in range(len(self.rgb_files))}
82 |
83 | # Try to read an image and get its width and height.
84 | image = imageio.imread(self.rgb_files[0]) # [H,W,3]
85 | # Use a fixed 480px image height since the convolutional feature backbone
86 | # is pretrained to ingest images scaled to 480px.
87 | self.origin_image_height, self.origin_image_width = image.shape[:2]
88 | # self.image_height = image.shape[0]
89 | # self.image_width = image.shape[1]
90 | self.image_height = 480
91 |
92 | # Image transformations. Excluding scale since that can vary batch-by-batch.
93 | if self.augment:
94 | self.image_transform = transforms.Compose([
95 | transforms.Grayscale(),
96 | transforms.ColorJitter(
97 | brightness=self.aug_black_white, contrast=self.aug_black_white),
98 | transforms.ToTensor(),
99 | transforms.Normalize(mean=[0.4], std=[0.25]),
100 | ])
101 | else:
102 | self.image_transform = transforms.Compose([
103 | transforms.Grayscale(),
104 | transforms.ToTensor(),
105 | transforms.Normalize(mean=[0.4], std=[0.25]),
106 | ])
107 |
108 | def image(self, idx):
109 | idx = self.valid_file_indices[idx]
110 | return self._load_image(idx)
111 |
112 | def image_tensor(self, idx):
113 | return torch.from_numpy(self.image(idx))
114 |
115 | def resized_image(self, idx, image_height: int, image_width: int = None):
116 | image = self.image(idx)
117 | return self._resize_image(image, image_height, image_width)
118 |
119 | def resized_grayscale_image(self, idx, image_height: int):
120 | color_image_pil = self.resized_image(idx, image_height)
121 | return color_image_pil, self.image_transform(color_image_pil)
122 |
123 | @staticmethod
124 | def _resize_image(image, image_height: int, image_width: int = None):
125 | # Resize a numpy image as PIL. Works slightly better than resizing the tensor
126 | # using torch's internal function.
127 | image = TF.to_pil_image(image)
128 | image = TF.resize(image, image_height) if image_width is None else \
129 | TF.resize(image, [image_height, image_width])
130 | return image
131 |
132 | @staticmethod
133 | def _rotate_image(image, angle, order, mode='constant'):
134 | # Image is a torch tensor (CxHxW), convert it to numpy as HxWxC.
135 | image = image.permute(1, 2, 0).numpy()
136 | # Apply rotation.
137 | image = rotate(image, angle, order=order, mode=mode)
138 | # Back to torch tensor.
139 | image = torch.from_numpy(image).permute(2, 0, 1).float()
140 | return image
141 |
142 | def _load_image(self, idx):
143 | image = io.imread(self.rgb_files[idx])
144 |
145 | if len(image.shape) < 3:
146 | # Convert to RGB if needed.
147 | image = color.gray2rgb(image)
148 |
149 | return image
150 |
151 | def _get_single_item(self, idx, image_height):
152 | # Apply index indirection.
153 | idx = self.valid_file_indices[idx]
154 |
155 | # Load image.
156 | image = self._load_image(idx)
157 |
158 | # Rescale image.
159 | image = self._resize_image(image, image_height)
160 |
161 | # Create mask of the same size as the resized image (it's a PIL image at this point).
162 | image_mask = torch.ones((1, image.size[1], image.size[0]))
163 |
164 | # Apply remaining transforms.
165 | image = self.image_transform(image)
166 |
167 | pose_rot = torch.eye(4)
168 |
169 | # Apply data augmentation if necessary.
170 | if self.augment:
171 | # Generate a random rotation angle.
172 | angle = random.uniform(-self.aug_rotation, self.aug_rotation)
173 |
174 | # Rotate input image and mask.
175 | image = self._rotate_image(image, angle, 1, 'reflect')
176 | image_mask = self._rotate_image(image_mask, angle, order=1, mode='constant')
177 |
178 | # Provide the rotation as well.
179 | # - pose = pose @ pose_rot
180 | angle = angle * math.pi / 180.
181 | pose_rot[0, 0] = math.cos(angle)
182 | pose_rot[0, 1] = -math.sin(angle)
183 | pose_rot[1, 0] = math.sin(angle)
184 | pose_rot[1, 1] = math.cos(angle)
185 |
186 | # Convert to half precision if needed.
187 | if self.use_half and torch.cuda.is_available():
188 | image = image.half()
189 |
190 | # Binarize the mask.
191 | image_mask = image_mask > 0
192 |
193 | # TODO(chenyu): shall we return the augmented status for latter 3D Gaussian Splatting?
194 |
195 | return image, image_mask, pose_rot, idx, str(self.rgb_files[idx])
196 |
197 | def __len__(self):
198 | return len(self.valid_file_indices)
199 |
200 | def __getitem__(self, idx):
201 | if self.augment:
202 | scale_factor = random.uniform(self.aug_scale_min, self.aug_scale_max)
203 | else:
204 | scale_factor = 1
205 |
206 | # Target image height. We compute it here in case we are asked for a full batch of tensors
207 | # because we need to apply the same scale factor to all of them.
208 | image_height = int(self.image_height * scale_factor)
209 |
210 | if type(idx) == list:
211 | # Whole batch.
212 | tensors = [self._get_single_item(i, image_height) for i in idx]
213 | return default_collate(tensors)
214 | else:
215 | # Single element.
216 | return self._get_single_item(idx, image_height)
217 |
--------------------------------------------------------------------------------
/conerf/geometry/align_poses.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1101
2 |
3 | import easydict as edict
4 |
5 | import torch
6 | import numpy as np
7 |
8 |
9 | def convert3x4_4x4(input):
10 | """
11 | Make into homogeneous coordinates by adding [0, 0, 0, 1] to the bottom.
12 | :param input: (N, 3, 4) or (3, 4) torch or np
13 | :return: (N, 4, 4) or (4, 4) torch or np
14 | """
15 | if torch.is_tensor(input):
16 | if len(input.shape) == 3:
17 | output = torch.cat([input, torch.zeros_like(
18 | input[:, 0:1])], dim=1) # (N, 4, 4)
19 | output[:, 3, 3] = 1.0
20 | else:
21 | output = torch.cat([input, torch.tensor(
22 | [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
23 | else:
24 | if len(input.shape) == 3:
25 | output = np.concatenate(
26 | [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
27 | output[:, 3, 3] = 1.0
28 | else:
29 | output = np.concatenate(
30 | [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4)
31 | output[3, 3] = 1.0
32 | return output
33 |
34 |
35 | def procrustes_analysis(X0, X1): # [N,3]
36 | # translation
37 | t0 = X0.mean(dim=0, keepdim=True)
38 | t1 = X1.mean(dim=0, keepdim=True)
39 | X0c = X0 - t0
40 | X1c = X1 - t1
41 |
42 | # scale
43 | s0 = (X0c ** 2).sum(dim=-1).mean().sqrt()
44 | s1 = (X1c ** 2).sum(dim=-1).mean().sqrt()
45 | X0cs = X0c / s0
46 | X1cs = X1c / s1
47 |
48 | # rotation (use double for SVD, float loses precision)
49 | U, S, V = (X0cs.t() @ X1cs).double().svd(some=True)
50 | R = (U @ V.t()).float()
51 | if R.det() < 0:
52 | R[2] *= -1
53 |
54 | # Align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0
55 | sim3 = edict(t0=t0[0], t1=t1[0], s0=s0, s1=s1, R=R)
56 | return sim3
57 |
58 |
59 | def get_best_yaw(C):
60 | '''
61 | maximize trace(Rz(theta) * C)
62 | '''
63 | assert C.shape == (3, 3)
64 |
65 | A = C[0, 1] - C[1, 0]
66 | B = C[0, 0] + C[1, 1]
67 | theta = np.pi / 2 - np.arctan2(B, A)
68 |
69 | return theta
70 |
71 |
72 | def align_umeyama(model, data, known_scale=False):
73 | """Implementation of the paper: S. Umeyama, Least-Squares Estimation
74 | of Transformation Parameters Between Two Point Patterns,
75 | IEEE Trans. Pattern Anal. Mach. Intell., vol. 13, no. 4, 1991.
76 |
77 | model = s * R * data + t
78 |
79 | Input:
80 | model -- first trajectory (nx3), numpy array type
81 | data -- second trajectory (nx3), numpy array type
82 |
83 | Output:
84 | s -- scale factor (scalar)
85 | R -- rotation matrix (3x3)
86 | t -- translation vector (3x1)
87 | t_error -- translational error per point (1xn)
88 |
89 | """
90 |
91 | # substract mean
92 | mu_M = model.mean(0)
93 | mu_D = data.mean(0)
94 | model_zerocentered = model - mu_M
95 | data_zerocentered = data - mu_D
96 | n = np.shape(model)[0]
97 |
98 | # correlation
99 | C = 1.0/n*np.dot(model_zerocentered.transpose(), data_zerocentered)
100 | sigma2 = 1.0/n*np.multiply(data_zerocentered, data_zerocentered).sum()
101 | U_svd, D_svd, V_svd = np.linalg.linalg.svd(C)
102 |
103 | D_svd = np.diag(D_svd)
104 | V_svd = np.transpose(V_svd)
105 |
106 | S = np.eye(3)
107 | if (np.linalg.det(U_svd)*np.linalg.det(V_svd) < 0):
108 | S[2, 2] = -1
109 |
110 | R = np.dot(U_svd, np.dot(S, np.transpose(V_svd)))
111 |
112 | if known_scale:
113 | s = 1
114 | else:
115 | s = 1.0/sigma2*np.trace(np.dot(D_svd, S))
116 |
117 | t = mu_M-s*np.dot(R, mu_D)
118 |
119 | return s, R, t
120 |
121 |
122 | def _getIndices(n_aligned, total_n):
123 | if n_aligned == -1:
124 | idxs = np.arange(0, total_n)
125 | else:
126 | assert n_aligned <= total_n and n_aligned >= 1
127 | idxs = np.arange(0, n_aligned)
128 | return idxs
129 |
130 |
131 | # align by similarity transformation
132 | def align_sim3(p_es, p_gt, n_aligned=-1):
133 | '''
134 | calculate s, R, t so that:
135 | gt = R * s * est + t
136 | '''
137 | idxs = _getIndices(n_aligned, p_es.shape[0])
138 | est_pos = p_es[idxs, 0:3]
139 | gt_pos = p_gt[idxs, 0:3]
140 | try:
141 | s, R, t = align_umeyama(gt_pos, est_pos) # note the order
142 | except: # pylint: disable=W0702
143 | print('[WARNING] align_poses.py: SVD did not converge!')
144 | s, R, t = 1.0, np.eye(3), np.zeros(3)
145 | return s, R, t
146 |
147 |
148 | def align_ate_c2b_use_a2b(traj_a, traj_b, traj_c=None):
149 | """Align c to b using the sim3 from a to b.
150 | :param traj_a: (N0, 3/4, 4) torch tensor
151 | :param traj_b: (N0, 3/4, 4) torch tensor
152 | :param traj_c: None or (N1, 3/4, 4) torch tensor
153 | :return: (N1, 4, 4) torch tensor
154 | """
155 | device = traj_a.device
156 | if traj_c is None:
157 | traj_c = traj_a.clone()
158 |
159 | traj_a = traj_a.float().cpu().numpy()
160 | traj_b = traj_b.float().cpu().numpy()
161 | traj_c = traj_c.float().cpu().numpy()
162 |
163 | # R_a = traj_a[:, :3, :3] # (N0, 3, 3)
164 | t_a = traj_a[:, :3, 3] # (N0, 3)
165 |
166 | # R_b = traj_b[:, :3, :3] # (N0, 3, 3)
167 | t_b = traj_b[:, :3, 3] # (N0, 3)
168 |
169 | # This function works in quaternion.
170 | # scalar, (3, 3), (3, ) gt = R * s * est + t.
171 | s, R, t = align_sim3(t_a, t_b)
172 |
173 | # reshape tensors
174 | R = R[None, :, :].astype(np.float32) # (1, 3, 3)
175 | t = t[None, :, None].astype(np.float32) # (1, 3, 1)
176 | s = float(s)
177 |
178 | R_c = traj_c[:, :3, :3] # (N1, 3, 3)
179 | t_c = traj_c[:, :3, 3:4] # (N1, 3, 1)
180 |
181 | R_c_aligned = R @ R_c # (N1, 3, 3)
182 | t_c_aligned = s * (R @ t_c) + t # (N1, 3, 1)
183 | traj_c_aligned = np.concatenate(
184 | [R_c_aligned, t_c_aligned], axis=2) # (N1, 3, 4)
185 |
186 | # append the last row
187 | traj_c_aligned = convert3x4_4x4(traj_c_aligned) # (N1, 4, 4)
188 |
189 | traj_c_aligned = torch.from_numpy(traj_c_aligned).to(device)
190 |
191 | return traj_c_aligned, s, R, t # (N1, 4, 4)
192 |
--------------------------------------------------------------------------------
/conerf/geometry/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def todevice(batch, device, callback=None, non_blocking=False):
6 | ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
7 |
8 | batch: list, tuple, dict of tensors or other things
9 | device: pytorch device or 'numpy'
10 | callback: function that would be called on every sub-elements.
11 | '''
12 | if callback:
13 | batch = callback(batch)
14 |
15 | if isinstance(batch, dict):
16 | return {k: todevice(v, device) for k, v in batch.items()}
17 |
18 | if isinstance(batch, (tuple, list)):
19 | return type(batch)(todevice(x, device) for x in batch)
20 |
21 | x = batch
22 | if device == 'numpy':
23 | if isinstance(x, torch.Tensor):
24 | x = x.detach().cpu().numpy()
25 | elif x is not None:
26 | if isinstance(x, np.ndarray):
27 | x = torch.from_numpy(x)
28 | if torch.is_tensor(x):
29 | x = x.to(device, non_blocking=non_blocking)
30 | return x
31 |
32 |
33 | def to_numpy(x):
34 | return todevice(x, 'numpy')
35 |
36 |
37 | def to_cpu(x):
38 | return todevice(x, 'cpu')
39 |
40 |
41 | def to_cuda(x):
42 | return todevice(x, 'cuda')
43 |
--------------------------------------------------------------------------------
/conerf/loss/ssim_torch.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=[E1102,C0103,R0903]
2 |
3 | from math import exp
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from torch.autograd import Variable
9 |
10 |
11 | def gaussian(window_size, sigma):
12 | gauss = torch.Tensor([
13 | exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) \
14 | for x in range(window_size)
15 | ])
16 | return gauss / gauss.sum()
17 |
18 |
19 | def create_window(window_size, channel):
20 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
21 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
22 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
23 | return window
24 |
25 |
26 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
27 | mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
28 | mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
29 |
30 | mu1_sq = mu1.pow(2)
31 | mu2_sq = mu2.pow(2)
32 | mu1_mu2 = mu1*mu2
33 |
34 | sigma1_sq = F.conv2d(
35 | img1 * img1, window, padding=window_size//2, groups=channel
36 | ) - mu1_sq
37 | sigma2_sq = F.conv2d(
38 | img2 * img2, window, padding=window_size//2, groups=channel
39 | ) - mu2_sq
40 | sigma12 = F.conv2d(
41 | img1 * img2, window, padding=window_size//2, groups=channel
42 | ) - mu1_mu2
43 |
44 | C1 = 0.01 ** 2
45 | C2 = 0.03 ** 2
46 |
47 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((
48 | mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
49 |
50 | if size_average:
51 | return ssim_map.mean()
52 | return ssim_map.mean(1).mean(1).mean(1)
53 |
54 |
55 | class SSIM(torch.nn.Module):
56 | def __init__(self, window_size=11, size_average = True):
57 | super().__init__()
58 | self.window_size = window_size
59 | self.size_average = size_average
60 | self.channel = 1
61 | self.window = create_window(window_size, self.channel)
62 |
63 | def forward(self, img1, img2):
64 | channel = img1.size(-3)
65 |
66 | if channel == self.channel and self.window.data.type() == img1.data.type():
67 | window = self.window
68 | else:
69 | window = create_window(self.window_size, channel)
70 |
71 | if img1.is_cuda:
72 | window = window.cuda(img1.get_device())
73 | window = window.type_as(img1)
74 |
75 | self.window = window
76 | self.channel = channel
77 |
78 |
79 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
80 |
81 |
82 | def ssim(img1, img2, window_size = 11, size_average = True):
83 | channel = img1.size(-3)
84 | window = create_window(window_size, channel)
85 |
86 | if img1.is_cuda:
87 | window = window.cuda(img1.get_device())
88 | window = window.type_as(img1)
89 |
90 | return _ssim(img1, img2, window, window_size, channel, size_average)
91 |
--------------------------------------------------------------------------------
/conerf/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/conerf/model/__init__.py
--------------------------------------------------------------------------------
/conerf/model/backbone/activations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Gaussian(nn.Module):
6 | """
7 | Gaussian activation function.
8 | """
9 | def __init__(
10 | self,
11 | mean: float = 0.0,
12 | sigma: float = 0.1,
13 | ) -> None:
14 | super().__init__()
15 |
16 | self.mean = mean
17 | self.sigma = sigma
18 | self.sigma_square = self.sigma ** 2
19 |
20 | def forward(self, x: torch.Tensor) -> torch.Tensor:
21 | return ( # torch.exp(
22 | -0.5 * (x ** 2) / self.sigma_square
23 | ).exp()
24 |
--------------------------------------------------------------------------------
/conerf/model/backbone/encodings.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1101
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 | from typing import List, Callable, Dict
9 |
10 |
11 | class SinusoidalEncoder(nn.Module):
12 | """Sinusoidal Positional Encoder used in Nerf."""
13 |
14 | def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True):
15 | super().__init__()
16 | self.x_dim = x_dim
17 | self.min_deg = min_deg
18 | self.max_deg = max_deg
19 | self.use_identity = use_identity
20 | self.c2f = None
21 | self.register_buffer(
22 | "scales", torch.tensor([2**i for i in range(min_deg, max_deg)])
23 | )
24 |
25 | @property
26 | def latent_dim(self) -> int:
27 | return (
28 | int(self.use_identity) + (self.max_deg - self.min_deg) * 2
29 | ) * self.x_dim
30 |
31 | def forward(self, x: torch.Tensor) -> torch.Tensor:
32 | """
33 | Args:
34 | x: [..., x_dim]
35 | Returns:
36 | latent: [..., latent_dim]
37 | """
38 | if self.max_deg == self.min_deg:
39 | return x
40 | xb = torch.reshape(
41 | (x[Ellipsis, None, :] * self.scales[:, None]),
42 | list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
43 | )
44 | latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1))
45 | if self.use_identity:
46 | latent = torch.cat([x] + [latent], dim=-1)
47 | return latent
48 |
49 |
50 | class ProgressiveSinusoidalEncoder(SinusoidalEncoder):
51 | """
52 | Coarse-to-fine positional encodings.
53 | """
54 |
55 | def __init__(
56 | self,
57 | x_dim: int,
58 | min_deg: int,
59 | max_deg: int,
60 | use_identity: bool = True,
61 | c2f: List = [0.1, 0.5],
62 | half_dim: bool = False,
63 | ):
64 | super().__init__(x_dim, min_deg, max_deg, use_identity)
65 |
66 | # Use nn.Parameter so it could be checkpointed.
67 | self.progress = torch.nn.Parameter(torch.tensor(0.), requires_grad=False)
68 |
69 | self.c2f = c2f
70 | self.half_dim = half_dim
71 |
72 | @property
73 | def latent_dim(self) -> int:
74 | latent_dim = super().latent_dim
75 | if self.half_dim:
76 | latent_dim = (latent_dim - self.x_dim) // 2 + self.x_dim
77 | return latent_dim
78 |
79 | def anneal(
80 | self,
81 | iteration: int,
82 | max_iteration: int,
83 | factor: float = 1.0,
84 | reduction: float = 0.0,
85 | bias: float = 0.0,
86 | anneal_surface: bool = False,
87 | ):
88 | """
89 | Gradually increase the controllable parameter during training.
90 | """
91 | if anneal_surface:
92 | if iteration > max_iteration // 2:
93 | progress_data = 1.0
94 | else:
95 | progress_data = 0.5 + float(iteration) / float(max_iteration)
96 | else:
97 | # For camera pose annealing.
98 | progress_data = float(iteration) / float(max_iteration)
99 |
100 | progress_data = factor * (progress_data - reduction) + bias
101 |
102 | self.progress.data.fill_(progress_data)
103 |
104 | def forward(self, x: torch.Tensor) -> torch.Tensor:
105 | latent = super().forward(x)
106 | latent_dim = super().latent_dim
107 |
108 | # Computing weights.
109 | start, end = self.c2f
110 | alpha = (self.progress.data - start) / (end - start) * self.max_deg
111 | ks = torch.arange(self.min_deg, self.max_deg,
112 | dtype=torch.float32, device=x.device)
113 | weight = (
114 | 1.0 - (alpha - ks).clamp_(min=0, max=1).mul_(np.pi).cos_()
115 | ) / 2.0
116 |
117 | # Apply weight to positional encodings.
118 | shape = latent.shape
119 | L = self.max_deg - self.min_deg
120 |
121 | if self.use_identity:
122 | latent_freq = latent[:, self.x_dim:].reshape(-1, L)
123 | latent_freq = (
124 | latent_freq * weight).reshape(shape[0], shape[-1] - self.x_dim)
125 | latent[:, self.x_dim:] = latent_freq
126 | else:
127 | latent = (latent.reshape(-1, L) * weight).reshape(*shape)
128 |
129 | if self.half_dim:
130 | half_freq = L // 2
131 | # input coordinates are excluded.
132 | half_latent_dim = (latent_dim - self.x_dim) // 2
133 | num_feat_each_band = (latent_dim - self.x_dim) // L
134 | half_latent = latent[:, self.x_dim:].view(-1, L, num_feat_each_band)[
135 | :, :half_freq, :].view(-1, half_latent_dim)
136 |
137 | half_latent_contg = latent[:, self.x_dim:].view(-1, L, num_feat_each_band)[
138 | :, :half_freq, :].view(-1, half_latent_dim).contiguous()
139 | half_latent_contg = (
140 | half_latent_contg.view(-1, half_freq) * weight[:half_freq]
141 | ).view(-1, half_latent_dim)
142 | flag = weight[:half_freq].tile(shape[0], num_feat_each_band, 1).transpose(
143 | 1, 2).contiguous().view(-1, half_latent_dim)
144 | half_latent = torch.where(
145 | flag > 0.01, half_latent, half_latent_contg)
146 | latent = torch.cat([latent[:, :self.x_dim], half_latent], dim=-1)
147 |
148 | return latent
149 |
150 |
151 | class GaussianEncoder(nn.Module):
152 | """
153 | Gaussian encodings.
154 | """
155 |
156 | def __init__(
157 | self,
158 | x_dim: int,
159 | feature_dim: int,
160 | init_func: Callable = nn.init.uniform_,
161 | init_range: float = 0.1,
162 | sigma: float = 0.1,
163 | ) -> None:
164 | super().__init__()
165 |
166 | self.init_func = init_func
167 | self.init_range = init_range
168 | self.sigma = sigma
169 | self.sigma_square = sigma ** 2
170 | self.latent_dim = feature_dim
171 |
172 | gaussian_linear = torch.nn.Linear(x_dim, feature_dim)
173 | self.init_func(gaussian_linear.weight, -
174 | self.init_range, self.init_range)
175 | self.gaussian_linear = nn.utils.weight_norm(gaussian_linear)
176 |
177 | def forward(self, x: torch.Tensor) -> torch.Tensor:
178 | x = self.gaussian_linear(x)
179 | mu = torch.mean(x, axis=-1).unsqueeze(-1)
180 | x = torch.exp(
181 | -0.5 * ((x - mu) ** 2) / self.sigma_square
182 | )
183 | return x
184 |
185 |
186 | def create_encoder(x_dim: int, config: Dict):
187 | """
188 | Factory function for creating encodings that applied to coordinate input.
189 | """
190 | encoder_type = config["type"]
191 | if encoder_type == "sinusoidal":
192 | return SinusoidalEncoder(
193 | x_dim=x_dim,
194 | min_deg=config["min_deg"],
195 | max_deg=config["max_deg"],
196 | use_identity=config["use_identity"],
197 | )
198 | elif encoder_type == "progressive":
199 | return ProgressiveSinusoidalEncoder(
200 | x_dim=x_dim,
201 | min_deg=config["min_deg"],
202 | max_deg=config["max_deg"],
203 | use_identity=config["use_identity"],
204 | c2f=config["c2f"],
205 | half_dim=config["half_dim"],
206 | )
207 | elif encoder_type == "gaussian":
208 | return GaussianEncoder(
209 | x_dim=x_dim,
210 | feature_dim=config["feature_dim"] // 2,
211 | init_range=config["init_range"],
212 | sigma=config["sigma"],
213 | )
214 | else:
215 | raise NotImplementedError
216 |
--------------------------------------------------------------------------------
/conerf/model/backbone/mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, List
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 |
9 | class MLP(nn.Module):
10 | def __init__(
11 | self,
12 | input_dim: int, # The number of input tensor channels.
13 | output_dim: int = None, # The number of output tensor channels.
14 | net_depth: int = 8, # The depth of the MLP.
15 | net_width: int = 256, # The width of the MLP.
16 | skip_layer: int = 4, # The layer to add skip layers to.
17 | hidden_init: Callable = nn.init.xavier_uniform_,
18 | hidden_activation: Callable = nn.ReLU(),
19 | output_enabled: bool = True,
20 | output_init: Optional[Callable] = nn.init.xavier_uniform_,
21 | output_activation: Optional[Callable] = nn.Identity(),
22 | bias_enabled: bool = True,
23 | bias_init: Callable = nn.init.zeros_,
24 | ):
25 | super().__init__()
26 | self.input_dim = input_dim
27 | self.output_dim = output_dim
28 | self.net_depth = net_depth
29 | self.net_width = net_width
30 | self.skip_layer = skip_layer
31 | self.hidden_init = hidden_init
32 | self.hidden_activation = hidden_activation
33 | self.output_enabled = output_enabled
34 | self.output_init = output_init
35 | self.output_activation = output_activation
36 | self.bias_enabled = bias_enabled
37 | self.bias_init = bias_init
38 |
39 | self.hidden_layers = nn.ModuleList()
40 | in_features = self.input_dim
41 | for i in range(self.net_depth):
42 | self.hidden_layers.append(
43 | nn.Linear(in_features, self.net_width, bias=bias_enabled)
44 | )
45 | if (
46 | (self.skip_layer is not None)
47 | and (i % self.skip_layer == 0)
48 | and (i > 0)
49 | ):
50 | in_features = self.net_width + self.input_dim
51 | else:
52 | in_features = self.net_width
53 | if self.output_enabled:
54 | self.output_layer = nn.Linear(
55 | in_features, self.output_dim, bias=bias_enabled
56 | )
57 | else:
58 | self.output_dim = in_features
59 |
60 | self.initialize()
61 |
62 | def initialize(self):
63 | def init_func_hidden(m):
64 | if isinstance(m, nn.Linear):
65 | if self.hidden_init is not None:
66 | self.hidden_init(m.weight)
67 | if self.bias_enabled and self.bias_init is not None:
68 | self.bias_init(m.bias)
69 |
70 | self.hidden_layers.apply(init_func_hidden)
71 | if self.output_enabled:
72 |
73 | def init_func_output(m):
74 | if isinstance(m, nn.Linear):
75 | if self.output_init is not None:
76 | self.output_init(m.weight)
77 | if self.bias_enabled and self.bias_init is not None:
78 | self.bias_init(m.bias)
79 |
80 | self.output_layer.apply(init_func_output)
81 |
82 | def forward(self, x):
83 | inputs = x
84 | for i in range(self.net_depth):
85 | x = self.hidden_layers[i](x)
86 | x = self.hidden_activation(x)
87 | if (
88 | (self.skip_layer is not None)
89 | and (i % self.skip_layer == 0)
90 | and (i > 0)
91 | ):
92 | x = torch.cat([x, inputs], dim=-1)
93 | if self.output_enabled:
94 | x = self.output_layer(x)
95 | x = self.output_activation(x)
96 | return x
97 |
98 |
99 | class DenseLayer(MLP):
100 | def __init__(self, input_dim, output_dim, **kwargs):
101 | super().__init__(
102 | input_dim=input_dim,
103 | output_dim=output_dim,
104 | net_depth=0, # no hidden layers
105 | **kwargs,
106 | )
107 |
108 |
109 | class NormalizedMLP(nn.Module):
110 | def __init__(
111 | self,
112 | input_dim: int,
113 | output_dim: int = None,
114 | net_depth: int = 8,
115 | net_width: int = 256,
116 | skip_layer: List = [4],
117 | hidden_activation: Callable = nn.ReLU(),
118 | bias: float = 0.5,
119 | weight_norm: bool = True,
120 | geometric_init: bool = True,
121 | ) -> None:
122 | super().__init__()
123 |
124 | self.input_dim = input_dim
125 | self.output_dim = output_dim
126 | self.net_depth = net_depth
127 | self.net_width = net_width
128 | self.skip_layer = skip_layer
129 | self.hidden_activation = hidden_activation
130 | self.bias = bias
131 | dims = [input_dim] + [net_width for _ in range(net_depth)] + [output_dim]
132 | self.num_layers = len(dims)
133 |
134 | for i in range(0, self.num_layers - 1):
135 | if (self.skip_layer is not None) and (i + 1) in self.skip_layer:
136 | out_dim = dims[i + 1] - dims[0]
137 | else:
138 | out_dim = dims[i + 1]
139 |
140 | lin = nn.Linear(dims[i], out_dim)
141 |
142 | if geometric_init:
143 | if i == self.num_layers - 2:
144 | torch.nn.init.normal_(
145 | lin.weight,
146 | mean=np.sqrt(np.pi) / np.sqrt(dims[i]),
147 | std=0.0001
148 | )
149 | torch.nn.init.constant_(lin.bias, -bias)
150 | elif i == 0:
151 | torch.nn.init.constant_(lin.bias, 0.0)
152 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
153 | torch.nn.init.normal_(
154 | lin.weight[:, :3],
155 | 0.0,
156 | np.sqrt(2) / np.sqrt(out_dim)
157 | )
158 | elif (self.skip_layer is not None) and (i in self.skip_layer):
159 | torch.nn.init.constant_(lin.bias, 0.0)
160 | torch.nn.init.normal_(
161 | lin.weight,
162 | 0.0,
163 | np.sqrt(2) / np.sqrt(out_dim)
164 | )
165 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
166 | else:
167 | torch.nn.init.constant_(lin.bias, 0.0)
168 | torch.nn.init.normal_(
169 | lin.weight,
170 | 0.0,
171 | np.sqrt(2) / np.sqrt(out_dim)
172 | )
173 |
174 | if weight_norm:
175 | lin = nn.utils.weight_norm(lin)
176 |
177 | setattr(self, "lin" + str(i), lin)
178 |
179 | def forward(self, x):
180 | inputs = x
181 | for i in range(0, self.num_layers - 1):
182 | lin = getattr(self, "lin" + str(i))
183 |
184 | if (self.skip_layer is not None) and (i in self.skip_layer):
185 | x = torch.cat([x, inputs], dim=-1) / np.sqrt(2)
186 |
187 | x = lin(x)
188 |
189 | if i < self.num_layers - 2:
190 | x = self.hidden_activation(x)
191 |
192 | return x
193 |
--------------------------------------------------------------------------------
/conerf/model/backbone/resnet3d.py:
--------------------------------------------------------------------------------
1 | # Code is adapted from: https://github.com/DonGovi/pyramid-detection-3D
2 |
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 |
10 |
11 | __all__ = ['ResNet3D', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
12 | 'resnet152']
13 |
14 | '''
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | }
22 | '''
23 |
24 |
25 | def conv3x3x3(in_planes, out_planes, stride=1):
26 | """3x3x3 convolution with padding"""
27 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
28 | padding=1, bias=False)
29 | '''
30 | def downsample_basic_block(x, planes, stride):
31 | out = F.avg_pool3d(x, kernel_size=1, stride=stride)
32 | zero_pads = torch.Tensor(
33 | out.size(0), planes - out.size(1), out.size(2), out.size(3),
34 | out.size(4)).zero_()
35 | if isinstance(out.data, torch.cuda.FloatTensor):
36 | zero_pads = zero_pads.cuda()
37 |
38 | out = Variable(torch.cat([out.data, zero_pads], dim=1))
39 |
40 | return out
41 | '''
42 |
43 | class BasicBlock(nn.Module):
44 | expansion = 1
45 |
46 | def __init__(self, in_planes, planes, stride=1, downsample=None):
47 | super(BasicBlock, self).__init__()
48 | self.conv1 = conv3x3x3(in_planes, planes, stride)
49 | self.bn1 = nn.BatchNorm3d(planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.conv2 = conv3x3x3(planes, planes)
52 | self.bn2 = nn.BatchNorm3d(planes)
53 | self.downsample = downsample
54 | self.stride = stride
55 |
56 | def forward(self, x):
57 | residual = x
58 |
59 | out = self.conv1(x)
60 | out = self.bn1(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv2(out)
64 | #conv2_rep = out
65 | out = self.bn2(out)
66 |
67 | if self.downsample is not None:
68 | residual = self.downsample(x)
69 |
70 | out += residual
71 | out = self.relu(out)
72 |
73 | return out
74 |
75 |
76 | class Bottleneck(nn.Module):
77 | expansion = 4
78 |
79 | def __init__(self, in_planes, planes, stride=1, downsample=None):
80 | super(Bottleneck, self).__init__()
81 | self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False)
82 | self.bn1 = nn.BatchNorm3d(planes)
83 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
84 | padding=1, bias=False)
85 | self.bn2 = nn.BatchNorm3d(planes)
86 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
87 | self.bn3 = nn.BatchNorm3d(planes * 4)
88 | self.relu = nn.ReLU(inplace=True)
89 | self.downsample = downsample
90 | self.stride = stride
91 |
92 | def forward(self, x):
93 | residual = x
94 |
95 | out = self.conv1(x)
96 | out = self.bn1(out)
97 | out = self.relu(out)
98 |
99 | out = self.conv2(out)
100 | out = self.bn2(out)
101 | out = self.relu(out)
102 |
103 | out = self.conv3(out)
104 | #conv3_rep = out
105 | out = self.bn3(out)
106 |
107 | if self.downsample is not None:
108 | residual = self.downsample(x)
109 |
110 | out += residual
111 | out = self.relu(out)
112 |
113 | return out
114 |
115 |
116 | class ResNet3D(nn.Module):
117 | def __init__(self, in_channels, block, layers):
118 | self.in_planes = 64
119 | super(ResNet3D, self).__init__()
120 | self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=5, stride=2, padding=2, bias=False) # 128 -> 64
121 | self.bn1 = nn.BatchNorm3d(64)
122 | self.relu = nn.ReLU(inplace=True)
123 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) # 64 -> 32
124 | self.layer1 = self._make_layer(block, 64, layers[0])
125 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # 32 -> 16
126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 16 -> 8
127 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 8 -> 4
128 | '''
129 | self.avgpool = nn.AvgPool3d(7, stride=1)
130 | self.fc = nn.Linear(512 * block.expansion, num_classes)
131 | '''
132 |
133 | for m in self.modules():
134 | if isinstance(m, nn.Conv3d):
135 | m.weight = nn.init.xavier_normal_(m.weight)
136 | elif isinstance(m, nn.BatchNorm3d):
137 | m.weight.data.fill_(1)
138 | m.bias.data.zero_()
139 |
140 | def _make_layer(self, block, planes, blocks, stride=1):
141 | downsample = None
142 | if stride != 1 or self.in_planes != planes * block.expansion:
143 | downsample = nn.Sequential(
144 | nn.Conv3d(self.in_planes, planes * block.expansion,
145 | kernel_size=1, stride=stride, bias=False),
146 | nn.BatchNorm3d(planes * block.expansion)
147 | )
148 |
149 | layers = []
150 | layers.append(block(self.in_planes, planes, stride, downsample))
151 | self.in_planes = planes * block.expansion
152 | for i in range(1, blocks):
153 | layers.append(block(self.in_planes, planes))
154 |
155 | return nn.Sequential(*layers)
156 |
157 | def forward(self, x): # 128
158 | c1 = self.conv1(x) # 64 --> 8 anchor_area
159 | c1 = self.bn1(c1)
160 | c1 = self.relu(c1)
161 | c2 = self.maxpool(c1) # 32
162 |
163 | c2 = self.layer1(c2) # 32 --> 16 anchor_area
164 | c3 = self.layer2(c2) # 16 --> 32 anchor_area
165 | c4 = self.layer3(c3) # 8
166 | c5 = self.layer4(c4) # 4
167 | '''
168 | x = self.avgpool(x)
169 | x = x.view(x.size(0), -1)
170 | x = self.fc(x)
171 | '''
172 | return c1, c2, c3, c4, c5
173 |
174 |
175 | def resnet18(in_channels=3, pretrained=False, **kwargs):
176 | """Constructs a ResNet-18 model.
177 | Args:
178 | pretrained (bool): If True, returns a model pre-trained on ImageNet
179 | """
180 | model = ResNet3D(in_channels, BasicBlock, [2, 2, 2, 2], **kwargs)
181 | # if pretrained:
182 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
183 | return model
184 |
185 |
186 | def resnet34(in_channels=3, pretrained=False, **kwargs):
187 | """Constructs a ResNet-34 model.
188 | Args:
189 | pretrained (bool): If True, returns a model pre-trained on ImageNet
190 | """
191 | model = ResNet3D(in_channels, BasicBlock, [3, 4, 6, 3], **kwargs)
192 | # if pretrained:
193 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
194 | return model
195 |
196 |
197 | def resnet50(in_channels=3, pretrained=False, **kwargs):
198 | """Constructs a ResNet-50 model.
199 | Args:
200 | pretrained (bool): If True, returns a model pre-trained on ImageNet
201 | """
202 | model = ResNet3D(in_channels, Bottleneck, [3, 4, 6, 3], **kwargs)
203 | # if pretrained:
204 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
205 | return model
206 |
207 |
208 | def resnet101(in_channels=3, pretrained=False, **kwargs):
209 | """Constructs a ResNet-101 model.
210 | Args:
211 | pretrained (bool): If True, returns a model pre-trained on ImageNet
212 | """
213 | model = ResNet3D(in_channels, Bottleneck, [3, 4, 23, 3], **kwargs)
214 | # if pretrained:
215 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
216 | return model
217 |
218 |
219 | def resnet152(in_channels=3, pretrained=False, **kwargs):
220 | """Constructs a ResNet-152 model.
221 | Args:
222 | pretrained (bool): If True, returns a model pre-trained on ImageNet
223 | """
224 | model = ResNet3D(in_channels, Bottleneck, [3, 8, 36, 3], **kwargs)
225 | # if pretrained:
226 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
227 | return model
228 |
--------------------------------------------------------------------------------
/conerf/model/misc.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=[E1101,W0108]
2 |
3 | import gc
4 | from collections import defaultdict
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.autograd import Function
9 | from torch.cuda.amp import custom_bwd, custom_fwd
10 |
11 | import tinycudann as tcnn
12 |
13 |
14 | def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs):
15 | B = None
16 | for arg in args:
17 | if isinstance(arg, torch.Tensor):
18 | B = arg.shape[0]
19 | break
20 | out = defaultdict(list)
21 | out_type = None
22 | chunk_length = 0
23 | for i in range(0, B, chunk_size):
24 | out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs)
25 | if out_chunk is None:
26 | continue
27 | out_type = type(out_chunk)
28 | if isinstance(out_chunk, torch.Tensor):
29 | out_chunk = {0: out_chunk}
30 | elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
31 | chunk_length = len(out_chunk)
32 | out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
33 | elif isinstance(out_chunk, dict):
34 | pass
35 | else:
36 | print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.')
37 | exit(1)
38 | for k, v in out_chunk.items():
39 | if v is None:
40 | chunk_length -= 1
41 | continue
42 | v = v if torch.is_grad_enabled() else v.detach()
43 | v = v.cpu() if move_to_cpu else v
44 | out[k].append(v)
45 |
46 | if out_type is None:
47 | return
48 |
49 | out = {k: torch.cat(v, dim=0) for k, v in out.items()}
50 | if out_type is torch.Tensor:
51 | return out[0]
52 | elif out_type in [tuple, list]:
53 | # return out_type([out[i] for i in range(chunk_length)])
54 | return out_type([out[i] for i in out.keys()])
55 | elif out_type is dict:
56 | return out
57 |
58 |
59 | class _TruncExp(Function): # pylint: disable=abstract-method
60 | # Implementation from torch-ngp:
61 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
62 | @staticmethod
63 | @custom_fwd(cast_inputs=torch.float32)
64 | def forward(ctx, x): # pylint: disable=arguments-differ
65 | ctx.save_for_backward(x)
66 | return torch.exp(x)
67 |
68 | @staticmethod
69 | @custom_bwd
70 | def backward(ctx, g): # pylint: disable=arguments-differ
71 | x = ctx.saved_tensors[0]
72 | return g * torch.exp(torch.clamp(x, max=15))
73 |
74 | trunc_exp = _TruncExp.apply
75 |
76 |
77 | def get_activation(name):
78 | if name is None:
79 | return lambda x: x
80 | name = name.lower()
81 | if name == 'none':
82 | return lambda x: x
83 | elif name.startswith('scale'):
84 | scale_factor = float(name[5:])
85 | return lambda x: x.clamp(0., scale_factor) / scale_factor
86 | elif name.startswith('clamp'):
87 | clamp_max = float(name[5:])
88 | return lambda x: x.clamp(0., clamp_max)
89 | elif name.startswith('mul'):
90 | mul_factor = float(name[3:])
91 | return lambda x: x * mul_factor
92 | elif name == 'lin2srgb':
93 | return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.)
94 | elif name == 'trunc_exp':
95 | return trunc_exp
96 | elif name.startswith('+') or name.startswith('-'):
97 | return lambda x: x + float(name)
98 | elif name == 'sigmoid':
99 | return lambda x: torch.sigmoid(x)
100 | elif name == 'tanh':
101 | return lambda x: torch.tanh(x)
102 | else:
103 | return getattr(F, name)
104 |
105 |
106 | def dot(x, y):
107 | return torch.sum(x*y, -1, keepdim=True)
108 |
109 |
110 | def reflect(x, n):
111 | return 2 * dot(x, n) * n - x
112 |
113 |
114 | def cleanup():
115 | gc.collect()
116 | torch.cuda.empty_cache()
117 | tcnn.free_temporary_memory()
--------------------------------------------------------------------------------
/conerf/model/scene_regressor/ace_encoder_pretrained.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIBluefisher/ZeroGS/de01ec444296b887d610939ac9b0abf276ab54b1/conerf/model/scene_regressor/ace_encoder_pretrained.pt
--------------------------------------------------------------------------------
/conerf/model/scene_regressor/ace_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright © Niantic, Inc. 2022.
2 |
3 | import math
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def weighted_tanh(repro_errs, weight):
9 | return weight * torch.tanh(repro_errs / weight).sum()
10 |
11 |
12 | class ReproLoss:
13 | """
14 | Compute per-pixel reprojection loss using different configurable approaches.
15 |
16 | - tanh: tanh loss with a constant scale factor given by the `soft_clamp` parameter
17 | (when a pixel's reprojection error is equal to `soft_clamp`, its loss is equal
18 | to `soft_clamp * tanh(1)`).
19 | - dyntanh: Used in the paper, similar to the tanh loss above, but the scaling factor
20 | decreases during the course of the training from `soft_clamp` to `soft_clamp_min`.
21 | The decrease is linear, unless `circle_schedule` is True (default), in which
22 | case it applies a circular scheduling. See paper for details.
23 | - l1: Standard L1 loss, computed only on those pixels having an error lower than `soft_clamp`
24 | - l1+sqrt: L1 loss for pixels with reprojection error smaller than `soft_clamp` and
25 | `sqrt(soft_clamp * reprojection_error)` for pixels with a higher error.
26 | - l1+logl1: Similar to the above, but using log L1 for pixels with high reprojection error.
27 | """
28 |
29 | def __init__(self,
30 | total_iterations,
31 | soft_clamp,
32 | soft_clamp_min,
33 | type='dyntanh',
34 | circle_schedule=True):
35 |
36 | self.total_iterations = total_iterations
37 | self.soft_clamp = soft_clamp
38 | self.soft_clamp_min = soft_clamp_min
39 | self.type = type
40 | self.circle_schedule = circle_schedule
41 |
42 | def compute(self, repro_errs_b1N, iteration):
43 | if repro_errs_b1N.nelement() == 0:
44 | return 0
45 |
46 | if self.type == "tanh":
47 | return weighted_tanh(repro_errs_b1N, self.soft_clamp)
48 |
49 | elif self.type == "dyntanh":
50 | # Compute the progress over the training process.
51 | schedule_weight = iteration / self.total_iterations
52 |
53 | if self.circle_schedule:
54 | # Optionally scale it using the circular schedule.
55 | schedule_weight = 1 - math.sqrt(1 - schedule_weight ** 2)
56 |
57 | # Compute the weight to use in the tanh loss.
58 | loss_weight = (1 - schedule_weight) * self.soft_clamp + self.soft_clamp_min
59 |
60 | # Compute actual loss.
61 | return weighted_tanh(repro_errs_b1N, loss_weight)
62 |
63 | elif self.type == "l1":
64 | # L1 loss on all pixels with small-enough error.
65 | softclamp_mask_b1 = repro_errs_b1N > self.soft_clamp
66 | return repro_errs_b1N[~softclamp_mask_b1].sum()
67 |
68 | elif self.type == "l1+sqrt":
69 | # L1 loss on pixels with small errors and sqrt for the others.
70 | softclamp_mask_b1 = repro_errs_b1N > self.soft_clamp
71 | loss_l1 = repro_errs_b1N[~softclamp_mask_b1].sum()
72 | loss_sqrt = torch.sqrt(self.soft_clamp * repro_errs_b1N[softclamp_mask_b1]).sum()
73 |
74 | return loss_l1 + loss_sqrt
75 |
76 | else:
77 | # l1+logl1: same as above, but use log(L1) for pixels with a larger error.
78 | softclamp_mask_b1 = repro_errs_b1N > self.soft_clamp
79 | loss_l1 = repro_errs_b1N[~softclamp_mask_b1].sum()
80 | loss_logl1 = torch.log(1 + (self.soft_clamp * repro_errs_b1N[softclamp_mask_b1])).sum()
81 |
82 | return loss_l1 + loss_logl1
83 |
--------------------------------------------------------------------------------
/conerf/model/scene_regressor/ace_util.py:
--------------------------------------------------------------------------------
1 | # Copyright © Niantic, Inc. 2022.
2 | # pylint: disable=[E1101]
3 |
4 | import torch
5 |
6 | from conerf.datasets.utils import store_ply
7 |
8 |
9 | def get_pixel_grid(image_height: int, image_width: int):
10 | """
11 | Generate target pixel positions according to image height and width, assuming
12 | prediction at center pixel.
13 | """
14 | ys = torch.arange(image_height, dtype=torch.float32)
15 | xs = torch.arange(image_width, dtype=torch.float32)
16 | yy, xx = torch.meshgrid(ys, xs, indexing='ij')
17 |
18 | return torch.stack([xx, yy]) + 0.5
19 |
20 |
21 | # def get_pixel_grid(subsampling_factor):
22 | # """
23 | # Generate target pixel positions according to a subsampling factor, assuming prediction
24 | # at center pixel.
25 | # """
26 | # pix_range = torch.arange(np.ceil(5000 / subsampling_factor), dtype=torch.float32)
27 | # yy, xx = torch.meshgrid(pix_range, pix_range, indexing='ij')
28 |
29 | # return subsampling_factor * (torch.stack([xx, yy]) + 0.5)
30 |
31 |
32 | def to_homogeneous(input_tensor, dim=1):
33 | """
34 | Converts tensor to homogeneous coordinates by adding ones to the specified dimension
35 | """
36 | ones = torch.ones_like(input_tensor.select(dim, 0).unsqueeze(dim))
37 | output = torch.cat([input_tensor, ones], dim=dim)
38 |
39 | return output
40 |
41 |
42 | def save_point_cloud(points3d: torch.Tensor, colors: torch.Tensor = None, path: str = ""):
43 | """Save point cloud to '.ply' file.
44 | """
45 | if isinstance(points3d, torch.Tensor):
46 | points3d = points3d.detach().cpu().numpy()
47 |
48 | if colors is not None:
49 | if isinstance(colors, torch.Tensor):
50 | colors = colors.detach().cpu().numpy()
51 |
52 | store_ply(path, points3d, colors)
53 |
--------------------------------------------------------------------------------
/conerf/model/scene_regressor/calibr.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=[E1101]
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class Calibr(nn.Module):
8 | """
9 | A modular class for calibration refinement.
10 | NOTE:
11 | The class assumes that:
12 | (1) the principle point is in the center;
13 | (2) pixels are unskewed and square;
14 | (3) image distortion is not modelled.
15 | """
16 |
17 | def __init__(self, device: str = "cuda"):
18 | super(Calibr, self).__init__()
19 |
20 | self.device = device
21 | self.scaler = torch.nn.Parameter(torch.tensor(0.), requires_grad=True)
22 |
23 | def forward(self, heights: torch.Tensor, widths: torch.Tensor) -> torch.Tensor:
24 | batch_size = heights.shape[0]
25 | # The initial focal length is set to 70% of the image diagonal.
26 | focal_lengths_init = 0.7 * torch.sqrt(heights ** 2 + widths ** 2)
27 |
28 | # assume principle point is in the center.
29 | cxs = widths / 2
30 | cys = heights / 2
31 |
32 | focal_lengths = focal_lengths_init * (1 + self.scaler)
33 |
34 | Ks = torch.eye(3, device=self.device)[None, ...].repeat(batch_size, 1, 1)
35 | Ks[:, 0, 0] = Ks[:, 1, 1] = focal_lengths
36 | Ks[:, 0, 2] = cxs
37 | Ks[:, 1, 2] = cys
38 |
39 | return Ks
40 |
--------------------------------------------------------------------------------
/conerf/model/scene_regressor/depth_network.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=[E1101,W0212]
2 |
3 | import torch
4 | import ssl
5 | ssl._create_default_https_context = ssl._create_unverified_context
6 |
7 |
8 | class DepthNetwork:
9 | """
10 | A wrapper of different depth network (ZoeDepth and Metric3D)
11 | """
12 | def __init__(
13 | self,
14 | method: str = "ZoeDepth",
15 | depth_type: str = "ZoeD_NK",
16 | pretrain: bool = True,
17 | depth_min: float = 0.1,
18 | depth_max: float = 1000,
19 | device: str = "cuda"
20 | ) -> None:
21 | self.method = method
22 | self.depth_min = depth_min
23 | self.depth_max = depth_max
24 | self.device = device
25 |
26 | if method == "ZoeDepth":
27 | self.depth_network = torch.hub.load(
28 | 'isl-org/ZoeDepth',
29 | depth_type,
30 | pretrained=pretrain,
31 | ).to(self.device)
32 | elif method == "metric3d":
33 | self.depth_network = torch.hub.load(
34 | 'yvanyin/metric3d',
35 | depth_type,
36 | pretrain=pretrain,
37 | ).to(self.device)
38 | else:
39 | raise NotImplementedError
40 |
41 | def infer(self, image: torch.Tensor):
42 | """
43 | Param:
44 | @param image: [B,3,H,W]
45 | Return:
46 | depth: depth map for image [B,1,H,W]
47 | confidence: confidence score corresponds to the depth map
48 | output_dict: other outputs from metric3d
49 | """
50 | confidence = None
51 | output_dict = None
52 | if self.method == "ZoeDepth":
53 | depth = self.depth_network.infer(image)
54 | elif self.method == "metric3d":
55 | depth, confidence, output_dict = self.depth_network.inference({'input': image})
56 | else:
57 | raise NotImplementedError
58 |
59 | depth = torch.clamp(depth, self.depth_min, self.depth_max)
60 |
61 | return depth, confidence, output_dict
62 |
--------------------------------------------------------------------------------
/conerf/model/scene_regressor/pose_refine_network.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=[E1101,E1102]
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from conerf.geometry.pose_util import se3_exp_map
7 | from conerf.model.backbone.mlp import MLP
8 |
9 |
10 | class PoseRefineNetwork(nn.Module):
11 | """
12 | Optimize the 6DoF camera poses.
13 | """
14 |
15 | def __init__(self, input_dim: int = 12, output_dim: int = 6, hidden_dim: int = 128):
16 | super(PoseRefineNetwork, self).__init__()
17 |
18 | self.hidden_dim = hidden_dim
19 |
20 | self.mlp = MLP(
21 | input_dim=input_dim,
22 | output_dim=output_dim,
23 | net_depth=6, # hard-coded.
24 | net_width=hidden_dim,
25 | skip_layer=3, # hard-coded.
26 | # TODO(chenyu): check with the hidden activation since it is not mentioned in the paper.
27 | hidden_activation=nn.ReLU(),
28 | )
29 |
30 | def forward(self, poses: torch.Tensor):
31 | """
32 | Parameters:
33 | @param poses: [N,3/4,4]
34 | Returns:
35 | optimized poses [N,4,4]
36 | """
37 | batch_size = poses.shape[0]
38 |
39 | poses_3x4 = poses[:, :3, :].reshape(batch_size, -1) # [B,12]
40 | delta_se3 = self.mlp(poses_3x4) # [B,6]
41 | delta_pose_4x4 = se3_exp_map(delta_se3) # [B,4,4]
42 |
43 | updated_poses = poses @ delta_pose_4x4
44 |
45 | return updated_poses, delta_se3
46 |
47 | # poses = poses[:, :3, :].reshape(batch_size, -1) # [B,12]
48 | # poses = self.mlp(poses) # [B,12]
49 | # poses = poses.reshape(batch_size, 3, 4) # [B,3,4]
50 |
51 | # # Retraction to recover the rotational part.
52 | # Us, _, Vhs = torch.linalg.svd(poses[:, :3, :3]) # pylint: disable=C0103
53 |
54 | # updated_poses = torch.eye(4, device=poses.device).reshape(-1, 4).repeat(batch_size, 1, 1)
55 |
56 | # # R = U @ V^T.
57 | # # Construct Z to fix the orientation of R to get det(R) = 1.
58 | # Z = torch.eye(3, device=poses.device).reshape(-1, 3).repeat(batch_size, 1, 1)
59 | # Z[:, -1, -1] = Z [:, -1, -1] * torch.sign(torch.linalg.det(Us @ Vhs))
60 | # updated_poses[:, :3, :3] = Us @ Z @ Vhs
61 |
62 | # # Copy translational part.
63 | # updated_poses[:, :3, 3:] = poses[:, :3, 3:]
64 |
65 | # return updated_poses
66 |
--------------------------------------------------------------------------------
/conerf/pycolmap/pycolmap/__init__.py:
--------------------------------------------------------------------------------
1 | from conerf.pycolmap.pycolmap.camera import Camera
2 | from conerf.pycolmap.pycolmap.database import COLMAPDatabase
3 | from conerf.pycolmap.pycolmap.image import Image
4 | from conerf.pycolmap.pycolmap.scene_manager import SceneManager
5 | from conerf.pycolmap.pycolmap.rotation import Quaternion, DualQuaternion
6 |
--------------------------------------------------------------------------------
/conerf/pycolmap/pycolmap/image.py:
--------------------------------------------------------------------------------
1 | # Author: True Price
2 |
3 | import numpy as np
4 |
5 | #-------------------------------------------------------------------------------
6 | #
7 | # Image
8 | #
9 | #-------------------------------------------------------------------------------
10 |
11 | class Image:
12 | def __init__(self, name_, camera_id_, q_, tvec_):
13 | self.name = name_
14 | self.camera_id = camera_id_
15 | self.q = q_
16 | self.tvec = tvec_
17 |
18 | self.points2D = np.empty((0, 2), dtype=np.float64)
19 | self.point3D_ids = np.empty((0,), dtype=np.uint64)
20 |
21 | #---------------------------------------------------------------------------
22 |
23 | def R(self):
24 | return self.q.ToR()
25 |
26 | #---------------------------------------------------------------------------
27 |
28 | def C(self):
29 | return -self.R().T.dot(self.tvec)
30 |
31 | #---------------------------------------------------------------------------
32 |
33 | @property
34 | def t(self):
35 | return self.tvec
36 |
--------------------------------------------------------------------------------
/conerf/pycolmap/tools/colmap_to_nvm.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import sys
3 | sys.path.append("..")
4 |
5 | import numpy as np
6 |
7 | from pycolmap import Quaternion, SceneManager
8 |
9 |
10 | #-------------------------------------------------------------------------------
11 |
12 | def main(args):
13 | scene_manager = SceneManager(args.input_folder)
14 | scene_manager.load()
15 |
16 | with open(args.output_file, "w") as fid:
17 | fid.write("NVM_V3\n \n{:d}\n".format(len(scene_manager.images)))
18 |
19 | image_fmt_str = " {:.3f} " + 7 * "{:.7f} "
20 | for image_id, image in scene_manager.images.iteritems():
21 | camera = scene_manager.cameras[image.camera_id]
22 | f = 0.5 * (camera.fx + camera.fy)
23 | fid.write(args.image_name_prefix + image.name)
24 | fid.write(image_fmt_str.format(
25 | *((f,) + tuple(image.q.q) + tuple(image.C()))))
26 | if camera.distortion_func is None:
27 | fid.write("0 0\n")
28 | else:
29 | fid.write("{:.7f} 0\n".format(-camera.k1))
30 |
31 | image_id_to_idx = dict(
32 | (image_id, i) for i, image_id in enumerate(scene_manager.images))
33 |
34 | fid.write("{:d}\n".format(len(scene_manager.points3D)))
35 | for i, point3D_id in enumerate(scene_manager.point3D_ids):
36 | fid.write(
37 | "{:.7f} {:.7f} {:.7f} ".format(*scene_manager.points3D[i]))
38 | fid.write(
39 | "{:d} {:d} {:d} ".format(*scene_manager.point3D_colors[i]))
40 | keypoints = [
41 | (image_id_to_idx[image_id], kp_idx) +
42 | tuple(scene_manager.images[image_id].points2D[kp_idx])
43 | for image_id, kp_idx in
44 | scene_manager.point3D_id_to_images[point3D_id]]
45 | fid.write("{:d}".format(len(keypoints)))
46 | fid.write(
47 | (len(keypoints) * " {:d} {:d} {:.3f} {:.3f}" + "\n").format(
48 | *itertools.chain(*keypoints)))
49 |
50 |
51 | #-------------------------------------------------------------------------------
52 |
53 | if __name__ == "__main__":
54 | import argparse
55 |
56 | parser = argparse.ArgumentParser(
57 | description="Save a COLMAP reconstruction in the NVM format "
58 | "(http://ccwu.me/vsfm/doc.html#nvm).",
59 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
60 |
61 | parser.add_argument("input_folder")
62 | parser.add_argument("output_file")
63 |
64 | parser.add_argument("--image_name_prefix", type=str, default="",
65 | help="prefix image names with this string (e.g., 'images/')")
66 |
67 | args = parser.parse_args()
68 |
69 | main(args)
70 |
--------------------------------------------------------------------------------
/conerf/pycolmap/tools/delete_images.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import numpy as np
5 |
6 | from pycolmap import DualQuaternion, Image, SceneManager
7 |
8 |
9 | #-------------------------------------------------------------------------------
10 |
11 | def main(args):
12 | scene_manager = SceneManager(args.input_folder)
13 | scene_manager.load()
14 |
15 | image_ids = map(scene_manager.get_image_from_name,
16 | iter(lambda: sys.stdin.readline().strip(), ""))
17 | scene_manager.delete_images(image_ids)
18 |
19 | scene_manager.save(args.output_folder)
20 |
21 |
22 | #-------------------------------------------------------------------------------
23 |
24 | if __name__ == "__main__":
25 | import argparse
26 |
27 | parser = argparse.ArgumentParser(
28 | description="Deletes images (filenames read from stdin) from a model.",
29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
30 |
31 | parser.add_argument("input_folder")
32 | parser.add_argument("output_folder")
33 |
34 | args = parser.parse_args()
35 |
36 | main(args)
37 |
--------------------------------------------------------------------------------
/conerf/pycolmap/tools/impute_missing_cameras.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import numpy as np
5 |
6 | from pycolmap import DualQuaternion, Image, SceneManager
7 |
8 |
9 | #-------------------------------------------------------------------------------
10 |
11 | image_to_idx = lambda im: int(im.name[:im.name.rfind(".")])
12 |
13 |
14 | #-------------------------------------------------------------------------------
15 |
16 | def interpolate_linear(images, camera_id, file_format):
17 | if len(images) < 2:
18 | raise ValueError("Need at least two images for linear interpolation!")
19 |
20 | prev_image = images[0]
21 | prev_idx = image_to_idx(prev_image)
22 | prev_dq = DualQuaternion.FromQT(prev_image.q, prev_image.t)
23 | start = prev_idx
24 |
25 | new_images = []
26 |
27 | for image in images[1:]:
28 | curr_idx = image_to_idx(image)
29 | curr_dq = DualQuaternion.FromQT(image.q, image.t)
30 | T = curr_idx - prev_idx
31 | Tinv = 1. / T
32 |
33 | # like quaternions, dq(x) = -dq(x), so we'll need to pick the one more
34 | # appropriate for interpolation by taking -dq if the dot product of the
35 | # two q-vectors is negative
36 | if prev_dq.q0.dot(curr_dq.q0) < 0:
37 | curr_dq = -curr_dq
38 |
39 | for i in xrange(1, T):
40 | t = i * Tinv
41 | dq = t * prev_dq + (1. - t) * curr_dq
42 | q, t = dq.ToQT()
43 | new_images.append(
44 | Image(file_format.format(prev_idx + i), args.camera_id, q, t))
45 |
46 | prev_idx = curr_idx
47 | prev_dq = curr_dq
48 |
49 | return new_images
50 |
51 |
52 | #-------------------------------------------------------------------------------
53 |
54 | def interpolate_hermite(images, camera_id, file_format):
55 | if len(images) < 4:
56 | raise ValueError(
57 | "Need at least four images for Hermite spline interpolation!")
58 |
59 | new_images = []
60 |
61 | # linear blending for the first frames
62 | T0 = image_to_idx(images[0])
63 | dq0 = DualQuaternion.FromQT(images[0].q, images[0].t)
64 | T1 = image_to_idx(images[1])
65 | dq1 = DualQuaternion.FromQT(images[1].q, images[1].t)
66 |
67 | if dq0.q0.dot(dq1.q0) < 0:
68 | dq1 = -dq1
69 | dT = 1. / float(T1 - T0)
70 | for j in xrange(1, T1 - T0):
71 | t = j * dT
72 | dq = ((1. - t) * dq0 + t * dq1).normalize()
73 | new_images.append(
74 | Image(file_format.format(T0 + j), camera_id, *dq.ToQT()))
75 |
76 | T2 = image_to_idx(images[2])
77 | dq2 = DualQuaternion.FromQT(images[2].q, images[2].t)
78 | if dq1.q0.dot(dq2.q0) < 0:
79 | dq2 = -dq2
80 |
81 | # Hermite spline interpolation of dual quaternions
82 | # pdfs.semanticscholar.org/05b1/8ede7f46c29c2722fed3376d277a1d286c55.pdf
83 | for i in xrange(1, len(images) - 2):
84 | T3 = image_to_idx(images[i + 2])
85 | dq3 = DualQuaternion.FromQT(images[i + 2].q, images[i + 2].t)
86 | if dq2.q0.dot(dq3.q0) < 0:
87 | dq3 = -dq3
88 |
89 | prev_duration = T1 - T0
90 | current_duration = T2 - T1
91 | next_duration = T3 - T2
92 |
93 | # approximate the derivatives at dq1 and dq2 using weighted central
94 | # differences
95 | dt1 = 1. / float(T2 - T0)
96 | dt2 = 1. / float(T3 - T1)
97 |
98 | m1 = (current_duration * dt1) * (dq2 - dq1) + \
99 | (prev_duration * dt1) * (dq1 - dq0)
100 | m2 = (next_duration * dt2) * (dq3 - dq2) + \
101 | (current_duration * dt2) * (dq2 - dq1)
102 |
103 | dT = 1. / float(current_duration)
104 |
105 | for j in xrange(1, current_duration):
106 | t = j * dT # 0 to 1
107 | t2 = t * t # t squared
108 | t3 = t2 * t # t cubed
109 |
110 | # coefficients of the Hermite spline (a=>dq and b=>m)
111 | a1 = 2. * t3 - 3. * t2 + 1.
112 | b1 = t3 - 2. * t2 + t
113 | a2 = -2. * t3 + 3. * t2
114 | b2 = t3 - t2
115 |
116 | dq = (a1 * dq1 + b1 * m1 + a2 * dq2 + b2 * m2).normalize()
117 |
118 | new_images.append(
119 | Image(file_format.format(T1 + j), camera_id, *dq.ToQT()))
120 |
121 | T0, T1, T2 = T1, T2, T3
122 | dq0, dq1, dq2 = dq1, dq2, dq3
123 |
124 | # linear blending for the last frames
125 | dT = 1. / float(T2 - T1)
126 | for j in xrange(1, T2 - T1):
127 | t = j * dT # 0 to 1
128 | dq = ((1. - t) * dq1 + t * dq2).normalize()
129 | new_images.append(
130 | Image(file_format.format(T1 + j), camera_id, *dq.ToQT()))
131 |
132 | return new_images
133 |
134 |
135 | #-------------------------------------------------------------------------------
136 |
137 | def main(args):
138 | scene_manager = SceneManager(args.input_folder)
139 | scene_manager.load()
140 |
141 | images = sorted(scene_manager.images.itervalues(), key=image_to_idx)
142 |
143 | if args.method.lower() == "linear":
144 | new_images = interpolate_linear(images, args.camera_id, args.format)
145 | else:
146 | new_images = interpolate_hermite(images, args.camera_id, args.format)
147 |
148 | map(scene_manager.add_image, new_images)
149 |
150 | scene_manager.save(args.output_folder)
151 |
152 |
153 | #-------------------------------------------------------------------------------
154 |
155 | if __name__ == "__main__":
156 | import argparse
157 |
158 | parser = argparse.ArgumentParser(
159 | description="Given a reconstruction with ordered images *with integer "
160 | "filenames* like '000100.png', fill in missing camera positions for "
161 | "intermediate frames.",
162 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
163 |
164 | parser.add_argument("input_folder")
165 | parser.add_argument("output_folder")
166 |
167 | parser.add_argument("--camera_id", type=int, default=1,
168 | help="camera id to use for the missing images")
169 |
170 | parser.add_argument("--format", type=str, default="{:06d}.png",
171 | help="filename format to use for added images")
172 |
173 | parser.add_argument(
174 | "--method", type=str.lower, choices=("linear", "hermite"),
175 | default="hermite",
176 | help="Pose imputation method")
177 |
178 | args = parser.parse_args()
179 |
180 | main(args)
181 |
--------------------------------------------------------------------------------
/conerf/pycolmap/tools/save_cameras_as_ply.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import numpy as np
5 | import os
6 |
7 | from pycolmap import SceneManager
8 |
9 |
10 | #-------------------------------------------------------------------------------
11 |
12 | # Saves the cameras as a mesh
13 | #
14 | # inputs:
15 | # - ply_file: output file
16 | # - images: ordered array of pycolmap Image objects
17 | # - color: color string for the camera
18 | # - scale: amount to shrink/grow the camera model
19 | def save_camera_ply(ply_file, images, scale):
20 | points3D = scale * np.array((
21 | (0., 0., 0.),
22 | (-1., -1., 1.),
23 | (-1., 1., 1.),
24 | (1., -1., 1.),
25 | (1., 1., 1.)))
26 |
27 | faces = np.array(((0, 2, 1),
28 | (0, 4, 2),
29 | (0, 3, 4),
30 | (0, 1, 3),
31 | (1, 2, 4),
32 | (1, 4, 3)))
33 |
34 | r = np.linspace(0, 255, len(images), dtype=np.uint8)
35 | g = 255 - r
36 | b = r - np.linspace(0, 128, len(images), dtype=np.uint8)
37 | color = np.column_stack((r, g, b))
38 |
39 | with open(ply_file, "w") as fid:
40 | print>>fid, "ply"
41 | print>>fid, "format ascii 1.0"
42 | print>>fid, "element vertex", len(points3D) * len(images)
43 | print>>fid, "property float x"
44 | print>>fid, "property float y"
45 | print>>fid, "property float z"
46 | print>>fid, "property uchar red"
47 | print>>fid, "property uchar green"
48 | print>>fid, "property uchar blue"
49 | print>>fid, "element face", len(faces) * len(images)
50 | print>>fid, "property list uchar int vertex_index"
51 | print>>fid, "end_header"
52 |
53 | for image, c in zip(images, color):
54 | for p3D in (points3D.dot(image.R()) + image.C()):
55 | print>>fid, p3D[0], p3D[1], p3D[2], c[0], c[1], c[2]
56 |
57 | for i in xrange(len(images)):
58 | for f in (faces + len(points3D) * i):
59 | print>>fid, "3 {} {} {}".format(*f)
60 |
61 |
62 | #-------------------------------------------------------------------------------
63 |
64 | def main(args):
65 | scene_manager = SceneManager(args.input_folder)
66 | scene_manager.load_images()
67 |
68 | images = sorted(scene_manager.images.itervalues(),
69 | key=lambda image: image.name)
70 |
71 | save_camera_ply(args.output_file, images, args.scale)
72 |
73 |
74 | #-------------------------------------------------------------------------------
75 |
76 | if __name__ == "__main__":
77 | import argparse
78 |
79 | parser = argparse.ArgumentParser(
80 | description="Saves camera positions to a PLY for easy viewing outside "
81 | "of COLMAP. Currently, camera FoV is not reflected in the output.",
82 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
83 |
84 | parser.add_argument("input_folder")
85 | parser.add_argument("output_file")
86 |
87 | parser.add_argument("--scale", type=float, default=1.,
88 | help="Scaling factor for the camera mesh.")
89 |
90 | args = parser.parse_args()
91 |
92 | main(args)
93 |
--------------------------------------------------------------------------------
/conerf/pycolmap/tools/transform_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import numpy as np
5 |
6 | from pycolmap import Quaternion, SceneManager
7 |
8 |
9 | #-------------------------------------------------------------------------------
10 |
11 | def main(args):
12 | scene_manager = SceneManager(args.input_folder)
13 | scene_manager.load()
14 |
15 | # expect each line of input corresponds to one row
16 | P = np.array([
17 | map(float, sys.stdin.readline().strip().split()) for _ in xrange(3)])
18 |
19 | scene_manager.points3D[:] = scene_manager.points3D.dot(P[:,:3].T) + P[:,3]
20 |
21 | # get rotation without any global scaling (assuming isotropic scaling)
22 | scale = np.cbrt(np.linalg.det(P[:,:3]))
23 | q_old_from_new = ~Quaternion.FromR(P[:,:3] / scale)
24 |
25 | for image in scene_manager.images.itervalues():
26 | image.q *= q_old_from_new
27 | image.tvec = scale * image.tvec - image.R().dot(P[:,3])
28 |
29 | scene_manager.save(args.output_folder)
30 |
31 |
32 | #-------------------------------------------------------------------------------
33 |
34 | if __name__ == "__main__":
35 | import argparse
36 |
37 | parser = argparse.ArgumentParser(
38 | description="Apply a 3x4 transformation matrix to a COLMAP model and "
39 | "save the result as a new model. Row-major input can be piped in from "
40 | "a file or entered via the command line.",
41 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
42 |
43 | parser.add_argument("input_folder")
44 | parser.add_argument("output_folder")
45 |
46 | args = parser.parse_args()
47 |
48 | main(args)
49 |
--------------------------------------------------------------------------------
/conerf/pycolmap/tools/write_camera_track_to_bundler.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import numpy as np
5 |
6 | from pycolmap import SceneManager
7 |
8 |
9 | #-------------------------------------------------------------------------------
10 |
11 | def main(args):
12 | scene_manager = SceneManager(args.input_folder)
13 | scene_manager.load_cameras()
14 | scene_manager.load_images()
15 |
16 | if args.sort:
17 | images = sorted(
18 | scene_manager.images.itervalues(), key=lambda im: im.name)
19 | else:
20 | images = scene_manager.images.values()
21 |
22 | fid = open(args.output_file, "w")
23 | fid_filenames = open(args.output_file + ".list.txt", "w")
24 |
25 | print>>fid, "# Bundle file v0.3"
26 | print>>fid, len(images), 0
27 |
28 | for image in images:
29 | print>>fid_filenames, image.name
30 | camera = scene_manager.cameras[image.camera_id]
31 | print>>fid, 0.5 * (camera.fx + camera.fy), 0, 0
32 | R, t = image.R(), image.t
33 | print>>fid, R[0, 0], R[0, 1], R[0, 2]
34 | print>>fid, -R[1, 0], -R[1, 1], -R[1, 2]
35 | print>>fid, -R[2, 0], -R[2, 1], -R[2, 2]
36 | print>>fid, t[0], -t[1], -t[2]
37 |
38 | fid.close()
39 | fid_filenames.close()
40 |
41 |
42 | #-------------------------------------------------------------------------------
43 |
44 | if __name__ == "__main__":
45 | import argparse
46 |
47 | parser = argparse.ArgumentParser(
48 | description="Saves the camera positions in the Bundler format. Note "
49 | "that 3D points are not saved.",
50 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
51 |
52 | parser.add_argument("input_folder")
53 | parser.add_argument("output_file")
54 |
55 | parser.add_argument("--sort", default=False, action="store_true",
56 | help="sort the images by their filename")
57 |
58 | args = parser.parse_args()
59 |
60 | main(args)
61 |
--------------------------------------------------------------------------------
/conerf/pycolmap/tools/write_depthmap_to_ply.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import imageio
5 | import numpy as np
6 | import os
7 |
8 | from plyfile import PlyData, PlyElement
9 | from pycolmap import SceneManager
10 | from scipy.ndimage.interpolation import zoom
11 |
12 |
13 | #-------------------------------------------------------------------------------
14 |
15 | def main(args):
16 | suffix = ".photometric.bin" if args.photometric else ".geometric.bin"
17 |
18 | image_file = os.path.join(args.dense_folder, "images", args.image_filename)
19 | depth_file = os.path.join(
20 | args.dense_folder, args.stereo_folder, "depth_maps",
21 | args.image_filename + suffix)
22 | if args.save_normals:
23 | normals_file = os.path.join(
24 | args.dense_folder, args.stereo_folder, "normal_maps",
25 | args.image_filename + suffix)
26 |
27 | # load camera intrinsics from the COLMAP reconstruction
28 | scene_manager = SceneManager(os.path.join(args.dense_folder, "sparse"))
29 | scene_manager.load_cameras()
30 | scene_manager.load_images()
31 |
32 | image_id, image = scene_manager.get_image_from_name(args.image_filename)
33 | camera = scene_manager.cameras[image.camera_id]
34 | rotation_camera_from_world = image.R()
35 | camera_center = image.C()
36 |
37 | # load image, depth map, and normal map
38 | image = imageio.imread(image_file)
39 |
40 | with open(depth_file, "rb") as fid:
41 | w = int("".join(iter(lambda: fid.read(1), "&")))
42 | h = int("".join(iter(lambda: fid.read(1), "&")))
43 | c = int("".join(iter(lambda: fid.read(1), "&")))
44 | depth_map = np.fromfile(fid, np.float32).reshape(h, w)
45 | if (h, w) != image.shape[:2]:
46 | depth_map = zoom(
47 | depth_map,
48 | (float(image.shape[0]) / h, float(image.shape[1]) / w),
49 | order=0)
50 |
51 | if args.save_normals:
52 | with open(normals_file, "rb") as fid:
53 | w = int("".join(iter(lambda: fid.read(1), "&")))
54 | h = int("".join(iter(lambda: fid.read(1), "&")))
55 | c = int("".join(iter(lambda: fid.read(1), "&")))
56 | normals = np.fromfile(
57 | fid, np.float32).reshape(c, h, w).transpose([1, 2, 0])
58 | if (h, w) != image.shape[:2]:
59 | normals = zoom(
60 | normals,
61 | (float(image.shape[0]) / h, float(image.shape[1]) / w, 1.),
62 | order=0)
63 |
64 | if args.min_depth is not None:
65 | depth_map[depth_map < args.min_depth] = 0.
66 | if args.max_depth is not None:
67 | depth_map[depth_map > args.max_depth] = 0.
68 |
69 | # create 3D points
70 | #depth_map = np.minimum(depth_map, 100.)
71 | points3D = np.dstack(camera.get_image_grid() + [depth_map])
72 | points3D[:,:,:2] *= depth_map[:,:,np.newaxis]
73 |
74 | # save
75 | points3D = points3D.astype(np.float32).reshape(-1, 3)
76 | if args.save_normals:
77 | normals = normals.astype(np.float32).reshape(-1, 3)
78 | image = image.reshape(-1, 3)
79 | if image.dtype != np.uint8:
80 | if image.max() <= 1:
81 | image = (image * 255.).astype(np.uint8)
82 | else:
83 | image = image.astype(np.uint8)
84 |
85 | if args.world_space:
86 | points3D = points3D.dot(rotation_camera_from_world) + camera_center
87 | if args.save_normals:
88 | normals = normals.dot(rotation_camera_from_world)
89 |
90 | if args.save_normals:
91 | vertices = np.rec.fromarrays(
92 | tuple(points3D.T) + tuple(normals.T) + tuple(image.T),
93 | names="x,y,z,nx,ny,nz,red,green,blue")
94 | else:
95 | vertices = np.rec.fromarrays(
96 | tuple(points3D.T) + tuple(image.T), names="x,y,z,red,green,blue")
97 | vertices = PlyElement.describe(vertices, "vertex")
98 | PlyData([vertices]).write(args.output_filename)
99 |
100 |
101 | #-------------------------------------------------------------------------------
102 |
103 | if __name__ == "__main__":
104 | import argparse
105 |
106 | parser = argparse.ArgumentParser(
107 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
108 |
109 | parser.add_argument("dense_folder", type=str)
110 | parser.add_argument("image_filename", type=str)
111 | parser.add_argument("output_filename", type=str)
112 |
113 | parser.add_argument(
114 | "--photometric", default=False, action="store_true",
115 | help="use photometric depthmap instead of geometric")
116 |
117 | parser.add_argument(
118 | "--world_space", default=False, action="store_true",
119 | help="apply the camera->world extrinsic transformation to the result")
120 |
121 | parser.add_argument(
122 | "--save_normals", default=False, action="store_true",
123 | help="load the estimated normal map and save as part of the PLY")
124 |
125 | parser.add_argument(
126 | "--stereo_folder", type=str, default="stereo",
127 | help="folder in the dense workspace containing depth and normal maps")
128 |
129 | parser.add_argument(
130 | "--min_depth", type=float, default=None,
131 | help="set pixels with depth less than this value to zero depth")
132 |
133 | parser.add_argument(
134 | "--max_depth", type=float, default=None,
135 | help="set pixels with depth greater than this value to zero depth")
136 |
137 | args = parser.parse_args()
138 |
139 | main(args)
140 |
--------------------------------------------------------------------------------
/conerf/utils/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from omegaconf import OmegaConf
3 |
4 |
5 | def strs2ints(strs):
6 | strs = strs.split(',')
7 | ints = []
8 | for num in strs:
9 | ints.append(int(num))
10 | print(f'ints: {ints}')
11 | return ints
12 |
13 |
14 | def calc_milestones(max_step, muls, divs):
15 | # muls, divs = strs2ints(muls), strs2ints(divs)
16 | milestones = "["
17 | for mul, div in zip(muls, divs):
18 | milestones += str(max_step * mul // div)
19 | milestones += ","
20 | real_milestones = milestones[:-1]
21 | real_milestones += "]"
22 | return real_milestones
23 |
24 |
25 | OmegaConf.register_new_resolver(
26 | 'calc_exp_lr_decay_rate',
27 | lambda factor, n: factor**(1./n)
28 | )
29 | OmegaConf.register_new_resolver('add', lambda a, b: a + b)
30 | OmegaConf.register_new_resolver('sub', lambda a, b: a - b)
31 | OmegaConf.register_new_resolver('mul', lambda a, b: a * b)
32 | OmegaConf.register_new_resolver('divi', lambda a, b: a // b)
33 | OmegaConf.register_new_resolver(
34 | 'calc_milestones',
35 | lambda max_step, muls, divs: calc_milestones(max_step, muls, divs) # pylint: disable=W0108
36 | )
37 |
38 |
39 | def config_parser():
40 | parser = argparse.ArgumentParser()
41 |
42 | ##################################### Base configs ########################################
43 | parser.add_argument("--config",
44 | type=str,
45 | default="",
46 | help="absolute path of config file")
47 | parser.add_argument("--suffix",
48 | type=str,
49 | default="",
50 | help="suffix for training folder")
51 | parser.add_argument("--scene",
52 | type=str,
53 | default="",
54 | help="name for the trained scene")
55 | parser.add_argument("--expname",
56 | type=str,
57 | default="",
58 | help="experiment name")
59 | parser.add_argument("--model_folder",
60 | type=str,
61 | default="sparse", # ['sparse', 'zero_gs']
62 | help="folder that contain colmap model output")
63 | parser.add_argument("--init_ply_type",
64 | type=str,
65 | default="sparse", # ['sparse', 'dense']
66 | help="use dense or sparse point cloud to initialize 3DGS")
67 | parser.add_argument("--load_specified_images",
68 | action="store_true",
69 | help="Only load the specified images to train.")
70 |
71 | ##################################### Block Training ########################################
72 | parser.add_argument("--block_id",
73 | type=int,
74 | default=0,
75 | help="block id")
76 | parser.add_argument("--block_data_path",
77 | type=str,
78 | default="",
79 | help="directory that stores the block data")
80 | parser.add_argument("--train_local",
81 | action="store_true",
82 | help="train local blocks")
83 |
84 | ##################################### registration ########################################
85 | parser.add_argument("--position_embedding_type",
86 | type=str,
87 | default="sine",
88 | help="which kind of positional embedding to use in transformer")
89 | parser.add_argument("--position_embedding_dim",
90 | type=int,
91 | default=256,
92 | help="dimensionality of position embeddings")
93 | parser.add_argument("--position_embedding_scaling",
94 | type=float,
95 | default=1.0,
96 | help="position embedding scale factor")
97 | parser.add_argument("--num_downsample",
98 | type=int,
99 | default=6,
100 | help="how many layers used to downsample points")
101 | parser.add_argument("--robust_loss",
102 | action="store_true",
103 | help="whether to use robust loss function")
104 |
105 | #################################### composite inr blocks #################################
106 | parser.add_argument("--enable_composite",
107 | action="store_true",
108 | help="whether to composite implicit neural representation blocks.")
109 |
110 | args = parser.parse_args()
111 |
112 | return args
113 |
114 |
115 | def load_config(*yaml_files, cli_args=[]):
116 | yaml_confs = [OmegaConf.load(f) for f in yaml_files]
117 | cli_conf = OmegaConf.from_cli(cli_args)
118 | conf = OmegaConf.merge(*yaml_confs, cli_conf)
119 | OmegaConf.resolve(conf)
120 |
121 | return conf
122 |
--------------------------------------------------------------------------------
/conerf/visualization/feature_visualizer.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1101
2 |
3 | import math
4 |
5 | import torch
6 | import torchvision
7 | import torchvision.transforms as transforms
8 |
9 | import numpy as np
10 | import cv2
11 |
12 |
13 | def plot_feature_map(writer, global_step, ray_sampler, feat_maps, prefix=''):
14 | coarse_feat_map = ray_sampler.target_feat_map[0].transpose(0, 1)
15 | feat_map_grid = torchvision.utils.make_grid(
16 | coarse_feat_map, normalize=True, scale_each=True, nrow=8)
17 | writer.add_image(prefix + 'target_feat_map', feat_map_grid, global_step)
18 |
19 | num_nearby_views = feat_maps[0].shape[0]
20 | for i in range(num_nearby_views):
21 | feat_map = feat_maps[0][i].unsqueeze(0).transpose(0, 1)
22 | # print(f'[DEBUG] feat_map shape: {feat_map}')
23 | feat_map_grid = torchvision.utils.make_grid(
24 | feat_map, normalize=True, scale_each=True, nrow=8)
25 | writer.add_image(
26 | prefix + f'nearby_feat_map-{i}', feat_map_grid, global_step)
27 |
28 |
29 | def feature_map_to_heatmap(feat_maps):
30 | '''
31 | feat_maps: [C, H, W]
32 | '''
33 | # Define a transform to convert the image to tensor
34 | transform = transforms.ToTensor()
35 |
36 | num_channels = feat_maps.shape[0]
37 | heat_maps = []
38 |
39 | for i in range(num_channels):
40 | feat_map = np.asarray(feat_maps)[i]
41 | # print('feat_map.shape:', feat_map.shape) # [H, W]
42 | # print('feat_map type:', feat_map.dtype) # float32
43 |
44 | feat_map = np.asarray(feat_map * 255, dtype=np.uint8) # [0,255]
45 | # print('feat_map type:', feat_map.dtype) # uint8
46 |
47 | # https://www.sohu.com/a/343215045_120197868
48 | feat_map = cv2.applyColorMap(feat_map, cv2.COLORMAP_RAINBOW)
49 | feat_map = transform(feat_map)
50 |
51 | heat_maps.append(feat_map)
52 |
53 | heat_maps = torch.stack(heat_maps, dim=0) # [C, 3, 25, 25]
54 | return heat_maps
55 |
56 |
57 | def feature_maps_to_heatmap(feat_maps):
58 | '''
59 | Args:
60 | feat_maps: [C, H, W]
61 | Return:
62 | A composed heat map with shape [H, W]
63 | '''
64 | # Define a transform to convert the image to tensor
65 | transform = transforms.ToTensor()
66 |
67 | # print(f'[DEBUG] feat_maps shape: {feat_maps.shape}')
68 | [c, h, w] = feat_maps.shape
69 |
70 | heatmap = torch.zeros((h, w))
71 | weight = []
72 | feat_maps = np.asarray(feat_maps)
73 |
74 | for i in range(c):
75 | feat_map = feat_maps[i]
76 | weight = np.mean(feat_map)
77 | heatmap[:, :] += weight * feat_map
78 |
79 | heatmap = (heatmap - heatmap.min()) / heatmap.max() # normalization
80 |
81 | heatmap = np.asarray(heatmap * 255, dtype=np.uint8)
82 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_RAINBOW)
83 | heatmap = transform(heatmap)
84 |
85 | return heatmap
86 |
87 |
88 | def plot_sampled_feature_map(writer, global_step, target_rgb_feat, rgb_feats, N_rand, prefix='train/'):
89 | width = int(math.sqrt(N_rand))
90 | target_rgb_feat = target_rgb_feat.detach().cpu()
91 | rgb_feats = rgb_feats.detach().cpu()
92 |
93 | target_rgb_feat = target_rgb_feat.permute(
94 | 3, 2, 0, 1).reshape(35, -1, width, width)
95 | rgb_feats = rgb_feats.permute(3, 2, 0, 1).reshape(35, -1, width, width)
96 | res_rgb_feats = torch.abs(target_rgb_feat - rgb_feats)
97 | target_rgb_feat = target_rgb_feat[:, 0, ...]
98 |
99 | # target_feat_map = feature_map_to_heatmap(target_rgb_feat[3:])
100 | # feat_map_grid = torchvision.utils.make_grid(target_feat_map, normalize=True, scale_each=True, nrow=8)
101 | # writer.add_image(prefix + f'target_feat_map', feat_map_grid, global_step) # feature map
102 | target_feat_map = feature_maps_to_heatmap(target_rgb_feat[3:])
103 | writer.add_image(prefix + f'target_feat_map',
104 | target_feat_map, global_step) # feature map
105 | writer.add_image(prefix + f'target_rgb_map',
106 | target_rgb_feat[0:3], global_step)
107 |
108 | num_nearby_views = rgb_feats.shape[1]
109 | nearby_feat_maps, nearby_rgb_maps = [], []
110 | nearby_res_feat_maps, nearby_res_rgb_maps = [], []
111 | for i in range(num_nearby_views):
112 | rgb_feat_map = rgb_feats[:, i, ...]
113 | feat_map = feature_maps_to_heatmap(rgb_feat_map[3:])
114 | nearby_feat_maps.append(feat_map)
115 | nearby_rgb_maps.append(rgb_feat_map[0:3])
116 |
117 | res_rgb_feat_map = res_rgb_feats[:, i, ...]
118 | res_feat_map = feature_maps_to_heatmap(res_rgb_feat_map[3:])
119 | nearby_res_feat_maps.append(res_feat_map)
120 | nearby_res_rgb_maps.append(res_rgb_feat_map[0:3])
121 |
122 | nearby_feat_maps = torch.stack(nearby_feat_maps, dim=0)
123 | nearby_feat_grid = torchvision.utils.make_grid(
124 | nearby_feat_maps, normalize=True, scale_each=True, nrow=5)
125 | writer.add_image(prefix + f'nearby_feat_maps',
126 | nearby_feat_grid, global_step)
127 |
128 | nearby_rgb_maps = torch.stack(nearby_rgb_maps, dim=0) # [n_views, 3, h, w]
129 | nearby_rgb_grid = torchvision.utils.make_grid(
130 | nearby_rgb_maps, normalize=True, scale_each=True, nrow=5)
131 | writer.add_image(prefix + f'nearby_rgb_maps', nearby_rgb_grid, global_step)
132 |
133 | nearby_res_feat_maps = torch.stack(nearby_res_feat_maps, dim=0)
134 | nearby_res_feat_grid = torchvision.utils.make_grid(
135 | nearby_res_feat_maps, normalize=True, scale_each=True, nrow=5)
136 | writer.add_image(prefix + f'nearby_res_feat_maps',
137 | nearby_res_feat_grid, global_step)
138 |
139 | nearby_res_rgb_maps = torch.stack(
140 | nearby_res_rgb_maps, dim=0) # [n_views, 3, h, w]
141 | nearby_res_rgb_grid = torchvision.utils.make_grid(
142 | nearby_res_rgb_maps, normalize=True, scale_each=True, nrow=5)
143 | writer.add_image(prefix + f'nearby_res_rgb_maps',
144 | nearby_res_rgb_grid, global_step)
145 |
--------------------------------------------------------------------------------
/conerf/visualization/pose_visualizer.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1101
2 |
3 | from typing import List
4 |
5 | import torch
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from easydict import EasyDict as edict
9 |
10 |
11 | def to_hom(X):
12 | # get homogeneous coordinates of the input
13 | X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)
14 |
15 | return X_hom
16 |
17 |
18 | def get_camera_mesh(pose, depth=1):
19 | vertices = torch.tensor([[-0.5, -0.5, 1],
20 | [0.5, -0.5, 1],
21 | [0.5, 0.5, 1],
22 | [-0.5, 0.5, 1],
23 | [0, 0, 0]]) * depth
24 |
25 | faces = torch.tensor([[0, 1, 2],
26 | [0, 2, 3],
27 | [0, 1, 4],
28 | [1, 2, 4],
29 | [2, 3, 4],
30 | [3, 0, 4]])
31 |
32 | # vertices = camera.cam2world(vertices[None], pose)
33 | vertices = to_hom(vertices[None]) @ pose.transpose(-1, -2)
34 |
35 | wire_frame = vertices[:, [0, 1, 2, 3, 0, 4, 1, 2, 4, 3]]
36 |
37 | return vertices, faces, wire_frame
38 |
39 |
40 | def merge_wire_frames(wire_frame):
41 | wire_frame_merged = [[], [], []]
42 | for w in wire_frame:
43 | wire_frame_merged[0] += [float(n) for n in w[:, 0]] + [None]
44 | wire_frame_merged[1] += [float(n) for n in w[:, 1]] + [None]
45 | wire_frame_merged[2] += [float(n) for n in w[:, 2]] + [None]
46 |
47 | return wire_frame_merged
48 |
49 |
50 | def merge_meshes(vertices, faces):
51 | mesh_N, vertex_N = vertices.shape[:2]
52 | faces_merged = torch.cat([faces+i*vertex_N for i in range(mesh_N)], dim=0)
53 | vertices_merged = vertices.view(-1, vertices.shape[-1])
54 |
55 | return vertices_merged, faces_merged
56 |
57 |
58 | def merge_centers(centers):
59 | center_merged = [[], [], []]
60 |
61 | for c1, c2 in zip(*centers):
62 | center_merged[0] += [float(c1[0]), float(c2[0]), None]
63 | center_merged[1] += [float(c1[1]), float(c2[1]), None]
64 | center_merged[2] += [float(c1[2]), float(c2[2]), None]
65 |
66 | return center_merged
67 |
68 |
69 | @torch.no_grad()
70 | def visualize_cameras(
71 | vis,
72 | step: int = 0,
73 | poses: List = [],
74 | cam_depth: float = 0.5,
75 | colors: List = ["blue", "magenta"],
76 | plot_dist: bool = True
77 | ):
78 | win_name = "gt_pred"
79 | data = []
80 |
81 | # set up plots
82 | centers = []
83 | for pose, color in zip(poses, colors):
84 | pose = pose.detach().cpu()
85 | vertices, faces, wire_frame = get_camera_mesh(pose, depth=cam_depth)
86 | center = vertices[:, -1]
87 | centers.append(center)
88 |
89 | # camera centers
90 | data.append(dict(
91 | type="scatter3d",
92 | x=[float(n) for n in center[:, 0]],
93 | y=[float(n) for n in center[:, 1]],
94 | z=[float(n) for n in center[:, 2]],
95 | mode="markers",
96 | marker=dict(color=color, size=3),
97 | ))
98 |
99 | # colored camera mesh
100 | vertices_merged, faces_merged = merge_meshes(vertices, faces)
101 |
102 | data.append(dict(
103 | type="mesh3d",
104 | x=[float(n) for n in vertices_merged[:, 0]],
105 | y=[float(n) for n in vertices_merged[:, 1]],
106 | z=[float(n) for n in vertices_merged[:, 2]],
107 | i=[int(n) for n in faces_merged[:, 0]],
108 | j=[int(n) for n in faces_merged[:, 1]],
109 | k=[int(n) for n in faces_merged[:, 2]],
110 | flatshading=True,
111 | color=color,
112 | opacity=0.05,
113 | ))
114 |
115 | # camera wire_frame
116 | wire_frame_merged = merge_wire_frames(wire_frame)
117 | data.append(dict(
118 | type="scatter3d",
119 | x=wire_frame_merged[0],
120 | y=wire_frame_merged[1],
121 | z=wire_frame_merged[2],
122 | mode="lines",
123 | line=dict(color=color,),
124 | opacity=0.3,
125 | ))
126 |
127 | if plot_dist:
128 | # distance between two poses (camera centers)
129 | center_merged = merge_centers(centers[:2])
130 | data.append(dict(
131 | type="scatter3d",
132 | x=center_merged[0],
133 | y=center_merged[1],
134 | z=center_merged[2],
135 | mode="lines",
136 | line=dict(color="red", width=4,),
137 | ))
138 |
139 | if len(centers) == 4:
140 | center_merged = merge_centers(centers[2:4])
141 | data.append(dict(
142 | type="scatter3d",
143 | x=center_merged[0],
144 | y=center_merged[1],
145 | z=center_merged[2],
146 | mode="lines",
147 | line=dict(color="red", width=4,),
148 | ))
149 |
150 | # send data to visdom
151 | vis._send(dict(
152 | data=data,
153 | win="poses",
154 | eid=win_name,
155 | layout=dict(
156 | title=f"({step})",
157 | autosize=True,
158 | margin=dict(l=30, r=30, b=30, t=30,),
159 | showlegend=False,
160 | yaxis=dict(
161 | scaleanchor="x",
162 | scaleratio=1,
163 | )
164 | ),
165 | opts=dict(title=f"{win_name} poses ({step})",),
166 | ))
167 |
168 |
169 | def plot_save_poses(
170 | cam_depth: float,
171 | fig,
172 | pose: torch.Tensor,
173 | pose_ref: torch.Tensor = None,
174 | path: str = None,
175 | ep=None,
176 | axis_len: float = 1.0,
177 | ):
178 | # get the camera meshes
179 | _, _, cam = get_camera_mesh(pose, depth=cam_depth)
180 | cam = cam.numpy()
181 |
182 | if pose_ref is not None:
183 | _, _, cam_ref = get_camera_mesh(pose_ref, depth=cam_depth)
184 | cam_ref = cam_ref.numpy()
185 |
186 | # set up plot window(s)
187 | plt.title(f"epoch {ep}")
188 | ax1 = fig.add_subplot(121, projection="3d")
189 | ax2 = fig.add_subplot(122, projection="3d")
190 | setup_3D_plot(
191 | ax1, elev=-90, azim=-90,
192 | lim=edict(x=(-axis_len, axis_len), y=(-axis_len,
193 | axis_len), z=(-axis_len, axis_len))
194 | )
195 | setup_3D_plot(
196 | ax2, elev=0, azim=-90,
197 | lim=edict(x=(-axis_len, axis_len), y=(-axis_len,
198 | axis_len), z=(-axis_len, axis_len))
199 | )
200 | ax1.set_title("forward-facing view", pad=0)
201 | ax2.set_title("top-down view", pad=0)
202 | plt.subplots_adjust(left=0, right=1, bottom=0,
203 | top=0.95, wspace=0, hspace=0)
204 | plt.margins(tight=True, x=0, y=0)
205 |
206 | # plot the cameras
207 | N = len(cam)
208 | color = plt.get_cmap("gist_rainbow")
209 | for i in range(N):
210 | if pose_ref is not None:
211 | ax1.plot(cam_ref[i, :, 0], cam_ref[i, :, 1],
212 | cam_ref[i, :, 2], color=(0.1, 0.1, 0.1), linewidth=1)
213 | ax2.plot(cam_ref[i, :, 0], cam_ref[i, :, 1],
214 | cam_ref[i, :, 2], color=(0.1, 0.1, 0.1), linewidth=1)
215 | ax1.scatter(cam_ref[i, 5, 0], cam_ref[i, 5, 1],
216 | cam_ref[i, 5, 2], color=(0.1, 0.1, 0.1), s=40)
217 | ax2.scatter(cam_ref[i, 5, 0], cam_ref[i, 5, 1],
218 | cam_ref[i, 5, 2], color=(0.1, 0.1, 0.1), s=40)
219 | c = np.array(color(float(i) / N)) * 0.8
220 | ax1.plot(cam[i, :, 0], cam[i, :, 1], cam[i, :, 2], color=c)
221 | ax2.plot(cam[i, :, 0], cam[i, :, 1], cam[i, :, 2], color=c)
222 | ax1.scatter(cam[i, 5, 0], cam[i, 5, 1], cam[i, 5, 2], color=c, s=40)
223 | ax2.scatter(cam[i, 5, 0], cam[i, 5, 1], cam[i, 5, 2], color=c, s=40)
224 |
225 | png_fname = f"{path}/{ep}.png"
226 | plt.savefig(png_fname, dpi=75)
227 | # clean up
228 | plt.clf()
229 |
230 |
231 | def setup_3D_plot(ax, elev, azim, lim=None):
232 | ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
233 | ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
234 | ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
235 | ax.xaxis._axinfo["grid"]["color"] = (0.9, 0.9, 0.9, 1)
236 | ax.yaxis._axinfo["grid"]["color"] = (0.9, 0.9, 0.9, 1)
237 | ax.zaxis._axinfo["grid"]["color"] = (0.9, 0.9, 0.9, 1)
238 | ax.xaxis.set_tick_params(labelsize=8)
239 | ax.yaxis.set_tick_params(labelsize=8)
240 | ax.zaxis.set_tick_params(labelsize=8)
241 | ax.set_xlabel("X", fontsize=16)
242 | ax.set_ylabel("Y", fontsize=16)
243 | ax.set_zlabel("Z", fontsize=16)
244 | ax.set_xlim(lim.x[0], lim.x[1])
245 | ax.set_ylim(lim.y[0], lim.y[1])
246 | ax.set_zlim(lim.z[0], lim.z[1])
247 | ax.view_init(elev=elev, azim=azim)
248 |
--------------------------------------------------------------------------------
/config/ace/llff.yaml:
--------------------------------------------------------------------------------
1 | neural_field_type: mlp
2 | expname: ${neural_field_type}_${task}_${dataset.name}_${dataset.scene}
3 | task: pose
4 | seed: 42
5 |
6 | dataset:
7 | name: llff
8 | root_dir: # eg.: /home/user/datasets/${dataset.name}
9 | encoder_path: # eg: /home/user/Projects/ZeroGS/conerf/model/scene_regressor/ace_encoder_pretrained.pt
10 | scene: ['fern', 'flower', 'fortress', 'horns', 'leaves', 'orchids', 'room', 'trex']
11 | image_resolution:
12 | scale: true
13 | rotate: false
14 | use_aug: true
15 | aug_rotation: 15
16 | aug_scale: 1.5
17 | factor: 4
18 | val_interval: -1
19 | apply_mask: false
20 | cam_depth: 0.2
21 | axis_len: 1.7
22 |
23 | trainer:
24 | epochs: 16
25 | max_patch_loops_per_epoch: 10
26 | samples_per_image: 1024
27 | training_buffer_size: 8000000
28 | batch_size: 5120
29 | min_iterations_per_epoch: 5000
30 | max_iterations_per_epoch: 10000
31 | early_stop_thresh: 6
32 | use_half: true
33 | ckpt_path: ""
34 | no_load_opt: false
35 | no_load_scheduler: false
36 | enable_tensorboard: true
37 | enable_visdom: false
38 | visdom_server: localhost
39 | visdom_port: 9002
40 | n_tensorboard: 100
41 | n_validation: 5000
42 | n_checkpoint: 1000
43 | distributed: false
44 | excluded_gpus: []
45 | num_workers: 4
46 | local_rank: 0
47 |
48 | optimizer:
49 | lr_sc_min: 0.0005 # lowest learning rate of 1 cycle scheduler
50 | lr_sc_max: 0.003 # highest learning rate of 1 cycle scheduler
51 | lr_pr: 1e-3 # learning rate for the pose refiner
52 | lr_cr: 1e-3 # learning rate for the calibration refiner
53 |
54 | regressor:
55 | # ZoeD_N is fine-tuned for metric depth on NYU Depth v2 for relative depth estimation,
56 | # ZoeD_K is fine-tuned for metric depth on KITTI for relative depth estimation.
57 | # ZoeD_NK has two separate heads fine-tuned on both NYU Depth v2 and KITTI.
58 |
59 | # [ZoeDepth, metric3d]
60 | depth_net_method: ZoeDepth
61 | # ZoeDepth: [ZoeD_N, ZoeD_K, ZoeD_NK]; metric3d: [metric3d_vit_small, metric3d_vit_large, metric3d_vit_giant2]
62 | depth_net_type: ZoeD_NK
63 | num_seed_image_trials: 5
64 | num_reloc_images_max: 1000 # the number of relocalization test during seed reconstruction.
65 | num_head_blocks: 1 # The depth of the head network.
66 | use_homogeneous: true
67 | depth_min: 0.1
68 | depth_max: 1000 # [ZoeDepth: 1000; metric3d: 200]
69 | depth_target: 10
70 |
71 | pose_estimator:
72 | reproj_thresh: 10 # inlier threshold in pixels (RGB) or centimeters (RGB-D)
73 | hypotheses: 64 # number of hypotheses, i.e. number of RANSAC iterations.
74 | inlier_alpha: 100 # alpha parameter of the soft inlier count.
75 | max_pixel_error: 100 # maximum reprojection (RGB, in px) or 3D distance (RGB-D, in cm) error when checking pose consistency.
76 | min_inlier_count: 2000 # minimum number of inlier correspondences when registering an image
77 |
78 | loss:
79 | repro_loss_hard_clamp: 1000
80 | repro_loss_soft_clamp: 50
81 | repro_loss_soft_clamp_min: 1
82 | repro_loss_type: tanh # dyntanh
83 | repro_loss_scheduler: circle
84 |
--------------------------------------------------------------------------------
/config/ace/mipnerf360.yaml:
--------------------------------------------------------------------------------
1 | neural_field_type: mlp
2 | expname: ${neural_field_type}_${task}_${dataset.name}_${dataset.scene}
3 | task: pose
4 | seed: 42
5 |
6 | dataset:
7 | name: mipnerf360
8 | root_dir: # eg.: /home/user/datasets/${dataset.name}
9 | encoder_path: # eg: /home/user/Projects/ZeroGS/conerf/model/scene_regressor/ace_encoder_pretrained.pt
10 | scene: ["bicycle", "bonsai", "counter", "garden", "kitchen", "room", "stump", "flowers", "treehill"]
11 | image_resolution:
12 | scale: true
13 | rotate: false
14 | use_aug: true
15 | aug_rotation: 15
16 | aug_scale: 1.5
17 | factor: 4
18 | val_interval: -1
19 | apply_mask: false
20 | cam_depth: 0.1
21 | axis_len: 1.0
22 |
23 | trainer:
24 | epochs: 16
25 | max_patch_loops_per_epoch: 10
26 | samples_per_image: 1024
27 | training_buffer_size: 8000000
28 | batch_size: 5120
29 | min_iterations_per_epoch: 5000
30 | max_iterations_per_epoch: 10000
31 | early_stop_thresh: 6
32 | use_half: true
33 | ckpt_path: ""
34 | no_load_opt: false
35 | no_load_scheduler: false
36 | enable_tensorboard: true
37 | enable_visdom: false
38 | visdom_server: localhost
39 | visdom_port: 9002
40 | n_tensorboard: 100
41 | n_validation: 5000
42 | n_checkpoint: 1000
43 | distributed: false
44 | excluded_gpus: []
45 | num_workers: 4
46 | local_rank: 0
47 |
48 | optimizer:
49 | lr_sc_min: 0.0005 # lowest learning rate of 1 cycle scheduler
50 | lr_sc_max: 0.003 # highest learning rate of 1 cycle scheduler
51 | lr_pr: 1e-3 # learning rate for the pose refiner
52 | lr_cr: 1e-3 # learning rate for the calibration refiner
53 |
54 | regressor:
55 | # ZoeD_N is fine-tuned for metric depth on NYU Depth v2 for relative depth estimation,
56 | # ZoeD_K is fine-tuned for metric depth on KITTI for relative depth estimation.
57 | # ZoeD_NK has two separate heads fine-tuned on both NYU Depth v2 and KITTI.
58 |
59 | # [ZoeDepth, metric3d]
60 | depth_net_method: ZoeDepth
61 | # ZoeDepth: [ZoeD_N, ZoeD_K, ZoeD_NK]; metric3d: [metric3d_vit_small, metric3d_vit_large, metric3d_vit_giant2]
62 | depth_net_type: ZoeD_NK
63 | num_seed_image_trials: 5
64 | num_reloc_images_max: 1000 # the number of relocalization test during seed reconstruction.
65 | num_head_blocks: 1 # The depth of the head network.
66 | use_homogeneous: true
67 | depth_min: 0.1
68 | depth_max: 1000 # [ZoeDepth: 1000; metric3d: 200]
69 | depth_target: 10
70 |
71 | pose_estimator:
72 | reproj_thresh: 10 # inlier threshold in pixels (RGB) or centimeters (RGB-D)
73 | hypotheses: 64 # number of hypotheses, i.e. number of RANSAC iterations.
74 | inlier_alpha: 100 # alpha parameter of the soft inlier count.
75 | max_pixel_error: 100 # maximum reprojection (RGB, in px) or 3D distance (RGB-D, in cm) error when checking pose consistency.
76 | min_inlier_count: 2000 # minimum number of inlier correspondences when registering an image
77 |
78 | loss:
79 | repro_loss_hard_clamp: 1000
80 | repro_loss_soft_clamp: 50
81 | repro_loss_soft_clamp_min: 1
82 | repro_loss_type: tanh # dyntanh
83 | repro_loss_scheduler: circle
84 |
--------------------------------------------------------------------------------
/config/ace/tanks_and_temples.yaml:
--------------------------------------------------------------------------------
1 | neural_field_type: mlp
2 | expname: ${neural_field_type}_${task}_${dataset.name}_${dataset.scene}
3 | task: pose
4 | seed: 42
5 |
6 | dataset:
7 | name: tanks_and_temples
8 | root_dir: # eg.: /home/user/datasets/${dataset.name}
9 | encoder_path: # eg: /home/user/Projects/ZeroGS/conerf/model/scene_regressor/ace_encoder_pretrained.pt
10 | scene: ["Family", "Francis", "Ignatius", "Train", "Truck", "Playground"]
11 | image_resolution:
12 | scale: true
13 | rotate: false
14 | use_aug: true
15 | aug_rotation: 15
16 | aug_scale: 1.5
17 | factor: 2
18 | val_interval: -1
19 | apply_mask: false
20 | cam_depth: 0.1
21 | axis_len: 1.0
22 |
23 | trainer:
24 | epochs: 16
25 | max_patch_loops_per_epoch: 10
26 | samples_per_image: 1024
27 | training_buffer_size: 8000000
28 | batch_size: 5120
29 | min_iterations_per_epoch: 5000
30 | max_iterations_per_epoch: 10000
31 | early_stop_thresh: 6
32 | use_half: true
33 | ckpt_path: ""
34 | no_load_opt: false
35 | no_load_scheduler: false
36 | enable_tensorboard: true
37 | enable_visdom: false
38 | visdom_server: localhost
39 | visdom_port: 9002
40 | n_tensorboard: 100
41 | n_validation: 5000
42 | n_checkpoint: 1000
43 | distributed: false
44 | excluded_gpus: []
45 | num_workers: 4
46 | local_rank: 0
47 |
48 | optimizer:
49 | lr_sc_min: 0.0005 # lowest learning rate of 1 cycle scheduler
50 | lr_sc_max: 0.003 # highest learning rate of 1 cycle scheduler
51 | lr_pr: 1e-3 # learning rate for the pose refiner
52 | lr_cr: 1e-3 # learning rate for the calibration refiner
53 |
54 | regressor:
55 | # ZoeD_N is fine-tuned for metric depth on NYU Depth v2 for relative depth estimation,
56 | # ZoeD_K is fine-tuned for metric depth on KITTI for relative depth estimation.
57 | # ZoeD_NK has two separate heads fine-tuned on both NYU Depth v2 and KITTI.
58 |
59 | # [ZoeDepth, metric3d]
60 | depth_net_method: ZoeDepth
61 | # ZoeDepth: [ZoeD_N, ZoeD_K, ZoeD_NK]; metric3d: [metric3d_vit_small, metric3d_vit_large, metric3d_vit_giant2]
62 | depth_net_type: ZoeD_NK
63 | num_seed_image_trials: 5
64 | num_reloc_images_max: 1000 # the number of relocalization test during seed reconstruction.
65 | num_head_blocks: 1 # The depth of the head network.
66 | use_homogeneous: true
67 | depth_min: 0.1
68 | depth_max: 1000 # [ZoeDepth: 1000; metric3d: 200]
69 | depth_target: 10
70 |
71 | pose_estimator:
72 | reproj_thresh: 10 # inlier threshold in pixels (RGB) or centimeters (RGB-D)
73 | hypotheses: 64 # number of hypotheses, i.e. number of RANSAC iterations.
74 | inlier_alpha: 100 # alpha parameter of the soft inlier count.
75 | max_pixel_error: 100 # maximum reprojection (RGB, in px) or 3D distance (RGB-D, in cm) error when checking pose consistency.
76 | min_inlier_count: 2200 # minimum number of inlier correspondences when registering an image
77 |
78 | loss:
79 | repro_loss_hard_clamp: 1000
80 | repro_loss_soft_clamp: 50
81 | repro_loss_soft_clamp_min: 1
82 | repro_loss_type: tanh # dyntanh
83 | repro_loss_scheduler: circle
84 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=[E1101,W0621]
2 |
3 | import os
4 | import copy
5 | import json
6 | import warnings
7 | from typing import List
8 |
9 | import omegaconf
10 | from omegaconf import OmegaConf
11 |
12 | from conerf.evaluators.ace_zero_evaluator import AceZeroEvaluator
13 | from conerf.utils.utils import setup_seed
14 |
15 | warnings.filterwarnings("ignore", category=UserWarning)
16 |
17 |
18 | def create_evaluator(
19 | config: OmegaConf,
20 | load_train_data: bool = False,
21 | trainset=None,
22 | load_val_data: bool = True,
23 | valset=None,
24 | load_test_data: bool = False,
25 | testset = None,
26 | models: List = None,
27 | meta_data: List = None,
28 | verbose: bool = False,
29 | device: str = "cuda",
30 | ):
31 | """Factory function for training neural network trainers."""
32 | if config.task == "pose":
33 | evaluator = AceZeroEvaluator(
34 | config, load_train_data, trainset,
35 | load_val_data, valset, load_test_data,
36 | testset, models, meta_data, verbose, device
37 | )
38 | else:
39 | raise NotImplementedError
40 |
41 | return evaluator
42 |
43 |
44 | if __name__ == "__main__":
45 | from conerf.utils.config import config_parser, load_config
46 | args = config_parser()
47 |
48 | # parse YAML config to OmegaConf
49 | config = load_config(args.config)
50 |
51 | assert config.dataset.get("data_split_json", "") != "" or config.dataset.scene != ""
52 |
53 | setup_seed(config.seed)
54 |
55 | scenes = []
56 | if config.dataset.get("data_split_json", "") != "" and config.dataset.scene == "":
57 | # For objaverse only.
58 | with open(config.dataset.data_split_json, "r", encoding="utf-8") as fp:
59 | obj_id_to_name = json.load(fp)
60 |
61 | for idx, name in obj_id_to_name.items():
62 | scenes.append(name)
63 | elif (
64 | type(config.dataset.scene) == omegaconf.listconfig.ListConfig # pylint: disable=C0123
65 | ): # pylint: disable=C0123
66 | scene_list = list(config.dataset.scene)
67 | for sc in config.dataset.scene:
68 | scenes.append(sc)
69 | else:
70 | scenes.append(config.dataset.scene)
71 |
72 | for scene in scenes:
73 | data_dir = os.path.join(config.dataset.root_dir, scene)
74 | if not os.path.exists(data_dir):
75 | continue
76 |
77 | local_config = copy.deepcopy(config)
78 | local_config.expname = (
79 | f"{config.neural_field_type}_{config.task}_{config.dataset.name}_{scene}"
80 | )
81 | local_config.expname = local_config.expname + "_" + args.suffix
82 | local_config.dataset.scene = scene
83 |
84 | evaluator = create_evaluator(
85 | local_config,
86 | load_train_data=False,
87 | trainset=None,
88 | load_val_data=True,
89 | valset=None,
90 | load_test_data=True,
91 | testset=None,
92 | verbose=True,
93 | )
94 | evaluator.eval(split="val")
95 | # evaluator.eval(split="test")
96 | evaluator.export_mesh()
97 |
--------------------------------------------------------------------------------
/scripts/env/install.sh:
--------------------------------------------------------------------------------
1 | # conda create -n zero_gs python=3.9
2 | # conda activate zero_gs
3 |
4 | # install pytorch
5 | # Ref: https://pytorch.org/get-started/previous-versions/
6 |
7 | # CUDA 11.7
8 | # conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.7 -c pytorch -c nvidia
9 | conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.8 -c pytorch -c nvidia
10 | conda install -c "nvidia/label/cuda-11.8.0" cuda
11 |
12 | # Basic packages.
13 | pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg easydict \
14 | kornia lpips tensorboard visdom tensorboardX matplotlib plyfile trimesh h5py pandas \
15 | omegaconf PyMCubes Ninja pyransac3d einops pyglet pre-commit pylint GPUtil \
16 | open3d pyrender
17 | pip install timm==0.6.7
18 | pip install -U scikit-learn
19 | pip install git+https://github.com/jonbarron/robust_loss_pytorch
20 | pip install torch-geometric==2.4.0
21 |
22 | conda install pytorch3d -c pytorch3d
23 | conda install conda-forge::opencv
24 | conda install pytorch-scatter -c pyg
25 | conda remove ffmpeg --force
26 |
27 | # Third-parties.
28 |
29 | cd submodules/dsacstar
30 | python setup.py install
31 |
32 | cd ../../
33 | pip install submodules/simple-knn
34 | pip install submodules/diff-gaussian-rasterization
35 |
36 | mkdir 3rd_party && cd 3rd_party
37 |
38 | git clone https://github.com/cvg/sfm-disambiguation-colmap.git
39 | cd sfm-disambiguation-colmap
40 | python -m pip install -e .
41 | cd ..
42 |
43 | # HLoc is used for extracting keypoints and matching features.
44 | git clone --recursive https://github.com/cvg/Hierarchical-Localization/
45 | cd Hierarchical-Localization/
46 | python -m pip install -e .
47 | cd ..
48 |
49 | # Tiny-cuda-cnn & nerfacc
50 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
51 |
52 | # nerfacc
53 | # pip install nerfacc -f https://nerfacc-bucket.s3.us-west-2.amazonaws.com/whl/torch-1.13.1_cu117.html
54 | # or install the latest version
55 | # pip install git+https://github.com/KAIR-BAIR/nerfacc.git
56 | # To install a specified version:
57 | # pip install nerfacc==0.3.5 -f https://nerfacc-bucket.s3.us-west-2.amazonaws.com/whl/torch-1.13.1_cu117.html
58 | pip install nerfacc==0.3.5 -f https://nerfacc-bucket.s3.us-west-2.amazonaws.com/whl/torch-2.0.0_cu118.html
59 |
60 | # Install CURope
61 | cd croco/models/curope/
62 | python setup.py build_ext --inplace
63 |
--------------------------------------------------------------------------------
/scripts/eval/eval_ace_zero.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_IDS=$1 # {'0,1,2,...'}
4 |
5 | export PYTHONDONTWRITEBYTECODE=1
6 | export CUDA_VISIBLE_DEVICES=${CUDA_IDS}
7 |
8 | # Default parameters.
9 | DATASET='blender' # [blender, mipnerf360, tanks_and_temples]
10 | ENCODING='ace'
11 | SUFFIX=''
12 |
13 | NUM_CMD_PARAMS=$#
14 | if [ $NUM_CMD_PARAMS -eq 2 ]
15 | then
16 | SUFFIX=$2
17 | elif [ $NUM_CMD_PARAMS -eq 3 ]
18 | then
19 | SUFFIX=$2
20 | DATASET=$3
21 | elif [ $NUM_CMD_PARAMS -eq 4 ]
22 | then
23 | SUFFIX=$2
24 | DATASET=$3
25 | ENCODING=$4
26 | fi
27 |
28 | YAML=${ENCODING}/${DATASET}'.yaml'
29 | echo "Using yaml file: ${YAML}"
30 |
31 | HOME_DIR=$HOME
32 | CODE_ROOT_DIR=$HOME/'Projects/ZeroGS'
33 |
34 | cd $CODE_ROOT_DIR
35 |
36 | python eval.py --config 'config/'${YAML} \
37 | --suffix $SUFFIX
38 |
--------------------------------------------------------------------------------
/scripts/eval/vis_recon.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 | import numpy as np
6 | import open3d as o3d
7 |
8 | from conerf.datasets.realworld import similarity_from_cameras, normalize_poses
9 | from conerf.datasets.utils import compute_bounding_box3D, points_in_bbox3D
10 | from conerf.pycolmap.pycolmap.scene_manager import SceneManager
11 | from conerf.visualization.scene_visualizer import visualize_single_scene
12 |
13 |
14 | def config_parser():
15 | parser = argparse.ArgumentParser()
16 |
17 | parser.add_argument("--colmap_dir",
18 | type=str,
19 | default="",
20 | help="absolute path of config file")
21 | parser.add_argument("--output_dir",
22 | type=str,
23 | default="",
24 | help="absolute path of config file")
25 |
26 | args = parser.parse_args()
27 |
28 | return args
29 |
30 |
31 | if __name__ == '__main__':
32 | args = config_parser()
33 | rotate = False
34 |
35 | if not os.path.exists(args.output_dir):
36 | os.makedirs(args.output_dir)
37 |
38 | # (1) Loading camera poses and 3D points.
39 | manager = SceneManager(args.colmap_dir, load_points=False)
40 | manager.load()
41 |
42 | ply_path = os.path.join(args.colmap_dir, "points3D.ply")
43 | pcd = o3d.io.read_point_cloud(ply_path)
44 | points = np.asarray(pcd.points)
45 | colors = np.asarray(pcd.colors)
46 | num_points = np.asarray(pcd.points).shape[0]
47 | print(f'num points: {num_points}')
48 |
49 | colmap_image_data = manager.images
50 | colmap_camera_data = manager.cameras
51 |
52 | w2c_mats = []
53 | bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
54 | for k in colmap_image_data:
55 | im_data = colmap_image_data[k]
56 | w2c = np.concatenate([
57 | np.concatenate(
58 | [im_data.R(), im_data.tvec.reshape(3, 1)], 1), bottom
59 | ], axis=0)
60 | w2c_mats.append(w2c)
61 | w2c_mats = np.stack(w2c_mats, axis=0)
62 | cam_to_world = np.linalg.inv(w2c_mats)
63 |
64 | # (2) Normalize the scene.
65 | T, scale = similarity_from_cameras(
66 | cam_to_world, strict_scaling=False
67 | )
68 | cam_to_world = np.einsum("nij, ki -> nkj", cam_to_world, T)
69 | cam_to_world[:, :3, 3:4] *= scale
70 |
71 | points = scale * (T[:3, :3] @ points.T + T[:3, 3][..., None]).T # [Np, 3]
72 |
73 | # (3) Rotate the scene to align with ground plane.
74 | if rotate:
75 | down_pcd = pcd.voxel_down_sample(voxel_size=0.1)
76 | points_for_est_normal = np.asarray(down_pcd.points)
77 | print(
78 | f'num points for estimating normal: {points_for_est_normal.shape}')
79 | cam_to_world, _, R, t = normalize_poses(
80 | torch.from_numpy(cam_to_world).float(), # pylint: disable=E1101
81 | torch.from_numpy(points_for_est_normal).float(
82 | ), # pylint: disable=E1101
83 | up_est_method="ground",
84 | center_est_method="lookat",
85 | )
86 | cam_to_world = cam_to_world.numpy()
87 | points[:, :] = (R @ points.T + t).T
88 |
89 | # (4) Compute bounding box to exclude points outside the bounding box.
90 | aabb = compute_bounding_box3D(
91 | torch.from_numpy(cam_to_world[..., :, -1]), # pylint: disable=E1101
92 | scale_factor=[7, 7, 7], # [4.0,4.0,4.0]
93 | ).numpy()
94 | valid_point_indices = points_in_bbox3D(points, aabb).reshape(-1)
95 | points = points[valid_point_indices]
96 | colors = colors[valid_point_indices]
97 | colors = np.clip(colors, 0, 1)
98 | print(f'num points: {points.shape[0]}')
99 |
100 | pcd.points = o3d.utility.Vector3dVector(points)
101 | pcd.colors = o3d.utility.Vector3dVector(colors)
102 |
103 | # (5) Downsample points if there are too many.
104 | if num_points > 2000000:
105 | down_pcd = pcd.voxel_down_sample(voxel_size=0.005)
106 | points = np.asarray(down_pcd.points)
107 | colors = np.asarray(down_pcd.colors)
108 | print(f'points shape: {points.shape}')
109 |
110 | pcd.points = o3d.utility.Vector3dVector(points)
111 | pcd.colors = o3d.utility.Vector3dVector(colors)
112 | else:
113 | colors = np.asarray(pcd.colors)
114 | pcd.colors = o3d.utility.Vector3dVector(colors)
115 |
116 | visualize_single_scene(
117 | pcd,
118 | cam_to_world,
119 | size=0.05,
120 | rainbow_color=True,
121 | output_directory=args.output_dir
122 | )
123 |
124 | video_filename = os.path.join(args.output_dir, "zero_gs_scene.mp4")
125 | os.system(f"ffmpeg -framerate 10 -i {args.output_dir}/screenshot_%05d.png -c:v libx264 " +
126 | f"-pix_fmt yuv420p {video_filename}"
127 | )
128 |
--------------------------------------------------------------------------------
/scripts/preprocess/colmap_mapping.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | DATASET_PATH=$1
5 | OUTPUT_PATH=$2
6 | VOC_TREE_PATH=$3
7 | MOST_SIMILAR_IMAGES_NUM=$4
8 | CUDA_IDS=$5
9 |
10 | NUM_THREADS=24
11 | # export PYTHONDONTWRITEBYTECODE=1
12 | # export CUDA_VISIBLE_DEVICES=${CUDA_IDS}
13 |
14 | COLMAP_DIR=/usr/local/bin
15 | COLMAP_EXE=$COLMAP_DIR/colmap
16 |
17 | mkdir $OUTPUT_PATH/sparse
18 |
19 | $COLMAP_EXE feature_extractor \
20 | --database_path=$OUTPUT_PATH/database.db \
21 | --image_path=$DATASET_PATH/images \
22 | --SiftExtraction.num_threads=$NUM_THREADS \
23 | --SiftExtraction.use_gpu=1 \
24 | --SiftExtraction.gpu_index=$CUDA_IDS \
25 | --SiftExtraction.estimate_affine_shape=true \
26 | --SiftExtraction.domain_size_pooling=true \
27 | --ImageReader.camera_model PINHOLE \
28 | --ImageReader.single_camera 1 \
29 | --SiftExtraction.max_num_features 8192 \
30 | > $DATASET_PATH/log_extract_feature.txt 2>&1
31 |
32 | $COLMAP_EXE vocab_tree_matcher \
33 | --database_path=$OUTPUT_PATH/database.db \
34 | --SiftMatching.num_threads=$NUM_THREADS \
35 | --SiftMatching.use_gpu=1 \
36 | --SiftMatching.gpu_index=$CUDA_IDS \
37 | --SiftMatching.guided_matching=false \
38 | --VocabTreeMatching.num_images=$MOST_SIMILAR_IMAGES_NUM \
39 | --VocabTreeMatching.num_nearest_neighbors=5 \
40 | --VocabTreeMatching.vocab_tree_path=$VOC_TREE_PATH \
41 | > $DATASET_PATH/log_match.txt 2>&1
42 |
43 | $COLMAP_EXE mapper $OUTPUT_PATH \
44 | --database_path=$OUTPUT_PATH/database.db \
45 | --image_path=$DATASET_PATH/images \
46 | --output_path=$OUTPUT_PATH/sparse \
47 | --Mapper.num_threads=$NUM_THREADS \
48 | > $DATASET_PATH/log_sfm.txt 2>&1
49 |
--------------------------------------------------------------------------------
/scripts/preprocess/database.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import sqlite3
3 |
4 | from typing import Dict
5 |
6 | import numpy as np
7 |
8 | IS_PYTHON3 = sys.version_info[0] >= 3
9 |
10 | #-------------------------------------------------------------------------------
11 | # create table commands
12 |
13 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
14 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
15 | model INTEGER NOT NULL,
16 | width INTEGER NOT NULL,
17 | height INTEGER NOT NULL,
18 | params BLOB,
19 | prior_focal_length INTEGER NOT NULL)"""
20 |
21 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
22 | image_id INTEGER PRIMARY KEY NOT NULL,
23 | rows INTEGER NOT NULL,
24 | cols INTEGER NOT NULL,
25 | data BLOB,
26 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
27 |
28 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
29 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
30 | name TEXT NOT NULL UNIQUE,
31 | camera_id INTEGER NOT NULL,
32 | prior_qw REAL,
33 | prior_qx REAL,
34 | prior_qy REAL,
35 | prior_qz REAL,
36 | prior_tx REAL,
37 | prior_ty REAL,
38 | prior_tz REAL,
39 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647),
40 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))"""
41 |
42 | CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries (
43 | pair_id INTEGER PRIMARY KEY NOT NULL,
44 | rows INTEGER NOT NULL,
45 | cols INTEGER NOT NULL,
46 | data BLOB,
47 | config INTEGER NOT NULL,
48 | F BLOB,
49 | E BLOB,
50 | H BLOB)"""
51 |
52 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
53 | image_id INTEGER PRIMARY KEY NOT NULL,
54 | rows INTEGER NOT NULL,
55 | cols INTEGER NOT NULL,
56 | data BLOB,
57 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
58 |
59 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
60 | pair_id INTEGER PRIMARY KEY NOT NULL,
61 | rows INTEGER NOT NULL,
62 | cols INTEGER NOT NULL,
63 | data BLOB)"""
64 |
65 | CREATE_NAME_INDEX = \
66 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
67 |
68 | CREATE_ALL = "; ".join([CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE,
69 | CREATE_IMAGES_TABLE, CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE,
70 | CREATE_MATCHES_TABLE, CREATE_NAME_INDEX])
71 |
72 |
73 | def array_to_blob(array):
74 | if IS_PYTHON3:
75 | return array.tostring()
76 |
77 | return np.getbuffer(array)
78 |
79 |
80 | class COLMAPDatabase(sqlite3.Connection):
81 | @staticmethod
82 | def connect(database_path):
83 | return sqlite3.connect(database_path, factory=COLMAPDatabase)
84 |
85 |
86 | def __init__(self, *args, **kwargs):
87 | super().__init__(*args, **kwargs)
88 |
89 | self.initialize_tables = lambda: self.executescript(CREATE_ALL)
90 |
91 | self.initialize_cameras = \
92 | lambda: self.executescript(CREATE_CAMERAS_TABLE)
93 | self.initialize_descriptors = \
94 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
95 | self.initialize_images = \
96 | lambda: self.executescript(CREATE_IMAGES_TABLE)
97 | self.initialize_inlier_matches = \
98 | lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE)
99 | self.initialize_keypoints = \
100 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
101 | self.initialize_matches = \
102 | lambda: self.executescript(CREATE_MATCHES_TABLE)
103 |
104 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
105 |
106 | def add_camera(self, model, width, height, params,
107 | prior_focal_length=False, camera_id=None):
108 | params = np.asarray(params, np.float64)
109 | cursor = self.execute(
110 | "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
111 | (camera_id, model, width, height, array_to_blob(params),
112 | prior_focal_length))
113 | return cursor.lastrowid
114 |
115 |
116 | def fetch_images_from_database(database_path: str) -> Dict:
117 | db = COLMAPDatabase.connect(database_path) # pylint: disable=[C0103]
118 | rows = db.execute("SELECT * FROM images")
119 | name_to_image_id = {}
120 | for row in rows:
121 | image_id, name = row[0], row[1]
122 | # print(f'image_id: {image_id}, name: {name}')
123 | name_to_image_id[name] = image_id
124 |
125 | return name_to_image_id
126 |
--------------------------------------------------------------------------------
/scripts/preprocess/hloc_mapping/filter_matches.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import argparse
3 | import numpy as np
4 | import networkx as nx
5 |
6 | from matplotlib import pyplot as plt
7 | from scipy.sparse.csgraph import minimum_spanning_tree
8 |
9 | from disambiguation.utils.read_write_database import remove_matches_from_db
10 | from disambiguation.utils.run_colmap import run_matches_importer
11 |
12 |
13 | def draw_graph(scores, plot_path, display=False):
14 | graph = nx.from_numpy_array(scores)
15 | # print(scores)
16 | pos = nx.nx_agraph.graphviz_layout(graph)
17 | edge_vmin = np.percentile(scores[scores.nonzero()], 10)
18 | edge_vmax = np.percentile(scores[scores.nonzero()], 90)
19 | # print(edge_vmin, edge_vmax)
20 | nx.draw(
21 | graph,
22 | pos,
23 | with_labels=True,
24 | edge_color=[graph[u][v]['weight'] for u, v in graph.edges],
25 | # edge_cmap=plt.cm.plasma,
26 | edge_cmap=plt.cm.YlOrRd,
27 | edge_vmin=edge_vmin,
28 | edge_vmax=edge_vmax)
29 | plt.savefig(plot_path)
30 | if display:
31 | plt.show()
32 | plt.close()
33 | return
34 |
35 |
36 | def filter_with_fixed_threshold(scores, thres, plot_path=None):
37 | valid = scores >= thres
38 | invalid = np.logical_not(valid)
39 | scores[invalid] = 0.
40 | if plot_path is not None:
41 | draw_graph(scores, plot_path, display=False)
42 | return valid
43 |
44 |
45 | def filter_with_knn(scores, k, plot_path):
46 | valid = np.zeros_like(scores, dtype=np.bool)
47 | valid_indices = scores.argsort()[:, -k:]
48 | for i in range(scores.shape[0]):
49 | for j in valid_indices[i]:
50 | valid[i, j] = True
51 | invalid = np.logical_not(valid)
52 | scores[invalid] = 0.
53 | if plot_path is not None:
54 | draw_graph(scores, plot_path, display=False)
55 | return valid
56 |
57 |
58 | def filter_with_mst_min(scores, plot_path=None):
59 | min_scores = np.minimum(scores, scores.T)
60 | assert np.allclose(min_scores, min_scores.T)
61 | mst = minimum_spanning_tree(-min_scores)
62 | valid = (-mst).toarray() > 0
63 | invalid = np.logical_not(valid)
64 | scores[invalid] = 0.
65 | if plot_path is not None:
66 | draw_graph(scores, plot_path, display=False)
67 | return valid
68 |
69 |
70 | def filter_with_mst_mean(scores, plot_path=None):
71 | mean_scores = (scores + scores.T) / 2
72 | assert np.allclose(mean_scores, mean_scores.T)
73 | mst = minimum_spanning_tree(-mean_scores)
74 | valid = (-mst).toarray() > 0
75 | invalid = np.logical_not(valid)
76 | scores[invalid] = 0.
77 | if plot_path is not None:
78 | draw_graph(scores, plot_path, display=False)
79 | return valid
80 |
81 |
82 | def filter_with_percentile(scores, percentile, plot_path=None):
83 | num_images = scores.shape[0]
84 | thres = np.zeros((num_images, 1))
85 | for i in range(num_images):
86 | thres[i] = np.percentile(scores[i, scores[i].nonzero()], percentile)
87 | valid = scores >= thres
88 | invalid = np.logical_not(valid)
89 | scores[invalid] = 0.
90 | if plot_path is not None:
91 | draw_graph(scores, plot_path, display=False)
92 | return valid
93 |
94 |
95 | def main(colmap_path: str,
96 | results_path: str,
97 | filter_type: str,
98 | threshold: float,
99 | scores_dir: Path,
100 | scores_name: str,
101 | topk: int,
102 | percentile: float,
103 | old_db_path: str,
104 | new_db_path: str,
105 | geometric_verification_type: str):
106 | scores_path = scores_dir / scores_name
107 | scores = np.load(scores_path)
108 |
109 | # valid = scores >= args.threshold
110 | if filter_type == 'threshold':
111 | assert threshold is not None
112 | output_path = results_path / ('sparse' + scores_name[6:-4] +
113 | f'_t{threshold:.2f}')
114 | output_path.mkdir(exist_ok=True)
115 | plot_path = output_path / 'match_graph.png'
116 | match_list_path = results_path / (
117 | 'match_list' + scores_name[6:-4] + f'_t{threshold}.txt')
118 | valid = filter_with_fixed_threshold(scores, threshold, plot_path)
119 | elif filter_type == 'knn':
120 | assert topk is not None
121 | output_path = results_path / ('sparse' + scores_name[6:-4] +
122 | f'_k{topk}')
123 | output_path.mkdir(exist_ok=True)
124 | plot_path = output_path / 'match_graph.png'
125 | match_list_path = results_path / (
126 | 'match_list' + scores_name[6:-4] + f'_k{topk}.txt')
127 | valid = filter_with_knn(scores, topk, plot_path)
128 | elif filter_type == 'percentile':
129 | assert percentile is not None
130 | output_path = results_path / ('sparse' + scores_name[6:-4] +
131 | f'_p{percentile}')
132 | output_path.mkdir(exist_ok=True)
133 | plot_path = output_path / 'match_graph.png'
134 | match_list_path = results_path / (
135 | 'match_list' + scores_name[6:-4] + f'_p{percentile}.txt')
136 | valid = filter_with_percentile(scores, percentile, plot_path)
137 | elif filter_type == 'mst_min':
138 | output_path = results_path / ('sparse' + scores_name[6:-4] +
139 | '_mst_min')
140 | output_path.mkdir(exist_ok=True)
141 | plot_path = output_path / 'match_graph.png'
142 | match_list_path = results_path / (
143 | 'match_list' + scores_name[6:-4] + '_mst_min.txt')
144 | valid = filter_with_mst_min(scores, plot_path)
145 | # we don't do reconstruction based with mst graph as it is too sparse.
146 | # use it for visualization only
147 | exit(0)
148 | elif filter_type == 'mst_mean':
149 | output_path = results_path / ('sparse' + scores_name[6:-4] +
150 | '_mst_mean')
151 | output_path.mkdir(exist_ok=True)
152 | plot_path = output_path / 'match_graph.png'
153 | match_list_path = results_path / (
154 | 'match_list' + scores_name[6:-4] + '_mst_mean.txt')
155 | valid = filter_with_mst_mean(scores, plot_path)
156 | # we don't do reconstruction based with mst graph as it is too sparse.
157 | # use it for visualization only
158 | exit(0)
159 | else:
160 | raise NotImplementedError
161 |
162 | remove_matches_from_db(old_db_path, new_db_path, match_list_path, valid)
163 | run_matches_importer(colmap_path,
164 | new_db_path,
165 | match_list_path,
166 | use_gpu=False,
167 | colmap_matching_type=geometric_verification_type)
168 |
169 |
170 | if __name__ == '__main__':
171 | parser = argparse.ArgumentParser()
172 | parser.add_argument('--dataset_dir', type=Path, default='datasets',
173 | help='Path to the dataset, default: %(default)s')
174 | parser.add_argument('--results_path', type=Path, default='outputs',
175 | help='Path to the output directory, default: %(default)s')
176 | parser.add_argument('--scores_name', type=str, required=True,
177 | default='yan', choices=['yan', 'cui'])
178 | parser.add_argument('--filter_type',
179 | type=str,
180 | choices=['threshold', 'knn', 'mst_min', 'mst_mean', 'percentile'])
181 | parser.add_argument('--threshold', type=float)
182 | parser.add_argument('--topk', type=int)
183 | parser.add_argument('--percentile', type=float)
184 | parser.add_argument('--colmap_path', type=Path, default='colmap')
185 | parser.add_argument('--old_db_path', type=str, Required=True)
186 | parser.add_argument('--new_db_path', type=str, Required=True)
187 | parser.add_argument('--geometric_verification_type',
188 | type=str,
189 | required=True,
190 | choices=['default', 'strict'])
191 |
192 |
193 | args = parser.parse_args()
194 |
195 | main(args.colmap_path, args.results_path, args.filter_type, args.threshold,
196 | args.scores_name, args.topk, args.percentile, args.old_db_path,
197 | args.new_db_path, args.geometric_verification_type)
198 |
--------------------------------------------------------------------------------
/scripts/preprocess/hloc_mapping/match_features.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Union, Optional, Dict, List, Tuple
3 | from pathlib import Path
4 | import pprint
5 | import collections.abc as collections
6 | from tqdm import tqdm
7 | import h5py
8 | import torch
9 |
10 | from hloc import matchers, logger
11 | from hloc.utils.base_model import dynamic_load
12 | from hloc.utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
13 | from hloc.utils.io import list_h5_names
14 |
15 |
16 | '''
17 | A set of standard configurations that can be directly selected from the command
18 | line using their name. Each is a dictionary with the following entries:
19 | - output: the name of the match file that will be generated.
20 | - model: the model configuration, as passed to a feature matcher.
21 | '''
22 | confs = {
23 | 'superglue': {
24 | 'output': 'matches-superglue',
25 | 'model': {
26 | 'name': 'superglue',
27 | 'weights': 'outdoor',
28 | 'sinkhorn_iterations': 50,
29 | },
30 | },
31 | 'superglue-fast': {
32 | 'output': 'matches-superglue-it5',
33 | 'model': {
34 | 'name': 'superglue',
35 | 'weights': 'outdoor',
36 | 'sinkhorn_iterations': 5,
37 | },
38 | },
39 | 'NN-superpoint': {
40 | 'output': 'matches-NN-mutual-dist.7',
41 | 'model': {
42 | 'name': 'nearest_neighbor',
43 | 'do_mutual_check': True,
44 | 'distance_threshold': 0.7,
45 | },
46 | },
47 | 'NN-ratio': {
48 | 'output': 'matches-NN-mutual-ratio.8',
49 | 'model': {
50 | 'name': 'nearest_neighbor',
51 | 'do_mutual_check': True,
52 | 'ratio_threshold': 0.8,
53 | }
54 | },
55 | 'NN-mutual': {
56 | 'output': 'matches-NN-mutual',
57 | 'model': {
58 | 'name': 'nearest_neighbor',
59 | 'do_mutual_check': True,
60 | },
61 | }
62 | }
63 |
64 |
65 | def main(conf: Dict,
66 | pairs: Path, features: Union[Path, str],
67 | export_dir: Optional[Path] = None,
68 | matches: Optional[Path] = None,
69 | features_ref: Optional[Path] = None,
70 | overwrite: bool = False,
71 | device='cuda') -> Path:
72 |
73 | if isinstance(features, Path) or Path(features).exists():
74 | features_q = features
75 | if matches is None:
76 | raise ValueError('Either provide both features and matches as Path'
77 | ' or both as names.')
78 | else:
79 | if export_dir is None:
80 | raise ValueError('Provide an export_dir if features is not'
81 | f' a file path: {features}.')
82 | features_q = Path(export_dir, features+'.h5')
83 | if matches is None:
84 | matches = Path(
85 | export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5')
86 |
87 | if features_ref is None:
88 | features_ref = features_q
89 | if isinstance(features_ref, collections.Iterable):
90 | features_ref = list(features_ref)
91 | else:
92 | features_ref = [features_ref]
93 |
94 | match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite, device)
95 |
96 | return matches
97 |
98 |
99 | def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None):
100 | '''Avoid to recompute duplicates to save time.'''
101 | pairs = set()
102 | for i, j in pairs_all:
103 | if (j, i) not in pairs:
104 | pairs.add((i, j))
105 | pairs = list(pairs)
106 | if match_path is not None and match_path.exists():
107 | with h5py.File(str(match_path), 'r') as fd:
108 | pairs_filtered = []
109 | for i, j in pairs:
110 | if (names_to_pair(i, j) in fd or
111 | names_to_pair(j, i) in fd or
112 | names_to_pair_old(i, j) in fd or
113 | names_to_pair_old(j, i) in fd):
114 | continue
115 | pairs_filtered.append((i, j))
116 | return pairs_filtered
117 | return pairs
118 |
119 |
120 | @torch.no_grad()
121 | def match_from_paths(conf: Dict,
122 | pairs_path: Path,
123 | match_path: Path,
124 | feature_path_q: Path,
125 | feature_paths_refs: Path,
126 | overwrite: bool = False,
127 | device='cuda') -> Path:
128 | logger.info('Matching local features with configuration:'
129 | f'\n{pprint.pformat(conf)}')
130 |
131 | if not feature_path_q.exists():
132 | raise FileNotFoundError(f'Query feature file {feature_path_q}.')
133 | for path in feature_paths_refs:
134 | if not path.exists():
135 | raise FileNotFoundError(f'Reference feature file {path}.')
136 | name2ref = {n: i for i, p in enumerate(feature_paths_refs)
137 | for n in list_h5_names(p)}
138 | match_path.parent.mkdir(exist_ok=True, parents=True)
139 |
140 | assert pairs_path.exists(), pairs_path
141 | pairs = parse_retrieval(pairs_path)
142 | pairs = [(q, r) for q, rs in pairs.items() for r in rs]
143 | pairs = find_unique_new_pairs(pairs, None if overwrite else match_path)
144 | if len(pairs) == 0:
145 | logger.info('Skipping the matching.')
146 | return
147 |
148 | # device = 'cuda' if torch.cuda.is_available() else 'cpu'
149 | Model = dynamic_load(matchers, conf['model']['name'])
150 | model = Model(conf['model']).eval().to(device)
151 |
152 | for (name0, name1) in tqdm(pairs, smoothing=.1):
153 | data = {}
154 | with h5py.File(str(feature_path_q), 'r') as fd:
155 | grp = fd[name0]
156 | for k, v in grp.items():
157 | data[k+'0'] = torch.from_numpy(v.__array__()).float().to(device)
158 | # some matchers might expect an image but only use its size
159 | data['image0'] = torch.empty((1,)+tuple(grp['image_size'])[::-1])
160 | with h5py.File(str(feature_paths_refs[name2ref[name1]]), 'r') as fd:
161 | grp = fd[name1]
162 | for k, v in grp.items():
163 | data[k+'1'] = torch.from_numpy(v.__array__()).float().to(device)
164 | data['image1'] = torch.empty((1,)+tuple(grp['image_size'])[::-1])
165 | data = {k: v[None] for k, v in data.items()}
166 |
167 | pred = model(data)
168 | pair = names_to_pair(name0, name1)
169 | with h5py.File(str(match_path), 'a') as fd:
170 | if pair in fd:
171 | del fd[pair]
172 | grp = fd.create_group(pair)
173 | matches = pred['matches0'][0].cpu().short().numpy()
174 | grp.create_dataset('matches0', data=matches)
175 |
176 | if 'matching_scores0' in pred:
177 | scores = pred['matching_scores0'][0].cpu().half().numpy()
178 | grp.create_dataset('matching_scores0', data=scores)
179 |
180 | logger.info('Finished exporting matches.')
181 |
182 |
183 | if __name__ == '__main__':
184 | parser = argparse.ArgumentParser()
185 | parser.add_argument('--pairs', type=Path, required=True)
186 | parser.add_argument('--export_dir', type=Path)
187 | parser.add_argument('--features', type=str,
188 | default='feats-superpoint-n4096-r1024')
189 | parser.add_argument('--matches', type=Path)
190 | parser.add_argument('--conf', type=str, default='superglue',
191 | choices=list(confs.keys()))
192 | args = parser.parse_args()
193 | main(confs[args.conf], args.pairs, args.features, args.export_dir)
194 |
--------------------------------------------------------------------------------
/scripts/preprocess/hloc_mapping/pairs_from_retrieval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from pathlib import Path
4 | from typing import Optional
5 | import h5py
6 | import numpy as np
7 | import torch
8 | import collections.abc as collections
9 |
10 | from hloc import logger
11 | from hloc.utils.parsers import parse_image_lists
12 | from hloc.utils.read_write_model import read_images_binary
13 | from hloc.utils.io import list_h5_names
14 |
15 |
16 | def parse_names(prefix, names, names_all):
17 | if prefix is not None:
18 | if not isinstance(prefix, str):
19 | prefix = tuple(prefix)
20 | names = [n for n in names_all if n.startswith(prefix)]
21 | elif names is not None:
22 | if isinstance(names, (str, Path)):
23 | names = parse_image_lists(names)
24 | elif isinstance(names, collections.Iterable):
25 | names = list(names)
26 | else:
27 | raise ValueError(f'Unknown type of image list: {names}.'
28 | 'Provide either a list or a path to a list file.')
29 | else:
30 | names = names_all
31 | return names
32 |
33 |
34 | def get_descriptors(names, path, name2idx=None, key='global_descriptor'):
35 | if name2idx is None:
36 | with h5py.File(str(path), 'r') as fd:
37 | desc = [fd[n][key].__array__() for n in names]
38 | else:
39 | desc = []
40 | for n in names:
41 | with h5py.File(str(path[name2idx[n]]), 'r') as fd:
42 | desc.append(fd[n][key].__array__())
43 | return torch.from_numpy(np.stack(desc, 0)).float()
44 |
45 |
46 | def pairs_from_score_matrix(scores: torch.Tensor,
47 | invalid: np.array,
48 | num_select: int,
49 | min_score: Optional[float] = None):
50 | assert scores.shape == invalid.shape
51 | if isinstance(scores, np.ndarray):
52 | scores = torch.from_numpy(scores)
53 | invalid = torch.from_numpy(invalid).to(scores.device)
54 | if min_score is not None:
55 | invalid |= scores < min_score
56 | scores.masked_fill_(invalid, float('-inf'))
57 |
58 | topk = torch.topk(scores, num_select, dim=1)
59 | indices = topk.indices.cpu().numpy()
60 | valid = topk.values.isfinite().cpu().numpy()
61 |
62 | pairs = []
63 | for i, j in zip(*np.where(valid)):
64 | pairs.append((i, indices[i, j]))
65 | return pairs
66 |
67 |
68 | def get_query_names(
69 | descriptors,
70 | query_prefix=None, query_list=None,
71 | db_prefix=None, db_list=None, db_model=None, db_descriptors=None,
72 | ):
73 | # We handle multiple reference feature files.
74 | # We only assume that names are unique among them and map names to files.
75 | if db_descriptors is None:
76 | db_descriptors = descriptors
77 | if isinstance(db_descriptors, (Path, str)):
78 | db_descriptors = [db_descriptors]
79 | name2db = {n: i for i, p in enumerate(db_descriptors)
80 | for n in list_h5_names(p)}
81 |
82 | db_names_h5 = list(name2db.keys())
83 | db_names_h5 = sorted(db_names_h5)
84 |
85 | query_names_h5 = list_h5_names(descriptors)
86 | query_names_h5 = sorted(query_names_h5)
87 |
88 | if db_model:
89 | images = read_images_binary(os.path.join(db_model, 'images.bin'))
90 | db_names = [i.name for i in images.values()]
91 | else:
92 | db_names = parse_names(db_prefix, db_list, db_names_h5)
93 |
94 | num_images = len(db_names)
95 | if num_images == 0:
96 | raise ValueError('Could not find any database image.')
97 | query_names = parse_names(query_prefix, query_list, query_names_h5)
98 |
99 | return db_names, db_descriptors, query_names, name2db
100 |
101 |
102 | def compute_similarity_score(
103 | descriptors, output,
104 | db_names, db_descriptors, query_names, name2db,
105 | device='cuda'):
106 | logger.info('Extracting image pairs from a retrieval database.')
107 |
108 | db_desc = get_descriptors(db_names, db_descriptors, name2db)
109 | query_desc = get_descriptors(query_names, descriptors)
110 | sim = torch.einsum('id,jd->ij', query_desc.to(device), db_desc.to(device))
111 |
112 | torch.save(sim, output)
113 |
114 | return sim
115 |
116 |
117 | def main(descriptors, output, num_matched,
118 | query_prefix=None, query_list=None,
119 | db_prefix=None, db_list=None, db_model=None, db_descriptors=None,
120 | device='cuda'):
121 | logger.info('Extracting image pairs from a retrieval database.')
122 |
123 | # We handle multiple reference feature files.
124 | # We only assume that names are unique among them and map names to files.
125 | if db_descriptors is None:
126 | db_descriptors = descriptors
127 | if isinstance(db_descriptors, (Path, str)):
128 | db_descriptors = [db_descriptors]
129 | name2db = {n: i for i, p in enumerate(db_descriptors)
130 | for n in list_h5_names(p)}
131 | db_names_h5 = list(name2db.keys())
132 | query_names_h5 = list_h5_names(descriptors)
133 |
134 | if db_model:
135 | images = read_images_binary(os.path.join(db_model, 'images.bin'))
136 | db_names = [i.name for i in images.values()]
137 | else:
138 | db_names = parse_names(db_prefix, db_list, db_names_h5)
139 |
140 | num_images = len(db_names)
141 | if num_images == 0:
142 | raise ValueError('Could not find any database image.')
143 | query_names = parse_names(query_prefix, query_list, query_names_h5)
144 |
145 | # device = 'cuda' if torch.cuda.is_available() else 'cpu'
146 | db_desc = get_descriptors(db_names, db_descriptors, name2db)
147 | query_desc = get_descriptors(query_names, descriptors)
148 | sim = torch.einsum('id,jd->ij', query_desc.to(device), db_desc.to(device))
149 |
150 | # Avoid self-matching
151 | self = np.array(query_names)[:, None] == np.array(db_names)[None]
152 | num_matched = min(num_images, num_matched)
153 | pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0)
154 | pairs = [(query_names[i], db_names[j]) for i, j in pairs]
155 |
156 | logger.info(f'Found {len(pairs)} pairs.')
157 | with open(output, 'w') as f:
158 | f.write('\n'.join(' '.join([i, j]) for i, j in pairs))
159 |
160 |
161 | if __name__ == "__main__":
162 | parser = argparse.ArgumentParser()
163 | parser.add_argument('--descriptors', type=Path, required=True)
164 | parser.add_argument('--output', type=Path, required=True)
165 | parser.add_argument('--num_matched', type=int, required=True)
166 | parser.add_argument('--query_prefix', type=str, nargs='+')
167 | parser.add_argument('--query_list', type=Path)
168 | parser.add_argument('--db_prefix', type=str, nargs='+')
169 | parser.add_argument('--db_list', type=Path)
170 | parser.add_argument('--db_model', type=Path)
171 | parser.add_argument('--db_descriptors', type=Path)
172 | args = parser.parse_args()
173 | main(**args.__dict__)
174 |
--------------------------------------------------------------------------------
/scripts/preprocess/hloc_mapping/reconstruction.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import shutil
3 | from typing import Optional, List
4 | import multiprocessing
5 | from pathlib import Path
6 | import pycolmap
7 |
8 | from hloc import logger
9 | from hloc.utils.database import COLMAPDatabase
10 | from hloc.triangulation import (
11 | import_features, import_matches, geometric_verification, OutputCapture)
12 |
13 |
14 | def create_empty_db(database_path):
15 | if database_path.exists():
16 | logger.warning('The database already exists, deleting it.')
17 | database_path.unlink()
18 | logger.info('Creating an empty database...')
19 | db = COLMAPDatabase.connect(database_path)
20 | db.create_tables()
21 | db.commit()
22 | db.close()
23 |
24 |
25 | def import_images(image_dir, database_path, camera_mode, image_list=None):
26 | logger.info('Importing images into the database...')
27 | images = list(image_dir.iterdir())
28 | if len(images) == 0:
29 | raise IOError(f'No images found in {image_dir}.')
30 | with pycolmap.ostream():
31 | pycolmap.import_images(database_path, image_dir, camera_mode,
32 | image_list=image_list or [])
33 |
34 |
35 | def get_image_ids(database_path):
36 | db = COLMAPDatabase.connect(database_path)
37 | images = {}
38 | for name, image_id in db.execute("SELECT name, image_id FROM images;"):
39 | images[name] = image_id
40 | db.close()
41 | return images
42 |
43 |
44 | def run_reconstruction(sfm_dir, database_path, image_dir, verbose=False):
45 | models_path = sfm_dir / 'models'
46 | models_path.mkdir(exist_ok=True, parents=True)
47 | logger.info('Running 3D reconstruction...')
48 | with OutputCapture(verbose):
49 | with pycolmap.ostream():
50 | reconstructions = pycolmap.incremental_mapping(
51 | database_path, image_dir, models_path,
52 | num_threads=min(multiprocessing.cpu_count(), 16))
53 |
54 | if len(reconstructions) == 0:
55 | logger.error('Could not reconstruct any model!')
56 | return None
57 | logger.info(f'Reconstructed {len(reconstructions)} model(s).')
58 |
59 | largest_index = None
60 | largest_num_images = 0
61 | for index, rec in reconstructions.items():
62 | num_images = rec.num_reg_images()
63 | if num_images > largest_num_images:
64 | largest_index = index
65 | largest_num_images = num_images
66 | assert largest_index is not None
67 | logger.info(f'Largest model is #{largest_index} '
68 | f'with {largest_num_images} images.')
69 |
70 | for filename in ['images.bin', 'cameras.bin', 'points3D.bin']:
71 | if (sfm_dir / filename).exists():
72 | (sfm_dir / filename).unlink()
73 | shutil.move(
74 | str(models_path / str(largest_index) / filename), str(models_path))
75 | return reconstructions[largest_index]
76 |
77 |
78 | def main(database, output_dir, image_dir, pairs, features, matches,
79 | camera_mode=pycolmap.CameraMode.AUTO, verbose=False,
80 | skip_geometric_verification=False, min_match_score=None,
81 | image_list: Optional[List[str]] = None):
82 |
83 | assert features.exists(), features
84 | assert pairs.exists(), pairs
85 | assert matches.exists(), matches
86 |
87 | output_dir.mkdir(parents=True, exist_ok=True)
88 |
89 | # create_empty_db(database)
90 | # import_images(image_dir, database, camera_mode, image_list)
91 | image_ids = get_image_ids(database)
92 | # import_features(image_ids, database, features)
93 | # import_matches(image_ids, database, pairs, matches,
94 | # min_match_score, skip_geometric_verification)
95 | # if not skip_geometric_verification:
96 | # geometric_verification(database, pairs, verbose)
97 | reconstruction = run_reconstruction(output_dir, database, image_dir, verbose)
98 | if reconstruction is not None:
99 | logger.info(f'Reconstruction statistics:\n{reconstruction.summary()}'
100 | + f'\n\tnum_input_images = {len(image_ids)}')
101 | return reconstruction
102 |
103 |
104 | if __name__ == '__main__':
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument('--output_dir', type=Path, required=True)
107 | parser.add_argument('--image_dir', type=Path, required=True)
108 |
109 | parser.add_argument('--pairs', type=Path, required=True)
110 | parser.add_argument('--features', type=Path, required=True)
111 | parser.add_argument('--matches', type=Path, required=True)
112 |
113 | parser.add_argument('--camera_mode', type=str, default="AUTO",
114 | choices=list(pycolmap.CameraMode.__members__.keys()))
115 | parser.add_argument('--skip_geometric_verification', action='store_true')
116 | parser.add_argument('--min_match_score', type=float)
117 | parser.add_argument('--verbose', action='store_true')
118 | args = parser.parse_args()
119 |
120 | main(**args.__dict__)
121 |
--------------------------------------------------------------------------------
/scripts/preprocess/hloc_mapping/sfm_pipeline.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from scripts.preprocess.hloc_mapping import extract_relative_poses
5 |
6 |
7 | def parse_args():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--dataset_dir', type=Path, default='datasets',
10 | help='Path to the dataset, default: %(default)s')
11 | parser.add_argument('--outputs', type=Path, default='outputs',
12 | help='Path to the output directory, default: %(default)s')
13 | parser.add_argument('--num_matches', type=int, default=30,
14 | help='Number of image pairs for loc, default: %(default)s')
15 | parser.add_argument('--disambiguate', action="store_true",
16 | help='Enable/Disable disambiguating wrong matches.')
17 | parser.add_argument('--min_track_length', type=int, default=3)
18 | parser.add_argument('--max_track_length', type=int, default=40)
19 | parser.add_argument('--track_degree', type=int, default=3)
20 | parser.add_argument('--coverage_thres', type=float, default=0.9)
21 | parser.add_argument('--alpha', type=float, default=0.1)
22 | parser.add_argument('--minimal_views', type=int, default=5)
23 | parser.add_argument('--ds', type=str,
24 | choices=['dict', 'smallarray', 'largearray'],
25 | default='largearray')
26 | parser.add_argument('--filter_type', type=str, choices=[
27 | 'threshold', 'knn', 'mst_min', 'mst_mean', 'percentile'],
28 | default='threshold')
29 | parser.add_argument('--threshold', type=float, default=0.15)
30 | parser.add_argument('--topk', type=int, default=3)
31 | parser.add_argument('--percentile', type=float)
32 | parser.add_argument('--colmap_path', type=Path, default='colmap')
33 | parser.add_argument('--geometric_verification_type',
34 | type=str,
35 | choices=['default', 'strict'],
36 | default='default')
37 | parser.add_argument('--recon', action="store_true",
38 | help='Indicates whether to reconstruct the scene.')
39 | parser.add_argument('--visualize', action="store_true",
40 | help='Whether to visualize the reconstruction.')
41 | parser.add_argument('--gpu_idx', type=str, default='0')
42 | args = parser.parse_args()
43 | return args
44 |
45 |
46 | def main():
47 | args = parse_args()
48 | # Extracting relative poses and store as g2o file.
49 | view_graph_path, database_path, num_view_pairs = extract_relative_poses.main(args=args)
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
54 |
--------------------------------------------------------------------------------
/scripts/preprocess/hloc_mapping/triangulate_from_existing_model.py:
--------------------------------------------------------------------------------
1 | import io
2 | import sys
3 | import argparse
4 | import contextlib
5 |
6 | from typing import Optional, List, Dict, Any
7 | from pathlib import Path
8 |
9 | import pycolmap
10 |
11 |
12 | class OutputCapture:
13 | def __init__(self, verbose: bool):
14 | self.verbose = verbose
15 |
16 | def __enter__(self):
17 | if not self.verbose:
18 | self.capture = contextlib.redirect_stdout(io.StringIO()) # pylint: disable=W0201
19 | self.out = self.capture.__enter__() # pylint: disable=W0201
20 |
21 | def __exit__(self, exc_type, *args):
22 | if not self.verbose:
23 | self.capture.__exit__(exc_type, *args)
24 | if exc_type is not None:
25 | print('Failed with output:\n%s', self.out.getvalue())
26 | sys.stdout.flush()
27 |
28 |
29 | def run_triangulation(
30 | output_path: Path,
31 | database_path: Path,
32 | image_dir: Path,
33 | reference_model: pycolmap.Reconstruction,
34 | verbose: bool = False,
35 | options: Optional[Dict[str, Any]] = None,
36 | ) -> pycolmap.Reconstruction:
37 | output_path.mkdir(parents=True, exist_ok=True)
38 | print('Running 3D triangulation...')
39 | if options is None:
40 | options = {}
41 | with OutputCapture(verbose):
42 | with pycolmap.ostream():
43 | reconstruction = pycolmap.triangulate_points(
44 | reference_model, database_path, image_dir, output_path)
45 | return reconstruction
46 |
47 |
48 | def main(
49 | sfm_dir: Path,
50 | reference_model: Path,
51 | image_dir: Path,
52 | output_dir: Path,
53 | verbose: bool = False,
54 | mapper_options: Optional[Dict[str, Any]] = None,
55 | ) -> pycolmap.Reconstruction:
56 |
57 | assert reference_model.exists(), reference_model
58 |
59 | sfm_dir.mkdir(parents=True, exist_ok=True)
60 | database_path = sfm_dir / 'database.db'
61 | reference_model = pycolmap.Reconstruction(reference_model)
62 |
63 | reconstruction = run_triangulation(output_dir, database_path, image_dir, reference_model,
64 | verbose, mapper_options)
65 | print('Finished the triangulation with statistics:\n%s',
66 | reconstruction.summary())
67 | return reconstruction
68 |
69 |
70 | def parse_option_args(args: List[str], default_options) -> Dict[str, Any]:
71 | options = {}
72 | for arg in args:
73 | idx = arg.find('=')
74 | if idx == -1:
75 | raise ValueError('Options format: key1=value1 key2=value2 etc.')
76 | key, value = arg[:idx], arg[idx+1:]
77 | if not hasattr(default_options, key):
78 | raise ValueError(
79 | f'Unknown option "{key}", allowed options and default values'
80 | f' for {default_options.summary()}')
81 | value = eval(value) # pylint: disable=W0123
82 | target_type = type(getattr(default_options, key))
83 | if not isinstance(value, target_type):
84 | raise ValueError(f'Incorrect type for option "{key}":'
85 | f' {type(value)} vs {target_type}')
86 | options[key] = value
87 | return options
88 |
89 |
90 | if __name__ == '__main__':
91 | parser = argparse.ArgumentParser()
92 | parser.add_argument('--sfm_dir', type=Path, required=True)
93 | parser.add_argument('--reference_model', type=Path, required=True)
94 | parser.add_argument('--image_dir', type=Path, required=True)
95 | parser.add_argument('--output_dir', type=Path, required=True)
96 | parser.add_argument('--verbose', action='store_true')
97 | args = parser.parse_args().__dict__
98 |
99 | # mapper_options = parse_option_args(
100 | # args.pop("mapper_options"), pycolmap.IncrementalMapperOptions())
101 | mapper_options = pycolmap.IncrementalMapperOptions()
102 |
103 | main(**args, mapper_options=mapper_options)
104 |
--------------------------------------------------------------------------------
/scripts/preprocess/hloc_mapping/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import cv2
4 | from tqdm import tqdm
5 |
6 | from hloc.utils.database import COLMAPDatabase, blob_to_array
7 | from hloc import logger
8 | from hloc.utils.io import get_matches
9 |
10 |
11 | def import_matches(image_ids, database_path, pairs_path, matches_path,
12 | min_match_score=None, skip_geometric_verification=False
13 | ) -> int :
14 | logger.info('Importing matches into the database...')
15 |
16 | with open(str(pairs_path), 'r') as f:
17 | pairs = [p.split() for p in f.readlines()]
18 |
19 | db = COLMAPDatabase.connect(database_path)
20 |
21 | matched = set()
22 | for name0, name1 in tqdm(pairs):
23 | id0, id1 = image_ids[name0], image_ids[name1]
24 | if len({(id0, id1), (id1, id0)} & matched) > 0:
25 | continue
26 | matches, scores = get_matches(matches_path, name0, name1)
27 | if min_match_score:
28 | matches = matches[scores > min_match_score]
29 | db.add_matches(id0, id1, matches)
30 | matched |= {(id0, id1), (id1, id0)}
31 |
32 | if skip_geometric_verification:
33 | db.add_two_view_geometry(id0, id1, matches)
34 |
35 | db.commit()
36 | db.close()
37 | return len(pairs)
38 |
39 |
40 | def read_camera_intrinsics_by_image_id(image_id: int, db: COLMAPDatabase):
41 | rows = db.execute(f'SELECT camera_id FROM images WHERE image_id={image_id}')
42 | camera_id = next(rows)[0]
43 | rows = db.execute(f'SELECT params FROM cameras WHERE camera_id={camera_id}')
44 | params = blob_to_array(next(rows)[0], dtype=np.float64)
45 |
46 | # FIXME(chenyu): when camera model is not a simple pinhole.
47 | intrinsics = np.zeros((3, 3), dtype=np.float64)
48 | intrinsics[0, 0] = intrinsics[1, 1] = params[0]
49 | intrinsics[0, 2], intrinsics[1, 2] = params[1], params[2]
50 | intrinsics[2, 2] = 1.
51 | return intrinsics
52 |
53 |
54 | def read_all_keypoints(db: COLMAPDatabase):
55 | keypoints_dict = dict(
56 | (image_id, blob_to_array(data, np.float32, (-1, 2)))
57 | for image_id, data in db.execute(
58 | "SELECT image_id, data FROM keypoints"))
59 | return keypoints_dict
60 |
61 |
62 | def extract_inlier_keypoints_pair(inlier_matches, keypoints1, keypoints2):
63 | inlier_keypoints1, inlier_keypoints2 = [], []
64 | num_inliers = inlier_matches.shape[0]
65 | for i in range(num_inliers):
66 | idx = inlier_matches[i]
67 | inlier_keypoints1.append(keypoints1[idx[0]])
68 | inlier_keypoints2.append(keypoints2[idx[1]])
69 |
70 | inlier_keypoints1 = np.stack(inlier_keypoints1, axis=0)
71 | inlier_keypoints2 = np.stack(inlier_keypoints2, axis=0)
72 | return inlier_keypoints1, inlier_keypoints2
73 |
74 |
75 | def triangulate(inlier_keypoints1, inlier_keypoints2,
76 | extrinsics1: np.ndarray, extrinsics2: np.ndarray,
77 | intrinsics1: np.ndarray, intrinsics2: np.ndarray):
78 | proj_mtx1 = np.matmul(intrinsics1, extrinsics1)
79 | proj_mtx2 = np.matmul(intrinsics2, extrinsics2)
80 |
81 | points3d = cv2.triangulatePoints(projMatr1=proj_mtx1, projMatr2=proj_mtx2,
82 | projPoints1=inlier_keypoints1.transpose(1, 0),
83 | projPoints2=inlier_keypoints2.transpose(1, 0))
84 | points3d = points3d.transpose(1, 0)
85 | points3d = points3d[:, :3] / points3d[:, 3].reshape(-1, 1)
86 | return points3d
87 |
88 |
89 | def compute_depth(proj_matrix, point3d):
90 | homo_point3d = np.ones(4)
91 | homo_point3d[0:3] = point3d
92 | proj_z = np.dot(proj_matrix[2, :].T, homo_point3d)
93 | return proj_z * np.linalg.norm(proj_matrix[:, 2], ord=2)
94 |
95 |
96 | def check_cheirality(inlier_keypoints1, inlier_keypoints2,
97 | extrinsic1: np.ndarray, extrinsic2: np.ndarray,
98 | intrinsics1: np.ndarray, intrinsics2: np.ndarray):
99 | min_depth = 1e-16
100 | max_depth = 1000 * np.linalg.norm(
101 | np.dot(extrinsic2[:3, :3].T, extrinsic2[:, 3]), ord=2)
102 | points3d = []
103 |
104 | tmp_points3d = triangulate(inlier_keypoints1, inlier_keypoints2,
105 | extrinsic1, extrinsic2, intrinsics1, intrinsics2)
106 | for point3d in tmp_points3d:
107 | # Checking for positive depth in front of both cameras.
108 | depth1 = compute_depth(extrinsic1, point3d)
109 | if depth1 < max_depth and depth1 > min_depth:
110 | depth2 = compute_depth(extrinsic2, point3d)
111 | if depth2 < max_depth and depth2 > min_depth:
112 | points3d.append(point3d)
113 |
114 | return points3d
115 |
116 |
117 | def decompose_essential_matrix(
118 | keypoints1, keypoints2,
119 | essential_matrix, inlier_matches,
120 | intrinsics1: np.ndarray, intrinsics2: np.ndarray
121 | ) -> (np.ndarray, np.ndarray):
122 | """
123 | Assume that the image_id1 is at [I|0] and second image_id2 is at [R|t]
124 | where R, t are derived from the essential matrix.
125 |
126 | Args:
127 | keypoints1: keypoints locations of image1
128 | keypoints1: keypoints locations of image2
129 | essential_matrix: 3 x 3 numpy array,
130 | inlier_matches: matched keypoints indices between image1 and image2
131 | intrinsics1: 3 x 3 numpy array for image1
132 | intrinsics2: 3 x 3 numpy array for image2
133 |
134 | Returns:
135 | extrinsic matrix of shape (1, 12) from image 1 to image 2
136 | """
137 |
138 | inlier_keypoints1, inlier_keypoints2 = extract_inlier_keypoints_pair(
139 | inlier_matches, keypoints1, keypoints2)
140 | # print(f'{inlier_keypoints1.shape}')
141 | # print(f'{inlier_keypoints2.shape}')
142 |
143 | extrinsic1 = np.zeros(shape=[3, 4], dtype=np.float64)
144 | extrinsic1[:3, :3] = np.eye(3)
145 | # relative motion from camera1 to camera2.
146 | extrinsics2 = np.zeros(shape=[3, 4], dtype=np.float64)
147 |
148 | W = np.zeros((3, 3))
149 | W[0, 1], W[1, 0], W[2, 2] = -1, 1, 1
150 | U, _, Vh = np.linalg.svd(essential_matrix)
151 |
152 | if np.linalg.det(U) < 0:
153 | U *= -1
154 | if np.linalg.det(Vh) < 0:
155 | Vh *= -1
156 |
157 | R1, R2 = np.dot(np.dot(U, W), Vh), np.dot(np.dot(U, np.transpose(W)), Vh)
158 | t = U[:, 2]
159 | t /= np.linalg.norm(t, ord=2)
160 |
161 | def compose_projection_matrix(R, t):
162 | P = np.zeros(shape=[3, 4], dtype=float)
163 | P[:3, :3], P[:, 3] = R, t
164 | return P
165 |
166 | # Generate candidate projection matrices.
167 | P2_list = []
168 | P2_list.append(compose_projection_matrix(R1, t))
169 | P2_list.append(compose_projection_matrix(R2, t))
170 | P2_list.append(compose_projection_matrix(R1, -t))
171 | P2_list.append(compose_projection_matrix(R2, -t))
172 |
173 | candidate_points3d, points3d = [], []
174 | # Then, we need to iterate over each projection matrix and
175 | # make the cheirality validation.
176 | for extrinsic2 in P2_list:
177 | candidate_points3d = check_cheirality(
178 | inlier_keypoints1, inlier_keypoints2,
179 | extrinsic1, extrinsic2, intrinsics1, intrinsics2)
180 | # print(f'len points3d: {len(points3d)}')
181 | if len(points3d) < len(candidate_points3d):
182 | points3d[:] = candidate_points3d
183 | extrinsics2[:] = extrinsic2
184 |
185 | # print(f'final len points3d: {len(points3d)}')
186 | if len(points3d) == 0:
187 | return None, None
188 |
189 | points3d = np.stack(points3d, axis=0)
190 |
191 | return extrinsics2.reshape(1, -1), points3d
192 |
--------------------------------------------------------------------------------
/scripts/preprocess/mapping.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tqdm
4 | import argparse
5 |
6 | from conerf.datasets.hypersim import _collect_camera_names, _get_all_image_names
7 |
8 |
9 | SFM_SCRIPT_PATH = os.path.join(os.getcwd(), 'scripts/preprocess/colmap_mapping.sh')
10 | VOC_TREE_PATH = '/home/chenyu/HD_Datasets/datasets/vocab_tree_flickr100K_words256K.bin'
11 | TOPK_IMAGES = 100
12 | GPU_IDS = 1
13 |
14 | # DATASETS = ['Hypersim'] #, 'nerf_synthetic'] # ['nerf_llff_data', 'ibrnet_collected_more', 'BlendedMVS']
15 | DATASETS = ['DTU'] #, 'nerf_synthetic'] # ['nerf_llff_data', 'ibrnet_collected_more', 'BlendedMVS']
16 | ROOT_DIR = '/home/chenyu/HD_Datasets/datasets'
17 | # DATASETS = ['BlendedMVS']
18 | # ROOT_DIR = '/media/chenyu/SSD_Data/datasets'
19 |
20 |
21 | def config_parser():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--preprocess",
24 | action="store_true",
25 | help="whether to preprocess data")
26 | parser.add_argument("--run_colmap",
27 | action="store_true",
28 | help="whether to preprocess data")
29 | parser.add_argument("--start_index", type=int, default=0)
30 | parser.add_argument("--end_index", type=int, default=0)
31 |
32 | return parser.parse_args()
33 |
34 |
35 | def get_filename_from_abs_path(abs_path):
36 | return abs_path.split('/')[-1]
37 |
38 |
39 | def get_filename_no_ext(filename):
40 | return os.path.splitext(filename)[0]
41 |
42 |
43 | def get_file_extension(filename):
44 | return os.path.splitext(filename)[-1]
45 |
46 |
47 | def preprocess_nerf_synthetic_dataset(dataset_dir):
48 | # The DTU dataset follows pixel-nerf: https://github.com/sxyu/pixel-nerf ,
49 | # Url: https://drive.google.com/drive/folders/1PsT3uKwqHHD2bEEHkIXB99AlIjtmrEiR
50 | scenes = sorted(os.listdir(dataset_dir))
51 | for scene in scenes:
52 | scene_dir = os.path.join(dataset_dir, scene)
53 | image_dir = os.path.join(scene_dir, 'train')
54 | new_image_dir = os.path.join(scene_dir, 'images')
55 |
56 | os.system(f'cp -r {image_dir} {new_image_dir}')
57 |
58 |
59 | def preprocess_dtu_dataset(dataset_dir):
60 | # The DTU dataset follows pixel-nerf: https://github.com/sxyu/pixel-nerf ,
61 | # Url: https://drive.google.com/drive/folders/1PsT3uKwqHHD2bEEHkIXB99AlIjtmrEiR
62 | scenes = sorted(os.listdir(dataset_dir))
63 | for scene in scenes:
64 | scene_dir = os.path.join(dataset_dir, scene)
65 | if not os.path.isdir(scene_dir):
66 | continue
67 | image_dir = os.path.join(scene_dir, 'image')
68 | new_image_dir = os.path.join(scene_dir, 'images')
69 |
70 | os.system(f'mv {image_dir} {new_image_dir}')
71 |
72 |
73 | def preprocess_blended_mvs_dataset(dataset_dir):
74 | scenes = sorted(os.listdir(dataset_dir))
75 | if args.start_index < args.end_index:
76 | scenes = scenes[args.start_index:args.end_index]
77 |
78 | for scene in scenes:
79 | scene_dir = os.path.join(dataset_dir, scene)
80 |
81 | blended_image_dir = os.path.join(scene_dir, 'blended_images')
82 | image_dir = os.path.join(scene_dir, 'images')
83 | ori_image_dir = os.path.join(scene_dir, 'ori_images')
84 | masked_image_dir = os.path.join(scene_dir, 'masked_images')
85 |
86 | # os.system(f'rm -r {scene_dir}/output')
87 | # os.system(f'rm {scene_dir}/database.db {scene_dir}/poses_bounds.npy {scene_dir}/track.txt {scene_dir}/*.g2o {scene_dir}/*.json')
88 | # os.system(f'mv {image_dir} {ori_image_dir}')
89 | # os.system(f'mv {masked_image_dir} {image_dir}')
90 |
91 | # if not os.path.exists(image_dir):
92 | # os.mkdir(image_dir)
93 |
94 | # if not os.path.exists(masked_image_dir):
95 | # os.mkdir(masked_image_dir)
96 |
97 | # for root, dirs, files in os.walk(blended_image_dir):
98 | # for file in files:
99 | # image_path = os.path.join(blended_image_dir, root, file)
100 | # if file.find('masked') >= 0:
101 | # shutil.move(image_path, os.path.join(masked_image_dir, file))
102 | # else:
103 | # shutil.move(image_path, os.path.join(image_dir, file))
104 |
105 | # os.system(f'rm -r {blended_image_dir}')
106 |
107 |
108 | def preprocess_hypersim_dataset(dataset_dir):
109 | scenes = sorted(os.listdir(dataset_dir))
110 | if args.start_index < args.end_index:
111 | scenes = scenes[args.start_index:args.end_index]
112 |
113 | pbar = tqdm.trange(len(scenes), desc="Preprocessing", leave=False)
114 | for scene in scenes:
115 | scene_dir = os.path.join(dataset_dir, scene)
116 | if not os.path.isdir(scene_dir):
117 | continue
118 |
119 | camera_names = _collect_camera_names(os.path.join(scene_dir, '_detail'))
120 |
121 | new_image_dir = os.path.join(scene_dir, 'images')
122 | origin_image_dir = os.path.join(scene_dir, 'ori_images')
123 |
124 | if not os.path.exists(origin_image_dir):
125 | os.mkdir(origin_image_dir)
126 |
127 | # backup
128 | os.system(f'mv {new_image_dir}/* {origin_image_dir}')
129 | # os.system(f'rm -r {origin_image_dir}/images')
130 |
131 | for i, camera_name in enumerate(camera_names):
132 | image_dir = os.path.join(origin_image_dir, 'scene_' + camera_name + '_final_preview')
133 | image_files, _ = _get_all_image_names(image_dir, image_type='tonemap')
134 |
135 | for image_file in image_files:
136 | image_name = get_filename_from_abs_path(image_file)
137 |
138 | sub_image_dir = os.path.join(new_image_dir, str(i))
139 | os.makedirs(sub_image_dir, exist_ok=True)
140 | new_image_file = os.path.join(sub_image_dir, image_name)
141 | shutil.copy(image_file, new_image_file)
142 |
143 | pbar.update(1)
144 |
145 |
146 | def run_sfm(root_dir, dataset_list, args):
147 | for dataset in dataset_list:
148 | dataset_dir = os.path.join(root_dir, dataset)
149 |
150 | if args.preprocess:
151 | if dataset == 'BlendedMVS':
152 | preprocess_blended_mvs_dataset(dataset_dir)
153 |
154 | if dataset == 'DTU':
155 | preprocess_dtu_dataset(dataset_dir)
156 |
157 | if dataset == 'Hypersim':
158 | preprocess_hypersim_dataset(dataset_dir)
159 |
160 | scenes = sorted(os.listdir(dataset_dir))
161 | if args.start_index < args.end_index:
162 | scenes = scenes[args.start_index:args.end_index]
163 |
164 | pbar = tqdm.trange(len(scenes), desc="Running SfM", leave=False)
165 | for scene in scenes:
166 | data_dir = os.path.join(dataset_dir, scene)
167 | if not os.path.isdir(data_dir):
168 | continue
169 | output_dir = os.path.join(data_dir, 'sparse')
170 | if not os.path.exists(output_dir):
171 | os.makedirs(output_dir)
172 |
173 | if not args.run_colmap:
174 | # Compute bounding box.
175 | os.system(f'python -m scripts.preprocess.compute_bbox --colmap_dir {output_dir}/0')
176 | continue
177 |
178 | # print(f'output dir: {output_dir}')
179 | os.system(f'{SFM_SCRIPT_PATH} {data_dir} {output_dir} {VOC_TREE_PATH} {TOPK_IMAGES} {GPU_IDS}')
180 |
181 | shutil.move(os.path.join(output_dir, 'database.db'), os.path.join(data_dir, 'database.db'))
182 |
183 | # Compute bounding box.
184 | os.system(f'python -m scripts.preprocess.compute_bbox --colmap_dir {output_dir}/0')
185 |
186 | pbar.update(1)
187 |
188 |
189 | if __name__ == '__main__':
190 | args = config_parser()
191 |
192 | run_sfm(ROOT_DIR, DATASETS, args)
193 |
--------------------------------------------------------------------------------
/scripts/preprocess/triangulate.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Run triangulator with known camera poses.
4 |
5 | COLMAP_DIR=/usr/local/bin
6 | COLMAP_EXE=$COLMAP_DIR/colmap
7 |
8 | export PYTHONDONTWRITEBYTECODE=1
9 |
10 | PROJECT_PATH=$1
11 | colmap_method=$2 # ['colmap', 'pycolmap']
12 |
13 | if [ `echo $colmap_method | grep -c "py" ` -gt 0 ]
14 | then
15 | HOME_DIR=$HOME
16 | CODE_ROOT_DIR=$HOME/'Projects/ZeroGS'
17 | cd $CODE_ROOT_DIR
18 |
19 | python -m scripts.preprocess.hloc_mapping.triangulate_from_existing_model \
20 | --sfm_dir $PROJECT_PATH \
21 | --reference_model $PROJECT_PATH/sparse/triangulator_input \
22 | --output_dir $PROJECT_PATH/sparse/0 \
23 | --image_dir $PROJECT_PATH \
24 | --verbose \
25 | > $PROJECT_PATH/log_triangulate.txt 2>&1
26 | else
27 | $COLMAP_EXE point_triangulator \
28 | --database_path $PROJECT_PATH/database.db \
29 | --image_path $PROJECT_PATH \
30 | --input_path $PROJECT_PATH/sparse/triangulator_input \
31 | --output_path $PROJECT_PATH/sparse/0 \
32 | > $PROJECT_PATH/log_triangulate.txt 2>&1
33 | fi
34 |
--------------------------------------------------------------------------------
/scripts/preprocess/utils.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=[E0402, C)103]
2 |
3 | from pathlib import Path
4 | from typing import List, Dict, Tuple
5 |
6 | from .read_write_model import Camera
7 |
8 |
9 | def list_images(data: str) -> List[str]:
10 | """Lists all supported images in a directory
11 | Modified from:
12 | https://github.com/hturki/nerfstudio/nerfstudio/process_data/process_data_utils.py#L60
13 |
14 | Args:
15 | data: Path to the directory of images.
16 | Returns:
17 | Paths to images contained in the directory
18 | """
19 | data = Path(data)
20 | allowed_exts = [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
21 | image_paths = sorted([p for p in data.glob("[!.]*") if p.suffix.lower() in allowed_exts])
22 | return image_paths
23 |
24 |
25 | def list_metadata(data: str) -> List[str]:
26 | """Lists all supported images in a directory
27 | Modified from:
28 | https://github.com/hturki/nerfstudio/nerfstudio/process_data/process_data_utils.py#L60
29 |
30 | Args:
31 | data: Path to the directory of images.
32 | Returns:
33 | Paths to images contained in the directory
34 | """
35 | data = Path(data)
36 | allowed_exts = [".pt"]
37 | metadata_paths = sorted([p for p in data.glob("[!.]*") if p.suffix.lower() in allowed_exts])
38 | return metadata_paths
39 |
40 |
41 | def list_jsons(data: str) -> List[str]:
42 | """Lists all supported images in a directory
43 | Modified from:
44 | https://github.com/hturki/nerfstudio/nerfstudio/process_data/process_data_utils.py#L60
45 |
46 | Args:
47 | data: Path to the directory of images.
48 | Returns:
49 | Paths to images contained in the directory
50 | """
51 | data = Path(data)
52 | allowed_exts = [".json"]
53 | metadata_paths = sorted([p for p in data.glob("[!.]*") if p.suffix.lower() in allowed_exts])
54 | return metadata_paths
55 |
56 |
57 | def read_meganerf_mappings(mappings_path: str) -> Tuple[Dict, Dict]:
58 | image_name_to_metadata, metadata_to_image_name = {}, {}
59 | with open(mappings_path, "r", encoding="utf-8") as file:
60 | line = file.readline()
61 | while line:
62 | image_name, pt_name = line.split(',')
63 | pt_name = pt_name.strip()
64 | image_name_to_metadata[image_name] = pt_name
65 | metadata_to_image_name[pt_name] = image_name
66 | line = file.readline()
67 |
68 | return image_name_to_metadata, metadata_to_image_name
69 |
70 |
71 | def get_filename_from_path(path: str) -> str:
72 | last_slash_index = path.rfind('/')
73 | return path[last_slash_index+1:]
74 |
75 |
76 | def is_same_camera(camera1: Camera, camera2: Camera) -> bool:
77 | if camera1.width != camera2.width:
78 | return False
79 |
80 | if camera1.height != camera2.height:
81 | return False
82 |
83 | if len(camera1.params) != len(camera2.params):
84 | return False
85 |
86 | for i in range(len(camera1.params)):
87 | if camera1.params[i] != camera2.params[i]:
88 | return False
89 |
90 | return True
91 |
92 |
93 | def get_camera_id(cameras: Dict, query_camera: Camera) -> int:
94 | for idx, camera in cameras.items():
95 | if is_same_camera(camera, query_camera):
96 | return idx
97 |
98 | return len(cameras) + 1
99 |
--------------------------------------------------------------------------------
/scripts/train/train_ace_zero.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_IDS=$1 # {'0,1,2,...'}
4 |
5 | export PYTHONDONTWRITEBYTECODE=1
6 | export CUDA_VISIBLE_DEVICES=${CUDA_IDS}
7 |
8 | # Default parameters.
9 | DATASET='blender' # [llff, mipnerf360, tanks_and_temples]
10 | ENCODING='ace' # [ace, zero_gs]
11 | SUFFIX=''
12 |
13 | NUM_CMD_PARAMS=$#
14 | if [ $NUM_CMD_PARAMS -eq 2 ]
15 | then
16 | SUFFIX=$2
17 | elif [ $NUM_CMD_PARAMS -eq 3 ]
18 | then
19 | SUFFIX=$2
20 | DATASET=$3
21 | elif [ $NUM_CMD_PARAMS -eq 4 ]
22 | then
23 | SUFFIX=$2
24 | DATASET=$3
25 | ENCODING=$4
26 | fi
27 |
28 | YAML=${ENCODING}/${DATASET}'.yaml'
29 | echo "Using yaml file: ${YAML}"
30 |
31 | HOME_DIR=$HOME
32 | CODE_ROOT_DIR=$HOME/'Projects/ZeroGS'
33 |
34 | cd $CODE_ROOT_DIR
35 |
36 | python train.py --config 'config/'${YAML} \
37 | --suffix $SUFFIX
38 |
--------------------------------------------------------------------------------
/submodules/dsacstar/dsacstar.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Based on the DSAC++ and ESAC code.
3 | https://github.com/vislearn/LessMore
4 | https://github.com/vislearn/esac
5 |
6 | Copyright (c) 2016, TU Dresden
7 | Copyright (c) 2020, Heidelberg University
8 | All rights reserved.
9 |
10 | Redistribution and use in source and binary forms, with or without
11 | modification, are permitted provided that the following conditions are met:
12 | * Redistributions of source code must retain the above copyright
13 | notice, this list of conditions and the following disclaimer.
14 | * Redistributions in binary form must reproduce the above copyright
15 | notice, this list of conditions and the following disclaimer in the
16 | documentation and/or other materials provided with the distribution.
17 | * Neither the name of the TU Dresden, Heidelberg University nor the
18 | names of its contributors may be used to endorse or promote products
19 | derived from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY
25 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
28 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | */
32 |
33 | #include
34 | #include
35 |
36 | #include
37 |
38 | #include "thread_rand.h"
39 | #include "stop_watch.h"
40 |
41 | #include "dsacstar_types.h"
42 | #include "dsacstar_util.h"
43 | //#include "dsacstar_util_rgbd.h"
44 | #include "dsacstar_loss.h"
45 | #include "dsacstar_derivative.h"
46 |
47 | #define MAX_REF_STEPS 100 // max pose refienment iterations
48 | #define MAX_HYPOTHESES_TRIES 1000000 // repeat sampling x times hypothesis if hypothesis is invalid
49 |
50 | /**
51 | * @brief Estimate a camera pose based on a scene coordinate prediction
52 | * @param sceneCoordinatesSrc Scene coordinate prediction, (1x3xHxW) with 1=batch dimension (only batch_size=1 supported atm), 3=scene coordainte dimensions, H=height and W=width.
53 | * @param outPoseSrc Camera pose (output parameter), (4x4) tensor containing the homogeneous camera tranformation matrix.
54 | * @param ransacHypotheses Number of RANSAC iterations.
55 | * @param inlierThreshold Inlier threshold for RANSAC in px.
56 | * @param focalLength Focal length of the camera in px.
57 | * @param ppointX Coordinate (X) of the prinicpal points.
58 | * @param ppointY Coordinate (Y) of the prinicpal points.
59 | * @param inlierAlpha Alpha parameter for soft inlier counting.
60 | * @param maxReproj Reprojection errors are clamped above this value (px).
61 | * @param subSampling Sub-sampling of the scene coordinate prediction wrt the input image.
62 | * @return The number of inliers for the output pose.
63 | */
64 | int dsacstar_rgb_forward(
65 | at::Tensor sceneCoordinatesSrc,
66 | at::Tensor outPoseSrc,
67 | int ransacHypotheses,
68 | float inlierThreshold,
69 | float fx,
70 | float fy,
71 | float ppointX,
72 | float ppointY,
73 | float inlierAlpha,
74 | float maxReproj,
75 | int subSampling,
76 | int maxHypothesesTries = 10000000,
77 | bool verbose = false)
78 | {
79 | ThreadRand::init();
80 |
81 | // access to tensor objects
82 | dsacstar::coord_t sceneCoordinates =
83 | sceneCoordinatesSrc.accessor();
84 |
85 | // dimensions of scene coordinate predictions
86 | int imH = sceneCoordinates.size(2);
87 | int imW = sceneCoordinates.size(3);
88 |
89 | // internal camera calibration matrix
90 | cv::Mat_ camMat = cv::Mat_::eye(3, 3);
91 | camMat(0, 0) = fx;
92 | camMat(1, 1) = fy;
93 | camMat(0, 2) = ppointX;
94 | camMat(1, 2) = ppointY;
95 |
96 | // calculate original image position for each scene coordinate prediction
97 | cv::Mat_ sampling =
98 | dsacstar::createSampling(imW, imH, subSampling, 0, 0);
99 |
100 | if (verbose) {
101 | std::cout << BLUETEXT("Sampling " << ransacHypotheses << " hypotheses.") << std::endl;
102 | }
103 |
104 | StopWatch stopW;
105 |
106 | // sample RANSAC hypotheses
107 | std::vector hypotheses;
108 | std::vector> sampledPoints;
109 | std::vector> imgPts;
110 | std::vector> objPts;
111 |
112 | dsacstar::sampleHypotheses(
113 | sceneCoordinates,
114 | sampling,
115 | camMat,
116 | ransacHypotheses,
117 | maxHypothesesTries,
118 | inlierThreshold,
119 | hypotheses,
120 | sampledPoints,
121 | imgPts,
122 | objPts);
123 |
124 | if (verbose) {
125 | std::cout << "Done in " << stopW.stop() / 1000 << "s." << std::endl;
126 | std::cout << BLUETEXT("Calculating scores.") << std::endl;
127 | }
128 |
129 | // compute reprojection error images
130 | std::vector> reproErrs(ransacHypotheses);
131 | cv::Mat_ jacobeanDummy;
132 |
133 | #pragma omp parallel for
134 | for(unsigned h = 0; h < hypotheses.size(); h++)
135 | reproErrs[h] = dsacstar::getReproErrs(
136 | sceneCoordinates,
137 | hypotheses[h],
138 | sampling,
139 | camMat,
140 | maxReproj,
141 | jacobeanDummy);
142 |
143 | // soft inlier counting
144 | std::vector scores = dsacstar::getHypScores(
145 | reproErrs,
146 | inlierThreshold,
147 | inlierAlpha);
148 |
149 | if (verbose) {
150 | std::cout << "Done in " << stopW.stop() / 1000 << "s." << std::endl;
151 | std::cout << BLUETEXT("Drawing final hypothesis.") << std::endl;
152 | }
153 |
154 | // apply soft max to scores to get a distribution
155 | std::vector hypProbs = dsacstar::softMax(scores);
156 | double hypEntropy = dsacstar::entropy(hypProbs); // measure distribution entropy
157 | int hypIdx = dsacstar::draw(hypProbs, false); // select winning hypothesis
158 |
159 | if (verbose) {
160 | std::cout << "Soft inlier count: " << scores[hypIdx] << " (Selection Probability: " << (int) (hypProbs[hypIdx]*100) << "%)" << std::endl;
161 | std::cout << "Entropy of hypothesis distribution: " << hypEntropy << std::endl;
162 |
163 | std::cout << "Done in " << stopW.stop() / 1000 << "s." << std::endl;
164 | std::cout << BLUETEXT("Refining winning pose:") << std::endl;
165 | }
166 |
167 | // refine selected hypothesis
168 | cv::Mat_ inlierMap;
169 |
170 | dsacstar::refineHyp(
171 | sceneCoordinates,
172 | reproErrs[hypIdx],
173 | sampling,
174 | camMat,
175 | inlierThreshold,
176 | MAX_REF_STEPS,
177 | maxReproj,
178 | hypotheses[hypIdx],
179 | inlierMap);
180 |
181 | if (verbose) {
182 | std::cout << "Done in " << stopW.stop() / 1000 << "s." << std::endl;
183 | }
184 |
185 | // write result back to PyTorch
186 | dsacstar::trans_t estTrans = dsacstar::pose2trans(hypotheses[hypIdx]);
187 |
188 | auto outPose = outPoseSrc.accessor();
189 | for(unsigned x = 0; x < 4; x++)
190 | for(unsigned y = 0; y < 4; y++)
191 | outPose[y][x] = estTrans(y, x);
192 |
193 | // Return the inlier count. cv::sum returns a scalar, so we return its first element.
194 | return cv::sum(inlierMap)[0];
195 | }
196 |
197 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
198 | m.def("forward_rgb", &dsacstar_rgb_forward, "DSAC* forward (RGB)");
199 | // m.def("backward_rgb", &dsacstar_rgb_backward, "DSAC* backward (RGB)");
200 | // m.def("forward_rgbd", &dsacstar_rgbd_forward, "DSAC* forward (RGB-D)");
201 | // m.def("backward_rgbd", &dsacstar_rgbd_backward, "DSAC* backward (RGB-D)");
202 | }
203 |
--------------------------------------------------------------------------------
/submodules/dsacstar/dsacstar_loss.h:
--------------------------------------------------------------------------------
1 | /*
2 | Based on the DSAC++ and ESAC code.
3 | https://github.com/vislearn/LessMore
4 | https://github.com/vislearn/esac
5 |
6 | Copyright (c) 2016, TU Dresden
7 | Copyright (c) 2020, Heidelberg University
8 | All rights reserved.
9 |
10 | Redistribution and use in source and binary forms, with or without
11 | modification, are permitted provided that the following conditions are met:
12 | * Redistributions of source code must retain the above copyright
13 | notice, this list of conditions and the following disclaimer.
14 | * Redistributions in binary form must reproduce the above copyright
15 | notice, this list of conditions and the following disclaimer in the
16 | documentation and/or other materials provided with the distribution.
17 | * Neither the name of the TU Dresden, Heidelberg University nor the
18 | names of its contributors may be used to endorse or promote products
19 | derived from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY
25 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
28 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | */
32 |
33 | #pragma once
34 |
35 | #define MAXLOSS 10000000.0 // clamp for stability
36 |
37 | namespace dsacstar
38 | {
39 | /**
40 | * @brief Calculates the rotational distance in degree between two transformations.
41 | * Translation will be ignored.
42 | *
43 | * @param trans1 Transformation 1.
44 | * @param trans2 Transformation 2.
45 | * @return Angle in degree.
46 | */
47 | double calcAngularDistance(const dsacstar::trans_t& trans1, const dsacstar::trans_t& trans2)
48 | {
49 | cv::Mat rot1 = trans1.colRange(0, 3).rowRange(0, 3);
50 | cv::Mat rot2 = trans2.colRange(0, 3).rowRange(0, 3);
51 |
52 | cv::Mat rotDiff= rot2 * rot1.t();
53 | double trace = cv::trace(rotDiff)[0];
54 |
55 | trace = std::min(3.0, std::max(-1.0, trace));
56 | return 180*acos((trace-1.0)/2.0)/PI;
57 | }
58 |
59 | /**
60 | * @brief Weighted average of translational error and rotational error between two pose hypothesis.
61 | * @param h1 Pose 1.
62 | * @param h2 Pose 2.
63 | * @param wRot Weight of rotation error.
64 | * @param wTrans Weight of translation error.
65 | * @param cut Apply soft clamping after this value.
66 | * @return Loss.
67 | */
68 | double loss(
69 | const dsacstar::trans_t& trans1,
70 | const dsacstar::trans_t& trans2,
71 | double wRot = 1.0,
72 | double wTrans = 1.0,
73 | double cut = 100)
74 | {
75 | double rotErr = dsacstar::calcAngularDistance(trans1, trans2);
76 | double tErr = cv::norm(
77 | trans1.col(3).rowRange(0, 3) - trans2.col(3).rowRange(0, 3));
78 |
79 | double loss = wRot * rotErr + wTrans * tErr;
80 |
81 | if(loss > cut)
82 | loss = std::sqrt(cut * loss);
83 |
84 | return std::min(loss, MAXLOSS);
85 | }
86 |
87 | /**
88 | * @brief Calculate the derivative of the loss w.r.t. the estimated pose.
89 | * @param est Estimated pose (6 DoF).
90 | * @param gt Ground truth pose (6 DoF).
91 | * @param wRot Weight of rotation error.
92 | * @param wTrans Weight of translation error.
93 | * @param cut Apply soft clamping after this value.
94 | * @return 1x6 Jacobean.
95 | */
96 | cv::Mat_ dLoss(
97 | const dsacstar::pose_t& est,
98 | const dsacstar::pose_t& gt,
99 | double wRot = 1.0,
100 | double wTrans = 1.0,
101 | double cut = 100)
102 | {
103 | cv::Mat rot1, rot2, dRod;
104 | cv::Rodrigues(est.first, rot1, dRod);
105 | cv::Rodrigues(gt.first, rot2);
106 |
107 | // measure loss of inverted poses (camera pose instead of scene pose)
108 | cv::Mat_ invRot1 = rot1.t();
109 | cv::Mat_ invRot2 = rot2.t();
110 |
111 | // get the difference rotation
112 | cv::Mat diffRot = rot1 * invRot2;
113 |
114 | // calculate rotational and translational error
115 | double trace = cv::trace(diffRot)[0];
116 | trace = std::min(3.0, std::max(-1.0, trace));
117 | double rotErr = 180*acos((trace-1.0)/2.0)/CV_PI;
118 |
119 | cv::Mat_ invT1 = est.second.clone();
120 | invT1 = invRot1 * invT1;
121 |
122 | cv::Mat_ invT2 = gt.second.clone();
123 | invT2 = invRot2 * invT2;
124 |
125 | // zero error, abort
126 | double tErr = cv::norm(invT1 - invT2);
127 |
128 | cv::Mat_ jacobean = cv::Mat_::zeros(1, 6);
129 |
130 | // clamped loss, return zero gradient if loss is bigger than threshold
131 | double loss = wRot * rotErr + wTrans * tErr;
132 | bool cutLoss = false;
133 |
134 |
135 | if(loss > cut)
136 | {
137 | loss = std::sqrt(loss);
138 | cutLoss = true;
139 | }
140 |
141 | if(loss > MAXLOSS)
142 | return jacobean;
143 |
144 | if((tErr + rotErr) < EPS)
145 | return jacobean;
146 |
147 |
148 | // return gradient of translational error
149 | cv::Mat_ dDist_dInvT1(1, 3);
150 | for(unsigned i = 0; i < 3; i++)
151 | dDist_dInvT1(0, i) = (invT1(i, 0) - invT2(i, 0)) / tErr;
152 |
153 | cv::Mat_ dInvT1_dEstT(3, 3);
154 | dInvT1_dEstT = invRot1;
155 |
156 | cv::Mat_ dDist_dEstT = dDist_dInvT1 * dInvT1_dEstT;
157 | jacobean.colRange(3, 6) += dDist_dEstT * wTrans;
158 |
159 | cv::Mat_ dInvT1_dInvRot1 = cv::Mat_::zeros(3, 9);
160 |
161 | dInvT1_dInvRot1(0, 0) = est.second.at(0, 0);
162 | dInvT1_dInvRot1(0, 3) = est.second.at(1, 0);
163 | dInvT1_dInvRot1(0, 6) = est.second.at(2, 0);
164 |
165 | dInvT1_dInvRot1(1, 1) = est.second.at(0, 0);
166 | dInvT1_dInvRot1(1, 4) = est.second.at(1, 0);
167 | dInvT1_dInvRot1(1, 7) = est.second.at(2, 0);
168 |
169 | dInvT1_dInvRot1(2, 2) = est.second.at(0, 0);
170 | dInvT1_dInvRot1(2, 5) = est.second.at(1, 0);
171 | dInvT1_dInvRot1(2, 8) = est.second.at(2, 0);
172 |
173 | dRod = dRod.t();
174 |
175 | cv::Mat_ dDist_dRod = dDist_dInvT1 * dInvT1_dInvRot1 * dRod;
176 | jacobean.colRange(0, 3) += dDist_dRod * wTrans;
177 |
178 |
179 | // return gradient of rotational error
180 | cv::Mat_ dRotDiff = cv::Mat_::zeros(9, 9);
181 | invRot2.row(0).copyTo(dRotDiff.row(0).colRange(0, 3));
182 | invRot2.row(1).copyTo(dRotDiff.row(1).colRange(0, 3));
183 | invRot2.row(2).copyTo(dRotDiff.row(2).colRange(0, 3));
184 |
185 | invRot2.row(0).copyTo(dRotDiff.row(3).colRange(3, 6));
186 | invRot2.row(1).copyTo(dRotDiff.row(4).colRange(3, 6));
187 | invRot2.row(2).copyTo(dRotDiff.row(5).colRange(3, 6));
188 |
189 | invRot2.row(0).copyTo(dRotDiff.row(6).colRange(6, 9));
190 | invRot2.row(1).copyTo(dRotDiff.row(7).colRange(6, 9));
191 | invRot2.row(2).copyTo(dRotDiff.row(8).colRange(6, 9));
192 |
193 | dRotDiff = dRotDiff.t();
194 |
195 | cv::Mat_ dTrace = cv::Mat_::zeros(1, 9);
196 | dTrace(0, 0) = 1;
197 | dTrace(0, 4) = 1;
198 | dTrace(0, 8) = 1;
199 |
200 | cv::Mat_ dAngle = (180 / CV_PI * -1 / sqrt(3 - trace * trace + 2 * trace)) * dTrace * dRotDiff * dRod;
201 |
202 | jacobean.colRange(0, 3) += dAngle * wRot;
203 |
204 | if(cutLoss)
205 | jacobean *= 0.5 / loss;
206 |
207 |
208 | if(cv::sum(cv::Mat(jacobean != jacobean))[0] > 0) //check for NaNs
209 | return cv::Mat_::zeros(1, 6);
210 |
211 | return jacobean;
212 | }
213 |
214 |
215 | }
216 |
--------------------------------------------------------------------------------
/submodules/dsacstar/dsacstar_types.h:
--------------------------------------------------------------------------------
1 | /*
2 | Based on the DSAC++ and ESAC code.
3 | https://github.com/vislearn/LessMore
4 | https://github.com/vislearn/esac
5 |
6 | Copyright (c) 2016, TU Dresden
7 | Copyright (c) 2020, Heidelberg University
8 | All rights reserved.
9 |
10 | Redistribution and use in source and binary forms, with or without
11 | modification, are permitted provided that the following conditions are met:
12 | * Redistributions of source code must retain the above copyright
13 | notice, this list of conditions and the following disclaimer.
14 | * Redistributions in binary form must reproduce the above copyright
15 | notice, this list of conditions and the following disclaimer in the
16 | documentation and/or other materials provided with the distribution.
17 | * Neither the name of the TU Dresden, Heidelberg University nor the
18 | names of its contributors may be used to endorse or promote products
19 | derived from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY
25 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
28 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | */
32 |
33 | #pragma once
34 |
35 | #include "opencv2/opencv.hpp"
36 |
37 | /** Several important types used troughout all this code. If types have to be changed, it can be done here, conveniently. */
38 |
39 | namespace dsacstar
40 | {
41 | // scene pose type (OpenCV convention: axis-angle + translation)
42 | typedef std::pair pose_t;
43 | // camera transformation type (inverted scene pose as 4x4 matrix)
44 | typedef cv::Mat_ trans_t;
45 | // ATen accessor type
46 | typedef at::TensorAccessor coord_t;
47 | }
48 |
--------------------------------------------------------------------------------
/submodules/dsacstar/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import CppExtension, BuildExtension
3 | import os
4 |
5 | opencv_inc_dir = '' # directory containing OpenCV header files
6 | opencv_lib_dir = '' # directory containing OpenCV library files
7 |
8 | #if not explicitly provided, we try to locate OpenCV in the current Conda environment
9 | conda_env = os.environ['CONDA_PREFIX']
10 |
11 | if len(conda_env) > 0 and len(opencv_inc_dir) == 0 and len(opencv_lib_dir) == 0:
12 | print("Detected active conda environment:", conda_env)
13 |
14 | opencv_inc_dir = conda_env + '/include/opencv4'
15 | opencv_lib_dir = conda_env + '/lib/opencv4'
16 |
17 | print("Assuming OpenCV dependencies in:")
18 | print(opencv_inc_dir)
19 | print(opencv_lib_dir)
20 |
21 | if len(opencv_inc_dir) == 0:
22 | print("Error: You have to provide an OpenCV include directory. Edit this file.")
23 | exit()
24 | if len(opencv_lib_dir) == 0:
25 | print("Error: You have to provide an OpenCV library directory. Edit this file.")
26 | exit()
27 |
28 | setup(
29 | name='dsacstar',
30 | ext_modules=[CppExtension(
31 | name='dsacstar',
32 | sources=['dsacstar.cpp','thread_rand.cpp'],
33 | include_dirs=[opencv_inc_dir],
34 | library_dirs=[opencv_lib_dir],
35 | libraries=['opencv_core','opencv_calib3d'],
36 | extra_compile_args=['-fopenmp']
37 | )],
38 | cmdclass={'build_ext': BuildExtension})
39 |
--------------------------------------------------------------------------------
/submodules/dsacstar/stop_watch.h:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (c) 2016, TU Dresden
3 | Copyright (c) 2017, Heidelberg University
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 | * Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 | * Redistributions in binary form must reproduce the above copyright
11 | notice, this list of conditions and the following disclaimer in the
12 | documentation and/or other materials provided with the distribution.
13 | * Neither the name of the TU Dresden, Heidelberg University nor the
14 | names of its contributors may be used to endorse or promote products
15 | derived from this software without specific prior written permission.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY
21 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 | */
28 |
29 |
30 | #pragma once
31 |
32 | #include
33 |
34 | /**
35 | * @brief Class for time measurements.
36 | */
37 | class StopWatch
38 | {
39 | public:
40 | /**
41 | * @brief Construction. Initializes the stop watch.
42 | */
43 | StopWatch(){ init(); }
44 |
45 | /**
46 | * @brief Initialization. Starts the time measurement.
47 | *
48 | * @return void
49 | */
50 | void init()
51 | {
52 | start = std::chrono::high_resolution_clock::now();
53 | }
54 |
55 | /**
56 | * @brief Stops and restarts the time measurement.
57 | *
58 | * @return float The time in ms since the last init or stop call.
59 | */
60 | float stop()
61 | {
62 | std::chrono::high_resolution_clock::time_point now;
63 | now = std::chrono::high_resolution_clock::now();
64 |
65 | std::chrono::high_resolution_clock::duration duration = now - start;
66 |
67 | start = now;
68 |
69 | return static_cast(
70 | 1000.0 * std::chrono::duration_cast>(
71 | duration).count());
72 | }
73 |
74 | private:
75 | std::chrono::high_resolution_clock::time_point start; // start time of the current measurement.
76 | };
77 |
--------------------------------------------------------------------------------
/submodules/dsacstar/thread_rand.cpp:
--------------------------------------------------------------------------------
1 | #include "thread_rand.h"
2 | #include
3 |
4 | std::vector ThreadRand::generators;
5 | bool ThreadRand::initialised = false;
6 |
7 | void ThreadRand::forceInit(unsigned seed)
8 | {
9 | initialised = false;
10 | init(seed);
11 | }
12 |
13 | void ThreadRand::init(unsigned seed)
14 | {
15 | #pragma omp critical
16 | {
17 | if(!initialised)
18 | {
19 | unsigned nThreads = omp_get_max_threads();
20 |
21 | for(unsigned i = 0; i < nThreads; i++)
22 | {
23 | generators.push_back(std::mt19937());
24 | generators[i].seed(i+seed);
25 | }
26 |
27 | initialised = true;
28 | }
29 | }
30 | }
31 |
32 | int ThreadRand::irand(int min, int max, int tid)
33 | {
34 | std::uniform_int_distribution dist(min, max);
35 |
36 | unsigned threadID = omp_get_thread_num();
37 | if(tid >= 0) threadID = tid;
38 |
39 | if(!initialised) init();
40 |
41 | return dist(ThreadRand::generators[threadID]);
42 | }
43 |
44 | double ThreadRand::drand(double min, double max, int tid)
45 | {
46 | std::uniform_real_distribution dist(min, max);
47 |
48 | unsigned threadID = omp_get_thread_num();
49 | if(tid >= 0) threadID = tid;
50 |
51 | if(!initialised) init();
52 |
53 | return dist(ThreadRand::generators[threadID]);
54 | }
55 |
56 | double ThreadRand::dgauss(double mean, double stdDev, int tid)
57 | {
58 | std::normal_distribution dist(mean, stdDev);
59 |
60 | unsigned threadID = omp_get_thread_num();
61 | if(tid >= 0) threadID = tid;
62 |
63 | if(!initialised) init();
64 |
65 | return dist(ThreadRand::generators[threadID]);
66 | }
67 |
68 | int irand(int incMin, int excMax, int tid)
69 | {
70 | return ThreadRand::irand(incMin, excMax - 1, tid);
71 | }
72 |
73 | double drand(double incMin, double incMax,int tid)
74 | {
75 | return ThreadRand::drand(incMin, incMax, tid);
76 | }
77 |
78 | int igauss(int mean, int stdDev, int tid)
79 | {
80 | return (int) ThreadRand::dgauss(mean, stdDev, tid);
81 | }
82 |
83 | double dgauss(double mean, double stdDev, int tid)
84 | {
85 | return ThreadRand::dgauss(mean, stdDev, tid);
86 | }
--------------------------------------------------------------------------------
/submodules/dsacstar/thread_rand.h:
--------------------------------------------------------------------------------
1 | /*
2 | Based on the DSAC++ and ESAC code.
3 | https://github.com/vislearn/LessMore
4 | https://github.com/vislearn/esac
5 |
6 | Copyright (c) 2016, TU Dresden
7 | Copyright (c) 2019, Heidelberg University
8 | All rights reserved.
9 |
10 | Redistribution and use in source and binary forms, with or without
11 | modification, are permitted provided that the following conditions are met:
12 | * Redistributions of source code must retain the above copyright
13 | notice, this list of conditions and the following disclaimer.
14 | * Redistributions in binary form must reproduce the above copyright
15 | notice, this list of conditions and the following disclaimer in the
16 | documentation and/or other materials provided with the distribution.
17 | * Neither the name of the TU Dresden, Heidelberg University nor the
18 | names of its contributors may be used to endorse or promote products
19 | derived from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY
25 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
28 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | */
32 |
33 | #pragma once
34 |
35 | #include
36 |
37 | /** Classes and methods for generating random numbers in multi-threaded programs. */
38 |
39 | /**
40 | * @brief Provides random numbers for multiple threads.
41 | *
42 | * Singelton class. Holds a random number generator for each thread and gives random numbers for the current thread.
43 | */
44 | class ThreadRand
45 | {
46 | public:
47 | /**
48 | * @brief Returns a random integer (uniform distribution).
49 | *
50 | * @param min Minimum value of the random integer (inclusive).
51 | * @param max Maximum value of the random integer (exclusive).
52 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself.
53 | * @return int Random integer value.
54 | */
55 | static int irand(int min, int max, int tid = -1);
56 |
57 | /**
58 | * @brief Returns a random double value (uniform distribution).
59 | *
60 | * @param min Minimum value of the random double (inclusive).
61 | * @param max Maximum value of the random double (inclusive).
62 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself.
63 | * @return double Random double value.
64 | */
65 | static double drand(double min, double max, int tid = -1);
66 |
67 | /**
68 | * @brief Returns a random double value (Gauss distribution).
69 | *
70 | * @param mean Mean of the Gauss distribution to sample from.
71 | * @param stdDev Standard deviation of the Gauss distribution to sample from.
72 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself.
73 | * @return double Random double value.
74 | */
75 | static double dgauss(double mean, double stdDev, int tid = -1);
76 |
77 | /**
78 | * @brief Re-Initialize the object with the given seed.
79 | *
80 | * @param seed Seed to initialize the random number generators (seed is incremented by one for each generator).
81 | * @return void
82 | */
83 | static void forceInit(unsigned seed);
84 |
85 | /**
86 | * @brief List of random number generators. One for each thread.
87 | *
88 | */
89 | static std::vector generators;
90 |
91 | /**
92 | * @brief Initialize class with the given seed.
93 | *
94 | * Method will create a random number generator for each thread. The given seed
95 | * will be incremented by one for each generator. This methods is automatically
96 | * called when this calss is used the first time.
97 | *
98 | * @param seed Optional parameter. Seed to be used when initializing the generators. Will be incremented by one for each generator.
99 | * @return void
100 | */
101 | static void init(unsigned seed = 1305);
102 |
103 | private:
104 | /**
105 | * @brief True if the class has been initialized already
106 | */
107 | static bool initialised;
108 | };
109 |
110 | /**
111 | * @brief Returns a random integer (uniform distribution).
112 | *
113 | * This method used the ThreadRand class.
114 | *
115 | * @param min Minimum value of the random integer (inclusive).
116 | * @param max Maximum value of the random integer (exclusive).
117 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself.
118 | * @return int Random integer value.
119 | */
120 | int irand(int incMin, int excMax, int tid = -1);
121 | /**
122 | * @brief Returns a random double value (uniform distribution).
123 | *
124 | * This method used the ThreadRand class.
125 | *
126 | * @param min Minimum value of the random double (inclusive).
127 | * @param max Maximum value of the random double (inclusive).
128 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself.
129 | * @return double Random double value.
130 | */
131 | double drand(double incMin, double incMax, int tid = -1);
132 |
133 | /**
134 | * @brief Returns a random integer value (Gauss distribution).
135 | *
136 | * This method used the ThreadRand class.
137 | *
138 | * @param mean Mean of the Gauss distribution to sample from.
139 | * @param stdDev Standard deviation of the Gauss distribution to sample from.
140 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself.
141 | * @return double Random integer value.
142 | */
143 | int igauss(int mean, int stdDev, int tid = -1);
144 |
145 | /**
146 | * @brief Returns a random double value (Gauss distribution).
147 | *
148 | * This method used the ThreadRand class.
149 | *
150 | * @param mean Mean of the Gauss distribution to sample from.
151 | * @param stdDev Standard deviation of the Gauss distribution to sample from.
152 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself.
153 | * @return double Random double value.
154 | */
155 | double dgauss(double mean, double stdDev, int tid = -1);
156 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=[E1101,W0621,E0401]
2 |
3 | import copy
4 | import os
5 | import warnings
6 | import logging
7 |
8 | import omegaconf
9 | from omegaconf import OmegaConf
10 |
11 | from conerf.utils.config import config_parser, load_config
12 | from conerf.utils.utils import setup_seed
13 | from utils import create_trainer # pylint: disable=E0611
14 |
15 | warnings.filterwarnings("ignore", category=UserWarning)
16 |
17 |
18 | def run_cmd(cmd: str):
19 | os.system(cmd)
20 |
21 | return True
22 |
23 |
24 | def train(config: OmegaConf):
25 | trainer = create_trainer(config)
26 | trainer.update_meta_data()
27 | trainer.train()
28 | print(f"total iteration: {trainer.iteration}")
29 |
30 |
31 | if __name__ == "__main__":
32 | args = config_parser()
33 |
34 | logging.basicConfig(
35 | format='%(asctime)s %(levelname)-6s [%(filename)s:%(lineno)d] %(message)s',
36 | datefmt='%Y-%m-%d:%H:%M:%S',
37 | level=logging.INFO
38 | )
39 |
40 | # parse YAML config to OmegaConf
41 | config = load_config(args.config)
42 | config["config_file_path"] = args.config
43 |
44 | assert config.dataset.scene != ""
45 |
46 | setup_seed(config.seed)
47 |
48 | scenes = []
49 | if (
50 | type(config.dataset.scene) == omegaconf.listconfig.ListConfig # pylint: disable=C0123
51 | ):
52 | scene_list = list(config.dataset.scene)
53 | for sc in config.dataset.scene:
54 | scenes.append(sc)
55 | else:
56 | scenes.append(config.dataset.scene)
57 |
58 | for scene in scenes:
59 | data_dir = os.path.join(config.dataset.root_dir, scene)
60 | assert os.path.exists(data_dir), f"Dataset does not exist: {data_dir}!"
61 |
62 | local_config = copy.deepcopy(config)
63 | local_config.expname = (
64 | f"{config.neural_field_type}_{config.task}_{config.dataset.name}_{scene}"
65 | )
66 | local_config.expname = local_config.expname + "_" + args.suffix
67 | local_config.dataset.scene = scene
68 | local_config.dataset.model_folder = args.model_folder
69 | local_config.dataset.init_ply_type = args.init_ply_type
70 | local_config.dataset.load_specified_images = args.load_specified_images
71 |
72 | train(local_config)
73 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 |
3 | from conerf.base.model_base import ModelBase
4 | from conerf.trainers.ace_zero_trainer import AceZeroTrainer
5 |
6 |
7 | def create_trainer(
8 | config: OmegaConf,
9 | prefetch_dataset=True,
10 | trainset=None,
11 | valset=None,
12 | model: ModelBase = None
13 | ):
14 | """Factory function for training neural network trainers."""
15 | if config.task == "pose":
16 | trainer = AceZeroTrainer(config, prefetch_dataset, trainset, valset)
17 | else:
18 | raise NotImplementedError
19 |
20 | return trainer
21 |
--------------------------------------------------------------------------------