├── .gitignore
├── LICENSE
├── README.md
├── checkpoints-pretrained
└── config.json
├── data
├── eval
│ └── .keep
└── train
│ └── .keep
├── env.yml
├── feature-utils
├── convert_database_to_numpy.py
├── extract_descriptors.py
└── extract_sift.py
├── lib
├── database.py
├── datasets.py
├── losses.py
├── networks.py
└── utils.py
├── local-feature-evaluation
├── align_and_compare.py
├── reconstruction_pipeline_embed.py
├── reconstruction_pipeline_progressive.py
├── reconstruction_pipeline_subset.py
└── utils.py
├── scripts
├── download_checkpoints.sh
├── download_evaluation_data.sh
├── download_processed_training_data.sh
├── download_training_data.sh
├── process_LFE_data.sh
└── process_training_data.sh
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | data
3 | *.pth
4 | checkpoints
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, ETH Zurich.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cross-Descriptor Visual Localization and Mapping
2 |
3 | This repository contains the implementation of the following paper:
4 |
5 | ```text
6 | "Cross-Descriptor Visual Localization and Mapping".
7 | M. Dusmanu, O. Miksik, J.L. Schönberger, and M. Pollefeys. ICCV 2021.
8 | ```
9 |
10 | [[Paper on arXiv]](https://arxiv.org/abs/2012.01377)
11 |
12 |
13 | ## Requirements
14 |
15 | ### COLMAP
16 |
17 | We use COLMAP for DoG keypoint extraction as well as localization and mapping.
18 | Please follow the installation instructions available on the [official webpage](https://colmap.github.io).
19 | Before proceeding, we recommend setting an environmental variable to the COLMAP executable folder by running `export COLMAP_PATH=path_to_colmap_executable_folder`.
20 |
21 | ### Python
22 |
23 | The environment can be set up directly using conda:
24 | ```
25 | conda env create -f env.yml
26 | conda activate cross-descriptor-vis-loc-map
27 | ```
28 |
29 | ### Training data
30 |
31 | We provide a script for downloading the raw training data:
32 | ```
33 | bash scripts/download_training_data.sh
34 | ```
35 |
36 | ### Evaluation data
37 |
38 | We provide a script for downloading the LFE dataset along with the GT used for evaluation as well as the Aachen Day-Night dataset:
39 | ```
40 | bash scripts/download_evaluation_data.sh
41 | ```
42 |
43 |
44 | ## Training
45 |
46 | ### Data preprocessing
47 |
48 | First step is extracting keypoints and descriptors on the training data downloaded above.
49 | ```
50 | bash scripts/process_training_data.sh
51 | ```
52 | Alternatively, you can directly download the processed training data by running:
53 | ```
54 | bash scripts/download_processed_training_data.sh
55 | ```
56 |
57 | ### Training
58 |
59 | To run training with the default architecture and hyper-parameters, execute the following:
60 | ```
61 | python train.py \
62 | --dataset_path data/train/colmap \
63 | --features brief sift-kornia hardnet sosnet
64 | ```
65 |
66 | ### Pretrained models
67 |
68 | We provide two pretrained models trained on descriptors extracted from COLMAP SIFT and OpenCV SIFT keypoints, respectively.
69 | These models can be downloaded by running:
70 | ```
71 | bash scripts/download_checkpoints.sh
72 | ```
73 |
74 |
75 | ## Evaluation
76 |
77 | ### Demo Notebook
78 |
79 |
80 | Click for details...
81 |
82 |
83 |
84 | ### Local Feature Evaluation Benchmark
85 |
86 |
87 | Click for details...
88 |
89 | First step is extracting descriptors on all datasets:
90 | ```
91 | bash scripts/process_LFE_data.sh
92 | ```
93 |
94 | We provide examples below for running reconstruction on Madrid Metrpolis in each different evaluation scenario.
95 |
96 | #### Reconstruction using a single descriptor (standard)
97 |
98 | ```
99 | python local-feature-evaluation/reconstruction_pipeline_progressive.py \
100 | --dataset_path data/eval/LFE-release/Madrid_Metropolis \
101 | --colmap_path $COLMAP_PATH \
102 | --features sift-kornia \
103 | --exp_name sift-kornia-single
104 | ```
105 |
106 | #### Reconstruction using the progressive approach (ours)
107 |
108 | ```
109 | python local-feature-evaluation/reconstruction_pipeline_progressive.py \
110 | --dataset_path data/eval/LFE-release/Madrid_Metropolis \
111 | --colmap_path $COLMAP_PATH \
112 | --features brief sift-kornia hardnet sosnet \
113 | --exp_name progressive
114 | ```
115 |
116 | #### Reconstruction using the joint embedding approach (ours)
117 |
118 | ```
119 | python local-feature-evaluation/reconstruction_pipeline_embed.py \
120 | --dataset_path data/eval/LFE-release/Madrid_Metropolis \
121 | --colmap_path $COLMAP_PATH \
122 | --features brief sift-kornia hardnet sosnet \
123 | --exp_name embed
124 | ```
125 |
126 | #### Reconstruction using a single descriptor on the associated split (real-world)
127 |
128 | ```
129 | python local-feature-evaluation/reconstruction_pipeline_subset.py \
130 | --dataset_path data/eval/LFE-release/Madrid_Metropolis/ \
131 | --colmap_path $COLMAP_PATH \
132 | --features brief sift-kornia hardnet sosnet \
133 | --feature sift-kornia \
134 | --exp_name sift-kornia-subset
135 | ```
136 |
137 | #### Evaluation of a reconstruction w.r.t. metric pseudo-ground-truth
138 |
139 | ```
140 | python local-feature-evaluation/align_and_compare.py \
141 | --colmap_path $COLMAP_PATH \
142 | --reference_model_path data/eval/LFE-release/Madrid_Metropolis/sparse-reference/filtered-metric/ \
143 | --model_path data/eval/LFE-release/Madrid_Metropolis/sparse-sift-kornia-single/0/
144 | ```
145 |
146 |
147 |
148 | ### Aachen Day-Night
149 |
150 |
151 | Click for details...
152 |
153 |
154 |
155 | ## BibTeX
156 |
157 | If you use this code in your project, please cite the following paper:
158 | ```
159 | @InProceedings{Dusmanu2021Cross,
160 | author = {Dusmanu, Mihai and Miksik, Ondrej and Sch\"onberger, Johannes L. and Pollefeys, Marc},
161 | title = {{Cross Descriptor Visual Localization and Mapping}},
162 | booktitle = {Proceedings of the International Conference on Computer Vision},
163 | year = {2021}
164 | }
165 | ```
166 |
--------------------------------------------------------------------------------
/checkpoints-pretrained/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "use_bn": true,
3 | "emb_dim": 128,
4 | "emb_l2_norm": true,
5 | "emb_last_activation": null,
6 | "brief": {
7 | "descriptor_dim": 512,
8 | "hidden_dims": [1024, 1024],
9 | "l2_norm": false,
10 | "last_activation": "sigmoid"
11 | },
12 | "sift-kornia": {
13 | "descriptor_dim": 128,
14 | "hidden_dims": [1024, 1024],
15 | "l2_norm": true,
16 | "last_activation": "relu"
17 | },
18 | "hardnet": {
19 | "descriptor_dim": 128,
20 | "hidden_dims": [256, 256],
21 | "l2_norm": true,
22 | "last_activation": null
23 | },
24 | "sosnet": {
25 | "descriptor_dim": 128,
26 | "hidden_dims": [256, 256],
27 | "l2_norm": true,
28 | "last_activation": null
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/data/eval/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mihaidusmanu/cross-descriptor-vis-loc-map/2d3c7eb716706c65710480b0e82eee4b653ddabd/data/eval/.keep
--------------------------------------------------------------------------------
/data/train/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mihaidusmanu/cross-descriptor-vis-loc-map/2d3c7eb716706c65710480b0e82eee4b653ddabd/data/train/.keep
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: cross-descriptor-vis-loc-map
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=10.2
8 | - pip
9 | - python=3.9
10 | - pytorch=1.9
11 | - torchaudio
12 | - torchvision
13 | - tqdm
14 | - pip:
15 | - extract-patches
16 | - kornia
17 | - opencv-python
18 |
--------------------------------------------------------------------------------
/feature-utils/convert_database_to_numpy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 |
5 | import os
6 |
7 | import sqlite3
8 |
9 | from tqdm import tqdm
10 |
11 |
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument(
16 | '--dataset_path', type=str, required=True,
17 | help='path to the dataset'
18 | )
19 | parser.add_argument(
20 | '--feature', type=str, required=True,
21 | help='descriptor'
22 | )
23 |
24 | args = parser.parse_args()
25 |
26 | database_path = os.path.join(
27 | args.dataset_path, '%s-features.db' % args.feature
28 | )
29 | output_path = os.path.join(
30 | args.dataset_path, '%s-features.npy' % args.feature
31 | )
32 |
33 | connection = sqlite3.connect(database_path)
34 | cursor = connection.cursor()
35 |
36 | all_descriptors = []
37 | for (image_id,) in tqdm(cursor.execute('SELECT image_id FROM images').fetchall()):
38 | r, c, blob = cursor.execute('SELECT rows, cols, data FROM descriptors WHERE image_id=?', (image_id,)).fetchall()[0]
39 | try:
40 | descriptors = np.frombuffer(blob, dtype=np.float32).reshape(r, c)
41 | except ValueError:
42 | descriptors = np.frombuffer(blob, dtype=bool).reshape(r, c)
43 | all_descriptors.append(descriptors)
44 | all_descriptors = np.concatenate(all_descriptors, axis=0)
45 |
46 | # Random shuffle - not required.
47 | random = np.random.RandomState(seed=1)
48 | random.shuffle(all_descriptors)
49 |
50 | np.save(output_path, all_descriptors)
51 |
52 | cursor.close()
53 | connection.close()
--------------------------------------------------------------------------------
/feature-utils/extract_descriptors.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/vcg-uvic/image-matching-benchmark-baselines/blob/master/extract_descriptors_hardnet.py.
2 |
3 | import argparse
4 |
5 | import numpy as np
6 |
7 | import os
8 |
9 | import cv2
10 |
11 | import kornia
12 |
13 | import shutil
14 |
15 | import sqlite3
16 |
17 | import types
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 |
23 | import torchvision.transforms as transforms
24 |
25 | import tqdm
26 |
27 | from extract_patches.core import extract_patches
28 |
29 |
30 | def get_transforms():
31 | transform = transforms.Compose([
32 | transforms.Lambda(lambda x: cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)),
33 | transforms.Lambda(lambda x: np.reshape(x, (32, 32, 1))),
34 | transforms.ToTensor(),
35 | ])
36 | return transform
37 |
38 |
39 | class BRIEFDescriptor(nn.Module):
40 | # Adapted from https://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.BRIEF.
41 | def __init__(self, desc_size=512, patch_size=32, seed=1):
42 | super(BRIEFDescriptor, self).__init__()
43 |
44 | # Sampling pattern.
45 | random = np.random.RandomState()
46 | random.seed(seed)
47 | samples = (patch_size / 5.0) * random.randn(desc_size * 8)
48 | samples = np.array(samples, dtype=np.int32)
49 | samples = samples[
50 | (samples <= (patch_size // 2)) & (samples >= - (patch_size - 2) // 2)
51 | ]
52 | samples += (patch_size // 2 - 1)
53 | pos1 = samples[: desc_size * 2].reshape(desc_size, 2)
54 | pos2 = samples[desc_size * 2 : desc_size * 4].reshape(desc_size, 2)
55 |
56 | # Create tensors.
57 | self.pos1 = torch.from_numpy(pos1).long()
58 | self.pos2 = torch.from_numpy(pos2).long()
59 |
60 | def forward(self, patches):
61 | pixel_values1 = patches[:, 0, self.pos1[:, 0], self.pos1[:, 1]]
62 | pixel_values2 = patches[:, 0, self.pos2[:, 0], self.pos2[:, 1]]
63 | descriptors = (pixel_values1 < pixel_values2)
64 | return descriptors
65 |
66 |
67 | def recover_database_images_and_ids(database_path):
68 | # Connect to the database.
69 | connection = sqlite3.connect(database_path)
70 | cursor = connection.cursor()
71 |
72 | # Recover database images and ids.
73 | images = {}
74 | cursor.execute('SELECT name, image_id FROM images;')
75 | for row in cursor:
76 | images[row[1]] = row[0]
77 |
78 | # Close the connection to the database.
79 | cursor.close()
80 | connection.close()
81 |
82 | return images
83 |
84 |
85 | if __name__ == '__main__':
86 | parser = argparse.ArgumentParser()
87 |
88 | parser.add_argument(
89 | '--dataset_path', type=str, required=True,
90 | help='path to the dataset'
91 | )
92 | parser.add_argument(
93 | '--image_path', type=str, default=None,
94 | help='path to the images'
95 | )
96 | parser.add_argument(
97 | '--feature', type=str, required=True,
98 | choices=['brief', 'sift-kornia', 'hardnet', 'sosnet'],
99 | help='descriptors to be extracted'
100 | )
101 | parser.add_argument(
102 | '--mr_size', type=float, default=12.0,
103 | help='patch size in image is mr_size * pt.size'
104 | )
105 | parser.add_argument(
106 | '--batch_size', type=int, default=512,
107 | help='path to the model weights'
108 | )
109 |
110 | args = parser.parse_args()
111 |
112 | if args.image_path is None:
113 | args.image_path = args.dataset_path
114 |
115 | # Dataset paths.
116 | paths = types.SimpleNamespace()
117 | paths.sift_database_path = os.path.join(args.dataset_path, 'sift-features.db')
118 | paths.database_path = os.path.join(args.dataset_path, '%s-features.db' % args.feature)
119 |
120 | # Copy SIFT database.
121 | if os.path.exists(paths.database_path):
122 | raise FileExistsError('Database already exists at %s.' % paths.database_path)
123 | shutil.copy(paths.sift_database_path, paths.database_path)
124 |
125 | # PyTorch settings.
126 | use_cuda = torch.cuda.is_available()
127 | device = torch.device("cuda:0" if use_cuda else "cpu")
128 | torch.set_grad_enabled(False)
129 |
130 | # Network and input transforms.
131 | dim = 128
132 | dtype = np.float32
133 | if args.feature == 'brief':
134 | model = BRIEFDescriptor()
135 | model = model.to(device)
136 | dim = 512
137 | dtype = bool
138 | elif args.feature == 'sift-kornia':
139 | model = kornia.feature.SIFTDescriptor(patch_size=32, rootsift=False)
140 | model = model.to(device)
141 | elif args.feature == 'hardnet':
142 | model = kornia.feature.HardNet(pretrained=True)
143 | model = model.to(device)
144 | model.eval()
145 | elif args.feature == 'sosnet':
146 | model = kornia.feature.SOSNet(pretrained=True)
147 | model = model.to(device)
148 | model.eval()
149 | transform = get_transforms()
150 |
151 | # Recover list of images.
152 | images = recover_database_images_and_ids(paths.database_path)
153 |
154 | # Connect to database.
155 | connection = sqlite3.connect(paths.database_path)
156 | cursor = connection.cursor()
157 |
158 | cursor.execute('DELETE FROM descriptors;')
159 | connection.commit()
160 |
161 | cursor.execute('SELECT image_id, rows, cols, data FROM keypoints;')
162 | raw_keypoints = cursor.fetchall()
163 | for row in tqdm.tqdm(raw_keypoints):
164 | assert(row[2] == 6)
165 | image_id = row[0]
166 | image_relative_path = images[image_id]
167 | if row[1] == 0:
168 | keypoints = np.zeros([0, 6])
169 | else:
170 | keypoints = np.frombuffer(row[-1], dtype=np.float32).reshape(row[1], row[2])
171 |
172 | keypoints = np.copy(keypoints)
173 | # In COLMAP, the upper left pixel has the coordinate (0.5, 0.5).
174 | keypoints[:, 0] = keypoints[:, 0] - .5
175 | keypoints[:, 1] = keypoints[:, 1] - .5
176 |
177 | # Extract patches.
178 | image = cv2.cvtColor(
179 | cv2.imread(os.path.join(args.image_path, image_relative_path)),
180 | cv2.COLOR_BGR2RGB
181 | )
182 |
183 | patches = extract_patches(
184 | keypoints, image, 32, args.mr_size, 'xyA'
185 | )
186 |
187 | # Extract descriptors.
188 | descriptors = np.zeros((len(patches), dim), dtype=dtype)
189 | for i in range(0, len(patches), args.batch_size):
190 | data_a = patches[i : i + args.batch_size]
191 | data_a = torch.stack(
192 | [transform(patch) for patch in data_a]
193 | ).to(device)
194 | # Predict
195 | out_a = model(data_a)
196 | descriptors[i : i + args.batch_size] = out_a.cpu().detach().numpy()
197 |
198 | # Insert into database.
199 | cursor.execute(
200 | 'INSERT INTO descriptors(image_id, rows, cols, data) VALUES(?, ?, ?, ?);',
201 | (image_id, descriptors.shape[0], descriptors.shape[1], descriptors.tobytes())
202 | )
203 | connection.commit()
204 |
205 | # Close connection to database.
206 | cursor.close()
207 | connection.close()
208 |
--------------------------------------------------------------------------------
/feature-utils/extract_sift.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import os
4 |
5 | import subprocess
6 |
7 | import types
8 |
9 |
10 | if __name__ == '__main__':
11 | parser = argparse.ArgumentParser()
12 |
13 | parser.add_argument(
14 | '--colmap_path', type=str, required=True,
15 | help='path to the COLMAP executable folder'
16 | )
17 | parser.add_argument(
18 | '--dataset_path', type=str, required=True,
19 | help='path to the dataset'
20 | )
21 | parser.add_argument(
22 | '--image_path', type=str, default=None,
23 | help='path to the images'
24 | )
25 |
26 | args = parser.parse_args()
27 |
28 | if args.image_path is None:
29 | args.image_path = args.dataset_path
30 |
31 | # Dataset paths.
32 | paths = types.SimpleNamespace()
33 | paths.database_path = os.path.join(args.dataset_path, 'sift-features.db')
34 |
35 | if os.path.exists(paths.database_path):
36 | raise FileExistsError('Database already exists at %s.' % paths.database_path)
37 |
38 | # Extract SIFT features.
39 | subprocess.call([
40 | os.path.join(args.colmap_path, 'colmap'), 'feature_extractor',
41 | '--database_path', paths.database_path,
42 | '--image_path', args.image_path,
43 | '--SiftExtraction.first_octave', str(0),
44 | '--SiftExtraction.num_threads', str(1)
45 | ])
46 |
--------------------------------------------------------------------------------
/lib/database.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright
8 | # notice, this list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 | # its contributors may be used to endorse or promote products derived
16 | # from this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 | # POSSIBILITY OF SUCH DAMAGE.
29 | #
30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31 |
32 | # This script is based on an original implementation by True Price.
33 |
34 | import sys
35 | import sqlite3
36 | import numpy as np
37 |
38 |
39 | IS_PYTHON3 = sys.version_info[0] >= 3
40 |
41 | MAX_IMAGE_ID = 2**31 - 1
42 |
43 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
44 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
45 | model INTEGER NOT NULL,
46 | width INTEGER NOT NULL,
47 | height INTEGER NOT NULL,
48 | params BLOB,
49 | prior_focal_length INTEGER NOT NULL)"""
50 |
51 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
52 | image_id INTEGER PRIMARY KEY NOT NULL,
53 | rows INTEGER NOT NULL,
54 | cols INTEGER NOT NULL,
55 | data BLOB,
56 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
57 |
58 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
59 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
60 | name TEXT NOT NULL UNIQUE,
61 | camera_id INTEGER NOT NULL,
62 | prior_qw REAL,
63 | prior_qx REAL,
64 | prior_qy REAL,
65 | prior_qz REAL,
66 | prior_tx REAL,
67 | prior_ty REAL,
68 | prior_tz REAL,
69 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
70 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
71 | """.format(MAX_IMAGE_ID)
72 |
73 | CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
74 | CREATE TABLE IF NOT EXISTS two_view_geometries (
75 | pair_id INTEGER PRIMARY KEY NOT NULL,
76 | rows INTEGER NOT NULL,
77 | cols INTEGER NOT NULL,
78 | data BLOB,
79 | config INTEGER NOT NULL,
80 | F BLOB,
81 | E BLOB,
82 | H BLOB)
83 | """
84 |
85 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
86 | image_id INTEGER PRIMARY KEY NOT NULL,
87 | rows INTEGER NOT NULL,
88 | cols INTEGER NOT NULL,
89 | data BLOB,
90 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
91 | """
92 |
93 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
94 | pair_id INTEGER PRIMARY KEY NOT NULL,
95 | rows INTEGER NOT NULL,
96 | cols INTEGER NOT NULL,
97 | data BLOB)"""
98 |
99 | CREATE_NAME_INDEX = \
100 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
101 |
102 | CREATE_ALL = "; ".join([
103 | CREATE_CAMERAS_TABLE,
104 | CREATE_IMAGES_TABLE,
105 | CREATE_KEYPOINTS_TABLE,
106 | CREATE_DESCRIPTORS_TABLE,
107 | CREATE_MATCHES_TABLE,
108 | CREATE_TWO_VIEW_GEOMETRIES_TABLE,
109 | CREATE_NAME_INDEX
110 | ])
111 |
112 |
113 | def image_ids_to_pair_id(image_id1, image_id2):
114 | if image_id1 > image_id2:
115 | image_id1, image_id2 = image_id2, image_id1
116 | return image_id1 * MAX_IMAGE_ID + image_id2
117 |
118 |
119 | def pair_id_to_image_ids(pair_id):
120 | image_id2 = pair_id % MAX_IMAGE_ID
121 | image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
122 | return image_id1, image_id2
123 |
124 |
125 | def array_to_blob(array):
126 | if IS_PYTHON3:
127 | return array.tostring()
128 | else:
129 | return np.getbuffer(array)
130 |
131 |
132 | def blob_to_array(blob, dtype, shape=(-1,)):
133 | if blob is None:
134 | return np.zeros(shape, dtype=dtype)
135 | if IS_PYTHON3:
136 | return np.fromstring(blob, dtype=dtype).reshape(*shape)
137 | else:
138 | return np.frombuffer(blob, dtype=dtype).reshape(*shape)
139 |
140 |
141 | class COLMAPDatabase(sqlite3.Connection):
142 |
143 | @staticmethod
144 | def connect(database_path):
145 | db = sqlite3.connect(database_path, factory=COLMAPDatabase)
146 | db.path = database_path
147 | return db
148 |
149 |
150 | def __init__(self, *args, **kwargs):
151 | super(COLMAPDatabase, self).__init__(*args, **kwargs)
152 |
153 | self.path = None
154 |
155 | self.create_tables = lambda: self.executescript(CREATE_ALL)
156 | self.create_cameras_table = \
157 | lambda: self.executescript(CREATE_CAMERAS_TABLE)
158 | self.create_descriptors_table = \
159 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
160 | self.create_images_table = \
161 | lambda: self.executescript(CREATE_IMAGES_TABLE)
162 | self.create_two_view_geometries_table = \
163 | lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
164 | self.create_keypoints_table = \
165 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
166 | self.create_matches_table = \
167 | lambda: self.executescript(CREATE_MATCHES_TABLE)
168 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
169 |
170 | def add_camera(self, model, width, height, params,
171 | prior_focal_length=False, camera_id=None):
172 | params = np.asarray(params, np.float64)
173 | cursor = self.execute(
174 | "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
175 | (camera_id, model, width, height, array_to_blob(params),
176 | prior_focal_length))
177 | return cursor.lastrowid
178 |
179 | def add_image(self, name, camera_id,
180 | prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None):
181 | cursor = self.execute(
182 | "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
183 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
184 | prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
185 | return cursor.lastrowid
186 |
187 | def add_keypoints(self, image_id, keypoints):
188 | assert(len(keypoints.shape) == 2)
189 | assert(keypoints.shape[1] in [2, 4, 6])
190 |
191 | keypoints = np.asarray(keypoints, np.float32)
192 | self.execute(
193 | "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
194 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
195 |
196 | def add_descriptors(self, image_id, descriptors):
197 | descriptors = np.ascontiguousarray(descriptors, np.uint8)
198 | self.execute(
199 | "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
200 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),))
201 |
202 | def add_matches(self, image_id1, image_id2, matches):
203 | assert(len(matches.shape) == 2)
204 | assert(matches.shape[1] == 2)
205 |
206 | if image_id1 > image_id2:
207 | matches = matches[:,::-1]
208 |
209 | pair_id = image_ids_to_pair_id(image_id1, image_id2)
210 | matches = np.asarray(matches, np.uint32)
211 | self.execute(
212 | "INSERT INTO matches VALUES (?, ?, ?, ?)",
213 | (pair_id,) + matches.shape + (array_to_blob(matches),))
214 |
215 | def add_two_view_geometry(self, image_id1, image_id2, matches,
216 | F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
217 | assert(len(matches.shape) == 2)
218 | assert(matches.shape[1] == 2)
219 |
220 | if image_id1 > image_id2:
221 | matches = matches[:,::-1]
222 |
223 | pair_id = image_ids_to_pair_id(image_id1, image_id2)
224 | matches = np.asarray(matches, np.uint32)
225 | F = np.asarray(F, dtype=np.float64)
226 | E = np.asarray(E, dtype=np.float64)
227 | H = np.asarray(H, dtype=np.float64)
228 | self.execute(
229 | "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
230 | (pair_id,) + matches.shape + (array_to_blob(matches), config,
231 | array_to_blob(F), array_to_blob(E), array_to_blob(H)))
232 |
233 |
234 | def example_usage():
235 | import os
236 | import argparse
237 |
238 | parser = argparse.ArgumentParser()
239 | parser.add_argument("--database_path", default="database.db")
240 | args = parser.parse_args()
241 |
242 | if os.path.exists(args.database_path):
243 | print("ERROR: database path already exists -- will not modify it.")
244 | return
245 |
246 | # Open the database.
247 |
248 | db = COLMAPDatabase.connect(args.database_path)
249 |
250 | # For convenience, try creating all the tables upfront.
251 |
252 | db.create_tables()
253 |
254 | # Create dummy cameras.
255 |
256 | model1, width1, height1, params1 = \
257 | 0, 1024, 768, np.array((1024., 512., 384.))
258 | model2, width2, height2, params2 = \
259 | 2, 1024, 768, np.array((1024., 512., 384., 0.1))
260 |
261 | camera_id1 = db.add_camera(model1, width1, height1, params1)
262 | camera_id2 = db.add_camera(model2, width2, height2, params2)
263 |
264 | # Create dummy images.
265 |
266 | image_id1 = db.add_image("image1.png", camera_id1)
267 | image_id2 = db.add_image("image2.png", camera_id1)
268 | image_id3 = db.add_image("image3.png", camera_id2)
269 | image_id4 = db.add_image("image4.png", camera_id2)
270 |
271 | # Create dummy keypoints.
272 | #
273 | # Note that COLMAP supports:
274 | # - 2D keypoints: (x, y)
275 | # - 4D keypoints: (x, y, theta, scale)
276 | # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
277 |
278 | num_keypoints = 1000
279 | keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
280 | keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
281 | keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
282 | keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
283 |
284 | db.add_keypoints(image_id1, keypoints1)
285 | db.add_keypoints(image_id2, keypoints2)
286 | db.add_keypoints(image_id3, keypoints3)
287 | db.add_keypoints(image_id4, keypoints4)
288 |
289 | # Create dummy matches.
290 |
291 | M = 50
292 | matches12 = np.random.randint(num_keypoints, size=(M, 2))
293 | matches23 = np.random.randint(num_keypoints, size=(M, 2))
294 | matches34 = np.random.randint(num_keypoints, size=(M, 2))
295 |
296 | db.add_matches(image_id1, image_id2, matches12)
297 | db.add_matches(image_id2, image_id3, matches23)
298 | db.add_matches(image_id3, image_id4, matches34)
299 |
300 | # Commit the data to the file.
301 |
302 | db.commit()
303 |
304 | # Read and check cameras.
305 |
306 | rows = db.execute("SELECT * FROM cameras")
307 |
308 | camera_id, model, width, height, params, prior = next(rows)
309 | params = blob_to_array(params, np.float64)
310 | assert camera_id == camera_id1
311 | assert model == model1 and width == width1 and height == height1
312 | assert np.allclose(params, params1)
313 |
314 | camera_id, model, width, height, params, prior = next(rows)
315 | params = blob_to_array(params, np.float64)
316 | assert camera_id == camera_id2
317 | assert model == model2 and width == width2 and height == height2
318 | assert np.allclose(params, params2)
319 |
320 | # Read and check keypoints.
321 |
322 | keypoints = dict(
323 | (image_id, blob_to_array(data, np.float32, (-1, 2)))
324 | for image_id, data in db.execute(
325 | "SELECT image_id, data FROM keypoints"))
326 |
327 | assert np.allclose(keypoints[image_id1], keypoints1)
328 | assert np.allclose(keypoints[image_id2], keypoints2)
329 | assert np.allclose(keypoints[image_id3], keypoints3)
330 | assert np.allclose(keypoints[image_id4], keypoints4)
331 |
332 | # Read and check matches.
333 |
334 | pair_ids = [image_ids_to_pair_id(*pair) for pair in
335 | ((image_id1, image_id2),
336 | (image_id2, image_id3),
337 | (image_id3, image_id4))]
338 |
339 | matches = dict(
340 | (pair_id_to_image_ids(pair_id),
341 | blob_to_array(data, np.uint32, (-1, 2)))
342 | for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
343 | )
344 |
345 | assert np.all(matches[(image_id1, image_id2)] == matches12)
346 | assert np.all(matches[(image_id2, image_id3)] == matches23)
347 | assert np.all(matches[(image_id3, image_id4)] == matches34)
348 |
349 | # Clean up.
350 |
351 | db.close()
352 |
353 | if os.path.exists(args.database_path):
354 | os.remove(args.database_path)
355 |
356 |
357 | if __name__ == "__main__":
358 | example_usage()
359 |
--------------------------------------------------------------------------------
/lib/datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import os
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class TranslationDataset(Dataset):
10 | def __init__(
11 | self,
12 | base_path=None,
13 | features=['brief', 'sift-kornia', 'hardnet', 'sosnet'],
14 | subsampling_ratio=1.0
15 | ):
16 | self.features = features
17 |
18 | self.arrays = {}
19 | for feature in self.features:
20 | npy_path = os.path.join(
21 | base_path, '%s-features.npy' % feature
22 | )
23 | descriptors = np.load(npy_path)
24 | # Deterministic subsampling.
25 | if subsampling_ratio < 1.0:
26 | num_descriptors = int(np.ceil(subsampling_ratio * descriptors.shape[0]))
27 | random = np.random.RandomState(seed=1)
28 | selected_ids = random.choice(descriptors.shape[0], num_descriptors, replace=False)
29 | else:
30 | selected_ids = np.arange(descriptors.shape[0])
31 | self.arrays[feature] = descriptors[selected_ids, :]
32 |
33 | self.len = self.arrays[features[0]].shape[0]
34 |
35 | def __len__(self):
36 | return self.len
37 |
38 | def __getitem__(self, idx):
39 | sample = {}
40 | for feature in self.features:
41 | sample[feature] = torch.from_numpy(
42 | self.arrays[feature][idx, :]
43 | ).float()
44 | return sample
45 |
--------------------------------------------------------------------------------
/lib/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 |
7 |
8 | def exhaustive_loss(encoders, decoders, batch, device, alpha=0.1, margin=1.0):
9 | # Translation loss.
10 | embeddings = {}
11 | for source_feature in batch.keys():
12 | source_descriptors = batch[source_feature]
13 | embeddings[source_feature] = encoders[source_feature](source_descriptors)
14 |
15 | all_embeddings = torch.cat(list(embeddings.values()), dim=0)
16 |
17 | t_loss = torch.tensor(0.).float().to(device)
18 | for target_feature in batch.keys():
19 | target_descriptors = batch[target_feature]
20 | output_descriptors = decoders[target_feature](all_embeddings)
21 | if target_feature == 'brief':
22 | current_loss = F.binary_cross_entropy(
23 | output_descriptors,
24 | torch.cat([target_descriptors] * len(batch), dim=0)
25 | )
26 | else:
27 | current_loss = torch.mean(
28 | torch.norm(output_descriptors - torch.cat([target_descriptors] * len(batch), dim=0), dim=1)
29 | )
30 | t_loss += current_loss
31 | t_loss /= len(batch)
32 |
33 | # Triplet loss in embedding space.
34 | e_loss = torch.tensor(0.).float().to(device)
35 | if alpha > 0:
36 | for source_feature in embeddings.keys():
37 | for target_feature in embeddings.keys():
38 | # TODO: Implement symmetric negative mining.
39 | sqdist_matrix = 2 - 2 * embeddings[source_feature] @ embeddings[target_feature].T
40 | pos_dist = torch.norm(torch.diag(sqdist_matrix).unsqueeze(-1), dim=-1)
41 | sqdist_matrix = sqdist_matrix + torch.diag(torch.full((sqdist_matrix.shape[0],), np.inf)).to(device)
42 | # neg_sqdist = torch.min(torch.min(sqdist_matrix, dim=-1)[0], torch.min(sqdist_matrix, dim=0)[0])
43 | neg_sqdist = torch.min(sqdist_matrix, dim=-1)[0]
44 | neg_dist = torch.norm(neg_sqdist.unsqueeze(-1), dim=-1)
45 | e_loss = e_loss + torch.mean(
46 | F.relu(margin + pos_dist - neg_dist)
47 | )
48 | e_loss /= (len(batch) ** 2)
49 |
50 | # Final loss.
51 | if alpha > 0:
52 | loss = t_loss + alpha * e_loss
53 | else:
54 | loss = t_loss
55 |
56 | return loss, (t_loss.detach(), e_loss.detach())
57 |
58 |
59 | def autoencoder_loss(encoders, decoders, batch, device, alpha=0.1, margin=1.0):
60 | # AE loss.
61 | embeddings = {}
62 | t_loss = torch.tensor(0.).float().to(device)
63 | for source_feature in batch.keys():
64 | source_descriptors = batch[source_feature]
65 | current_embeddings = encoders[source_feature](source_descriptors)
66 | embeddings[source_feature] = current_embeddings
67 | output_descriptors = decoders[source_feature](current_embeddings)
68 | if source_feature == 'brief':
69 | current_loss = F.binary_cross_entropy(
70 | output_descriptors, source_descriptors
71 | )
72 | else:
73 | current_loss = torch.mean(
74 | torch.norm(output_descriptors - source_descriptors, dim=1)
75 | )
76 | t_loss += current_loss
77 | t_loss /= len(batch)
78 |
79 | # Triplet loss in embedding space.
80 | e_loss = torch.tensor(0.).float().to(device)
81 | if alpha > 0:
82 | for source_feature in embeddings.keys():
83 | for target_feature in embeddings.keys():
84 | # TODO: Implement symmetric negative mining.
85 | sqdist_matrix = 2 - 2 * embeddings[source_feature] @ embeddings[target_feature].T
86 | pos_dist = torch.norm(torch.diag(sqdist_matrix).unsqueeze(-1), dim=-1)
87 | sqdist_matrix = sqdist_matrix + torch.diag(torch.full((sqdist_matrix.shape[0],), np.inf)).to(device)
88 | # neg_sqdist = torch.min(torch.min(sqdist_matrix, dim=-1)[0], torch.min(sqdist_matrix, dim=0)[0])
89 | neg_sqdist = torch.min(sqdist_matrix, dim=-1)[0]
90 | neg_dist = torch.norm(neg_sqdist.unsqueeze(-1), dim=-1)
91 | e_loss = e_loss + torch.mean(
92 | F.relu(margin + pos_dist - neg_dist)
93 | )
94 | e_loss /= (len(batch) ** 2)
95 |
96 | # Final loss.
97 | if alpha > 0:
98 | loss = t_loss + alpha * e_loss
99 | else:
100 | loss = t_loss
101 |
102 | return loss, (t_loss.detach(), e_loss.detach())
103 |
104 |
105 | def subset_loss(encoders, decoders, batch, device, alpha=0.1, margin=1.0):
106 | # Select random pairs of descriptors.
107 | # Make sure that all encoders and all encoders are used.
108 | source_features = np.array(list(batch.keys()))
109 | target_features = np.array(source_features)
110 | np.random.shuffle(target_features)
111 |
112 | # Translation loss.
113 | embeddings = {}
114 | t_loss = torch.tensor(0.).float().to(device)
115 | for source_feature, target_feature in zip(source_features, target_features):
116 | source_descriptors = batch[source_feature]
117 | target_descriptors = batch[target_feature]
118 | current_embeddings = encoders[source_feature](source_descriptors)
119 | embeddings[source_feature] = current_embeddings
120 | output_descriptors = decoders[target_feature](current_embeddings)
121 | if target_feature == 'brief':
122 | current_loss = F.binary_cross_entropy(
123 | output_descriptors, target_descriptors
124 | )
125 | else:
126 | current_loss = torch.mean(
127 | torch.norm(output_descriptors - target_descriptors, dim=1)
128 | )
129 | t_loss += current_loss
130 | t_loss /= len(batch)
131 |
132 | # Triplet loss in embedding space.
133 | e_loss = torch.tensor(0.).float().to(device)
134 | if alpha > 0:
135 | for source_feature, target_feature in zip(source_features, target_features):
136 | # TODO: Implement symmetric negative mining.
137 | sqdist_matrix = 2 - 2 * embeddings[source_feature] @ embeddings[target_feature].T
138 | pos_dist = torch.norm(torch.diag(sqdist_matrix).unsqueeze(-1), dim=-1)
139 | sqdist_matrix = sqdist_matrix + torch.diag(torch.full((sqdist_matrix.shape[0],), np.inf)).to(device)
140 | # neg_sqdist = torch.min(torch.min(sqdist_matrix, dim=-1)[0], torch.min(sqdist_matrix, dim=0)[0])
141 | neg_sqdist = torch.min(sqdist_matrix, dim=-1)[0]
142 | neg_dist = torch.norm(neg_sqdist.unsqueeze(-1), dim=-1)
143 | e_loss = e_loss + torch.mean(
144 | F.relu(margin + pos_dist - neg_dist)
145 | )
146 | e_loss /= len(batch)
147 |
148 | # Final loss.
149 | if alpha > 0:
150 | loss = t_loss + alpha * e_loss
151 | else:
152 | loss = t_loss
153 |
154 | return loss, (t_loss.detach(), e_loss.detach())
155 |
--------------------------------------------------------------------------------
/lib/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MLP(nn.Module):
7 | def __init__(
8 | self,
9 | num_channels,
10 | use_cuda=True,
11 | use_bn=False,
12 | last_activation=None,
13 | l2_norm=False
14 | ):
15 | super(MLP, self).__init__()
16 |
17 | # Retrieve from num_channels.
18 | source_num_channels = num_channels[0]
19 | target_num_channels = num_channels[-1]
20 | hidden_num_channels = num_channels[1 : -1]
21 |
22 | # Prepare layers.
23 | layers = []
24 | previous_num_channels = source_num_channels
25 | for current_num_channels in hidden_num_channels:
26 | layers.append(
27 | nn.Linear(previous_num_channels, current_num_channels)
28 | )
29 | if use_bn:
30 | layers.append(
31 | nn.BatchNorm1d(current_num_channels)
32 | )
33 | layers.append(
34 | nn.ReLU(inplace=True)
35 | )
36 | previous_num_channels = current_num_channels
37 | layers.append(nn.Linear(previous_num_channels, target_num_channels))
38 |
39 | # Make a sequential model.
40 | self.network = nn.Sequential(*layers)
41 |
42 | # Last activation.
43 | if last_activation is None:
44 | self.last_activation = None
45 | elif last_activation.lower() == 'relu':
46 | self.last_activation = nn.ReLU()
47 | elif last_activation.lower() == 'sigmoid':
48 | self.last_activation = nn.Sigmoid()
49 | else:
50 | raise NotImplementedError('Unknown activation "%s".' % last_activation)
51 |
52 | # L2-normalize output.
53 | self.l2_norm = l2_norm
54 |
55 | # Move to GPU if needed.
56 | if use_cuda:
57 | self.cuda()
58 |
59 | def forward(self, batch):
60 | output = self.network(batch)
61 | if self.last_activation is not None:
62 | output = self.last_activation(output)
63 | if self.l2_norm:
64 | output = F.normalize(output, dim=1)
65 | return output
66 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | from lib.networks import MLP
2 |
3 |
4 | def create_network_for_feature(feature, config, use_cuda):
5 | use_bn = config['use_bn']
6 | emb_dim = config['emb_dim']
7 | emb_l2_norm = config['emb_l2_norm']
8 | emb_last_activation = config['emb_last_activation']
9 | descriptor_dim = config[feature]['descriptor_dim']
10 | hidden_dims = config[feature]['hidden_dims']
11 | l2_norm = config[feature]['l2_norm']
12 | last_activation = config[feature]['last_activation']
13 |
14 | # Define encoder.
15 | encoder = MLP(
16 | num_channels=[descriptor_dim] + hidden_dims + [emb_dim], use_cuda=use_cuda,
17 | use_bn=use_bn, last_activation=emb_last_activation, l2_norm=emb_l2_norm
18 | )
19 |
20 | # Define decoder.
21 | decoder = MLP(
22 | num_channels=[emb_dim] + hidden_dims[:: -1] + [descriptor_dim], use_cuda=use_cuda,
23 | use_bn=use_bn, last_activation=last_activation, l2_norm=l2_norm
24 | )
25 |
26 | return encoder, decoder
--------------------------------------------------------------------------------
/local-feature-evaluation/align_and_compare.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 |
5 | import os
6 |
7 | import subprocess
8 |
9 |
10 | def qvec_to_rotmat(qvec):
11 | w, x, y, z = qvec
12 | R = np.array([
13 | [
14 | 1 - 2 * y * y - 2 * z * z,
15 | 2 * x * y - 2 * z * w,
16 | 2 * x * z + 2 * y * w
17 | ],
18 | [
19 | 2 * x * y + 2 * z * w,
20 | 1 - 2 * x * x - 2 * z * z,
21 | 2 * y * z - 2 * x * w
22 | ],
23 | [
24 | 2 * x * z - 2 * y * w,
25 | 2 * y * z + 2 * x * w,
26 | 1 - 2 * x * x - 2 * y * y
27 | ]
28 | ])
29 | return R
30 |
31 |
32 | if __name__ == '__main__':
33 | parser = argparse.ArgumentParser()
34 |
35 | parser.add_argument(
36 | '--model_path', required=True,
37 | help='path to the model'
38 | )
39 | parser.add_argument(
40 | '--reference_model_path', required=True,
41 | help='path to the reference model'
42 | )
43 | parser.add_argument(
44 | '--colmap_path', required=True,
45 | help='path to the COLMAP executable folder'
46 | )
47 |
48 | args = parser.parse_args()
49 |
50 | # Create the output path.
51 | aligned_model_path = os.path.join(args.model_path, 'aligned')
52 |
53 | if not os.path.exists(aligned_model_path):
54 | os.mkdir(aligned_model_path)
55 |
56 | # Read and cache the reference model.
57 | with open(os.path.join(args.reference_model_path, 'images.txt'), 'r') as f:
58 | lines = f.readlines()
59 | reference_poses = {}
60 | for line in lines[4 :: 2]:
61 | line = line.strip('\n').split(' ')
62 | image_name = line[-1]
63 | qvec = np.array(list(map(float, line[1 : 5])))
64 | t = np.array(list(map(float, line[5 : 8])))
65 | R = qvec_to_rotmat(qvec)
66 | reference_poses[image_name] = [R, t]
67 |
68 | # Run the model aligner.
69 | subprocess.call([
70 | os.path.join(args.colmap_path, 'colmap'), 'model_aligner',
71 | '--input_path', args.model_path,
72 | '--output_path', aligned_model_path,
73 | '--ref_images_path', os.path.join(args.reference_model_path, 'geo.txt'),
74 | '--robust_alignment_max_error', str(0.25)
75 | ])
76 |
77 | subprocess.call([
78 | os.path.join(args.colmap_path, 'colmap'), 'model_converter',
79 | '--input_path', aligned_model_path,
80 | '--output_path', aligned_model_path,
81 | '--output_type', 'TXT'
82 | ])
83 |
84 | # Parse the aligned model.
85 | with open(os.path.join(aligned_model_path, 'images.txt'), 'r') as f:
86 | lines = f.readlines()
87 | ori_errors = []
88 | center_errors = []
89 | image_ids = []
90 | for line in lines[4 :: 2]:
91 | line = line.strip('\n').split(' ')
92 | image_id = int(line[0])
93 | image_name = line[-1]
94 | qvec = np.array(list(map(float, line[1 : 5])))
95 | t = np.array(list(map(float, line[5 : 8])))
96 | R = qvec_to_rotmat(qvec)
97 | # Compute the error.
98 | annotated_R, annotated_t = reference_poses[image_name]
99 |
100 | rotation_difference = R @ annotated_R.transpose()
101 | ori_error = np.rad2deg(np.arccos(np.clip((np.trace(rotation_difference) - 1) / 2, -1, 1)))
102 |
103 | annotated_C = (-1) * annotated_R.transpose() @ annotated_t
104 | C = (-1) * R.transpose() @ t
105 | center_error = np.linalg.norm(C - annotated_C)
106 | if center_error > 0.10:
107 | image_ids.append((image_id, image_name))
108 |
109 | ori_errors.append(ori_error)
110 | center_errors.append(center_error)
111 | ori_errors = np.array(ori_errors)
112 | center_errors = np.array(center_errors)
113 |
114 | print('Registered %d out of %d images.' % (len(ori_errors), len(reference_poses)))
115 | print('0.25m, 2 deg:', np.sum(np.logical_and(ori_errors <= 2, center_errors <= 0.25)) / len(reference_poses))
116 | print('0.50m, 5 deg:', np.sum(np.logical_and(ori_errors <= 5, center_errors <= 0.50)) / len(reference_poses))
117 | print('inf :', len(ori_errors) / len(reference_poses))
118 |
--------------------------------------------------------------------------------
/local-feature-evaluation/reconstruction_pipeline_embed.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import json
4 |
5 | import numpy as np
6 |
7 | import os
8 |
9 | import shutil
10 |
11 | import sqlite3
12 |
13 | import sys
14 |
15 | import types
16 |
17 | import torch
18 |
19 | from utils import translate_descriptors, build_hybrid_database, match_features, geometric_verification, reconstruct, blob_to_array, compute_extra_stats
20 |
21 | sys.path.append(os.getcwd())
22 | from lib.utils import create_network_for_feature
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser()
27 |
28 | parser.add_argument(
29 | '--dataset_path', type=str, required=True,
30 | help='path to the dataset'
31 | )
32 |
33 | parser.add_argument(
34 | '--colmap_path', type=str, required=True,
35 | help='path to the COLMAP executable folder'
36 | )
37 |
38 | parser.add_argument(
39 | '--features', nargs='+', type=str, required=True,
40 | help='list of descriptors to consider'
41 | )
42 |
43 | parser.add_argument(
44 | '--exp_name', type=str, required=True,
45 | help='name of the experiment'
46 | )
47 |
48 | parser.add_argument(
49 | '--checkpoint', type=str, default='checkpoints-pretrained/model.pth',
50 | help='path to the checkpoint'
51 | )
52 |
53 | parser.add_argument(
54 | '--batch_size', type=int, default=4096,
55 | help='batch size'
56 | )
57 |
58 | args = parser.parse_args()
59 | return args
60 |
61 |
62 | def translate_database(image_features, database_path, encoders, batch_size, device):
63 | shutil.copy(database_path, database_path + '.aux')
64 |
65 | connection = sqlite3.connect(database_path + '.aux')
66 | cursor = connection.cursor()
67 |
68 | output_connection = sqlite3.connect(database_path)
69 | output_cursor = output_connection.cursor()
70 | output_cursor.execute(
71 | 'DELETE FROM descriptors'
72 | )
73 | output_connection.commit()
74 |
75 | for image_id, feature in image_features.items():
76 | image_id = int(image_id)
77 | result = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id,))).fetchall()
78 | data, rows, cols = result[0]
79 | try:
80 | descriptors = blob_to_array(data, np.float32, (rows, cols))
81 | except ValueError:
82 | descriptors = blob_to_array(data, bool, (rows, cols)).astype(np.float32)
83 |
84 | descriptors = translate_descriptors(descriptors, feature, None, encoders, {}, batch_size, device)
85 |
86 | output_cursor.execute(
87 | 'INSERT INTO descriptors VALUES (?, ?, ?, ?)',
88 | (image_id,) + descriptors.shape + (descriptors.tobytes(),)
89 | )
90 | output_connection.commit()
91 |
92 | cursor.close()
93 | connection.close()
94 | os.remove(database_path + '.aux')
95 |
96 | output_cursor.close()
97 | output_connection.close()
98 |
99 |
100 | def main():
101 | # Set CUDA.
102 | use_cuda = torch.cuda.is_available()
103 | device = torch.device("cuda:0" if use_cuda else "cpu")
104 | torch.set_grad_enabled(False)
105 |
106 | # Load config json.
107 | with open('checkpoints-pretrained/config.json', 'r') as f:
108 | config = json.load(f)
109 |
110 | # Parse arguments.
111 | args = parse_args()
112 | assert(len(args.features) > 1)
113 |
114 | paths = types.SimpleNamespace()
115 | paths.database_path = os.path.join(
116 | args.dataset_path, '%s.db' % args.exp_name
117 | )
118 | paths.image_path = os.path.join(
119 | args.dataset_path, 'images'
120 | )
121 | paths.match_list_path = os.path.join(
122 | args.dataset_path, 'match-list-exh.txt'
123 | )
124 | paths.sparse_path = os.path.join(
125 | args.dataset_path, 'sparse-%s' % args.exp_name
126 | )
127 | paths.output_path = os.path.join(
128 | args.dataset_path, 'stats-%s.txt' % args.exp_name
129 | )
130 |
131 | # Copy reference database.
132 | if os.path.exists(paths.database_path):
133 | raise FileExistsError('Database file already exists.')
134 | shutil.copy(
135 | os.path.join(args.dataset_path, 'database.db'),
136 | paths.database_path
137 | )
138 |
139 | # Create networks.
140 | checkpoint = torch.load(args.checkpoint)
141 | encoders = {}
142 | for feature in args.features:
143 | encoder, _ = create_network_for_feature(feature, config, use_cuda)
144 | state_dict = list(filter(lambda x: x[0] == feature, checkpoint['encoders']))[0]
145 | encoder.load_state_dict(state_dict[1])
146 | encoder.eval()
147 | encoders[feature] = encoder
148 |
149 | # Build and translate database.
150 | image_features = build_hybrid_database(
151 | args.features,
152 | args.dataset_path,
153 | paths.database_path
154 | )
155 | np.save(os.path.join(args.dataset_path, 'features.npy'), image_features)
156 | translate_database(
157 | image_features,
158 | paths.database_path,
159 | encoders, args.batch_size, device
160 | )
161 |
162 | # Matching + GV + reconstruction.
163 | match_features(
164 | args.colmap_path,
165 | paths.database_path, paths.image_path, paths.match_list_path
166 | )
167 | torch.cuda.empty_cache()
168 | matching_stats = geometric_verification(
169 | args.colmap_path,
170 | paths.database_path, paths.match_list_path
171 | )
172 | largest_model_path, reconstruction_stats = reconstruct(
173 | args.colmap_path,
174 | paths.database_path, paths.image_path, paths.sparse_path
175 | )
176 | extra_stats = compute_extra_stats(image_features, largest_model_path)
177 |
178 | with open(paths.output_path, 'w') as f:
179 | f.write(json.dumps(matching_stats))
180 | f.write('\n')
181 | f.write(json.dumps(reconstruction_stats))
182 | f.write('\n')
183 | f.write(json.dumps(extra_stats))
184 | f.write('\n')
185 |
186 |
187 | if __name__ == '__main__':
188 | main()
189 |
--------------------------------------------------------------------------------
/local-feature-evaluation/reconstruction_pipeline_progressive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import json
4 |
5 | import numpy as np
6 |
7 | import os
8 |
9 | import shutil
10 |
11 | import sys
12 |
13 | import types
14 |
15 | import torch
16 |
17 | from utils import build_hybrid_database, match_features_hybrid, geometric_verification, reconstruct, compute_extra_stats
18 |
19 | sys.path.append(os.getcwd())
20 | from lib.utils import create_network_for_feature
21 |
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser()
25 |
26 | parser.add_argument(
27 | '--dataset_path', type=str, required=True,
28 | help='path to the dataset'
29 | )
30 |
31 | parser.add_argument(
32 | '--colmap_path', type=str, required=True,
33 | help='path to the COLMAP executable folder'
34 | )
35 |
36 | parser.add_argument(
37 | '--features', nargs='+', type=str, required=True,
38 | help='list of descriptors to consider'
39 | )
40 |
41 | parser.add_argument(
42 | '--exp_name', type=str, required=True,
43 | help='name of the experiment'
44 | )
45 |
46 | parser.add_argument(
47 | '--checkpoint', type=str, default='checkpoints-pretrained/model.pth',
48 | help='path to the checkpoint'
49 | )
50 |
51 | parser.add_argument(
52 | '--batch_size', type=int, default=4096,
53 | help='batch size'
54 | )
55 |
56 | args = parser.parse_args()
57 | return args
58 |
59 |
60 | def main():
61 | # Set CUDA.
62 | use_cuda = torch.cuda.is_available()
63 | device = torch.device("cuda:0" if use_cuda else "cpu")
64 | torch.set_grad_enabled(False)
65 |
66 | # Load config json.
67 | with open('checkpoints-pretrained/config.json', 'r') as f:
68 | config = json.load(f)
69 |
70 | # Parse arguments.
71 | args = parse_args()
72 |
73 | paths = types.SimpleNamespace()
74 | paths.database_path = os.path.join(
75 | args.dataset_path, '%s.db' % args.exp_name
76 | )
77 | paths.image_path = os.path.join(
78 | args.dataset_path, 'images'
79 | )
80 | paths.match_list_path = os.path.join(
81 | args.dataset_path, 'match-list-exh.txt'
82 | )
83 | paths.sparse_path = os.path.join(
84 | args.dataset_path, 'sparse-%s' % args.exp_name
85 | )
86 | paths.output_path = os.path.join(
87 | args.dataset_path, 'stats-%s.txt' % args.exp_name
88 | )
89 |
90 | # Copy reference database.
91 | if os.path.exists(paths.database_path):
92 | raise FileExistsError('Database file already exists.')
93 | shutil.copy(
94 | os.path.join(args.dataset_path, 'database.db'),
95 | paths.database_path
96 | )
97 |
98 | # Create networks.
99 | encoders = {}
100 | decoders = {}
101 | if len(args.features) > 1:
102 | checkpoint = torch.load(args.checkpoint)
103 |
104 | for feature in args.features:
105 | encoder, decoder = create_network_for_feature(feature, config, use_cuda)
106 |
107 | state_dict = list(filter(lambda x: x[0] == feature, checkpoint['encoders']))[0]
108 | encoder.load_state_dict(state_dict[1])
109 | encoder.eval()
110 | encoders[feature] = encoder
111 |
112 | state_dict = list(filter(lambda x: x[0] == feature, checkpoint['decoders']))[0]
113 | decoder.load_state_dict(state_dict[1])
114 | decoder.eval()
115 | decoders[feature] = decoder
116 | else:
117 | encoders[args.features[0]] = None
118 | decoders[args.features[0]] = None
119 |
120 | # Build and translate database.
121 | image_features = build_hybrid_database(
122 | args.features,
123 | args.dataset_path,
124 | paths.database_path
125 | )
126 |
127 | # Matching + GV + reconstruction.
128 | match_features_hybrid(
129 | args.features,
130 | image_features,
131 | args.colmap_path,
132 | paths.database_path, paths.image_path, paths.match_list_path,
133 | encoders, decoders, args.batch_size, device
134 | )
135 | torch.cuda.empty_cache()
136 | matching_stats = geometric_verification(
137 | args.colmap_path,
138 | paths.database_path, paths.match_list_path
139 | )
140 | largest_model_path, reconstruction_stats = reconstruct(
141 | args.colmap_path,
142 | paths.database_path, paths.image_path, paths.sparse_path
143 | )
144 | extra_stats = compute_extra_stats(image_features, largest_model_path)
145 |
146 | with open(paths.output_path, 'w') as f:
147 | f.write(json.dumps(matching_stats))
148 | f.write('\n')
149 | f.write(json.dumps(reconstruction_stats))
150 | f.write('\n')
151 | f.write(json.dumps(extra_stats))
152 | f.write('\n')
153 |
154 |
155 | if __name__ == '__main__':
156 | main()
157 |
--------------------------------------------------------------------------------
/local-feature-evaluation/reconstruction_pipeline_subset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import json
4 |
5 | import os
6 |
7 | import shutil
8 |
9 | import types
10 |
11 | import torch
12 |
13 | from utils import build_hybrid_database, match_features_subset, geometric_verification, reconstruct, compute_extra_stats
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser()
18 |
19 | parser.add_argument(
20 | '--dataset_path', type=str, required=True,
21 | help='path to the dataset'
22 | )
23 |
24 | parser.add_argument(
25 | '--colmap_path', type=str, required=True,
26 | help='path to the COLMAP executable folder'
27 | )
28 |
29 | parser.add_argument(
30 | '--features', nargs='+', type=str, required=True,
31 | help='list of descriptors to consider'
32 | )
33 |
34 | parser.add_argument(
35 | '--feature', type=str, required=True,
36 | help='descriptors to map'
37 | )
38 |
39 | parser.add_argument(
40 | '--exp_name', type=str, required=True,
41 | help='name of the experiment'
42 | )
43 |
44 | args = parser.parse_args()
45 | return args
46 |
47 |
48 | def main():
49 | # Set CUDA.
50 | use_cuda = torch.cuda.is_available()
51 | device = torch.device("cuda:0" if use_cuda else "cpu")
52 | torch.set_grad_enabled(False)
53 |
54 | # Parse arguments.
55 | args = parse_args()
56 |
57 | paths = types.SimpleNamespace()
58 | paths.database_path = os.path.join(
59 | args.dataset_path, '%s.db' % args.exp_name
60 | )
61 | paths.image_path = os.path.join(
62 | args.dataset_path, 'images'
63 | )
64 | paths.match_list_path = os.path.join(
65 | args.dataset_path, 'match-list-exh.txt'
66 | )
67 | paths.sparse_path = os.path.join(
68 | args.dataset_path, 'sparse-%s' % args.exp_name
69 | )
70 | paths.output_path = os.path.join(
71 | args.dataset_path, 'stats-%s.txt' % args.exp_name
72 | )
73 |
74 | # Copy reference database.
75 | if os.path.exists(paths.database_path):
76 | raise FileExistsError('Database file already exists.')
77 | shutil.copy(
78 | os.path.join(args.dataset_path, 'database.db'),
79 | paths.database_path
80 | )
81 |
82 | # Build and translate database.
83 | image_features = build_hybrid_database(
84 | args.features,
85 | args.dataset_path,
86 | paths.database_path
87 | )
88 |
89 | # Matching + GV + reconstruction.
90 | match_features_subset(
91 | args.feature,
92 | image_features,
93 | args.colmap_path,
94 | paths.database_path, paths.image_path, paths.match_list_path
95 | )
96 | torch.cuda.empty_cache()
97 | matching_stats = geometric_verification(
98 | args.colmap_path,
99 | paths.database_path, paths.match_list_path + '.aux'
100 | )
101 | os.remove(paths.match_list_path + '.aux')
102 | largest_model_path, reconstruction_stats = reconstruct(
103 | args.colmap_path,
104 | paths.database_path, paths.image_path, paths.sparse_path
105 | )
106 | extra_stats = compute_extra_stats(image_features, largest_model_path)
107 |
108 | with open(paths.output_path, 'w') as f:
109 | f.write(json.dumps(matching_stats))
110 | f.write('\n')
111 | f.write(json.dumps(reconstruction_stats))
112 | f.write('\n')
113 | f.write(json.dumps(extra_stats))
114 | f.write('\n')
115 |
116 |
117 | if __name__ == '__main__':
118 | main()
119 |
--------------------------------------------------------------------------------
/local-feature-evaluation/utils.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/ahojnnes/local-feature-evaluation/blob/master/scripts/reconstruction_pipeline.py.
2 | # Copyright 2017, Johannes L. Schoenberger .
3 | import numpy as np
4 |
5 | import os
6 |
7 | import subprocess
8 |
9 | import sqlite3
10 |
11 | import torch
12 |
13 | from tqdm import tqdm
14 |
15 |
16 | def mnn_ratio_matcher(descriptors1, descriptors2, ratio=0.9):
17 | # Mutual NN + symmetric Lowe's ratio test matcher.
18 | descriptors1 = torch.from_numpy(np.array(descriptors1)).float().cuda()
19 | descriptors2 = torch.from_numpy(np.array(descriptors2)).float().cuda()
20 |
21 | # L2 normalize descriptors.
22 | descriptors1 /= torch.norm(descriptors1, dim=-1).unsqueeze(-1)
23 | descriptors2 /= torch.norm(descriptors2, dim=-1).unsqueeze(-1)
24 |
25 | # Similarity matrix.
26 | device = descriptors1.device
27 | sim = descriptors1 @ descriptors2.t()
28 |
29 | # Retrieve top 2 nearest neighbors 1->2.
30 | nns_sim, nns = torch.topk(sim, 2, dim=1)
31 | nns_dist = torch.sqrt(2 - 2 * nns_sim)
32 | # Compute Lowe's ratio.
33 | ratios12 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8)
34 | # Save first NN and match similarity.
35 | nn12 = nns[:, 0]
36 | match_sim = nns_sim[:, 0]
37 |
38 | # Retrieve top 2 nearest neighbors 1->2.
39 | nns_sim, nns = torch.topk(sim.t(), 2, dim=1)
40 | nns_dist = torch.sqrt(2 - 2 * nns_sim)
41 | # Compute Lowe's ratio.
42 | ratios21 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8)
43 | # Save first NN.
44 | nn21 = nns[:, 0]
45 |
46 | # Mutual NN + symmetric ratio test.
47 | ids1 = torch.arange(0, sim.shape[0], device=device)
48 | mask = torch.min(ids1 == nn21[nn12], torch.min(ratios12 <= ratio, ratios21[nn12] <= ratio))
49 |
50 | # Final matches.
51 | matches = torch.stack([ids1[mask], nn12[mask]], dim=-1)
52 | match_sim = match_sim[mask]
53 |
54 | return (
55 | matches.data.cpu().numpy(),
56 | match_sim.data.cpu().numpy()
57 | )
58 |
59 |
60 | def translate_descriptors(descriptors, source_feature, target_feature, encoders, decoders, batch_size, device):
61 | descriptors_torch = torch.from_numpy(descriptors).to(device)
62 | start_idx = 0
63 | translated_descriptors_torch = torch.zeros([0, 128]).to(device)
64 | while start_idx < descriptors_torch.shape[0]:
65 | aux = encoders[source_feature](descriptors_torch[start_idx : start_idx + batch_size])
66 | if target_feature is not None:
67 | aux = decoders[target_feature](aux)
68 | translated_descriptors_torch = torch.cat([translated_descriptors_torch, aux], dim=0)
69 | start_idx += batch_size
70 | translated_descriptors = translated_descriptors_torch.cpu().numpy()
71 | return translated_descriptors
72 |
73 |
74 | def image_ids_to_pair_id(image_id1, image_id2):
75 | if image_id1 > image_id2:
76 | return 2147483647 * image_id2 + image_id1
77 | else:
78 | return 2147483647 * image_id1 + image_id2
79 |
80 |
81 | def array_to_blob(array):
82 | return array.tobytes()
83 |
84 |
85 | def blob_to_array(blob, dtype, shape=(-1,)):
86 | if blob is None:
87 | return np.zeros(shape, dtype=dtype)
88 | return np.frombuffer(blob, dtype=dtype).reshape(*shape)
89 |
90 |
91 | def build_hybrid_database(features, dataset_path, database_path):
92 | sorted_features = sorted(features)
93 | n_features = len(sorted_features)
94 |
95 | # Connect to database.
96 | connection = sqlite3.connect(database_path)
97 | cursor = connection.cursor()
98 |
99 | cursor.execute('SELECT name, image_id FROM images;')
100 | images = {}
101 | image_names = {}
102 | for row in cursor:
103 | images[row[0]] = row[1]
104 | image_names[row[1]] = row[0]
105 |
106 | # Randomize features
107 | random = np.random.RandomState()
108 | random.seed(1)
109 | assigned_features = np.array(
110 | list(range(n_features)) * (len(images) // n_features) +
111 | list(random.randint(0, n_features, len(images) % n_features))
112 | )
113 | random.shuffle(assigned_features)
114 |
115 | image_ids = np.array(list(images.values()))
116 | for feature_idx, feature in enumerate(sorted_features):
117 | connection_aux = sqlite3.connect(os.path.join(
118 | dataset_path, '%s-features.db' % feature
119 | ))
120 | cursor_aux = connection_aux.cursor()
121 |
122 | image_indices = np.where(assigned_features == feature_idx)[0]
123 | for image_id in image_ids[image_indices]:
124 | image_id = int(image_id)
125 | assert image_names[image_id] == cursor_aux.execute("SELECT name FROM images WHERE image_id=?", (image_id,)).fetchall()[0][0]
126 | data, rows, cols = (cursor_aux.execute("SELECT data, rows, cols FROM keypoints WHERE image_id=?", (image_id,))).fetchall()[0]
127 | cursor.execute(
128 | 'INSERT INTO keypoints(image_id, rows, cols, data) VALUES(?, ?, ?, ?);',
129 | (image_id, rows, cols, data)
130 | )
131 | data, rows, cols = (cursor_aux.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id,))).fetchall()[0]
132 | cursor.execute(
133 | 'INSERT INTO descriptors(image_id, rows, cols, data) VALUES(?, ?, ?, ?);',
134 | (image_id, rows, cols, data)
135 | )
136 |
137 | cursor_aux.close()
138 | connection_aux.close()
139 | connection.commit()
140 |
141 | cursor.close()
142 | connection.close()
143 |
144 | image_features = {}
145 | for image_id, feature in zip(image_ids, np.array(sorted_features)[assigned_features]):
146 | image_features[image_id] = feature
147 |
148 | return image_features
149 |
150 |
151 | def match_features(colmap_path, database_path, image_path, match_list_path):
152 | connection = sqlite3.connect(database_path)
153 | cursor = connection.cursor()
154 |
155 | cursor.execute(
156 | 'SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'inlier_matches\';'
157 | )
158 | try:
159 | inlier_matches_table_exists = bool(next(cursor)[0])
160 | except StopIteration:
161 | inlier_matches_table_exists = False
162 |
163 | cursor.execute('DELETE FROM matches;')
164 | if inlier_matches_table_exists:
165 | cursor.execute('DELETE FROM inlier_matches;')
166 | else:
167 | cursor.execute('DELETE FROM two_view_geometries;')
168 | connection.commit()
169 |
170 | images = {}
171 | cursor.execute('SELECT name, image_id FROM images;')
172 | for row in cursor:
173 | images[row[0]] = row[1]
174 |
175 | with open(match_list_path, 'r') as f:
176 | raw_image_pairs = f.readlines()
177 | image_pairs = list(map(lambda s: s.strip('\n').split(' '), raw_image_pairs))
178 |
179 | image_pair_ids = set()
180 | for image_name1, image_name2 in tqdm(image_pairs):
181 | image_id1, image_id2 = images[image_name1], images[image_name2]
182 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2)
183 | if image_pair_id in image_pair_ids:
184 | continue
185 | image_pair_ids.add(image_pair_id)
186 |
187 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id1,))).fetchall()[0]
188 | try:
189 | descriptors1 = blob_to_array(data, np.float32, (rows, cols))
190 | except ValueError:
191 | descriptors1 = blob_to_array(data, bool, (rows, cols)).astype(np.float32)
192 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id2,))).fetchall()[0]
193 | try:
194 | descriptors2 = blob_to_array(data, np.float32, (rows, cols))
195 | except ValueError:
196 | descriptors2 = blob_to_array(data, bool, (rows, cols)).astype(np.float32)
197 |
198 | # Match.
199 | if descriptors1.shape[0] == 0 or descriptors2.shape[0] == 0:
200 | matches = np.zeros([0, 2], dtype=np.int)
201 | else:
202 | matches, _ = mnn_ratio_matcher(descriptors1, descriptors2)
203 |
204 | matches = np.array(matches).astype(np.uint32)
205 | if matches.shape[0] == 0:
206 | matches = np.zeros([0, 2])
207 | assert matches.shape[1] == 2
208 | if image_id1 > image_id2:
209 | matches = matches[:, [1, 0]]
210 | cursor.execute(
211 | 'INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?);',
212 | (image_pair_id, matches.shape[0], matches.shape[1], matches.tostring())
213 | )
214 | connection.commit()
215 |
216 | cursor.close()
217 | connection.close()
218 |
219 |
220 | def match_features_hybrid(features, image_features, colmap_path, database_path, image_path, match_list_path, encoders, decoders, batch_size, device):
221 | connection = sqlite3.connect(database_path)
222 | cursor = connection.cursor()
223 |
224 | cursor.execute(
225 | 'SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'inlier_matches\';'
226 | )
227 | try:
228 | inlier_matches_table_exists = bool(next(cursor)[0])
229 | except StopIteration:
230 | inlier_matches_table_exists = False
231 |
232 | cursor.execute('DELETE FROM matches;')
233 | if inlier_matches_table_exists:
234 | cursor.execute('DELETE FROM inlier_matches;')
235 | else:
236 | cursor.execute('DELETE FROM two_view_geometries;')
237 | connection.commit()
238 |
239 | images = {}
240 | cursor.execute('SELECT name, image_id FROM images;')
241 | for row in cursor:
242 | images[row[0]] = row[1]
243 |
244 | with open(match_list_path, 'r') as f:
245 | raw_image_pairs = f.readlines()
246 | image_pairs = list(map(lambda s: s.strip('\n').split(' '), raw_image_pairs))
247 |
248 | image_pair_ids = set()
249 | for image_name1, image_name2 in tqdm(image_pairs):
250 | image_id1, image_id2 = images[image_name1], images[image_name2]
251 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2)
252 | if image_pair_id in image_pair_ids:
253 | continue
254 | image_pair_ids.add(image_pair_id)
255 |
256 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id1,))).fetchall()[0]
257 | try:
258 | descriptors1 = blob_to_array(data, np.float32, (rows, cols))
259 | except ValueError:
260 | descriptors1 = blob_to_array(data, bool, (rows, cols)).astype(np.float32)
261 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id2,))).fetchall()[0]
262 | try:
263 | descriptors2 = blob_to_array(data, np.float32, (rows, cols))
264 | except ValueError:
265 | descriptors2 = blob_to_array(data, bool, (rows, cols)).astype(np.float32)
266 |
267 | # Check feature consistency.
268 | feature1, feature2 = image_features[image_id1], image_features[image_id2]
269 |
270 | if feature1 != feature2:
271 | ford1 = features.index(feature1)
272 | ford2 = features.index(feature2)
273 | if ford1 > ford2:
274 | feature = feature1
275 | descriptors2 = translate_descriptors(descriptors2, feature2, feature1, encoders, decoders, batch_size, device)
276 | else:
277 | feature = feature2
278 | descriptors1 = translate_descriptors(descriptors1, feature1, feature2, encoders, decoders, batch_size, device)
279 | else:
280 | feature = feature1
281 |
282 | # Match.
283 | if descriptors1.shape[0] == 0 or descriptors2.shape[0] == 0:
284 | matches = np.zeros([0, 2], dtype=np.int)
285 | else:
286 | matches, _ = mnn_ratio_matcher(descriptors1, descriptors2)
287 |
288 | matches = np.array(matches).astype(np.uint32)
289 | if matches.shape[0] == 0:
290 | matches = np.zeros([0, 2])
291 | assert matches.shape[1] == 2
292 | if image_id1 > image_id2:
293 | matches = matches[:, [1, 0]]
294 | cursor.execute(
295 | 'INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?);',
296 | (image_pair_id, matches.shape[0], matches.shape[1], matches.tostring())
297 | )
298 | connection.commit()
299 |
300 | cursor.close()
301 | connection.close()
302 |
303 |
304 | def match_features_subset(feature, image_features, colmap_path, database_path, image_path, match_list_path):
305 | connection = sqlite3.connect(database_path)
306 | cursor = connection.cursor()
307 |
308 | cursor.execute(
309 | 'SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'inlier_matches\';'
310 | )
311 | try:
312 | inlier_matches_table_exists = bool(next(cursor)[0])
313 | except StopIteration:
314 | inlier_matches_table_exists = False
315 |
316 | cursor.execute('DELETE FROM matches;')
317 | if inlier_matches_table_exists:
318 | cursor.execute('DELETE FROM inlier_matches;')
319 | else:
320 | cursor.execute('DELETE FROM two_view_geometries;')
321 | connection.commit()
322 |
323 | images = {}
324 | cursor.execute('SELECT name, image_id FROM images;')
325 | for row in cursor:
326 | images[row[0]] = row[1]
327 |
328 | image_pairs = set()
329 | for image_name1, image_id1 in images.items():
330 | for image_name2, image_id2 in images.items():
331 | if image_features[image_id1] != feature or image_features[image_id2] != feature:
332 | continue
333 | if image_name1 == image_name2:
334 | continue
335 | if (image_name2, image_name1) not in image_pairs:
336 | image_pairs.add((image_name1, image_name2))
337 |
338 | f = open(match_list_path + '.aux', 'w')
339 | image_pair_ids = set()
340 | for image_name1, image_name2 in tqdm(image_pairs):
341 | image_id1, image_id2 = images[image_name1], images[image_name2]
342 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2)
343 | if image_pair_id in image_pair_ids:
344 | continue
345 | image_pair_ids.add(image_pair_id)
346 |
347 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id1,))).fetchall()[0]
348 | try:
349 | descriptors1 = blob_to_array(data, np.float32, (rows, cols))
350 | except ValueError:
351 | descriptors1 = blob_to_array(data, bool, (rows, cols)).astype(np.float32)
352 | data, rows, cols = (cursor.execute("SELECT data, rows, cols FROM descriptors WHERE image_id=?", (image_id2,))).fetchall()[0]
353 | try:
354 | descriptors2 = blob_to_array(data, np.float32, (rows, cols))
355 | except ValueError:
356 | descriptors2 = blob_to_array(data, bool, (rows, cols)).astype(np.float32)
357 |
358 | # Match.
359 | if descriptors1.shape[0] == 0 or descriptors2.shape[0] == 0:
360 | matches = np.zeros([0, 2], dtype=np.int)
361 | else:
362 | matches, _ = mnn_ratio_matcher(descriptors1, descriptors2)
363 |
364 | matches = np.array(matches).astype(np.uint32)
365 | if matches.shape[0] == 0:
366 | matches = np.zeros([0, 2])
367 | assert matches.shape[1] == 2
368 | if image_id1 > image_id2:
369 | matches = matches[:, [1, 0]]
370 | cursor.execute(
371 | 'INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?);',
372 | (image_pair_id, matches.shape[0], matches.shape[1], matches.tostring())
373 | )
374 | f.write('%s %s\n' % (image_name1, image_name2))
375 | connection.commit()
376 |
377 | cursor.close()
378 | connection.close()
379 |
380 |
381 | def geometric_verification(colmap_path, database_path, match_list_path):
382 | subprocess.call([
383 | os.path.join(colmap_path, 'colmap'), 'matches_importer',
384 | '--database_path', database_path,
385 | '--match_list_path', match_list_path,
386 | '--match_type', 'pairs',
387 | '--SiftMatching.num_threads', str(8),
388 | '--SiftMatching.use_gpu', '0',
389 | '--SiftMatching.min_inlier_ratio', '0.1'
390 | ])
391 |
392 | connection = sqlite3.connect(database_path)
393 | cursor = connection.cursor()
394 |
395 | cursor.execute('SELECT count(*) FROM images;')
396 | num_images = next(cursor)[0]
397 |
398 | cursor.execute('SELECT count(*) FROM two_view_geometries WHERE rows > 0;')
399 | num_inlier_pairs = next(cursor)[0]
400 |
401 | cursor.execute('SELECT sum(rows) FROM two_view_geometries WHERE rows > 0;')
402 | num_inlier_matches = next(cursor)[0]
403 |
404 | cursor.close()
405 | connection.close()
406 |
407 | return dict(
408 | num_images=num_images,
409 | num_inlier_pairs=num_inlier_pairs,
410 | num_inlier_matches=num_inlier_matches
411 | )
412 |
413 | def reconstruct(colmap_path, database_path, image_path, sparse_path, refine_intrinsics=False):
414 | # Run the sparse reconstruction.
415 | if not os.path.exists(sparse_path):
416 | os.mkdir(sparse_path)
417 | if not refine_intrinsics:
418 | extra_mapper_params = [
419 | '--Mapper.ba_refine_focal_length', str(0),
420 | '--Mapper.ba_refine_principal_point', str(0),
421 | '--Mapper.ba_refine_extra_params', str(0)
422 | ]
423 | else:
424 | extra_mapper_params = [
425 | '--Mapper.ba_refine_focal_length', str(1),
426 | '--Mapper.ba_refine_principal_point', str(0),
427 | '--Mapper.ba_refine_extra_params', str(1)
428 | ]
429 | subprocess.call([
430 | os.path.join(colmap_path, 'colmap'), 'mapper',
431 | '--database_path', database_path,
432 | '--image_path', image_path,
433 | '--output_path', sparse_path,
434 | '--Mapper.abs_pose_min_inlier_ratio', str(0.05),
435 | '--Mapper.num_threads', str(16),
436 | ] + extra_mapper_params)
437 |
438 | # Find the largest reconstructed sparse model.
439 | models = os.listdir(sparse_path)
440 | if len(models) == 0:
441 | print('Warning: Could not reconstruct any model')
442 | return
443 |
444 | largest_model = None
445 | largest_model_num_images = 0
446 | for model in models:
447 | subprocess.call([
448 | os.path.join(colmap_path, 'colmap'), 'model_converter',
449 | '--input_path', os.path.join(sparse_path, model),
450 | '--output_path', os.path.join(sparse_path, model),
451 | '--output_type', 'TXT'
452 | ])
453 | with open(os.path.join(sparse_path, model, 'cameras.txt'), 'r') as fid:
454 | for line in fid:
455 | if line.startswith('# Number of cameras'):
456 | num_images = int(line.split()[-1])
457 | if num_images > largest_model_num_images:
458 | largest_model = model
459 | largest_model_num_images = num_images
460 | break
461 |
462 | assert largest_model_num_images > 0
463 |
464 | largest_model_path = os.path.join(sparse_path, largest_model)
465 |
466 | # Convert largest model to ply.
467 | subprocess.call([
468 | os.path.join(colmap_path, 'colmap'), 'model_converter',
469 | '--input_path', largest_model_path,
470 | '--output_path', os.path.join(sparse_path, 'pointcloud.ply'),
471 | '--output_type', 'PLY'
472 | ])
473 |
474 | # Recover model statistics.
475 | stats = subprocess.check_output([
476 | os.path.join(colmap_path, 'colmap'), 'model_analyzer',
477 | '--path', largest_model_path
478 | ])
479 |
480 | stats = stats.decode().split('\n')
481 | for stat in stats:
482 | if stat.startswith('Registered images'):
483 | num_reg_images = int(stat.split()[-1])
484 | elif stat.startswith('Points'):
485 | num_sparse_points = int(stat.split()[-1])
486 | elif stat.startswith('Observations'):
487 | num_observations = int(stat.split()[-1])
488 | elif stat.startswith('Mean track length'):
489 | mean_track_length = float(stat.split()[-1])
490 | elif stat.startswith('Mean observations per image'):
491 | num_observations_per_image = float(stat.split()[-1])
492 | elif stat.startswith('Mean reprojection error'):
493 | mean_reproj_error = float(stat.split()[-1][:-2])
494 |
495 | return largest_model_path, dict(
496 | num_reg_images=num_reg_images,
497 | num_sparse_points=num_sparse_points,
498 | num_observations=num_observations,
499 | mean_track_length=mean_track_length,
500 | num_observations_per_image=num_observations_per_image,
501 | mean_reproj_error=mean_reproj_error
502 | )
503 |
504 |
505 | def compute_extra_stats(image_features, largest_model_path):
506 | with open(os.path.join(largest_model_path, 'images.txt'), 'r') as f:
507 | raw_images = f.readlines()
508 | raw_images = raw_images[4 :][:: 2]
509 |
510 | counter = {}
511 | for raw_image in raw_images:
512 | image = raw_image.strip('\n').split(' ')
513 | feature = image_features[int(image[0])]
514 | if feature not in counter:
515 | counter[feature] = 0
516 | counter[feature] += 1
517 |
518 | return counter
519 |
--------------------------------------------------------------------------------
/scripts/download_checkpoints.sh:
--------------------------------------------------------------------------------
1 | cd checkpoints
2 | wget https://cvg-data.inf.ethz.ch/cross-descriptor-vis-loc-map-ICCV2021/checkpoints/model.pth
3 | wget https://cvg-data.inf.ethz.ch/cross-descriptor-vis-loc-map-ICCV2021/checkpoints/model-cvsift.pth
4 |
--------------------------------------------------------------------------------
/scripts/download_evaluation_data.sh:
--------------------------------------------------------------------------------
1 | cd data/eval
2 |
3 | # Aachen Day-Night.
4 |
5 | # Local Feature Evaluation.
6 | wget https://cvg-data.inf.ethz.ch/cross-descriptor-vis-loc-map-ICCV2021/data/LFE-release.tar.gz
7 | tar xvzf LFE-release.tar.gz
8 | rm LFE-release.tar.gz
9 |
--------------------------------------------------------------------------------
/scripts/download_processed_training_data.sh:
--------------------------------------------------------------------------------
1 | cd data/train
2 | wget https://cvg-data.inf.ethz.ch/cross-descriptor-vis-loc-map-ICCV2021/data/training-data-colmap.tar.gz
3 | tar xvzf training-data-colmap.tar.gz
4 | rm -r training-data-colmap.tar.gz
5 |
--------------------------------------------------------------------------------
/scripts/download_training_data.sh:
--------------------------------------------------------------------------------
1 | cd data/train
2 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.1.tar.gz
3 | tar -xzvf revisitop1m.1.tar.gz
4 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.2.tar.gz
5 | tar -xzvf revisitop1m.2.tar.gz
6 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.3.tar.gz
7 | tar -xzvf revisitop1m.3.tar.gz
8 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.4.tar.gz
9 | tar -xzvf revisitop1m.4.tar.gz
10 | wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.5.tar.gz
11 | tar -xzvf revisitop1m.5.tar.gz
12 | rm -f revisitop1m.*.tar.gz
--------------------------------------------------------------------------------
/scripts/process_LFE_data.sh:
--------------------------------------------------------------------------------
1 | LFE_PATH=data/eval/LFE-release
2 | for dataset in 'Gendarmenmarkt' 'Madrid_Metropolis' 'Tower_of_London'; do
3 | echo $dataset
4 | # Dataset already provides keypoints consistent with the reference database.
5 | # python feature-utils/extract_sift.py --colmap_path $COLMAP_PATH --dataset_path $LFE_PATH/$dataset --image_path $LFE_PATH/$dataset/images
6 | for feature in 'brief' 'sift-kornia' 'hardnet' 'sosnet'; do
7 | echo $feature
8 | python feature-utils/extract_descriptors.py --dataset_path $LFE_PATH/$dataset --image_path $LFE_PATH/$dataset/images --feature $feature
9 | done
10 | done
11 |
--------------------------------------------------------------------------------
/scripts/process_training_data.sh:
--------------------------------------------------------------------------------
1 | python feature-utils/extract_sift.py --colmap_path $COLMAP_PATH --dataset_path data/train/
2 | for feature in 'brief' 'sift-kornia' 'hardnet' 'sosnet'; do
3 | echo $feature
4 | python feature-utils/extract_descriptors.py --dataset_path data/train/ --feature $feature
5 | python feature-utils/convert_database_to_numpy.py --dataset_path data/train/ --feature $feature
6 | done
7 | mkdir data/train/colmap
8 | mv data/train/*.db data/train/colmap
9 | mv data/train/*.npy data/train/colmap
10 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import json
4 |
5 | import numpy as np
6 |
7 | import os
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 |
14 | from torch.utils.data import DataLoader
15 |
16 | from tqdm import tqdm
17 |
18 | from lib.datasets import TranslationDataset as Dataset
19 | from lib.losses import exhaustive_loss
20 | from lib.utils import create_network_for_feature
21 |
22 |
23 | def parse_arguments():
24 | parser = argparse.ArgumentParser(description='Training script')
25 |
26 | parser.add_argument(
27 | '--random_seed', type=int, default=1,
28 | help='random seed for numpy and PyTorch'
29 | )
30 |
31 | parser.add_argument(
32 | '--dataset_path', type=str, required=True,
33 | help='path to the dataset'
34 | )
35 |
36 | parser.add_argument(
37 | '--features', nargs='+', type=str, required=True,
38 | help='list of descriptors to consider'
39 | )
40 |
41 | parser.add_argument(
42 | '--initial_checkpoint', type=str, default=None,
43 | help='path to the initial checkpoint'
44 | )
45 |
46 | parser.add_argument(
47 | '--num_epochs', type=int, default=5,
48 | help='number of training epochs'
49 | )
50 | parser.add_argument(
51 | '--lr', type=float, default=1e-3,
52 | help='learning rate'
53 | )
54 | parser.add_argument(
55 | '--batch_size', type=int, default=1024,
56 | help='batch size'
57 | )
58 | parser.add_argument(
59 | '--num_workers', type=int, default=4,
60 | help='number of workers for data loading'
61 | )
62 |
63 | parser.add_argument(
64 | '--log_interval', type=int, default=1000,
65 | help='loss logging interval'
66 | )
67 |
68 | parser.add_argument(
69 | '--checkpoint_directory', type=str, default='checkpoints',
70 | help='directory for training checkpoints'
71 | )
72 | parser.add_argument(
73 | '--checkpoint_prefix', type=str, default='multi',
74 | help='prefix for training checkpoints'
75 | )
76 |
77 | parser.add_argument(
78 | '--alpha', type=float, default=0.1,
79 | help='consistency loss weight'
80 | )
81 | parser.add_argument(
82 | '--margin', type=float, default=1.0,
83 | help='margin for the negative margin loss'
84 | )
85 |
86 | args = parser.parse_args()
87 |
88 | print(args)
89 |
90 | return args
91 |
92 |
93 | # Updating mean class for loss aggregation.
94 | class UpdatingMean():
95 | def __init__(self):
96 | self.sum = 0
97 | self.n = 0
98 |
99 | def mean(self):
100 | return self.sum / self.n
101 |
102 | def add(self, loss):
103 | self.sum += loss
104 | self.n += 1
105 |
106 |
107 | # Epoch training / validation loop.
108 | def run_epoch(
109 | encoders,
110 | decoders,
111 | loss_function,
112 | optimizer,
113 | dataloader,
114 | device,
115 | log_file, train=True
116 | ):
117 | epoch_loss = UpdatingMean()
118 | epoch_t_loss = UpdatingMean()
119 | epoch_e_loss = UpdatingMean()
120 |
121 | torch.set_grad_enabled(train)
122 |
123 | progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
124 | for batch_idx, batch in progress_bar:
125 | # Move batch to device.
126 | for key in batch.keys():
127 | batch[key] = batch[key].to(device)
128 |
129 | # Reset gradient if needed.
130 | if train:
131 | optimizer.zero_grad()
132 |
133 | # Compute loss.
134 | loss, (t_loss, e_loss) = loss_function(encoders, decoders, batch, device)
135 |
136 | # Add loss to history.
137 | epoch_loss.add(loss.data.cpu().numpy())
138 | epoch_t_loss.add(t_loss)
139 | epoch_e_loss.add(e_loss)
140 |
141 | # Update progress bar.
142 | progress_bar.set_postfix(
143 | loss=('%.4f' % epoch_loss.mean()),
144 | t_loss=('%.4f' % epoch_t_loss.mean()),
145 | e_loss=('%.4f' % epoch_e_loss.mean())
146 | )
147 |
148 | # Update logs.
149 | if batch_idx % args.log_interval == 0:
150 | log_file.write('[%s] epoch %02d - batch %04d / %04d - avg_loss: %f, avg_t_loss: %f, avg_e_loss: %f\n' % (
151 | 'train' if train else 'valid',
152 | epoch_idx, batch_idx, len(dataloader),
153 | epoch_loss.mean(), epoch_t_loss.mean(), epoch_e_loss.mean()
154 | ))
155 |
156 | # Backprop.
157 | if train:
158 | loss.backward()
159 | optimizer.step()
160 |
161 | # Update logs.
162 | log_file.write('[%s] epoch %02d - avg_loss: %f, avg_t_loss: %f, avg_e_loss: %f\n' % (
163 | 'train' if train else 'valid',
164 | epoch_idx,
165 | epoch_loss.mean(), epoch_t_loss.mean(), epoch_e_loss.mean()
166 | ))
167 | log_file.flush()
168 |
169 | return epoch_loss.mean()
170 |
171 |
172 | if __name__ == '__main__':
173 | # Set CUDA.
174 | use_cuda = torch.cuda.is_available()
175 | device = torch.device("cuda:0" if use_cuda else "cpu")
176 |
177 | # Load config json.
178 | with open('checkpoints-pretrained/config.json', 'r') as f:
179 | config = json.load(f)
180 |
181 | # Command line arguments.
182 | args = parse_arguments()
183 |
184 | # Fix random seed.
185 | torch.manual_seed(args.random_seed)
186 | if use_cuda:
187 | torch.cuda.manual_seed(args.random_seed)
188 | np.random.seed(args.random_seed)
189 |
190 | # Networks.
191 | encoders = {}
192 | decoders = {}
193 | for feature in args.features:
194 | encoder, decoder = create_network_for_feature(feature, config, use_cuda)
195 |
196 | encoders[feature] = encoder
197 | decoders[feature] = decoder
198 |
199 | # Load initial checkpoint if needed.
200 | if args.initial_checkpoint is not None:
201 | checkpoint = torch.load(args.initial_checkpoint)
202 | for feature, state_dict in checkpoint['encoders']:
203 | encoders[feature].load_state_dict(state_dict)
204 | for feature, state_dict in checkpoint['decoders']:
205 | decoders[feature].load_state_dict(state_dict)
206 |
207 | # Dataset.
208 | training_dataset = Dataset(
209 | base_path=args.dataset_path,
210 | features=args.features
211 | )
212 | training_dataloader = DataLoader(
213 | training_dataset,
214 | batch_size=args.batch_size,
215 | num_workers=args.num_workers,
216 | shuffle=True
217 | )
218 |
219 | # Optimizer and loss.
220 | optimizer = optim.Adam(
221 | filter(
222 | lambda p: p.requires_grad,
223 | [param for _, enc in encoders.items() for param in enc.parameters()] +
224 | [param for _, dec in decoders.items() for param in dec.parameters()]
225 | ),
226 | lr=args.lr
227 | )
228 | loss_function = lambda encoders, decoders, batch, device: exhaustive_loss(
229 | encoders, decoders, batch, device,
230 | alpha=args.alpha, margin=args.margin
231 | )
232 |
233 | # Create the checkpoint directory.
234 | if os.path.isdir(args.checkpoint_directory):
235 | print('[Warning] Checkpoint directory already exists.')
236 | else:
237 | os.mkdir(args.checkpoint_directory)
238 |
239 | # Open the log file for writing
240 | if os.path.exists(os.path.join(args.checkpoint_directory, 'log.txt')):
241 | print('[Warning] Log file already exists.')
242 | log_file = open(os.path.join(args.checkpoint_directory, 'log.txt'), 'a+')
243 |
244 | # Training loop.
245 | train_loss_history = []
246 | for epoch_idx in range(1, args.num_epochs + 1):
247 | # Run training epoch.
248 | train_loss_history.append(
249 | run_epoch(
250 | encoders, decoders,
251 | loss_function,
252 | optimizer,
253 | training_dataloader,
254 | device,
255 | log_file
256 | )
257 | )
258 |
259 | # Save the current checkpoint
260 | checkpoint_path = os.path.join(
261 | args.checkpoint_directory,
262 | '%s.%02d.pth' % (args.checkpoint_prefix, epoch_idx)
263 | )
264 | checkpoint = {
265 | 'args': args,
266 | 'epoch_idx': epoch_idx,
267 | 'encoders': [(feature, enc.state_dict()) for feature, enc in encoders.items()],
268 | 'decoders': [(feature, dec.state_dict()) for feature, dec in decoders.items()],
269 | 'optimizer': optimizer.state_dict(),
270 | 'train_loss_history': train_loss_history
271 | }
272 | torch.save(checkpoint, checkpoint_path)
273 |
274 | # Close the log file
275 | log_file.close()
276 |
--------------------------------------------------------------------------------