├── .gitmodules
├── LICENSE
├── README.md
├── assets
└── animation.gif
├── datasets.py
├── datasets
├── README.md
└── datasetup.py
├── environment.yml
├── losses.py
├── main.py
├── models.py
└── utils.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "colmap"]
2 | path = colmap
3 | url = https://github.com/colmap/colmap.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Homography-Based Loss Function for Camera Pose Regression
2 | In this repository, we share our implementation of several camera pose regression
3 | loss functions in a simple end-to-end network similar to
4 | [PoseNet](https://openaccess.thecvf.com/content_iccv_2015/html/Kendall_PoseNet_A_Convolutional_ICCV_2015_paper.html).
5 | We implemented the homography-based loss functions introduced in our paper alongside PoseNet, Homoscedastic, Geometric
6 | and DSAC loss functions. We provide the code to train and test the network on Cambridge, 7-Scenes and custom COLMAP
7 | datasets.
8 |
9 | Our paper [Homography-Based Loss Function for Camera Pose Regression](https://arxiv.org/abs/2205.01937) is published in IEEE Robotics and Automation Letters 2022.
10 |
11 | 
12 | *
Convergence of our proposed Homography loss
*
13 | We show other losses convergence
14 | [on our YouTube channel](https://youtube.com/playlist?list=PLe92vnufKoYIIHrW5I268RYdX6aV4gTa6).
15 |
16 | ## Installation
17 |
18 | ### COLMAP dependancy
19 | This code relies on COLMAP for loading COLMAP models. To satisfy this dependancy, simply run:
20 | ```shell
21 | git submodule update --init
22 | ```
23 |
24 | ### Python environment setup
25 | We share an [Anaconda](https://www.anaconda.com) environment that can be easily installed by running:
26 | ```shell
27 | conda env create -f environment.yml
28 | ```
29 | Anaconda is easy to install and benefits from a lighter implementation named
30 | [Miniconda](https://docs.conda.io/en/latest/miniconda.html).
31 | Once the environment is installed you can activate it by running:
32 | ```shell
33 | conda activate homographyloss
34 | ```
35 |
36 | ### Dataset setup
37 | Have a look at the [datasets](datasets) folder to setup the datasets.
38 |
39 | ## Run relocalization
40 | The script [main.py](main.py) trains the network on a given scene and logs the performance of the model on the
41 | train set. It requires one positional argument: the path to the scene on which to train the model.
42 | For example, for training the model on the ShopFacade scene, simply run:
43 | ```shell
44 | python main.py datasets/ShopFacade
45 | ```
46 | Let's say you have a custom dataset in `datasets/mydataset` with the structure defined in [datasets](datasets):
47 | > - mydataset
48 | > - images
49 | > - frame001.jpg
50 | > - frame002.jpg
51 | > - frame003.jpg
52 | > - ...
53 | > - cameras.bin
54 | > - images.bin
55 | > - points3D.bin
56 | > - list_db.txt
57 | > - list_query.txt
58 |
59 | Then you might run the script on your custom dataset:
60 | ```shell
61 | python main.py datasets/mydataset
62 | ```
63 |
64 | Other available training options can be listed by running `python main.py -h`.
65 |
66 | ## Monitor training and test results
67 | Training and test metrics are saved in a `logs` directory. One can monitor them using tensorboard.
68 | Simply run in a new terminal:
69 | ```shell
70 | tensorboard --logdir logs
71 | ```
72 |
73 | All estimated poses are also saved in a CSV file in `logs/[scene]/[loss]/epochs_poses_log.csv`.
74 | For each epoch, each image and each set, we save the estimated pose in the following format:
75 | - `w_t_chat` is the camera-to-world translation of the image.
76 | - `chat_q_w` is the world-to-camera quaternion representing the rotation of the image.
77 |
78 | ## Acknowledgements
79 | This work was supported by [Ifremer](https://wwz.ifremer.fr/), [DYNI](https://dyni.pages.lis-lab.fr/) team of [LIS laboratory](https://www.lis-lab.fr/) and [COSMER laboratory](https://cosmer.univ-tln.fr/).
80 |
81 | ## License
82 | This code is released under the LGPLv3 licence. Please have a look at the licence file at the repository root.
83 |
84 | ## Citation
85 | If you use this work for your research, please cite:
86 | ```
87 | @article{boittiaux2022homographyloss,
88 | author={Boittiaux, Cl\'ementin and
89 | Marxer, Ricard and
90 | Dune, Claire and
91 | Arnaubec, Aur\'elien and
92 | Hugel, Vincent},
93 | journal={IEEE Robotics and Automation Letters},
94 | title={Homography-Based Loss Function for Camera Pose Regression},
95 | year={2022},
96 | volume={7},
97 | number={3},
98 | pages={6242-6249},
99 | }
100 | ```
101 |
--------------------------------------------------------------------------------
/assets/animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clementinboittiaux/homography-loss-function/919c255483e3e21ae3a38b4b7209772e770232e2/assets/animation.gif
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | import tqdm
9 | from PIL import Image
10 | from kornia.geometry.conversions import (
11 | rotation_matrix_to_quaternion,
12 | quaternion_to_rotation_matrix,
13 | QuaternionCoeffOrder
14 | )
15 | from torch.nn.functional import normalize
16 | from torch.utils.data import Dataset
17 | from torchvision import transforms
18 |
19 | from colmap.scripts.python.read_write_model import read_model
20 |
21 | # Image preprocessing pipeline according to PyTorch implementation
22 | preprocess = transforms.Compose([
23 | transforms.Resize(256),
24 | transforms.ToTensor(),
25 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26 | ])
27 |
28 |
29 | def collate_fn(views):
30 | """
31 | Transforms list of dicts [{key1: value1, key2:value2}, {key1: value3, key2:value4}]
32 | into a dict of lists {key1: [value1, value3], key2: [value2, value4]}.
33 | Then stacks batch-compatible values into tensor batchs.
34 | """
35 | batch = {key: [] for key in views[0].keys()}
36 | for view in views:
37 | for key, value in view.items():
38 | batch[key].append(value)
39 | for key, value in batch.items():
40 | if key not in ['w_P', 'c_p', 'image_file']:
41 | batch[key] = torch.stack(value)
42 | return batch
43 |
44 |
45 | class RelocDataset(Dataset):
46 | """
47 | Dataset template class for use with PyTorch DataLoader class.
48 | """
49 |
50 | def __init__(self, dataset):
51 | """
52 | `dataset` must be a list of dicts providing localization data for each image.
53 | Dicts must provide:
54 | {
55 | 'image_file': name of image file
56 | 'image': torch.tensor image with shape (3, height, width)
57 | 'w_t_c': torch.tensor camera-to-world translation with shape (3, 1)
58 | 'c_q_w': torch.tensor world-to-camera quaternion rotation with shape (4,) in format wxyz
59 | 'c_R_w': torch.tensor world-to-camera rotation matrix with shape (3, 3)
60 | (can be computed with quaternion_to_R)
61 | 'K': torch.tensor camera intrinsics matrix with shape (3, 3)
62 | 'w_P': torch.tensor 3D observations of the image in the world frame with shape (*, 3)
63 | 'c_p': reprojections of the 3D observations in the camera view (in pixels) with shape (*, 2)
64 | 'xmin': minimum depth of observations
65 | 'xmax': maximum depth of observations
66 | }
67 | """
68 | self.data = dataset
69 |
70 | def __len__(self):
71 | return len(self.data)
72 |
73 | def __getitem__(self, idx):
74 | return {
75 | 'image_file': self.data[idx]['image_file'],
76 | 'image': self.data[idx]['image'],
77 | 'w_t_c': self.data[idx]['w_t_c'],
78 | 'c_q_w': self.data[idx]['c_q_w'],
79 | 'c_R_w': self.data[idx]['c_R_w'],
80 | 'K': self.data[idx]['K'],
81 | 'w_P': self.data[idx]['w_P'],
82 | 'c_p': self.data[idx]['c_p'],
83 | 'xmin': self.data[idx]['xmin'],
84 | 'xmax': self.data[idx]['xmax']
85 | }
86 |
87 |
88 | class CambridgeDataset:
89 | """
90 | Template class to load every scene of Cambridge dataset.
91 | """
92 |
93 | def __init__(self, path, xmin_percentile, xmax_percentile):
94 | """
95 | `path` is the path to the dataset directory,
96 | e.g. for King's College: "/home/data/KingsCollege".
97 | Creates 6 attributes:
98 | - 2 lists of dicts (train and test) providing localization data for each image.
99 | - 4 parameters (train and test) for minimum and maximum depths of observations.
100 | """
101 | views = []
102 | scene_coordinates = []
103 | with open(os.path.join(path, 'reconstruction.nvm'), mode='r') as file:
104 |
105 | # Skip first two lines
106 | for _ in range(2):
107 | file.readline()
108 |
109 | # `n_views` is the number of images
110 | n_views = int(file.readline())
111 |
112 | # For each image, NVM format is:
113 | # 0
114 | for _ in range(n_views):
115 | line = file.readline().split()
116 |
117 | f = float(line[1])
118 | K = torch.tensor([
119 | [f, 0, 1920 / 2],
120 | [0, f, 1080 / 2],
121 | [0, 0, 1]
122 | ], dtype=torch.float32)
123 | views.append({
124 | 'image_file': line[0],
125 | 'K': K,
126 | 'observations_ids': []
127 | })
128 |
129 | # Skip one line
130 | file.readline()
131 |
132 | # `n_points` is the number of scene coordinates
133 | n_points = int(file.readline())
134 |
135 | # For each scene coordinate, SVM format is:
136 | #
137 | for i in range(n_points):
138 |
139 | line = file.readline().split()
140 |
141 | scene_coordinates.append(torch.tensor(list(map(float, line[:3]))))
142 |
143 | # `n_obs` is the number of images where the scene coordinate is observed
144 | n_obs = int(line[6])
145 |
146 | # Each measurement is
147 | #
148 | for n in range(n_obs):
149 | views[int(line[7 + n * 4])]['observations_ids'].append(i)
150 |
151 | views = {view.pop('image_file'): view for view in views}
152 | scene_coordinates = torch.stack(scene_coordinates)
153 |
154 | train_df = pd.read_csv(os.path.join(path, 'dataset_train.txt'), sep=' ', skiprows=1)
155 | test_df = pd.read_csv(os.path.join(path, 'dataset_test.txt'), sep=' ', skiprows=1)
156 |
157 | train_data = []
158 | test_data = []
159 | train_global_depths = []
160 | test_global_depths = []
161 |
162 | print('Loading images from dataset. This may take a while...')
163 | for data, df, global_depths in [(train_data, train_df, train_global_depths),
164 | (test_data, test_df, test_global_depths)]:
165 | for line in tqdm.tqdm(df.values):
166 | image_file = line[0]
167 | image = preprocess(Image.open(os.path.join(path, image_file)))
168 | w_t_c = torch.tensor(line[1:4].tolist()).view(3, 1)
169 | c_q_w = normalize(torch.tensor(line[4:8].tolist()), dim=0)
170 | c_R_w = quaternion_to_rotation_matrix(c_q_w, order=QuaternionCoeffOrder.WXYZ)
171 | view = views[os.path.splitext(image_file)[0] + '.jpg']
172 | w_P = scene_coordinates[view['observations_ids']]
173 | c_P = c_R_w @ (w_P.T - w_t_c)
174 | c_p = view['K'] @ c_P
175 | c_p = c_p[:2] / c_p[2]
176 |
177 | args_inliers = torch.where(torch.logical_and(
178 | torch.logical_and(
179 | torch.logical_and(c_P[2] > 0.2, c_P[2] < 1000),
180 | torch.logical_and(c_P[0].abs() < 1000, c_P[1].abs() < 1000)
181 | ),
182 | torch.logical_and(
183 | torch.logical_and(c_p[0] > 0, c_p[0] < 1920),
184 | torch.logical_and(c_p[1] > 0, c_p[1] < 1080)
185 | )
186 | ))[0]
187 |
188 | if args_inliers.shape[0] < 10:
189 | tqdm.tqdm.write(f'Not using image {image_file}: [{args_inliers.shape[0]}/{w_P.shape[0]}] scene '
190 | f'coordinates inliers')
191 | elif w_t_c.abs().max() > 1000:
192 | tqdm.tqdm.write(f'Not using image {image_file}: t is {w_t_c.numpy()}')
193 | else:
194 | if args_inliers.shape[0] != w_P.shape[0]:
195 | tqdm.tqdm.write(f'Eliminating outliers in image {image_file}: '
196 | f'[{args_inliers.shape[0]}/{w_P.shape[0]}] scene coordinates inliers')
197 |
198 | depths = torch.sort(c_P.T[args_inliers][:, 2]).values
199 | global_depths.append(depths)
200 |
201 | data.append({
202 | 'image_file': image_file,
203 | 'image': image,
204 | 'w_t_c': w_t_c,
205 | 'c_q_w': c_q_w,
206 | 'c_R_w': c_R_w,
207 | 'w_P': w_P[args_inliers],
208 | 'c_p': c_p.T[args_inliers],
209 | 'K': view['K'],
210 | 'xmin': depths[int(xmin_percentile * (depths.shape[0] - 1))],
211 | 'xmax': depths[int(xmax_percentile * (depths.shape[0] - 1))]
212 | })
213 |
214 | train_global_depths = torch.sort(torch.hstack(train_global_depths)).values
215 | test_global_depths = torch.sort(torch.hstack(test_global_depths)).values
216 | self.train_global_xmin = train_global_depths[int(xmin_percentile * (train_global_depths.shape[0] - 1))]
217 | self.train_global_xmax = train_global_depths[int(xmax_percentile * (train_global_depths.shape[0] - 1))]
218 | self.test_global_xmin = test_global_depths[int(xmin_percentile * (test_global_depths.shape[0] - 1))]
219 | self.test_global_xmax = test_global_depths[int(xmax_percentile * (test_global_depths.shape[0] - 1))]
220 | self.train_data = train_data
221 | self.test_data = test_data
222 |
223 |
224 | class SevenScenesDataset:
225 | """
226 | Template class to load every scene from 7-Scenes dataset
227 | """
228 |
229 | def __init__(self, path, xmin_percentile, xmax_percentile):
230 |
231 | # Camera intrinsics
232 | K = np.array([
233 | [585, 0, 320],
234 | [0, 585, 240],
235 | [0, 0, 1]
236 | ], dtype=np.float64)
237 | K_inv = np.linalg.inv(K)
238 | K_torch = torch.tensor(K, dtype=torch.float32)
239 |
240 | # Grid of pixels
241 | u = np.arange(640) + 0.5
242 | v = np.arange(480) + 0.5
243 | u, v = np.meshgrid(u, v)
244 |
245 | # Array of all pixel positions in pixels
246 | c_p_px = np.hstack([
247 | u.reshape(-1, 1),
248 | v.reshape(-1, 1),
249 | np.ones((u.size, 1))
250 | ])
251 | c_p_px_torch = torch.tensor(c_p_px[:, :2], dtype=torch.float32)
252 |
253 | # Array of all pixels in the sensor plane
254 | c_p = K_inv @ c_p_px.T
255 |
256 | train_data = []
257 | test_data = []
258 | train_global_depths = []
259 | test_global_depths = []
260 |
261 | for data, file, global_depths in [(train_data, 'TrainSplit.txt', train_global_depths),
262 | (test_data, 'TestSplit.txt', test_global_depths)]:
263 |
264 | with open(os.path.join(path, file), mode='r') as f:
265 | seqs = [int(line[8:]) for line in f]
266 |
267 | for seq in seqs:
268 |
269 | seq_dir = os.path.join(path, f'seq-{seq:02d}')
270 |
271 | print(f'Loading seq-{seq:02d}')
272 |
273 | for frame in tqdm.tqdm(glob.glob(os.path.join(seq_dir, '*.color.png'))):
274 |
275 | frame = os.path.basename(frame).split('.')[0]
276 | image_path = os.path.join(seq_dir, f'{frame}.color.png')
277 | pose_path = os.path.join(seq_dir, f'{frame}.pose.txt')
278 | depth_path = os.path.join(seq_dir, f'{frame}.depth.png')
279 |
280 | image = preprocess(Image.open(image_path))
281 |
282 | # Read camera-to-world pose
283 | w_M_c = np.zeros((4, 4))
284 | with open(pose_path, mode='r') as f:
285 | for i, line in enumerate(f):
286 | w_M_c[i] = list(map(float, line.strip().split('\t')))
287 |
288 | # Read depth map
289 | Z = np.array(Image.open(depth_path)).reshape(-1, 1)
290 |
291 | # Filter outliers
292 | args_inliers = np.logical_and(Z > 0, Z != 65535).squeeze()
293 |
294 | # Unproject pixels
295 | c_P = c_p.T[args_inliers] * (Z[args_inliers] / 1000)
296 |
297 | # Convert 3D points from camera to world frame
298 | w_P = w_M_c[:3, :3] @ c_P.T + w_M_c[:3, 3:4]
299 |
300 | # Building rotation matrix and its quaternion
301 | w_M_c = torch.tensor(w_M_c)
302 | c_R_w = w_M_c[:3, :3].T.contiguous()
303 | c_q_w = rotation_matrix_to_quaternion(c_R_w, order=QuaternionCoeffOrder.WXYZ)
304 |
305 | # Keep the quaternion on the top hypersphere
306 | if c_q_w[0] < 0:
307 | c_q_w *= -1
308 |
309 | # Sort depths
310 | depths = Z[args_inliers].flatten()
311 | global_depths.append(depths)
312 | depths = np.sort(depths)
313 |
314 | data.append({
315 | 'image_file': f'seq-{seq:02d}/{frame}.color.png',
316 | 'image': image,
317 | 'w_t_c': w_M_c[:3, 3:4].float(),
318 | 'c_q_w': c_q_w.float(),
319 | 'c_R_w': c_R_w.float(),
320 | 'w_P': torch.tensor(w_P.T, dtype=torch.float32),
321 | 'c_p': c_p_px_torch[args_inliers],
322 | 'K': K_torch,
323 | 'xmin': torch.tensor(
324 | depths[int(xmin_percentile * (depths.size - 1))] / 1000, dtype=torch.float32
325 | ),
326 | 'xmax': torch.tensor(
327 | depths[int(xmax_percentile * (depths.size - 1))] / 1000, dtype=torch.float32
328 | )
329 | })
330 |
331 | # Sort global depths
332 | print('Sorting depths, this may take a while...')
333 | train_global_depths = np.sort(np.hstack(train_global_depths))
334 | test_global_depths = np.sort(np.hstack(test_global_depths))
335 |
336 | self.train_global_xmin = torch.tensor(
337 | train_global_depths[int(xmin_percentile * (train_global_depths.size - 1))] / 1000,
338 | dtype=torch.float32
339 | )
340 | self.train_global_xmax = torch.tensor(
341 | train_global_depths[int(xmax_percentile * (train_global_depths.size - 1))] / 1000,
342 | dtype=torch.float32
343 | )
344 | self.test_global_xmin = torch.tensor(
345 | test_global_depths[int(xmin_percentile * (test_global_depths.size - 1))] / 1000,
346 | dtype=torch.float32
347 | )
348 | self.test_global_xmax = torch.tensor(
349 | test_global_depths[int(xmax_percentile * (test_global_depths.size - 1))] / 1000,
350 | dtype=torch.float32
351 | )
352 | self.train_data = train_data
353 | self.test_data = test_data
354 |
355 |
356 | class COLMAPDataset:
357 | """
358 | WIP class to load COLMAP scenes. Only RADIAL camera model is supported.
359 | """
360 |
361 | def __init__(self, path, xmin_percentile, xmax_percentile):
362 | """
363 | `path` to a folder containing:
364 | - COLMAP model
365 | - an `images` directory containing all images
366 | - two lists named `list_db.txt` and `list_query.txt` containing
367 | respectively the names of database and query images (one name per line)
368 | """
369 |
370 | print('COLMAPDataset is work in progress, only supports RADIAL camera model!')
371 |
372 | images_path = os.path.join(path, 'images')
373 | list_query = os.path.join(path, 'list_query.txt')
374 | list_db = os.path.join(path, 'list_db.txt')
375 |
376 | cameras, images, points3D = read_model(path)
377 |
378 | image_name_to_id = {image.name: i for i, image in images.items()}
379 |
380 | scene_coordinates = torch.zeros(max(points3D.keys()) + 1, 3, dtype=torch.float64)
381 | for i, point3D in points3D.items():
382 | scene_coordinates[i] = torch.tensor(point3D.xyz)
383 |
384 | train_data = []
385 | test_data = []
386 | train_global_depths = []
387 | test_global_depths = []
388 |
389 | for data, file, global_depths in zip([train_data, test_data],
390 | [list_db, list_query],
391 | [train_global_depths, test_global_depths]):
392 | with open(file, 'r') as f:
393 | image_names = f.read().splitlines()
394 |
395 | for image_name in tqdm.tqdm(image_names):
396 |
397 | image = images[image_name_to_id[image_name]]
398 | camera = cameras[image.camera_id]
399 |
400 | im = cv2.imread(os.path.join(images_path, image_name))
401 |
402 | f, u0, v0, k1, k2 = camera.params
403 | K = np.array([
404 | [f, 0, u0],
405 | [0, f, v0],
406 | [0, 0, 1]
407 | ])
408 | dist_coeffs = np.array([k1, k2, 0, 0])
409 | new_K, roi = cv2.getOptimalNewCameraMatrix(
410 | cameraMatrix=K,
411 | distCoeffs=dist_coeffs,
412 | imageSize=im.shape[:2][::-1],
413 | alpha=0,
414 | centerPrincipalPoint=True
415 | )
416 | new_K = torch.tensor(new_K)
417 | new_K[0, 2] = camera.width / 2
418 | new_K[1, 2] = camera.height / 2
419 |
420 | # Undistort image and center its principal point
421 | im = cv2.undistort(im, K, dist_coeffs, newCameraMatrix=new_K.numpy())
422 | im = preprocess(Image.fromarray(im[:, :, ::-1]))
423 |
424 | c_t_w = torch.tensor(image.tvec).view(3, 1)
425 | c_q_w = torch.tensor(image.qvec)
426 |
427 | # Keep the quaternion on the top hypersphere
428 | if c_q_w[0] < 0:
429 | c_q_w *= -1
430 |
431 | c_R_w = quaternion_to_rotation_matrix(c_q_w, order=QuaternionCoeffOrder.WXYZ)
432 | w_t_c = -c_R_w.T @ c_t_w
433 |
434 | w_P = scene_coordinates[[i for i in image.point3D_ids if i != -1]]
435 | c_P = c_R_w @ (w_P.T - w_t_c)
436 | c_p = new_K @ c_P
437 | c_p = c_p[:2] / c_p[2]
438 |
439 | depths = torch.sort(c_P[2]).values
440 | global_depths.append(depths.float())
441 |
442 | data.append({
443 | 'image_file': image_name,
444 | 'image': im,
445 | 'w_t_c': w_t_c.float(),
446 | 'c_q_w': c_q_w.float(),
447 | 'c_R_w': c_R_w.float(),
448 | 'w_P': w_P.float(),
449 | 'c_p': c_p.T.float(),
450 | 'K': new_K.float(),
451 | 'xmin': depths[int(xmin_percentile * (depths.shape[0] - 1))].float(),
452 | 'xmax': depths[int(xmax_percentile * (depths.shape[0] - 1))].float()
453 | })
454 |
455 | train_global_depths = torch.sort(torch.hstack(train_global_depths)).values
456 | test_global_depths = torch.sort(torch.hstack(test_global_depths)).values
457 | self.train_global_xmin = train_global_depths[int(xmin_percentile * (train_global_depths.shape[0] - 1))]
458 | self.train_global_xmax = train_global_depths[int(xmax_percentile * (train_global_depths.shape[0] - 1))]
459 | self.test_global_xmin = test_global_depths[int(xmin_percentile * (test_global_depths.shape[0] - 1))]
460 | self.test_global_xmax = test_global_depths[int(xmax_percentile * (test_global_depths.shape[0] - 1))]
461 | self.train_data = train_data
462 | self.test_data = test_data
463 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Dataset setup
2 |
3 | Here you can find instructions to setup datasets for use with this code.
4 |
5 | ## Cambridge and 7-Scenes
6 |
7 | We provide the script [datasetup.py](datasetup.py) for setting up Cambridge and 7-Scenes datasets. The script can be
8 | called with either the name of the dataset to setup, *e.g.*, `7-Scenes`, or the name of a specific scene, *e.g.*,
9 | `KingsCollege`. For example, if you want to setup the whole Cambridge dataset:
10 | ```shell
11 | python datasetup.py Cambridge
12 | ```
13 | Or if you want to only setup the *chess* scene of 7-Scenes dataset:
14 | ```shell
15 | python datasetup.py chess
16 | ```
17 | All possibilities can be accessed by running:
18 | ```shell
19 | python datasetup.py -h
20 | ```
21 |
22 |
23 | ## Custom dataset
24 |
25 | We also support custom datasets in **COLMAP** model format.
26 | ⚠️ Please note that only **RADIAL** camera models are supported for now.
27 |
28 | The custom dataset folder must contain:
29 | - The COLMAP model: `cameras`, `images` and `points3D` files in `.bin` or `.txt` format.
30 | - A folder named `images` containing all images in the model.
31 | - A file named `list_db.txt` with the name of all the images used for training, one image name per line.
32 | - A file named `list_query.txt` with the name of all the images used for testing, one image name per line.
33 |
34 | The final outline of the folder should look like this:
35 | > - mydataset
36 | > - images
37 | > - frame001.jpg
38 | > - frame002.jpg
39 | > - frame003.jpg
40 | > - ...
41 | > - cameras.bin
42 | > - images.bin
43 | > - points3D.bin
44 | > - list_db.txt
45 | > - list_query.txt
46 |
47 | An example of `list_db.txt` or `list_query.txt`:
48 | ```text
49 | frame001.jpg
50 | frame002.jpg
51 | frame003.jpg
52 | ...
53 | ```
54 |
--------------------------------------------------------------------------------
/datasets/datasetup.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import zipfile
5 | from collections import namedtuple
6 |
7 | from torchvision.datasets.utils import download_and_extract_archive
8 |
9 | datasets = {
10 | 'Cambridge': [
11 | 'GreatCourt', 'KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch', 'Street'
12 | ],
13 | '7-Scenes': [
14 | 'chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs'
15 | ]
16 | }
17 | Scene = namedtuple('Scene', ['url', 'dataset'])
18 | scenes = {
19 | 'GreatCourt': Scene(
20 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251291/GreatCourt.zip',
21 | dataset='Cambridge'
22 | ),
23 | 'KingsCollege': Scene(
24 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251342/KingsCollege.zip',
25 | dataset='Cambridge'
26 | ),
27 | 'OldHospital': Scene(
28 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251340/OldHospital.zip',
29 | dataset='Cambridge'
30 | ),
31 | 'ShopFacade': Scene(
32 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251336/ShopFacade.zip',
33 | dataset='Cambridge'
34 | ),
35 | 'StMarysChurch': Scene(
36 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251294/StMarysChurch.zip',
37 | dataset='Cambridge'
38 | ),
39 | 'Street': Scene(
40 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251292/Street.zip',
41 | dataset='Cambridge'
42 | ),
43 | 'chess': Scene(
44 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/chess.zip',
45 | dataset='7-Scenes'
46 | ),
47 | 'fire': Scene(
48 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/fire.zip',
49 | dataset='7-Scenes'
50 | ),
51 | 'heads': Scene(
52 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/heads.zip',
53 | dataset='7-Scenes'
54 | ),
55 | 'office': Scene(
56 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/office.zip',
57 | dataset='7-Scenes'
58 | ),
59 | 'pumpkin': Scene(
60 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/pumpkin.zip',
61 | dataset='7-Scenes'
62 | ),
63 | 'redkitchen': Scene(
64 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/redkitchen.zip',
65 | dataset='7-Scenes'
66 | ),
67 | 'stairs': Scene(
68 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/stairs.zip',
69 | dataset='7-Scenes'
70 | )
71 | }
72 |
73 |
74 | def setup_scene(scene_str):
75 | scene = scenes[scene_str]
76 | download_and_extract_archive(scene.url, scene.dataset)
77 | os.remove(os.path.join(scene.dataset, scene.url.split('/')[-1]))
78 | if scene_str in datasets['7-Scenes']:
79 | for file in glob.glob(os.path.join(scene.dataset, scene_str, '*.zip')):
80 | with zipfile.ZipFile(file, 'r') as f:
81 | members = [m for m in f.namelist() if os.path.basename(m) != 'Thumbs.db']
82 | f.extractall(os.path.join(scene.dataset, scene_str), members=members)
83 | os.remove(file)
84 |
85 |
86 | if __name__ == '__main__':
87 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
88 | parser.add_argument(
89 | 'dataset',
90 | choices=list(datasets.keys()) + list(scenes.keys()),
91 | help='name of the dataset or single scene to setup'
92 | )
93 | args = parser.parse_args()
94 |
95 | if args.dataset in datasets:
96 | for scene_name in datasets[args.dataset]:
97 | setup_scene(scene_name)
98 | else:
99 | setup_scene(args.dataset)
100 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: homographyloss
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=1_llvm
10 | - absl-py=1.0.0=pyhd8ed1ab_0
11 | - aiohttp=3.8.1=py39h3811e60_0
12 | - aiosignal=1.2.0=pyhd8ed1ab_0
13 | - alsa-lib=1.2.3=h516909a_0
14 | - async-timeout=4.0.2=pyhd8ed1ab_0
15 | - attrs=21.4.0=pyhd8ed1ab_0
16 | - blas=1.0=mkl
17 | - blinker=1.4=py_1
18 | - bottleneck=1.3.2=py39hdd57654_1
19 | - brotlipy=0.7.0=py39h27cfd23_1003
20 | - bzip2=1.0.8=h7b6447c_0
21 | - c-ares=1.18.1=h7f98852_0
22 | - ca-certificates=2021.10.8=ha878542_0
23 | - cachetools=5.0.0=pyhd8ed1ab_0
24 | - cairo=1.16.0=ha12eb4b_1010
25 | - certifi=2021.10.8=py39hf3d152e_1
26 | - cffi=1.15.0=py39hd667e15_1
27 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
28 | - click=8.0.4=py39hf3d152e_0
29 | - colorama=0.4.4=pyh9f0ad1d_0
30 | - cryptography=36.0.0=py39h9ce1e76_0
31 | - cudatoolkit=11.3.1=h2bc3f7f_2
32 | - dbus=1.13.6=h5008d03_3
33 | - expat=2.4.7=h27087fc_0
34 | - ffmpeg=4.3.2=h37c90e5_3
35 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
36 | - font-ttf-inconsolata=3.000=h77eed37_0
37 | - font-ttf-source-code-pro=2.038=h77eed37_0
38 | - font-ttf-ubuntu=0.83=hab24e00_0
39 | - fontconfig=2.13.96=h8e229c2_1
40 | - fonts-conda-ecosystem=1=0
41 | - fonts-conda-forge=1=0
42 | - freeglut=3.2.2=h9c3ff4c_1
43 | - freetype=2.11.0=h70c0345_0
44 | - frozenlist=1.3.0=py39h3811e60_0
45 | - gettext=0.19.8.1=h73d1719_1008
46 | - giflib=5.2.1=h7b6447c_0
47 | - gmp=6.2.1=h2531618_2
48 | - gnutls=3.6.15=he1e5248_0
49 | - google-auth=2.6.0=pyh6c4a22f_1
50 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
51 | - graphite2=1.3.13=h58526e2_1001
52 | - grpcio=1.44.0=py39hff7568b_0
53 | - gst-plugins-base=1.18.5=hf529b03_3
54 | - gstreamer=1.18.5=h9f60fe5_3
55 | - harfbuzz=3.4.0=hb4a5f5f_0
56 | - hdf5=1.12.1=nompi_h2386368_104
57 | - icu=69.1=h9c3ff4c_0
58 | - idna=3.3=pyhd3eb1b0_0
59 | - importlib-metadata=4.11.2=py39hf3d152e_0
60 | - intel-openmp=2021.4.0=h06a4308_3561
61 | - jasper=2.0.33=ha77e612_0
62 | - jbig=2.1=h7f98852_2003
63 | - jpeg=9d=h7f8727e_0
64 | - keyutils=1.6.1=h166bdaf_0
65 | - kornia=0.6.3=pyhd8ed1ab_0
66 | - krb5=1.19.2=h3790be6_4
67 | - lame=3.100=h7b6447c_0
68 | - lcms2=2.12=h3be6417_0
69 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
70 | - lerc=3.0=h9c3ff4c_0
71 | - libblas=3.9.0=12_linux64_mkl
72 | - libcblas=3.9.0=12_linux64_mkl
73 | - libclang=13.0.1=default_hc23dcda_0
74 | - libcurl=7.82.0=h7bff187_0
75 | - libdeflate=1.8=h7f98852_0
76 | - libedit=3.1.20191231=he28a2e2_2
77 | - libev=4.33=h516909a_1
78 | - libevent=2.1.10=h9b69904_4
79 | - libffi=3.4.2=h7f98852_5
80 | - libgcc-ng=11.2.0=h1d223b6_13
81 | - libgfortran-ng=11.2.0=h69a702a_13
82 | - libgfortran5=11.2.0=h5c6108e_13
83 | - libglib=2.70.2=h174f98d_4
84 | - libglu=9.0.0=he1b5a44_1001
85 | - libiconv=1.16=h516909a_0
86 | - libidn2=2.3.2=h7f8727e_0
87 | - liblapack=3.9.0=12_linux64_mkl
88 | - liblapacke=3.9.0=12_linux64_mkl
89 | - libllvm13=13.0.1=hf817b99_2
90 | - libnghttp2=1.47.0=h727a467_0
91 | - libnsl=2.0.0=h7f98852_0
92 | - libogg=1.3.4=h7f98852_1
93 | - libopencv=4.5.5=py39h7d09d5f_0
94 | - libopus=1.3.1=h7f98852_1
95 | - libpng=1.6.37=hbc83047_0
96 | - libpq=14.2=hd57d9b9_0
97 | - libprotobuf=3.19.4=h780b84a_0
98 | - libssh2=1.10.0=ha56f1ee_2
99 | - libstdcxx-ng=11.2.0=he4da1e4_13
100 | - libtasn1=4.16.0=h27cfd23_0
101 | - libtiff=4.3.0=h6f004c6_2
102 | - libunistring=0.9.10=h27cfd23_0
103 | - libuuid=2.32.1=h7f98852_1000
104 | - libuv=1.40.0=h7b6447c_0
105 | - libvorbis=1.3.7=h9c3ff4c_0
106 | - libwebp=1.2.2=h55f646e_0
107 | - libwebp-base=1.2.2=h7f8727e_0
108 | - libxcb=1.13=h7f98852_1004
109 | - libxkbcommon=1.0.3=he3ba5ed_0
110 | - libxml2=2.9.12=h885dcf4_1
111 | - libzlib=1.2.11=h36c2ea0_1013
112 | - llvm-openmp=13.0.1=he0ac6c6_1
113 | - lz4-c=1.9.3=h295c915_1
114 | - markdown=3.3.6=pyhd8ed1ab_0
115 | - mkl=2021.4.0=h06a4308_640
116 | - mkl-service=2.4.0=py39h7f8727e_0
117 | - mkl_fft=1.3.1=py39hd3c417c_0
118 | - mkl_random=1.2.2=py39h51133e4_0
119 | - multidict=6.0.2=py39h3811e60_0
120 | - mysql-common=8.0.28=ha770c72_0
121 | - mysql-libs=8.0.28=hfa10184_0
122 | - ncurses=6.3=h7f8727e_2
123 | - nettle=3.7.3=hbbd107a_1
124 | - nspr=4.32=h9c3ff4c_1
125 | - nss=3.74=hb5efdd6_0
126 | - numexpr=2.8.1=py39h6abb31d_0
127 | - numpy=1.21.2=py39h20f2e39_0
128 | - numpy-base=1.21.2=py39h79a1101_0
129 | - oauthlib=3.2.0=pyhd8ed1ab_0
130 | - opencv=4.5.5=py39hf3d152e_0
131 | - openh264=2.1.1=h4ff587b_0
132 | - openssl=1.1.1l=h7f98852_0
133 | - packaging=21.3=pyhd8ed1ab_0
134 | - pandas=1.4.1=py39h295c915_0
135 | - pcre=8.45=h9c3ff4c_0
136 | - pillow=9.0.1=py39h22f2fdc_0
137 | - pip=21.2.4=py39h06a4308_0
138 | - pixman=0.40.0=h36c2ea0_0
139 | - protobuf=3.19.4=py39he80948d_0
140 | - pthread-stubs=0.4=h36c2ea0_1001
141 | - py-opencv=4.5.5=py39hef51801_0
142 | - pyasn1=0.4.8=py_0
143 | - pyasn1-modules=0.2.7=py_0
144 | - pycparser=2.21=pyhd3eb1b0_0
145 | - pyjwt=2.3.0=pyhd8ed1ab_1
146 | - pyopenssl=22.0.0=pyhd3eb1b0_0
147 | - pyparsing=3.0.7=pyhd8ed1ab_0
148 | - pysocks=1.7.1=py39h06a4308_0
149 | - python=3.9.10=h85951f9_2_cpython
150 | - python-dateutil=2.8.1=py_0
151 | - python_abi=3.9=2_cp39
152 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
153 | - pytorch-mutex=1.0=cuda
154 | - pytz=2020.1=py_0
155 | - pyu2f=0.1.5=pyhd8ed1ab_0
156 | - qt=5.12.9=ha98a1a1_5
157 | - readline=8.1.2=h7f8727e_1
158 | - requests=2.27.1=pyhd3eb1b0_0
159 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0
160 | - rsa=4.8=pyhd8ed1ab_0
161 | - setuptools=58.0.4=py39h06a4308_0
162 | - six=1.16.0=pyhd3eb1b0_1
163 | - sqlite=3.37.2=hc218d9a_0
164 | - tensorboard=2.8.0=pyhd8ed1ab_1
165 | - tensorboard-data-server=0.6.0=py39h95dcef6_1
166 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
167 | - tk=8.6.11=h1ccaba5_0
168 | - torchaudio=0.11.0=py39_cu113
169 | - torchvision=0.12.0=py39_cu113
170 | - tqdm=4.63.0=pyhd8ed1ab_0
171 | - typing-extensions=3.10.0.2=hd3eb1b0_0
172 | - typing_extensions=3.10.0.2=pyh06a4308_0
173 | - tzdata=2021e=hda174b7_0
174 | - urllib3=1.26.8=pyhd3eb1b0_0
175 | - werkzeug=2.0.3=pyhd8ed1ab_1
176 | - wheel=0.37.1=pyhd3eb1b0_0
177 | - x264=1!161.3030=h7f98852_1
178 | - xorg-fixesproto=5.0=h7f98852_1002
179 | - xorg-inputproto=2.3.2=h7f98852_1002
180 | - xorg-kbproto=1.0.7=h7f98852_1002
181 | - xorg-libice=1.0.10=h7f98852_0
182 | - xorg-libsm=1.2.3=hd9c2040_1000
183 | - xorg-libx11=1.7.2=h7f98852_0
184 | - xorg-libxau=1.0.9=h7f98852_0
185 | - xorg-libxdmcp=1.1.3=h7f98852_0
186 | - xorg-libxext=1.3.4=h7f98852_1
187 | - xorg-libxfixes=5.0.3=h7f98852_1004
188 | - xorg-libxi=1.7.10=h7f98852_0
189 | - xorg-libxrender=0.9.10=h7f98852_1003
190 | - xorg-renderproto=0.11.1=h7f98852_1002
191 | - xorg-xextproto=7.3.0=h7f98852_1002
192 | - xorg-xproto=7.0.31=h7f98852_1007
193 | - xz=5.2.5=h7b6447c_0
194 | - yarl=1.7.2=py39h3811e60_1
195 | - zipp=3.7.0=pyhd8ed1ab_1
196 | - zlib=1.2.11=h36c2ea0_1013
197 | - zstd=1.5.2=ha95c52a_0
198 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils import angle_between_quaternions, l1_loss, l2_loss, compute_ABC, project
4 |
5 |
6 | class LocalHomographyLoss(torch.nn.Module):
7 | def __init__(self, device='cpu'):
8 | super().__init__()
9 |
10 | # `c_n` is the normal vector of the plane inducing the homographies in the ground-truth camera frame
11 | self.c_n = torch.tensor([0, 0, -1], dtype=torch.float32, device=device).view(3, 1)
12 |
13 | # `eye` is the (3, 3) identity matrix
14 | self.eye = torch.eye(3, device=device)
15 |
16 | def __call__(self, batch):
17 | A, B, C = compute_ABC(batch['w_t_c'], batch['c_R_w'], batch['w_t_chat'], batch['chat_R_w'], self.c_n, self.eye)
18 |
19 | xmin = batch['xmin'].view(-1, 1, 1)
20 | xmax = batch['xmax'].view(-1, 1, 1)
21 | B_weight = torch.log(xmax / xmin) / (xmax - xmin)
22 | C_weight = xmin * xmax
23 |
24 | error = A + B * B_weight + C / C_weight
25 | error = error.diagonal(dim1=1, dim2=2).sum(dim=1).mean()
26 | return error
27 |
28 |
29 | class GlobalHomographyLoss(torch.nn.Module):
30 | def __init__(self, xmin, xmax, device='cpu'):
31 | """
32 | `xmin` is the minimum distance of observations across all frames.
33 | `xmax` is the maximum distance of observations across all frames.
34 | """
35 | super().__init__()
36 |
37 | # `xmin` is the minimum distance of observations in all frames
38 | xmin = torch.tensor(xmin, dtype=torch.float32, device=device)
39 |
40 | # `xmax` is the maximum distance of observations in all frames
41 | xmax = torch.tensor(xmax, dtype=torch.float32, device=device)
42 |
43 | # `B_weight` and `C_weight` are the weigths of matrices A and B computed from `xmin` and `xmax`
44 | self.B_weight = torch.log(xmin / xmax) / (xmax - xmin)
45 | self.C_weight = xmin * xmax
46 |
47 | # `c_n` is the normal vector of the plane inducing the homographies in the ground-truth camera frame
48 | self.c_n = torch.tensor([0, 0, -1], dtype=torch.float32, device=device).view(3, 1)
49 |
50 | # `eye` is the (3, 3) identity matrix
51 | self.eye = torch.eye(3, device=device)
52 |
53 | def __call__(self, batch):
54 | A, B, C = compute_ABC(batch['w_t_c'], batch['c_R_w'], batch['w_t_chat'], batch['chat_R_w'], self.c_n, self.eye)
55 |
56 | error = A + B * self.B_weight + C / self.C_weight
57 | error = error.diagonal(dim1=1, dim2=2).sum(dim=1).mean()
58 | return error
59 |
60 |
61 | class PoseNetLoss(torch.nn.Module):
62 | def __init__(self, beta):
63 | super().__init__()
64 | self.beta = beta
65 |
66 | def __call__(self, batch):
67 | t_error = l2_loss(batch['w_t_chat'], batch['w_t_c'])
68 | q_error = l2_loss(batch['chat_q_w'], batch['c_q_w'])
69 | error = t_error + self.beta * q_error
70 | return error
71 |
72 |
73 | class HomoscedasticLoss(torch.nn.Module):
74 | def __init__(self, s_hat_t, s_hat_q, device='cpu'):
75 | super().__init__()
76 | self.s_hat_t = torch.nn.Parameter(torch.tensor(s_hat_t, dtype=torch.float32, device=device))
77 | self.s_hat_q = torch.nn.Parameter(torch.tensor(s_hat_q, dtype=torch.float32, device=device))
78 |
79 | def __call__(self, batch):
80 | LtI = l1_loss(batch['w_t_chat'], batch['w_t_c'])
81 | LqI = l1_loss(batch['normalized_chat_q_w'], batch['c_q_w'])
82 | error = LtI * torch.exp(-self.s_hat_t) + self.s_hat_t + LqI * torch.exp(-self.s_hat_q) + self.s_hat_q
83 | return error
84 |
85 |
86 | class GeometricLoss(torch.nn.Module):
87 | def __init__(self):
88 | super().__init__()
89 |
90 | def __call__(self, batch):
91 | error = 0
92 | for w_t_c, c_R_w, w_t_chat, chat_R_w, w_P in zip(batch['w_t_c'], batch['c_R_w'], batch['w_t_chat'],
93 | batch['chat_R_w'], batch['w_P']):
94 | c_p = project(w_t_c, c_R_w, w_P)
95 | chat_p = project(w_t_chat, chat_R_w, w_P)
96 | error += l1_loss(chat_p.T, c_p.T, reduce='none').clip(0, 100).mean()
97 | error = error / batch['w_t_c'].shape[0]
98 | return error
99 |
100 |
101 | class DSACLoss(torch.nn.Module):
102 | def __init__(self):
103 | super().__init__()
104 |
105 | def __call__(self, batch):
106 | t_error = 100 * l2_loss(batch['w_t_chat'], batch['w_t_c'], reduce='none')
107 | q_error = angle_between_quaternions(batch['normalized_chat_q_w'], batch['c_q_w']).rad2deg()
108 | error = torch.max(
109 | t_error.view(-1),
110 | q_error
111 | ).mean()
112 | return error
113 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import tqdm
8 | from torch.utils.data import DataLoader
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | import datasets
12 | import losses
13 | import models
14 | from utils import batch_to_device, batch_errors, batch_compute_utils, log_poses, log_errors
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 | parser.add_argument(
19 | 'path', metavar='DATA_PATH',
20 | help='path to the dataset directory, e.g. "/home/data/KingsCollege"'
21 | )
22 | parser.add_argument(
23 | '--loss', help='loss function for training',
24 | choices=['local_homography', 'global_homography', 'posenet', 'homoscedastic', 'geometric', 'dsac'],
25 | default='local_homography'
26 | )
27 | parser.add_argument('--epochs', help='number of epochs for training', type=int, default=5000)
28 | parser.add_argument('--batch_size', help='training batch size', type=int, default=64)
29 | parser.add_argument('--xmin_percentile', help='xmin depth percentile', type=float, default=0.025)
30 | parser.add_argument('--xmax_percentile', help='xmax depth percentile', type=float, default=0.975)
31 | parser.add_argument(
32 | '--weights', metavar='WEIGHTS_PATH',
33 | help='path to weights with which the model will be initialized'
34 | )
35 | parser.add_argument(
36 | '--device', default='cpu',
37 | help='set the device to train the model, `cuda` for GPU'
38 | )
39 | args = parser.parse_args()
40 |
41 | # Set seed for reproductibility
42 | seed = 1
43 | random.seed(seed)
44 | np.random.seed(seed)
45 | torch.manual_seed(seed)
46 |
47 | # Load model
48 | model = models.load_model(args.weights)
49 | model.train()
50 | model.to(args.device)
51 |
52 | # Load dataset
53 | dataset_name = os.path.basename(os.path.normpath(args.path))
54 | if dataset_name in ['GreatCourt', 'KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch', 'Street']:
55 | dataset = datasets.CambridgeDataset(args.path, args.xmin_percentile, args.xmax_percentile)
56 | elif dataset_name in ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']:
57 | dataset = datasets.SevenScenesDataset(args.path, args.xmin_percentile, args.xmax_percentile)
58 | else:
59 | dataset = datasets.COLMAPDataset(args.path, args.xmin_percentile, args.xmax_percentile)
60 |
61 | # Wrapper for use with PyTorch's DataLoader
62 | train_dataset = datasets.RelocDataset(dataset.train_data)
63 | test_dataset = datasets.RelocDataset(dataset.test_data)
64 |
65 | # Creating data loaders for train and test data
66 | train_loader = DataLoader(
67 | train_dataset,
68 | batch_size=args.batch_size,
69 | shuffle=True,
70 | pin_memory=True,
71 | collate_fn=datasets.collate_fn,
72 | drop_last=True
73 | )
74 | test_loader = DataLoader(
75 | test_dataset,
76 | batch_size=args.batch_size,
77 | shuffle=False,
78 | pin_memory=True,
79 | collate_fn=datasets.collate_fn
80 | )
81 |
82 | # Adam optimizer default epsilon parameter is 1e-8
83 | eps = 1e-8
84 |
85 | # Instantiate loss
86 | if args.loss == 'local_homography':
87 | criterion = losses.LocalHomographyLoss(device=args.device)
88 | eps = 1e-14 # Adam optimizer epsilon is set to 1e-14 for homography losses
89 | elif args.loss == 'global_homography':
90 | criterion = losses.GlobalHomographyLoss(
91 | xmin=dataset.train_global_xmin,
92 | xmax=dataset.train_global_xmax,
93 | device=args.device
94 | )
95 | eps = 1e-14 # Adam optimizer epsilon is set to 1e-14 for homography losses
96 | elif args.loss == 'posenet':
97 | criterion = losses.PoseNetLoss(beta=500)
98 | elif args.loss == 'homoscedastic':
99 | criterion = losses.HomoscedasticLoss(s_hat_t=0.0, s_hat_q=-3.0, device=args.device)
100 | elif args.loss == 'geometric':
101 | criterion = losses.GeometricLoss()
102 | elif args.loss == 'dsac':
103 | criterion = losses.DSACLoss()
104 | else:
105 | raise Exception(f'Loss {args.loss} not recognized...')
106 |
107 | # Instantiate adam optimizer
108 | optimizer = torch.optim.Adam(list(model.parameters()) + list(criterion.parameters()), lr=1e-4, eps=eps)
109 |
110 | # Set up tensorboard
111 | writer = SummaryWriter(os.path.join('logs', os.path.basename(os.path.normpath(args.path)), args.loss))
112 |
113 | # Set up folder to save weights
114 | if not os.path.exists(os.path.join(writer.log_dir, 'weights')):
115 | os.makedirs(os.path.join(writer.log_dir, 'weights'))
116 |
117 | # Set up file to save logs
118 | log_file_path = os.path.join(writer.log_dir, 'epochs_poses_log.csv')
119 | with open(log_file_path, mode='w') as log_file:
120 | log_file.write('epoch,image_file,type,w_tx_chat,w_ty_chat,w_tz_chat,chat_qw_w,chat_qx_w,chat_qy_w,chat_qz_w\n')
121 |
122 | print('Start training...')
123 | for epoch in tqdm.tqdm(range(args.epochs)):
124 | epoch_loss = 0
125 | errors = {}
126 |
127 | for batch in train_loader:
128 | optimizer.zero_grad()
129 |
130 | # Move all batch data to proper device
131 | batch = batch_to_device(batch, args.device)
132 |
133 | # Estimate the pose from the image
134 | batch['w_t_chat'], batch['chat_q_w'] = model(batch['image']).split([3, 4], dim=1)
135 |
136 | # Computes useful data for our batch
137 | # - Normalized quaternion
138 | # - Rotation matrix from this normalized quaternion
139 | # - Reshapes translation component to fit shape (batch_size, 3, 1)
140 | batch_compute_utils(batch)
141 |
142 | # Compute loss
143 | loss = criterion(batch)
144 |
145 | # Backprop
146 | loss.backward()
147 | optimizer.step()
148 |
149 | # Add current batch loss to epoch loss
150 | epoch_loss += loss.item() / len(train_loader)
151 |
152 | # Compute training batch errors and log poses
153 | with torch.no_grad():
154 | batch_errors(batch, errors)
155 |
156 | with open(log_file_path, mode='a') as log_file:
157 | log_poses(log_file, batch, epoch, 'train')
158 |
159 | # Log epoch loss
160 | writer.add_scalar('train loss', epoch_loss, epoch)
161 |
162 | with torch.no_grad():
163 |
164 | # Log train errors
165 | log_errors(errors, writer, epoch, 'train')
166 |
167 | # Set the model to eval mode for test data
168 | model.eval()
169 | errors = {}
170 |
171 | for batch in test_loader:
172 | # Compute test poses estimations
173 | batch = batch_to_device(batch, args.device)
174 | batch['w_t_chat'], batch['chat_q_w'] = model(batch['image']).split([3, 4], dim=1)
175 | batch_compute_utils(batch)
176 |
177 | # Log test poses
178 | with open(log_file_path, mode='a') as log_file:
179 | log_poses(log_file, batch, epoch, 'test')
180 |
181 | # Compute test errors
182 | batch_errors(batch, errors)
183 |
184 | # Log test errors
185 | log_errors(errors, writer, epoch, 'test')
186 |
187 | # Log loss parameters, if there are any
188 | for p_name, p in criterion.named_parameters():
189 | writer.add_scalar(p_name, p, epoch)
190 |
191 | writer.flush()
192 | model.train()
193 |
194 | # Save model and optimizer weights every n and last epochs:
195 | if epoch % 500 == 0 or epoch == args.epochs - 1:
196 | torch.save({
197 | 'model_state_dict': model.state_dict(),
198 | 'optimizer_state_dict': optimizer.state_dict(),
199 | 'criterion_state_dict': criterion.state_dict()
200 | }, os.path.join(writer.log_dir, 'weights', f'epoch_{epoch}.pth'))
201 |
202 | writer.close()
203 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def load_model(weights_path=None):
5 | """
6 | Loads MobileNetV2 pre-trained on ImageNet from PyTorch's cloud.
7 | Modifies last layers to fit our pose regression problem.
8 | """
9 | # Base model is MobileNetV2 from PyTorch's hub
10 | model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
11 |
12 | # We modify the classifier of MobileNetV2 with a custom regressor
13 | in_features = list(model.classifier.children())[-1].in_features
14 | model.classifier = torch.nn.Sequential(
15 | torch.nn.ReLU(),
16 | torch.nn.Linear(
17 | in_features=in_features,
18 | out_features=2048,
19 | bias=True
20 | ),
21 | torch.nn.ReLU(),
22 | torch.nn.Linear(
23 | in_features=2048,
24 | out_features=7,
25 | bias=True
26 | )
27 | )
28 | if weights_path is not None:
29 | model.load_state_dict(torch.load(weights_path)['model_state_dict'])
30 | return model
31 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.functional import normalize
3 | from kornia.geometry.conversions import quaternion_to_rotation_matrix, QuaternionCoeffOrder
4 |
5 |
6 | def angle_between_quaternions(q, r):
7 | """
8 | Works on batchs of quaternions only.
9 | `q` and `r` must be batchs of unit quaternions with shape (n, 4).
10 | """
11 | return 2 * torch.sum(q * r, dim=1).abs().clip(0, 1).arccos()
12 |
13 |
14 | def l1_loss(input, target, reduce='mean'):
15 | """
16 | Computes batch L1 loss with `reduce` reduction.
17 | `input` and `target` must have shape (batch_size, *).
18 | L1 norm will be computed for each element on the batch.
19 | """
20 | loss = torch.abs(target - input).sum(dim=1)
21 | if reduce == 'none':
22 | return loss
23 | elif reduce == 'mean':
24 | return loss.mean()
25 | else:
26 | raise Exception(f'Reduction method {reduce} not known')
27 |
28 |
29 | def l2_loss(input, target, reduce='mean'):
30 | """
31 | Computes batch L2 loss with `reduce` reduction.
32 | `input` and `target` must have shape (batch_size, *).
33 | L2 norm will be computed for each element on the batch.
34 | """
35 | loss = torch.square(target - input).sum(dim=1).sqrt()
36 | if reduce == 'none':
37 | return loss
38 | elif reduce == 'mean':
39 | return loss.mean()
40 | else:
41 | raise Exception(f'Reduction method {reduce} not known')
42 |
43 |
44 | def compute_ABC(w_t_c, c_R_w, w_t_chat, chat_R_w, c_n, eye):
45 | """
46 | Computes A, B, and C matrix given estimated and ground truth poses
47 | and normal vector n.
48 | `w_t_c` and `w_t_chat` must have shape (batch_size, 3, 1).
49 | `c_R_w` and `chat_R_w` must have shape (batch_size, 3, 3).
50 | `n` must have shape (3, 1).
51 | `eye` is the (3, 3) identity matrix on the proper device.
52 | """
53 | chat_t_c = chat_R_w @ (w_t_c - w_t_chat)
54 | chat_R_c = chat_R_w @ c_R_w.transpose(1, 2)
55 |
56 | A = eye - chat_R_c
57 | C = c_n @ chat_t_c.transpose(1, 2)
58 | B = C @ A
59 | A = A @ A.transpose(1, 2)
60 | B = B + B.transpose(1, 2)
61 | C = C @ C.transpose(1, 2)
62 |
63 | return A, B, C
64 |
65 |
66 | def batch_to_device(batch, device):
67 | """
68 | If `device` is not 'cpu', moves all data in batch to the GPU.
69 | """
70 | if device != 'cpu':
71 | for key, value in batch.items():
72 | if type(value) is torch.Tensor:
73 | batch[key] = value.to(device)
74 | elif type(value[0]) is torch.Tensor:
75 | for index_value, value_value in enumerate(value):
76 | value[index_value] = value_value.to(device)
77 | return batch
78 |
79 |
80 | def project(w_t_c, c_R_w, w_P, K=None):
81 | """
82 | Projects 3D points P expressed in frame w onto frame c camera view.
83 | `w_t_c` is the (3, 1) shaped translation from frame c to frame w.
84 | `c_R_w` is the (3, 3) rotation matrix from frame w to frame c.
85 | `w_P` are (n, 3) shaped 3D points P expressed in the w frame.
86 | `K` is frame c camera matrix.
87 | """
88 | c_p = c_R_w @ (w_P.T - w_t_c)
89 | if K is not None:
90 | c_p = K @ c_p
91 | c_p = c_p[:2] / c_p[2]
92 | return c_p
93 |
94 |
95 | def batch_errors(batch, errors):
96 | """
97 | Computes translation, rotation and reprojection errors for the batch.
98 | """
99 | b_errors = {
100 | 't_errors': [l2_loss(batch['w_t_chat'], batch['w_t_c'], reduce='none').squeeze()],
101 | 'q_errors': [angle_between_quaternions(batch['normalized_chat_q_w'], batch['c_q_w'])],
102 | 'reprojection_error_sum': 0,
103 | 'reprojection_distance_sum': 0,
104 | 'l1_reprojection_error_sum': 0,
105 | 'n_points': 0
106 | }
107 |
108 | for w_t_chat, chat_R_w, w_P, c_p, K in zip(batch['w_t_chat'], batch['chat_R_w'], batch['w_P'],
109 | batch['c_p'], batch['K']):
110 | chat_p = project(w_t_chat, chat_R_w, w_P, K=K)
111 | diff = chat_p.T - c_p
112 | reprojection_errors = torch.square(diff).sum(dim=1).clip(0, 1000000)
113 | b_errors['reprojection_error_sum'] += reprojection_errors.sum()
114 | b_errors['reprojection_distance_sum'] += reprojection_errors.sqrt().sum()
115 | b_errors['l1_reprojection_error_sum'] += torch.abs(diff).sum(dim=1).clip(0, 1000000).sum()
116 | b_errors['n_points'] += c_p.shape[0]
117 |
118 | if len(errors) == 0:
119 | for key, value in b_errors.items():
120 | errors[key] = value
121 | else:
122 | for key, value in b_errors.items():
123 | errors[key] += value
124 |
125 |
126 | def batch_compute_utils(batch):
127 | """
128 | Computes inplace useful data for the batch.
129 | - Computes a normalized quaternion, and its corresponding rotation matrix.
130 | - Reshapes translation component to fit shape (batchs_size, 3, 1).
131 | """
132 | batch['w_t_chat'] = batch['w_t_chat'].view(-1, 3, 1)
133 | batch['normalized_chat_q_w'] = normalize(batch['chat_q_w'], dim=1)
134 | batch['chat_R_w'] = quaternion_to_rotation_matrix(batch['chat_q_w'], order=QuaternionCoeffOrder.WXYZ)
135 |
136 |
137 | def log_poses(log_file, batch, epoch, data_type):
138 | """
139 | Logs batch estimated poses in log file.
140 | """
141 | log_file.write('\n'.join([
142 | f'{epoch},{image_file},{data_type},{",".join(map(str, w_t_chat.squeeze().tolist()))},'
143 | f'{",".join(map(str, chat_q_w.tolist()))}'
144 | for image_file, w_t_chat, chat_q_w in
145 | zip(batch['image_file'], batch['w_t_chat'], batch['chat_q_w'])]) + '\n'
146 | )
147 |
148 |
149 | def log_errors(errors, writer, epoch, data_type):
150 | """
151 | Logs epoch poses errors in tensorboard.
152 | """
153 | t_errors = torch.hstack(errors['t_errors'])
154 | q_errors = torch.hstack(errors['q_errors']).rad2deg()
155 |
156 | writer.add_scalar(f'{data_type} distance median', t_errors.median(), epoch)
157 | writer.add_scalar(f'{data_type} angle median', q_errors.median(), epoch)
158 | writer.add_scalar(
159 | f'{data_type} mean reprojection error',
160 | errors['reprojection_error_sum'] / errors['n_points'], epoch
161 | )
162 | writer.add_scalar(
163 | f'{data_type} mean reprojection distance',
164 | errors['reprojection_distance_sum'] / errors['n_points'], epoch
165 | )
166 | writer.add_scalar(
167 | f'{data_type} mean l1 reprojection error',
168 | errors['l1_reprojection_error_sum'] / errors['n_points'], epoch
169 | )
170 |
171 | for meter_threshold, deg_threshold in zip(
172 | [0.05, 0.15, 0.25, 0.5, 0.01, 0.02, 0.03, 0.05, 0.25, 0.5, 5],
173 | [2, 5, 10, 15, 1, 2, 3, 5, 2, 5, 10]
174 | ):
175 | score = torch.logical_and(
176 | t_errors <= meter_threshold, q_errors <= deg_threshold
177 | ).sum() / t_errors.shape[0]
178 | writer.add_scalar(
179 | f'{data_type} percentage localized within {meter_threshold}m, {deg_threshold}deg',
180 | score, epoch
181 | )
182 |
--------------------------------------------------------------------------------