├── .gitignore ├── LICENSE ├── README.md ├── checkpoints-pretrained └── config.json ├── data ├── eval │ └── .keep └── train │ └── .keep ├── env.yml ├── feature-utils ├── convert_database_to_numpy.py ├── extract_descriptors.py └── extract_sift.py ├── lib ├── database.py ├── datasets.py ├── losses.py ├── networks.py └── utils.py ├── local-feature-evaluation ├── align_and_compare.py ├── reconstruction_pipeline_embed.py ├── reconstruction_pipeline_progressive.py ├── reconstruction_pipeline_subset.py └── utils.py ├── scripts ├── download_checkpoints.sh ├── download_evaluation_data.sh ├── download_processed_training_data.sh ├── download_training_data.sh ├── process_LFE_data.sh └── process_training_data.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data 3 | *.pth 4 | checkpoints 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, ETH Zurich. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross-Descriptor Visual Localization and Mapping 2 | 3 | This repository contains the implementation of the following paper: 4 | 5 | ```text 6 | "Cross-Descriptor Visual Localization and Mapping". 7 | M. Dusmanu, O. Miksik, J.L. Schönberger, and M. Pollefeys. ICCV 2021. 8 | ``` 9 | 10 | [[Paper on arXiv]](https://arxiv.org/abs/2012.01377) 11 | 12 | 13 | ## Requirements 14 | 15 | ### COLMAP 16 | 17 | We use COLMAP for DoG keypoint extraction as well as localization and mapping. 18 | Please follow the installation instructions available on the [official webpage](https://colmap.github.io). 19 | Before proceeding, we recommend setting an environmental variable to the COLMAP executable folder by running `export COLMAP_PATH=path_to_colmap_executable_folder`. 20 | 21 | ### Python 22 | 23 | The environment can be set up directly using conda: 24 | ``` 25 | conda env create -f env.yml 26 | conda activate cross-descriptor-vis-loc-map 27 | ``` 28 | 29 | ### Training data 30 | 31 | We provide a script for downloading the raw training data: 32 | ``` 33 | bash scripts/download_training_data.sh 34 | ``` 35 | 36 | ### Evaluation data 37 | 38 | We provide a script for downloading the LFE dataset along with the GT used for evaluation as well as the Aachen Day-Night dataset: 39 | ``` 40 | bash scripts/download_evaluation_data.sh 41 | ``` 42 | 43 | 44 | ## Training 45 | 46 | ### Data preprocessing 47 | 48 | First step is extracting keypoints and descriptors on the training data downloaded above. 49 | ``` 50 | bash scripts/process_training_data.sh 51 | ``` 52 | Alternatively, you can directly download the processed training data by running: 53 | ``` 54 | bash scripts/download_processed_training_data.sh 55 | ``` 56 | 57 | ### Training 58 | 59 | To run training with the default architecture and hyper-parameters, execute the following: 60 | ``` 61 | python train.py \ 62 | --dataset_path data/train/colmap \ 63 | --features brief sift-kornia hardnet sosnet 64 | ``` 65 | 66 | ### Pretrained models 67 | 68 | We provide two pretrained models trained on descriptors extracted from COLMAP SIFT and OpenCV SIFT keypoints, respectively. 69 | These models can be downloaded by running: 70 | ``` 71 | bash scripts/download_checkpoints.sh 72 | ``` 73 | 74 | 75 | ## Evaluation 76 | 77 | ### Demo Notebook 78 | 79 |
80 | Click for details... 81 | 82 |
83 | 84 | ### Local Feature Evaluation Benchmark 85 | 86 |
87 | Click for details... 88 | 89 | First step is extracting descriptors on all datasets: 90 | ``` 91 | bash scripts/process_LFE_data.sh 92 | ``` 93 | 94 | We provide examples below for running reconstruction on Madrid Metrpolis in each different evaluation scenario. 95 | 96 | #### Reconstruction using a single descriptor (standard) 97 | 98 | ``` 99 | python local-feature-evaluation/reconstruction_pipeline_progressive.py \ 100 | --dataset_path data/eval/LFE-release/Madrid_Metropolis \ 101 | --colmap_path $COLMAP_PATH \ 102 | --features sift-kornia \ 103 | --exp_name sift-kornia-single 104 | ``` 105 | 106 | #### Reconstruction using the progressive approach (ours) 107 | 108 | ``` 109 | python local-feature-evaluation/reconstruction_pipeline_progressive.py \ 110 | --dataset_path data/eval/LFE-release/Madrid_Metropolis \ 111 | --colmap_path $COLMAP_PATH \ 112 | --features brief sift-kornia hardnet sosnet \ 113 | --exp_name progressive 114 | ``` 115 | 116 | #### Reconstruction using the joint embedding approach (ours) 117 | 118 | ``` 119 | python local-feature-evaluation/reconstruction_pipeline_embed.py \ 120 | --dataset_path data/eval/LFE-release/Madrid_Metropolis \ 121 | --colmap_path $COLMAP_PATH \ 122 | --features brief sift-kornia hardnet sosnet \ 123 | --exp_name embed 124 | ``` 125 | 126 | #### Reconstruction using a single descriptor on the associated split (real-world) 127 | 128 | ``` 129 | python local-feature-evaluation/reconstruction_pipeline_subset.py \ 130 | --dataset_path data/eval/LFE-release/Madrid_Metropolis/ \ 131 | --colmap_path $COLMAP_PATH \ 132 | --features brief sift-kornia hardnet sosnet \ 133 | --feature sift-kornia \ 134 | --exp_name sift-kornia-subset 135 | ``` 136 | 137 | #### Evaluation of a reconstruction w.r.t. metric pseudo-ground-truth 138 | 139 | ``` 140 | python local-feature-evaluation/align_and_compare.py \ 141 | --colmap_path $COLMAP_PATH \ 142 | --reference_model_path data/eval/LFE-release/Madrid_Metropolis/sparse-reference/filtered-metric/ \ 143 | --model_path data/eval/LFE-release/Madrid_Metropolis/sparse-sift-kornia-single/0/ 144 | ``` 145 | 146 |
147 | 148 | ### Aachen Day-Night 149 | 150 |
151 | Click for details... 152 | 153 |
154 | 155 | ## BibTeX 156 | 157 | If you use this code in your project, please cite the following paper: 158 | ``` 159 | @InProceedings{Dusmanu2021Cross, 160 | author = {Dusmanu, Mihai and Miksik, Ondrej and Sch\"onberger, Johannes L. and Pollefeys, Marc}, 161 | title = {{Cross Descriptor Visual Localization and Mapping}}, 162 | booktitle = {Proceedings of the International Conference on Computer Vision}, 163 | year = {2021} 164 | } 165 | ``` 166 | -------------------------------------------------------------------------------- /checkpoints-pretrained/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_bn": true, 3 | "emb_dim": 128, 4 | "emb_l2_norm": true, 5 | "emb_last_activation": null, 6 | "brief": { 7 | "descriptor_dim": 512, 8 | "hidden_dims": [1024, 1024], 9 | "l2_norm": false, 10 | "last_activation": "sigmoid" 11 | }, 12 | "sift-kornia": { 13 | "descriptor_dim": 128, 14 | "hidden_dims": [1024, 1024], 15 | "l2_norm": true, 16 | "last_activation": "relu" 17 | }, 18 | "hardnet": { 19 | "descriptor_dim": 128, 20 | "hidden_dims": [256, 256], 21 | "l2_norm": true, 22 | "last_activation": null 23 | }, 24 | "sosnet": { 25 | "descriptor_dim": 128, 26 | "hidden_dims": [256, 256], 27 | "l2_norm": true, 28 | "last_activation": null 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /data/eval/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mihaidusmanu/cross-descriptor-vis-loc-map/2d3c7eb716706c65710480b0e82eee4b653ddabd/data/eval/.keep -------------------------------------------------------------------------------- /data/train/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mihaidusmanu/cross-descriptor-vis-loc-map/2d3c7eb716706c65710480b0e82eee4b653ddabd/data/train/.keep -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: cross-descriptor-vis-loc-map 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=10.2 8 | - pip 9 | - python=3.9 10 | - pytorch=1.9 11 | - torchaudio 12 | - torchvision 13 | - tqdm 14 | - pip: 15 | - extract-patches 16 | - kornia 17 | - opencv-python 18 | -------------------------------------------------------------------------------- /feature-utils/convert_database_to_numpy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import os 6 | 7 | import sqlite3 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument( 16 | '--dataset_path', type=str, required=True, 17 | help='path to the dataset' 18 | ) 19 | parser.add_argument( 20 | '--feature', type=str, required=True, 21 | help='descriptor' 22 | ) 23 | 24 | args = parser.parse_args() 25 | 26 | database_path = os.path.join( 27 | args.dataset_path, '%s-features.db' % args.feature 28 | ) 29 | output_path = os.path.join( 30 | args.dataset_path, '%s-features.npy' % args.feature 31 | ) 32 | 33 | connection = sqlite3.connect(database_path) 34 | cursor = connection.cursor() 35 | 36 | all_descriptors = [] 37 | for (image_id,) in tqdm(cursor.execute('SELECT image_id FROM images').fetchall()): 38 | r, c, blob = cursor.execute('SELECT rows, cols, data FROM descriptors WHERE image_id=?', (image_id,)).fetchall()[0] 39 | try: 40 | descriptors = np.frombuffer(blob, dtype=np.float32).reshape(r, c) 41 | except ValueError: 42 | descriptors = np.frombuffer(blob, dtype=bool).reshape(r, c) 43 | all_descriptors.append(descriptors) 44 | all_descriptors = np.concatenate(all_descriptors, axis=0) 45 | 46 | # Random shuffle - not required. 47 | random = np.random.RandomState(seed=1) 48 | random.shuffle(all_descriptors) 49 | 50 | np.save(output_path, all_descriptors) 51 | 52 | cursor.close() 53 | connection.close() -------------------------------------------------------------------------------- /feature-utils/extract_descriptors.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/vcg-uvic/image-matching-benchmark-baselines/blob/master/extract_descriptors_hardnet.py. 2 | 3 | import argparse 4 | 5 | import numpy as np 6 | 7 | import os 8 | 9 | import cv2 10 | 11 | import kornia 12 | 13 | import shutil 14 | 15 | import sqlite3 16 | 17 | import types 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | import torchvision.transforms as transforms 24 | 25 | import tqdm 26 | 27 | from extract_patches.core import extract_patches 28 | 29 | 30 | def get_transforms(): 31 | transform = transforms.Compose([ 32 | transforms.Lambda(lambda x: cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)), 33 | transforms.Lambda(lambda x: np.reshape(x, (32, 32, 1))), 34 | transforms.ToTensor(), 35 | ]) 36 | return transform 37 | 38 | 39 | class BRIEFDescriptor(nn.Module): 40 | # Adapted from https://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.BRIEF. 41 | def __init__(self, desc_size=512, patch_size=32, seed=1): 42 | super(BRIEFDescriptor, self).__init__() 43 | 44 | # Sampling pattern. 45 | random = np.random.RandomState() 46 | random.seed(seed) 47 | samples = (patch_size / 5.0) * random.randn(desc_size * 8) 48 | samples = np.array(samples, dtype=np.int32) 49 | samples = samples[ 50 | (samples <= (patch_size // 2)) & (samples >= - (patch_size - 2) // 2) 51 | ] 52 | samples += (patch_size // 2 - 1) 53 | pos1 = samples[: desc_size * 2].reshape(desc_size, 2) 54 | pos2 = samples[desc_size * 2 : desc_size * 4].reshape(desc_size, 2) 55 | 56 | # Create tensors. 57 | self.pos1 = torch.from_numpy(pos1).long() 58 | self.pos2 = torch.from_numpy(pos2).long() 59 | 60 | def forward(self, patches): 61 | pixel_values1 = patches[:, 0, self.pos1[:, 0], self.pos1[:, 1]] 62 | pixel_values2 = patches[:, 0, self.pos2[:, 0], self.pos2[:, 1]] 63 | descriptors = (pixel_values1 < pixel_values2) 64 | return descriptors 65 | 66 | 67 | def recover_database_images_and_ids(database_path): 68 | # Connect to the database. 69 | connection = sqlite3.connect(database_path) 70 | cursor = connection.cursor() 71 | 72 | # Recover database images and ids. 73 | images = {} 74 | cursor.execute('SELECT name, image_id FROM images;') 75 | for row in cursor: 76 | images[row[1]] = row[0] 77 | 78 | # Close the connection to the database. 79 | cursor.close() 80 | connection.close() 81 | 82 | return images 83 | 84 | 85 | if __name__ == '__main__': 86 | parser = argparse.ArgumentParser() 87 | 88 | parser.add_argument( 89 | '--dataset_path', type=str, required=True, 90 | help='path to the dataset' 91 | ) 92 | parser.add_argument( 93 | '--image_path', type=str, default=None, 94 | help='path to the images' 95 | ) 96 | parser.add_argument( 97 | '--feature', type=str, required=True, 98 | choices=['brief', 'sift-kornia', 'hardnet', 'sosnet'], 99 | help='descriptors to be extracted' 100 | ) 101 | parser.add_argument( 102 | '--mr_size', type=float, default=12.0, 103 | help='patch size in image is mr_size * pt.size' 104 | ) 105 | parser.add_argument( 106 | '--batch_size', type=int, default=512, 107 | help='path to the model weights' 108 | ) 109 | 110 | args = parser.parse_args() 111 | 112 | if args.image_path is None: 113 | args.image_path = args.dataset_path 114 | 115 | # Dataset paths. 116 | paths = types.SimpleNamespace() 117 | paths.sift_database_path = os.path.join(args.dataset_path, 'sift-features.db') 118 | paths.database_path = os.path.join(args.dataset_path, '%s-features.db' % args.feature) 119 | 120 | # Copy SIFT database. 121 | if os.path.exists(paths.database_path): 122 | raise FileExistsError('Database already exists at %s.' % paths.database_path) 123 | shutil.copy(paths.sift_database_path, paths.database_path) 124 | 125 | # PyTorch settings. 126 | use_cuda = torch.cuda.is_available() 127 | device = torch.device("cuda:0" if use_cuda else "cpu") 128 | torch.set_grad_enabled(False) 129 | 130 | # Network and input transforms. 131 | dim = 128 132 | dtype = np.float32 133 | if args.feature == 'brief': 134 | model = BRIEFDescriptor() 135 | model = model.to(device) 136 | dim = 512 137 | dtype = bool 138 | elif args.feature == 'sift-kornia': 139 | model = kornia.feature.SIFTDescriptor(patch_size=32, rootsift=False) 140 | model = model.to(device) 141 | elif args.feature == 'hardnet': 142 | model = kornia.feature.HardNet(pretrained=True) 143 | model = model.to(device) 144 | model.eval() 145 | elif args.feature == 'sosnet': 146 | model = kornia.feature.SOSNet(pretrained=True) 147 | model = model.to(device) 148 | model.eval() 149 | transform = get_transforms() 150 | 151 | # Recover list of images. 152 | images = recover_database_images_and_ids(paths.database_path) 153 | 154 | # Connect to database. 155 | connection = sqlite3.connect(paths.database_path) 156 | cursor = connection.cursor() 157 | 158 | cursor.execute('DELETE FROM descriptors;') 159 | connection.commit() 160 | 161 | cursor.execute('SELECT image_id, rows, cols, data FROM keypoints;') 162 | raw_keypoints = cursor.fetchall() 163 | for row in tqdm.tqdm(raw_keypoints): 164 | assert(row[2] == 6) 165 | image_id = row[0] 166 | image_relative_path = images[image_id] 167 | if row[1] == 0: 168 | keypoints = np.zeros([0, 6]) 169 | else: 170 | keypoints = np.frombuffer(row[-1], dtype=np.float32).reshape(row[1], row[2]) 171 | 172 | keypoints = np.copy(keypoints) 173 | # In COLMAP, the upper left pixel has the coordinate (0.5, 0.5). 174 | keypoints[:, 0] = keypoints[:, 0] - .5 175 | keypoints[:, 1] = keypoints[:, 1] - .5 176 | 177 | # Extract patches. 178 | image = cv2.cvtColor( 179 | cv2.imread(os.path.join(args.image_path, image_relative_path)), 180 | cv2.COLOR_BGR2RGB 181 | ) 182 | 183 | patches = extract_patches( 184 | keypoints, image, 32, args.mr_size, 'xyA' 185 | ) 186 | 187 | # Extract descriptors. 188 | descriptors = np.zeros((len(patches), dim), dtype=dtype) 189 | for i in range(0, len(patches), args.batch_size): 190 | data_a = patches[i : i + args.batch_size] 191 | data_a = torch.stack( 192 | [transform(patch) for patch in data_a] 193 | ).to(device) 194 | # Predict 195 | out_a = model(data_a) 196 | descriptors[i : i + args.batch_size] = out_a.cpu().detach().numpy() 197 | 198 | # Insert into database. 199 | cursor.execute( 200 | 'INSERT INTO descriptors(image_id, rows, cols, data) VALUES(?, ?, ?, ?);', 201 | (image_id, descriptors.shape[0], descriptors.shape[1], descriptors.tobytes()) 202 | ) 203 | connection.commit() 204 | 205 | # Close connection to database. 206 | cursor.close() 207 | connection.close() 208 | -------------------------------------------------------------------------------- /feature-utils/extract_sift.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | 5 | import subprocess 6 | 7 | import types 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument( 14 | '--colmap_path', type=str, required=True, 15 | help='path to the COLMAP executable folder' 16 | ) 17 | parser.add_argument( 18 | '--dataset_path', type=str, required=True, 19 | help='path to the dataset' 20 | ) 21 | parser.add_argument( 22 | '--image_path', type=str, default=None, 23 | help='path to the images' 24 | ) 25 | 26 | args = parser.parse_args() 27 | 28 | if args.image_path is None: 29 | args.image_path = args.dataset_path 30 | 31 | # Dataset paths. 32 | paths = types.SimpleNamespace() 33 | paths.database_path = os.path.join(args.dataset_path, 'sift-features.db') 34 | 35 | if os.path.exists(paths.database_path): 36 | raise FileExistsError('Database already exists at %s.' % paths.database_path) 37 | 38 | # Extract SIFT features. 39 | subprocess.call([ 40 | os.path.join(args.colmap_path, 'colmap'), 'feature_extractor', 41 | '--database_path', paths.database_path, 42 | '--image_path', args.image_path, 43 | '--SiftExtraction.first_octave', str(0), 44 | '--SiftExtraction.num_threads', str(1) 45 | ]) 46 | -------------------------------------------------------------------------------- /lib/database.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | # This script is based on an original implementation by True Price. 33 | 34 | import sys 35 | import sqlite3 36 | import numpy as np 37 | 38 | 39 | IS_PYTHON3 = sys.version_info[0] >= 3 40 | 41 | MAX_IMAGE_ID = 2**31 - 1 42 | 43 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( 44 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 45 | model INTEGER NOT NULL, 46 | width INTEGER NOT NULL, 47 | height INTEGER NOT NULL, 48 | params BLOB, 49 | prior_focal_length INTEGER NOT NULL)""" 50 | 51 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( 52 | image_id INTEGER PRIMARY KEY NOT NULL, 53 | rows INTEGER NOT NULL, 54 | cols INTEGER NOT NULL, 55 | data BLOB, 56 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 57 | 58 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( 59 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 60 | name TEXT NOT NULL UNIQUE, 61 | camera_id INTEGER NOT NULL, 62 | prior_qw REAL, 63 | prior_qx REAL, 64 | prior_qy REAL, 65 | prior_qz REAL, 66 | prior_tx REAL, 67 | prior_ty REAL, 68 | prior_tz REAL, 69 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}), 70 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id)) 71 | """.format(MAX_IMAGE_ID) 72 | 73 | CREATE_TWO_VIEW_GEOMETRIES_TABLE = """ 74 | CREATE TABLE IF NOT EXISTS two_view_geometries ( 75 | pair_id INTEGER PRIMARY KEY NOT NULL, 76 | rows INTEGER NOT NULL, 77 | cols INTEGER NOT NULL, 78 | data BLOB, 79 | config INTEGER NOT NULL, 80 | F BLOB, 81 | E BLOB, 82 | H BLOB) 83 | """ 84 | 85 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( 86 | image_id INTEGER PRIMARY KEY NOT NULL, 87 | rows INTEGER NOT NULL, 88 | cols INTEGER NOT NULL, 89 | data BLOB, 90 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE) 91 | """ 92 | 93 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( 94 | pair_id INTEGER PRIMARY KEY NOT NULL, 95 | rows INTEGER NOT NULL, 96 | cols INTEGER NOT NULL, 97 | data BLOB)""" 98 | 99 | CREATE_NAME_INDEX = \ 100 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" 101 | 102 | CREATE_ALL = "; ".join([ 103 | CREATE_CAMERAS_TABLE, 104 | CREATE_IMAGES_TABLE, 105 | CREATE_KEYPOINTS_TABLE, 106 | CREATE_DESCRIPTORS_TABLE, 107 | CREATE_MATCHES_TABLE, 108 | CREATE_TWO_VIEW_GEOMETRIES_TABLE, 109 | CREATE_NAME_INDEX 110 | ]) 111 | 112 | 113 | def image_ids_to_pair_id(image_id1, image_id2): 114 | if image_id1 > image_id2: 115 | image_id1, image_id2 = image_id2, image_id1 116 | return image_id1 * MAX_IMAGE_ID + image_id2 117 | 118 | 119 | def pair_id_to_image_ids(pair_id): 120 | image_id2 = pair_id % MAX_IMAGE_ID 121 | image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID 122 | return image_id1, image_id2 123 | 124 | 125 | def array_to_blob(array): 126 | if IS_PYTHON3: 127 | return array.tostring() 128 | else: 129 | return np.getbuffer(array) 130 | 131 | 132 | def blob_to_array(blob, dtype, shape=(-1,)): 133 | if blob is None: 134 | return np.zeros(shape, dtype=dtype) 135 | if IS_PYTHON3: 136 | return np.fromstring(blob, dtype=dtype).reshape(*shape) 137 | else: 138 | return np.frombuffer(blob, dtype=dtype).reshape(*shape) 139 | 140 | 141 | class COLMAPDatabase(sqlite3.Connection): 142 | 143 | @staticmethod 144 | def connect(database_path): 145 | db = sqlite3.connect(database_path, factory=COLMAPDatabase) 146 | db.path = database_path 147 | return db 148 | 149 | 150 | def __init__(self, *args, **kwargs): 151 | super(COLMAPDatabase, self).__init__(*args, **kwargs) 152 | 153 | self.path = None 154 | 155 | self.create_tables = lambda: self.executescript(CREATE_ALL) 156 | self.create_cameras_table = \ 157 | lambda: self.executescript(CREATE_CAMERAS_TABLE) 158 | self.create_descriptors_table = \ 159 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) 160 | self.create_images_table = \ 161 | lambda: self.executescript(CREATE_IMAGES_TABLE) 162 | self.create_two_view_geometries_table = \ 163 | lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE) 164 | self.create_keypoints_table = \ 165 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE) 166 | self.create_matches_table = \ 167 | lambda: self.executescript(CREATE_MATCHES_TABLE) 168 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) 169 | 170 | def add_camera(self, model, width, height, params, 171 | prior_focal_length=False, camera_id=None): 172 | params = np.asarray(params, np.float64) 173 | cursor = self.execute( 174 | "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", 175 | (camera_id, model, width, height, array_to_blob(params), 176 | prior_focal_length)) 177 | return cursor.lastrowid 178 | 179 | def add_image(self, name, camera_id, 180 | prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None): 181 | cursor = self.execute( 182 | "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 183 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], 184 | prior_q[3], prior_t[0], prior_t[1], prior_t[2])) 185 | return cursor.lastrowid 186 | 187 | def add_keypoints(self, image_id, keypoints): 188 | assert(len(keypoints.shape) == 2) 189 | assert(keypoints.shape[1] in [2, 4, 6]) 190 | 191 | keypoints = np.asarray(keypoints, np.float32) 192 | self.execute( 193 | "INSERT INTO keypoints VALUES (?, ?, ?, ?)", 194 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) 195 | 196 | def add_descriptors(self, image_id, descriptors): 197 | descriptors = np.ascontiguousarray(descriptors, np.uint8) 198 | self.execute( 199 | "INSERT INTO descriptors VALUES (?, ?, ?, ?)", 200 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) 201 | 202 | def add_matches(self, image_id1, image_id2, matches): 203 | assert(len(matches.shape) == 2) 204 | assert(matches.shape[1] == 2) 205 | 206 | if image_id1 > image_id2: 207 | matches = matches[:,::-1] 208 | 209 | pair_id = image_ids_to_pair_id(image_id1, image_id2) 210 | matches = np.asarray(matches, np.uint32) 211 | self.execute( 212 | "INSERT INTO matches VALUES (?, ?, ?, ?)", 213 | (pair_id,) + matches.shape + (array_to_blob(matches),)) 214 | 215 | def add_two_view_geometry(self, image_id1, image_id2, matches, 216 | F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2): 217 | assert(len(matches.shape) == 2) 218 | assert(matches.shape[1] == 2) 219 | 220 | if image_id1 > image_id2: 221 | matches = matches[:,::-1] 222 | 223 | pair_id = image_ids_to_pair_id(image_id1, image_id2) 224 | matches = np.asarray(matches, np.uint32) 225 | F = np.asarray(F, dtype=np.float64) 226 | E = np.asarray(E, dtype=np.float64) 227 | H = np.asarray(H, dtype=np.float64) 228 | self.execute( 229 | "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)", 230 | (pair_id,) + matches.shape + (array_to_blob(matches), config, 231 | array_to_blob(F), array_to_blob(E), array_to_blob(H))) 232 | 233 | 234 | def example_usage(): 235 | import os 236 | import argparse 237 | 238 | parser = argparse.ArgumentParser() 239 | parser.add_argument("--database_path", default="database.db") 240 | args = parser.parse_args() 241 | 242 | if os.path.exists(args.database_path): 243 | print("ERROR: database path already exists -- will not modify it.") 244 | return 245 | 246 | # Open the database. 247 | 248 | db = COLMAPDatabase.connect(args.database_path) 249 | 250 | # For convenience, try creating all the tables upfront. 251 | 252 | db.create_tables() 253 | 254 | # Create dummy cameras. 255 | 256 | model1, width1, height1, params1 = \ 257 | 0, 1024, 768, np.array((1024., 512., 384.)) 258 | model2, width2, height2, params2 = \ 259 | 2, 1024, 768, np.array((1024., 512., 384., 0.1)) 260 | 261 | camera_id1 = db.add_camera(model1, width1, height1, params1) 262 | camera_id2 = db.add_camera(model2, width2, height2, params2) 263 | 264 | # Create dummy images. 265 | 266 | image_id1 = db.add_image("image1.png", camera_id1) 267 | image_id2 = db.add_image("image2.png", camera_id1) 268 | image_id3 = db.add_image("image3.png", camera_id2) 269 | image_id4 = db.add_image("image4.png", camera_id2) 270 | 271 | # Create dummy keypoints. 272 | # 273 | # Note that COLMAP supports: 274 | # - 2D keypoints: (x, y) 275 | # - 4D keypoints: (x, y, theta, scale) 276 | # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) 277 | 278 | num_keypoints = 1000 279 | keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1) 280 | keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1) 281 | keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2) 282 | keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2) 283 | 284 | db.add_keypoints(image_id1, keypoints1) 285 | db.add_keypoints(image_id2, keypoints2) 286 | db.add_keypoints(image_id3, keypoints3) 287 | db.add_keypoints(image_id4, keypoints4) 288 | 289 | # Create dummy matches. 290 | 291 | M = 50 292 | matches12 = np.random.randint(num_keypoints, size=(M, 2)) 293 | matches23 = np.random.randint(num_keypoints, size=(M, 2)) 294 | matches34 = np.random.randint(num_keypoints, size=(M, 2)) 295 | 296 | db.add_matches(image_id1, image_id2, matches12) 297 | db.add_matches(image_id2, image_id3, matches23) 298 | db.add_matches(image_id3, image_id4, matches34) 299 | 300 | # Commit the data to the file. 301 | 302 | db.commit() 303 | 304 | # Read and check cameras. 305 | 306 | rows = db.execute("SELECT * FROM cameras") 307 | 308 | camera_id, model, width, height, params, prior = next(rows) 309 | params = blob_to_array(params, np.float64) 310 | assert camera_id == camera_id1 311 | assert model == model1 and width == width1 and height == height1 312 | assert np.allclose(params, params1) 313 | 314 | camera_id, model, width, height, params, prior = next(rows) 315 | params = blob_to_array(params, np.float64) 316 | assert camera_id == camera_id2 317 | assert model == model2 and width == width2 and height == height2 318 | assert np.allclose(params, params2) 319 | 320 | # Read and check keypoints. 321 | 322 | keypoints = dict( 323 | (image_id, blob_to_array(data, np.float32, (-1, 2))) 324 | for image_id, data in db.execute( 325 | "SELECT image_id, data FROM keypoints")) 326 | 327 | assert np.allclose(keypoints[image_id1], keypoints1) 328 | assert np.allclose(keypoints[image_id2], keypoints2) 329 | assert np.allclose(keypoints[image_id3], keypoints3) 330 | assert np.allclose(keypoints[image_id4], keypoints4) 331 | 332 | # Read and check matches. 333 | 334 | pair_ids = [image_ids_to_pair_id(*pair) for pair in 335 | ((image_id1, image_id2), 336 | (image_id2, image_id3), 337 | (image_id3, image_id4))] 338 | 339 | matches = dict( 340 | (pair_id_to_image_ids(pair_id), 341 | blob_to_array(data, np.uint32, (-1, 2))) 342 | for pair_id, data in db.execute("SELECT pair_id, data FROM matches") 343 | ) 344 | 345 | assert np.all(matches[(image_id1, image_id2)] == matches12) 346 | assert np.all(matches[(image_id2, image_id3)] == matches23) 347 | assert np.all(matches[(image_id3, image_id4)] == matches34) 348 | 349 | # Clean up. 350 | 351 | db.close() 352 | 353 | if os.path.exists(args.database_path): 354 | os.remove(args.database_path) 355 | 356 | 357 | if __name__ == "__main__": 358 | example_usage() 359 | -------------------------------------------------------------------------------- /lib/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class TranslationDataset(Dataset): 10 | def __init__( 11 | self, 12 | base_path=None, 13 | features=['brief', 'sift-kornia', 'hardnet', 'sosnet'], 14 | subsampling_ratio=1.0 15 | ): 16 | self.features = features 17 | 18 | self.arrays = {} 19 | for feature in self.features: 20 | npy_path = os.path.join( 21 | base_path, '%s-features.npy' % feature 22 | ) 23 | descriptors = np.load(npy_path) 24 | # Deterministic subsampling. 25 | if subsampling_ratio < 1.0: 26 | num_descriptors = int(np.ceil(subsampling_ratio * descriptors.shape[0])) 27 | random = np.random.RandomState(seed=1) 28 | selected_ids = random.choice(descriptors.shape[0], num_descriptors, replace=False) 29 | else: 30 | selected_ids = np.arange(descriptors.shape[0]) 31 | self.arrays[feature] = descriptors[selected_ids, :] 32 | 33 | self.len = self.arrays[features[0]].shape[0] 34 | 35 | def __len__(self): 36 | return self.len 37 | 38 | def __getitem__(self, idx): 39 | sample = {} 40 | for feature in self.features: 41 | sample[feature] = torch.from_numpy( 42 | self.arrays[feature][idx, :] 43 | ).float() 44 | return sample 45 | -------------------------------------------------------------------------------- /lib/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | 8 | def exhaustive_loss(encoders, decoders, batch, device, alpha=0.1, margin=1.0): 9 | # Translation loss. 10 | embeddings = {} 11 | for source_feature in batch.keys(): 12 | source_descriptors = batch[source_feature] 13 | embeddings[source_feature] = encoders[source_feature](source_descriptors) 14 | 15 | all_embeddings = torch.cat(list(embeddings.values()), dim=0) 16 | 17 | t_loss = torch.tensor(0.).float().to(device) 18 | for target_feature in batch.keys(): 19 | target_descriptors = batch[target_feature] 20 | output_descriptors = decoders[target_feature](all_embeddings) 21 | if target_feature == 'brief': 22 | current_loss = F.binary_cross_entropy( 23 | output_descriptors, 24 | torch.cat([target_descriptors] * len(batch), dim=0) 25 | ) 26 | else: 27 | current_loss = torch.mean( 28 | torch.norm(output_descriptors - torch.cat([target_descriptors] * len(batch), dim=0), dim=1) 29 | ) 30 | t_loss += current_loss 31 | t_loss /= len(batch) 32 | 33 | # Triplet loss in embedding space. 34 | e_loss = torch.tensor(0.).float().to(device) 35 | if alpha > 0: 36 | for source_feature in embeddings.keys(): 37 | for target_feature in embeddings.keys(): 38 | # TODO: Implement symmetric negative mining. 39 | sqdist_matrix = 2 - 2 * embeddings[source_feature] @ embeddings[target_feature].T 40 | pos_dist = torch.norm(torch.diag(sqdist_matrix).unsqueeze(-1), dim=-1) 41 | sqdist_matrix = sqdist_matrix + torch.diag(torch.full((sqdist_matrix.shape[0],), np.inf)).to(device) 42 | # neg_sqdist = torch.min(torch.min(sqdist_matrix, dim=-1)[0], torch.min(sqdist_matrix, dim=0)[0]) 43 | neg_sqdist = torch.min(sqdist_matrix, dim=-1)[0] 44 | neg_dist = torch.norm(neg_sqdist.unsqueeze(-1), dim=-1) 45 | e_loss = e_loss + torch.mean( 46 | F.relu(margin + pos_dist - neg_dist) 47 | ) 48 | e_loss /= (len(batch) ** 2) 49 | 50 | # Final loss. 51 | if alpha > 0: 52 | loss = t_loss + alpha * e_loss 53 | else: 54 | loss = t_loss 55 | 56 | return loss, (t_loss.detach(), e_loss.detach()) 57 | 58 | 59 | def autoencoder_loss(encoders, decoders, batch, device, alpha=0.1, margin=1.0): 60 | # AE loss. 61 | embeddings = {} 62 | t_loss = torch.tensor(0.).float().to(device) 63 | for source_feature in batch.keys(): 64 | source_descriptors = batch[source_feature] 65 | current_embeddings = encoders[source_feature](source_descriptors) 66 | embeddings[source_feature] = current_embeddings 67 | output_descriptors = decoders[source_feature](current_embeddings) 68 | if source_feature == 'brief': 69 | current_loss = F.binary_cross_entropy( 70 | output_descriptors, source_descriptors 71 | ) 72 | else: 73 | current_loss = torch.mean( 74 | torch.norm(output_descriptors - source_descriptors, dim=1) 75 | ) 76 | t_loss += current_loss 77 | t_loss /= len(batch) 78 | 79 | # Triplet loss in embedding space. 80 | e_loss = torch.tensor(0.).float().to(device) 81 | if alpha > 0: 82 | for source_feature in embeddings.keys(): 83 | for target_feature in embeddings.keys(): 84 | # TODO: Implement symmetric negative mining. 85 | sqdist_matrix = 2 - 2 * embeddings[source_feature] @ embeddings[target_feature].T 86 | pos_dist = torch.norm(torch.diag(sqdist_matrix).unsqueeze(-1), dim=-1) 87 | sqdist_matrix = sqdist_matrix + torch.diag(torch.full((sqdist_matrix.shape[0],), np.inf)).to(device) 88 | # neg_sqdist = torch.min(torch.min(sqdist_matrix, dim=-1)[0], torch.min(sqdist_matrix, dim=0)[0]) 89 | neg_sqdist = torch.min(sqdist_matrix, dim=-1)[0] 90 | neg_dist = torch.norm(neg_sqdist.unsqueeze(-1), dim=-1) 91 | e_loss = e_loss + torch.mean( 92 | F.relu(margin + pos_dist - neg_dist) 93 | ) 94 | e_loss /= (len(batch) ** 2) 95 | 96 | # Final loss. 97 | if alpha > 0: 98 | loss = t_loss + alpha * e_loss 99 | else: 100 | loss = t_loss 101 | 102 | return loss, (t_loss.detach(), e_loss.detach()) 103 | 104 | 105 | def subset_loss(encoders, decoders, batch, device, alpha=0.1, margin=1.0): 106 | # Select random pairs of descriptors. 107 | # Make sure that all encoders and all encoders are used. 108 | source_features = np.array(list(batch.keys())) 109 | target_features = np.array(source_features) 110 | np.random.shuffle(target_features) 111 | 112 | # Translation loss. 113 | embeddings = {} 114 | t_loss = torch.tensor(0.).float().to(device) 115 | for source_feature, target_feature in zip(source_features, target_features): 116 | source_descriptors = batch[source_feature] 117 | target_descriptors = batch[target_feature] 118 | current_embeddings = encoders[source_feature](source_descriptors) 119 | embeddings[source_feature] = current_embeddings 120 | output_descriptors = decoders[target_feature](current_embeddings) 121 | if target_feature == 'brief': 122 | current_loss = F.binary_cross_entropy( 123 | output_descriptors, target_descriptors 124 | ) 125 | else: 126 | current_loss = torch.mean( 127 | torch.norm(output_descriptors - target_descriptors, dim=1) 128 | ) 129 | t_loss += current_loss 130 | t_loss /= len(batch) 131 | 132 | # Triplet loss in embedding space. 133 | e_loss = torch.tensor(0.).float().to(device) 134 | if alpha > 0: 135 | for source_feature, target_feature in zip(source_features, target_features): 136 | # TODO: Implement symmetric negative mining. 137 | sqdist_matrix = 2 - 2 * embeddings[source_feature] @ embeddings[target_feature].T 138 | pos_dist = torch.norm(torch.diag(sqdist_matrix).unsqueeze(-1), dim=-1) 139 | sqdist_matrix = sqdist_matrix + torch.diag(torch.full((sqdist_matrix.shape[0],), np.inf)).to(device) 140 | # neg_sqdist = torch.min(torch.min(sqdist_matrix, dim=-1)[0], torch.min(sqdist_matrix, dim=0)[0]) 141 | neg_sqdist = torch.min(sqdist_matrix, dim=-1)[0] 142 | neg_dist = torch.norm(neg_sqdist.unsqueeze(-1), dim=-1) 143 | e_loss = e_loss + torch.mean( 144 | F.relu(margin + pos_dist - neg_dist) 145 | ) 146 | e_loss /= len(batch) 147 | 148 | # Final loss. 149 | if alpha > 0: 150 | loss = t_loss + alpha * e_loss 151 | else: 152 | loss = t_loss 153 | 154 | return loss, (t_loss.detach(), e_loss.detach()) 155 | -------------------------------------------------------------------------------- /lib/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__( 8 | self, 9 | num_channels, 10 | use_cuda=True, 11 | use_bn=False, 12 | last_activation=None, 13 | l2_norm=False 14 | ): 15 | super(MLP, self).__init__() 16 | 17 | # Retrieve from num_channels. 18 | source_num_channels = num_channels[0] 19 | target_num_channels = num_channels[-1] 20 | hidden_num_channels = num_channels[1 : -1] 21 | 22 | # Prepare layers. 23 | layers = [] 24 | previous_num_channels = source_num_channels 25 | for current_num_channels in hidden_num_channels: 26 | layers.append( 27 | nn.Linear(previous_num_channels, current_num_channels) 28 | ) 29 | if use_bn: 30 | layers.append( 31 | nn.BatchNorm1d(current_num_channels) 32 | ) 33 | layers.append( 34 | nn.ReLU(inplace=True) 35 | ) 36 | previous_num_channels = current_num_channels 37 | layers.append(nn.Linear(previous_num_channels, target_num_channels)) 38 | 39 | # Make a sequential model. 40 | self.network = nn.Sequential(*layers) 41 | 42 | # Last activation. 43 | if last_activation is None: 44 | self.last_activation = None 45 | elif last_activation.lower() == 'relu': 46 | self.last_activation = nn.ReLU() 47 | elif last_activation.lower() == 'sigmoid': 48 | self.last_activation = nn.Sigmoid() 49 | else: 50 | raise NotImplementedError('Unknown activation "%s".' % last_activation) 51 | 52 | # L2-normalize output. 53 | self.l2_norm = l2_norm 54 | 55 | # Move to GPU if needed. 56 | if use_cuda: 57 | self.cuda() 58 | 59 | def forward(self, batch): 60 | output = self.network(batch) 61 | if self.last_activation is not None: 62 | output = self.last_activation(output) 63 | if self.l2_norm: 64 | output = F.normalize(output, dim=1) 65 | return output 66 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | from lib.networks import MLP 2 | 3 | 4 | def create_network_for_feature(feature, config, use_cuda): 5 | use_bn = config['use_bn'] 6 | emb_dim = config['emb_dim'] 7 | emb_l2_norm = config['emb_l2_norm'] 8 | emb_last_activation = config['emb_last_activation'] 9 | descriptor_dim = config[feature]['descriptor_dim'] 10 | hidden_dims = config[feature]['hidden_dims'] 11 | l2_norm = config[feature]['l2_norm'] 12 | last_activation = config[feature]['last_activation'] 13 | 14 | # Define encoder. 15 | encoder = MLP( 16 | num_channels=[descriptor_dim] + hidden_dims + [emb_dim], use_cuda=use_cuda, 17 | use_bn=use_bn, last_activation=emb_last_activation, l2_norm=emb_l2_norm 18 | ) 19 | 20 | # Define decoder. 21 | decoder = MLP( 22 | num_channels=[emb_dim] + hidden_dims[:: -1] + [descriptor_dim], use_cuda=use_cuda, 23 | use_bn=use_bn, last_activation=last_activation, l2_norm=l2_norm 24 | ) 25 | 26 | return encoder, decoder -------------------------------------------------------------------------------- /local-feature-evaluation/align_and_compare.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import os 6 | 7 | import subprocess 8 | 9 | 10 | def qvec_to_rotmat(qvec): 11 | w, x, y, z = qvec 12 | R = np.array([ 13 | [ 14 | 1 - 2 * y * y - 2 * z * z, 15 | 2 * x * y - 2 * z * w, 16 | 2 * x * z + 2 * y * w 17 | ], 18 | [ 19 | 2 * x * y + 2 * z * w, 20 | 1 - 2 * x * x - 2 * z * z, 21 | 2 * y * z - 2 * x * w 22 | ], 23 | [ 24 | 2 * x * z - 2 * y * w, 25 | 2 * y * z + 2 * x * w, 26 | 1 - 2 * x * x - 2 * y * y 27 | ] 28 | ]) 29 | return R 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | 35 | parser.add_argument( 36 | '--model_path', required=True, 37 | help='path to the model' 38 | ) 39 | parser.add_argument( 40 | '--reference_model_path', required=True, 41 | help='path to the reference model' 42 | ) 43 | parser.add_argument( 44 | '--colmap_path', required=True, 45 | help='path to the COLMAP executable folder' 46 | ) 47 | 48 | args = parser.parse_args() 49 | 50 | # Create the output path. 51 | aligned_model_path = os.path.join(args.model_path, 'aligned') 52 | 53 | if not os.path.exists(aligned_model_path): 54 | os.mkdir(aligned_model_path) 55 | 56 | # Read and cache the reference model. 57 | with open(os.path.join(args.reference_model_path, 'images.txt'), 'r') as f: 58 | lines = f.readlines() 59 | reference_poses = {} 60 | for line in lines[4 :: 2]: 61 | line = line.strip('\n').split(' ') 62 | image_name = line[-1] 63 | qvec = np.array(list(map(float, line[1 : 5]))) 64 | t = np.array(list(map(float, line[5 : 8]))) 65 | R = qvec_to_rotmat(qvec) 66 | reference_poses[image_name] = [R, t] 67 | 68 | # Run the model aligner. 69 | subprocess.call([ 70 | os.path.join(args.colmap_path, 'colmap'), 'model_aligner', 71 | '--input_path', args.model_path, 72 | '--output_path', aligned_model_path, 73 | '--ref_images_path', os.path.join(args.reference_model_path, 'geo.txt'), 74 | '--robust_alignment_max_error', str(0.25) 75 | ]) 76 | 77 | subprocess.call([ 78 | os.path.join(args.colmap_path, 'colmap'), 'model_converter', 79 | '--input_path', aligned_model_path, 80 | '--output_path', aligned_model_path, 81 | '--output_type', 'TXT' 82 | ]) 83 | 84 | # Parse the aligned model. 85 | with open(os.path.join(aligned_model_path, 'images.txt'), 'r') as f: 86 | lines = f.readlines() 87 | ori_errors = [] 88 | center_errors = [] 89 | image_ids = [] 90 | for line in lines[4 :: 2]: 91 | line = line.strip('\n').split(' ') 92 | image_id = int(line[0]) 93 | image_name = line[-1] 94 | qvec = np.array(list(map(float, line[1 : 5]))) 95 | t = np.array(list(map(float, line[5 : 8]))) 96 | R = qvec_to_rotmat(qvec) 97 | # Compute the error. 98 | annotated_R, annotated_t = reference_poses[image_name] 99 | 100 | rotation_difference = R @ annotated_R.transpose() 101 | ori_error = np.rad2deg(np.arccos(np.clip((np.trace(rotation_difference) - 1) / 2, -1, 1))) 102 | 103 | annotated_C = (-1) * annotated_R.transpose() @ annotated_t 104 | C = (-1) * R.transpose() @ t 105 | center_error = np.linalg.norm(C - annotated_C) 106 | if center_error > 0.10: 107 | image_ids.append((image_id, image_name)) 108 | 109 | ori_errors.append(ori_error) 110 | center_errors.append(center_error) 111 | ori_errors = np.array(ori_errors) 112 | center_errors = np.array(center_errors) 113 | 114 | print('Registered %d out of %d images.' % (len(ori_errors), len(reference_poses))) 115 | print('0.25m, 2 deg:', np.sum(np.logical_and(ori_errors <= 2, center_errors <= 0.25)) / len(reference_poses)) 116 | print('0.50m, 5 deg:', np.sum(np.logical_and(ori_errors <= 5, center_errors <= 0.50)) / len(reference_poses)) 117 | print('inf :', len(ori_errors) / len(reference_poses)) 118 | -------------------------------------------------------------------------------- /local-feature-evaluation/reconstruction_pipeline_embed.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import json 4 | 5 | import numpy as np 6 | 7 | import os 8 | 9 | import shutil 10 | 11 | import sqlite3 12 | 13 | import sys 14 | 15 | import types 16 | 17 | import torch 18 | 19 | from utils import translate_descriptors, build_hybrid_database, match_features, geometric_verification, reconstruct, blob_to_array, compute_extra_stats 20 | 21 | sys.path.append(os.getcwd()) 22 | from lib.utils import create_network_for_feature 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | 28 | parser.add_argument( 29 | '--dataset_path', type=str, required=True, 30 | help='path to the dataset' 31 | ) 32 | 33 | parser.add_argument( 34 | '--colmap_path', type=str, required=True, 35 | help='path to the COLMAP executable folder' 36 | ) 37 | 38 | parser.add_argument( 39 | '--features', nargs='+', type=str, required=True, 40 | help='list of descriptors to consider' 41 | ) 42 | 43 | parser.add_argument( 44 | '--exp_name', type=str, required=True, 45 | help='name of the experiment' 46 | ) 47 | 48 | parser.add_argument( 49 | '--checkpoint', type=str, default='checkpoints-pretrained/model.pth', 50 | help='path to the checkpoint' 51 | ) 52 | 53 | parser.add_argument( 54 | '--batch_size', type=int, default=4096, 55 | help='batch size' 56 | ) 57 | 58 | args = parser.parse_args() 59 | return args 60 | 61 | 62 | def translate_database(image_features, database_path, encoders, batch_size, device): 63 | shutil.copy(database_path, database_path + '.aux') 64 | 65 | connection = sqlite3.connect(database_path + '.aux') 66 | cursor = connection.cursor() 67 | 68 | output_connection = sqlite3.connect(database_path) 69 | output_cursor = output_connection.cursor() 70 | output_cursor.execute( 71 | 'DELETE FROM descriptors' 72 | ) 73 | output_connection.commit() 74 | 75 | for image_id, feature in image_features.items(): 76 | image_id = int(image_id) 77 | result = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id,))).fetchall() 78 | data, rows, cols = result[0] 79 | try: 80 | descriptors = blob_to_array(data, np.float32, (rows, cols)) 81 | except ValueError: 82 | descriptors = blob_to_array(data, bool, (rows, cols)).astype(np.float32) 83 | 84 | descriptors = translate_descriptors(descriptors, feature, None, encoders, {}, batch_size, device) 85 | 86 | output_cursor.execute( 87 | 'INSERT INTO descriptors VALUES (?, ?, ?, ?)', 88 | (image_id,) + descriptors.shape + (descriptors.tobytes(),) 89 | ) 90 | output_connection.commit() 91 | 92 | cursor.close() 93 | connection.close() 94 | os.remove(database_path + '.aux') 95 | 96 | output_cursor.close() 97 | output_connection.close() 98 | 99 | 100 | def main(): 101 | # Set CUDA. 102 | use_cuda = torch.cuda.is_available() 103 | device = torch.device("cuda:0" if use_cuda else "cpu") 104 | torch.set_grad_enabled(False) 105 | 106 | # Load config json. 107 | with open('checkpoints-pretrained/config.json', 'r') as f: 108 | config = json.load(f) 109 | 110 | # Parse arguments. 111 | args = parse_args() 112 | assert(len(args.features) > 1) 113 | 114 | paths = types.SimpleNamespace() 115 | paths.database_path = os.path.join( 116 | args.dataset_path, '%s.db' % args.exp_name 117 | ) 118 | paths.image_path = os.path.join( 119 | args.dataset_path, 'images' 120 | ) 121 | paths.match_list_path = os.path.join( 122 | args.dataset_path, 'match-list-exh.txt' 123 | ) 124 | paths.sparse_path = os.path.join( 125 | args.dataset_path, 'sparse-%s' % args.exp_name 126 | ) 127 | paths.output_path = os.path.join( 128 | args.dataset_path, 'stats-%s.txt' % args.exp_name 129 | ) 130 | 131 | # Copy reference database. 132 | if os.path.exists(paths.database_path): 133 | raise FileExistsError('Database file already exists.') 134 | shutil.copy( 135 | os.path.join(args.dataset_path, 'database.db'), 136 | paths.database_path 137 | ) 138 | 139 | # Create networks. 140 | checkpoint = torch.load(args.checkpoint) 141 | encoders = {} 142 | for feature in args.features: 143 | encoder, _ = create_network_for_feature(feature, config, use_cuda) 144 | state_dict = list(filter(lambda x: x[0] == feature, checkpoint['encoders']))[0] 145 | encoder.load_state_dict(state_dict[1]) 146 | encoder.eval() 147 | encoders[feature] = encoder 148 | 149 | # Build and translate database. 150 | image_features = build_hybrid_database( 151 | args.features, 152 | args.dataset_path, 153 | paths.database_path 154 | ) 155 | np.save(os.path.join(args.dataset_path, 'features.npy'), image_features) 156 | translate_database( 157 | image_features, 158 | paths.database_path, 159 | encoders, args.batch_size, device 160 | ) 161 | 162 | # Matching + GV + reconstruction. 163 | match_features( 164 | args.colmap_path, 165 | paths.database_path, paths.image_path, paths.match_list_path 166 | ) 167 | torch.cuda.empty_cache() 168 | matching_stats = geometric_verification( 169 | args.colmap_path, 170 | paths.database_path, paths.match_list_path 171 | ) 172 | largest_model_path, reconstruction_stats = reconstruct( 173 | args.colmap_path, 174 | paths.database_path, paths.image_path, paths.sparse_path 175 | ) 176 | extra_stats = compute_extra_stats(image_features, largest_model_path) 177 | 178 | with open(paths.output_path, 'w') as f: 179 | f.write(json.dumps(matching_stats)) 180 | f.write('\n') 181 | f.write(json.dumps(reconstruction_stats)) 182 | f.write('\n') 183 | f.write(json.dumps(extra_stats)) 184 | f.write('\n') 185 | 186 | 187 | if __name__ == '__main__': 188 | main() 189 | -------------------------------------------------------------------------------- /local-feature-evaluation/reconstruction_pipeline_progressive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import json 4 | 5 | import numpy as np 6 | 7 | import os 8 | 9 | import shutil 10 | 11 | import sys 12 | 13 | import types 14 | 15 | import torch 16 | 17 | from utils import build_hybrid_database, match_features_hybrid, geometric_verification, reconstruct, compute_extra_stats 18 | 19 | sys.path.append(os.getcwd()) 20 | from lib.utils import create_network_for_feature 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser() 25 | 26 | parser.add_argument( 27 | '--dataset_path', type=str, required=True, 28 | help='path to the dataset' 29 | ) 30 | 31 | parser.add_argument( 32 | '--colmap_path', type=str, required=True, 33 | help='path to the COLMAP executable folder' 34 | ) 35 | 36 | parser.add_argument( 37 | '--features', nargs='+', type=str, required=True, 38 | help='list of descriptors to consider' 39 | ) 40 | 41 | parser.add_argument( 42 | '--exp_name', type=str, required=True, 43 | help='name of the experiment' 44 | ) 45 | 46 | parser.add_argument( 47 | '--checkpoint', type=str, default='checkpoints-pretrained/model.pth', 48 | help='path to the checkpoint' 49 | ) 50 | 51 | parser.add_argument( 52 | '--batch_size', type=int, default=4096, 53 | help='batch size' 54 | ) 55 | 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | def main(): 61 | # Set CUDA. 62 | use_cuda = torch.cuda.is_available() 63 | device = torch.device("cuda:0" if use_cuda else "cpu") 64 | torch.set_grad_enabled(False) 65 | 66 | # Load config json. 67 | with open('checkpoints-pretrained/config.json', 'r') as f: 68 | config = json.load(f) 69 | 70 | # Parse arguments. 71 | args = parse_args() 72 | 73 | paths = types.SimpleNamespace() 74 | paths.database_path = os.path.join( 75 | args.dataset_path, '%s.db' % args.exp_name 76 | ) 77 | paths.image_path = os.path.join( 78 | args.dataset_path, 'images' 79 | ) 80 | paths.match_list_path = os.path.join( 81 | args.dataset_path, 'match-list-exh.txt' 82 | ) 83 | paths.sparse_path = os.path.join( 84 | args.dataset_path, 'sparse-%s' % args.exp_name 85 | ) 86 | paths.output_path = os.path.join( 87 | args.dataset_path, 'stats-%s.txt' % args.exp_name 88 | ) 89 | 90 | # Copy reference database. 91 | if os.path.exists(paths.database_path): 92 | raise FileExistsError('Database file already exists.') 93 | shutil.copy( 94 | os.path.join(args.dataset_path, 'database.db'), 95 | paths.database_path 96 | ) 97 | 98 | # Create networks. 99 | encoders = {} 100 | decoders = {} 101 | if len(args.features) > 1: 102 | checkpoint = torch.load(args.checkpoint) 103 | 104 | for feature in args.features: 105 | encoder, decoder = create_network_for_feature(feature, config, use_cuda) 106 | 107 | state_dict = list(filter(lambda x: x[0] == feature, checkpoint['encoders']))[0] 108 | encoder.load_state_dict(state_dict[1]) 109 | encoder.eval() 110 | encoders[feature] = encoder 111 | 112 | state_dict = list(filter(lambda x: x[0] == feature, checkpoint['decoders']))[0] 113 | decoder.load_state_dict(state_dict[1]) 114 | decoder.eval() 115 | decoders[feature] = decoder 116 | else: 117 | encoders[args.features[0]] = None 118 | decoders[args.features[0]] = None 119 | 120 | # Build and translate database. 121 | image_features = build_hybrid_database( 122 | args.features, 123 | args.dataset_path, 124 | paths.database_path 125 | ) 126 | 127 | # Matching + GV + reconstruction. 128 | match_features_hybrid( 129 | args.features, 130 | image_features, 131 | args.colmap_path, 132 | paths.database_path, paths.image_path, paths.match_list_path, 133 | encoders, decoders, args.batch_size, device 134 | ) 135 | torch.cuda.empty_cache() 136 | matching_stats = geometric_verification( 137 | args.colmap_path, 138 | paths.database_path, paths.match_list_path 139 | ) 140 | largest_model_path, reconstruction_stats = reconstruct( 141 | args.colmap_path, 142 | paths.database_path, paths.image_path, paths.sparse_path 143 | ) 144 | extra_stats = compute_extra_stats(image_features, largest_model_path) 145 | 146 | with open(paths.output_path, 'w') as f: 147 | f.write(json.dumps(matching_stats)) 148 | f.write('\n') 149 | f.write(json.dumps(reconstruction_stats)) 150 | f.write('\n') 151 | f.write(json.dumps(extra_stats)) 152 | f.write('\n') 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /local-feature-evaluation/reconstruction_pipeline_subset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import json 4 | 5 | import os 6 | 7 | import shutil 8 | 9 | import types 10 | 11 | import torch 12 | 13 | from utils import build_hybrid_database, match_features_subset, geometric_verification, reconstruct, compute_extra_stats 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument( 20 | '--dataset_path', type=str, required=True, 21 | help='path to the dataset' 22 | ) 23 | 24 | parser.add_argument( 25 | '--colmap_path', type=str, required=True, 26 | help='path to the COLMAP executable folder' 27 | ) 28 | 29 | parser.add_argument( 30 | '--features', nargs='+', type=str, required=True, 31 | help='list of descriptors to consider' 32 | ) 33 | 34 | parser.add_argument( 35 | '--feature', type=str, required=True, 36 | help='descriptors to map' 37 | ) 38 | 39 | parser.add_argument( 40 | '--exp_name', type=str, required=True, 41 | help='name of the experiment' 42 | ) 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | # Set CUDA. 50 | use_cuda = torch.cuda.is_available() 51 | device = torch.device("cuda:0" if use_cuda else "cpu") 52 | torch.set_grad_enabled(False) 53 | 54 | # Parse arguments. 55 | args = parse_args() 56 | 57 | paths = types.SimpleNamespace() 58 | paths.database_path = os.path.join( 59 | args.dataset_path, '%s.db' % args.exp_name 60 | ) 61 | paths.image_path = os.path.join( 62 | args.dataset_path, 'images' 63 | ) 64 | paths.match_list_path = os.path.join( 65 | args.dataset_path, 'match-list-exh.txt' 66 | ) 67 | paths.sparse_path = os.path.join( 68 | args.dataset_path, 'sparse-%s' % args.exp_name 69 | ) 70 | paths.output_path = os.path.join( 71 | args.dataset_path, 'stats-%s.txt' % args.exp_name 72 | ) 73 | 74 | # Copy reference database. 75 | if os.path.exists(paths.database_path): 76 | raise FileExistsError('Database file already exists.') 77 | shutil.copy( 78 | os.path.join(args.dataset_path, 'database.db'), 79 | paths.database_path 80 | ) 81 | 82 | # Build and translate database. 83 | image_features = build_hybrid_database( 84 | args.features, 85 | args.dataset_path, 86 | paths.database_path 87 | ) 88 | 89 | # Matching + GV + reconstruction. 90 | match_features_subset( 91 | args.feature, 92 | image_features, 93 | args.colmap_path, 94 | paths.database_path, paths.image_path, paths.match_list_path 95 | ) 96 | torch.cuda.empty_cache() 97 | matching_stats = geometric_verification( 98 | args.colmap_path, 99 | paths.database_path, paths.match_list_path + '.aux' 100 | ) 101 | os.remove(paths.match_list_path + '.aux') 102 | largest_model_path, reconstruction_stats = reconstruct( 103 | args.colmap_path, 104 | paths.database_path, paths.image_path, paths.sparse_path 105 | ) 106 | extra_stats = compute_extra_stats(image_features, largest_model_path) 107 | 108 | with open(paths.output_path, 'w') as f: 109 | f.write(json.dumps(matching_stats)) 110 | f.write('\n') 111 | f.write(json.dumps(reconstruction_stats)) 112 | f.write('\n') 113 | f.write(json.dumps(extra_stats)) 114 | f.write('\n') 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /local-feature-evaluation/utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/ahojnnes/local-feature-evaluation/blob/master/scripts/reconstruction_pipeline.py. 2 | # Copyright 2017, Johannes L. Schoenberger . 3 | import numpy as np 4 | 5 | import os 6 | 7 | import subprocess 8 | 9 | import sqlite3 10 | 11 | import torch 12 | 13 | from tqdm import tqdm 14 | 15 | 16 | def mnn_ratio_matcher(descriptors1, descriptors2, ratio=0.9): 17 | # Mutual NN + symmetric Lowe's ratio test matcher. 18 | descriptors1 = torch.from_numpy(np.array(descriptors1)).float().cuda() 19 | descriptors2 = torch.from_numpy(np.array(descriptors2)).float().cuda() 20 | 21 | # L2 normalize descriptors. 22 | descriptors1 /= torch.norm(descriptors1, dim=-1).unsqueeze(-1) 23 | descriptors2 /= torch.norm(descriptors2, dim=-1).unsqueeze(-1) 24 | 25 | # Similarity matrix. 26 | device = descriptors1.device 27 | sim = descriptors1 @ descriptors2.t() 28 | 29 | # Retrieve top 2 nearest neighbors 1->2. 30 | nns_sim, nns = torch.topk(sim, 2, dim=1) 31 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 32 | # Compute Lowe's ratio. 33 | ratios12 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 34 | # Save first NN and match similarity. 35 | nn12 = nns[:, 0] 36 | match_sim = nns_sim[:, 0] 37 | 38 | # Retrieve top 2 nearest neighbors 1->2. 39 | nns_sim, nns = torch.topk(sim.t(), 2, dim=1) 40 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 41 | # Compute Lowe's ratio. 42 | ratios21 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 43 | # Save first NN. 44 | nn21 = nns[:, 0] 45 | 46 | # Mutual NN + symmetric ratio test. 47 | ids1 = torch.arange(0, sim.shape[0], device=device) 48 | mask = torch.min(ids1 == nn21[nn12], torch.min(ratios12 <= ratio, ratios21[nn12] <= ratio)) 49 | 50 | # Final matches. 51 | matches = torch.stack([ids1[mask], nn12[mask]], dim=-1) 52 | match_sim = match_sim[mask] 53 | 54 | return ( 55 | matches.data.cpu().numpy(), 56 | match_sim.data.cpu().numpy() 57 | ) 58 | 59 | 60 | def translate_descriptors(descriptors, source_feature, target_feature, encoders, decoders, batch_size, device): 61 | descriptors_torch = torch.from_numpy(descriptors).to(device) 62 | start_idx = 0 63 | translated_descriptors_torch = torch.zeros([0, 128]).to(device) 64 | while start_idx < descriptors_torch.shape[0]: 65 | aux = encoders[source_feature](descriptors_torch[start_idx : start_idx + batch_size]) 66 | if target_feature is not None: 67 | aux = decoders[target_feature](aux) 68 | translated_descriptors_torch = torch.cat([translated_descriptors_torch, aux], dim=0) 69 | start_idx += batch_size 70 | translated_descriptors = translated_descriptors_torch.cpu().numpy() 71 | return translated_descriptors 72 | 73 | 74 | def image_ids_to_pair_id(image_id1, image_id2): 75 | if image_id1 > image_id2: 76 | return 2147483647 * image_id2 + image_id1 77 | else: 78 | return 2147483647 * image_id1 + image_id2 79 | 80 | 81 | def array_to_blob(array): 82 | return array.tobytes() 83 | 84 | 85 | def blob_to_array(blob, dtype, shape=(-1,)): 86 | if blob is None: 87 | return np.zeros(shape, dtype=dtype) 88 | return np.frombuffer(blob, dtype=dtype).reshape(*shape) 89 | 90 | 91 | def build_hybrid_database(features, dataset_path, database_path): 92 | sorted_features = sorted(features) 93 | n_features = len(sorted_features) 94 | 95 | # Connect to database. 96 | connection = sqlite3.connect(database_path) 97 | cursor = connection.cursor() 98 | 99 | cursor.execute('SELECT name, image_id FROM images;') 100 | images = {} 101 | image_names = {} 102 | for row in cursor: 103 | images[row[0]] = row[1] 104 | image_names[row[1]] = row[0] 105 | 106 | # Randomize features 107 | random = np.random.RandomState() 108 | random.seed(1) 109 | assigned_features = np.array( 110 | list(range(n_features)) * (len(images) // n_features) + 111 | list(random.randint(0, n_features, len(images) % n_features)) 112 | ) 113 | random.shuffle(assigned_features) 114 | 115 | image_ids = np.array(list(images.values())) 116 | for feature_idx, feature in enumerate(sorted_features): 117 | connection_aux = sqlite3.connect(os.path.join( 118 | dataset_path, '%s-features.db' % feature 119 | )) 120 | cursor_aux = connection_aux.cursor() 121 | 122 | image_indices = np.where(assigned_features == feature_idx)[0] 123 | for image_id in image_ids[image_indices]: 124 | image_id = int(image_id) 125 | assert image_names[image_id] == cursor_aux.execute("SELECT name FROM images WHERE image_id=?", (image_id,)).fetchall()[0][0] 126 | data, rows, cols = (cursor_aux.execute("SELECT data, rows, cols FROM keypoints WHERE image_id=?", (image_id,))).fetchall()[0] 127 | cursor.execute( 128 | 'INSERT INTO keypoints(image_id, rows, cols, data) VALUES(?, ?, ?, ?);', 129 | (image_id, rows, cols, data) 130 | ) 131 | data, rows, cols = (cursor_aux.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id,))).fetchall()[0] 132 | cursor.execute( 133 | 'INSERT INTO descriptors(image_id, rows, cols, data) VALUES(?, ?, ?, ?);', 134 | (image_id, rows, cols, data) 135 | ) 136 | 137 | cursor_aux.close() 138 | connection_aux.close() 139 | connection.commit() 140 | 141 | cursor.close() 142 | connection.close() 143 | 144 | image_features = {} 145 | for image_id, feature in zip(image_ids, np.array(sorted_features)[assigned_features]): 146 | image_features[image_id] = feature 147 | 148 | return image_features 149 | 150 | 151 | def match_features(colmap_path, database_path, image_path, match_list_path): 152 | connection = sqlite3.connect(database_path) 153 | cursor = connection.cursor() 154 | 155 | cursor.execute( 156 | 'SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'inlier_matches\';' 157 | ) 158 | try: 159 | inlier_matches_table_exists = bool(next(cursor)[0]) 160 | except StopIteration: 161 | inlier_matches_table_exists = False 162 | 163 | cursor.execute('DELETE FROM matches;') 164 | if inlier_matches_table_exists: 165 | cursor.execute('DELETE FROM inlier_matches;') 166 | else: 167 | cursor.execute('DELETE FROM two_view_geometries;') 168 | connection.commit() 169 | 170 | images = {} 171 | cursor.execute('SELECT name, image_id FROM images;') 172 | for row in cursor: 173 | images[row[0]] = row[1] 174 | 175 | with open(match_list_path, 'r') as f: 176 | raw_image_pairs = f.readlines() 177 | image_pairs = list(map(lambda s: s.strip('\n').split(' '), raw_image_pairs)) 178 | 179 | image_pair_ids = set() 180 | for image_name1, image_name2 in tqdm(image_pairs): 181 | image_id1, image_id2 = images[image_name1], images[image_name2] 182 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2) 183 | if image_pair_id in image_pair_ids: 184 | continue 185 | image_pair_ids.add(image_pair_id) 186 | 187 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id1,))).fetchall()[0] 188 | try: 189 | descriptors1 = blob_to_array(data, np.float32, (rows, cols)) 190 | except ValueError: 191 | descriptors1 = blob_to_array(data, bool, (rows, cols)).astype(np.float32) 192 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id2,))).fetchall()[0] 193 | try: 194 | descriptors2 = blob_to_array(data, np.float32, (rows, cols)) 195 | except ValueError: 196 | descriptors2 = blob_to_array(data, bool, (rows, cols)).astype(np.float32) 197 | 198 | # Match. 199 | if descriptors1.shape[0] == 0 or descriptors2.shape[0] == 0: 200 | matches = np.zeros([0, 2], dtype=np.int) 201 | else: 202 | matches, _ = mnn_ratio_matcher(descriptors1, descriptors2) 203 | 204 | matches = np.array(matches).astype(np.uint32) 205 | if matches.shape[0] == 0: 206 | matches = np.zeros([0, 2]) 207 | assert matches.shape[1] == 2 208 | if image_id1 > image_id2: 209 | matches = matches[:, [1, 0]] 210 | cursor.execute( 211 | 'INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?);', 212 | (image_pair_id, matches.shape[0], matches.shape[1], matches.tostring()) 213 | ) 214 | connection.commit() 215 | 216 | cursor.close() 217 | connection.close() 218 | 219 | 220 | def match_features_hybrid(features, image_features, colmap_path, database_path, image_path, match_list_path, encoders, decoders, batch_size, device): 221 | connection = sqlite3.connect(database_path) 222 | cursor = connection.cursor() 223 | 224 | cursor.execute( 225 | 'SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'inlier_matches\';' 226 | ) 227 | try: 228 | inlier_matches_table_exists = bool(next(cursor)[0]) 229 | except StopIteration: 230 | inlier_matches_table_exists = False 231 | 232 | cursor.execute('DELETE FROM matches;') 233 | if inlier_matches_table_exists: 234 | cursor.execute('DELETE FROM inlier_matches;') 235 | else: 236 | cursor.execute('DELETE FROM two_view_geometries;') 237 | connection.commit() 238 | 239 | images = {} 240 | cursor.execute('SELECT name, image_id FROM images;') 241 | for row in cursor: 242 | images[row[0]] = row[1] 243 | 244 | with open(match_list_path, 'r') as f: 245 | raw_image_pairs = f.readlines() 246 | image_pairs = list(map(lambda s: s.strip('\n').split(' '), raw_image_pairs)) 247 | 248 | image_pair_ids = set() 249 | for image_name1, image_name2 in tqdm(image_pairs): 250 | image_id1, image_id2 = images[image_name1], images[image_name2] 251 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2) 252 | if image_pair_id in image_pair_ids: 253 | continue 254 | image_pair_ids.add(image_pair_id) 255 | 256 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id1,))).fetchall()[0] 257 | try: 258 | descriptors1 = blob_to_array(data, np.float32, (rows, cols)) 259 | except ValueError: 260 | descriptors1 = blob_to_array(data, bool, (rows, cols)).astype(np.float32) 261 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id2,))).fetchall()[0] 262 | try: 263 | descriptors2 = blob_to_array(data, np.float32, (rows, cols)) 264 | except ValueError: 265 | descriptors2 = blob_to_array(data, bool, (rows, cols)).astype(np.float32) 266 | 267 | # Check feature consistency. 268 | feature1, feature2 = image_features[image_id1], image_features[image_id2] 269 | 270 | if feature1 != feature2: 271 | ford1 = features.index(feature1) 272 | ford2 = features.index(feature2) 273 | if ford1 > ford2: 274 | feature = feature1 275 | descriptors2 = translate_descriptors(descriptors2, feature2, feature1, encoders, decoders, batch_size, device) 276 | else: 277 | feature = feature2 278 | descriptors1 = translate_descriptors(descriptors1, feature1, feature2, encoders, decoders, batch_size, device) 279 | else: 280 | feature = feature1 281 | 282 | # Match. 283 | if descriptors1.shape[0] == 0 or descriptors2.shape[0] == 0: 284 | matches = np.zeros([0, 2], dtype=np.int) 285 | else: 286 | matches, _ = mnn_ratio_matcher(descriptors1, descriptors2) 287 | 288 | matches = np.array(matches).astype(np.uint32) 289 | if matches.shape[0] == 0: 290 | matches = np.zeros([0, 2]) 291 | assert matches.shape[1] == 2 292 | if image_id1 > image_id2: 293 | matches = matches[:, [1, 0]] 294 | cursor.execute( 295 | 'INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?);', 296 | (image_pair_id, matches.shape[0], matches.shape[1], matches.tostring()) 297 | ) 298 | connection.commit() 299 | 300 | cursor.close() 301 | connection.close() 302 | 303 | 304 | def match_features_subset(feature, image_features, colmap_path, database_path, image_path, match_list_path): 305 | connection = sqlite3.connect(database_path) 306 | cursor = connection.cursor() 307 | 308 | cursor.execute( 309 | 'SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'inlier_matches\';' 310 | ) 311 | try: 312 | inlier_matches_table_exists = bool(next(cursor)[0]) 313 | except StopIteration: 314 | inlier_matches_table_exists = False 315 | 316 | cursor.execute('DELETE FROM matches;') 317 | if inlier_matches_table_exists: 318 | cursor.execute('DELETE FROM inlier_matches;') 319 | else: 320 | cursor.execute('DELETE FROM two_view_geometries;') 321 | connection.commit() 322 | 323 | images = {} 324 | cursor.execute('SELECT name, image_id FROM images;') 325 | for row in cursor: 326 | images[row[0]] = row[1] 327 | 328 | image_pairs = set() 329 | for image_name1, image_id1 in images.items(): 330 | for image_name2, image_id2 in images.items(): 331 | if image_features[image_id1] != feature or image_features[image_id2] != feature: 332 | continue 333 | if image_name1 == image_name2: 334 | continue 335 | if (image_name2, image_name1) not in image_pairs: 336 | image_pairs.add((image_name1, image_name2)) 337 | 338 | f = open(match_list_path + '.aux', 'w') 339 | image_pair_ids = set() 340 | for image_name1, image_name2 in tqdm(image_pairs): 341 | image_id1, image_id2 = images[image_name1], images[image_name2] 342 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2) 343 | if image_pair_id in image_pair_ids: 344 | continue 345 | image_pair_ids.add(image_pair_id) 346 | 347 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id1,))).fetchall()[0] 348 | try: 349 | descriptors1 = blob_to_array(data, np.float32, (rows, cols)) 350 | except ValueError: 351 | descriptors1 = blob_to_array(data, bool, (rows, cols)).astype(np.float32) 352 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id2,))).fetchall()[0] 353 | try: 354 | descriptors2 = blob_to_array(data, np.float32, (rows, cols)) 355 | except ValueError: 356 | descriptors2 = blob_to_array(data, bool, (rows, cols)).astype(np.float32) 357 | 358 | # Match. 359 | if descriptors1.shape[0] == 0 or descriptors2.shape[0] == 0: 360 | matches = np.zeros([0, 2], dtype=np.int) 361 | else: 362 | matches, _ = mnn_ratio_matcher(descriptors1, descriptors2) 363 | 364 | matches = np.array(matches).astype(np.uint32) 365 | if matches.shape[0] == 0: 366 | matches = np.zeros([0, 2]) 367 | assert matches.shape[1] == 2 368 | if image_id1 > image_id2: 369 | matches = matches[:, [1, 0]] 370 | cursor.execute( 371 | 'INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?);', 372 | (image_pair_id, matches.shape[0], matches.shape[1], matches.tostring()) 373 | ) 374 | f.write('%s %s\n' % (image_name1, image_name2)) 375 | connection.commit() 376 | 377 | cursor.close() 378 | connection.close() 379 | 380 | 381 | def geometric_verification(colmap_path, database_path, match_list_path): 382 | subprocess.call([ 383 | os.path.join(colmap_path, 'colmap'), 'matches_importer', 384 | '--database_path', database_path, 385 | '--match_list_path', match_list_path, 386 | '--match_type', 'pairs', 387 | '--SiftMatching.num_threads', str(8), 388 | '--SiftMatching.use_gpu', '0', 389 | '--SiftMatching.min_inlier_ratio', '0.1' 390 | ]) 391 | 392 | connection = sqlite3.connect(database_path) 393 | cursor = connection.cursor() 394 | 395 | cursor.execute('SELECT count(*) FROM images;') 396 | num_images = next(cursor)[0] 397 | 398 | cursor.execute('SELECT count(*) FROM two_view_geometries WHERE rows > 0;') 399 | num_inlier_pairs = next(cursor)[0] 400 | 401 | cursor.execute('SELECT sum(rows) FROM two_view_geometries WHERE rows > 0;') 402 | num_inlier_matches = next(cursor)[0] 403 | 404 | cursor.close() 405 | connection.close() 406 | 407 | return dict( 408 | num_images=num_images, 409 | num_inlier_pairs=num_inlier_pairs, 410 | num_inlier_matches=num_inlier_matches 411 | ) 412 | 413 | def reconstruct(colmap_path, database_path, image_path, sparse_path, refine_intrinsics=False): 414 | # Run the sparse reconstruction. 415 | if not os.path.exists(sparse_path): 416 | os.mkdir(sparse_path) 417 | if not refine_intrinsics: 418 | extra_mapper_params = [ 419 | '--Mapper.ba_refine_focal_length', str(0), 420 | '--Mapper.ba_refine_principal_point', str(0), 421 | '--Mapper.ba_refine_extra_params', str(0) 422 | ] 423 | else: 424 | extra_mapper_params = [ 425 | '--Mapper.ba_refine_focal_length', str(1), 426 | '--Mapper.ba_refine_principal_point', str(0), 427 | '--Mapper.ba_refine_extra_params', str(1) 428 | ] 429 | subprocess.call([ 430 | os.path.join(colmap_path, 'colmap'), 'mapper', 431 | '--database_path', database_path, 432 | '--image_path', image_path, 433 | '--output_path', sparse_path, 434 | '--Mapper.abs_pose_min_inlier_ratio', str(0.05), 435 | '--Mapper.num_threads', str(16), 436 | ] + extra_mapper_params) 437 | 438 | # Find the largest reconstructed sparse model. 439 | models = os.listdir(sparse_path) 440 | if len(models) == 0: 441 | print('Warning: Could not reconstruct any model') 442 | return 443 | 444 | largest_model = None 445 | largest_model_num_images = 0 446 | for model in models: 447 | subprocess.call([ 448 | os.path.join(colmap_path, 'colmap'), 'model_converter', 449 | '--input_path', os.path.join(sparse_path, model), 450 | '--output_path', os.path.join(sparse_path, model), 451 | '--output_type', 'TXT' 452 | ]) 453 | with open(os.path.join(sparse_path, model, 'cameras.txt'), 'r') as fid: 454 | for line in fid: 455 | if line.startswith('# Number of cameras'): 456 | num_images = int(line.split()[-1]) 457 | if num_images > largest_model_num_images: 458 | largest_model = model 459 | largest_model_num_images = num_images 460 | break 461 | 462 | assert largest_model_num_images > 0 463 | 464 | largest_model_path = os.path.join(sparse_path, largest_model) 465 | 466 | # Convert largest model to ply. 467 | subprocess.call([ 468 | os.path.join(colmap_path, 'colmap'), 'model_converter', 469 | '--input_path', largest_model_path, 470 | '--output_path', os.path.join(sparse_path, 'pointcloud.ply'), 471 | '--output_type', 'PLY' 472 | ]) 473 | 474 | # Recover model statistics. 475 | stats = subprocess.check_output([ 476 | os.path.join(colmap_path, 'colmap'), 'model_analyzer', 477 | '--path', largest_model_path 478 | ]) 479 | 480 | stats = stats.decode().split('\n') 481 | for stat in stats: 482 | if stat.startswith('Registered images'): 483 | num_reg_images = int(stat.split()[-1]) 484 | elif stat.startswith('Points'): 485 | num_sparse_points = int(stat.split()[-1]) 486 | elif stat.startswith('Observations'): 487 | num_observations = int(stat.split()[-1]) 488 | elif stat.startswith('Mean track length'): 489 | mean_track_length = float(stat.split()[-1]) 490 | elif stat.startswith('Mean observations per image'): 491 | num_observations_per_image = float(stat.split()[-1]) 492 | elif stat.startswith('Mean reprojection error'): 493 | mean_reproj_error = float(stat.split()[-1][:-2]) 494 | 495 | return largest_model_path, dict( 496 | num_reg_images=num_reg_images, 497 | num_sparse_points=num_sparse_points, 498 | num_observations=num_observations, 499 | mean_track_length=mean_track_length, 500 | num_observations_per_image=num_observations_per_image, 501 | mean_reproj_error=mean_reproj_error 502 | ) 503 | 504 | 505 | def compute_extra_stats(image_features, largest_model_path): 506 | with open(os.path.join(largest_model_path, 'images.txt'), 'r') as f: 507 | raw_images = f.readlines() 508 | raw_images = raw_images[4 :][:: 2] 509 | 510 | counter = {} 511 | for raw_image in raw_images: 512 | image = raw_image.strip('\n').split(' ') 513 | feature = image_features[int(image[0])] 514 | if feature not in counter: 515 | counter[feature] = 0 516 | counter[feature] += 1 517 | 518 | return counter 519 | -------------------------------------------------------------------------------- /scripts/download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | cd checkpoints 2 | wget https://cvg-data.inf.ethz.ch/cross-descriptor-vis-loc-map-ICCV2021/checkpoints/model.pth 3 | wget https://cvg-data.inf.ethz.ch/cross-descriptor-vis-loc-map-ICCV2021/checkpoints/model-cvsift.pth 4 | -------------------------------------------------------------------------------- /scripts/download_evaluation_data.sh: -------------------------------------------------------------------------------- 1 | cd data/eval 2 | 3 | # Aachen Day-Night. 4 | 5 | # Local Feature Evaluation. 6 | wget https://cvg-data.inf.ethz.ch/cross-descriptor-vis-loc-map-ICCV2021/data/LFE-release.tar.gz 7 | tar xvzf LFE-release.tar.gz 8 | rm LFE-release.tar.gz 9 | -------------------------------------------------------------------------------- /scripts/download_processed_training_data.sh: -------------------------------------------------------------------------------- 1 | cd data/train 2 | wget https://cvg-data.inf.ethz.ch/cross-descriptor-vis-loc-map-ICCV2021/data/training-data-colmap.tar.gz 3 | tar xvzf training-data-colmap.tar.gz 4 | rm -r training-data-colmap.tar.gz 5 | -------------------------------------------------------------------------------- /scripts/download_training_data.sh: -------------------------------------------------------------------------------- 1 | cd data/train 2 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.1.tar.gz 3 | tar -xzvf revisitop1m.1.tar.gz 4 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.2.tar.gz 5 | tar -xzvf revisitop1m.2.tar.gz 6 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.3.tar.gz 7 | tar -xzvf revisitop1m.3.tar.gz 8 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.4.tar.gz 9 | tar -xzvf revisitop1m.4.tar.gz 10 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.5.tar.gz 11 | tar -xzvf revisitop1m.5.tar.gz 12 | rm -f revisitop1m.*.tar.gz -------------------------------------------------------------------------------- /scripts/process_LFE_data.sh: -------------------------------------------------------------------------------- 1 | LFE_PATH=data/eval/LFE-release 2 | for dataset in 'Gendarmenmarkt' 'Madrid_Metropolis' 'Tower_of_London'; do 3 | echo $dataset 4 | # Dataset already provides keypoints consistent with the reference database. 5 | # python feature-utils/extract_sift.py --colmap_path $COLMAP_PATH --dataset_path $LFE_PATH/$dataset --image_path $LFE_PATH/$dataset/images 6 | for feature in 'brief' 'sift-kornia' 'hardnet' 'sosnet'; do 7 | echo $feature 8 | python feature-utils/extract_descriptors.py --dataset_path $LFE_PATH/$dataset --image_path $LFE_PATH/$dataset/images --feature $feature 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /scripts/process_training_data.sh: -------------------------------------------------------------------------------- 1 | python feature-utils/extract_sift.py --colmap_path $COLMAP_PATH --dataset_path data/train/ 2 | for feature in 'brief' 'sift-kornia' 'hardnet' 'sosnet'; do 3 | echo $feature 4 | python feature-utils/extract_descriptors.py --dataset_path data/train/ --feature $feature 5 | python feature-utils/convert_database_to_numpy.py --dataset_path data/train/ --feature $feature 6 | done 7 | mkdir data/train/colmap 8 | mv data/train/*.db data/train/colmap 9 | mv data/train/*.npy data/train/colmap 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import json 4 | 5 | import numpy as np 6 | 7 | import os 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | 14 | from torch.utils.data import DataLoader 15 | 16 | from tqdm import tqdm 17 | 18 | from lib.datasets import TranslationDataset as Dataset 19 | from lib.losses import exhaustive_loss 20 | from lib.utils import create_network_for_feature 21 | 22 | 23 | def parse_arguments(): 24 | parser = argparse.ArgumentParser(description='Training script') 25 | 26 | parser.add_argument( 27 | '--random_seed', type=int, default=1, 28 | help='random seed for numpy and PyTorch' 29 | ) 30 | 31 | parser.add_argument( 32 | '--dataset_path', type=str, required=True, 33 | help='path to the dataset' 34 | ) 35 | 36 | parser.add_argument( 37 | '--features', nargs='+', type=str, required=True, 38 | help='list of descriptors to consider' 39 | ) 40 | 41 | parser.add_argument( 42 | '--initial_checkpoint', type=str, default=None, 43 | help='path to the initial checkpoint' 44 | ) 45 | 46 | parser.add_argument( 47 | '--num_epochs', type=int, default=5, 48 | help='number of training epochs' 49 | ) 50 | parser.add_argument( 51 | '--lr', type=float, default=1e-3, 52 | help='learning rate' 53 | ) 54 | parser.add_argument( 55 | '--batch_size', type=int, default=1024, 56 | help='batch size' 57 | ) 58 | parser.add_argument( 59 | '--num_workers', type=int, default=4, 60 | help='number of workers for data loading' 61 | ) 62 | 63 | parser.add_argument( 64 | '--log_interval', type=int, default=1000, 65 | help='loss logging interval' 66 | ) 67 | 68 | parser.add_argument( 69 | '--checkpoint_directory', type=str, default='checkpoints', 70 | help='directory for training checkpoints' 71 | ) 72 | parser.add_argument( 73 | '--checkpoint_prefix', type=str, default='multi', 74 | help='prefix for training checkpoints' 75 | ) 76 | 77 | parser.add_argument( 78 | '--alpha', type=float, default=0.1, 79 | help='consistency loss weight' 80 | ) 81 | parser.add_argument( 82 | '--margin', type=float, default=1.0, 83 | help='margin for the negative margin loss' 84 | ) 85 | 86 | args = parser.parse_args() 87 | 88 | print(args) 89 | 90 | return args 91 | 92 | 93 | # Updating mean class for loss aggregation. 94 | class UpdatingMean(): 95 | def __init__(self): 96 | self.sum = 0 97 | self.n = 0 98 | 99 | def mean(self): 100 | return self.sum / self.n 101 | 102 | def add(self, loss): 103 | self.sum += loss 104 | self.n += 1 105 | 106 | 107 | # Epoch training / validation loop. 108 | def run_epoch( 109 | encoders, 110 | decoders, 111 | loss_function, 112 | optimizer, 113 | dataloader, 114 | device, 115 | log_file, train=True 116 | ): 117 | epoch_loss = UpdatingMean() 118 | epoch_t_loss = UpdatingMean() 119 | epoch_e_loss = UpdatingMean() 120 | 121 | torch.set_grad_enabled(train) 122 | 123 | progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) 124 | for batch_idx, batch in progress_bar: 125 | # Move batch to device. 126 | for key in batch.keys(): 127 | batch[key] = batch[key].to(device) 128 | 129 | # Reset gradient if needed. 130 | if train: 131 | optimizer.zero_grad() 132 | 133 | # Compute loss. 134 | loss, (t_loss, e_loss) = loss_function(encoders, decoders, batch, device) 135 | 136 | # Add loss to history. 137 | epoch_loss.add(loss.data.cpu().numpy()) 138 | epoch_t_loss.add(t_loss) 139 | epoch_e_loss.add(e_loss) 140 | 141 | # Update progress bar. 142 | progress_bar.set_postfix( 143 | loss=('%.4f' % epoch_loss.mean()), 144 | t_loss=('%.4f' % epoch_t_loss.mean()), 145 | e_loss=('%.4f' % epoch_e_loss.mean()) 146 | ) 147 | 148 | # Update logs. 149 | if batch_idx % args.log_interval == 0: 150 | log_file.write('[%s] epoch %02d - batch %04d / %04d - avg_loss: %f, avg_t_loss: %f, avg_e_loss: %f\n' % ( 151 | 'train' if train else 'valid', 152 | epoch_idx, batch_idx, len(dataloader), 153 | epoch_loss.mean(), epoch_t_loss.mean(), epoch_e_loss.mean() 154 | )) 155 | 156 | # Backprop. 157 | if train: 158 | loss.backward() 159 | optimizer.step() 160 | 161 | # Update logs. 162 | log_file.write('[%s] epoch %02d - avg_loss: %f, avg_t_loss: %f, avg_e_loss: %f\n' % ( 163 | 'train' if train else 'valid', 164 | epoch_idx, 165 | epoch_loss.mean(), epoch_t_loss.mean(), epoch_e_loss.mean() 166 | )) 167 | log_file.flush() 168 | 169 | return epoch_loss.mean() 170 | 171 | 172 | if __name__ == '__main__': 173 | # Set CUDA. 174 | use_cuda = torch.cuda.is_available() 175 | device = torch.device("cuda:0" if use_cuda else "cpu") 176 | 177 | # Load config json. 178 | with open('checkpoints-pretrained/config.json', 'r') as f: 179 | config = json.load(f) 180 | 181 | # Command line arguments. 182 | args = parse_arguments() 183 | 184 | # Fix random seed. 185 | torch.manual_seed(args.random_seed) 186 | if use_cuda: 187 | torch.cuda.manual_seed(args.random_seed) 188 | np.random.seed(args.random_seed) 189 | 190 | # Networks. 191 | encoders = {} 192 | decoders = {} 193 | for feature in args.features: 194 | encoder, decoder = create_network_for_feature(feature, config, use_cuda) 195 | 196 | encoders[feature] = encoder 197 | decoders[feature] = decoder 198 | 199 | # Load initial checkpoint if needed. 200 | if args.initial_checkpoint is not None: 201 | checkpoint = torch.load(args.initial_checkpoint) 202 | for feature, state_dict in checkpoint['encoders']: 203 | encoders[feature].load_state_dict(state_dict) 204 | for feature, state_dict in checkpoint['decoders']: 205 | decoders[feature].load_state_dict(state_dict) 206 | 207 | # Dataset. 208 | training_dataset = Dataset( 209 | base_path=args.dataset_path, 210 | features=args.features 211 | ) 212 | training_dataloader = DataLoader( 213 | training_dataset, 214 | batch_size=args.batch_size, 215 | num_workers=args.num_workers, 216 | shuffle=True 217 | ) 218 | 219 | # Optimizer and loss. 220 | optimizer = optim.Adam( 221 | filter( 222 | lambda p: p.requires_grad, 223 | [param for _, enc in encoders.items() for param in enc.parameters()] + 224 | [param for _, dec in decoders.items() for param in dec.parameters()] 225 | ), 226 | lr=args.lr 227 | ) 228 | loss_function = lambda encoders, decoders, batch, device: exhaustive_loss( 229 | encoders, decoders, batch, device, 230 | alpha=args.alpha, margin=args.margin 231 | ) 232 | 233 | # Create the checkpoint directory. 234 | if os.path.isdir(args.checkpoint_directory): 235 | print('[Warning] Checkpoint directory already exists.') 236 | else: 237 | os.mkdir(args.checkpoint_directory) 238 | 239 | # Open the log file for writing 240 | if os.path.exists(os.path.join(args.checkpoint_directory, 'log.txt')): 241 | print('[Warning] Log file already exists.') 242 | log_file = open(os.path.join(args.checkpoint_directory, 'log.txt'), 'a+') 243 | 244 | # Training loop. 245 | train_loss_history = [] 246 | for epoch_idx in range(1, args.num_epochs + 1): 247 | # Run training epoch. 248 | train_loss_history.append( 249 | run_epoch( 250 | encoders, decoders, 251 | loss_function, 252 | optimizer, 253 | training_dataloader, 254 | device, 255 | log_file 256 | ) 257 | ) 258 | 259 | # Save the current checkpoint 260 | checkpoint_path = os.path.join( 261 | args.checkpoint_directory, 262 | '%s.%02d.pth' % (args.checkpoint_prefix, epoch_idx) 263 | ) 264 | checkpoint = { 265 | 'args': args, 266 | 'epoch_idx': epoch_idx, 267 | 'encoders': [(feature, enc.state_dict()) for feature, enc in encoders.items()], 268 | 'decoders': [(feature, dec.state_dict()) for feature, dec in decoders.items()], 269 | 'optimizer': optimizer.state_dict(), 270 | 'train_loss_history': train_loss_history 271 | } 272 | torch.save(checkpoint, checkpoint_path) 273 | 274 | # Close the log file 275 | log_file.close() 276 | --------------------------------------------------------------------------------