├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── confidence.gif ├── logo.svg ├── teaser.svg └── viewer.gif ├── notebooks ├── demo.ipynb ├── disc.png ├── plot_convergence_basin.ipynb ├── plot_damping_factors.ipynb ├── plot_initial_errors.ipynb ├── training_CMU.ipynb ├── training_MegaDepth.ipynb └── visualize_confidences.ipynb ├── pixloc ├── __init__.py ├── download.py ├── localization │ ├── __init__.py │ ├── base_refiner.py │ ├── feature_extractor.py │ ├── localizer.py │ ├── model3d.py │ ├── refiners.py │ └── tracker.py ├── pixlib │ ├── README.md │ ├── __init__.py │ ├── configs │ │ ├── train_gnnet_cmu.yaml │ │ ├── train_pixloc_cmu.yaml │ │ └── train_pixloc_megadepth.yaml │ ├── datasets │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── cmu.py │ │ ├── image_folder.py │ │ ├── megadepth.py │ │ ├── sampling.py │ │ ├── train_scenes.txt │ │ ├── valid_scenes.txt │ │ └── view.py │ ├── geometry │ │ ├── __init__.py │ │ ├── check_jacobians.py │ │ ├── costs.py │ │ ├── interpolation.py │ │ ├── losses.py │ │ ├── optimization.py │ │ ├── utils.py │ │ └── wrappers.py │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── base_optimizer.py │ │ ├── classic_optimizer.py │ │ ├── gaussiannet.py │ │ ├── gnnet.py │ │ ├── learned_optimizer.py │ │ ├── s2dnet.py │ │ ├── two_view_refiner.py │ │ ├── unet.py │ │ └── utils.py │ ├── preprocess_cmu.py │ ├── preprocess_megadepth.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── experiments.py │ │ ├── stdout_capturing.py │ │ ├── tensor.py │ │ └── tools.py ├── run_7Scenes.py ├── run_Aachen.py ├── run_CMU.py ├── run_Cambridge.py ├── run_RobotCar.py ├── settings.py ├── utils │ ├── colmap.py │ ├── data.py │ ├── eval.py │ ├── io.py │ ├── quaternions.py │ └── tools.py └── visualization │ ├── animation.py │ ├── viz_2d.py │ └── viz_3d.py ├── requirements.txt ├── setup.py └── viewer ├── disc.png ├── dumps └── sample │ ├── dump.json │ ├── dump_p2d.json │ ├── query.jpg │ ├── ref0.jpg │ ├── ref1.jpg │ └── ref2.jpg ├── jsm ├── OrbitControls.js ├── PLYLoader.js ├── lib3d.js ├── lines │ ├── Line2.js │ ├── LineGeometry.js │ ├── LineMaterial.js │ ├── LineSegments2.js │ ├── LineSegmentsGeometry.js │ ├── Wireframe.js │ └── WireframeGeometry2.js └── three.module.js ├── jupyter.html ├── server.py ├── style.css ├── viewer.html └── viewer.js /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | outputs/ 3 | *.mp4 4 | lsf* 5 | .DS_Store 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # vscode 137 | .vscode 138 | 139 | -------------------------------------------------------------------------------- /assets/confidence.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/assets/confidence.gif -------------------------------------------------------------------------------- /assets/viewer.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/assets/viewer.gif -------------------------------------------------------------------------------- /notebooks/disc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/notebooks/disc.png -------------------------------------------------------------------------------- /notebooks/training_MegaDepth.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**This notebook shows how to run the inference in the training-time two-view settings on the validation or training set of MegaDepth to visualize the training metrics and losses.**" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "%autoreload 2\n", 18 | "\n", 19 | "import torch\n", 20 | "import numpy as np\n", 21 | "import matplotlib as mpl\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "from omegaconf import OmegaConf\n", 24 | "\n", 25 | "from pixloc import run_Aachen\n", 26 | "from pixloc.pixlib.datasets.megadepth import MegaDepth\n", 27 | "from pixloc.pixlib.utils.tensor import batch_to_device, map_tensor\n", 28 | "from pixloc.pixlib.utils.tools import set_seed\n", 29 | "from pixloc.pixlib.utils.experiments import load_experiment\n", 30 | "from pixloc.visualization.viz_2d import (\n", 31 | " plot_images, plot_keypoints, plot_matches, cm_RdGn,\n", 32 | " features_to_RGB, add_text)\n", 33 | "\n", 34 | "torch.set_grad_enabled(False);\n", 35 | "mpl.rcParams['image.interpolation'] = 'bilinear'" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "# Create a validation or training dataloader" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 5, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stderr", 52 | "output_type": "stream", 53 | "text": [ 54 | "[09/24/2021 16:58:00 pixloc.pixlib.datasets.base_dataset INFO] Creating dataset MegaDepth\n", 55 | "[09/24/2021 16:58:00 pixloc.pixlib.datasets.megadepth INFO] Sampling new images or pairs with seed 1\n", 56 | " 44%|██████████████████████▉ | 34/77 [00:03<00:04, 9.24it/s][09/24/2021 16:58:04 pixloc.pixlib.datasets.megadepth WARNING] Scene 0209 does not have an info file\n", 57 | "100%|████████████████████████████████████████████████████| 77/77 [00:06<00:00, 11.66it/s]\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "conf = {\n", 63 | " 'min_overlap': 0.4,\n", 64 | " 'max_overlap': 1.0,\n", 65 | " 'max_num_points3D': 512,\n", 66 | " 'force_num_points3D': True,\n", 67 | " \n", 68 | " 'resize': 512,\n", 69 | " 'resize_by': 'min',\n", 70 | " 'crop': 512,\n", 71 | " 'optimal_crop': True,\n", 72 | " \n", 73 | " 'init_pose': [0.75, 1.],\n", 74 | "# 'init_pose': 'max_error',\n", 75 | "# 'init_pose_max_error': 4,\n", 76 | "# 'init_pose_num_samples': 50,\n", 77 | " \n", 78 | " 'batch_size': 1,\n", 79 | " 'seed': 1,\n", 80 | " 'num_workers': 0,\n", 81 | "}\n", 82 | "loader = MegaDepth(conf).get_data_loader('val', shuffle=True)\n", 83 | "orig_items = loader.dataset.items" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "# Load the training experiment" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# Name of the example experiment. Replace with your own training experiment.\n", 100 | "exp = run_Aachen.experiment\n", 101 | "device = 'cuda'\n", 102 | "conf = {\n", 103 | " 'optimizer': {'num_iters': 20,},\n", 104 | "}\n", 105 | "refiner = load_experiment(exp, conf).to(device)\n", 106 | "print(OmegaConf.to_yaml(refiner.conf))" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "# Run on a few examples" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "- Reference image: red/green = reprojections of 3D points not/visible in the query at the ground truth pose\n", 121 | "- Query image: red/blue/green = reprojections of 3D points at the initial/final/GT poses\n", 122 | "- ΔP/ΔR/Δt are final errors in terms of 2D reprojections, rotation, and translation" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "set_seed(7)\n", 132 | "for _, data in zip(range(5), loader):\n", 133 | " data_ = batch_to_device(data, device)\n", 134 | " pred_ = refiner(data_)\n", 135 | " pred = map_tensor(pred_, lambda x: x[0].cpu())\n", 136 | " data = map_tensor(data, lambda x: x[0].cpu())\n", 137 | " cam_q = data['query']['camera']\n", 138 | " p3D_r = data['ref']['points3D']\n", 139 | " \n", 140 | " p2D_r, valid_r = data['ref']['camera'].world2image(p3D_r)\n", 141 | " p2D_q_gt, valid_q = cam_q.world2image(data['T_r2q_gt'] * p3D_r)\n", 142 | " p2D_q_init, _ = cam_q.world2image(data['T_r2q_init'] * p3D_r)\n", 143 | " p2D_q_opt, _ = cam_q.world2image(pred['T_r2q_opt'][-1] * p3D_r)\n", 144 | " valid = valid_q & valid_r\n", 145 | " \n", 146 | " losses = refiner.loss(pred_, data_)\n", 147 | " mets = refiner.metrics(pred_, data_)\n", 148 | " errP = f\"ΔP {losses['reprojection_error/init'].item():.2f} -> {losses['reprojection_error'].item():.3f} px; \"\n", 149 | " errR = f\"ΔR {mets['R_error/init'].item():.2f} -> {mets['R_error'].item():.3f} deg; \"\n", 150 | " errt = f\"Δt {mets['t_error/init'].item():.2f} -> {mets['t_error'].item():.3f} %m\"\n", 151 | " print(errP, errR, errt)\n", 152 | "\n", 153 | " imr, imq = data['ref']['image'].permute(1, 2, 0), data['query']['image'].permute(1, 2, 0)\n", 154 | " plot_images([imr, imq],titles=[(data['scene'][0], valid_r.sum().item(), valid_q.sum().item()), errP+'; '+errR])\n", 155 | " plot_keypoints([p2D_r[valid_r], p2D_q_gt[valid]], colors=[cm_RdGn(valid[valid_r]), 'lime'])\n", 156 | " plot_keypoints([np.empty((0, 2)), p2D_q_init[valid]], colors='red')\n", 157 | " plot_keypoints([np.empty((0, 2)), p2D_q_opt[valid]], colors='blue')\n", 158 | " add_text(0, 'reference')\n", 159 | " add_text(1, 'query')\n", 160 | "\n", 161 | " continue\n", 162 | " for i, (F0, F1) in enumerate(zip(pred['ref']['feature_maps'], pred['query']['feature_maps'])):\n", 163 | " C_r, C_q = pred['ref']['confidences'][i][0], pred['query']['confidences'][i][0]\n", 164 | " plot_images([C_r, C_q], cmaps=mpl.cm.turbo)\n", 165 | " add_text(0, f'Level {i}')\n", 166 | " \n", 167 | " axes = plt.gcf().axes\n", 168 | " axes[0].imshow(imr, alpha=0.2, extent=axes[0].images[0]._extent)\n", 169 | " axes[1].imshow(imq, alpha=0.2, extent=axes[1].images[0]._extent)\n", 170 | " plot_images(features_to_RGB(F0.numpy(), F1.numpy(), skip=1))" 171 | ] 172 | } 173 | ], 174 | "metadata": { 175 | "kernelspec": { 176 | "display_name": "Python 3 (ipykernel)", 177 | "language": "python", 178 | "name": "python3" 179 | }, 180 | "language_info": { 181 | "codemirror_mode": { 182 | "name": "ipython", 183 | "version": 3 184 | }, 185 | "file_extension": ".py", 186 | "mimetype": "text/x-python", 187 | "name": "python", 188 | "nbconvert_exporter": "python", 189 | "pygments_lexer": "ipython3", 190 | "version": "3.7.8" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 4 195 | } 196 | -------------------------------------------------------------------------------- /pixloc/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | formatter = logging.Formatter( 4 | fmt='[%(asctime)s %(name)s %(levelname)s] %(message)s', 5 | datefmt='%m/%d/%Y %H:%M:%S') 6 | handler = logging.StreamHandler() 7 | handler.setFormatter(formatter) 8 | handler.setLevel(logging.INFO) 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | logger.addHandler(handler) 13 | logger.propagate = False 14 | 15 | 16 | def set_logging_debug(mode: bool): 17 | if mode: 18 | logger.setLevel(logging.DEBUG) 19 | -------------------------------------------------------------------------------- /pixloc/localization/__init__.py: -------------------------------------------------------------------------------- 1 | from .model3d import Model3D # noqa 2 | from .localizer import PoseLocalizer, RetrievalLocalizer # noqa 3 | from .refiners import PoseRefiner, RetrievalRefiner # noqa 4 | from .tracker import SimpleTracker # noqa 5 | -------------------------------------------------------------------------------- /pixloc/localization/feature_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | from omegaconf import DictConfig, OmegaConf as oc 3 | import numpy as np 4 | import torch 5 | 6 | from ..pixlib.datasets.view import resize, numpy_image_to_torch 7 | 8 | 9 | class FeatureExtractor(torch.nn.Module): 10 | default_conf: Dict = dict( 11 | resize=1024, 12 | resize_by='max', 13 | ) 14 | 15 | def __init__(self, model: torch.nn.Module, device: torch.device, 16 | conf: Union[Dict, DictConfig]): 17 | super().__init__() 18 | self.conf = oc.merge(oc.create(self.default_conf), oc.create(conf)) 19 | self.device = device 20 | self.model = model 21 | 22 | assert hasattr(self.model, 'scales') 23 | assert self.conf.resize_by in ['max', 'max_force'], self.conf.resize_by 24 | self.to(device) 25 | self.eval() 26 | 27 | def prepare_input(self, image: np.array) -> torch.Tensor: 28 | return numpy_image_to_torch(image).to(self.device).unsqueeze(0) 29 | 30 | @torch.no_grad() 31 | def __call__(self, image: np.array, scale_image: int = 1): 32 | """Extract feature-maps for a given image. 33 | Args: 34 | image: input image (H, W, C) 35 | """ 36 | image = image.astype(np.float32) # better for resizing 37 | scale_resize = (1., 1.) 38 | if self.conf.resize is not None: 39 | target_size = self.conf.resize // scale_image 40 | if (max(image.shape[:2]) > target_size or 41 | self.conf.resize_by == 'max_force'): 42 | image, scale_resize = resize(image, target_size, max, 'linear') 43 | 44 | image_tensor = self.prepare_input(image) 45 | pred = self.model({'image': image_tensor}) 46 | features = pred['feature_maps'] 47 | assert len(self.model.scales) == len(features) 48 | 49 | features = [feat.squeeze(0) for feat in features] # remove batch dim 50 | confidences = pred.get('confidences') 51 | if confidences is not None: 52 | confidences = [c.squeeze(0) for c in confidences] 53 | 54 | scales = [(scale_resize[0]/s, scale_resize[1]/s) 55 | for s in self.model.scales] 56 | 57 | return features, scales, confidences 58 | -------------------------------------------------------------------------------- /pixloc/localization/localizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | from typing import Optional, Dict, Tuple, Union 4 | from omegaconf import DictConfig, OmegaConf as oc 5 | from tqdm import tqdm 6 | import torch 7 | 8 | from .model3d import Model3D 9 | from .feature_extractor import FeatureExtractor 10 | from .refiners import PoseRefiner, RetrievalRefiner 11 | 12 | from ..utils.data import Paths 13 | from ..utils.io import parse_image_lists, parse_retrieval, load_hdf5 14 | from ..utils.quaternions import rotmat2qvec 15 | from ..pixlib.utils.experiments import load_experiment 16 | from ..pixlib.models import get_model 17 | from ..pixlib.geometry import Camera 18 | 19 | logger = logging.getLogger(__name__) 20 | # TODO: despite torch.no_grad in BaseModel, requires_grad flips in ref interp 21 | torch.set_grad_enabled(False) 22 | 23 | 24 | class Localizer: 25 | def __init__(self, paths: Paths, conf: Union[DictConfig, Dict], 26 | device: Optional[torch.device] = None): 27 | if device is None: 28 | if torch.cuda.is_available(): 29 | device = torch.device('cuda:0') 30 | else: 31 | device = torch.device('cpu') 32 | 33 | self.model3d = Model3D(paths.reference_sfm) 34 | cameras = parse_image_lists(paths.query_list, with_intrinsics=True) 35 | self.queries = {n: c for n, c in cameras} 36 | 37 | # Loading feature extractor and optimizer from experiment or scratch 38 | conf = oc.create(conf) 39 | conf_features = conf.features.get('conf', {}) 40 | conf_optim = conf.get('optimizer', {}) 41 | if conf.get('experiment'): 42 | pipeline = load_experiment( 43 | conf.experiment, 44 | {'extractor': conf_features, 'optimizer': conf_optim}) 45 | pipeline = pipeline.to(device) 46 | logger.debug( 47 | 'Use full pipeline from experiment %s with config:\n%s', 48 | conf.experiment, oc.to_yaml(pipeline.conf)) 49 | extractor = pipeline.extractor 50 | optimizer = pipeline.optimizer 51 | if isinstance(optimizer, torch.nn.ModuleList): 52 | optimizer = list(optimizer) 53 | else: 54 | assert 'name' in conf.features 55 | extractor = get_model(conf.features.name)(conf_features) 56 | optimizer = get_model(conf.optimizer.name)(conf_optim) 57 | 58 | self.paths = paths 59 | self.conf = conf 60 | self.device = device 61 | self.optimizer = optimizer 62 | self.extractor = FeatureExtractor( 63 | extractor, device, conf.features.get('preprocessing', {})) 64 | 65 | def run_query(self, name: str, camera: Camera): 66 | raise NotImplementedError 67 | 68 | def run_batched(self, skip: Optional[int] = None, 69 | ) -> Tuple[Dict[str, Tuple], Dict]: 70 | output_poses = {} 71 | output_logs = { 72 | 'paths': self.paths.asdict(), 73 | 'configuration': oc.to_yaml(self.conf), 74 | 'localization': {}, 75 | } 76 | 77 | logger.info('Starting the localization process...') 78 | query_names = list(self.queries.keys())[::skip or 1] 79 | for name in tqdm(query_names): 80 | camera = Camera.from_colmap(self.queries[name]) 81 | try: 82 | ret = self.run_query(name, camera) 83 | except RuntimeError as e: 84 | if 'CUDA out of memory' in str(e): 85 | logger.info('Out of memory') 86 | torch.cuda.empty_cache() 87 | ret = {'success': False} 88 | else: 89 | raise 90 | output_logs['localization'][name] = ret 91 | if ret['success']: 92 | R, tvec = ret['T_refined'].numpy() 93 | elif 'T_init' in ret: 94 | R, tvec = ret['T_init'].numpy() 95 | else: 96 | continue 97 | output_poses[name] = (rotmat2qvec(R), tvec) 98 | 99 | return output_poses, output_logs 100 | 101 | 102 | class RetrievalLocalizer(Localizer): 103 | def __init__(self, paths: Paths, conf: Union[DictConfig, Dict], 104 | device: Optional[torch.device] = None): 105 | super().__init__(paths, conf, device) 106 | 107 | if paths.global_descriptors is not None: 108 | global_descriptors = load_hdf5(paths.global_descriptors) 109 | else: 110 | global_descriptors = None 111 | 112 | self.refiner = RetrievalRefiner( 113 | self.device, self.optimizer, self.model3d, self.extractor, paths, 114 | self.conf.refinement, global_descriptors=global_descriptors) 115 | 116 | if paths.hloc_logs is not None: 117 | logger.info('Reading hloc logs...') 118 | with open(paths.hloc_logs, 'rb') as f: 119 | self.logs = pickle.load(f)['loc'] 120 | self.retrieval = {q: [self.model3d.dbs[i].name for i in loc['db']] 121 | for q, loc in self.logs.items()} 122 | elif paths.retrieval_pairs is not None: 123 | self.logs = None 124 | self.retrieval = parse_retrieval(paths.retrieval_pairs) 125 | else: 126 | raise ValueError 127 | 128 | def run_query(self, name: str, camera: Camera): 129 | dbs = [self.model3d.name2id[r] for r in self.retrieval[name]] 130 | loc = None if self.logs is None else self.logs[name] 131 | ret = self.refiner.refine(name, camera, dbs, loc=loc) 132 | return ret 133 | 134 | 135 | class PoseLocalizer(Localizer): 136 | def __init__(self, paths: Paths, conf: Union[DictConfig, Dict], 137 | device: Optional[torch.device] = None): 138 | super().__init__(paths, conf, device) 139 | 140 | self.refiner = PoseRefiner( 141 | device, self.optimizer, self.model3d, self.extractor, paths, 142 | self.conf.refinement) 143 | 144 | logger.info('Reading hloc logs...') 145 | with open(paths.hloc_logs, 'rb') as f: 146 | self.logs = pickle.load(f)['loc'] 147 | 148 | def run_query(self, name: str, camera: Camera): 149 | loc = self.logs[name] 150 | if loc['PnP_ret']['success']: 151 | ret = self.refiner.refine(name, camera, loc) 152 | else: 153 | ret = {'success': False} 154 | return ret 155 | -------------------------------------------------------------------------------- /pixloc/localization/model3d.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from typing import Dict, List, Optional 4 | import numpy as np 5 | 6 | from ..utils.colmap import read_model 7 | from ..utils.quaternions import weighted_pose 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class Model3D: 13 | def __init__(self, path): 14 | logger.info('Reading COLMAP model %s.', path) 15 | self.cameras, self.dbs, self.points3D = read_model(path) 16 | self.name2id = {i.name: i.id for i in self.dbs.values()} 17 | 18 | def covisbility_filtering(self, dbids): 19 | clusters = do_covisibility_clustering(dbids, self.dbs, self.points3D) 20 | dbids = clusters[0] 21 | return dbids 22 | 23 | def pose_approximation(self, qname, dbids, global_descriptors, alpha=8): 24 | """Described in: 25 | Benchmarking Image Retrieval for Visual Localization. 26 | Noé Pion, Martin Humenberger, Gabriela Csurka, 27 | Yohann Cabon, Torsten Sattler. 3DV 2020. 28 | """ 29 | dbs = [self.dbs[i] for i in dbids] 30 | 31 | dbdescs = np.stack([global_descriptors[im.name] for im in dbs]) 32 | qdesc = global_descriptors[qname] 33 | sim = dbdescs @ qdesc 34 | weights = sim**alpha 35 | weights /= weights.sum() 36 | 37 | tvecs = [im.tvec for im in dbs] 38 | qvecs = [im.qvec for im in dbs] 39 | return weighted_pose(tvecs, qvecs, weights) 40 | 41 | def get_dbid_to_p3dids(self, p3did_to_dbids): 42 | """Link the database images to selected 3D points.""" 43 | dbid_to_p3dids = defaultdict(list) 44 | for p3id, obs_dbids in p3did_to_dbids.items(): 45 | for obs_dbid in obs_dbids: 46 | dbid_to_p3dids[obs_dbid].append(p3id) 47 | return dict(dbid_to_p3dids) 48 | 49 | def get_p3did_to_dbids(self, dbids: List, loc: Optional[Dict] = None, 50 | inliers: Optional[List] = None, 51 | point_selection: str = 'all', 52 | min_track_length: int = 3): 53 | """Return a dictionary mapping 3D point ids to their covisible dbids. 54 | This function can use hloc sfm logs to only select inliers. 55 | Which can be further used to select top reference images / in 56 | sufficient track length selection of points. 57 | """ 58 | p3did_to_dbids = defaultdict(set) 59 | if point_selection == 'all': 60 | for dbid in dbids: 61 | p3dids = self.dbs[dbid].point3D_ids 62 | for p3did in p3dids[p3dids != -1]: 63 | p3did_to_dbids[p3did].add(dbid) 64 | elif point_selection in ['inliers', 'matched']: 65 | if loc is None: 66 | raise ValueError('"{point_selection}" point selection requires' 67 | ' localization logs.') 68 | 69 | # The given SfM model must match the localization SfM model! 70 | for (p3did, dbidxs), inlier in zip(loc["keypoint_index_to_db"][1], 71 | inliers): 72 | if inlier or point_selection == 'matched': 73 | obs_dbids = set(loc["db"][dbidx] for dbidx in dbidxs) 74 | obs_dbids &= set(dbids) 75 | if len(obs_dbids) > 0: 76 | p3did_to_dbids[p3did] |= obs_dbids 77 | else: 78 | raise ValueError(f"{point_selection} point selection not defined.") 79 | 80 | # Filter unstable points (min track length) 81 | p3did_to_dbids = { 82 | i: v 83 | for i, v in p3did_to_dbids.items() 84 | if len(self.points3D[i].image_ids) >= min_track_length 85 | } 86 | 87 | return p3did_to_dbids 88 | 89 | def rerank_and_filter_db_images(self, dbids: List, ninl_dbs: List, 90 | num_dbs: int, min_matches_db: int = 0): 91 | """Re-rank the images by inlier count and filter invalid images.""" 92 | dbids = [dbids[i] for i in np.argsort(-ninl_dbs) 93 | if ninl_dbs[i] > min_matches_db] 94 | # Keep top num_images matched image images 95 | dbids = dbids[:num_dbs] 96 | return dbids 97 | 98 | def get_db_inliers(self, loc: Dict, dbids: List, inliers: List): 99 | """Get the number of inliers for each db.""" 100 | inliers = loc["PnP_ret"]["inliers"] 101 | dbids = loc["db"] 102 | ninl_dbs = np.zeros(len(dbids)) 103 | for (_, dbidxs), inl in zip(loc["keypoint_index_to_db"][1], inliers): 104 | if not inl: 105 | continue 106 | for dbidx in dbidxs: 107 | ninl_dbs[dbidx] += 1 108 | return ninl_dbs 109 | 110 | 111 | def do_covisibility_clustering(frame_ids, all_images, points3D): 112 | clusters = [] 113 | visited = set() 114 | 115 | for frame_id in frame_ids: 116 | # Check if already labeled 117 | if frame_id in visited: 118 | continue 119 | 120 | # New component 121 | clusters.append([]) 122 | queue = {frame_id} 123 | while len(queue): 124 | exploration_frame = queue.pop() 125 | 126 | # Already part of the component 127 | if exploration_frame in visited: 128 | continue 129 | visited.add(exploration_frame) 130 | clusters[-1].append(exploration_frame) 131 | 132 | observed = all_images[exploration_frame].point3D_ids 133 | connected_frames = set( 134 | j for i in observed if i != -1 for j in points3D[i].image_ids) 135 | connected_frames &= set(frame_ids) 136 | connected_frames -= visited 137 | queue |= connected_frames 138 | 139 | clusters = sorted(clusters, key=len, reverse=True) 140 | return clusters 141 | -------------------------------------------------------------------------------- /pixloc/localization/refiners.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Optional, List 3 | 4 | from .base_refiner import BaseRefiner 5 | from ..pixlib.geometry import Pose, Camera 6 | from ..utils.colmap import qvec2rotmat 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class PoseRefiner(BaseRefiner): 12 | default_config = dict( 13 | min_matches_total=10, 14 | ) 15 | 16 | def refine(self, qname: str, qcamera: Camera, loc: Dict) -> Dict: 17 | # Unpack initial query pose 18 | T_init = Pose.from_Rt(qvec2rotmat(loc["PnP_ret"]["qvec"]), 19 | loc["PnP_ret"]["tvec"]) 20 | fail = {'success': False, 'T_init': T_init} 21 | 22 | num_inliers = loc["PnP_ret"]["num_inliers"] 23 | if num_inliers < self.conf.min_matches_total: 24 | logger.debug(f"Too few inliers: {num_inliers}") 25 | return fail 26 | 27 | # Fetch database inlier matches count 28 | dbids = loc["db"] 29 | inliers = loc["PnP_ret"]["inliers"] 30 | ninl_dbs = self.model3d.get_db_inliers(loc, dbids, inliers) 31 | 32 | # Re-rank and filter database images 33 | dbids = self.model3d.rerank_and_filter_db_images( 34 | dbids, ninl_dbs, self.conf.num_dbs, self.conf.min_matches_db) 35 | 36 | # Abort if no image matches the minimum number of inliers criterion 37 | if len(dbids) == 0: 38 | logger.debug("No DB image with min num matches") 39 | return fail 40 | 41 | # Select the 3D points and collect their observations 42 | p3did_to_dbids = self.model3d.get_p3did_to_dbids( 43 | dbids, loc, inliers, self.conf.point_selection, 44 | self.conf.min_track_length) 45 | 46 | # Abort if there are not enough 3D points after filtering 47 | if len(p3did_to_dbids) < self.conf.min_points_opt: 48 | logger.debug("Not enough valid 3D points to optimize") 49 | return fail 50 | 51 | ret = self.refine_query_pose(qname, qcamera, T_init, p3did_to_dbids) 52 | ret = {**ret, 'dbids': dbids} 53 | return ret 54 | 55 | 56 | class RetrievalRefiner(BaseRefiner): 57 | default_config = dict( 58 | multiscale=None, 59 | filter_covisibility=False, 60 | do_pose_approximation=False, 61 | do_inlier_ranking=False, 62 | ) 63 | 64 | def __init__(self, *args, **kwargs): 65 | self.global_descriptors = kwargs.pop('global_descriptors', None) 66 | super().__init__(*args, **kwargs) 67 | 68 | def refine(self, qname: str, qcamera: Camera, dbids: List[int], 69 | loc: Optional[Dict] = None) -> Dict: 70 | 71 | if self.conf.do_inlier_ranking: 72 | assert loc is not None 73 | 74 | if self.conf.do_inlier_ranking and loc['PnP_ret']['success']: 75 | inliers = loc['PnP_ret']['inliers'] 76 | ninl_dbs = self.model3d.get_db_inliers(loc, dbids, inliers) 77 | dbids = self.model3d.rerank_and_filter_db_images( 78 | dbids, ninl_dbs, self.conf.num_dbs, 79 | self.conf.min_matches_db) 80 | else: 81 | assert self.conf.point_selection == 'all' 82 | dbids = dbids[:self.conf.num_dbs] 83 | if self.conf.do_pose_approximation or self.conf.filter_covisibility: 84 | dbids = self.model3d.covisbility_filtering(dbids) 85 | inliers = None 86 | 87 | if self.conf.do_pose_approximation: 88 | if self.global_descriptors is None: 89 | raise RuntimeError( 90 | 'Pose approximation requires global descriptors') 91 | Rt_init = self.model3d.pose_approximation( 92 | qname, dbids, self.global_descriptors) 93 | else: 94 | id_init = dbids[0] 95 | image_init = self.model3d.dbs[id_init] 96 | Rt_init = (image_init.qvec2rotmat(), image_init.tvec) 97 | T_init = Pose.from_Rt(*Rt_init) 98 | fail = {'success': False, 'T_init': T_init, 'dbids': dbids} 99 | 100 | p3did_to_dbids = self.model3d.get_p3did_to_dbids( 101 | dbids, loc, inliers, self.conf.point_selection, 102 | self.conf.min_track_length) 103 | 104 | # Abort if there are not enough 3D points after filtering 105 | if len(p3did_to_dbids) < self.conf.min_points_opt: 106 | logger.debug("Not enough valid 3D points to optimize") 107 | return fail 108 | 109 | ret = self.refine_query_pose(qname, qcamera, T_init, p3did_to_dbids, 110 | self.conf.multiscale) 111 | ret = {**ret, 'dbids': dbids} 112 | return ret 113 | -------------------------------------------------------------------------------- /pixloc/localization/tracker.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | 4 | class BaseTracker: 5 | def __init__(self, refiner): 6 | # attach the tracker to the refiner 7 | refiner.tracker = self 8 | 9 | # attach the tracker to the optimizer(s) 10 | opts = refiner.optimizer 11 | opts = opts if isinstance(opts, (tuple, list)) else [opts] 12 | for opt in opts: 13 | opt.logging_fn = self.log_optim_iter 14 | 15 | def log_dense(self, **args): 16 | raise NotImplementedError 17 | 18 | def log_optim_done(self, **args): 19 | raise NotImplementedError 20 | 21 | def log_optim_iter(self, **args): 22 | raise NotImplementedError 23 | 24 | 25 | class SimpleTracker(BaseTracker): 26 | def __init__(self, refiner): 27 | super().__init__(refiner) 28 | 29 | self.dense = defaultdict(dict) 30 | self.costs = [] 31 | self.T = [] 32 | self.dt = [] 33 | self.p3d = None 34 | self.p3d_ids = None 35 | self.num_iters = [] 36 | 37 | def log_dense(self, **args): 38 | feats = [f.cpu() for f in args['features']] 39 | weights = [w.cpu()[0] for w in args['weight']] 40 | data = (args['image'], feats, weights) 41 | self.dense[args['name']][args['image_scale']] = data 42 | 43 | def log_optim_done(self, **args): 44 | self.p3d = args['p3d'] 45 | self.p3d_ids = args['p3d_ids'] 46 | 47 | def log_optim_iter(self, **args): 48 | if args['i'] == 0: # new scale or level 49 | self.costs.append([]) 50 | self.T.append(args['T_init'].cpu()) 51 | self.num_iters.append(None) 52 | 53 | valid = args['valid'].float() 54 | cost = ((valid*args['cost']).sum(-1)/valid.sum(-1)) 55 | 56 | self.costs[-1].append(cost.cpu().numpy()) 57 | self.dt.append(args['T_delta'].magnitude()[1].cpu().numpy()) 58 | self.num_iters[-1] = args['i']+1 59 | self.T.append(args['T'].cpu()) 60 | -------------------------------------------------------------------------------- /pixloc/pixlib/README.md: -------------------------------------------------------------------------------- 1 | # PixLib - training library 2 | 3 | `pixlib` is built on top of a framework whose core principles are: 4 | 5 | - modularity: it is easy to add a new dataset or model with custom loss and metrics; 6 | - reusability: components like geometric primitives, training loop, or experiment tools are reused across projects; 7 | - reproducibility: a training run is parametrized by a configuration, which is saved and reused for evaluation; 8 | - simplicity: it has few external dependencies, and can be easily grasped by a new user. 9 | 10 | ## Framework structure 11 | `pixlib` includes of the following components: 12 | - [`datasets/`](./datasets) contains the dataloaders, all inherited from [`BaseDataset`](./datasets/base_dataset.py). Each loader is configurable and produces a set of batched data dictionaries. 13 | - [`models/`](./models) contains the deep networks and learned blocks, all inherited from [`BaseModel`](./models/base_model.py). Each model is configurable, takes as input data, and outputs predictions. It also exposes its own loss and evaluation metrics. 14 | - [`geometry/`](pixlib/geometry) groups Numpy/PyTorch primitives for 3D vision: poses and camera models, linear algebra, optimization, etc. 15 | - [`utils/`](./utils) contains various utilities, for example to [manage experiments](./utils/experiments.py). 16 | 17 | Datasets, models, and training runs are parametrized by [omegaconf](https://github.com/omry/omegaconf) configurations. See examples of training configurations in [`configs/`](./configs/) as `.yaml` files. 18 | 19 | ## Workflow 20 |
21 | Training:
22 | 23 | The following command starts a new training run: 24 | ```bash 25 | python3 -m pixloc.pixlib.train experiment_name \ 26 | --conf pixloc/pixlib/configs/config_name.yaml 27 | ``` 28 | 29 | It creates a new directory `experiment_name/` in `TRAINING_PATH` and dumps the configuration, model checkpoints, logs of stdout, and [Tensorboard](https://pytorch.org/docs/stable/tensorboard.html) summaries. 30 | 31 | Extra flags can be given: 32 | 33 | - `--overfit` loops the training and validation sets on a single batch ([useful to test losses and metrics](http://karpathy.github.io/2019/04/25/recipe/)). 34 | - `--restore` restarts the training from the last checkpoint (last epoch) of the same experiment. 35 | - `--distributed` uses all GPUs available with multiple processes and batch norm synchronization. 36 | - individual configuration entries to overwrite the YAML entries. Examples: `train.lr=0.001` or `data.batch_size=8`. 37 | 38 | **Monitoring the training:** Launch a Tensorboard session with `tensorboard --logdir=path/to/TRAINING_PATH` to visualize losses and metrics, and compare them across experiments. Press `Ctrl+C` to gracefully interrupt the training. 39 |
40 | 41 |
42 | Inference with a trained model:
43 | 44 | After training, you can easily load a model to evaluate it: 45 | ```python 46 | from pixloc.pixlib.utils.experiments import load_experiment 47 | 48 | test_conf = {} # will overwrite the training and default configurations 49 | model = load_experiment('name_of_my_experiment', test_conf) 50 | model = model.eval().cuda() # optionally move the model to GPU 51 | predictions = model(data) # data is a dictionary of tensors 52 | ``` 53 | 54 |
55 | 56 |
57 | Adding new datasets or models:
58 | 59 | We simply need to create a new file in [`datasets/`](./datasets/) or [`models/`](./models/). This makes it easy to collaborate on the same codebase. Each class should inherit from the base class, declare a `default_conf`, and define some specific methods. Have a look at the base files [`BaseDataset`](./datasets/base_dataset.py) and [`BaseModel`](./models/base_model.py) for more details. Please follow [PEP 8](https://www.python.org/dev/peps/pep-0008/) and use relative imports. 60 | 61 |
62 | -------------------------------------------------------------------------------- /pixloc/pixlib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/pixloc/pixlib/__init__.py -------------------------------------------------------------------------------- /pixloc/pixlib/configs/train_gnnet_cmu.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: cmu 3 | min_overlap: 0.3 4 | max_overlap: 1.0 5 | max_num_points3D: 512 6 | force_num_points3D: true 7 | max_baseline: 7.0 8 | resize: 512 9 | resize_by: min 10 | crop: 512 11 | optimal_crop: false 12 | batch_size: 4 13 | num_workers: 6 14 | seed: 1 15 | model: 16 | name: gnnet 17 | extractor: 18 | name: unet 19 | encoder: vgg16 20 | decoder: [64, 64, 64, 32] 21 | output_scales: [0, 2, 4] 22 | output_dim: [32, 128, 128] 23 | freeze_batch_normalization: false 24 | do_average_pooling: false 25 | compute_uncertainty: false 26 | checkpointed: true 27 | optimizer: 28 | num_iters: 15 29 | pad: 3 30 | lambda_: 0.01 31 | verbose: false 32 | loss_fn: scaled_barron(0, 0.1) 33 | jacobi_scaling: false 34 | normalize_features: true 35 | loss: 36 | margin_positive: 0.2 37 | margin_negative: 1 38 | num_top_negative_sampling: 200 39 | gauss_newton_magnitude: 1.0 40 | gauss_newton_weight: 0.1 41 | contrastive_weight: 1 42 | train: 43 | seed: 0 44 | epochs: 200 45 | log_every_iter: 50 46 | eval_every_iter: 500 47 | dataset_callback_fn: sample_new_items 48 | lr: 1.0e-06 49 | median_metrics: 50 | - loss/reprojection_error 51 | - loss/reprojection_error/init 52 | - R_error 53 | - t_error 54 | -------------------------------------------------------------------------------- /pixloc/pixlib/configs/train_pixloc_cmu.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: cmu 3 | min_overlap: 0.3 4 | max_overlap: 1.0 5 | max_num_points3D: 512 6 | force_num_points3D: true 7 | max_baseline: 7.0 8 | resize: 720 9 | resize_by: min 10 | crop: 720 11 | optimal_crop: false 12 | batch_size: 3 13 | num_workers: 6 14 | seed: 1 15 | model: 16 | name: two_view_refiner 17 | success_thresh: 3 18 | normalize_features: true 19 | duplicate_optimizer_per_scale: true 20 | normalize_dt: false 21 | extractor: 22 | name: unet 23 | encoder: vgg16 24 | decoder: [64, 64, 64, 32] 25 | output_scales: [0, 2, 4] 26 | output_dim: [32, 128, 128] 27 | freeze_batch_normalization: false 28 | do_average_pooling: false 29 | compute_uncertainty: true 30 | checkpointed: true 31 | optimizer: 32 | name: learned_optimizer 33 | num_iters: 15 34 | pad: 3 35 | lambda_: 0.01 36 | verbose: false 37 | loss_fn: scaled_barron(0, 0.1) 38 | jacobi_scaling: false 39 | learned_damping: true 40 | damping: 41 | type: constant 42 | train: 43 | seed: 0 44 | epochs: 200 45 | log_every_iter: 50 46 | eval_every_iter: 500 47 | dataset_callback_fn: sample_new_items 48 | lr: 1.0e-05 49 | lr_scaling: [[100, ['dampingnet.const']]] 50 | median_metrics: 51 | - loss/reprojection_error 52 | - loss/reprojection_error/init 53 | clip_grad: 1.0 54 | -------------------------------------------------------------------------------- /pixloc/pixlib/configs/train_pixloc_megadepth.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: megadepth 3 | min_overlap: 0.4 4 | max_overlap: 1.0 5 | max_num_points3D: 512 6 | force_num_points3D: true 7 | resize_by: min 8 | resize: 512 9 | crop: 512 10 | optimal_crop: false 11 | init_pose: [0.75, 1.0] 12 | train_num_per_scene: 150 13 | val_num_per_scene: 8 14 | batch_size: 6 15 | num_workers: 8 16 | seed: 1 17 | model: 18 | name: two_view_refiner 19 | success_thresh: 3 20 | normalize_features: true 21 | duplicate_optimizer_per_scale: true 22 | extractor: 23 | name: unet 24 | encoder: vgg19 25 | decoder: [64, 64, 64, 32] 26 | output_scales: [0, 2, 4] 27 | output_dim: [32, 128, 128] 28 | freeze_batch_normalization: false 29 | do_average_pooling: false 30 | compute_uncertainty: true 31 | checkpointed: true 32 | optimizer: 33 | name: learned_optimizer 34 | num_iters: 15 35 | pad: 2 36 | lambda_: 0.01 37 | verbose: false 38 | loss_fn: scaled_barron(0, 0.1) 39 | jacobi_scaling: false 40 | learned_damping: true 41 | damping: 42 | type: constant 43 | train: 44 | seed: 0 45 | epochs: 200 46 | log_every_iter: 50 47 | eval_every_iter: 500 48 | dataset_callback_fn: sample_new_items 49 | lr: 5.0e-06 50 | lr_scaling: [[100, ['dampingnet.const']]] 51 | median_metrics: 52 | - loss/reprojection_error 53 | - loss/reprojection_error/init 54 | - R_error 55 | - t_error 56 | clip_grad: 1.0 57 | -------------------------------------------------------------------------------- /pixloc/pixlib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils.tools import get_class 2 | from .base_dataset import BaseDataset 3 | 4 | 5 | def get_dataset(name): 6 | return get_class(name, __name__, BaseDataset) 7 | -------------------------------------------------------------------------------- /pixloc/pixlib/datasets/cmu.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | import numpy as np 4 | import logging 5 | import torch 6 | import pickle 7 | 8 | from .base_dataset import BaseDataset 9 | from .view import read_view 10 | from ..geometry import Camera, Pose 11 | from ...settings import DATA_PATH 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | CAMERAS = '''c0 OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571 17 | c1 OPENCV 1024 768 873.382641 876.489513 529.324138 397.272397 -0.397066 0.181925 0.000176 -0.000579''' 18 | 19 | 20 | class CMU(BaseDataset): 21 | default_conf = { 22 | 'dataset_dir': 'CMU/', 23 | 'info_dir': 'cmu_pixloc_training/', 24 | 25 | 'train_slices': [8, 9, 10, 11, 12, 22, 23, 24, 25], 26 | 'val_slices': [6, 13, 21], 27 | 'train_num_per_slice': 1000, 28 | 'val_num_per_slice': 80, 29 | 30 | 'two_view': True, 31 | 'min_overlap': 0.3, 32 | 'max_overlap': 1., 33 | 'min_baseline': None, 34 | 'max_baseline': None, 35 | 'sort_by_overlap': False, 36 | 37 | 'grayscale': False, 38 | 'resize': None, 39 | 'resize_by': 'max', 40 | 'crop': None, 41 | 'pad': None, 42 | 'optimal_crop': True, 43 | 'seed': 0, 44 | 45 | 'max_num_points3D': 512, 46 | 'force_num_points3D': False, 47 | } 48 | 49 | def _init(self, conf): 50 | pass 51 | 52 | def get_dataset(self, split): 53 | assert split != 'test', 'Not supported' 54 | return _Dataset(self.conf, split) 55 | 56 | 57 | class _Dataset(torch.utils.data.Dataset): 58 | def __init__(self, conf, split): 59 | self.root = Path(DATA_PATH, conf.dataset_dir) 60 | self.slices = conf.get(split+'_slices') 61 | self.conf, self.split = conf, split 62 | 63 | self.info = {} 64 | for slice_ in self.slices: 65 | path = Path(DATA_PATH, self.conf.info_dir, f'slice{slice_}.pkl') 66 | assert path.exists(), path 67 | with open(path, 'rb') as f: 68 | info = pickle.load(f) 69 | self.info[slice_] = {k: info[k] for k in info if 'matrix' not in k} 70 | 71 | self.cameras = {} 72 | for c in CAMERAS.split('\n'): 73 | data = c.split() 74 | name, camera_model, width, height = data[:4] 75 | params = np.array(data[4:], float) 76 | camera = Camera.from_colmap(dict( 77 | model=camera_model, params=params, 78 | width=int(width), height=int(height))) 79 | self.cameras[name] = camera 80 | 81 | self.sample_new_items(conf.seed) 82 | 83 | def sample_new_items(self, seed): 84 | logger.info(f'Sampling new images or pairs with seed {seed}') 85 | self.items = [] 86 | for slice_ in tqdm(self.slices): 87 | num = self.conf[self.split+'_num_per_slice'] 88 | 89 | if self.conf.two_view: 90 | path = Path( 91 | DATA_PATH, self.conf.info_dir, f'slice{slice_}.pkl') 92 | assert path.exists(), path 93 | with open(path, 'rb') as f: 94 | info = pickle.load(f) 95 | 96 | mat = info['query_overlap_matrix'] 97 | pairs = ( 98 | (mat > self.conf.min_overlap) 99 | & (mat <= self.conf.max_overlap)) 100 | if self.conf.min_baseline: 101 | pairs &= (info['query_to_ref_distance_matrix'] 102 | > self.conf.min_baseline) 103 | if self.conf.max_baseline: 104 | pairs &= (info['query_to_ref_distance_matrix'] 105 | < self.conf.max_baseline) 106 | pairs = np.stack(np.where(pairs), -1) 107 | if len(pairs) > num: 108 | selected = np.random.RandomState(seed).choice( 109 | len(pairs), num, replace=False) 110 | pairs = pairs[selected] 111 | pairs = [(slice_, i, j, mat[i, j]) for i, j in pairs] 112 | self.items.extend(pairs) 113 | else: 114 | ids = np.arange(len(self.images[slice_])) 115 | if len(ids) > num: 116 | ids = np.random.RandomState(seed).choice( 117 | ids, num, replace=False) 118 | ids = [(slice_, i) for i in ids] 119 | self.items.extend(ids) 120 | 121 | if self.conf.two_view and self.conf.sort_by_overlap: 122 | self.items.sort(key=lambda i: i[-1], reverse=True) 123 | else: 124 | np.random.RandomState(seed).shuffle(self.items) 125 | 126 | def _read_view(self, slice_, idx, common_p3D_idx, is_reference=False): 127 | prefix = 'ref' if is_reference else 'query' 128 | path = self.root / f'slice{slice_}/' 129 | path /= 'database' if is_reference else 'query' 130 | path /= self.info[slice_][f'{prefix}_image_names'][idx] 131 | 132 | camera = self.cameras[path.name.split('_')[2]] 133 | R, t = self.info[slice_][f'{prefix}_poses'][idx] 134 | T = Pose.from_Rt(R, t) 135 | p3D = self.info[slice_]['points3D'] 136 | data = read_view(self.conf, path, camera, T, p3D, common_p3D_idx, 137 | random=(self.split == 'train')) 138 | data['index'] = idx 139 | assert (tuple(data['camera'].size.numpy()) 140 | == data['image'].shape[1:][::-1]) 141 | 142 | if is_reference: 143 | obs = self.info[slice_]['p3D_observed'][idx] 144 | if self.conf.crop: 145 | _, valid = data['camera'].world2image(data['T_w2cam']*p3D[obs]) 146 | obs = obs[valid.numpy()] 147 | num_diff = self.conf.max_num_points3D - len(obs) 148 | if num_diff < 0: 149 | obs = np.random.choice(obs, self.conf.max_num_points3D) 150 | elif num_diff > 0 and self.conf.force_num_points3D: 151 | add = np.random.choice( 152 | np.delete(np.arange(len(p3D)), obs), num_diff) 153 | obs = np.r_[obs, add] 154 | data['points3D'] = data['T_w2cam'] * p3D[obs] 155 | return data 156 | 157 | def __getitem__(self, idx): 158 | if self.conf.two_view: 159 | slice_, idx_q, idx_r, overlap = self.items[idx] 160 | obs_r = self.info[slice_]['p3D_observed'][idx_r] 161 | obs_q = self.info[slice_]['p3D_observed'][ 162 | self.info[slice_]['query_closest_indices'][idx_q]] 163 | common = np.array(list(set(obs_r) & set(obs_q))) 164 | 165 | data_r = self._read_view(slice_, idx_r, common, is_reference=True) 166 | data_q = self._read_view(slice_, idx_q, common) 167 | data = { 168 | 'ref': data_r, 169 | 'query': data_q, 170 | 'overlap': overlap, 171 | 'T_r2q_init': Pose.from_4x4mat(np.eye(4, dtype=np.float32)), 172 | 'T_r2q_gt': data_q['T_w2cam'] @ data_r['T_w2cam'].inv(), 173 | } 174 | else: 175 | slice_, idx = self.items[idx] 176 | data = self._read_view(slice_, idx, is_reference=True) 177 | data['scene'] = slice_ 178 | return data 179 | 180 | def __len__(self): 181 | return len(self.items) 182 | -------------------------------------------------------------------------------- /pixloc/pixlib/datasets/image_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simply load images from a folder or nested folders (does not have any split). 3 | """ 4 | 5 | from pathlib import Path 6 | import torch 7 | import cv2 8 | import numpy as np 9 | import logging 10 | import omegaconf 11 | 12 | from .base_dataset import BaseDataset 13 | from .utils.preprocessing import resize, numpy_image_to_torch 14 | 15 | 16 | class ImageFolder(BaseDataset, torch.utils.data.Dataset): 17 | default_conf = { 18 | 'glob': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'], 19 | 'grayscale': False, 20 | 'images': '???', 21 | 'resize': None, 22 | 'resize_by': 'max', 23 | 'interpolation': 'linear', 24 | 'root_folder': '/', 25 | } 26 | 27 | def _init(self, conf): 28 | self.root = conf.root_folder 29 | if isinstance(conf.images, str): 30 | if not Path(conf.images).is_dir(): 31 | with open(conf.images, 'r') as f: 32 | self.images = f.read().rstrip('\n').split('\n') 33 | logging.info(f'Found {len(self.images)} images in list file.') 34 | else: 35 | self.images = [] 36 | glob = [conf.glob] if isinstance(conf.glob, str) else conf.glob 37 | for g in glob: 38 | self.images += list(Path(conf.images).glob('**/'+g)) 39 | if len(self.images) == 0: 40 | raise ValueError( 41 | f'Could not find any image in folder: {conf.images}.') 42 | self.images = [i.relative_to(conf.images) for i in self.images] 43 | self.root = conf.images 44 | logging.info(f'Found {len(self.images)} images in folder.') 45 | elif isinstance(conf.images, omegaconf.listconfig.ListConfig): 46 | self.images = conf.images.to_container() 47 | else: 48 | raise ValueError(conf.images) 49 | 50 | def get_dataset(self, split): 51 | return self 52 | 53 | def __getitem__(self, idx): 54 | path = self.images[idx] 55 | if self.conf.grayscale: 56 | mode = cv2.IMREAD_GRAYSCALE 57 | else: 58 | mode = cv2.IMREAD_COLOR 59 | img = cv2.imread(str(Path(self.root, path)), mode) 60 | if img is None: 61 | logging.warning(f'Image {str(path)} could not be read.') 62 | img = np.zeros((1024, 1024)+(() if self.conf.grayscale else (3,))) 63 | img = img.astype(np.float32) 64 | size = img.shape[:2][::-1] 65 | 66 | if self.conf.resize: 67 | args = {'interp': self.conf.interpolation} 68 | h, w = img.shape[:2] 69 | if self.conf.resize_by in ['max', 'force-max']: 70 | if ((self.conf.resize_by == 'force-max') or 71 | (max(h, w) > self.conf.resize)): 72 | img, _ = resize(img, self.conf.resize, fn=max, **args) 73 | elif self.conf.resize_by == 'min': 74 | if min(h, w) < self.conf.resize: 75 | img, _ = resize(img, self.conf.resize, fn=min, **args) 76 | else: 77 | img, _ = resize(img, self.conf.resize, **args) 78 | 79 | data = { 80 | 'name': str(path), 81 | 'image': numpy_image_to_torch(img), 82 | 'original_image_size': np.array(size), 83 | } 84 | return data 85 | 86 | def __len__(self): 87 | return len(self.images) 88 | -------------------------------------------------------------------------------- /pixloc/pixlib/datasets/megadepth.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import collections 3 | from tqdm import tqdm 4 | import numpy as np 5 | import logging 6 | import torch 7 | import pickle 8 | 9 | from .base_dataset import BaseDataset 10 | from .view import read_view 11 | from .sampling import sample_pose_interval, sample_pose_reprojection 12 | from ..geometry import Camera, Pose 13 | from ...settings import DATA_PATH 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class MegaDepth(BaseDataset): 19 | default_conf = { 20 | 'dataset_dir': 'megadepth/', 21 | 'depth_subpath': 'phoenix/S6/zl548/MegaDepth_v1/{}/dense0/depths/', 22 | 'image_subpath': 'Undistorted_SfM/{}/images/', 23 | 'info_dir': 'megadepth_pixloc_training/', 24 | 25 | 'train_split': 'train_scenes.txt', 26 | 'val_split': 'valid_scenes.txt', 27 | 'train_num_per_scene': 500, 28 | 'val_num_per_scene': 10, 29 | 30 | 'two_view': True, 31 | 'min_overlap': 0.3, 32 | 'max_overlap': 1., 33 | 'sort_by_overlap': False, 34 | 'init_pose': None, 35 | 'init_pose_max_error': 63, 36 | 'init_pose_num_samples': 20, 37 | 38 | 'read_depth': False, 39 | 'grayscale': False, 40 | 'resize': None, 41 | 'resize_by': 'max', 42 | 'crop': None, 43 | 'pad': None, 44 | 'optimal_crop': True, 45 | 'seed': 0, 46 | 47 | 'max_num_points3D': 500, 48 | 'force_num_points3D': False, 49 | } 50 | 51 | def _init(self, conf): 52 | pass 53 | 54 | def get_dataset(self, split): 55 | assert split != 'test', 'Not supported' 56 | return _Dataset(self.conf, split) 57 | 58 | 59 | class _Dataset(torch.utils.data.Dataset): 60 | def __init__(self, conf, split): 61 | if conf.init_pose is None: 62 | raise ValueError('The initial pose sampling strategy is required.') 63 | 64 | self.root = Path(DATA_PATH, conf.dataset_dir) 65 | with open(Path(__file__).parent / conf[split+'_split'], 'r') as f: 66 | self.scenes = f.read().split() 67 | self.conf, self.split = conf, split 68 | 69 | self.sample_new_items(conf.seed) 70 | 71 | def sample_new_items(self, seed): 72 | logger.info(f'Sampling new images or pairs with seed {seed}') 73 | self.images, self.poses, self.intrinsics = {}, {}, {} 74 | self.rotations, self.points3D, self.p3D_observed = {}, {}, {} 75 | self.items = [] 76 | for scene in tqdm(self.scenes): 77 | path = Path(DATA_PATH, self.conf.info_dir, scene + '.pkl') 78 | if not path.exists(): 79 | logger.warning(f'Scene {scene} does not have an info file') 80 | continue 81 | with open(path, 'rb') as f: 82 | info = pickle.load(f) 83 | num = self.conf[self.split+'_num_per_scene'] 84 | 85 | self.images[scene] = info['image_names'] 86 | self.rotations[scene] = info['rotations'] 87 | self.points3D[scene] = info['points3D'] 88 | self.p3D_observed[scene] = info['p3D_observed'] 89 | self.poses[scene] = info['poses'] 90 | self.intrinsics[scene] = info['intrinsics'] 91 | 92 | if self.conf.two_view: 93 | mat = info['overlap_matrix'] 94 | pairs = ( 95 | (mat > self.conf.min_overlap) 96 | & (mat <= self.conf.max_overlap)) 97 | pairs = np.stack(np.where(pairs), -1) 98 | if len(pairs) > num: 99 | selected = np.random.RandomState(seed).choice( 100 | len(pairs), num, replace=False) 101 | pairs = pairs[selected] 102 | pairs = [(scene, i, j, mat[i, j]) for i, j in pairs] 103 | self.items.extend(pairs) 104 | else: 105 | ids = np.arange(len(self.images[scene])) 106 | if len(ids) > num: 107 | ids = np.random.RandomState(seed).choice( 108 | ids, num, replace=False) 109 | ids = [(scene, i) for i in ids] 110 | self.items.extend(ids) 111 | 112 | if self.conf.two_view and self.conf.sort_by_overlap: 113 | self.items.sort(key=lambda i: i[-1], reverse=True) 114 | else: 115 | np.random.RandomState(seed).shuffle(self.items) 116 | 117 | def _read_view(self, scene, idx, common_p3D_idx, is_reference=False): 118 | path = self.root / self.conf.image_subpath.format(scene) 119 | path /= self.images[scene][idx] 120 | 121 | if self.conf.read_depth: 122 | raise NotImplementedError 123 | 124 | K = self.intrinsics[scene][idx] 125 | camera = Camera.from_colmap(dict( 126 | model='PINHOLE', width=K[0, 2]*2, height=K[1, 2]*2, 127 | params=K[[0, 1, 0, 1], [0, 1, 2, 2]])) 128 | T = Pose.from_Rt(*self.poses[scene][idx]) 129 | rotation = self.rotations[scene][idx] 130 | p3D = self.points3D[scene] 131 | data = read_view(self.conf, path, camera, T, p3D, common_p3D_idx, 132 | rotation=rotation, random=(self.split == 'train')) 133 | data['index'] = idx 134 | assert (tuple(data['camera'].size.numpy()) 135 | == data['image'].shape[1:][::-1]) 136 | 137 | if is_reference: 138 | obs = self.p3D_observed[scene][idx] 139 | if self.conf.crop: 140 | _, valid = data['camera'].world2image(data['T_w2cam']*p3D[obs]) 141 | obs = obs[valid.numpy()] 142 | num_diff = self.conf.max_num_points3D - len(obs) 143 | if num_diff < 0: 144 | obs = np.random.choice(obs, self.conf.max_num_points3D) 145 | elif num_diff > 0 and self.conf.force_num_points3D: 146 | add = np.random.choice( 147 | np.delete(np.arange(len(p3D)), obs), num_diff) 148 | obs = np.r_[obs, add] 149 | data['points3D'] = data['T_w2cam'] * p3D[obs] 150 | return data 151 | 152 | def __getitem__(self, idx): 153 | if self.conf.two_view: 154 | scene, idx_r, idx_q, overlap = self.items[idx] 155 | common = np.array(list(set(self.p3D_observed[scene][idx_r]) 156 | & set(self.p3D_observed[scene][idx_q]))) 157 | 158 | data_r = self._read_view(scene, idx_r, common, is_reference=True) 159 | data_q = self._read_view(scene, idx_q, common) 160 | data = { 161 | 'ref': data_r, 162 | 'query': data_q, 163 | 'overlap': overlap, 164 | 'T_r2q_gt': data_q['T_w2cam'] @ data_r['T_w2cam'].inv(), 165 | } 166 | 167 | if self.conf.init_pose == 'identity': 168 | T_init = Pose.from_4x4mat(np.eye(4)) 169 | elif self.conf.init_pose == 'max_error': 170 | T_init = sample_pose_reprojection( 171 | data['T_r2q_gt'], data_q['camera'], data_r['points3D'], 172 | self.conf.seed+idx, self.conf.init_pose_num_samples, 173 | self.conf.init_pose_max_error) 174 | elif isinstance(self.conf.init_pose, collections.abc.Sequence): 175 | T_init = sample_pose_interval( 176 | data['T_r2q_gt'], self.conf.init_pose, self.conf.seed+idx) 177 | else: 178 | raise ValueError(self.conf.init_pose) 179 | data['T_r2q_init'] = T_init 180 | else: 181 | scene, idx = self.items[idx] 182 | data = self._read_view(scene, idx, is_reference=True) 183 | data['scene'] = scene 184 | return data 185 | 186 | def __len__(self): 187 | return len(self.items) 188 | -------------------------------------------------------------------------------- /pixloc/pixlib/datasets/sampling.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | import torch 3 | import numpy as np 4 | import cv2 5 | 6 | from ..geometry import Pose, Camera 7 | 8 | 9 | def sample_pose_reprojection( 10 | T_r2q: Pose, camera: Camera, p3D_r: np.ndarray, seed: int, 11 | num_samples: int, max_err: Union[int, float, Tuple[int, float]], 12 | min_vis: int = 10): 13 | 14 | R0, t0 = T_r2q.R, T_r2q.t 15 | w0 = cv2.Rodrigues(R0.numpy())[0][:, 0] 16 | 17 | s = torch.linspace(0, 1, num_samples+1)[:, None] 18 | Ts = Pose.from_aa(torch.from_numpy(w0)[None] * s, t0[None] * s) 19 | 20 | p2Ds, vis = camera.world2image(Ts * p3D_r) 21 | p2Ds, vis = p2Ds.numpy(), vis.numpy() 22 | 23 | p2D0, vis0 = p2Ds[-1], vis[-1] 24 | err = np.linalg.norm(p2Ds - p2D0, axis=-1) 25 | err = np.where(vis & vis0, err, np.nan) 26 | valid = ~np.all(np.isnan(err), -1) 27 | err = np.where(valid[:, None], err, np.inf) 28 | err = np.nanmedian(err, -1) 29 | nvis = np.sum(vis & vis0, -1) 30 | 31 | if not isinstance(max_err, (int, float)): 32 | max_err = np.random.RandomState(seed).uniform(*max_err) 33 | valid = (nvis >= min_vis) & (err < max_err) 34 | if valid.any(): 35 | idx = np.where(valid)[0][0] 36 | else: 37 | idx = -1 38 | return Ts[idx] 39 | 40 | 41 | def sample_pose_interval(T_r2q: Pose, interval: Tuple[float], seed: int): 42 | a = np.random.RandomState(seed).uniform(*interval) 43 | R, t = T_r2q.numpy() 44 | t = t * a 45 | w = cv2.Rodrigues(R)[0][:, 0] * a 46 | T = Pose.from_Rt(cv2.Rodrigues(w)[0], t) 47 | return T 48 | -------------------------------------------------------------------------------- /pixloc/pixlib/datasets/train_scenes.txt: -------------------------------------------------------------------------------- 1 | 0000 2 | 0001 3 | 0002 4 | 0003 5 | 0004 6 | 0005 7 | 0007 8 | 0008 9 | 0011 10 | 0012 11 | 0013 12 | 0015 13 | 0017 14 | 0019 15 | 0020 16 | 0021 17 | 0022 18 | 0023 19 | 0024 20 | 0025 21 | 0026 22 | 0027 23 | 0032 24 | 0035 25 | 0036 26 | 0037 27 | 0039 28 | 0042 29 | 0043 30 | 0046 31 | 0048 32 | 0050 33 | 0056 34 | 0057 35 | 0060 36 | 0061 37 | 0063 38 | 0065 39 | 0070 40 | 0080 41 | 0083 42 | 0086 43 | 0087 44 | 0092 45 | 0095 46 | 0098 47 | 0100 48 | 0101 49 | 0103 50 | 0104 51 | 0105 52 | 0107 53 | 0115 54 | 0117 55 | 0122 56 | 0130 57 | 0137 58 | 0143 59 | 0147 60 | 0148 61 | 0149 62 | 0150 63 | 0156 64 | 0160 65 | 0176 66 | 0183 67 | 0189 68 | 0190 69 | 0200 70 | 0214 71 | 0224 72 | 0235 73 | 0237 74 | 0240 75 | 0243 76 | 0258 77 | 0265 78 | 0269 79 | 0299 80 | 0312 81 | 0326 82 | 0327 83 | 0331 84 | 0335 85 | 0341 86 | 0348 87 | 0366 88 | 0377 89 | 0380 90 | 0394 91 | 0407 92 | 0411 93 | 0430 94 | 0446 95 | 0455 96 | 0472 97 | 0474 98 | 0476 99 | 0478 100 | 0493 101 | 0494 102 | 0496 103 | 0505 104 | 0559 105 | 0733 106 | 0860 107 | 1017 108 | 1589 109 | 4541 110 | 5004 111 | 5005 112 | 5006 113 | 5007 114 | 5009 115 | 5010 116 | 5012 117 | 5013 118 | 5017 119 | -------------------------------------------------------------------------------- /pixloc/pixlib/datasets/valid_scenes.txt: -------------------------------------------------------------------------------- 1 | 0016 2 | 0033 3 | 0034 4 | 0041 5 | 0044 6 | 0047 7 | 0049 8 | 0058 9 | 0062 10 | 0064 11 | 0067 12 | 0071 13 | 0076 14 | 0078 15 | 0090 16 | 0094 17 | 0099 18 | 0102 19 | 0121 20 | 0129 21 | 0133 22 | 0141 23 | 0151 24 | 0162 25 | 0168 26 | 0175 27 | 0177 28 | 0178 29 | 0181 30 | 0185 31 | 0186 32 | 0197 33 | 0204 34 | 0205 35 | 0209 36 | 0212 37 | 0217 38 | 0223 39 | 0229 40 | 0231 41 | 0238 42 | 0252 43 | 0257 44 | 0271 45 | 0275 46 | 0277 47 | 0281 48 | 0285 49 | 0286 50 | 0290 51 | 0294 52 | 0303 53 | 0306 54 | 0307 55 | 0323 56 | 0349 57 | 0360 58 | 0387 59 | 0389 60 | 0402 61 | 0406 62 | 0412 63 | 0443 64 | 0482 65 | 0768 66 | 1001 67 | 3346 68 | 5000 69 | 5001 70 | 5002 71 | 5003 72 | 5008 73 | 5011 74 | 5014 75 | 5015 76 | 5016 77 | 5018 78 | -------------------------------------------------------------------------------- /pixloc/pixlib/datasets/view.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import cv2 4 | # TODO: consider using PIL instead of OpenCV as it is heavy and only used here 5 | import torch 6 | 7 | from ..geometry import Camera, Pose 8 | 9 | 10 | def numpy_image_to_torch(image): 11 | """Normalize the image tensor and reorder the dimensions.""" 12 | if image.ndim == 3: 13 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW 14 | elif image.ndim == 2: 15 | image = image[None] # add channel axis 16 | else: 17 | raise ValueError(f'Not an image: {image.shape}') 18 | return torch.from_numpy(image / 255.).float() 19 | 20 | 21 | def read_image(path, grayscale=False): 22 | mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR 23 | image = cv2.imread(str(path), mode) 24 | if image is None: 25 | raise IOError(f'Could not read image at {path}.') 26 | if not grayscale: 27 | image = image[..., ::-1] 28 | return image 29 | 30 | 31 | def resize(image, size, fn=None, interp='linear'): 32 | """Resize an image to a fixed size, or according to max or min edge.""" 33 | h, w = image.shape[:2] 34 | if isinstance(size, int): 35 | scale = size / fn(h, w) 36 | h_new, w_new = int(round(h*scale)), int(round(w*scale)) 37 | # TODO: we should probably recompute the scale like in the second case 38 | scale = (scale, scale) 39 | elif isinstance(size, (tuple, list)): 40 | h_new, w_new = size 41 | scale = (w_new / w, h_new / h) 42 | else: 43 | raise ValueError(f'Incorrect new size: {size}') 44 | mode = { 45 | 'linear': cv2.INTER_LINEAR, 46 | 'cubic': cv2.INTER_CUBIC, 47 | 'nearest': cv2.INTER_NEAREST}[interp] 48 | return cv2.resize(image, (w_new, h_new), interpolation=mode), scale 49 | 50 | 51 | def crop(image, size, *, random=True, other=None, camera=None, 52 | return_bbox=False, centroid=None): 53 | """Random or deterministic crop of an image, adjust depth and intrinsics. 54 | """ 55 | h, w = image.shape[:2] 56 | h_new, w_new = (size, size) if isinstance(size, int) else size 57 | if random: 58 | top = np.random.randint(0, h - h_new + 1) 59 | left = np.random.randint(0, w - w_new + 1) 60 | elif centroid is not None: 61 | x, y = centroid 62 | top = np.clip(int(y) - h_new // 2, 0, h - h_new) 63 | left = np.clip(int(x) - w_new // 2, 0, w - w_new) 64 | else: 65 | top = left = 0 66 | 67 | image = image[top:top+h_new, left:left+w_new] 68 | ret = [image] 69 | if other is not None: 70 | ret += [other[top:top+h_new, left:left+w_new]] 71 | if camera is not None: 72 | ret += [camera.crop((left, top), (w_new, h_new))] 73 | if return_bbox: 74 | ret += [(top, top+h_new, left, left+w_new)] 75 | return ret 76 | 77 | 78 | def zero_pad(size, *images): 79 | ret = [] 80 | for image in images: 81 | h, w = image.shape[:2] 82 | padded = np.zeros((size, size)+image.shape[2:], dtype=image.dtype) 83 | padded[:h, :w] = image 84 | ret.append(padded) 85 | return ret 86 | 87 | 88 | def read_view(conf, image_path: Path, camera: Camera, T_w2cam: Pose, 89 | p3D: np.ndarray, p3D_idxs: np.ndarray, *, 90 | rotation=0, random=False): 91 | 92 | img = read_image(image_path, conf.grayscale) 93 | img = img.astype(np.float32) 94 | name = image_path.name 95 | 96 | # we assume that the pose and camera were already rotated during preprocess 97 | if rotation != 0: 98 | img = np.rot90(img, rotation) 99 | 100 | if conf.resize: 101 | scales = (1, 1) 102 | if conf.resize_by == 'max': 103 | img, scales = resize(img, conf.resize, fn=max) 104 | elif (conf.resize_by == 'min' or 105 | (conf.resize_by == 'min_if' 106 | and min(*img.shape[:2]) < conf.resize)): 107 | img, scales = resize(img, conf.resize, fn=min) 108 | if scales != (1, 1): 109 | camera = camera.scale(scales) 110 | 111 | if conf.crop: 112 | if conf.optimal_crop: 113 | p2D, valid = camera.world2image(T_w2cam * p3D[p3D_idxs]) 114 | p2D = p2D[valid].numpy() 115 | centroid = tuple(p2D.mean(0)) if len(p2D) > 0 else None 116 | random = False 117 | else: 118 | centroid = None 119 | img, camera, bbox = crop( 120 | img, conf.crop, random=random, 121 | camera=camera, return_bbox=True, centroid=centroid) 122 | elif conf.pad: 123 | img, = zero_pad(conf.pad, img) 124 | # we purposefully do not update the image size in the camera object 125 | 126 | data = { 127 | 'name': name, 128 | 'image': numpy_image_to_torch(img), 129 | 'camera': camera.float(), 130 | 'T_w2cam': T_w2cam.float(), 131 | } 132 | return data 133 | -------------------------------------------------------------------------------- /pixloc/pixlib/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrappers import Pose, Camera # noqa 2 | -------------------------------------------------------------------------------- /pixloc/pixlib/geometry/check_jacobians.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | from . import Pose, Camera 5 | from .costs import DirectAbsoluteCost 6 | from .interpolation import Interpolator 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def compute_J(fn_J, inp): 12 | with torch.enable_grad(): 13 | return torch.autograd.functional.jacobian(fn_J, inp) 14 | 15 | 16 | def compute_J_batched(fn, inp): 17 | inp_ = inp.reshape(-1) 18 | fn_ = lambda x: fn(x.reshape(inp.shape)) # noqa 19 | J = compute_J(fn_, inp_) 20 | if len(J.shape) != 3: 21 | raise ValueError('Only supports a single leading batch dimension.') 22 | J = J.reshape(J.shape[:-1] + inp.shape) 23 | J = J.diagonal(dim1=0, dim2=-2).permute(2, 0, 1) 24 | return J 25 | 26 | 27 | def local_param(delta): 28 | dt, dw = delta.split(3, dim=-1) 29 | return Pose.from_aa(dw, dt) 30 | 31 | 32 | def toy_problem(seed=0, n_points=500): 33 | torch.random.manual_seed(seed) 34 | aa = torch.randn(3) / 10 35 | t = torch.randn(3) / 5 36 | T_w2q = Pose.from_aa(aa, t) 37 | 38 | w, h = 640, 480 39 | fx, fy = 300., 350. 40 | cx, cy = w/2, h/2 41 | radial = [0.1, 0.01] 42 | camera = Camera(torch.tensor([w, h, fx, fy, cx, cy] + radial)).float() 43 | torch.testing.assert_allclose((w, h), camera.size.long()) 44 | torch.testing.assert_allclose((fx, fy), camera.f) 45 | torch.testing.assert_allclose((cx, cy), camera.c) 46 | 47 | p3D = torch.randn(n_points, 3) 48 | p3D[:, -1] += 2 49 | 50 | dim = 16 51 | F_ref = torch.randn(n_points, dim) 52 | F_query = torch.randn(dim, h, w) 53 | 54 | return T_w2q, camera, p3D, F_ref, F_query 55 | 56 | 57 | def print_J_diff(prefix, J, J_auto): 58 | logger.info('Check J %s: pass=%r, max_diff=%e, shape=%r', 59 | prefix, 60 | torch.allclose(J, J_auto), 61 | torch.abs(J-J_auto).max(), 62 | tuple(J.shape)) 63 | 64 | 65 | def test_J_pose(T: Pose, p3D: torch.Tensor): 66 | J = T.J_transform(T * p3D) 67 | fn = lambda d: (local_param(d) @ T) * p3D # noqa 68 | delta = torch.zeros(6).to(p3D) 69 | J_auto = compute_J(fn, delta) 70 | print_J_diff('pose transform', J, J_auto) 71 | 72 | 73 | def test_J_undistort(camera: Camera, p3D: torch.Tensor): 74 | p2D, valid = camera.project(p3D) 75 | J = camera.J_undistort(p2D) 76 | J_auto = compute_J_batched(camera.undistort, p2D) 77 | J, J_auto = J[valid], J_auto[valid] 78 | print_J_diff('undistort', J, J_auto) 79 | 80 | 81 | def test_J_world2image(camera: Camera, p3D: torch.Tensor): 82 | _, valid = camera.world2image(p3D) 83 | J, _ = camera.J_world2image(p3D) 84 | J_auto = compute_J_batched(lambda x: camera.world2image(x)[0], p3D) 85 | J, J_auto = J[valid], J_auto[valid] 86 | print_J_diff('world2image', J, J_auto) 87 | 88 | 89 | def test_J_geometric_cost(T_w2q: Pose, camera: Camera, p3D: torch.Tensor): 90 | def forward(T): 91 | p3D_q = T * p3D 92 | p2D, visible = camera.world2image(p3D_q) 93 | return p2D, visible, p3D_q 94 | 95 | _, valid, p3D_q = forward(T_w2q) 96 | J = camera.J_world2image(p3D_q)[0] @ T_w2q.J_transform(p3D_q) 97 | delta = torch.zeros(6).to(p3D) 98 | fn = lambda d: forward(local_param(d) @ T_w2q)[0] # noqa 99 | J_auto = compute_J(fn, delta) 100 | J, J_auto = J[valid], J_auto[valid] 101 | print_J_diff('geometric cost', J, J_auto) 102 | 103 | 104 | def test_J_direct_absolute_cost(T_w2q: Pose, camera: Camera, p3D: torch.Tensor, 105 | F_ref, F_query): 106 | interpolator = Interpolator(mode='cubic', pad=2) 107 | cost = DirectAbsoluteCost(interpolator, normalize=True) 108 | 109 | args = (camera, p3D, F_ref, F_query) 110 | res, valid, weight, F_q_p2D, info = cost.residuals( 111 | T_w2q, *args, do_gradients=True) 112 | J, _ = cost.jacobian(T_w2q, camera, *info) 113 | 114 | delta = torch.zeros(6).to(p3D) 115 | fn = lambda d: cost.residuals(local_param(d) @ T_w2q, *args)[0] # noqa 116 | J_auto = compute_J(fn, delta) 117 | 118 | J, J_auto = J[valid], J_auto[valid] 119 | print_J_diff('direct absolute cost', J, J_auto) 120 | 121 | 122 | def main(): 123 | T_w2q, camera, p3D, F_ref, F_query = toy_problem() 124 | test_J_pose(T_w2q, p3D) 125 | test_J_undistort(camera, p3D) 126 | test_J_world2image(camera, p3D) 127 | 128 | # perform the checsk in double precision to factor out numerical errors 129 | T_w2q, camera, p3D, F_ref, F_query = ( 130 | x.to(torch.double) for x in (T_w2q, camera, p3D, F_ref, F_query)) 131 | 132 | test_J_geometric_cost(T_w2q, camera, p3D) 133 | test_J_direct_absolute_cost(T_w2q, camera, p3D, F_ref, F_query) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /pixloc/pixlib/geometry/costs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple 3 | from torch import Tensor 4 | 5 | from . import Pose, Camera 6 | from .optimization import J_normalization 7 | from .interpolation import Interpolator 8 | 9 | 10 | class DirectAbsoluteCost: 11 | def __init__(self, interpolator: Interpolator, normalize: bool = False): 12 | self.interpolator = interpolator 13 | self.normalize = normalize 14 | 15 | def residuals( 16 | self, T_w2q: Pose, camera: Camera, p3D: Tensor, 17 | F_ref: Tensor, F_query: Tensor, 18 | confidences: Optional[Tuple[Tensor, Tensor]] = None, 19 | do_gradients: bool = False): 20 | 21 | p3D_q = T_w2q * p3D 22 | p2D, visible = camera.world2image(p3D_q) 23 | F_p2D_raw, valid, gradients = self.interpolator( 24 | F_query, p2D, return_gradients=do_gradients) 25 | valid = valid & visible 26 | 27 | if confidences is not None: 28 | C_ref, C_query = confidences 29 | C_query_p2D, _, _ = self.interpolator( 30 | C_query, p2D, return_gradients=False) 31 | weight = C_ref * C_query_p2D 32 | weight = weight.squeeze(-1).masked_fill(~valid, 0.) 33 | else: 34 | weight = None 35 | 36 | if self.normalize: 37 | F_p2D = torch.nn.functional.normalize(F_p2D_raw, dim=-1) 38 | else: 39 | F_p2D = F_p2D_raw 40 | 41 | res = F_p2D - F_ref 42 | info = (p3D_q, F_p2D_raw, gradients) 43 | return res, valid, weight, F_p2D, info 44 | 45 | def jacobian( 46 | self, T_w2q: Pose, camera: Camera, 47 | p3D_q: Tensor, F_p2D_raw: Tensor, J_f_p2D: Tensor): 48 | 49 | J_p3D_T = T_w2q.J_transform(p3D_q) 50 | J_p2D_p3D, _ = camera.J_world2image(p3D_q) 51 | 52 | if self.normalize: 53 | J_f_p2D = J_normalization(F_p2D_raw) @ J_f_p2D 54 | 55 | J_p2D_T = J_p2D_p3D @ J_p3D_T 56 | J = J_f_p2D @ J_p2D_T 57 | return J, J_p2D_T 58 | 59 | def residual_jacobian( 60 | self, T_w2q: Pose, camera: Camera, p3D: Tensor, 61 | F_ref: Tensor, F_query: Tensor, 62 | confidences: Optional[Tuple[Tensor, Tensor]] = None): 63 | 64 | res, valid, weight, F_p2D, info = self.residuals( 65 | T_w2q, camera, p3D, F_ref, F_query, confidences, True) 66 | J, _ = self.jacobian(T_w2q, camera, *info) 67 | return res, valid, weight, F_p2D, J 68 | -------------------------------------------------------------------------------- /pixloc/pixlib/geometry/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic losses and error functions for optimization or training deep networks. 3 | """ 4 | 5 | import torch 6 | 7 | 8 | def scaled_loss(x, fn, a): 9 | """Apply a loss function to a tensor and pre- and post-scale it. 10 | Args: 11 | x: the data tensor, should already be squared: `x = y**2`. 12 | fn: the loss function, with signature `fn(x) -> y`. 13 | a: the scale parameter. 14 | Returns: 15 | The value of the loss, and its first and second derivatives. 16 | """ 17 | a2 = a**2 18 | loss, loss_d1, loss_d2 = fn(x/a2) 19 | return loss*a2, loss_d1, loss_d2/a2 20 | 21 | 22 | def squared_loss(x): 23 | """A dummy squared loss.""" 24 | return x, torch.ones_like(x), torch.zeros_like(x) 25 | 26 | 27 | def huber_loss(x): 28 | """The classical robust Huber loss, with first and second derivatives.""" 29 | mask = x <= 1 30 | sx = torch.sqrt(x) 31 | isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1/sx) 32 | loss = torch.where(mask, x, 2*sx-1) 33 | loss_d1 = torch.where(mask, torch.ones_like(x), isx) 34 | loss_d2 = torch.where(mask, torch.zeros_like(x), -isx/(2*x)) 35 | return loss, loss_d1, loss_d2 36 | 37 | 38 | def barron_loss(x, alpha, derivatives: bool = True, eps: float = 1e-7): 39 | """Parameterized & adaptive robust loss function. 40 | Described in: 41 | A General and Adaptive Robust Loss Function, Barron, CVPR 2019 42 | 43 | Contrary to the original implementation, assume the the input is already 44 | squared and scaled (basically scale=1). Computes the first derivative, but 45 | not the second (TODO if needed). 46 | """ 47 | loss_two = x 48 | loss_zero = 2 * torch.log1p(torch.clamp(0.5*x, max=33e37)) 49 | 50 | # The loss when not in one of the above special cases. 51 | # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. 52 | beta_safe = torch.abs(alpha - 2.).clamp(min=eps) 53 | # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. 54 | alpha_safe = torch.where( 55 | alpha >= 0, torch.ones_like(alpha), -torch.ones_like(alpha)) 56 | alpha_safe = alpha_safe * torch.abs(alpha).clamp(min=eps) 57 | 58 | loss_otherwise = 2 * (beta_safe / alpha_safe) * ( 59 | torch.pow(x / beta_safe + 1., 0.5 * alpha) - 1.) 60 | 61 | # Select which of the cases of the loss to return. 62 | loss = torch.where( 63 | alpha == 0, loss_zero, 64 | torch.where(alpha == 2, loss_two, loss_otherwise)) 65 | dummy = torch.zeros_like(x) 66 | 67 | if derivatives: 68 | loss_two_d1 = torch.ones_like(x) 69 | loss_zero_d1 = 2 / (x + 2) 70 | loss_otherwise_d1 = torch.pow(x / beta_safe + 1., 0.5 * alpha - 1.) 71 | loss_d1 = torch.where( 72 | alpha == 0, loss_zero_d1, 73 | torch.where(alpha == 2, loss_two_d1, loss_otherwise_d1)) 74 | 75 | return loss, loss_d1, dummy 76 | else: 77 | return loss, dummy, dummy 78 | 79 | 80 | def scaled_barron(a, c): 81 | return lambda x: scaled_loss( 82 | x, lambda y: barron_loss(y, y.new_tensor(a)), c) 83 | -------------------------------------------------------------------------------- /pixloc/pixlib/geometry/optimization.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | import torch 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | if version.parse(torch.__version__) >= version.parse('1.9'): 8 | cholesky = torch.linalg.cholesky 9 | else: 10 | cholesky = torch.cholesky 11 | 12 | 13 | def optimizer_step(g, H, lambda_=0, mute=False, mask=None, eps=1e-6): 14 | """One optimization step with Gauss-Newton or Levenberg-Marquardt. 15 | Args: 16 | g: batched gradient tensor of size (..., N). 17 | H: batched hessian tensor of size (..., N, N). 18 | lambda_: damping factor for LM (use GN if lambda_=0). 19 | mask: denotes valid elements of the batch (optional). 20 | """ 21 | if lambda_ is 0: # noqa 22 | diag = torch.zeros_like(g) 23 | else: 24 | diag = H.diagonal(dim1=-2, dim2=-1) * lambda_ 25 | H = H + diag.clamp(min=eps).diag_embed() 26 | 27 | if mask is not None: 28 | # make sure that masked elements are not singular 29 | H = torch.where(mask[..., None, None], H, torch.eye(H.shape[-1]).to(H)) 30 | # set g to 0 to delta is 0 for masked elements 31 | g = g.masked_fill(~mask[..., None], 0.) 32 | 33 | H_, g_ = H.cpu(), g.cpu() 34 | try: 35 | U = cholesky(H_) 36 | except RuntimeError as e: 37 | if 'singular U' in str(e): 38 | if not mute: 39 | logger.debug( 40 | 'Cholesky decomposition failed, fallback to LU.') 41 | delta = -torch.solve(g_[..., None], H_)[0][..., 0] 42 | else: 43 | raise 44 | else: 45 | delta = -torch.cholesky_solve(g_[..., None], U)[..., 0] 46 | 47 | return delta.to(H.device) 48 | 49 | 50 | def skew_symmetric(v): 51 | """Create a skew-symmetric matrix from a (batched) vector of size (..., 3). 52 | """ 53 | z = torch.zeros_like(v[..., 0]) 54 | M = torch.stack([ 55 | z, -v[..., 2], v[..., 1], 56 | v[..., 2], z, -v[..., 0], 57 | -v[..., 1], v[..., 0], z, 58 | ], dim=-1).reshape(v.shape[:-1]+(3, 3)) 59 | return M 60 | 61 | 62 | def so3exp_map(w, eps: float = 1e-7): 63 | """Compute rotation matrices from batched twists. 64 | Args: 65 | w: batched 3D axis-angle vectors of size (..., 3). 66 | Returns: 67 | A batch of rotation matrices of size (..., 3, 3). 68 | """ 69 | theta = w.norm(p=2, dim=-1, keepdim=True) 70 | small = theta < eps 71 | div = torch.where(small, torch.ones_like(theta), theta) 72 | W = skew_symmetric(w / div) 73 | theta = theta[..., None] # ... x 1 x 1 74 | res = W * torch.sin(theta) + (W @ W) * (1 - torch.cos(theta)) 75 | res = torch.where(small[..., None], W, res) # first-order Taylor approx 76 | return torch.eye(3).to(W) + res 77 | 78 | 79 | def J_normalization(x): 80 | """Jacobian of the L2 normalization, assuming that we normalize 81 | along the last dimension. 82 | """ 83 | x_normed = torch.nn.functional.normalize(x, dim=-1) 84 | norm = torch.norm(x, p=2, dim=-1, keepdim=True) 85 | 86 | Id = torch.diag_embed(torch.ones_like(x_normed)) 87 | J = (Id - x_normed.unsqueeze(-1) @ x_normed.unsqueeze(-2)) 88 | J = J / norm.unsqueeze(-1) 89 | return J 90 | -------------------------------------------------------------------------------- /pixloc/pixlib/geometry/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | A set of geometry tools for PyTorch tensors and sometimes NumPy arrays. 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | def to_homogeneous(points): 10 | """Convert N-dimensional points to homogeneous coordinates. 11 | Args: 12 | points: torch.Tensor or numpy.ndarray with size (..., N). 13 | Returns: 14 | A torch.Tensor or numpy.ndarray with size (..., N+1). 15 | """ 16 | if isinstance(points, torch.Tensor): 17 | pad = points.new_ones(points.shape[:-1]+(1,)) 18 | return torch.cat([points, pad], dim=-1) 19 | elif isinstance(points, np.ndarray): 20 | pad = np.ones((points.shape[:-1]+(1,)), dtype=points.dtype) 21 | return np.concatenate([points, pad], axis=-1) 22 | else: 23 | raise ValueError 24 | 25 | 26 | def from_homogeneous(points): 27 | """Remove the homogeneous dimension of N-dimensional points. 28 | Args: 29 | points: torch.Tensor or numpy.ndarray with size (..., N+1). 30 | Returns: 31 | A torch.Tensor or numpy ndarray with size (..., N). 32 | """ 33 | return points[..., :-1] / points[..., -1:] 34 | 35 | 36 | @torch.jit.script 37 | def undistort_points(pts, dist): 38 | '''Undistort normalized 2D coordinates 39 | and check for validity of the distortion model. 40 | ''' 41 | dist = dist.unsqueeze(-2) # add point dimension 42 | ndist = dist.shape[-1] 43 | undist = pts 44 | valid = torch.ones(pts.shape[:-1], device=pts.device, dtype=torch.bool) 45 | if ndist > 0: 46 | k1, k2 = dist[..., :2].split(1, -1) 47 | r2 = torch.sum(pts**2, -1, keepdim=True) 48 | radial = k1*r2 + k2*r2**2 49 | undist = undist + pts * radial 50 | 51 | # The distortion model is supposedly only valid within the image 52 | # boundaries. Because of the negative radial distortion, points that 53 | # are far outside of the boundaries might actually be mapped back 54 | # within the image. To account for this, we discard points that are 55 | # beyond the inflection point of the distortion model, 56 | # e.g. such that d(r + k_1 r^3 + k2 r^5)/dr = 0 57 | limited = ((k2 > 0) & ((9*k1**2-20*k2) > 0)) | ((k2 <= 0) & (k1 > 0)) 58 | limit = torch.abs(torch.where( 59 | k2 > 0, (torch.sqrt(9*k1**2-20*k2)-3*k1)/(10*k2), 1/(3*k1))) 60 | valid = valid & torch.squeeze(~limited | (r2 < limit), -1) 61 | 62 | if ndist > 2: 63 | p12 = dist[..., 2:] 64 | p21 = p12.flip(-1) 65 | uv = torch.prod(pts, -1, keepdim=True) 66 | undist = undist + 2*p12*uv + p21*(r2 + 2*pts**2) 67 | # TODO: handle tangential boundaries 68 | 69 | return undist, valid 70 | 71 | 72 | @torch.jit.script 73 | def J_undistort_points(pts, dist): 74 | dist = dist.unsqueeze(-2) # add point dimension 75 | ndist = dist.shape[-1] 76 | 77 | J_diag = torch.ones_like(pts) 78 | J_cross = torch.zeros_like(pts) 79 | if ndist > 0: 80 | k1, k2 = dist[..., :2].split(1, -1) 81 | r2 = torch.sum(pts**2, -1, keepdim=True) 82 | uv = torch.prod(pts, -1, keepdim=True) 83 | radial = k1*r2 + k2*r2**2 84 | d_radial = (2*k1 + 4*k2*r2) 85 | J_diag += radial + (pts**2)*d_radial 86 | J_cross += uv*d_radial 87 | 88 | if ndist > 2: 89 | p12 = dist[..., 2:] 90 | p21 = p12.flip(-1) 91 | J_diag += 2*p12*pts.flip(-1) + 6*p21*pts 92 | J_cross += 2*p12*pts + 2*p21*pts.flip(-1) 93 | 94 | J = torch.diag_embed(J_diag) + torch.diag_embed(J_cross).flip(-1) 95 | return J 96 | -------------------------------------------------------------------------------- /pixloc/pixlib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils.tools import get_class 2 | from .base_model import BaseModel 3 | 4 | 5 | def get_model(name): 6 | return get_class(name, __name__, BaseModel) 7 | -------------------------------------------------------------------------------- /pixloc/pixlib/models/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for trainable models. 3 | """ 4 | 5 | from abc import ABCMeta, abstractmethod 6 | import omegaconf 7 | from omegaconf import OmegaConf 8 | from torch import nn 9 | from copy import copy 10 | 11 | 12 | class MetaModel(ABCMeta): 13 | def __prepare__(name, bases, **kwds): 14 | total_conf = OmegaConf.create() 15 | for base in bases: 16 | for key in ('base_default_conf', 'default_conf'): 17 | update = getattr(base, key, {}) 18 | if isinstance(update, dict): 19 | update = OmegaConf.create(update) 20 | total_conf = OmegaConf.merge(total_conf, update) 21 | return dict(base_default_conf=total_conf) 22 | 23 | 24 | class BaseModel(nn.Module, metaclass=MetaModel): 25 | """ 26 | What the child model is expect to declare: 27 | default_conf: dictionary of the default configuration of the model. 28 | It recursively updates the default_conf of all parent classes, and 29 | it is updated by the user-provided configuration passed to __init__. 30 | Configurations can be nested. 31 | 32 | required_data_keys: list of expected keys in the input data dictionary. 33 | 34 | strict_conf (optional): boolean. If false, BaseModel does not raise 35 | an error when the user provides an unknown configuration entry. 36 | 37 | _init(self, conf): initialization method, where conf is the final 38 | configuration object (also accessible with `self.conf`). Accessing 39 | unkown configuration entries will raise an error. 40 | 41 | _forward(self, data): method that returns a dictionary of batched 42 | prediction tensors based on a dictionary of batched input data tensors. 43 | 44 | loss(self, pred, data): method that returns a dictionary of losses, 45 | computed from model predictions and input data. Each loss is a batch 46 | of scalars, i.e. a torch.Tensor of shape (B,). 47 | The total loss to be optimized has the key `'total'`. 48 | 49 | metrics(self, pred, data): method that returns a dictionary of metrics, 50 | each as a batch of scalars. 51 | """ 52 | default_conf = { 53 | 'name': None, 54 | 'trainable': True, # if false: do not optimize this model parameters 55 | 'freeze_batch_normalization': False, # use test-time statistics 56 | } 57 | required_data_keys = [] 58 | strict_conf = True 59 | 60 | def __init__(self, conf): 61 | """Perform some logic and call the _init method of the child model.""" 62 | super().__init__() 63 | default_conf = OmegaConf.merge( 64 | self.base_default_conf, OmegaConf.create(self.default_conf)) 65 | if self.strict_conf: 66 | OmegaConf.set_struct(default_conf, True) 67 | 68 | # fixme: backward compatibility 69 | if 'pad' in conf and 'pad' not in default_conf: # backward compat. 70 | with omegaconf.read_write(conf): 71 | with omegaconf.open_dict(conf): 72 | conf['interpolation'] = {'pad': conf.pop('pad')} 73 | 74 | if isinstance(conf, dict): 75 | conf = OmegaConf.create(conf) 76 | self.conf = conf = OmegaConf.merge(default_conf, conf) 77 | OmegaConf.set_readonly(conf, True) 78 | OmegaConf.set_struct(conf, True) 79 | self.required_data_keys = copy(self.required_data_keys) 80 | self._init(conf) 81 | 82 | if not conf.trainable: 83 | for p in self.parameters(): 84 | p.requires_grad = False 85 | 86 | def train(self, mode=True): 87 | super().train(mode) 88 | 89 | def freeze_bn(module): 90 | if isinstance(module, nn.modules.batchnorm._BatchNorm): 91 | module.eval() 92 | if self.conf.freeze_batch_normalization: 93 | self.apply(freeze_bn) 94 | 95 | return self 96 | 97 | def forward(self, data): 98 | """Check the data and call the _forward method of the child model.""" 99 | def recursive_key_check(expected, given): 100 | for key in expected: 101 | assert key in given, f'Missing key {key} in data' 102 | if isinstance(expected, dict): 103 | recursive_key_check(expected[key], given[key]) 104 | 105 | recursive_key_check(self.required_data_keys, data) 106 | return self._forward(data) 107 | 108 | @abstractmethod 109 | def _init(self, conf): 110 | """To be implemented by the child class.""" 111 | raise NotImplementedError 112 | 113 | @abstractmethod 114 | def _forward(self, data): 115 | """To be implemented by the child class.""" 116 | raise NotImplementedError 117 | 118 | @abstractmethod 119 | def loss(self, pred, data): 120 | """To be implemented by the child class.""" 121 | raise NotImplementedError 122 | 123 | @abstractmethod 124 | def metrics(self, pred, data): 125 | """To be implemented by the child class.""" 126 | raise NotImplementedError 127 | -------------------------------------------------------------------------------- /pixloc/pixlib/models/base_optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements a simple differentiable optimizer based on Levenberg-Marquardt 3 | with a constant, scalar damping factor and a fixed number of iterations. 4 | """ 5 | 6 | import logging 7 | from typing import Tuple, Dict, Optional 8 | import torch 9 | from torch import Tensor 10 | 11 | from .base_model import BaseModel 12 | from .utils import masked_mean 13 | from ..geometry import Camera, Pose 14 | from ..geometry.optimization import optimizer_step 15 | from ..geometry.interpolation import Interpolator 16 | from ..geometry.costs import DirectAbsoluteCost 17 | from ..geometry import losses # noqa 18 | from ...utils.tools import torchify 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class BaseOptimizer(BaseModel): 24 | default_conf = dict( 25 | num_iters=100, 26 | loss_fn='squared_loss', 27 | jacobi_scaling=False, 28 | normalize_features=False, 29 | lambda_=0, # Gauss-Newton 30 | interpolation=dict( 31 | mode='linear', 32 | pad=4, 33 | ), 34 | grad_stop_criteria=1e-4, 35 | dt_stop_criteria=5e-3, # in meters 36 | dR_stop_criteria=5e-2, # in degrees 37 | 38 | # deprecated entries 39 | sqrt_diag_damping=False, 40 | bound_confidence=True, 41 | no_conditions=True, 42 | verbose=False, 43 | ) 44 | logging_fn = None 45 | 46 | def _init(self, conf): 47 | self.loss_fn = eval('losses.' + conf.loss_fn) 48 | self.interpolator = Interpolator(**conf.interpolation) 49 | self.cost_fn = DirectAbsoluteCost(self.interpolator, 50 | normalize=conf.normalize_features) 51 | assert conf.lambda_ >= 0. 52 | # deprecated entries 53 | assert not conf.sqrt_diag_damping 54 | assert conf.bound_confidence 55 | assert conf.no_conditions 56 | assert not conf.verbose 57 | 58 | def log(self, **args): 59 | if self.logging_fn is not None: 60 | self.logging_fn(**args) 61 | 62 | def early_stop(self, **args): 63 | stop = False 64 | if not self.training and (args['i'] % 10) == 0: 65 | T_delta, grad = args['T_delta'], args['grad'] 66 | grad_norm = torch.norm(grad.detach(), dim=-1) 67 | small_grad = grad_norm < self.conf.grad_stop_criteria 68 | dR, dt = T_delta.magnitude() 69 | small_step = ((dt < self.conf.dt_stop_criteria) 70 | & (dR < self.conf.dR_stop_criteria)) 71 | if torch.all(small_step | small_grad): 72 | stop = True 73 | return stop 74 | 75 | def J_scaling(self, J: Tensor, J_scaling: Tensor, valid: Tensor): 76 | if J_scaling is None: 77 | J_norm = torch.norm(J.detach(), p=2, dim=(-2)) 78 | J_norm = masked_mean(J_norm, valid[..., None], -2) 79 | J_scaling = 1 / (1 + J_norm) 80 | J = J * J_scaling[..., None, None, :] 81 | return J, J_scaling 82 | 83 | def build_system(self, J: Tensor, res: Tensor, weights: Tensor): 84 | grad = torch.einsum('...ndi,...nd->...ni', J, res) # ... x N x 6 85 | grad = weights[..., None] * grad 86 | grad = grad.sum(-2) # ... x 6 87 | 88 | Hess = torch.einsum('...ijk,...ijl->...ikl', J, J) # ... x N x 6 x 6 89 | Hess = weights[..., None, None] * Hess 90 | Hess = Hess.sum(-3) # ... x 6 x6 91 | 92 | return grad, Hess 93 | 94 | def _forward(self, data: Dict): 95 | return self._run( 96 | data['p3D'], data['F_ref'], data['F_q'], data['T_init'], 97 | data['cam_q'], data['mask'], data.get('W_ref_q')) 98 | 99 | @torchify 100 | def run(self, *args, **kwargs): 101 | return self._run(*args, **kwargs) 102 | 103 | def _run(self, p3D: Tensor, F_ref: Tensor, F_query: Tensor, 104 | T_init: Pose, camera: Camera, mask: Optional[Tensor] = None, 105 | W_ref_query: Optional[Tuple[Tensor, Tensor]] = None): 106 | 107 | T = T_init 108 | J_scaling = None 109 | if self.conf.normalize_features: 110 | F_ref = torch.nn.functional.normalize(F_ref, dim=-1) 111 | args = (camera, p3D, F_ref, F_query, W_ref_query) 112 | failed = torch.full(T.shape, False, dtype=torch.bool, device=T.device) 113 | 114 | for i in range(self.conf.num_iters): 115 | res, valid, w_unc, _, J = self.cost_fn.residual_jacobian(T, *args) 116 | if mask is not None: 117 | valid &= mask 118 | failed = failed | (valid.long().sum(-1) < 10) # too few points 119 | 120 | # compute the cost and aggregate the weights 121 | cost = (res**2).sum(-1) 122 | cost, w_loss, _ = self.loss_fn(cost) 123 | weights = w_loss * valid.float() 124 | if w_unc is not None: 125 | weights *= w_unc 126 | if self.conf.jacobi_scaling: 127 | J, J_scaling = self.J_scaling(J, J_scaling, valid) 128 | 129 | # solve the linear system 130 | g, H = self.build_system(J, res, weights) 131 | delta = optimizer_step(g, H, self.conf.lambda_, mask=~failed) 132 | if self.conf.jacobi_scaling: 133 | delta = delta * J_scaling 134 | 135 | # compute the pose update 136 | dt, dw = delta.split([3, 3], dim=-1) 137 | T_delta = Pose.from_aa(dw, dt) 138 | T = T_delta @ T 139 | 140 | self.log(i=i, T_init=T_init, T=T, T_delta=T_delta, cost=cost, 141 | valid=valid, w_unc=w_unc, w_loss=w_loss, H=H, J=J) 142 | if self.early_stop(i=i, T_delta=T_delta, grad=g, cost=cost): 143 | break 144 | 145 | if failed.any(): 146 | logger.debug('One batch element had too few valid points.') 147 | 148 | return T, failed 149 | 150 | def loss(self, pred, data): 151 | raise NotImplementedError 152 | 153 | def metrics(self, pred, data): 154 | raise NotImplementedError 155 | -------------------------------------------------------------------------------- /pixloc/pixlib/models/classic_optimizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, Optional 3 | import torch 4 | from torch import Tensor 5 | 6 | from .base_optimizer import BaseOptimizer 7 | from .utils import masked_mean 8 | from ..geometry import Camera, Pose 9 | from ..geometry.optimization import optimizer_step 10 | from ..geometry import losses # noqa 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class ClassicOptimizer(BaseOptimizer): 16 | default_conf = dict( 17 | lambda_=1e-2, 18 | lambda_max=1e4, 19 | ) 20 | 21 | def _run(self, p3D: Tensor, F_ref: Tensor, F_query: Tensor, 22 | T_init: Pose, camera: Camera, mask: Optional[Tensor] = None, 23 | W_ref_query: Optional[Tuple[Tensor, Tensor]] = None): 24 | 25 | T = T_init 26 | J_scaling = None 27 | if self.conf.normalize_features: 28 | F_ref = torch.nn.functional.normalize(F_ref, dim=-1) 29 | args = (camera, p3D, F_ref, F_query, W_ref_query) 30 | failed = torch.full(T.shape, False, dtype=torch.bool, device=T.device) 31 | 32 | lambda_ = torch.full_like(failed, self.conf.lambda_, dtype=T.dtype) 33 | mult = torch.full_like(lambda_, 10) 34 | recompute = True 35 | 36 | # compute the initial cost 37 | with torch.no_grad(): 38 | res, valid_i, w_unc_i = self.cost_fn.residuals(T_init, *args)[:3] 39 | cost_i = self.loss_fn((res.detach()**2).sum(-1))[0] 40 | if w_unc_i is not None: 41 | cost_i *= w_unc_i.detach() 42 | valid_i &= mask 43 | cost_best = masked_mean(cost_i, valid_i, -1) 44 | 45 | for i in range(self.conf.num_iters): 46 | if recompute: 47 | res, valid, w_unc, _, J = self.cost_fn.residual_jacobian( 48 | T, *args) 49 | if mask is not None: 50 | valid &= mask 51 | failed = failed | (valid.long().sum(-1) < 10) # too few points 52 | 53 | cost = (res**2).sum(-1) 54 | cost, w_loss, _ = self.loss_fn(cost) 55 | weights = w_loss * valid.float() 56 | if w_unc is not None: 57 | weights *= w_unc 58 | if self.conf.jacobi_scaling: 59 | J, J_scaling = self.J_scaling(J, J_scaling, valid) 60 | g, H = self.build_system(J, res, weights) 61 | 62 | delta = optimizer_step(g, H, lambda_.unqueeze(-1), mask=~failed) 63 | if self.conf.jacobi_scaling: 64 | delta = delta * J_scaling 65 | 66 | dt, dw = delta.split([3, 3], dim=-1) 67 | T_delta = Pose.from_aa(dw, dt) 68 | T_new = T_delta @ T 69 | 70 | # compute the new cost and update if it decreased 71 | with torch.no_grad(): 72 | res = self.cost_fn.residual(T_new, *args)[0] 73 | cost_new = self.loss_fn((res**2).sum(-1))[0] 74 | cost_new = masked_mean(cost_new, valid, -1) 75 | accept = cost_new < cost_best 76 | lambda_ = lambda_ * torch.where(accept, 1/mult, mult) 77 | lambda_ = lambda_.clamp(max=self.conf.lambda_max, min=1e-7) 78 | T = Pose(torch.where(accept[..., None], T_new._data, T._data)) 79 | cost_best = torch.where(accept, cost_new, cost_best) 80 | recompute = accept.any() 81 | 82 | self.log(i=i, T_init=T_init, T=T, T_delta=T_delta, cost=cost, 83 | valid=valid, w_unc=w_unc, w_loss=w_loss, accept=accept, 84 | lambda_=lambda_, H=H, J=J) 85 | 86 | stop = self.early_stop(i=i, T_delta=T_delta, grad=g, cost=cost) 87 | if self.conf.lambda_ == 0: # Gauss-Newton 88 | stop |= (~recompute) 89 | else: # LM saturates 90 | stop |= bool(torch.all(lambda_ >= self.conf.lambda_max)) 91 | if stop: 92 | break 93 | 94 | if failed.any(): 95 | logger.debug('One batch element had too few valid points.') 96 | 97 | return T, failed 98 | -------------------------------------------------------------------------------- /pixloc/pixlib/models/gaussiannet.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dummy model that computes an image pyramid with appropriate blurring. 3 | """ 4 | 5 | import torch 6 | import kornia 7 | 8 | from .base_model import BaseModel 9 | 10 | 11 | class GaussianNet(BaseModel): 12 | default_conf = { 13 | 'output_scales': [1, 4, 16], # what scales to adapt and output 14 | 'kernel_size_factor': 3, 15 | } 16 | 17 | def _init(self, conf): 18 | self.scales = conf['output_scales'] 19 | 20 | def _forward(self, data): 21 | image = data['image'] 22 | scale_prev = 1 23 | pyramid = [] 24 | for scale in self.conf.output_scales: 25 | sigma = scale / scale_prev 26 | ksize = self.conf.kernel_size_factor * sigma 27 | image = kornia.filter.gaussian_blur2d( 28 | image, kernel_size=ksize, sigma=sigma) 29 | if sigma != 1: 30 | image = torch.nn.functional.interpolate( 31 | image, scale_factor=1/sigma, mode='bilinear', 32 | align_corners=False) 33 | pyramid.append(image) 34 | scale_prev = scale 35 | return {'feature_maps': pyramid} 36 | 37 | def loss(self, pred, data): 38 | raise NotImplementedError 39 | 40 | def metrics(self, pred, data): 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /pixloc/pixlib/models/learned_optimizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, Optional 3 | import torch 4 | from torch import nn, Tensor 5 | 6 | from .base_optimizer import BaseOptimizer 7 | from ..geometry import Camera, Pose 8 | from ..geometry.optimization import optimizer_step 9 | from ..geometry import losses # noqa 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class DampingNet(nn.Module): 15 | def __init__(self, conf, num_params=6): 16 | super().__init__() 17 | self.conf = conf 18 | if conf.type == 'constant': 19 | const = torch.zeros(num_params) 20 | self.register_parameter('const', torch.nn.Parameter(const)) 21 | else: 22 | raise ValueError(f'Unsupported type of damping: {conf.type}.') 23 | 24 | def forward(self): 25 | min_, max_ = self.conf.log_range 26 | lambda_ = 10.**(min_ + self.const.sigmoid()*(max_ - min_)) 27 | return lambda_ 28 | 29 | 30 | class LearnedOptimizer(BaseOptimizer): 31 | default_conf = dict( 32 | damping=dict( 33 | type='constant', 34 | log_range=[-6, 5], 35 | ), 36 | feature_dim=None, 37 | 38 | # deprecated entries 39 | lambda_=0., 40 | learned_damping=True, 41 | ) 42 | 43 | def _init(self, conf): 44 | self.dampingnet = DampingNet(conf.damping) 45 | assert conf.learned_damping 46 | super()._init(conf) 47 | 48 | def _run(self, p3D: Tensor, F_ref: Tensor, F_query: Tensor, 49 | T_init: Pose, camera: Camera, mask: Optional[Tensor] = None, 50 | W_ref_query: Optional[Tuple[Tensor, Tensor]] = None): 51 | 52 | T = T_init 53 | J_scaling = None 54 | if self.conf.normalize_features: 55 | F_ref = torch.nn.functional.normalize(F_ref, dim=-1) 56 | args = (camera, p3D, F_ref, F_query, W_ref_query) 57 | failed = torch.full(T.shape, False, dtype=torch.bool, device=T.device) 58 | 59 | lambda_ = self.dampingnet() 60 | 61 | for i in range(self.conf.num_iters): 62 | res, valid, w_unc, _, J = self.cost_fn.residual_jacobian(T, *args) 63 | if mask is not None: 64 | valid &= mask 65 | failed = failed | (valid.long().sum(-1) < 10) # too few points 66 | 67 | # compute the cost and aggregate the weights 68 | cost = (res**2).sum(-1) 69 | cost, w_loss, _ = self.loss_fn(cost) 70 | weights = w_loss * valid.float() 71 | if w_unc is not None: 72 | weights *= w_unc 73 | if self.conf.jacobi_scaling: 74 | J, J_scaling = self.J_scaling(J, J_scaling, valid) 75 | 76 | # solve the linear system 77 | g, H = self.build_system(J, res, weights) 78 | delta = optimizer_step(g, H, lambda_, mask=~failed) 79 | if self.conf.jacobi_scaling: 80 | delta = delta * J_scaling 81 | 82 | # compute the pose update 83 | dt, dw = delta.split([3, 3], dim=-1) 84 | T_delta = Pose.from_aa(dw, dt) 85 | T = T_delta @ T 86 | 87 | self.log(i=i, T_init=T_init, T=T, T_delta=T_delta, cost=cost, 88 | valid=valid, w_unc=w_unc, w_loss=w_loss, H=H, J=J) 89 | if self.early_stop(i=i, T_delta=T_delta, grad=g, cost=cost): 90 | break 91 | 92 | if failed.any(): 93 | logger.debug('One batch element had too few valid points.') 94 | 95 | return T, failed 96 | -------------------------------------------------------------------------------- /pixloc/pixlib/models/s2dnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of 3 | S2DNet: Learning Image Features for Accurate Sparse-to-Dense Matching 4 | Hugo Germain, Guillaume Bourmaud, Vincent Lepetit 5 | European Conference on Computer Vision (ECCV) 2020 6 | 7 | Adapted from https://github.com/germain-hug/S2DNet-Minimal 8 | """ 9 | 10 | from typing import List 11 | import torch 12 | import torch.nn as nn 13 | from torchvision import models 14 | import logging 15 | 16 | from .base_model import BaseModel 17 | from ...settings import DATA_PATH 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | # VGG-16 Layer Names and Channels 23 | vgg16_layers = { 24 | "conv1_1": 64, 25 | "relu1_1": 64, 26 | "conv1_2": 64, 27 | "relu1_2": 64, 28 | "pool1": 64, 29 | "conv2_1": 128, 30 | "relu2_1": 128, 31 | "conv2_2": 128, 32 | "relu2_2": 128, 33 | "pool2": 128, 34 | "conv3_1": 256, 35 | "relu3_1": 256, 36 | "conv3_2": 256, 37 | "relu3_2": 256, 38 | "conv3_3": 256, 39 | "relu3_3": 256, 40 | "pool3": 256, 41 | "conv4_1": 512, 42 | "relu4_1": 512, 43 | "conv4_2": 512, 44 | "relu4_2": 512, 45 | "conv4_3": 512, 46 | "relu4_3": 512, 47 | "pool4": 512, 48 | "conv5_1": 512, 49 | "relu5_1": 512, 50 | "conv5_2": 512, 51 | "relu5_2": 512, 52 | "conv5_3": 512, 53 | "relu5_3": 512, 54 | "pool5": 512, 55 | } 56 | 57 | 58 | class AdapLayers(nn.Module): 59 | """Small adaptation layers. 60 | """ 61 | 62 | def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128): 63 | """Initialize one adaptation layer for every extraction point. 64 | 65 | Args: 66 | hypercolumn_layers: The list of the hypercolumn layer names. 67 | output_dim: The output channel dimension. 68 | """ 69 | super(AdapLayers, self).__init__() 70 | self.layers = [] 71 | channel_sizes = [vgg16_layers[name] for name in hypercolumn_layers] 72 | for i, l in enumerate(channel_sizes): 73 | layer = nn.Sequential( 74 | nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0), 75 | nn.ReLU(), 76 | nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2), 77 | nn.BatchNorm2d(output_dim), 78 | ) 79 | self.layers.append(layer) 80 | self.add_module("adap_layer_{}".format(i), layer) 81 | 82 | def forward(self, features: List[torch.tensor]): 83 | """Apply adaptation layers. 84 | """ 85 | for i, _ in enumerate(features): 86 | features[i] = getattr(self, "adap_layer_{}".format(i))(features[i]) 87 | return features 88 | 89 | 90 | class S2DNet(BaseModel): 91 | default_conf = { 92 | 'hypercolumn_layers': ["conv1_2", "conv3_3", "conv5_3"], 93 | 'checkpointing': None, 94 | 'output_dim': 128, 95 | 'pretrained': 's2dnet', 96 | } 97 | mean = [0.485, 0.456, 0.406] 98 | std = [0.229, 0.224, 0.225] 99 | 100 | def _init(self, conf): 101 | assert conf.pretrained in ['s2dnet', 'imagenet', None] 102 | 103 | self.layer_to_index = {k: v for v, k in enumerate(vgg16_layers.keys())} 104 | self.hypercolumn_indices = [ 105 | self.layer_to_index[n] for n in conf.hypercolumn_layers] 106 | num_layers = self.hypercolumn_indices[-1] + 1 107 | 108 | # Initialize architecture 109 | vgg16 = models.vgg16(pretrained=conf.pretrained == 'imagenet') 110 | layers = list(vgg16.features.children())[:num_layers] 111 | self.encoder = nn.ModuleList(layers) 112 | 113 | self.scales = [] 114 | current_scale = 0 115 | for i, layer in enumerate(layers): 116 | if isinstance(layer, torch.nn.MaxPool2d): 117 | current_scale += 1 118 | if i in self.hypercolumn_indices: 119 | self.scales.append(2**current_scale) 120 | 121 | self.adaptation_layers = AdapLayers( 122 | conf.hypercolumn_layers, conf.output_dim) 123 | 124 | if conf.pretrained == 's2dnet': 125 | path = DATA_PATH / 's2dnet_weights.pth' 126 | logger.info(f'Loading S2DNet checkpoint at {path}.') 127 | state_dict = torch.load(path, map_location='cpu')['state_dict'] 128 | params = self.state_dict() 129 | state_dict = {k: v for k, v in state_dict.items() 130 | if v.shape == params[k].shape} 131 | self.load_state_dict(state_dict, strict=False) 132 | 133 | def _forward(self, data): 134 | image = data['image'] 135 | mean, std = image.new_tensor(self.mean), image.new_tensor(self.std) 136 | image = (image - mean[:, None, None]) / std[:, None, None] 137 | 138 | feature_map = image 139 | feature_maps = [] 140 | start = 0 141 | for idx in self.hypercolumn_indices: 142 | if self.conf.checkpointing: 143 | blocks = list(range(start, idx+2, self.conf.checkpointing)) 144 | if blocks[-1] != idx+1: 145 | blocks.append(idx+1) 146 | for start_, end_ in zip(blocks[:-1], blocks[1:]): 147 | feature_map = torch.utils.checkpoint.checkpoint( 148 | nn.Sequential(*self.encoder[start_:end_]), feature_map) 149 | else: 150 | for i in range(start, idx + 1): 151 | feature_map = self.encoder[i](feature_map) 152 | feature_maps.append(feature_map) 153 | start = idx + 1 154 | 155 | feature_maps = self.adaptation_layers(feature_maps) 156 | return {'feature_maps': feature_maps} 157 | 158 | def loss(self, pred, data): 159 | raise NotImplementedError 160 | 161 | def metrics(self, pred, data): 162 | raise NotImplementedError 163 | -------------------------------------------------------------------------------- /pixloc/pixlib/models/two_view_refiner.py: -------------------------------------------------------------------------------- 1 | """ 2 | The top-level model of training-time PixLoc. 3 | Encapsulates the feature extraction, pose optimization, loss and metrics. 4 | """ 5 | import torch 6 | from torch.nn import functional as nnF 7 | import logging 8 | from copy import deepcopy 9 | import omegaconf 10 | 11 | from .base_model import BaseModel 12 | from . import get_model 13 | from .utils import masked_mean 14 | from ..geometry.losses import scaled_barron 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class TwoViewRefiner(BaseModel): 20 | default_conf = { 21 | 'extractor': { 22 | 'name': 's2dnet', 23 | }, 24 | 'optimizer': { 25 | 'name': 'basic_optimizer', 26 | }, 27 | 'duplicate_optimizer_per_scale': False, 28 | 'success_thresh': 2, 29 | 'clamp_error': 50, 30 | 'normalize_features': True, 31 | 'normalize_dt': True, 32 | 33 | # deprecated entries 34 | 'init_target_offset': None, 35 | } 36 | required_data_keys = { 37 | 'ref': ['image', 'camera', 'T_w2cam'], 38 | 'query': ['image', 'camera', 'T_w2cam'], 39 | } 40 | strict_conf = False # need to pass new confs to children models 41 | 42 | def _init(self, conf): 43 | self.extractor = get_model(conf.extractor.name)(conf.extractor) 44 | assert hasattr(self.extractor, 'scales') 45 | 46 | Opt = get_model(conf.optimizer.name) 47 | if conf.duplicate_optimizer_per_scale: 48 | oconfs = [deepcopy(conf.optimizer) for _ in self.extractor.scales] 49 | feature_dim = self.extractor.conf.output_dim 50 | if not isinstance(feature_dim, int): 51 | for d, oconf in zip(feature_dim, oconfs): 52 | with omegaconf.read_write(oconf): 53 | with omegaconf.open_dict(oconf): 54 | oconf.feature_dim = d 55 | self.optimizer = torch.nn.ModuleList([Opt(c) for c in oconfs]) 56 | else: 57 | self.optimizer = Opt(conf.optimizer) 58 | 59 | if conf.init_target_offset is not None: 60 | raise ValueError('This entry has been deprecated. Please instead ' 61 | 'use the `init_pose` config of the dataloader.') 62 | 63 | def _forward(self, data): 64 | def process_siamese(data_i): 65 | pred_i = self.extractor(data_i) 66 | pred_i['camera_pyr'] = [data_i['camera'].scale(1/s) 67 | for s in self.extractor.scales] 68 | return pred_i 69 | 70 | pred = {i: process_siamese(data[i]) for i in ['ref', 'query']} 71 | p3D_ref = data['ref']['points3D'] 72 | T_init = data['T_r2q_init'] 73 | 74 | pred['T_r2q_init'] = [] 75 | pred['T_r2q_opt'] = [] 76 | pred['valid_masks'] = [] 77 | for i in reversed(range(len(self.extractor.scales))): 78 | F_ref = pred['ref']['feature_maps'][i] 79 | F_q = pred['query']['feature_maps'][i] 80 | cam_ref = pred['ref']['camera_pyr'][i] 81 | cam_q = pred['query']['camera_pyr'][i] 82 | if self.conf.duplicate_optimizer_per_scale: 83 | opt = self.optimizer[i] 84 | else: 85 | opt = self.optimizer 86 | 87 | p2D_ref, visible = cam_ref.world2image(p3D_ref) 88 | F_ref, mask, _ = opt.interpolator(F_ref, p2D_ref) 89 | mask &= visible 90 | 91 | W_ref_q = None 92 | if self.extractor.conf.get('compute_uncertainty', False): 93 | W_ref = pred['ref']['confidences'][i] 94 | W_q = pred['query']['confidences'][i] 95 | W_ref, _, _ = opt.interpolator(W_ref, p2D_ref) 96 | W_ref_q = (W_ref, W_q) 97 | 98 | if self.conf.normalize_features: 99 | F_ref = nnF.normalize(F_ref, dim=2) # B x N x C 100 | F_q = nnF.normalize(F_q, dim=1) # B x C x W x H 101 | 102 | T_opt, failed = opt(dict( 103 | p3D=p3D_ref, F_ref=F_ref, F_q=F_q, T_init=T_init, cam_q=cam_q, 104 | mask=mask, W_ref_q=W_ref_q)) 105 | 106 | pred['T_r2q_init'].append(T_init) 107 | pred['T_r2q_opt'].append(T_opt) 108 | T_init = T_opt.detach() 109 | 110 | return pred 111 | 112 | def loss(self, pred, data): 113 | cam_q = data['query']['camera'] 114 | 115 | def project(T_r2q): 116 | return cam_q.world2image(T_r2q * data['ref']['points3D']) 117 | 118 | p2D_q_gt, mask = project(data['T_r2q_gt']) 119 | p2D_q_i, mask_i = project(data['T_r2q_init']) 120 | mask = (mask & mask_i).float() 121 | 122 | too_few = torch.sum(mask, -1) < 10 123 | if torch.any(too_few): 124 | logger.warning( 125 | 'Few points in batch '+str([ 126 | (data['scene'][i], data['ref']['index'][i].item(), 127 | data['query']['index'][i].item()) 128 | for i in torch.where(too_few)[0]])) 129 | 130 | def reprojection_error(T_r2q): 131 | p2D_q, _ = project(T_r2q) 132 | err = torch.sum((p2D_q_gt - p2D_q)**2, dim=-1) 133 | err = scaled_barron(1., 2.)(err)[0]/4 134 | err = masked_mean(err, mask, -1) 135 | return err 136 | 137 | num_scales = len(self.extractor.scales) 138 | success = None 139 | losses = {'total': 0.} 140 | for i, T_opt in enumerate(pred['T_r2q_opt']): 141 | err = reprojection_error(T_opt).clamp(max=self.conf.clamp_error) 142 | loss = err / num_scales 143 | if i > 0: 144 | loss = loss * success.float() 145 | thresh = self.conf.success_thresh * self.extractor.scales[-1-i] 146 | success = err < thresh 147 | losses[f'reprojection_error/{i}'] = err 148 | losses['total'] += loss 149 | losses['reprojection_error'] = err 150 | losses['total'] *= (~too_few).float() 151 | 152 | err_init = reprojection_error(pred['T_r2q_init'][0]) 153 | losses['reprojection_error/init'] = err_init 154 | 155 | return losses 156 | 157 | def metrics(self, pred, data): 158 | T_q2r_gt = data['ref']['T_w2cam'] @ data['query']['T_w2cam'].inv() 159 | 160 | @torch.no_grad() 161 | def scaled_pose_error(T_r2q): 162 | err_R, err_t = (T_r2q @ T_q2r_gt).magnitude() 163 | if self.conf.normalize_dt: 164 | err_t /= torch.norm(T_q2r_gt.t, dim=-1) 165 | return err_R, err_t 166 | 167 | metrics = {} 168 | for i, T_opt in enumerate(pred['T_r2q_opt']): 169 | err = scaled_pose_error(T_opt) 170 | metrics[f'R_error/{i}'], metrics[f't_error/{i}'] = err 171 | metrics['R_error'], metrics['t_error'] = err 172 | 173 | err_init = scaled_pose_error(pred['T_r2q_init'][0]) 174 | metrics['R_error/init'], metrics['t_error/init'] = err_init 175 | 176 | return metrics 177 | -------------------------------------------------------------------------------- /pixloc/pixlib/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def masked_mean(x, mask, dim): 5 | mask = mask.float() 6 | return (mask * x).sum(dim) / mask.sum(dim).clamp(min=1) 7 | 8 | 9 | def checkpointed(cls, do=True): 10 | '''Adapted from the DISK implementation of Michał Tyszkiewicz.''' 11 | assert issubclass(cls, torch.nn.Module) 12 | 13 | class Checkpointed(cls): 14 | def forward(self, *args, **kwargs): 15 | super_fwd = super(Checkpointed, self).forward 16 | if any((torch.is_tensor(a) and a.requires_grad) for a in args): 17 | return torch.utils.checkpoint.checkpoint( 18 | super_fwd, *args, **kwargs) 19 | else: 20 | return super_fwd(*args, **kwargs) 21 | 22 | return Checkpointed if do else cls 23 | -------------------------------------------------------------------------------- /pixloc/pixlib/preprocess_cmu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | from multiprocessing import Pool 5 | import functools 6 | import pickle 7 | import scipy.spatial 8 | 9 | from .. import logger 10 | from ..settings import DATA_PATH, LOC_PATH 11 | from ..utils.colmap import read_model 12 | from ..utils.io import parse_image_lists 13 | 14 | 15 | def preprocess_slice(slice_, root, sfm_path, min_common=50, verbose=False): 16 | logger.info(f'Preprocessing {slice_}.') 17 | root = root / slice_ 18 | 19 | sfm = Path(str(sfm_path).format(slice_)) 20 | assert sfm.exists(), sfm 21 | cameras, images, points3D = read_model(sfm, ext='.bin') 22 | 23 | query_poses_paths = root / 'camera-poses/*.txt' 24 | query_images = parse_image_lists(query_poses_paths, with_poses=True) 25 | assert len(query_images) > 0 26 | 27 | p3D_ids = sorted(points3D.keys()) 28 | p3D_id_to_idx = dict(zip(p3D_ids, range(len(points3D)))) 29 | p3D_xyz = np.stack([points3D[i].xyz for i in p3D_ids]) 30 | track_lengths = np.stack([len(points3D[i].image_ids) for i in p3D_ids]) 31 | 32 | ref_ids = sorted(images.keys()) 33 | n_ref = len(images) 34 | if verbose: 35 | logger.info(f'Found {n_ref} ref images and {len(p3D_ids)} points.') 36 | 37 | ref_poses = [] 38 | ref_image_names = [] 39 | p3D_observed = [] 40 | for i in ref_ids: 41 | image = images[i] 42 | R = image.qvec2rotmat() 43 | t = image.tvec 44 | obs = np.stack([ 45 | p3D_id_to_idx[i] for i in image.point3D_ids if i != -1]) 46 | 47 | assert (root / 'database' / image.name).exists() 48 | 49 | ref_poses.append((R, t)) 50 | ref_image_names.append(image.name) 51 | p3D_observed.append(obs) 52 | 53 | query_poses = [] 54 | query_image_names = [] 55 | for _, image in query_images: 56 | R = image.qvec2rotmat() 57 | t = -R @ image.tvec 58 | query_poses.append((R, t)) 59 | query_image_names.append(image.name) 60 | 61 | assert (root / 'query' / image.name).exists() 62 | 63 | p3D_observed_sets = [set(p) for p in p3D_observed] 64 | ref_overlaps = np.full([n_ref]*2, -1.) 65 | for idx1 in tqdm(range(n_ref), disable=not verbose): 66 | for idx2 in range(n_ref): 67 | if idx1 == idx2: 68 | continue 69 | 70 | common = p3D_observed_sets[idx1] & p3D_observed_sets[idx2] 71 | if len(common) < min_common: 72 | continue 73 | ref_overlaps[idx1, idx2] = len(common)/len(p3D_observed_sets[idx1]) 74 | 75 | Rs_r = np.stack([p[0] for p in ref_poses]) 76 | ts_r = np.stack([p[1] for p in ref_poses]) 77 | Rs_q = np.stack([p[0] for p in query_poses]) 78 | ts_q = np.stack([p[1] for p in query_poses]) 79 | distances = scipy.spatial.distance.cdist( 80 | -np.einsum('nij,ni->nj', Rs_q, ts_q), 81 | -np.einsum('nij,ni->nj', Rs_r, ts_r)) 82 | trace = np.einsum('nij,mij->nm', Rs_q, Rs_r, optimize=True) 83 | dR = np.clip((trace - 1) / 2, -1., 1.) 84 | dR = np.rad2deg(np.abs(np.arccos(dR))) 85 | mask = (dR < 30) 86 | masked_distances = np.where(mask, distances, np.inf) 87 | 88 | closest = np.argmin(masked_distances, 1) 89 | dist_closest = masked_distances.min(1) 90 | query_overlaps = np.stack([ref_overlaps[c] for c in closest], 0) 91 | query_overlaps[dist_closest > 1.] = -1 92 | 93 | data = { 94 | 'points3D': p3D_xyz, 95 | 'track_lengths': track_lengths, 96 | 'ref_poses': ref_poses, 97 | 'ref_image_names': ref_image_names, 98 | 'p3D_observed': p3D_observed, 99 | 'query_poses': query_poses, 100 | 'query_image_names': query_image_names, 101 | 'query_closest_indices': closest, 102 | 'ref_overlap_matrix': ref_overlaps, 103 | 'query_overlap_matrix': query_overlaps, 104 | 'query_to_ref_distance_matrix': distances, 105 | } 106 | return data 107 | 108 | 109 | def preprocess_and_write(slice_, root, out_dir, **kwargs): 110 | path = out_dir / (slice_ + '.pkl') 111 | if path.exists(): 112 | return 113 | 114 | try: 115 | data = preprocess_slice(slice_, root, **kwargs) 116 | except: # noqa E722 117 | logger.info(f'Error for slice {slice_}.') 118 | raise 119 | if data is None: 120 | return 121 | 122 | logger.info(f'Writing slice {slice_} to {path}.') 123 | with open(path, 'wb') as f: 124 | pickle.dump(data, f) 125 | 126 | 127 | if __name__ == '__main__': 128 | root = DATA_PATH / 'CMU/' 129 | sfm = LOC_PATH / 'CMU/{}/sfm_superpoint+superglue/model/' 130 | out_dir = DATA_PATH / 'cmu_pixloc_training/' 131 | out_dir.mkdir(exist_ok=True) 132 | 133 | slices = [6, 7, 8, 9, 10, 11, 12, 13, 21, 22, 23, 24, 25] 134 | slices = [f'slice{i}' for i in slices] 135 | logger.info(f'Found {len(slices)} slices.') 136 | 137 | fn = functools.partial( 138 | preprocess_and_write, root=root, sfm_path=sfm, out_dir=out_dir) 139 | with Pool(5) as p: 140 | p.map(fn, slices) 141 | -------------------------------------------------------------------------------- /pixloc/pixlib/preprocess_megadepth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import logging 4 | from multiprocessing import Pool 5 | import functools 6 | import pickle 7 | 8 | from ..settings import DATA_PATH 9 | from ..utils.colmap import read_model 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def assemble_intrinsics(fx, fy, cx, cy): 15 | K = np.eye(3) 16 | K[0, 0] = fx 17 | K[1, 1] = fy 18 | K[0, 2] = cx - 0.5 # COLMAP convention 19 | K[1, 2] = cy - 0.5 20 | return K 21 | 22 | 23 | def get_camera_angles(R_c_to_w): 24 | trace = np.einsum('nji,mji->mn', R_c_to_w, R_c_to_w, optimize=True) 25 | dR = np.clip((trace - 1) / 2, -1., 1.) 26 | dR = np.rad2deg(np.abs(np.arccos(dR))) 27 | return dR 28 | 29 | 30 | def in_plane_rotation_matrix(rot): 31 | a = np.deg2rad(-90*rot) 32 | R = np.array([ 33 | [np.cos(a), -np.sin(a), 0], 34 | [np.sin(a), np.cos(a), 0], 35 | [0, 0, 1]]) 36 | return R 37 | 38 | 39 | def rotate_intrinsics(K, image_shape, rot): 40 | """Correct the intrinsics after in-plane rotation. 41 | Args: 42 | K: the original (3, 3) intrinsic matrix. 43 | image_shape: shape of the image after rotation `[H, W]`. 44 | rot: the number of clockwise 90deg rotations. 45 | """ 46 | h, w = image_shape[:2] 47 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 48 | rot = rot % 4 49 | if rot == 0: 50 | return K 51 | elif rot == 1: 52 | return np.array([[fy, 0., cy], 53 | [0., fx, w-1-cx], 54 | [0., 0., 1.]], dtype=K.dtype) 55 | elif rot == 2: 56 | return np.array([[fx, 0., w-1-cx], 57 | [0., fy, h-1-cy], 58 | [0., 0., 1.]], dtype=K.dtype) 59 | elif rot == 3: 60 | return np.array([[fy, 0., h-1-cy], 61 | [0., fx, cx], 62 | [0., 0., 1.]], dtype=K.dtype) 63 | else: 64 | raise ValueError 65 | 66 | 67 | def find_in_plane_rotations(R_c_to_w): 68 | gravity = np.median(R_c_to_w @ np.array([0, 1, 0]), 0) 69 | gravity_2D = (R_c_to_w.transpose(0, 2, 1) @ gravity)[:, :2] 70 | gravity_angle = np.rad2deg(np.arctan2(gravity_2D[:, 0], gravity_2D[:, 1])) 71 | rotated = np.abs(gravity_angle) > 60 72 | 73 | rot90 = np.array([-90, 180, -180, 90]) 74 | rot90_indices = np.array([1, 2, 2, 3]) 75 | rotations = np.zeros(len(rotated), int) 76 | if np.any(rotated): 77 | rots = np.argmin( 78 | np.abs(rot90[None] - gravity_angle[rotated][:, None]), -1) 79 | rotations[rotated] = rot90_indices[rots] 80 | return rotations 81 | 82 | 83 | def preprocess_scene(scene, root, min_common=50, verbose=False): 84 | logger.info(f'Preprocessing scene {scene}.') 85 | sfm = root / scene / 'sparse' 86 | if not sfm.exists(): # empty model 87 | logger.warning(f'Scene {scene} is empty.') 88 | return None 89 | cameras, images, points3D = read_model(sfm, ext='.bin') 90 | 91 | p3D_ids = sorted(points3D.keys()) 92 | p3D_id_to_idx = dict(zip(p3D_ids, range(len(points3D)))) 93 | p3D_xyz = np.stack([points3D[i].xyz for i in p3D_ids]) 94 | track_lengths = np.stack([len(points3D[i].image_ids) for i in p3D_ids]) 95 | 96 | images_ids = sorted(images.keys()) 97 | n_images = len(images) 98 | if n_images == 0: 99 | return None 100 | if verbose: 101 | logger.info(f'Found {n_images} images and {len(p3D_ids)} points.') 102 | 103 | intrinsics = [] 104 | poses = [] 105 | p3D_observed = [] 106 | image_names = [] 107 | too_small = [] 108 | for i in tqdm(images_ids, disable=not verbose): 109 | image = images[i] 110 | camera = cameras[image.camera_id] 111 | assert camera.model == 'PINHOLE', camera.model 112 | K = assemble_intrinsics(*camera.params) 113 | R = image.qvec2rotmat() 114 | t = image.tvec 115 | obs = np.stack([ 116 | p3D_id_to_idx[i] for i in image.point3D_ids if i != -1]) 117 | 118 | intrinsics.append(K) 119 | poses.append((R, t)) 120 | p3D_observed.append(obs) 121 | image_names.append(image.name) 122 | too_small.append(min(camera.height, camera.width) < 480) 123 | 124 | R_w_to_c = np.stack([R for R, _ in poses], 0) 125 | R_c_to_w = R_w_to_c.transpose(0, 2, 1) 126 | rotations = find_in_plane_rotations(R_c_to_w) 127 | for idx, rot in enumerate(rotations): 128 | if rot == 0: 129 | continue 130 | R_rot = in_plane_rotation_matrix(rot) 131 | R, t = poses[idx] 132 | poses[idx] = (R_rot@R, R_rot@t) 133 | 134 | image = images[images_ids[idx]] 135 | camera = cameras[image.camera_id] 136 | shape = (camera.height, camera.width) 137 | K = intrinsics[idx] 138 | intrinsics[idx] = rotate_intrinsics(K, shape, rot) 139 | 140 | R_w_to_c = np.stack([R for R, _ in poses], 0) 141 | R_c_to_w = R_w_to_c.transpose(0, 2, 1) 142 | camera_angles = get_camera_angles(R_c_to_w) 143 | 144 | p3D_observed_sets = [set(p) for p in p3D_observed] 145 | overlaps = np.full([n_images]*2, -1.) 146 | for idx1 in tqdm(range(n_images), disable=not verbose): 147 | for idx2 in range(idx1+1, n_images): 148 | if too_small[idx1] or too_small[idx2]: 149 | continue 150 | n_common = len(p3D_observed_sets[idx1] & p3D_observed_sets[idx2]) 151 | if n_common < min_common: 152 | continue 153 | overlaps[idx1, idx2] = n_common / len(p3D_observed[idx1]) 154 | overlaps[idx2, idx1] = n_common / len(p3D_observed[idx2]) 155 | 156 | data = { 157 | 'points3D': p3D_xyz, 158 | 'track_lengths': track_lengths, 159 | 'intrinsics': intrinsics, 160 | 'poses': poses, 161 | 'p3D_observed': p3D_observed, 162 | 'image_names': image_names, 163 | 'rotations': rotations, 164 | 'overlap_matrix': overlaps, 165 | 'angle_matrix': camera_angles, 166 | } 167 | return data 168 | 169 | 170 | def preprocess_and_write(scene, root, out_dir, **kwargs): 171 | path = out_dir / (scene + '.pkl') 172 | if path.exists(): 173 | return 174 | 175 | try: 176 | data = preprocess_scene(scene, root, **kwargs) 177 | except: # noqa E722 178 | logger.info(f'Error for scene {scene}.') 179 | raise 180 | if data is None: 181 | return 182 | 183 | logger.info(f'Writing scene {scene} to {path}.') 184 | with open(path, 'wb') as f: 185 | pickle.dump(data, f) 186 | 187 | 188 | if __name__ == '__main__': 189 | root = DATA_PATH / 'megadepth/Undistorted_SfM/' 190 | out_dir = DATA_PATH / 'megadepth_pixloc_training/' 191 | out_dir.mkdir(exist_ok=True) 192 | 193 | scenes = sorted([s.name for s in root.iterdir() if s.is_dir()]) 194 | logger.info(f'Found {len(scenes)} scenes.') 195 | 196 | fn = functools.partial(preprocess_and_write, root=root, out_dir=out_dir) 197 | with Pool(5) as p: 198 | p.map(fn, scenes) 199 | -------------------------------------------------------------------------------- /pixloc/pixlib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/pixloc/pixlib/utils/__init__.py -------------------------------------------------------------------------------- /pixloc/pixlib/utils/experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | A set of utilities to manage and load checkpoints of training experiments. 3 | """ 4 | 5 | from pathlib import Path 6 | import logging 7 | import re 8 | from omegaconf import OmegaConf 9 | import torch 10 | import os 11 | 12 | from ...settings import TRAINING_PATH 13 | from ..models import get_model 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def list_checkpoints(dir_): 19 | """List all valid checkpoints in a given directory.""" 20 | checkpoints = [] 21 | for p in dir_.glob('checkpoint_*.tar'): 22 | numbers = re.findall(r'(\d+)', p.name) 23 | if len(numbers) == 0: 24 | continue 25 | assert len(numbers) == 1 26 | checkpoints.append((int(numbers[0]), p)) 27 | return checkpoints 28 | 29 | 30 | def get_last_checkpoint(exper, allow_interrupted=True): 31 | """Get the last saved checkpoint for a given experiment name.""" 32 | ckpts = list_checkpoints(Path(TRAINING_PATH, exper)) 33 | if not allow_interrupted: 34 | ckpts = [(n, p) for (n, p) in ckpts if '_interrupted' not in p.name] 35 | assert len(ckpts) > 0 36 | return sorted(ckpts)[-1][1] 37 | 38 | 39 | def get_best_checkpoint(exper): 40 | """Get the checkpoint with the best loss, for a given experiment name.""" 41 | p = Path(TRAINING_PATH, exper, 'checkpoint_best.tar') 42 | return p 43 | 44 | 45 | def delete_old_checkpoints(dir_, num_keep): 46 | """Delete all but the num_keep last saved checkpoints.""" 47 | ckpts = list_checkpoints(dir_) 48 | ckpts = sorted(ckpts)[::-1] 49 | kept = 0 50 | for ckpt in ckpts: 51 | if ('_interrupted' in str(ckpt[1]) and kept > 0) or kept >= num_keep: 52 | logger.info(f'Deleting checkpoint {ckpt[1].name}') 53 | ckpt[1].unlink() 54 | else: 55 | kept += 1 56 | 57 | 58 | def load_experiment(exper, conf={}, get_last=False): 59 | """Load and return the model of a given experiment.""" 60 | if get_last: 61 | ckpt = get_last_checkpoint(exper) 62 | else: 63 | ckpt = get_best_checkpoint(exper) 64 | logger.info(f'Loading checkpoint {ckpt.name}') 65 | ckpt = torch.load(str(ckpt), map_location='cpu') 66 | 67 | loaded_conf = OmegaConf.create(ckpt['conf']) 68 | OmegaConf.set_struct(loaded_conf, False) 69 | conf = OmegaConf.merge(loaded_conf.model, OmegaConf.create(conf)) 70 | model = get_model(conf.name)(conf).eval() 71 | 72 | state_dict = ckpt['model'] 73 | dict_params = set(state_dict.keys()) 74 | model_params = set(map(lambda n: n[0], model.named_parameters())) 75 | diff = model_params - dict_params 76 | if len(diff) > 0: 77 | subs = os.path.commonprefix(list(diff)).rstrip('.') 78 | logger.warning(f'Missing {len(diff)} parameters in {subs}') 79 | model.load_state_dict(state_dict, strict=False) 80 | return model 81 | 82 | 83 | def flexible_load(state_dict, model): 84 | """TODO: fix a probable nasty bug, and move to BaseModel.""" 85 | dict_params = set(state_dict.keys()) 86 | model_params = set(map(lambda n: n[0], model.named_parameters())) 87 | 88 | if dict_params == model_params: # prefect fit 89 | logger.info('Loading all parameters of the checkpoint.') 90 | model.load_state_dict(state_dict, strict=True) 91 | return 92 | elif len(dict_params & model_params) == 0: # perfect mismatch 93 | strip_prefix = lambda x: '.'.join(x.split('.')[:1]+x.split('.')[2:]) 94 | state_dict = {strip_prefix(n): p for n, p in state_dict.items()} 95 | dict_params = set(state_dict.keys()) 96 | if len(dict_params & model_params) == 0: 97 | raise ValueError('Could not manage to load the checkpoint with' 98 | 'parameters:' + '\n\t'.join(sorted(dict_params))) 99 | common_params = dict_params & model_params 100 | left_params = dict_params - model_params 101 | logger.info('Loading parameters:\n\t'+'\n\t'.join(sorted(common_params))) 102 | if len(left_params) > 0: 103 | logger.info('Could not load parameters:\n\t' 104 | + '\n\t'.join(sorted(left_params))) 105 | model.load_state_dict(state_dict, strict=False) 106 | -------------------------------------------------------------------------------- /pixloc/pixlib/utils/stdout_capturing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on sacred/stdout_capturing.py in project Sacred 3 | https://github.com/IDSIA/sacred 4 | """ 5 | 6 | from __future__ import division, print_function, unicode_literals 7 | import os 8 | import sys 9 | import subprocess 10 | from threading import Timer 11 | from contextlib import contextmanager 12 | 13 | 14 | def apply_backspaces_and_linefeeds(text): 15 | """ 16 | Interpret backspaces and linefeeds in text like a terminal would. 17 | Interpret text like a terminal by removing backspace and linefeed 18 | characters and applying them line by line. 19 | If final line ends with a carriage it keeps it to be concatenable with next 20 | output chunk. 21 | """ 22 | orig_lines = text.split('\n') 23 | orig_lines_len = len(orig_lines) 24 | new_lines = [] 25 | for orig_line_idx, orig_line in enumerate(orig_lines): 26 | chars, cursor = [], 0 27 | orig_line_len = len(orig_line) 28 | for orig_char_idx, orig_char in enumerate(orig_line): 29 | if orig_char == '\r' and (orig_char_idx != orig_line_len - 1 or 30 | orig_line_idx != orig_lines_len - 1): 31 | cursor = 0 32 | elif orig_char == '\b': 33 | cursor = max(0, cursor - 1) 34 | else: 35 | if (orig_char == '\r' and 36 | orig_char_idx == orig_line_len - 1 and 37 | orig_line_idx == orig_lines_len - 1): 38 | cursor = len(chars) 39 | if cursor == len(chars): 40 | chars.append(orig_char) 41 | else: 42 | chars[cursor] = orig_char 43 | cursor += 1 44 | new_lines.append(''.join(chars)) 45 | return '\n'.join(new_lines) 46 | 47 | 48 | def flush(): 49 | """Try to flush all stdio buffers, both from python and from C.""" 50 | try: 51 | sys.stdout.flush() 52 | sys.stderr.flush() 53 | except (AttributeError, ValueError, IOError): 54 | pass # unsupported 55 | 56 | 57 | # Duplicate stdout and stderr to a file. Inspired by: 58 | # http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ 59 | # http://stackoverflow.com/a/651718/1388435 60 | # http://stackoverflow.com/a/22434262/1388435 61 | @contextmanager 62 | def capture_outputs(filename): 63 | """Duplicate stdout and stderr to a file on the file descriptor level.""" 64 | with open(str(filename), 'a+') as target: 65 | original_stdout_fd = 1 66 | original_stderr_fd = 2 67 | target_fd = target.fileno() 68 | 69 | # Save a copy of the original stdout and stderr file descriptors 70 | saved_stdout_fd = os.dup(original_stdout_fd) 71 | saved_stderr_fd = os.dup(original_stderr_fd) 72 | 73 | tee_stdout = subprocess.Popen( 74 | ['tee', '-a', '-i', '/dev/stderr'], start_new_session=True, 75 | stdin=subprocess.PIPE, stderr=target_fd, stdout=1) 76 | tee_stderr = subprocess.Popen( 77 | ['tee', '-a', '-i', '/dev/stderr'], start_new_session=True, 78 | stdin=subprocess.PIPE, stderr=target_fd, stdout=2) 79 | 80 | flush() 81 | os.dup2(tee_stdout.stdin.fileno(), original_stdout_fd) 82 | os.dup2(tee_stderr.stdin.fileno(), original_stderr_fd) 83 | 84 | try: 85 | yield 86 | finally: 87 | flush() 88 | 89 | # then redirect stdout back to the saved fd 90 | tee_stdout.stdin.close() 91 | tee_stderr.stdin.close() 92 | 93 | # restore original fds 94 | os.dup2(saved_stdout_fd, original_stdout_fd) 95 | os.dup2(saved_stderr_fd, original_stderr_fd) 96 | 97 | # wait for completion of the tee processes with timeout 98 | # implemented using a timer because timeout support is py3 only 99 | def kill_tees(): 100 | tee_stdout.kill() 101 | tee_stderr.kill() 102 | 103 | tee_timer = Timer(1, kill_tees) 104 | try: 105 | tee_timer.start() 106 | tee_stdout.wait() 107 | tee_stderr.wait() 108 | finally: 109 | tee_timer.cancel() 110 | 111 | os.close(saved_stdout_fd) 112 | os.close(saved_stderr_fd) 113 | 114 | # Cleanup log file 115 | with open(str(filename), 'r') as target: 116 | text = target.read() 117 | text = apply_backspaces_and_linefeeds(text) 118 | with open(str(filename), 'w') as target: 119 | target.write(text) 120 | -------------------------------------------------------------------------------- /pixloc/pixlib/utils/tensor.py: -------------------------------------------------------------------------------- 1 | from torch._six import string_classes 2 | import collections.abc as collections 3 | 4 | 5 | def map_tensor(input_, func): 6 | if isinstance(input_, string_classes): 7 | return input_ 8 | elif isinstance(input_, collections.Mapping): 9 | return {k: map_tensor(sample, func) for k, sample in input_.items()} 10 | elif isinstance(input_, collections.Sequence): 11 | return [map_tensor(sample, func) for sample in input_] 12 | else: 13 | return func(input_) 14 | 15 | 16 | def batch_to_numpy(batch): 17 | return map_tensor(batch, lambda tensor: tensor.cpu().numpy()) 18 | 19 | 20 | def batch_to_device(batch, device, non_blocking=True): 21 | def _func(tensor): 22 | return tensor.to(device=device, non_blocking=non_blocking) 23 | 24 | return map_tensor(batch, _func) 25 | -------------------------------------------------------------------------------- /pixloc/pixlib/utils/tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various handy Python and PyTorch utils. 3 | """ 4 | 5 | import time 6 | import inspect 7 | import numpy as np 8 | import os 9 | import torch 10 | import random 11 | from contextlib import contextmanager 12 | 13 | 14 | class AverageMetric: 15 | def __init__(self): 16 | self._sum = 0 17 | self._num_examples = 0 18 | 19 | def update(self, tensor): 20 | assert tensor.dim() == 1 21 | tensor = tensor[~torch.isnan(tensor)] 22 | self._sum += tensor.sum().item() 23 | self._num_examples += len(tensor) 24 | 25 | def compute(self): 26 | if self._num_examples == 0: 27 | return np.nan 28 | else: 29 | return self._sum / self._num_examples 30 | 31 | 32 | class MedianMetric: 33 | def __init__(self): 34 | self._elements = [] 35 | 36 | def update(self, tensor): 37 | assert tensor.dim() == 1 38 | self._elements += tensor.cpu().numpy().tolist() 39 | 40 | def compute(self): 41 | if len(self._elements) == 0: 42 | return np.nan 43 | else: 44 | return np.nanmedian(self._elements) 45 | 46 | 47 | def get_class(mod_name, base_path, BaseClass): 48 | """Get the class object which inherits from BaseClass and is defined in 49 | the module named mod_name, child of base_path. 50 | """ 51 | mod_path = '{}.{}'.format(base_path, mod_name) 52 | mod = __import__(mod_path, fromlist=['']) 53 | classes = inspect.getmembers(mod, inspect.isclass) 54 | # Filter classes defined in the module 55 | classes = [c for c in classes if c[1].__module__ == mod_path] 56 | # Filter classes inherited from BaseModel 57 | classes = [c for c in classes if issubclass(c[1], BaseClass)] 58 | assert len(classes) == 1, classes 59 | return classes[0][1] 60 | 61 | 62 | class Timer(object): 63 | """A simpler timer context object. 64 | Usage: 65 | ``` 66 | > with Timer('mytimer'): 67 | > # some computations 68 | [mytimer] Elapsed: X 69 | ``` 70 | """ 71 | def __init__(self, name=None): 72 | self.name = name 73 | 74 | def __enter__(self): 75 | self.tstart = time.time() 76 | return self 77 | 78 | def __exit__(self, type, value, traceback): 79 | self.duration = time.time() - self.tstart 80 | if self.name is not None: 81 | print('[%s] Elapsed: %s' % (self.name, self.duration)) 82 | 83 | 84 | def set_num_threads(nt): 85 | """Force numpy and other libraries to use a limited number of threads.""" 86 | try: 87 | import mkl 88 | except ImportError: 89 | pass 90 | else: 91 | mkl.set_num_threads(nt) 92 | torch.set_num_threads(1) 93 | os.environ['IPC_ENABLE'] = '1' 94 | for o in ['OPENBLAS_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 95 | 'OMP_NUM_THREADS', 'MKL_NUM_THREADS']: 96 | os.environ[o] = str(nt) 97 | 98 | 99 | def set_seed(seed): 100 | random.seed(seed) 101 | torch.manual_seed(seed) 102 | np.random.seed(seed) 103 | if torch.cuda.is_available(): 104 | torch.cuda.manual_seed(seed) 105 | torch.cuda.manual_seed_all(seed) 106 | 107 | 108 | def get_random_state(): 109 | pth_state = torch.get_rng_state() 110 | np_state = np.random.get_state() 111 | py_state = random.getstate() 112 | if torch.cuda.is_available(): 113 | cuda_state = torch.cuda.get_rng_state_all() 114 | else: 115 | cuda_state = None 116 | return pth_state, np_state, py_state, cuda_state 117 | 118 | 119 | def set_random_state(state): 120 | pth_state, np_state, py_state, cuda_state = state 121 | torch.set_rng_state(pth_state) 122 | np.random.set_state(np_state) 123 | random.setstate(py_state) 124 | if (cuda_state is not None 125 | and torch.cuda.is_available() 126 | and len(cuda_state) == torch.cuda.device_count()): 127 | torch.cuda.set_rng_state_all(cuda_state) 128 | 129 | 130 | @contextmanager 131 | def fork_rng(seed=None): 132 | state = get_random_state() 133 | if seed is not None: 134 | set_seed(seed) 135 | try: 136 | yield 137 | finally: 138 | set_random_state(state) 139 | -------------------------------------------------------------------------------- /pixloc/run_7Scenes.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from . import set_logging_debug, logger 4 | from .localization import RetrievalLocalizer, PoseLocalizer 5 | from .utils.data import Paths, create_argparser, parse_paths, parse_conf 6 | from .utils.io import write_pose_results 7 | from .utils.eval import evaluate 8 | 9 | default_paths = Paths( 10 | query_images='{scene}/', 11 | reference_images='{scene}/', 12 | reference_sfm='{scene}/sfm_superpoint+superglue+depth/', 13 | query_list='{scene}/query_list_with_intrinsics.txt', 14 | retrieval_pairs='7scenes_densevlad_retrieval/{scene}_top10.txt', 15 | ground_truth='7scenes_sfm_triangulated/{scene}/triangulated/', 16 | results='pixloc_7scenes_{scene}.txt', 17 | ) 18 | 19 | experiment = 'pixloc_megadepth' 20 | 21 | default_confs = { 22 | 'from_retrieval': { 23 | 'experiment': experiment, 24 | 'features': {}, 25 | 'optimizer': { 26 | 'num_iters': 100, 27 | 'pad': 2, # to 1? 28 | }, 29 | 'refinement': { 30 | 'num_dbs': 5, 31 | 'multiscale': [4, 1], 32 | 'point_selection': 'all', 33 | 'normalize_descriptors': True, 34 | 'average_observations': False, 35 | 'filter_covisibility': False, 36 | 'do_pose_approximation': False, 37 | }, 38 | }, 39 | 'from_poses': { 40 | 'experiment': experiment, 41 | 'features': {}, 42 | 'optimizer': { 43 | 'num_iters': 100, 44 | 'pad': 2, 45 | }, 46 | 'refinement': { 47 | 'num_dbs': 5, 48 | 'min_points_opt': 100, 49 | 'point_selection': 'inliers', 50 | 'normalize_descriptors': True, 51 | 'average_observations': True, 52 | 'layer_indices': [0, 1], 53 | }, 54 | }, 55 | } 56 | 57 | SCENES = ['chess', 'fire', 'heads', 'office', 'pumpkin', 58 | 'redkitchen', 'stairs'] 59 | 60 | 61 | def main(): 62 | parser = create_argparser('7Scenes') 63 | parser.add_argument('--scenes', default=SCENES, choices=SCENES, nargs='+') 64 | parser.add_argument('--eval_only', action='store_true') 65 | args = parser.parse_args() 66 | 67 | set_logging_debug(args.verbose) 68 | paths = parse_paths(args, default_paths) 69 | conf = parse_conf(args, default_confs) 70 | 71 | all_poses = {} 72 | for scene in args.scenes: 73 | logger.info('Working on scene %s.', scene) 74 | paths_scene = paths.interpolate(scene=scene) 75 | if args.eval_only and paths_scene.results.exists(): 76 | all_poses[scene] = paths_scene.results 77 | continue 78 | 79 | if args.from_poses: 80 | localizer = PoseLocalizer(paths_scene, conf) 81 | else: 82 | localizer = RetrievalLocalizer(paths_scene, conf) 83 | poses, logs = localizer.run_batched(skip=args.skip) 84 | write_pose_results(poses, paths_scene.results, 85 | prepend_camera_name=True) 86 | with open(f'{paths_scene.results}_logs.pkl', 'wb') as f: 87 | pickle.dump(logs, f) 88 | all_poses[scene] = poses 89 | 90 | for scene in args.scenes: 91 | paths_scene = paths.interpolate(scene=scene) 92 | logger.info('Evaluate scene %s: %s', scene, paths_scene.results) 93 | evaluate(paths_scene.ground_truth, all_poses[scene], 94 | paths_scene.ground_truth / 'list_test.txt', 95 | only_localized=(args.skip is not None and args.skip > 1)) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /pixloc/run_Aachen.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from . import set_logging_debug 4 | from .localization import RetrievalLocalizer, PoseLocalizer 5 | from .utils.data import Paths, create_argparser, parse_paths, parse_conf 6 | from .utils.io import write_pose_results 7 | 8 | 9 | default_paths = Paths( 10 | query_images='images/images_upright/', 11 | reference_images='images/images_upright/', 12 | reference_sfm='sfm_superpoint+superglue/', 13 | query_list='*_time_queries_with_intrinsics.txt', 14 | global_descriptors='aachen_tf-netvlad.h5', 15 | retrieval_pairs='pairs-query-netvlad50.txt', 16 | results='pixloc_Aachen.txt', 17 | ) 18 | 19 | experiment = 'pixloc_megadepth' 20 | 21 | default_confs = { 22 | 'from_retrieval': { 23 | 'experiment': experiment, 24 | 'features': {}, 25 | 'optimizer': { 26 | 'num_iters': 150, 27 | 'pad': 1, 28 | }, 29 | 'refinement': { 30 | 'num_dbs': 3, 31 | 'multiscale': [4, 1], 32 | 'point_selection': 'all', 33 | 'normalize_descriptors': True, 34 | 'average_observations': False, 35 | 'do_pose_approximation': True, 36 | }, 37 | }, 38 | 'from_poses': { 39 | 'experiment': experiment, 40 | 'features': {'preprocessing': {'resize': 1600}}, 41 | 'optimizer': { 42 | 'num_iters': 50, 43 | 'pad': 1, 44 | }, 45 | 'refinement': { 46 | 'num_dbs': 5, 47 | 'min_points_opt': 100, 48 | 'point_selection': 'inliers', 49 | 'normalize_descriptors': True, 50 | 'average_observations': True, 51 | 'layer_indices': [0, 1], 52 | }, 53 | }, 54 | } 55 | 56 | 57 | def main(): 58 | parser = create_argparser('Aachen') 59 | args = parser.parse_args() 60 | set_logging_debug(args.verbose) 61 | paths = parse_paths(args, default_paths) 62 | conf = parse_conf(args, default_confs) 63 | 64 | if args.from_poses: 65 | localizer = PoseLocalizer(paths, conf) 66 | else: 67 | localizer = RetrievalLocalizer(paths, conf) 68 | poses, logs = localizer.run_batched(skip=args.skip) 69 | 70 | write_pose_results(poses, paths.results) 71 | with open(f'{paths.results}_logs.pkl', 'wb') as f: 72 | pickle.dump(logs, f) 73 | 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /pixloc/run_CMU.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from . import set_logging_debug, logger 4 | from .localization import RetrievalLocalizer, PoseLocalizer 5 | from .utils.data import Paths, create_argparser, parse_paths, parse_conf 6 | from .utils.io import write_pose_results, concat_results 7 | 8 | 9 | default_paths = Paths( 10 | query_images='slice{slice}/query/', 11 | reference_images='slice{slice}/database', 12 | reference_sfm='slice{slice}/sfm_superpoint+superglue/model/', 13 | query_list='slice{slice}/queries_with_intrinsics.txt', 14 | global_descriptors='slice{slice}/cmu-slice{slice}_tf-netvlad.h5', 15 | retrieval_pairs='slice{slice}/pairs-query-netvlad10.txt', 16 | hloc_logs='slice{slice}/CMU_hloc_superpoint+superglue_netvlad10.txt_logs.pkl', 17 | results='pixloc_CMU_slice{slice}.txt', 18 | ) 19 | 20 | experiment = 'pixloc_cmu' 21 | 22 | default_confs = { 23 | 'from_retrieval': { 24 | 'experiment': experiment, 25 | 'features': {}, 26 | 'optimizer': { 27 | 'num_iters': 100, 28 | 'pad': 2, 29 | }, 30 | 'refinement': { 31 | 'num_dbs': 2, 32 | 'point_selection': 'all', 33 | 'normalize_descriptors': True, 34 | 'average_observations': False, 35 | 'filter_covisibility': False, 36 | 'do_pose_approximation': False, 37 | }, 38 | }, 39 | 'from_poses': { 40 | 'experiment': experiment, 41 | 'features': {}, 42 | 'optimizer': { 43 | 'num_iters': 100, 44 | 'pad': 2, 45 | }, 46 | 'refinement': { 47 | 'num_dbs': 5, 48 | 'min_points_opt': 100, 49 | 'point_selection': 'inliers', 50 | 'normalize_descriptors': True, 51 | 'average_observations': False, 52 | 'layer_indices': [0, 1], 53 | }, 54 | }, 55 | } 56 | 57 | TEST_URBAN = [2, 3, 4, 5, 6] 58 | TEST_SUBURBAN = [13, 14, 15, 16, 17] 59 | TEST_PARK = [18, 19, 20, 21] 60 | TEST_SLICES_CMU = TEST_URBAN + TEST_SUBURBAN + TEST_PARK 61 | TRAINING_SLICES_CMU = [7, 8, 9, 10, 11, 12, 22, 23, 24, 25] 62 | 63 | 64 | def generate_query_list(paths, slice_): 65 | cameras = {} 66 | with open(paths.dataset / 'intrinsics.txt', 'r') as f: 67 | for line in f.readlines(): 68 | if line[0] == '#' or line == '\n': 69 | continue 70 | data = line.split() 71 | cameras[data[0]] = data[1:] 72 | assert len(cameras) == 2 73 | 74 | queries = paths.dataset / f'slice{slice_}/test-images-slice{slice_}.txt' 75 | with open(queries, 'r') as f: 76 | queries = [q.rstrip('\n') for q in f.readlines()] 77 | 78 | out = [[q] + cameras[q.split('_')[2]] for q in queries] 79 | with open(paths.query_list, 'w') as f: 80 | f.write('\n'.join(map(' '.join, out))) 81 | 82 | 83 | def parse_slice_arg(slice_str): 84 | if slice_str is None: 85 | slices = TEST_SLICES_CMU 86 | logger.info( 87 | 'No slice list given, will evaluate all %d test slices; ' 88 | 'this might take a long time.', len(slices)) 89 | elif '-' in slice_str: 90 | min_, max_ = slice_str.split('-') 91 | slices = list(range(int(min_), int(max_)+1)) 92 | else: 93 | slices = eval(slice_str) 94 | if isinstance(slices, int): 95 | slices = [slices] 96 | return slices 97 | 98 | 99 | def main(): 100 | parser = create_argparser('CMU') 101 | parser.add_argument('--slices', type=str, 102 | help='a single number, an interval (e.g. 2-6), ' 103 | 'or a Python-style list or int (e.g. [2, 3, 4]') 104 | args = parser.parse_args() 105 | 106 | set_logging_debug(args.verbose) 107 | paths = parse_paths(args, default_paths) 108 | conf = parse_conf(args, default_confs) 109 | slices = parse_slice_arg(args.slices) 110 | 111 | all_results = [] 112 | logger.info('Will evaluate slices %s.', slices) 113 | for slice_ in slices: 114 | logger.info('Working on slice %s.', slice_) 115 | paths_slice = paths.interpolate(slice=slice_) 116 | all_results.append(paths_slice.results) 117 | if paths_slice.results.exists(): 118 | continue 119 | if not paths_slice.query_list.exists(): 120 | generate_query_list(paths_slice, slice_) 121 | 122 | if args.from_poses: 123 | localizer = PoseLocalizer(paths_slice, conf) 124 | else: 125 | localizer = RetrievalLocalizer(paths_slice, conf) 126 | poses, logs = localizer.run_batched(skip=args.skip) 127 | write_pose_results(poses, paths_slice.results) 128 | with open(f'{paths_slice.results}_logs.pkl', 'wb') as f: 129 | pickle.dump(logs, f) 130 | 131 | output_path = concat_results(all_results, slices, paths.results, 'slice') 132 | logger.info( 133 | 'Finished evaluating all slices, you can now submit the file %s to ' 134 | 'https://www.visuallocalization.net/submission/', output_path) 135 | 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /pixloc/run_Cambridge.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from . import set_logging_debug, logger 4 | from .localization import RetrievalLocalizer, PoseLocalizer 5 | from .utils.data import Paths, create_argparser, parse_paths, parse_conf 6 | from .utils.io import write_pose_results 7 | from .utils.eval import evaluate 8 | 9 | default_paths = Paths( 10 | query_images='{scene}/', 11 | reference_images='{scene}/', 12 | reference_sfm='{scene}/sfm_superpoint+superglue/', 13 | query_list='{scene}/query_list_with_intrinsics.txt', 14 | retrieval_pairs='{scene}/pairs-query-netvlad10.txt', 15 | ground_truth='CambridgeLandmarks_Colmap_Retriangulated_1024px/{scene}/', 16 | results='pixloc_Cambridge_{scene}.txt', 17 | ) 18 | 19 | experiment = 'pixloc_megadepth' 20 | 21 | default_confs = { 22 | 'from_retrieval': { 23 | 'experiment': experiment, 24 | 'features': {}, 25 | 'optimizer': { 26 | 'num_iters': 100, 27 | 'pad': 2, # to 1? 28 | }, 29 | 'refinement': { 30 | 'num_dbs': 5, 31 | 'multiscale': [4, 1], 32 | 'point_selection': 'all', 33 | 'normalize_descriptors': True, 34 | 'average_observations': True, 35 | 'filter_covisibility': False, 36 | 'do_pose_approximation': False, 37 | }, 38 | }, 39 | 'from_poses': { 40 | 'experiment': experiment, 41 | 'features': {}, 42 | 'optimizer': { 43 | 'num_iters': 100, 44 | 'pad': 2, 45 | }, 46 | 'refinement': { 47 | 'num_dbs': 5, 48 | 'min_points_opt': 100, 49 | 'point_selection': 'inliers', 50 | 'normalize_descriptors': True, 51 | 'average_observations': True, 52 | 'layer_indices': [0, 1], 53 | }, 54 | }, 55 | } 56 | 57 | SCENES = ['ShopFacade', 'KingsCollege', 'GreatCourt', 'OldHospital', 58 | 'StMarysChurch'] 59 | 60 | 61 | def main(): 62 | parser = create_argparser('Cambridge') 63 | parser.add_argument('--scenes', default=SCENES, choices=SCENES, nargs='+') 64 | parser.add_argument('--eval_only', action='store_true') 65 | args = parser.parse_args() 66 | 67 | set_logging_debug(args.verbose) 68 | paths = parse_paths(args, default_paths) 69 | conf = parse_conf(args, default_confs) 70 | 71 | all_poses = {} 72 | for scene in args.scenes: 73 | logger.info('Working on scene %s.', scene) 74 | paths_scene = paths.interpolate(scene=scene) 75 | if args.eval_only and paths_scene.results.exists(): 76 | all_poses[scene] = paths_scene.results 77 | continue 78 | 79 | if args.from_poses: 80 | localizer = PoseLocalizer(paths_scene, conf) 81 | else: 82 | localizer = RetrievalLocalizer(paths_scene, conf) 83 | poses, logs = localizer.run_batched(skip=args.skip) 84 | write_pose_results(poses, paths_scene.results, 85 | prepend_camera_name=True) 86 | with open(f'{paths_scene.results}_logs.pkl', 'wb') as f: 87 | pickle.dump(logs, f) 88 | all_poses[scene] = poses 89 | 90 | for scene in args.scenes: 91 | paths_scene = paths.interpolate(scene=scene) 92 | logger.info('Evaluate scene %s: %s', scene, paths_scene.results) 93 | evaluate(paths_scene.ground_truth / 'empty_all', all_poses[scene], 94 | paths_scene.ground_truth / 'list_query.txt', 95 | only_localized=(args.skip is not None and args.skip > 1)) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /pixloc/run_RobotCar.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | from . import set_logging_debug, logger 5 | from .localization import RetrievalLocalizer, PoseLocalizer 6 | from .utils.data import Paths, create_argparser, parse_paths, parse_conf 7 | from .utils.io import write_pose_results, concat_results 8 | 9 | default_paths = Paths( 10 | query_images='images/', 11 | reference_images='images/', 12 | reference_sfm='sfm_superpoint+superglue/', 13 | query_list='{condition}_queries_with_intrinsics.txt', 14 | global_descriptors='robotcar_ov-ref_tf-netvlad.h5', 15 | retrieval_pairs='pairs-query-netvlad10-percam-perloc.txt', 16 | results='pixloc_RobotCar_{condition}.txt', 17 | ) 18 | 19 | experiment = 'pixloc_cmu' 20 | 21 | default_confs = { 22 | 'from_retrieval': { 23 | 'experiment': experiment, 24 | 'features': {}, 25 | 'optimizer': { 26 | 'num_iters': 100, 27 | 'pad': 2, 28 | }, 29 | 'refinement': { 30 | 'num_dbs': 2, 31 | 'point_selection': 'all', 32 | 'normalize_descriptors': True, 33 | 'average_observations': False, 34 | 'filter_covisibility': False, 35 | 'do_pose_approximation': False, 36 | }, 37 | }, 38 | 'from_poses': { 39 | 'experiment': experiment, 40 | 'features': {}, 41 | 'optimizer': { 42 | 'num_iters': 100, 43 | 'pad': 2, 44 | }, 45 | 'refinement': { 46 | 'num_dbs': 5, 47 | 'min_points_opt': 100, 48 | 'point_selection': 'inliers', 49 | 'normalize_descriptors': True, 50 | 'average_observations': False, 51 | 'layer_indices': [0, 1], 52 | }, 53 | }, 54 | } 55 | 56 | 57 | CONDITIONS = ['dawn', 'dusk', 'night', 'night-rain', 'overcast-summer', 58 | 'overcast-winter', 'rain', 'snow', 'sun'] 59 | 60 | 61 | def generate_query_list(paths, condition): 62 | h, w = 1024, 1024 63 | intrinsics_filename = 'intrinsics/{}_intrinsics.txt' 64 | cameras = {} 65 | for side in ['left', 'right', 'rear']: 66 | with open(paths.dataset / intrinsics_filename.format(side), 'r') as f: 67 | fx = f.readline().split()[1] 68 | fy = f.readline().split()[1] 69 | cx = f.readline().split()[1] 70 | cy = f.readline().split()[1] 71 | assert fx == fy 72 | params = ['SIMPLE_RADIAL', w, h, fx, cx, cy, 0.0] 73 | cameras[side] = [str(p) for p in params] 74 | 75 | queries = sorted((paths.query_images / condition).glob('**/*.jpg')) 76 | queries = [str(q.relative_to(paths.query_images)) for q in queries] 77 | 78 | out = [[q] + cameras[Path(q).parent.name] for q in queries] 79 | with open(paths.query_list, 'w') as f: 80 | f.write('\n'.join(map(' '.join, out))) 81 | 82 | 83 | def main(): 84 | parser = create_argparser('RobotCar') 85 | parser.add_argument('--conditions', default=CONDITIONS, choices=CONDITIONS, 86 | nargs='+') 87 | args = parser.parse_args() 88 | 89 | set_logging_debug(args.verbose) 90 | paths = parse_paths(args, default_paths) 91 | conf = parse_conf(args, default_confs) 92 | logger.info('Will evaluate %s conditions.', len(args.conditions)) 93 | 94 | all_results = [] 95 | for condition in args.conditions: 96 | logger.info('Working on condition %s.', condition) 97 | paths_cond = paths.interpolate(condition=condition) 98 | all_results.append(paths_cond.results) 99 | if paths_cond.results.exists(): 100 | continue 101 | if not paths_cond.query_list.exists(): 102 | generate_query_list(paths_cond, condition) 103 | 104 | if args.from_poses: 105 | localizer = PoseLocalizer(paths_cond, conf) 106 | else: 107 | localizer = RetrievalLocalizer(paths_cond, conf) 108 | poses, logs = localizer.run_batched(skip=args.skip) 109 | write_pose_results(poses, paths_cond.results, prepend_camera_name=True) 110 | with open(f'{paths_cond.results}_logs.pkl', 'wb') as f: 111 | pickle.dump(logs, f) 112 | 113 | output_path = concat_results( 114 | all_results, args.conditions, paths.results, 'condition') 115 | logger.info( 116 | 'Finished evaluating all conditions, you can now submit the file %s to' 117 | ' https://www.visuallocalization.net/submission/', output_path) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /pixloc/settings.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | root = Path(__file__).parent.parent # top-level directory 4 | DATA_PATH = root / 'datasets/' # datasets and pretrained weights 5 | TRAINING_PATH = root / 'outputs/training/' # training checkpoints 6 | LOC_PATH = root / 'outputs/hloc/' # localization logs 7 | EVAL_PATH = root / 'outputs/results/' # evaluation results 8 | -------------------------------------------------------------------------------- /pixloc/utils/data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | from pathlib import Path 4 | from typing import Dict, List, Optional 5 | from omegaconf import DictConfig, OmegaConf as oc 6 | 7 | from .. import settings, logger 8 | 9 | 10 | @dataclasses.dataclass 11 | class Paths: 12 | query_images: Path 13 | reference_images: Path 14 | reference_sfm: Path 15 | query_list: Path 16 | 17 | dataset: Optional[Path] = None 18 | dumps: Optional[Path] = None 19 | 20 | retrieval_pairs: Optional[Path] = None 21 | results: Optional[Path] = None 22 | global_descriptors: Optional[Path] = None 23 | hloc_logs: Optional[Path] = None 24 | ground_truth: Optional[Path] = None 25 | 26 | def interpolate(self, **kwargs) -> 'Paths': 27 | args = {} 28 | for f in dataclasses.fields(self): 29 | val = getattr(self, f.name) 30 | if val is not None: 31 | val = str(val) 32 | for k, v in kwargs.items(): 33 | val = val.replace(f'{{{k}}}', str(v)) 34 | val = Path(val) 35 | args[f.name] = val 36 | return self.__class__(**args) 37 | 38 | def asdict(self) -> Dict[str, Path]: 39 | return dataclasses.asdict(self) 40 | 41 | @classmethod 42 | def fields(cls) -> List[str]: 43 | return [f.name for f in dataclasses.fields(cls)] 44 | 45 | def add_prefixes(self, dataset: Path, dumps: Path, 46 | eval_dir: Optional[Path] = Path('.')) -> 'Paths': 47 | paths = {} 48 | for attr in self.fields(): 49 | val = getattr(self, attr) 50 | if val is not None: 51 | if attr in {'dataset', 'dumps'}: 52 | paths[attr] = val 53 | elif attr in {'query_images', 54 | 'reference_images', 55 | 'ground_truth'}: 56 | paths[attr] = dataset / val 57 | elif attr in {'results'}: 58 | paths[attr] = eval_dir / val 59 | else: # everything else is part of the hloc dumps 60 | paths[attr] = dumps / val 61 | paths['dataset'] = dataset 62 | paths['dumps'] = dumps 63 | return self.__class__(**paths) 64 | 65 | 66 | def create_argparser(dataset: str) -> argparse.ArgumentParser: 67 | parser = argparse.ArgumentParser( 68 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 69 | 70 | parser.add_argument('--results', type=Path) 71 | parser.add_argument('--reference_sfm', type=Path) 72 | parser.add_argument('--retrieval', type=Path) 73 | parser.add_argument('--global_descriptors', type=Path) 74 | parser.add_argument('--hloc_logs', type=Path) 75 | 76 | parser.add_argument('--dataset', type=Path, 77 | default=settings.DATA_PATH / dataset) 78 | parser.add_argument('--dumps', type=Path, 79 | default=settings.LOC_PATH / dataset) 80 | parser.add_argument('--eval_dir', type=Path, 81 | default=settings.EVAL_PATH) 82 | 83 | parser.add_argument('--from_poses', action='store_true') 84 | parser.add_argument('--inlier_ranking', action='store_true') 85 | parser.add_argument('--skip', type=int) 86 | parser.add_argument('--verbose', action='store_true') 87 | parser.add_argument('dotlist', nargs='*') 88 | 89 | return parser 90 | 91 | 92 | def parse_paths(args, default_paths: Paths) -> Paths: 93 | default_paths = default_paths.add_prefixes( 94 | args.dataset, args.dumps, args.eval_dir) 95 | paths = {} 96 | for attr in Paths.fields(): 97 | val = getattr(args, attr, None) 98 | if val is None: 99 | val = getattr(default_paths, attr, None) 100 | if val is None: 101 | continue 102 | paths[attr] = val 103 | return Paths(**paths) 104 | 105 | 106 | def parse_conf(args, default_confs: Dict) -> DictConfig: 107 | conf = default_confs['from_poses' if args.from_poses else 'from_retrieval'] 108 | conf = oc.merge(oc.create(conf), oc.from_cli(args.dotlist)) 109 | logger.info('Parsed configuration:\n%s', oc.to_yaml(conf)) 110 | return conf 111 | -------------------------------------------------------------------------------- /pixloc/utils/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pathlib import Path 4 | from typing import Union, Dict, Tuple, Optional 5 | import numpy as np 6 | from .io import parse_image_list 7 | from .colmap import qvec2rotmat, read_images_binary, read_images_text 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def evaluate(gt_sfm_model: Path, predictions: Union[Dict, Path], 13 | test_file_list: Optional[Path] = None, 14 | only_localized: bool = False): 15 | """Compute the evaluation metrics for 7Scenes and Cambridge Landmarks. 16 | The other datasets are evaluated on visuallocalization.net 17 | """ 18 | if not isinstance(predictions, dict): 19 | predictions = parse_image_list(predictions, with_poses=True) 20 | predictions = {n: (im.qvec, im.tvec) for n, im in predictions} 21 | 22 | # ground truth poses from the sfm model 23 | images_bin = gt_sfm_model / 'images.bin' 24 | images_txt = gt_sfm_model / 'images.txt' 25 | if images_bin.exists(): 26 | images = read_images_binary(images_bin) 27 | elif images_txt.exists(): 28 | images = read_images_text(images_txt) 29 | else: 30 | raise ValueError(gt_sfm_model) 31 | name2id = {image.name: i for i, image in images.items()} 32 | 33 | if test_file_list is None: 34 | test_names = list(name2id) 35 | else: 36 | with open(test_file_list, 'r') as f: 37 | test_names = f.read().rstrip().split('\n') 38 | 39 | # translation and rotation errors 40 | errors_t = [] 41 | errors_R = [] 42 | for name in test_names: 43 | if name not in predictions: 44 | if only_localized: 45 | continue 46 | e_t = np.inf 47 | e_R = 180. 48 | else: 49 | image = images[name2id[name]] 50 | R_gt, t_gt = image.qvec2rotmat(), image.tvec 51 | qvec, t = predictions[name] 52 | R = qvec2rotmat(qvec) 53 | e_t = np.linalg.norm(-R_gt.T @ t_gt + R.T @ t, axis=0) 54 | cos = np.clip((np.trace(np.dot(R_gt.T, R)) - 1) / 2, -1., 1.) 55 | e_R = np.rad2deg(np.abs(np.arccos(cos))) 56 | errors_t.append(e_t) 57 | errors_R.append(e_R) 58 | 59 | errors_t = np.array(errors_t) 60 | errors_R = np.array(errors_R) 61 | med_t = np.median(errors_t) 62 | med_R = np.median(errors_R) 63 | out = f'\nMedian errors: {med_t:.3f}m, {med_R:.3f}deg' 64 | 65 | out += '\nPercentage of test images localized within:' 66 | threshs_t = [0.01, 0.02, 0.03, 0.05, 0.25, 0.5, 5.0] 67 | threshs_R = [1.0, 2.0, 3.0, 5.0, 2.0, 5.0, 10.0] 68 | for th_t, th_R in zip(threshs_t, threshs_R): 69 | ratio = np.mean((errors_t < th_t) & (errors_R < th_R)) 70 | out += f'\n\t{th_t*100:.0f}cm, {th_R:.0f}deg : {ratio*100:.2f}%' 71 | logger.info(out) 72 | 73 | 74 | def cumulative_recall(errors: np.ndarray) -> Tuple[np.ndarray]: 75 | sort_idx = np.argsort(errors) 76 | errors = np.array(errors.copy())[sort_idx] 77 | recall = (np.arange(len(errors)) + 1) / len(errors) 78 | errors = np.r_[0., errors] 79 | recall = np.r_[0., recall] 80 | return errors, recall*100 81 | -------------------------------------------------------------------------------- /pixloc/utils/io.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Union, Any 3 | from pathlib import Path 4 | from collections import defaultdict 5 | import numpy as np 6 | import h5py 7 | 8 | from .colmap import Camera, Image 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def parse_image_list(path: Path, with_intrinsics: bool = False, 14 | with_poses: bool = False) -> List: 15 | images = [] 16 | with open(path, 'r') as f: 17 | for line in f: 18 | line = line.strip('\n') 19 | if len(line) == 0 or line[0] == '#': 20 | continue 21 | name, *data = line.split() 22 | if with_intrinsics: 23 | camera_model, width, height, *params = data 24 | params = np.array(params, float) 25 | camera = Camera( 26 | None, camera_model, int(width), int(height), params) 27 | images.append((name, camera)) 28 | elif with_poses: 29 | qvec, tvec = np.split(np.array(data, float), [4]) 30 | image = Image( 31 | id=None, qvec=qvec, tvec=tvec, camera_id=None, name=name, 32 | xys=None, point3D_ids=None) 33 | images.append((name, image)) 34 | else: 35 | images.append(name) 36 | 37 | logger.info(f'Imported {len(images)} images from {path.name}') 38 | return images 39 | 40 | 41 | def parse_image_lists(paths: Path, **kwargs) -> List: 42 | images = [] 43 | files = list(Path(paths.parent).glob(paths.name)) 44 | assert len(files) > 0, paths 45 | for lfile in files: 46 | images += parse_image_list(lfile, **kwargs) 47 | return images 48 | 49 | 50 | def parse_retrieval(path: Path) -> Dict[str, List[str]]: 51 | retrieval = defaultdict(list) 52 | with open(path, 'r') as f: 53 | for p in f.read().rstrip('\n').split('\n'): 54 | q, r = p.split() 55 | retrieval[q].append(r) 56 | return dict(retrieval) 57 | 58 | 59 | def load_hdf5(path: Path) -> Dict[str, Any]: 60 | with h5py.File(path, 'r') as hfile: 61 | data = {} 62 | def collect(_, obj): # noqa 63 | if isinstance(obj, h5py.Dataset): 64 | name = obj.parent.name.strip('/') 65 | data[name] = obj.__array__() 66 | hfile.visititems(collect) 67 | return data 68 | 69 | 70 | def write_pose_results(pose_dict: Dict, outfile: Path, 71 | prepend_camera_name: bool = False): 72 | logger.info('Writing the localization results to %s.', outfile) 73 | outfile.parent.mkdir(parents=True, exist_ok=True) 74 | with open(str(outfile), 'w') as f: 75 | for imgname, (qvec, tvec) in pose_dict.items(): 76 | qvec = ' '.join(map(str, qvec)) 77 | tvec = ' '.join(map(str, tvec)) 78 | name = imgname.split('/')[-1] 79 | if prepend_camera_name: 80 | name = imgname.split('/')[-2] + '/' + name 81 | f.write(f'{name} {qvec} {tvec}\n') 82 | 83 | 84 | def concat_results(paths: List[Path], names: List[Union[int, str]], 85 | output_path: Path, key: str) -> Path: 86 | results = [] 87 | for path in sorted(paths): 88 | with open(path, 'r') as fp: 89 | results.append(fp.read().rstrip('\n')) 90 | output_path = str(output_path).replace( 91 | f'{{{key}}}', '-'.join(str(n)[:3] for n in names)) 92 | with open(output_path, 'w') as fp: 93 | fp.write('\n'.join(results)) 94 | return Path(output_path) 95 | -------------------------------------------------------------------------------- /pixloc/utils/quaternions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def qvec2rotmat(qvec): 5 | return np.array([ 6 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 7 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 8 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 9 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 10 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 11 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 12 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 13 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 14 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 15 | 16 | 17 | def rotmat2qvec(R): 18 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 19 | K = np.array([ 20 | [Rxx - Ryy - Rzz, 0, 0, 0], 21 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 22 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 23 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 24 | eigvals, eigvecs = np.linalg.eigh(K) 25 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 26 | if qvec[0] < 0: 27 | qvec *= -1 28 | return qvec 29 | 30 | 31 | def weighted_qvecs(qvecs, weights): 32 | """Adapted from Tolga Birdal: 33 | https://github.com/tolgabirdal/averaging_quaternions/blob/master/wavg_quaternion_markley.m 34 | """ 35 | outer = np.einsum('ni,nj,n->ij', qvecs, qvecs, weights) 36 | avg = np.linalg.eigh(outer)[1][:, -1] # eigenvector of largest eigenvalue 37 | avg *= np.sign(avg[0]) 38 | return avg 39 | 40 | 41 | def weighted_pose(t_w2c, q_w2c, weights): 42 | weights = np.array(weights) 43 | R_w2c = np.stack([qvec2rotmat(q) for q in q_w2c], 0) 44 | 45 | t_c2w = -np.einsum('nij,ni->nj', R_w2c, np.array(t_w2c)) 46 | t_approx_c2w = np.sum(t_c2w * weights[:, None], 0) 47 | 48 | q_c2w = np.array(q_w2c) * np.array([[1, -1, -1, -1]]) # invert 49 | q_c2w *= np.sign(q_c2w[:, 0])[:, None] # handle antipodal 50 | q_approx_c2w = weighted_qvecs(q_c2w, weights) 51 | 52 | # convert back to camera coordinates 53 | R_approx = qvec2rotmat(q_approx_c2w).T 54 | t_approx = -R_approx @ t_approx_c2w 55 | 56 | return R_approx, t_approx 57 | -------------------------------------------------------------------------------- /pixloc/utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import functools 4 | 5 | 6 | def torchify(func): 7 | """Extends to NumPy arrays a function written for PyTorch tensors. 8 | 9 | Converts input arrays to tensors and output tensors back to arrays. 10 | Supports hybrid inputs where some are arrays and others are tensors: 11 | - in this case all tensors should have the same device and float dtype; 12 | - the output is not converted. 13 | 14 | No data copy: tensors and arrays share the same underlying storage. 15 | 16 | Warning: kwargs are currently not supported when using jit. 17 | """ 18 | # TODO: switch to @torch.jit.unused when is_scripting will work 19 | @torch.jit.ignore 20 | @functools.wraps(func) 21 | def wrapped(*args, **kwargs): 22 | device = None 23 | dtype = None 24 | for arg in args: 25 | if isinstance(arg, torch.Tensor): 26 | device_ = arg.device 27 | if device is not None and device != device_: 28 | raise ValueError( 29 | 'Two input tensors have different devices: ' 30 | f'{device} and {device_}') 31 | device = device_ 32 | if torch.is_floating_point(arg): 33 | dtype_ = arg.dtype 34 | if dtype is not None and dtype != dtype_: 35 | raise ValueError( 36 | 'Two input tensors have different float dtypes: ' 37 | f'{dtype} and {dtype_}') 38 | dtype = dtype_ 39 | 40 | args_converted = [] 41 | for arg in args: 42 | if isinstance(arg, np.ndarray): 43 | arg = torch.from_numpy(arg).to(device) 44 | if torch.is_floating_point(arg): 45 | arg = arg.to(dtype) 46 | args_converted.append(arg) 47 | 48 | rets = func(*args_converted, **kwargs) 49 | 50 | def convert_back(ret): 51 | if isinstance(ret, torch.Tensor): 52 | if device is None: # no input was torch.Tensor 53 | ret = ret.cpu().numpy() 54 | return ret 55 | 56 | # TODO: handle nested struct with map tensor 57 | if not isinstance(rets, tuple): 58 | rets = convert_back(rets) 59 | else: 60 | rets = tuple(convert_back(ret) for ret in rets) 61 | return rets 62 | 63 | # BUG: is_scripting does not work in 1.6 so wrapped is always called 64 | if torch.jit.is_scripting(): 65 | return func 66 | else: 67 | return wrapped 68 | -------------------------------------------------------------------------------- /pixloc/visualization/animation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List 3 | import logging 4 | import shutil 5 | import json 6 | import io 7 | import base64 8 | import cv2 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from .viz_2d import save_plot 13 | from ..localization import Model3D 14 | from ..pixlib.geometry import Pose, Camera 15 | from ..utils.quaternions import rotmat2qvec 16 | 17 | logger = logging.getLogger(__name__) 18 | try: 19 | import ffmpeg 20 | except ImportError: 21 | logger.info('Cannot import ffmpeg.') 22 | 23 | 24 | def subsample_steps(T_w2q: Pose, p2d_q: np.ndarray, mask_q: np.ndarray, 25 | camera_size: np.ndarray, thresh_dt: float = 0.1, 26 | thresh_px: float = 0.005) -> List[int]: 27 | """Subsample steps of the optimization based on camera or point 28 | displacements. Main use case: compress an animation 29 | but keep it smooth and interesting. 30 | """ 31 | mask = mask_q.any(0) 32 | dp2ds = np.linalg.norm(np.diff(p2d_q, axis=0), axis=-1) 33 | dp2ds = np.median(dp2ds[:, mask], 1) 34 | dts = (T_w2q[:-1] @ T_w2q[1:].inv()).magnitude()[0].numpy() 35 | assert len(dts) == len(dp2ds) 36 | 37 | thresh_dp2 = camera_size.min()*thresh_px # from percent to pixel 38 | 39 | num = len(dp2ds) 40 | keep = [] 41 | count_dp2 = 0 42 | count_dt = 0 43 | for i, dp2 in enumerate(dp2ds): 44 | count_dp2 += dp2 45 | count_dt += dts[i] 46 | if (i == 0 or i == (num-1) 47 | or count_dp2 >= thresh_dp2 or count_dt >= thresh_dt): 48 | count_dp2 = 0 49 | count_dt = 0 50 | keep.append(i) 51 | return keep 52 | 53 | 54 | class VideoWriter: 55 | """Write frames sequentially as images, create a video, and clean up.""" 56 | def __init__(self, tmp_dir: Path, ext='.jpg'): 57 | self.tmp_dir = Path(tmp_dir) 58 | self.ext = ext 59 | self.count = 0 60 | if self.tmp_dir.exists(): 61 | shutil.rmtree(self.tmp_dir) 62 | self.tmp_dir.mkdir(parents=True) 63 | 64 | def add_frame(self): 65 | save_plot(self.tmp_dir / f'{self.count:0>5}{self.ext}') 66 | plt.close() 67 | self.count += 1 68 | 69 | def to_video(self, out_path: Path, duration: Optional[float] = None, 70 | fps: int = 5, crf: int = 23, verbose: bool = False): 71 | assert self.count > 0 72 | if duration is not None: 73 | fps = self.count / duration 74 | frames = self.tmp_dir / f'*{self.ext}' 75 | logger.info('Running ffmpeg.') 76 | ( 77 | ffmpeg 78 | .input(frames, pattern_type='glob', framerate=fps) 79 | .filter('crop', 'trunc(iw/2)*2', 'trunc(ih/2)*2') 80 | .output(out_path, crf=crf, vcodec='libx264', pix_fmt='yuv420p') 81 | .run(overwrite_output=True, quiet=not verbose) 82 | ) 83 | shutil.rmtree(self.tmp_dir) 84 | 85 | 86 | def display_video(path: Path): 87 | from IPython.display import HTML 88 | # prevent jupyter from caching the video file 89 | data = io.open(path, 'r+b').read() 90 | encoded = base64.b64encode(data).decode('ascii') 91 | return HTML(f""" 92 | 95 | """) 96 | 97 | 98 | def frustum_points(camera: Camera) -> np.ndarray: 99 | """Compute the corners of the frustum of a camera object.""" 100 | W, H = camera.size.numpy() 101 | corners = np.array([[0, 0], [W, 0], [W, H], [0, H], 102 | [0, 0], [W/2, -H/5], [W, 0]]) 103 | corners = (corners - camera.c.numpy()) / camera.f.numpy() 104 | return corners 105 | 106 | 107 | def copy_compress_image(source: Path, target: Path, quality: int = 50): 108 | """Read an image and write it to a low-quality jpeg.""" 109 | image = cv2.imread(str(source)) 110 | cv2.imwrite(str(target), image, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) 111 | 112 | 113 | def format_json(x, decimals: int = 3): 114 | """Control the precision of numpy float arrays, convert boolean to int.""" 115 | if isinstance(x, np.ndarray): 116 | if np.issubdtype(x.dtype, np.floating): 117 | if x.shape != (4,): # qvec 118 | x = np.round(x, decimals=decimals) 119 | elif x.dtype == np.bool: 120 | x = x.astype(int) 121 | return x.tolist() 122 | if isinstance(x, float): 123 | return round(x, decimals) 124 | if isinstance(x, dict): 125 | return {k: format_json(v) for k, v in x.items()} 126 | if isinstance(x, (list, tuple)): 127 | return [format_json(v) for v in x] 128 | return x 129 | 130 | 131 | def create_viz_dump(assets: Path, paths: Path, cam_q: Camera, name_q: str, 132 | T_w2q: Pose, mask_q: np.ndarray, p2d_q: np.ndarray, 133 | ref_ids: List[int], model3d: Model3D, p3d_ids: np.ndarray, 134 | tfm: np.ndarray = np.eye(3)): 135 | assets.mkdir(parents=True, exist_ok=True) 136 | 137 | dump = { 138 | 'p3d': {}, 139 | 'T': {}, 140 | 'camera': {}, 141 | 'image': {}, 142 | 'p2d': {}, 143 | } 144 | 145 | p3d = np.stack([model3d.points3D[i].xyz for i in p3d_ids], 0) 146 | dump['p3d']['colors'] = [model3d.points3D[i].rgb for i in p3d_ids] 147 | dump['p3d']['xyz'] = p3d @ tfm.T 148 | 149 | dump['T']['refs'] = [] 150 | dump['camera']['refs'] = [] 151 | dump['image']['refs'] = [] 152 | dump['p2d']['refs'] = [] 153 | for idx, ref_id in enumerate(ref_ids): 154 | ref = model3d.dbs[ref_id] 155 | cam_r = Camera.from_colmap(model3d.cameras[ref.camera_id]) 156 | T_w2r = Pose.from_colmap(ref) 157 | 158 | qtvec = (rotmat2qvec(T_w2r.R.numpy() @ tfm.T), T_w2r.t.numpy()) 159 | dump['T']['refs'].append(qtvec) 160 | dump['camera']['refs'].append(frustum_points(cam_r)) 161 | 162 | tmp_name = f'ref{idx}.jpg' 163 | dump['image']['refs'].append(tmp_name) 164 | copy_compress_image( 165 | paths.reference_images / ref.name, assets / tmp_name) 166 | 167 | p2d_, valid_ = cam_r.world2image(T_w2r * p3d) 168 | p2d_ = p2d_[valid_ & mask_q.any(0)] / cam_r.size 169 | dump['p2d']['refs'].append(p2d_.numpy()) 170 | 171 | qtvec_q = [(rotmat2qvec(T.R.numpy() @ tfm.T), T.t.numpy()) for T in T_w2q] 172 | dump['T']['query'] = qtvec_q 173 | dump['camera']['query'] = frustum_points(cam_q) 174 | 175 | p2d_q_norm = [np.asarray(p[v]/cam_q.size) for p, v in zip(p2d_q, mask_q)] 176 | dump['p2d']['query'] = p2d_q_norm[-1] 177 | 178 | tmp_name = 'query.jpg' 179 | dump['image']['query'] = tmp_name 180 | copy_compress_image(paths.query_images / name_q, assets / tmp_name) 181 | 182 | with open(assets / 'dump.json', 'w') as fid: 183 | json.dump(format_json(dump), fid, separators=(',', ':')) 184 | 185 | # We dump 2D points as a separate json because it is much heavier 186 | # and thus slower to load. 187 | dump_p2d = { 188 | 'query': p2d_q_norm, 189 | 'masks': np.asarray(mask_q), 190 | } 191 | with open(assets / 'dump_p2d.json', 'w') as fid: 192 | json.dump(format_json(dump_p2d), fid, separators=(',', ':')) 193 | -------------------------------------------------------------------------------- /pixloc/visualization/viz_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2D visualization primitives based on Matplotlib. 3 | 4 | 1) Plot images with `plot_images`. 5 | 2) Call `plot_keypoints` or `plot_matches` any number of times. 6 | 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. 7 | """ 8 | 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | import matplotlib.patheffects as path_effects 12 | import numpy as np 13 | 14 | 15 | def cm_RdGn(x): 16 | """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" 17 | x = np.clip(x, 0, 1)[..., None]*2 18 | c = x*np.array([[0, 1., 0]]) + (2-x)*np.array([[1., 0, 0]]) 19 | return np.clip(c, 0, 1) 20 | 21 | 22 | def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, 23 | adaptive=True, autoscale=True): 24 | """Plot a set of images horizontally. 25 | Args: 26 | imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). 27 | titles: a list of strings, as titles for each image. 28 | cmaps: colormaps for monochrome images. 29 | adaptive: whether the figure size should fit the image aspect ratios. 30 | """ 31 | n = len(imgs) 32 | if not isinstance(cmaps, (list, tuple)): 33 | cmaps = [cmaps] * n 34 | 35 | if adaptive: 36 | ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H 37 | else: 38 | ratios = [4/3] * n 39 | figsize = [sum(ratios)*4.5, 4.5] 40 | fig, ax = plt.subplots( 41 | 1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios}) 42 | if n == 1: 43 | ax = [ax] 44 | for i in range(n): 45 | ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) 46 | ax[i].get_yaxis().set_ticks([]) 47 | ax[i].get_xaxis().set_ticks([]) 48 | ax[i].set_axis_off() 49 | for spine in ax[i].spines.values(): # remove frame 50 | spine.set_visible(False) 51 | if titles: 52 | ax[i].set_title(titles[i]) 53 | if not autoscale: 54 | ax[i].autoscale(False) 55 | fig.tight_layout(pad=pad) 56 | 57 | 58 | def plot_keypoints(kpts, colors='lime', ps=6): 59 | """Plot keypoints for existing images. 60 | Args: 61 | kpts: list of ndarrays of size (N, 2). 62 | colors: string, or list of list of tuples (one for each keypoints). 63 | ps: size of the keypoints as float. 64 | """ 65 | if not isinstance(colors, list): 66 | colors = [colors] * len(kpts) 67 | axes = plt.gcf().axes 68 | for a, k, c in zip(axes, kpts, colors): 69 | if k is not None: 70 | a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) 71 | 72 | 73 | def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): 74 | """Plot matches for a pair of existing images. 75 | Args: 76 | kpts0, kpts1: corresponding keypoints of size (N, 2). 77 | color: color of each match, string or RGB tuple. Random if not given. 78 | lw: width of the lines. 79 | ps: size of the end points (no endpoint if ps=0) 80 | indices: indices of the images to draw the matches on. 81 | a: alpha opacity of the match lines. 82 | """ 83 | fig = plt.gcf() 84 | ax = fig.axes 85 | assert len(ax) > max(indices) 86 | ax0, ax1 = ax[indices[0]], ax[indices[1]] 87 | fig.canvas.draw() 88 | 89 | assert len(kpts0) == len(kpts1) 90 | if color is None: 91 | color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() 92 | elif len(color) > 0 and not isinstance(color[0], (tuple, list)): 93 | color = [color] * len(kpts0) 94 | 95 | if lw > 0: 96 | # transform the points into the figure coordinate system 97 | transFigure = fig.transFigure.inverted() 98 | fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) 99 | fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) 100 | fig.lines += [matplotlib.lines.Line2D( 101 | (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), 102 | zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, 103 | alpha=a) 104 | for i in range(len(kpts0))] 105 | 106 | # freeze the axes to prevent the transform to change 107 | ax0.autoscale(enable=False) 108 | ax1.autoscale(enable=False) 109 | 110 | if ps > 0: 111 | ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) 112 | ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) 113 | 114 | 115 | def add_text(idx, text, pos=(0.01, 0.99), fs=15, color='w', 116 | lcolor='k', lwidth=2): 117 | ax = plt.gcf().axes[idx] 118 | t = ax.text(*pos, text, fontsize=fs, va='top', ha='left', 119 | color=color, transform=ax.transAxes) 120 | if lcolor is not None: 121 | t.set_path_effects([ 122 | path_effects.Stroke(linewidth=lwidth, foreground=lcolor), 123 | path_effects.Normal()]) 124 | 125 | 126 | def save_plot(path, **kw): 127 | """Save the current figure without any white margin.""" 128 | plt.savefig(path, bbox_inches='tight', pad_inches=0, **kw) 129 | 130 | 131 | def features_to_RGB(*Fs, skip=1): 132 | """Project a list of d-dimensional feature maps to RGB colors using PCA.""" 133 | from sklearn.decomposition import PCA 134 | 135 | def normalize(x): 136 | return x / np.linalg.norm(x, axis=-1, keepdims=True) 137 | flatten = [] 138 | shapes = [] 139 | for F in Fs: 140 | c, h, w = F.shape 141 | F = np.rollaxis(F, 0, 3) 142 | F = F.reshape(-1, c) 143 | flatten.append(F) 144 | shapes.append((h, w)) 145 | flatten = np.concatenate(flatten, axis=0) 146 | 147 | pca = PCA(n_components=3) 148 | if skip > 1: 149 | pca.fit(normalize(flatten[::skip])) 150 | flatten = normalize(pca.transform(normalize(flatten))) 151 | else: 152 | flatten = normalize(pca.fit_transform(normalize(flatten))) 153 | flatten = (flatten + 1) / 2 154 | 155 | Fs = [] 156 | for h, w in shapes: 157 | F, flatten = np.split(flatten, [h*w], axis=0) 158 | F = F.reshape((h, w, 3)) 159 | Fs.append(F) 160 | assert flatten.shape[0] == 0 161 | return Fs 162 | -------------------------------------------------------------------------------- /pixloc/visualization/viz_3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D visualization primitives based on Plotly. 3 | We might want to instead use a more powerful library like Open3D. 4 | Plotly however supports animations, buttons and sliders. 5 | 6 | 1) Initialize a figure with `fig = init_figure()` 7 | 2) Plot points, cameras, lines, or create a slider animation. 8 | 3) Call `fig.show()` to render the figure. 9 | """ 10 | 11 | import plotly.graph_objects as go 12 | import numpy as np 13 | 14 | from ..pixlib.geometry.utils import to_homogeneous 15 | 16 | 17 | def init_figure(height=800): 18 | """Initialize a 3D figure.""" 19 | fig = go.Figure() 20 | fig.update_layout( 21 | height=height, 22 | scene_camera=dict( 23 | eye=dict(x=0., y=-.1, z=-2), up=dict(x=0, y=-1., z=0)), 24 | scene=dict( 25 | xaxis=dict(showbackground=False), 26 | yaxis=dict(showbackground=False), 27 | aspectmode='data', dragmode='orbit'), 28 | margin=dict(l=0, r=0, b=0, t=0, pad=0)) # noqa E741 29 | return fig 30 | 31 | 32 | def plot_points(fig, pts, color='rgba(255, 0, 0, 1)', ps=2): 33 | """Plot a set of 3D points.""" 34 | x, y, z = pts.T 35 | tr = go.Scatter3d( 36 | x=x, y=y, z=z, mode='markers', marker_size=ps, 37 | marker_color=color, marker_line_width=.2) 38 | fig.add_trace(tr) 39 | 40 | 41 | def plot_camera(fig, R, t, K, color='rgb(0, 0, 255)'): 42 | """Plot a camera as a cone with camera frustum.""" 43 | x, y, z = t 44 | u, v, w = R @ -np.array([0, 0, 1]) 45 | tr = go.Cone( 46 | x=[x], y=[y], z=[z], u=[u], v=[v], w=[w], anchor='tip', 47 | showscale=False, colorscale=[[0, color], [1, color]], 48 | sizemode='absolute') 49 | fig.add_trace(tr) 50 | 51 | W, H = K[0, 2]*2, K[1, 2]*2 52 | corners = np.array([[0, 0], [W, 0], [W, H], [0, H], [0, 0]]) 53 | corners = to_homogeneous(corners) @ np.linalg.inv(K).T 54 | corners = (corners/2) @ R.T + t 55 | x, y, z = corners.T 56 | tr = go.Scatter3d( 57 | x=x, y=y, z=z, line=dict(color='rgba(0, 0, 0, .5)'), 58 | marker=dict(size=0.0001), showlegend=False) 59 | fig.add_trace(tr) 60 | 61 | 62 | def create_slider_animation(fig, traces): 63 | """Create a slider that animates a list of traces (e.g. 3D points).""" 64 | slider = {'steps': []} 65 | frames = [] 66 | fig.add_trace(traces[0]) 67 | idx = len(fig.data) - 1 68 | for i, tr in enumerate(traces): 69 | frames.append(go.Frame(name=str(i), traces=[idx], data=[tr])) 70 | step = {"args": [ 71 | [str(i)], 72 | {"frame": {"redraw": True}, 73 | "mode": "immediate"}], 74 | "label": i, 75 | "method": "animate"} 76 | slider['steps'].append(step) 77 | fig.frames = tuple(frames) 78 | fig.layout.sliders = (slider,) 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | torchvision>=0.8 3 | numpy 4 | opencv-python 5 | tqdm 6 | matplotlib 7 | scipy 8 | h5py 9 | omegaconf 10 | tensorboard 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup 3 | 4 | description = ['Training and evaluation of the CVPR 2021 paper Back to the Feature'] 5 | 6 | with open(str(Path(__file__).parent / 'README.md'), 'r', encoding='utf-8') as f: 7 | readme = f.read() 8 | 9 | with open(str(Path(__file__).parent / 'requirements.txt'), 'r') as f: 10 | dependencies = f.read().split('\n') 11 | 12 | extra_dependencies = ['jupyter', 'scikit-learn', 'ffmpeg-python', 'kornia'] 13 | 14 | setup( 15 | name='pixloc', 16 | version='1.0', 17 | packages=['pixloc'], 18 | python_requires='>=3.6', 19 | install_requires=dependencies, 20 | extras_require={'extra': extra_dependencies}, 21 | author='Paul-Edouard Sarlin', 22 | description=description, 23 | long_description=readme, 24 | long_description_content_type="text/markdown", 25 | url='https://github.com/cvg/pixloc/', 26 | classifiers=[ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: Apache Software License", 29 | "Operating System :: OS Independent", 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /viewer/disc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/viewer/disc.png -------------------------------------------------------------------------------- /viewer/dumps/sample/query.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/viewer/dumps/sample/query.jpg -------------------------------------------------------------------------------- /viewer/dumps/sample/ref0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/viewer/dumps/sample/ref0.jpg -------------------------------------------------------------------------------- /viewer/dumps/sample/ref1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/viewer/dumps/sample/ref1.jpg -------------------------------------------------------------------------------- /viewer/dumps/sample/ref2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/pixloc/65a51a7300a55d0b933dd13b6d1d7c1e6ef775d5/viewer/dumps/sample/ref2.jpg -------------------------------------------------------------------------------- /viewer/jsm/lib3d.js: -------------------------------------------------------------------------------- 1 | import * as THREE from './three.module.js'; 2 | import { Line2 } from './lines/Line2.js'; 3 | import { LineMaterial } from './lines/LineMaterial.js'; 4 | import { LineGeometry } from './lines/LineGeometry.js'; 5 | 6 | function pcdFromArrays(xyz, colors, point_size, onLoad = undefined) { 7 | var pcd = new THREE.BufferGeometry(); 8 | var xyz_ = [].concat.apply([], xyz); 9 | var colors_ = [].concat.apply([], colors).map(x => x / 255.0); 10 | pcd.setAttribute('position', new THREE.Float32BufferAttribute(xyz_, 3)); 11 | pcd.setAttribute('color', new THREE.Float32BufferAttribute(colors_, 3)); 12 | var material = new THREE.PointsMaterial({ 13 | size: point_size, 14 | vertexColors: THREE.VertexColors, 15 | sizeAttenuation: true, 16 | alphaTest: 0.5, 17 | transparent: true 18 | }); 19 | new THREE.TextureLoader().load('./disc.png', function (texture) { 20 | material.map = texture; 21 | material.needsUpdate = true; 22 | if (onLoad != undefined) 23 | onLoad(); 24 | }); 25 | return new THREE.Points(pcd, material); 26 | } 27 | 28 | function parsePose(tuple) { 29 | var q = tuple[0]; 30 | var tvec = tuple[1]; 31 | var T = new THREE.Matrix4(); 32 | T.makeRotationFromQuaternion(new THREE.Quaternion(q[1], q[2], q[3], q[0])); 33 | T.setPosition(tvec[0], tvec[1], tvec[2]); 34 | return T.invert(); 35 | } 36 | 37 | function setTransform(obj, T) { 38 | obj.position.setFromMatrixPosition(T); 39 | obj.quaternion.setFromRotationMatrix(T); 40 | } 41 | 42 | function drawCamera(T, corners, image_path, scale = 1., color = 0x000000, light = false) { 43 | var camera = new THREE.Group(); 44 | camera.add(drawFrustum(corners, scale, color, light)); 45 | if (image_path != undefined) 46 | camera.add(drawImagePlane(corners, image_path, scale)); 47 | setTransform(camera, T); 48 | return camera; 49 | } 50 | 51 | function drawFrustum(corners, scale, color = 0x000000, light = false) { 52 | var c = new THREE.Vector3(); 53 | var p3d = corners.map(x => new THREE.Vector3(x[0], x[1], 1.)); 54 | p3d.push(c, p3d[3], p3d[0], c, p3d[2]); 55 | if (light) { 56 | var frustum = new THREE.Line( 57 | new THREE.BufferGeometry().setFromPoints(p3d), 58 | new THREE.LineBasicMaterial({color: color, linewidth: 10})); 59 | } else { 60 | const geometry = new LineGeometry(); 61 | const positions = []; 62 | p3d.forEach(p => positions.push(p.x, p.y, p.z)); 63 | geometry.setPositions(positions); 64 | frustum = new Line2( 65 | geometry, new LineMaterial({ 66 | color: color, 67 | linewidth: 0.002, 68 | alphaToCoverage: true})); 69 | } 70 | frustum.geometry.scale(scale, scale, scale); 71 | return frustum; 72 | } 73 | 74 | function drawImagePlane(corners, image_path, scale) { 75 | var img = new THREE.MeshBasicMaterial({ 76 | map: new THREE.TextureLoader().load(image_path), 77 | side: THREE.DoubleSide, 78 | }); 79 | img.map.needsUpdate = true; 80 | var [w, h] = corners[2].map((x, i) => (x-corners[0][i])*scale); 81 | var plane = new THREE.Mesh(new THREE.PlaneGeometry(w, h), img); 82 | plane.overdraw = true; 83 | plane.rotateX(THREE.MathUtils.degToRad(180.)); 84 | plane.translateZ(-scale); 85 | return plane; 86 | } 87 | 88 | function drawTrajectoryLine(T, T_prev, radius=0.02, color = 0x00ff00) { 89 | const start = new THREE.Vector3().setFromMatrixPosition(T); 90 | const end = new THREE.Vector3().setFromMatrixPosition(T_prev); 91 | const geometry = new LineGeometry(); 92 | geometry.setPositions([start.x, start.y, start.z, end.x, end.y, end.z]); 93 | var line = new Line2( 94 | geometry, 95 | new LineMaterial({color: color, linewidth: 0.002, alphaToCoverage: true})); 96 | return line; 97 | } 98 | 99 | function drawRays(xyz, mask, position) { 100 | var skip = Math.max(Math.floor(xyz.length/40), 1); 101 | var ray_p3d = []; 102 | for (let i = 0; i < xyz.length; i += skip) { 103 | if ((mask[i] == 1) || (mask[i] === true)) { 104 | ray_p3d.push(new THREE.Vector3().fromArray(xyz[i])); 105 | ray_p3d.push(position); 106 | } 107 | } 108 | var rays = new THREE.LineSegments( 109 | new THREE.BufferGeometry().setFromPoints(ray_p3d), 110 | new THREE.LineBasicMaterial({color: 0xff0000, transparent: true, opacity: 0.5})); 111 | return rays; 112 | } 113 | 114 | export { parsePose, setTransform, drawCamera, drawTrajectoryLine, drawRays, pcdFromArrays}; 115 | -------------------------------------------------------------------------------- /viewer/jsm/lines/Line2.js: -------------------------------------------------------------------------------- 1 | import { LineSegments2 } from '../lines/LineSegments2.js'; 2 | import { LineGeometry } from '../lines/LineGeometry.js'; 3 | import { LineMaterial } from '../lines/LineMaterial.js'; 4 | 5 | class Line2 extends LineSegments2 { 6 | 7 | constructor( geometry = new LineGeometry(), material = new LineMaterial( { color: Math.random() * 0xffffff } ) ) { 8 | 9 | super( geometry, material ); 10 | 11 | this.type = 'Line2'; 12 | 13 | } 14 | 15 | } 16 | 17 | Line2.prototype.isLine2 = true; 18 | 19 | export { Line2 }; 20 | -------------------------------------------------------------------------------- /viewer/jsm/lines/LineGeometry.js: -------------------------------------------------------------------------------- 1 | import { LineSegmentsGeometry } from '../lines/LineSegmentsGeometry.js'; 2 | 3 | class LineGeometry extends LineSegmentsGeometry { 4 | 5 | constructor() { 6 | 7 | super(); 8 | this.type = 'LineGeometry'; 9 | 10 | } 11 | 12 | setPositions( array ) { 13 | 14 | // converts [ x1, y1, z1, x2, y2, z2, ... ] to pairs format 15 | 16 | var length = array.length - 3; 17 | var points = new Float32Array( 2 * length ); 18 | 19 | for ( var i = 0; i < length; i += 3 ) { 20 | 21 | points[ 2 * i ] = array[ i ]; 22 | points[ 2 * i + 1 ] = array[ i + 1 ]; 23 | points[ 2 * i + 2 ] = array[ i + 2 ]; 24 | 25 | points[ 2 * i + 3 ] = array[ i + 3 ]; 26 | points[ 2 * i + 4 ] = array[ i + 4 ]; 27 | points[ 2 * i + 5 ] = array[ i + 5 ]; 28 | 29 | } 30 | 31 | super.setPositions( points ); 32 | 33 | return this; 34 | 35 | } 36 | 37 | setColors( array ) { 38 | 39 | // converts [ r1, g1, b1, r2, g2, b2, ... ] to pairs format 40 | 41 | var length = array.length - 3; 42 | var colors = new Float32Array( 2 * length ); 43 | 44 | for ( var i = 0; i < length; i += 3 ) { 45 | 46 | colors[ 2 * i ] = array[ i ]; 47 | colors[ 2 * i + 1 ] = array[ i + 1 ]; 48 | colors[ 2 * i + 2 ] = array[ i + 2 ]; 49 | 50 | colors[ 2 * i + 3 ] = array[ i + 3 ]; 51 | colors[ 2 * i + 4 ] = array[ i + 4 ]; 52 | colors[ 2 * i + 5 ] = array[ i + 5 ]; 53 | 54 | } 55 | 56 | super.setColors( colors ); 57 | 58 | return this; 59 | 60 | } 61 | 62 | fromLine( line ) { 63 | 64 | var geometry = line.geometry; 65 | 66 | if ( geometry.isGeometry ) { 67 | 68 | console.error( 'THREE.LineGeometry no longer supports Geometry. Use THREE.BufferGeometry instead.' ); 69 | return; 70 | 71 | } else if ( geometry.isBufferGeometry ) { 72 | 73 | this.setPositions( geometry.attributes.position.array ); // assumes non-indexed 74 | 75 | } 76 | 77 | // set colors, maybe 78 | 79 | return this; 80 | 81 | } 82 | 83 | copy( /* source */ ) { 84 | 85 | // todo 86 | 87 | return this; 88 | 89 | } 90 | 91 | } 92 | 93 | LineGeometry.prototype.isLineGeometry = true; 94 | 95 | export { LineGeometry }; 96 | -------------------------------------------------------------------------------- /viewer/jsm/lines/LineSegmentsGeometry.js: -------------------------------------------------------------------------------- 1 | import { 2 | Box3, 3 | Float32BufferAttribute, 4 | InstancedBufferGeometry, 5 | InstancedInterleavedBuffer, 6 | InterleavedBufferAttribute, 7 | Sphere, 8 | Vector3, 9 | WireframeGeometry 10 | } from '../three.module.js'; 11 | 12 | const _box = new Box3(); 13 | const _vector = new Vector3(); 14 | 15 | class LineSegmentsGeometry extends InstancedBufferGeometry { 16 | 17 | constructor() { 18 | 19 | super(); 20 | 21 | this.type = 'LineSegmentsGeometry'; 22 | 23 | const positions = [ - 1, 2, 0, 1, 2, 0, - 1, 1, 0, 1, 1, 0, - 1, 0, 0, 1, 0, 0, - 1, - 1, 0, 1, - 1, 0 ]; 24 | const uvs = [ - 1, 2, 1, 2, - 1, 1, 1, 1, - 1, - 1, 1, - 1, - 1, - 2, 1, - 2 ]; 25 | const index = [ 0, 2, 1, 2, 3, 1, 2, 4, 3, 4, 5, 3, 4, 6, 5, 6, 7, 5 ]; 26 | 27 | this.setIndex( index ); 28 | this.setAttribute( 'position', new Float32BufferAttribute( positions, 3 ) ); 29 | this.setAttribute( 'uv', new Float32BufferAttribute( uvs, 2 ) ); 30 | 31 | } 32 | 33 | applyMatrix4( matrix ) { 34 | 35 | const start = this.attributes.instanceStart; 36 | const end = this.attributes.instanceEnd; 37 | 38 | if ( start !== undefined ) { 39 | 40 | start.applyMatrix4( matrix ); 41 | 42 | end.applyMatrix4( matrix ); 43 | 44 | start.needsUpdate = true; 45 | 46 | } 47 | 48 | if ( this.boundingBox !== null ) { 49 | 50 | this.computeBoundingBox(); 51 | 52 | } 53 | 54 | if ( this.boundingSphere !== null ) { 55 | 56 | this.computeBoundingSphere(); 57 | 58 | } 59 | 60 | return this; 61 | 62 | } 63 | 64 | setPositions( array ) { 65 | 66 | let lineSegments; 67 | 68 | if ( array instanceof Float32Array ) { 69 | 70 | lineSegments = array; 71 | 72 | } else if ( Array.isArray( array ) ) { 73 | 74 | lineSegments = new Float32Array( array ); 75 | 76 | } 77 | 78 | const instanceBuffer = new InstancedInterleavedBuffer( lineSegments, 6, 1 ); // xyz, xyz 79 | 80 | this.setAttribute( 'instanceStart', new InterleavedBufferAttribute( instanceBuffer, 3, 0 ) ); // xyz 81 | this.setAttribute( 'instanceEnd', new InterleavedBufferAttribute( instanceBuffer, 3, 3 ) ); // xyz 82 | 83 | // 84 | 85 | this.computeBoundingBox(); 86 | this.computeBoundingSphere(); 87 | 88 | return this; 89 | 90 | } 91 | 92 | setColors( array ) { 93 | 94 | let colors; 95 | 96 | if ( array instanceof Float32Array ) { 97 | 98 | colors = array; 99 | 100 | } else if ( Array.isArray( array ) ) { 101 | 102 | colors = new Float32Array( array ); 103 | 104 | } 105 | 106 | const instanceColorBuffer = new InstancedInterleavedBuffer( colors, 6, 1 ); // rgb, rgb 107 | 108 | this.setAttribute( 'instanceColorStart', new InterleavedBufferAttribute( instanceColorBuffer, 3, 0 ) ); // rgb 109 | this.setAttribute( 'instanceColorEnd', new InterleavedBufferAttribute( instanceColorBuffer, 3, 3 ) ); // rgb 110 | 111 | return this; 112 | 113 | } 114 | 115 | fromWireframeGeometry( geometry ) { 116 | 117 | this.setPositions( geometry.attributes.position.array ); 118 | 119 | return this; 120 | 121 | } 122 | 123 | fromEdgesGeometry( geometry ) { 124 | 125 | this.setPositions( geometry.attributes.position.array ); 126 | 127 | return this; 128 | 129 | } 130 | 131 | fromMesh( mesh ) { 132 | 133 | this.fromWireframeGeometry( new WireframeGeometry( mesh.geometry ) ); 134 | 135 | // set colors, maybe 136 | 137 | return this; 138 | 139 | } 140 | 141 | romLineSegments( lineSegments ) { 142 | 143 | const geometry = lineSegments.geometry; 144 | 145 | if ( geometry.isGeometry ) { 146 | 147 | console.error( 'THREE.LineSegmentsGeometry no longer supports Geometry. Use THREE.BufferGeometry instead.' ); 148 | return; 149 | 150 | } else if ( geometry.isBufferGeometry ) { 151 | 152 | this.setPositions( geometry.attributes.position.array ); // assumes non-indexed 153 | 154 | } 155 | 156 | // set colors, maybe 157 | 158 | return this; 159 | 160 | } 161 | 162 | computeBoundingBox() { 163 | 164 | if ( this.boundingBox === null ) { 165 | 166 | this.boundingBox = new Box3(); 167 | 168 | } 169 | 170 | const start = this.attributes.instanceStart; 171 | const end = this.attributes.instanceEnd; 172 | 173 | if ( start !== undefined && end !== undefined ) { 174 | 175 | this.boundingBox.setFromBufferAttribute( start ); 176 | 177 | _box.setFromBufferAttribute( end ); 178 | 179 | this.boundingBox.union( _box ); 180 | 181 | } 182 | 183 | } 184 | 185 | computeBoundingSphere() { 186 | 187 | if ( this.boundingSphere === null ) { 188 | 189 | this.boundingSphere = new Sphere(); 190 | 191 | } 192 | 193 | if ( this.boundingBox === null ) { 194 | 195 | this.computeBoundingBox(); 196 | 197 | } 198 | 199 | const start = this.attributes.instanceStart; 200 | const end = this.attributes.instanceEnd; 201 | 202 | if ( start !== undefined && end !== undefined ) { 203 | 204 | const center = this.boundingSphere.center; 205 | 206 | this.boundingBox.getCenter( center ); 207 | 208 | let maxRadiusSq = 0; 209 | 210 | for ( let i = 0, il = start.count; i < il; i ++ ) { 211 | 212 | _vector.fromBufferAttribute( start, i ); 213 | maxRadiusSq = Math.max( maxRadiusSq, center.distanceToSquared( _vector ) ); 214 | 215 | _vector.fromBufferAttribute( end, i ); 216 | maxRadiusSq = Math.max( maxRadiusSq, center.distanceToSquared( _vector ) ); 217 | 218 | } 219 | 220 | this.boundingSphere.radius = Math.sqrt( maxRadiusSq ); 221 | 222 | if ( isNaN( this.boundingSphere.radius ) ) { 223 | 224 | console.error( 'THREE.LineSegmentsGeometry.computeBoundingSphere(): Computed radius is NaN. The instanced position data is likely to have NaN values.', this ); 225 | 226 | } 227 | 228 | } 229 | 230 | } 231 | 232 | toJSON() { 233 | 234 | // todo 235 | 236 | } 237 | 238 | applyMatrix( matrix ) { 239 | 240 | console.warn( 'THREE.LineSegmentsGeometry: applyMatrix() has been renamed to applyMatrix4().' ); 241 | 242 | return this.applyMatrix4( matrix ); 243 | 244 | } 245 | 246 | } 247 | 248 | LineSegmentsGeometry.prototype.isLineSegmentsGeometry = true; 249 | 250 | export { LineSegmentsGeometry }; 251 | -------------------------------------------------------------------------------- /viewer/jsm/lines/Wireframe.js: -------------------------------------------------------------------------------- 1 | import { 2 | InstancedInterleavedBuffer, 3 | InterleavedBufferAttribute, 4 | Mesh, 5 | Vector3 6 | } from '../../../build/three.module.js'; 7 | import { LineSegmentsGeometry } from '../lines/LineSegmentsGeometry.js'; 8 | import { LineMaterial } from '../lines/LineMaterial.js'; 9 | 10 | const _start = new Vector3(); 11 | const _end = new Vector3(); 12 | 13 | class Wireframe extends Mesh { 14 | 15 | constructor( geometry = new LineSegmentsGeometry(), material = new LineMaterial( { color: Math.random() * 0xffffff } ) ) { 16 | 17 | super( geometry, material ); 18 | 19 | this.type = 'Wireframe'; 20 | 21 | } 22 | 23 | // for backwards-compatability, but could be a method of LineSegmentsGeometry... 24 | 25 | computeLineDistances() { 26 | 27 | const geometry = this.geometry; 28 | 29 | const instanceStart = geometry.attributes.instanceStart; 30 | const instanceEnd = geometry.attributes.instanceEnd; 31 | const lineDistances = new Float32Array( 2 * instanceStart.count ); 32 | 33 | for ( let i = 0, j = 0, l = instanceStart.count; i < l; i ++, j += 2 ) { 34 | 35 | _start.fromBufferAttribute( instanceStart, i ); 36 | _end.fromBufferAttribute( instanceEnd, i ); 37 | 38 | lineDistances[ j ] = ( j === 0 ) ? 0 : lineDistances[ j - 1 ]; 39 | lineDistances[ j + 1 ] = lineDistances[ j ] + _start.distanceTo( _end ); 40 | 41 | } 42 | 43 | const instanceDistanceBuffer = new InstancedInterleavedBuffer( lineDistances, 2, 1 ); // d0, d1 44 | 45 | geometry.setAttribute( 'instanceDistanceStart', new InterleavedBufferAttribute( instanceDistanceBuffer, 1, 0 ) ); // d0 46 | geometry.setAttribute( 'instanceDistanceEnd', new InterleavedBufferAttribute( instanceDistanceBuffer, 1, 1 ) ); // d1 47 | 48 | return this; 49 | 50 | } 51 | 52 | } 53 | 54 | Wireframe.prototype.isWireframe = true; 55 | 56 | export { Wireframe }; 57 | -------------------------------------------------------------------------------- /viewer/jsm/lines/WireframeGeometry2.js: -------------------------------------------------------------------------------- 1 | import { 2 | WireframeGeometry 3 | } from '../../../build/three.module.js'; 4 | import { LineSegmentsGeometry } from '../lines/LineSegmentsGeometry.js'; 5 | 6 | class WireframeGeometry2 extends LineSegmentsGeometry { 7 | 8 | constructor( geometry ) { 9 | 10 | super(); 11 | 12 | this.type = 'WireframeGeometry2'; 13 | 14 | this.fromWireframeGeometry( new WireframeGeometry( geometry ) ); 15 | 16 | // set colors, maybe 17 | 18 | } 19 | 20 | } 21 | 22 | WireframeGeometry2.prototype.isWireframeGeometry2 = true; 23 | 24 | export { WireframeGeometry2 }; 25 | -------------------------------------------------------------------------------- /viewer/jupyter.html: -------------------------------------------------------------------------------- 1 | 2 | 14 |
15 |
16 |
17 |
18 | Mouse left/middle/right to orbit/zoom/pan
19 | Press + or - to increase or decrease the point size
20 | Press [ or ] to increase or decrease the frustum size
21 | Press i and o to initialize and step through the optimization
22 | Press p to play or clear the animation
23 | Press r to start or stop the auto rotation
24 | Press h to hide this help 25 |
26 |
27 |
28 |
29 |
30 | Could not load. 31 |

Query

32 |
33 |
34 | Could not load. 35 |

Reference

36 |
37 |
38 |
39 | 46 | -------------------------------------------------------------------------------- /viewer/server.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import http.server 4 | 5 | PORT = 8000 6 | 7 | 8 | class HttpRequestHandler(http.server.SimpleHTTPRequestHandler): 9 | extensions_map = { 10 | '': 'application/octet-stream', 11 | '.manifest': 'text/cache-manifest', 12 | '.html': 'text/html', 13 | '.png': 'image/png', 14 | '.jpg': 'image/jpg', 15 | '.svg': 'image/svg+xml', 16 | '.css': 'text/css', 17 | '.js': 'text/javascript', 18 | '.wasm': 'application/wasm', 19 | '.json': 'application/json', 20 | '.xml': 'application/xml', 21 | } 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--port', type=int, default=PORT, 26 | help='Server local port.') 27 | args = parser.parse_args() 28 | port = args.port 29 | 30 | httpd = http.server.HTTPServer(('localhost', port), HttpRequestHandler) 31 | 32 | try: 33 | relpath = Path(__file__).parent.resolve().relative_to(Path.cwd()) 34 | print(f'Open the viewer at http://localhost:{port}/{relpath}/viewer.html') 35 | httpd.serve_forever() 36 | except KeyboardInterrupt: 37 | pass 38 | -------------------------------------------------------------------------------- /viewer/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | } 4 | html, body { 5 | height: 100%; 6 | } 7 | 8 | .viewer-row { 9 | display: flex; 10 | width: 100%; 11 | height: 100%; 12 | } 13 | .column { 14 | padding: 0px; 15 | margin: 0px; 16 | } 17 | .pane3d { 18 | flex: 1; 19 | height: 100%; 20 | overflow: auto; 21 | position: relative; 22 | background-color: #cccccc; 23 | border-style: solid; 24 | border-color: #000000; 25 | border-width: 0.5px; 26 | box-sizing: border-box; 27 | } 28 | .pane2d { 29 | flex-shrink: 0 30 | vertical-align:top; 31 | max-width: 35%; 32 | width: 30%; 33 | height: 100%; 34 | max-height: 100%; 35 | } 36 | .threeCanvas { 37 | margin: 0px; 38 | height: 100%; 39 | width: 100%; 40 | } 41 | .threeCanvas > canvas:focus { 42 | outline: none; 43 | } 44 | #info { 45 | color: black; 46 | font-size: 15px; 47 | position: absolute; 48 | bottom: 0; 49 | left: 0; 50 | margin: 20px; 51 | user-select: none; 52 | pointer-events: none; 53 | z-index: 1; 54 | } 55 | .animContainer { 56 | position: relative; 57 | padding: 5px 0px 5px 10px; 58 | box-sizing: border-box; 59 | height: 50%; 60 | } 61 | .animContainer > canvas { 62 | background-color: #cccccc; 63 | position: absolute; 64 | margin: auto; 65 | display: block; 66 | } 67 | .animContainer > p { 68 | color: red; 69 | font-family: Arial; 70 | font-size: 20px; 71 | position: absolute; 72 | margin: 16px; 73 | user-select: none; 74 | pointer-events: none; 75 | } 76 | -------------------------------------------------------------------------------- /viewer/viewer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | PixLoc 2D-3D viewer 6 | 7 | 8 | 13 | 14 | 15 |
16 |
17 |
18 |
19 | Mouse left/middle/right to orbit/zoom/pan
20 | Press + or - to increase or decrease the point size
21 | Press [ or ] to increase or decrease the frustum size
22 | Press i and o to initialize and step through the optimization
23 | Press p to play or clear the animation
24 | Press r to start or stop the auto rotation
25 | Press h to hide this help 26 |
27 |
28 |
29 |
30 |
31 | Could not load. 32 |

Query

33 |
34 |
35 | Could not load. 36 |

Reference

37 |
38 |
39 |
40 | 44 | 45 | 46 | --------------------------------------------------------------------------------