├── .gitmodules ├── LICENSE ├── README.md ├── assets └── animation.gif ├── datasets.py ├── datasets ├── README.md └── datasetup.py ├── environment.yml ├── losses.py ├── main.py ├── models.py └── utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "colmap"] 2 | path = colmap 3 | url = https://github.com/colmap/colmap.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Homography-Based Loss Function for Camera Pose Regression 2 | In this repository, we share our implementation of several camera pose regression 3 | loss functions in a simple end-to-end network similar to 4 | [PoseNet](https://openaccess.thecvf.com/content_iccv_2015/html/Kendall_PoseNet_A_Convolutional_ICCV_2015_paper.html). 5 | We implemented the homography-based loss functions introduced in our paper alongside PoseNet, Homoscedastic, Geometric 6 | and DSAC loss functions. We provide the code to train and test the network on Cambridge, 7-Scenes and custom COLMAP 7 | datasets. 8 | 9 | Our paper [Homography-Based Loss Function for Camera Pose Regression](https://arxiv.org/abs/2205.01937) is published in IEEE Robotics and Automation Letters 2022. 10 | 11 | ![Convergence of homography loss](assets/animation.gif) 12 | *

Convergence of our proposed Homography loss

* 13 | We show other losses convergence 14 | [on our YouTube channel](https://youtube.com/playlist?list=PLe92vnufKoYIIHrW5I268RYdX6aV4gTa6). 15 | 16 | ## Installation 17 | 18 | ### COLMAP dependancy 19 | This code relies on COLMAP for loading COLMAP models. To satisfy this dependancy, simply run: 20 | ```shell 21 | git submodule update --init 22 | ``` 23 | 24 | ### Python environment setup 25 | We share an [Anaconda](https://www.anaconda.com) environment that can be easily installed by running: 26 | ```shell 27 | conda env create -f environment.yml 28 | ``` 29 | Anaconda is easy to install and benefits from a lighter implementation named 30 | [Miniconda](https://docs.conda.io/en/latest/miniconda.html). 31 | Once the environment is installed you can activate it by running: 32 | ```shell 33 | conda activate homographyloss 34 | ``` 35 | 36 | ### Dataset setup 37 | Have a look at the [datasets](datasets) folder to setup the datasets. 38 | 39 | ## Run relocalization 40 | The script [main.py](main.py) trains the network on a given scene and logs the performance of the model on the 41 | train set. It requires one positional argument: the path to the scene on which to train the model. 42 | For example, for training the model on the ShopFacade scene, simply run: 43 | ```shell 44 | python main.py datasets/ShopFacade 45 | ``` 46 | Let's say you have a custom dataset in `datasets/mydataset` with the structure defined in [datasets](datasets): 47 | > - mydataset 48 | > - images 49 | > - frame001.jpg 50 | > - frame002.jpg 51 | > - frame003.jpg 52 | > - ... 53 | > - cameras.bin 54 | > - images.bin 55 | > - points3D.bin 56 | > - list_db.txt 57 | > - list_query.txt 58 | 59 | Then you might run the script on your custom dataset: 60 | ```shell 61 | python main.py datasets/mydataset 62 | ``` 63 | 64 | Other available training options can be listed by running `python main.py -h`. 65 | 66 | ## Monitor training and test results 67 | Training and test metrics are saved in a `logs` directory. One can monitor them using tensorboard. 68 | Simply run in a new terminal: 69 | ```shell 70 | tensorboard --logdir logs 71 | ``` 72 | 73 | All estimated poses are also saved in a CSV file in `logs/[scene]/[loss]/epochs_poses_log.csv`. 74 | For each epoch, each image and each set, we save the estimated pose in the following format: 75 | - `w_t_chat` is the camera-to-world translation of the image. 76 | - `chat_q_w` is the world-to-camera quaternion representing the rotation of the image. 77 | 78 | ## Acknowledgements 79 | This work was supported by [Ifremer](https://wwz.ifremer.fr/), [DYNI](https://dyni.pages.lis-lab.fr/) team of [LIS laboratory](https://www.lis-lab.fr/) and [COSMER laboratory](https://cosmer.univ-tln.fr/). 80 | 81 | ## License 82 | This code is released under the LGPLv3 licence. Please have a look at the licence file at the repository root. 83 | 84 | ## Citation 85 | If you use this work for your research, please cite: 86 | ``` 87 | @article{boittiaux2022homographyloss, 88 | author={Boittiaux, Cl\'ementin and 89 | Marxer, Ricard and 90 | Dune, Claire and 91 | Arnaubec, Aur\'elien and 92 | Hugel, Vincent}, 93 | journal={IEEE Robotics and Automation Letters}, 94 | title={Homography-Based Loss Function for Camera Pose Regression}, 95 | year={2022}, 96 | volume={7}, 97 | number={3}, 98 | pages={6242-6249}, 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /assets/animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementinboittiaux/homography-loss-function/919c255483e3e21ae3a38b4b7209772e770232e2/assets/animation.gif -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import tqdm 9 | from PIL import Image 10 | from kornia.geometry.conversions import ( 11 | rotation_matrix_to_quaternion, 12 | quaternion_to_rotation_matrix, 13 | QuaternionCoeffOrder 14 | ) 15 | from torch.nn.functional import normalize 16 | from torch.utils.data import Dataset 17 | from torchvision import transforms 18 | 19 | from colmap.scripts.python.read_write_model import read_model 20 | 21 | # Image preprocessing pipeline according to PyTorch implementation 22 | preprocess = transforms.Compose([ 23 | transforms.Resize(256), 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 26 | ]) 27 | 28 | 29 | def collate_fn(views): 30 | """ 31 | Transforms list of dicts [{key1: value1, key2:value2}, {key1: value3, key2:value4}] 32 | into a dict of lists {key1: [value1, value3], key2: [value2, value4]}. 33 | Then stacks batch-compatible values into tensor batchs. 34 | """ 35 | batch = {key: [] for key in views[0].keys()} 36 | for view in views: 37 | for key, value in view.items(): 38 | batch[key].append(value) 39 | for key, value in batch.items(): 40 | if key not in ['w_P', 'c_p', 'image_file']: 41 | batch[key] = torch.stack(value) 42 | return batch 43 | 44 | 45 | class RelocDataset(Dataset): 46 | """ 47 | Dataset template class for use with PyTorch DataLoader class. 48 | """ 49 | 50 | def __init__(self, dataset): 51 | """ 52 | `dataset` must be a list of dicts providing localization data for each image. 53 | Dicts must provide: 54 | { 55 | 'image_file': name of image file 56 | 'image': torch.tensor image with shape (3, height, width) 57 | 'w_t_c': torch.tensor camera-to-world translation with shape (3, 1) 58 | 'c_q_w': torch.tensor world-to-camera quaternion rotation with shape (4,) in format wxyz 59 | 'c_R_w': torch.tensor world-to-camera rotation matrix with shape (3, 3) 60 | (can be computed with quaternion_to_R) 61 | 'K': torch.tensor camera intrinsics matrix with shape (3, 3) 62 | 'w_P': torch.tensor 3D observations of the image in the world frame with shape (*, 3) 63 | 'c_p': reprojections of the 3D observations in the camera view (in pixels) with shape (*, 2) 64 | 'xmin': minimum depth of observations 65 | 'xmax': maximum depth of observations 66 | } 67 | """ 68 | self.data = dataset 69 | 70 | def __len__(self): 71 | return len(self.data) 72 | 73 | def __getitem__(self, idx): 74 | return { 75 | 'image_file': self.data[idx]['image_file'], 76 | 'image': self.data[idx]['image'], 77 | 'w_t_c': self.data[idx]['w_t_c'], 78 | 'c_q_w': self.data[idx]['c_q_w'], 79 | 'c_R_w': self.data[idx]['c_R_w'], 80 | 'K': self.data[idx]['K'], 81 | 'w_P': self.data[idx]['w_P'], 82 | 'c_p': self.data[idx]['c_p'], 83 | 'xmin': self.data[idx]['xmin'], 84 | 'xmax': self.data[idx]['xmax'] 85 | } 86 | 87 | 88 | class CambridgeDataset: 89 | """ 90 | Template class to load every scene of Cambridge dataset. 91 | """ 92 | 93 | def __init__(self, path, xmin_percentile, xmax_percentile): 94 | """ 95 | `path` is the path to the dataset directory, 96 | e.g. for King's College: "/home/data/KingsCollege". 97 | Creates 6 attributes: 98 | - 2 lists of dicts (train and test) providing localization data for each image. 99 | - 4 parameters (train and test) for minimum and maximum depths of observations. 100 | """ 101 | views = [] 102 | scene_coordinates = [] 103 | with open(os.path.join(path, 'reconstruction.nvm'), mode='r') as file: 104 | 105 | # Skip first two lines 106 | for _ in range(2): 107 | file.readline() 108 | 109 | # `n_views` is the number of images 110 | n_views = int(file.readline()) 111 | 112 | # For each image, NVM format is: 113 | # 0 114 | for _ in range(n_views): 115 | line = file.readline().split() 116 | 117 | f = float(line[1]) 118 | K = torch.tensor([ 119 | [f, 0, 1920 / 2], 120 | [0, f, 1080 / 2], 121 | [0, 0, 1] 122 | ], dtype=torch.float32) 123 | views.append({ 124 | 'image_file': line[0], 125 | 'K': K, 126 | 'observations_ids': [] 127 | }) 128 | 129 | # Skip one line 130 | file.readline() 131 | 132 | # `n_points` is the number of scene coordinates 133 | n_points = int(file.readline()) 134 | 135 | # For each scene coordinate, SVM format is: 136 | # 137 | for i in range(n_points): 138 | 139 | line = file.readline().split() 140 | 141 | scene_coordinates.append(torch.tensor(list(map(float, line[:3])))) 142 | 143 | # `n_obs` is the number of images where the scene coordinate is observed 144 | n_obs = int(line[6]) 145 | 146 | # Each measurement is 147 | # 148 | for n in range(n_obs): 149 | views[int(line[7 + n * 4])]['observations_ids'].append(i) 150 | 151 | views = {view.pop('image_file'): view for view in views} 152 | scene_coordinates = torch.stack(scene_coordinates) 153 | 154 | train_df = pd.read_csv(os.path.join(path, 'dataset_train.txt'), sep=' ', skiprows=1) 155 | test_df = pd.read_csv(os.path.join(path, 'dataset_test.txt'), sep=' ', skiprows=1) 156 | 157 | train_data = [] 158 | test_data = [] 159 | train_global_depths = [] 160 | test_global_depths = [] 161 | 162 | print('Loading images from dataset. This may take a while...') 163 | for data, df, global_depths in [(train_data, train_df, train_global_depths), 164 | (test_data, test_df, test_global_depths)]: 165 | for line in tqdm.tqdm(df.values): 166 | image_file = line[0] 167 | image = preprocess(Image.open(os.path.join(path, image_file))) 168 | w_t_c = torch.tensor(line[1:4].tolist()).view(3, 1) 169 | c_q_w = normalize(torch.tensor(line[4:8].tolist()), dim=0) 170 | c_R_w = quaternion_to_rotation_matrix(c_q_w, order=QuaternionCoeffOrder.WXYZ) 171 | view = views[os.path.splitext(image_file)[0] + '.jpg'] 172 | w_P = scene_coordinates[view['observations_ids']] 173 | c_P = c_R_w @ (w_P.T - w_t_c) 174 | c_p = view['K'] @ c_P 175 | c_p = c_p[:2] / c_p[2] 176 | 177 | args_inliers = torch.where(torch.logical_and( 178 | torch.logical_and( 179 | torch.logical_and(c_P[2] > 0.2, c_P[2] < 1000), 180 | torch.logical_and(c_P[0].abs() < 1000, c_P[1].abs() < 1000) 181 | ), 182 | torch.logical_and( 183 | torch.logical_and(c_p[0] > 0, c_p[0] < 1920), 184 | torch.logical_and(c_p[1] > 0, c_p[1] < 1080) 185 | ) 186 | ))[0] 187 | 188 | if args_inliers.shape[0] < 10: 189 | tqdm.tqdm.write(f'Not using image {image_file}: [{args_inliers.shape[0]}/{w_P.shape[0]}] scene ' 190 | f'coordinates inliers') 191 | elif w_t_c.abs().max() > 1000: 192 | tqdm.tqdm.write(f'Not using image {image_file}: t is {w_t_c.numpy()}') 193 | else: 194 | if args_inliers.shape[0] != w_P.shape[0]: 195 | tqdm.tqdm.write(f'Eliminating outliers in image {image_file}: ' 196 | f'[{args_inliers.shape[0]}/{w_P.shape[0]}] scene coordinates inliers') 197 | 198 | depths = torch.sort(c_P.T[args_inliers][:, 2]).values 199 | global_depths.append(depths) 200 | 201 | data.append({ 202 | 'image_file': image_file, 203 | 'image': image, 204 | 'w_t_c': w_t_c, 205 | 'c_q_w': c_q_w, 206 | 'c_R_w': c_R_w, 207 | 'w_P': w_P[args_inliers], 208 | 'c_p': c_p.T[args_inliers], 209 | 'K': view['K'], 210 | 'xmin': depths[int(xmin_percentile * (depths.shape[0] - 1))], 211 | 'xmax': depths[int(xmax_percentile * (depths.shape[0] - 1))] 212 | }) 213 | 214 | train_global_depths = torch.sort(torch.hstack(train_global_depths)).values 215 | test_global_depths = torch.sort(torch.hstack(test_global_depths)).values 216 | self.train_global_xmin = train_global_depths[int(xmin_percentile * (train_global_depths.shape[0] - 1))] 217 | self.train_global_xmax = train_global_depths[int(xmax_percentile * (train_global_depths.shape[0] - 1))] 218 | self.test_global_xmin = test_global_depths[int(xmin_percentile * (test_global_depths.shape[0] - 1))] 219 | self.test_global_xmax = test_global_depths[int(xmax_percentile * (test_global_depths.shape[0] - 1))] 220 | self.train_data = train_data 221 | self.test_data = test_data 222 | 223 | 224 | class SevenScenesDataset: 225 | """ 226 | Template class to load every scene from 7-Scenes dataset 227 | """ 228 | 229 | def __init__(self, path, xmin_percentile, xmax_percentile): 230 | 231 | # Camera intrinsics 232 | K = np.array([ 233 | [585, 0, 320], 234 | [0, 585, 240], 235 | [0, 0, 1] 236 | ], dtype=np.float64) 237 | K_inv = np.linalg.inv(K) 238 | K_torch = torch.tensor(K, dtype=torch.float32) 239 | 240 | # Grid of pixels 241 | u = np.arange(640) + 0.5 242 | v = np.arange(480) + 0.5 243 | u, v = np.meshgrid(u, v) 244 | 245 | # Array of all pixel positions in pixels 246 | c_p_px = np.hstack([ 247 | u.reshape(-1, 1), 248 | v.reshape(-1, 1), 249 | np.ones((u.size, 1)) 250 | ]) 251 | c_p_px_torch = torch.tensor(c_p_px[:, :2], dtype=torch.float32) 252 | 253 | # Array of all pixels in the sensor plane 254 | c_p = K_inv @ c_p_px.T 255 | 256 | train_data = [] 257 | test_data = [] 258 | train_global_depths = [] 259 | test_global_depths = [] 260 | 261 | for data, file, global_depths in [(train_data, 'TrainSplit.txt', train_global_depths), 262 | (test_data, 'TestSplit.txt', test_global_depths)]: 263 | 264 | with open(os.path.join(path, file), mode='r') as f: 265 | seqs = [int(line[8:]) for line in f] 266 | 267 | for seq in seqs: 268 | 269 | seq_dir = os.path.join(path, f'seq-{seq:02d}') 270 | 271 | print(f'Loading seq-{seq:02d}') 272 | 273 | for frame in tqdm.tqdm(glob.glob(os.path.join(seq_dir, '*.color.png'))): 274 | 275 | frame = os.path.basename(frame).split('.')[0] 276 | image_path = os.path.join(seq_dir, f'{frame}.color.png') 277 | pose_path = os.path.join(seq_dir, f'{frame}.pose.txt') 278 | depth_path = os.path.join(seq_dir, f'{frame}.depth.png') 279 | 280 | image = preprocess(Image.open(image_path)) 281 | 282 | # Read camera-to-world pose 283 | w_M_c = np.zeros((4, 4)) 284 | with open(pose_path, mode='r') as f: 285 | for i, line in enumerate(f): 286 | w_M_c[i] = list(map(float, line.strip().split('\t'))) 287 | 288 | # Read depth map 289 | Z = np.array(Image.open(depth_path)).reshape(-1, 1) 290 | 291 | # Filter outliers 292 | args_inliers = np.logical_and(Z > 0, Z != 65535).squeeze() 293 | 294 | # Unproject pixels 295 | c_P = c_p.T[args_inliers] * (Z[args_inliers] / 1000) 296 | 297 | # Convert 3D points from camera to world frame 298 | w_P = w_M_c[:3, :3] @ c_P.T + w_M_c[:3, 3:4] 299 | 300 | # Building rotation matrix and its quaternion 301 | w_M_c = torch.tensor(w_M_c) 302 | c_R_w = w_M_c[:3, :3].T.contiguous() 303 | c_q_w = rotation_matrix_to_quaternion(c_R_w, order=QuaternionCoeffOrder.WXYZ) 304 | 305 | # Keep the quaternion on the top hypersphere 306 | if c_q_w[0] < 0: 307 | c_q_w *= -1 308 | 309 | # Sort depths 310 | depths = Z[args_inliers].flatten() 311 | global_depths.append(depths) 312 | depths = np.sort(depths) 313 | 314 | data.append({ 315 | 'image_file': f'seq-{seq:02d}/{frame}.color.png', 316 | 'image': image, 317 | 'w_t_c': w_M_c[:3, 3:4].float(), 318 | 'c_q_w': c_q_w.float(), 319 | 'c_R_w': c_R_w.float(), 320 | 'w_P': torch.tensor(w_P.T, dtype=torch.float32), 321 | 'c_p': c_p_px_torch[args_inliers], 322 | 'K': K_torch, 323 | 'xmin': torch.tensor( 324 | depths[int(xmin_percentile * (depths.size - 1))] / 1000, dtype=torch.float32 325 | ), 326 | 'xmax': torch.tensor( 327 | depths[int(xmax_percentile * (depths.size - 1))] / 1000, dtype=torch.float32 328 | ) 329 | }) 330 | 331 | # Sort global depths 332 | print('Sorting depths, this may take a while...') 333 | train_global_depths = np.sort(np.hstack(train_global_depths)) 334 | test_global_depths = np.sort(np.hstack(test_global_depths)) 335 | 336 | self.train_global_xmin = torch.tensor( 337 | train_global_depths[int(xmin_percentile * (train_global_depths.size - 1))] / 1000, 338 | dtype=torch.float32 339 | ) 340 | self.train_global_xmax = torch.tensor( 341 | train_global_depths[int(xmax_percentile * (train_global_depths.size - 1))] / 1000, 342 | dtype=torch.float32 343 | ) 344 | self.test_global_xmin = torch.tensor( 345 | test_global_depths[int(xmin_percentile * (test_global_depths.size - 1))] / 1000, 346 | dtype=torch.float32 347 | ) 348 | self.test_global_xmax = torch.tensor( 349 | test_global_depths[int(xmax_percentile * (test_global_depths.size - 1))] / 1000, 350 | dtype=torch.float32 351 | ) 352 | self.train_data = train_data 353 | self.test_data = test_data 354 | 355 | 356 | class COLMAPDataset: 357 | """ 358 | WIP class to load COLMAP scenes. Only RADIAL camera model is supported. 359 | """ 360 | 361 | def __init__(self, path, xmin_percentile, xmax_percentile): 362 | """ 363 | `path` to a folder containing: 364 | - COLMAP model 365 | - an `images` directory containing all images 366 | - two lists named `list_db.txt` and `list_query.txt` containing 367 | respectively the names of database and query images (one name per line) 368 | """ 369 | 370 | print('COLMAPDataset is work in progress, only supports RADIAL camera model!') 371 | 372 | images_path = os.path.join(path, 'images') 373 | list_query = os.path.join(path, 'list_query.txt') 374 | list_db = os.path.join(path, 'list_db.txt') 375 | 376 | cameras, images, points3D = read_model(path) 377 | 378 | image_name_to_id = {image.name: i for i, image in images.items()} 379 | 380 | scene_coordinates = torch.zeros(max(points3D.keys()) + 1, 3, dtype=torch.float64) 381 | for i, point3D in points3D.items(): 382 | scene_coordinates[i] = torch.tensor(point3D.xyz) 383 | 384 | train_data = [] 385 | test_data = [] 386 | train_global_depths = [] 387 | test_global_depths = [] 388 | 389 | for data, file, global_depths in zip([train_data, test_data], 390 | [list_db, list_query], 391 | [train_global_depths, test_global_depths]): 392 | with open(file, 'r') as f: 393 | image_names = f.read().splitlines() 394 | 395 | for image_name in tqdm.tqdm(image_names): 396 | 397 | image = images[image_name_to_id[image_name]] 398 | camera = cameras[image.camera_id] 399 | 400 | im = cv2.imread(os.path.join(images_path, image_name)) 401 | 402 | f, u0, v0, k1, k2 = camera.params 403 | K = np.array([ 404 | [f, 0, u0], 405 | [0, f, v0], 406 | [0, 0, 1] 407 | ]) 408 | dist_coeffs = np.array([k1, k2, 0, 0]) 409 | new_K, roi = cv2.getOptimalNewCameraMatrix( 410 | cameraMatrix=K, 411 | distCoeffs=dist_coeffs, 412 | imageSize=im.shape[:2][::-1], 413 | alpha=0, 414 | centerPrincipalPoint=True 415 | ) 416 | new_K = torch.tensor(new_K) 417 | new_K[0, 2] = camera.width / 2 418 | new_K[1, 2] = camera.height / 2 419 | 420 | # Undistort image and center its principal point 421 | im = cv2.undistort(im, K, dist_coeffs, newCameraMatrix=new_K.numpy()) 422 | im = preprocess(Image.fromarray(im[:, :, ::-1])) 423 | 424 | c_t_w = torch.tensor(image.tvec).view(3, 1) 425 | c_q_w = torch.tensor(image.qvec) 426 | 427 | # Keep the quaternion on the top hypersphere 428 | if c_q_w[0] < 0: 429 | c_q_w *= -1 430 | 431 | c_R_w = quaternion_to_rotation_matrix(c_q_w, order=QuaternionCoeffOrder.WXYZ) 432 | w_t_c = -c_R_w.T @ c_t_w 433 | 434 | w_P = scene_coordinates[[i for i in image.point3D_ids if i != -1]] 435 | c_P = c_R_w @ (w_P.T - w_t_c) 436 | c_p = new_K @ c_P 437 | c_p = c_p[:2] / c_p[2] 438 | 439 | depths = torch.sort(c_P[2]).values 440 | global_depths.append(depths.float()) 441 | 442 | data.append({ 443 | 'image_file': image_name, 444 | 'image': im, 445 | 'w_t_c': w_t_c.float(), 446 | 'c_q_w': c_q_w.float(), 447 | 'c_R_w': c_R_w.float(), 448 | 'w_P': w_P.float(), 449 | 'c_p': c_p.T.float(), 450 | 'K': new_K.float(), 451 | 'xmin': depths[int(xmin_percentile * (depths.shape[0] - 1))].float(), 452 | 'xmax': depths[int(xmax_percentile * (depths.shape[0] - 1))].float() 453 | }) 454 | 455 | train_global_depths = torch.sort(torch.hstack(train_global_depths)).values 456 | test_global_depths = torch.sort(torch.hstack(test_global_depths)).values 457 | self.train_global_xmin = train_global_depths[int(xmin_percentile * (train_global_depths.shape[0] - 1))] 458 | self.train_global_xmax = train_global_depths[int(xmax_percentile * (train_global_depths.shape[0] - 1))] 459 | self.test_global_xmin = test_global_depths[int(xmin_percentile * (test_global_depths.shape[0] - 1))] 460 | self.test_global_xmax = test_global_depths[int(xmax_percentile * (test_global_depths.shape[0] - 1))] 461 | self.train_data = train_data 462 | self.test_data = test_data 463 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Dataset setup 2 | 3 | Here you can find instructions to setup datasets for use with this code. 4 | 5 | ## Cambridge and 7-Scenes 6 | 7 | We provide the script [datasetup.py](datasetup.py) for setting up Cambridge and 7-Scenes datasets. The script can be 8 | called with either the name of the dataset to setup, *e.g.*, `7-Scenes`, or the name of a specific scene, *e.g.*, 9 | `KingsCollege`. For example, if you want to setup the whole Cambridge dataset: 10 | ```shell 11 | python datasetup.py Cambridge 12 | ``` 13 | Or if you want to only setup the *chess* scene of 7-Scenes dataset: 14 | ```shell 15 | python datasetup.py chess 16 | ``` 17 | All possibilities can be accessed by running: 18 | ```shell 19 | python datasetup.py -h 20 | ``` 21 | 22 | 23 | ## Custom dataset 24 | 25 | We also support custom datasets in **COLMAP** model format. 26 | ⚠️ Please note that only **RADIAL** camera models are supported for now. 27 | 28 | The custom dataset folder must contain: 29 | - The COLMAP model: `cameras`, `images` and `points3D` files in `.bin` or `.txt` format. 30 | - A folder named `images` containing all images in the model. 31 | - A file named `list_db.txt` with the name of all the images used for training, one image name per line. 32 | - A file named `list_query.txt` with the name of all the images used for testing, one image name per line. 33 | 34 | The final outline of the folder should look like this: 35 | > - mydataset 36 | > - images 37 | > - frame001.jpg 38 | > - frame002.jpg 39 | > - frame003.jpg 40 | > - ... 41 | > - cameras.bin 42 | > - images.bin 43 | > - points3D.bin 44 | > - list_db.txt 45 | > - list_query.txt 46 | 47 | An example of `list_db.txt` or `list_query.txt`: 48 | ```text 49 | frame001.jpg 50 | frame002.jpg 51 | frame003.jpg 52 | ... 53 | ``` 54 | -------------------------------------------------------------------------------- /datasets/datasetup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import zipfile 5 | from collections import namedtuple 6 | 7 | from torchvision.datasets.utils import download_and_extract_archive 8 | 9 | datasets = { 10 | 'Cambridge': [ 11 | 'GreatCourt', 'KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch', 'Street' 12 | ], 13 | '7-Scenes': [ 14 | 'chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs' 15 | ] 16 | } 17 | Scene = namedtuple('Scene', ['url', 'dataset']) 18 | scenes = { 19 | 'GreatCourt': Scene( 20 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251291/GreatCourt.zip', 21 | dataset='Cambridge' 22 | ), 23 | 'KingsCollege': Scene( 24 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251342/KingsCollege.zip', 25 | dataset='Cambridge' 26 | ), 27 | 'OldHospital': Scene( 28 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251340/OldHospital.zip', 29 | dataset='Cambridge' 30 | ), 31 | 'ShopFacade': Scene( 32 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251336/ShopFacade.zip', 33 | dataset='Cambridge' 34 | ), 35 | 'StMarysChurch': Scene( 36 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251294/StMarysChurch.zip', 37 | dataset='Cambridge' 38 | ), 39 | 'Street': Scene( 40 | url='https://www.repository.cam.ac.uk/bitstream/handle/1810/251292/Street.zip', 41 | dataset='Cambridge' 42 | ), 43 | 'chess': Scene( 44 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/chess.zip', 45 | dataset='7-Scenes' 46 | ), 47 | 'fire': Scene( 48 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/fire.zip', 49 | dataset='7-Scenes' 50 | ), 51 | 'heads': Scene( 52 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/heads.zip', 53 | dataset='7-Scenes' 54 | ), 55 | 'office': Scene( 56 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/office.zip', 57 | dataset='7-Scenes' 58 | ), 59 | 'pumpkin': Scene( 60 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/pumpkin.zip', 61 | dataset='7-Scenes' 62 | ), 63 | 'redkitchen': Scene( 64 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/redkitchen.zip', 65 | dataset='7-Scenes' 66 | ), 67 | 'stairs': Scene( 68 | url='http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/stairs.zip', 69 | dataset='7-Scenes' 70 | ) 71 | } 72 | 73 | 74 | def setup_scene(scene_str): 75 | scene = scenes[scene_str] 76 | download_and_extract_archive(scene.url, scene.dataset) 77 | os.remove(os.path.join(scene.dataset, scene.url.split('/')[-1])) 78 | if scene_str in datasets['7-Scenes']: 79 | for file in glob.glob(os.path.join(scene.dataset, scene_str, '*.zip')): 80 | with zipfile.ZipFile(file, 'r') as f: 81 | members = [m for m in f.namelist() if os.path.basename(m) != 'Thumbs.db'] 82 | f.extractall(os.path.join(scene.dataset, scene_str), members=members) 83 | os.remove(file) 84 | 85 | 86 | if __name__ == '__main__': 87 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 88 | parser.add_argument( 89 | 'dataset', 90 | choices=list(datasets.keys()) + list(scenes.keys()), 91 | help='name of the dataset or single scene to setup' 92 | ) 93 | args = parser.parse_args() 94 | 95 | if args.dataset in datasets: 96 | for scene_name in datasets[args.dataset]: 97 | setup_scene(scene_name) 98 | else: 99 | setup_scene(args.dataset) 100 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: homographyloss 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_llvm 10 | - absl-py=1.0.0=pyhd8ed1ab_0 11 | - aiohttp=3.8.1=py39h3811e60_0 12 | - aiosignal=1.2.0=pyhd8ed1ab_0 13 | - alsa-lib=1.2.3=h516909a_0 14 | - async-timeout=4.0.2=pyhd8ed1ab_0 15 | - attrs=21.4.0=pyhd8ed1ab_0 16 | - blas=1.0=mkl 17 | - blinker=1.4=py_1 18 | - bottleneck=1.3.2=py39hdd57654_1 19 | - brotlipy=0.7.0=py39h27cfd23_1003 20 | - bzip2=1.0.8=h7b6447c_0 21 | - c-ares=1.18.1=h7f98852_0 22 | - ca-certificates=2021.10.8=ha878542_0 23 | - cachetools=5.0.0=pyhd8ed1ab_0 24 | - cairo=1.16.0=ha12eb4b_1010 25 | - certifi=2021.10.8=py39hf3d152e_1 26 | - cffi=1.15.0=py39hd667e15_1 27 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 28 | - click=8.0.4=py39hf3d152e_0 29 | - colorama=0.4.4=pyh9f0ad1d_0 30 | - cryptography=36.0.0=py39h9ce1e76_0 31 | - cudatoolkit=11.3.1=h2bc3f7f_2 32 | - dbus=1.13.6=h5008d03_3 33 | - expat=2.4.7=h27087fc_0 34 | - ffmpeg=4.3.2=h37c90e5_3 35 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 36 | - font-ttf-inconsolata=3.000=h77eed37_0 37 | - font-ttf-source-code-pro=2.038=h77eed37_0 38 | - font-ttf-ubuntu=0.83=hab24e00_0 39 | - fontconfig=2.13.96=h8e229c2_1 40 | - fonts-conda-ecosystem=1=0 41 | - fonts-conda-forge=1=0 42 | - freeglut=3.2.2=h9c3ff4c_1 43 | - freetype=2.11.0=h70c0345_0 44 | - frozenlist=1.3.0=py39h3811e60_0 45 | - gettext=0.19.8.1=h73d1719_1008 46 | - giflib=5.2.1=h7b6447c_0 47 | - gmp=6.2.1=h2531618_2 48 | - gnutls=3.6.15=he1e5248_0 49 | - google-auth=2.6.0=pyh6c4a22f_1 50 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 51 | - graphite2=1.3.13=h58526e2_1001 52 | - grpcio=1.44.0=py39hff7568b_0 53 | - gst-plugins-base=1.18.5=hf529b03_3 54 | - gstreamer=1.18.5=h9f60fe5_3 55 | - harfbuzz=3.4.0=hb4a5f5f_0 56 | - hdf5=1.12.1=nompi_h2386368_104 57 | - icu=69.1=h9c3ff4c_0 58 | - idna=3.3=pyhd3eb1b0_0 59 | - importlib-metadata=4.11.2=py39hf3d152e_0 60 | - intel-openmp=2021.4.0=h06a4308_3561 61 | - jasper=2.0.33=ha77e612_0 62 | - jbig=2.1=h7f98852_2003 63 | - jpeg=9d=h7f8727e_0 64 | - keyutils=1.6.1=h166bdaf_0 65 | - kornia=0.6.3=pyhd8ed1ab_0 66 | - krb5=1.19.2=h3790be6_4 67 | - lame=3.100=h7b6447c_0 68 | - lcms2=2.12=h3be6417_0 69 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 70 | - lerc=3.0=h9c3ff4c_0 71 | - libblas=3.9.0=12_linux64_mkl 72 | - libcblas=3.9.0=12_linux64_mkl 73 | - libclang=13.0.1=default_hc23dcda_0 74 | - libcurl=7.82.0=h7bff187_0 75 | - libdeflate=1.8=h7f98852_0 76 | - libedit=3.1.20191231=he28a2e2_2 77 | - libev=4.33=h516909a_1 78 | - libevent=2.1.10=h9b69904_4 79 | - libffi=3.4.2=h7f98852_5 80 | - libgcc-ng=11.2.0=h1d223b6_13 81 | - libgfortran-ng=11.2.0=h69a702a_13 82 | - libgfortran5=11.2.0=h5c6108e_13 83 | - libglib=2.70.2=h174f98d_4 84 | - libglu=9.0.0=he1b5a44_1001 85 | - libiconv=1.16=h516909a_0 86 | - libidn2=2.3.2=h7f8727e_0 87 | - liblapack=3.9.0=12_linux64_mkl 88 | - liblapacke=3.9.0=12_linux64_mkl 89 | - libllvm13=13.0.1=hf817b99_2 90 | - libnghttp2=1.47.0=h727a467_0 91 | - libnsl=2.0.0=h7f98852_0 92 | - libogg=1.3.4=h7f98852_1 93 | - libopencv=4.5.5=py39h7d09d5f_0 94 | - libopus=1.3.1=h7f98852_1 95 | - libpng=1.6.37=hbc83047_0 96 | - libpq=14.2=hd57d9b9_0 97 | - libprotobuf=3.19.4=h780b84a_0 98 | - libssh2=1.10.0=ha56f1ee_2 99 | - libstdcxx-ng=11.2.0=he4da1e4_13 100 | - libtasn1=4.16.0=h27cfd23_0 101 | - libtiff=4.3.0=h6f004c6_2 102 | - libunistring=0.9.10=h27cfd23_0 103 | - libuuid=2.32.1=h7f98852_1000 104 | - libuv=1.40.0=h7b6447c_0 105 | - libvorbis=1.3.7=h9c3ff4c_0 106 | - libwebp=1.2.2=h55f646e_0 107 | - libwebp-base=1.2.2=h7f8727e_0 108 | - libxcb=1.13=h7f98852_1004 109 | - libxkbcommon=1.0.3=he3ba5ed_0 110 | - libxml2=2.9.12=h885dcf4_1 111 | - libzlib=1.2.11=h36c2ea0_1013 112 | - llvm-openmp=13.0.1=he0ac6c6_1 113 | - lz4-c=1.9.3=h295c915_1 114 | - markdown=3.3.6=pyhd8ed1ab_0 115 | - mkl=2021.4.0=h06a4308_640 116 | - mkl-service=2.4.0=py39h7f8727e_0 117 | - mkl_fft=1.3.1=py39hd3c417c_0 118 | - mkl_random=1.2.2=py39h51133e4_0 119 | - multidict=6.0.2=py39h3811e60_0 120 | - mysql-common=8.0.28=ha770c72_0 121 | - mysql-libs=8.0.28=hfa10184_0 122 | - ncurses=6.3=h7f8727e_2 123 | - nettle=3.7.3=hbbd107a_1 124 | - nspr=4.32=h9c3ff4c_1 125 | - nss=3.74=hb5efdd6_0 126 | - numexpr=2.8.1=py39h6abb31d_0 127 | - numpy=1.21.2=py39h20f2e39_0 128 | - numpy-base=1.21.2=py39h79a1101_0 129 | - oauthlib=3.2.0=pyhd8ed1ab_0 130 | - opencv=4.5.5=py39hf3d152e_0 131 | - openh264=2.1.1=h4ff587b_0 132 | - openssl=1.1.1l=h7f98852_0 133 | - packaging=21.3=pyhd8ed1ab_0 134 | - pandas=1.4.1=py39h295c915_0 135 | - pcre=8.45=h9c3ff4c_0 136 | - pillow=9.0.1=py39h22f2fdc_0 137 | - pip=21.2.4=py39h06a4308_0 138 | - pixman=0.40.0=h36c2ea0_0 139 | - protobuf=3.19.4=py39he80948d_0 140 | - pthread-stubs=0.4=h36c2ea0_1001 141 | - py-opencv=4.5.5=py39hef51801_0 142 | - pyasn1=0.4.8=py_0 143 | - pyasn1-modules=0.2.7=py_0 144 | - pycparser=2.21=pyhd3eb1b0_0 145 | - pyjwt=2.3.0=pyhd8ed1ab_1 146 | - pyopenssl=22.0.0=pyhd3eb1b0_0 147 | - pyparsing=3.0.7=pyhd8ed1ab_0 148 | - pysocks=1.7.1=py39h06a4308_0 149 | - python=3.9.10=h85951f9_2_cpython 150 | - python-dateutil=2.8.1=py_0 151 | - python_abi=3.9=2_cp39 152 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 153 | - pytorch-mutex=1.0=cuda 154 | - pytz=2020.1=py_0 155 | - pyu2f=0.1.5=pyhd8ed1ab_0 156 | - qt=5.12.9=ha98a1a1_5 157 | - readline=8.1.2=h7f8727e_1 158 | - requests=2.27.1=pyhd3eb1b0_0 159 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 160 | - rsa=4.8=pyhd8ed1ab_0 161 | - setuptools=58.0.4=py39h06a4308_0 162 | - six=1.16.0=pyhd3eb1b0_1 163 | - sqlite=3.37.2=hc218d9a_0 164 | - tensorboard=2.8.0=pyhd8ed1ab_1 165 | - tensorboard-data-server=0.6.0=py39h95dcef6_1 166 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 167 | - tk=8.6.11=h1ccaba5_0 168 | - torchaudio=0.11.0=py39_cu113 169 | - torchvision=0.12.0=py39_cu113 170 | - tqdm=4.63.0=pyhd8ed1ab_0 171 | - typing-extensions=3.10.0.2=hd3eb1b0_0 172 | - typing_extensions=3.10.0.2=pyh06a4308_0 173 | - tzdata=2021e=hda174b7_0 174 | - urllib3=1.26.8=pyhd3eb1b0_0 175 | - werkzeug=2.0.3=pyhd8ed1ab_1 176 | - wheel=0.37.1=pyhd3eb1b0_0 177 | - x264=1!161.3030=h7f98852_1 178 | - xorg-fixesproto=5.0=h7f98852_1002 179 | - xorg-inputproto=2.3.2=h7f98852_1002 180 | - xorg-kbproto=1.0.7=h7f98852_1002 181 | - xorg-libice=1.0.10=h7f98852_0 182 | - xorg-libsm=1.2.3=hd9c2040_1000 183 | - xorg-libx11=1.7.2=h7f98852_0 184 | - xorg-libxau=1.0.9=h7f98852_0 185 | - xorg-libxdmcp=1.1.3=h7f98852_0 186 | - xorg-libxext=1.3.4=h7f98852_1 187 | - xorg-libxfixes=5.0.3=h7f98852_1004 188 | - xorg-libxi=1.7.10=h7f98852_0 189 | - xorg-libxrender=0.9.10=h7f98852_1003 190 | - xorg-renderproto=0.11.1=h7f98852_1002 191 | - xorg-xextproto=7.3.0=h7f98852_1002 192 | - xorg-xproto=7.0.31=h7f98852_1007 193 | - xz=5.2.5=h7b6447c_0 194 | - yarl=1.7.2=py39h3811e60_1 195 | - zipp=3.7.0=pyhd8ed1ab_1 196 | - zlib=1.2.11=h36c2ea0_1013 197 | - zstd=1.5.2=ha95c52a_0 198 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils import angle_between_quaternions, l1_loss, l2_loss, compute_ABC, project 4 | 5 | 6 | class LocalHomographyLoss(torch.nn.Module): 7 | def __init__(self, device='cpu'): 8 | super().__init__() 9 | 10 | # `c_n` is the normal vector of the plane inducing the homographies in the ground-truth camera frame 11 | self.c_n = torch.tensor([0, 0, -1], dtype=torch.float32, device=device).view(3, 1) 12 | 13 | # `eye` is the (3, 3) identity matrix 14 | self.eye = torch.eye(3, device=device) 15 | 16 | def __call__(self, batch): 17 | A, B, C = compute_ABC(batch['w_t_c'], batch['c_R_w'], batch['w_t_chat'], batch['chat_R_w'], self.c_n, self.eye) 18 | 19 | xmin = batch['xmin'].view(-1, 1, 1) 20 | xmax = batch['xmax'].view(-1, 1, 1) 21 | B_weight = torch.log(xmax / xmin) / (xmax - xmin) 22 | C_weight = xmin * xmax 23 | 24 | error = A + B * B_weight + C / C_weight 25 | error = error.diagonal(dim1=1, dim2=2).sum(dim=1).mean() 26 | return error 27 | 28 | 29 | class GlobalHomographyLoss(torch.nn.Module): 30 | def __init__(self, xmin, xmax, device='cpu'): 31 | """ 32 | `xmin` is the minimum distance of observations across all frames. 33 | `xmax` is the maximum distance of observations across all frames. 34 | """ 35 | super().__init__() 36 | 37 | # `xmin` is the minimum distance of observations in all frames 38 | xmin = torch.tensor(xmin, dtype=torch.float32, device=device) 39 | 40 | # `xmax` is the maximum distance of observations in all frames 41 | xmax = torch.tensor(xmax, dtype=torch.float32, device=device) 42 | 43 | # `B_weight` and `C_weight` are the weigths of matrices A and B computed from `xmin` and `xmax` 44 | self.B_weight = torch.log(xmin / xmax) / (xmax - xmin) 45 | self.C_weight = xmin * xmax 46 | 47 | # `c_n` is the normal vector of the plane inducing the homographies in the ground-truth camera frame 48 | self.c_n = torch.tensor([0, 0, -1], dtype=torch.float32, device=device).view(3, 1) 49 | 50 | # `eye` is the (3, 3) identity matrix 51 | self.eye = torch.eye(3, device=device) 52 | 53 | def __call__(self, batch): 54 | A, B, C = compute_ABC(batch['w_t_c'], batch['c_R_w'], batch['w_t_chat'], batch['chat_R_w'], self.c_n, self.eye) 55 | 56 | error = A + B * self.B_weight + C / self.C_weight 57 | error = error.diagonal(dim1=1, dim2=2).sum(dim=1).mean() 58 | return error 59 | 60 | 61 | class PoseNetLoss(torch.nn.Module): 62 | def __init__(self, beta): 63 | super().__init__() 64 | self.beta = beta 65 | 66 | def __call__(self, batch): 67 | t_error = l2_loss(batch['w_t_chat'], batch['w_t_c']) 68 | q_error = l2_loss(batch['chat_q_w'], batch['c_q_w']) 69 | error = t_error + self.beta * q_error 70 | return error 71 | 72 | 73 | class HomoscedasticLoss(torch.nn.Module): 74 | def __init__(self, s_hat_t, s_hat_q, device='cpu'): 75 | super().__init__() 76 | self.s_hat_t = torch.nn.Parameter(torch.tensor(s_hat_t, dtype=torch.float32, device=device)) 77 | self.s_hat_q = torch.nn.Parameter(torch.tensor(s_hat_q, dtype=torch.float32, device=device)) 78 | 79 | def __call__(self, batch): 80 | LtI = l1_loss(batch['w_t_chat'], batch['w_t_c']) 81 | LqI = l1_loss(batch['normalized_chat_q_w'], batch['c_q_w']) 82 | error = LtI * torch.exp(-self.s_hat_t) + self.s_hat_t + LqI * torch.exp(-self.s_hat_q) + self.s_hat_q 83 | return error 84 | 85 | 86 | class GeometricLoss(torch.nn.Module): 87 | def __init__(self): 88 | super().__init__() 89 | 90 | def __call__(self, batch): 91 | error = 0 92 | for w_t_c, c_R_w, w_t_chat, chat_R_w, w_P in zip(batch['w_t_c'], batch['c_R_w'], batch['w_t_chat'], 93 | batch['chat_R_w'], batch['w_P']): 94 | c_p = project(w_t_c, c_R_w, w_P) 95 | chat_p = project(w_t_chat, chat_R_w, w_P) 96 | error += l1_loss(chat_p.T, c_p.T, reduce='none').clip(0, 100).mean() 97 | error = error / batch['w_t_c'].shape[0] 98 | return error 99 | 100 | 101 | class DSACLoss(torch.nn.Module): 102 | def __init__(self): 103 | super().__init__() 104 | 105 | def __call__(self, batch): 106 | t_error = 100 * l2_loss(batch['w_t_chat'], batch['w_t_c'], reduce='none') 107 | q_error = angle_between_quaternions(batch['normalized_chat_q_w'], batch['c_q_w']).rad2deg() 108 | error = torch.max( 109 | t_error.view(-1), 110 | q_error 111 | ).mean() 112 | return error 113 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import tqdm 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import datasets 12 | import losses 13 | import models 14 | from utils import batch_to_device, batch_errors, batch_compute_utils, log_poses, log_errors 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser.add_argument( 19 | 'path', metavar='DATA_PATH', 20 | help='path to the dataset directory, e.g. "/home/data/KingsCollege"' 21 | ) 22 | parser.add_argument( 23 | '--loss', help='loss function for training', 24 | choices=['local_homography', 'global_homography', 'posenet', 'homoscedastic', 'geometric', 'dsac'], 25 | default='local_homography' 26 | ) 27 | parser.add_argument('--epochs', help='number of epochs for training', type=int, default=5000) 28 | parser.add_argument('--batch_size', help='training batch size', type=int, default=64) 29 | parser.add_argument('--xmin_percentile', help='xmin depth percentile', type=float, default=0.025) 30 | parser.add_argument('--xmax_percentile', help='xmax depth percentile', type=float, default=0.975) 31 | parser.add_argument( 32 | '--weights', metavar='WEIGHTS_PATH', 33 | help='path to weights with which the model will be initialized' 34 | ) 35 | parser.add_argument( 36 | '--device', default='cpu', 37 | help='set the device to train the model, `cuda` for GPU' 38 | ) 39 | args = parser.parse_args() 40 | 41 | # Set seed for reproductibility 42 | seed = 1 43 | random.seed(seed) 44 | np.random.seed(seed) 45 | torch.manual_seed(seed) 46 | 47 | # Load model 48 | model = models.load_model(args.weights) 49 | model.train() 50 | model.to(args.device) 51 | 52 | # Load dataset 53 | dataset_name = os.path.basename(os.path.normpath(args.path)) 54 | if dataset_name in ['GreatCourt', 'KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch', 'Street']: 55 | dataset = datasets.CambridgeDataset(args.path, args.xmin_percentile, args.xmax_percentile) 56 | elif dataset_name in ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']: 57 | dataset = datasets.SevenScenesDataset(args.path, args.xmin_percentile, args.xmax_percentile) 58 | else: 59 | dataset = datasets.COLMAPDataset(args.path, args.xmin_percentile, args.xmax_percentile) 60 | 61 | # Wrapper for use with PyTorch's DataLoader 62 | train_dataset = datasets.RelocDataset(dataset.train_data) 63 | test_dataset = datasets.RelocDataset(dataset.test_data) 64 | 65 | # Creating data loaders for train and test data 66 | train_loader = DataLoader( 67 | train_dataset, 68 | batch_size=args.batch_size, 69 | shuffle=True, 70 | pin_memory=True, 71 | collate_fn=datasets.collate_fn, 72 | drop_last=True 73 | ) 74 | test_loader = DataLoader( 75 | test_dataset, 76 | batch_size=args.batch_size, 77 | shuffle=False, 78 | pin_memory=True, 79 | collate_fn=datasets.collate_fn 80 | ) 81 | 82 | # Adam optimizer default epsilon parameter is 1e-8 83 | eps = 1e-8 84 | 85 | # Instantiate loss 86 | if args.loss == 'local_homography': 87 | criterion = losses.LocalHomographyLoss(device=args.device) 88 | eps = 1e-14 # Adam optimizer epsilon is set to 1e-14 for homography losses 89 | elif args.loss == 'global_homography': 90 | criterion = losses.GlobalHomographyLoss( 91 | xmin=dataset.train_global_xmin, 92 | xmax=dataset.train_global_xmax, 93 | device=args.device 94 | ) 95 | eps = 1e-14 # Adam optimizer epsilon is set to 1e-14 for homography losses 96 | elif args.loss == 'posenet': 97 | criterion = losses.PoseNetLoss(beta=500) 98 | elif args.loss == 'homoscedastic': 99 | criterion = losses.HomoscedasticLoss(s_hat_t=0.0, s_hat_q=-3.0, device=args.device) 100 | elif args.loss == 'geometric': 101 | criterion = losses.GeometricLoss() 102 | elif args.loss == 'dsac': 103 | criterion = losses.DSACLoss() 104 | else: 105 | raise Exception(f'Loss {args.loss} not recognized...') 106 | 107 | # Instantiate adam optimizer 108 | optimizer = torch.optim.Adam(list(model.parameters()) + list(criterion.parameters()), lr=1e-4, eps=eps) 109 | 110 | # Set up tensorboard 111 | writer = SummaryWriter(os.path.join('logs', os.path.basename(os.path.normpath(args.path)), args.loss)) 112 | 113 | # Set up folder to save weights 114 | if not os.path.exists(os.path.join(writer.log_dir, 'weights')): 115 | os.makedirs(os.path.join(writer.log_dir, 'weights')) 116 | 117 | # Set up file to save logs 118 | log_file_path = os.path.join(writer.log_dir, 'epochs_poses_log.csv') 119 | with open(log_file_path, mode='w') as log_file: 120 | log_file.write('epoch,image_file,type,w_tx_chat,w_ty_chat,w_tz_chat,chat_qw_w,chat_qx_w,chat_qy_w,chat_qz_w\n') 121 | 122 | print('Start training...') 123 | for epoch in tqdm.tqdm(range(args.epochs)): 124 | epoch_loss = 0 125 | errors = {} 126 | 127 | for batch in train_loader: 128 | optimizer.zero_grad() 129 | 130 | # Move all batch data to proper device 131 | batch = batch_to_device(batch, args.device) 132 | 133 | # Estimate the pose from the image 134 | batch['w_t_chat'], batch['chat_q_w'] = model(batch['image']).split([3, 4], dim=1) 135 | 136 | # Computes useful data for our batch 137 | # - Normalized quaternion 138 | # - Rotation matrix from this normalized quaternion 139 | # - Reshapes translation component to fit shape (batch_size, 3, 1) 140 | batch_compute_utils(batch) 141 | 142 | # Compute loss 143 | loss = criterion(batch) 144 | 145 | # Backprop 146 | loss.backward() 147 | optimizer.step() 148 | 149 | # Add current batch loss to epoch loss 150 | epoch_loss += loss.item() / len(train_loader) 151 | 152 | # Compute training batch errors and log poses 153 | with torch.no_grad(): 154 | batch_errors(batch, errors) 155 | 156 | with open(log_file_path, mode='a') as log_file: 157 | log_poses(log_file, batch, epoch, 'train') 158 | 159 | # Log epoch loss 160 | writer.add_scalar('train loss', epoch_loss, epoch) 161 | 162 | with torch.no_grad(): 163 | 164 | # Log train errors 165 | log_errors(errors, writer, epoch, 'train') 166 | 167 | # Set the model to eval mode for test data 168 | model.eval() 169 | errors = {} 170 | 171 | for batch in test_loader: 172 | # Compute test poses estimations 173 | batch = batch_to_device(batch, args.device) 174 | batch['w_t_chat'], batch['chat_q_w'] = model(batch['image']).split([3, 4], dim=1) 175 | batch_compute_utils(batch) 176 | 177 | # Log test poses 178 | with open(log_file_path, mode='a') as log_file: 179 | log_poses(log_file, batch, epoch, 'test') 180 | 181 | # Compute test errors 182 | batch_errors(batch, errors) 183 | 184 | # Log test errors 185 | log_errors(errors, writer, epoch, 'test') 186 | 187 | # Log loss parameters, if there are any 188 | for p_name, p in criterion.named_parameters(): 189 | writer.add_scalar(p_name, p, epoch) 190 | 191 | writer.flush() 192 | model.train() 193 | 194 | # Save model and optimizer weights every n and last epochs: 195 | if epoch % 500 == 0 or epoch == args.epochs - 1: 196 | torch.save({ 197 | 'model_state_dict': model.state_dict(), 198 | 'optimizer_state_dict': optimizer.state_dict(), 199 | 'criterion_state_dict': criterion.state_dict() 200 | }, os.path.join(writer.log_dir, 'weights', f'epoch_{epoch}.pth')) 201 | 202 | writer.close() 203 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def load_model(weights_path=None): 5 | """ 6 | Loads MobileNetV2 pre-trained on ImageNet from PyTorch's cloud. 7 | Modifies last layers to fit our pose regression problem. 8 | """ 9 | # Base model is MobileNetV2 from PyTorch's hub 10 | model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True) 11 | 12 | # We modify the classifier of MobileNetV2 with a custom regressor 13 | in_features = list(model.classifier.children())[-1].in_features 14 | model.classifier = torch.nn.Sequential( 15 | torch.nn.ReLU(), 16 | torch.nn.Linear( 17 | in_features=in_features, 18 | out_features=2048, 19 | bias=True 20 | ), 21 | torch.nn.ReLU(), 22 | torch.nn.Linear( 23 | in_features=2048, 24 | out_features=7, 25 | bias=True 26 | ) 27 | ) 28 | if weights_path is not None: 29 | model.load_state_dict(torch.load(weights_path)['model_state_dict']) 30 | return model 31 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import normalize 3 | from kornia.geometry.conversions import quaternion_to_rotation_matrix, QuaternionCoeffOrder 4 | 5 | 6 | def angle_between_quaternions(q, r): 7 | """ 8 | Works on batchs of quaternions only. 9 | `q` and `r` must be batchs of unit quaternions with shape (n, 4). 10 | """ 11 | return 2 * torch.sum(q * r, dim=1).abs().clip(0, 1).arccos() 12 | 13 | 14 | def l1_loss(input, target, reduce='mean'): 15 | """ 16 | Computes batch L1 loss with `reduce` reduction. 17 | `input` and `target` must have shape (batch_size, *). 18 | L1 norm will be computed for each element on the batch. 19 | """ 20 | loss = torch.abs(target - input).sum(dim=1) 21 | if reduce == 'none': 22 | return loss 23 | elif reduce == 'mean': 24 | return loss.mean() 25 | else: 26 | raise Exception(f'Reduction method {reduce} not known') 27 | 28 | 29 | def l2_loss(input, target, reduce='mean'): 30 | """ 31 | Computes batch L2 loss with `reduce` reduction. 32 | `input` and `target` must have shape (batch_size, *). 33 | L2 norm will be computed for each element on the batch. 34 | """ 35 | loss = torch.square(target - input).sum(dim=1).sqrt() 36 | if reduce == 'none': 37 | return loss 38 | elif reduce == 'mean': 39 | return loss.mean() 40 | else: 41 | raise Exception(f'Reduction method {reduce} not known') 42 | 43 | 44 | def compute_ABC(w_t_c, c_R_w, w_t_chat, chat_R_w, c_n, eye): 45 | """ 46 | Computes A, B, and C matrix given estimated and ground truth poses 47 | and normal vector n. 48 | `w_t_c` and `w_t_chat` must have shape (batch_size, 3, 1). 49 | `c_R_w` and `chat_R_w` must have shape (batch_size, 3, 3). 50 | `n` must have shape (3, 1). 51 | `eye` is the (3, 3) identity matrix on the proper device. 52 | """ 53 | chat_t_c = chat_R_w @ (w_t_c - w_t_chat) 54 | chat_R_c = chat_R_w @ c_R_w.transpose(1, 2) 55 | 56 | A = eye - chat_R_c 57 | C = c_n @ chat_t_c.transpose(1, 2) 58 | B = C @ A 59 | A = A @ A.transpose(1, 2) 60 | B = B + B.transpose(1, 2) 61 | C = C @ C.transpose(1, 2) 62 | 63 | return A, B, C 64 | 65 | 66 | def batch_to_device(batch, device): 67 | """ 68 | If `device` is not 'cpu', moves all data in batch to the GPU. 69 | """ 70 | if device != 'cpu': 71 | for key, value in batch.items(): 72 | if type(value) is torch.Tensor: 73 | batch[key] = value.to(device) 74 | elif type(value[0]) is torch.Tensor: 75 | for index_value, value_value in enumerate(value): 76 | value[index_value] = value_value.to(device) 77 | return batch 78 | 79 | 80 | def project(w_t_c, c_R_w, w_P, K=None): 81 | """ 82 | Projects 3D points P expressed in frame w onto frame c camera view. 83 | `w_t_c` is the (3, 1) shaped translation from frame c to frame w. 84 | `c_R_w` is the (3, 3) rotation matrix from frame w to frame c. 85 | `w_P` are (n, 3) shaped 3D points P expressed in the w frame. 86 | `K` is frame c camera matrix. 87 | """ 88 | c_p = c_R_w @ (w_P.T - w_t_c) 89 | if K is not None: 90 | c_p = K @ c_p 91 | c_p = c_p[:2] / c_p[2] 92 | return c_p 93 | 94 | 95 | def batch_errors(batch, errors): 96 | """ 97 | Computes translation, rotation and reprojection errors for the batch. 98 | """ 99 | b_errors = { 100 | 't_errors': [l2_loss(batch['w_t_chat'], batch['w_t_c'], reduce='none').squeeze()], 101 | 'q_errors': [angle_between_quaternions(batch['normalized_chat_q_w'], batch['c_q_w'])], 102 | 'reprojection_error_sum': 0, 103 | 'reprojection_distance_sum': 0, 104 | 'l1_reprojection_error_sum': 0, 105 | 'n_points': 0 106 | } 107 | 108 | for w_t_chat, chat_R_w, w_P, c_p, K in zip(batch['w_t_chat'], batch['chat_R_w'], batch['w_P'], 109 | batch['c_p'], batch['K']): 110 | chat_p = project(w_t_chat, chat_R_w, w_P, K=K) 111 | diff = chat_p.T - c_p 112 | reprojection_errors = torch.square(diff).sum(dim=1).clip(0, 1000000) 113 | b_errors['reprojection_error_sum'] += reprojection_errors.sum() 114 | b_errors['reprojection_distance_sum'] += reprojection_errors.sqrt().sum() 115 | b_errors['l1_reprojection_error_sum'] += torch.abs(diff).sum(dim=1).clip(0, 1000000).sum() 116 | b_errors['n_points'] += c_p.shape[0] 117 | 118 | if len(errors) == 0: 119 | for key, value in b_errors.items(): 120 | errors[key] = value 121 | else: 122 | for key, value in b_errors.items(): 123 | errors[key] += value 124 | 125 | 126 | def batch_compute_utils(batch): 127 | """ 128 | Computes inplace useful data for the batch. 129 | - Computes a normalized quaternion, and its corresponding rotation matrix. 130 | - Reshapes translation component to fit shape (batchs_size, 3, 1). 131 | """ 132 | batch['w_t_chat'] = batch['w_t_chat'].view(-1, 3, 1) 133 | batch['normalized_chat_q_w'] = normalize(batch['chat_q_w'], dim=1) 134 | batch['chat_R_w'] = quaternion_to_rotation_matrix(batch['chat_q_w'], order=QuaternionCoeffOrder.WXYZ) 135 | 136 | 137 | def log_poses(log_file, batch, epoch, data_type): 138 | """ 139 | Logs batch estimated poses in log file. 140 | """ 141 | log_file.write('\n'.join([ 142 | f'{epoch},{image_file},{data_type},{",".join(map(str, w_t_chat.squeeze().tolist()))},' 143 | f'{",".join(map(str, chat_q_w.tolist()))}' 144 | for image_file, w_t_chat, chat_q_w in 145 | zip(batch['image_file'], batch['w_t_chat'], batch['chat_q_w'])]) + '\n' 146 | ) 147 | 148 | 149 | def log_errors(errors, writer, epoch, data_type): 150 | """ 151 | Logs epoch poses errors in tensorboard. 152 | """ 153 | t_errors = torch.hstack(errors['t_errors']) 154 | q_errors = torch.hstack(errors['q_errors']).rad2deg() 155 | 156 | writer.add_scalar(f'{data_type} distance median', t_errors.median(), epoch) 157 | writer.add_scalar(f'{data_type} angle median', q_errors.median(), epoch) 158 | writer.add_scalar( 159 | f'{data_type} mean reprojection error', 160 | errors['reprojection_error_sum'] / errors['n_points'], epoch 161 | ) 162 | writer.add_scalar( 163 | f'{data_type} mean reprojection distance', 164 | errors['reprojection_distance_sum'] / errors['n_points'], epoch 165 | ) 166 | writer.add_scalar( 167 | f'{data_type} mean l1 reprojection error', 168 | errors['l1_reprojection_error_sum'] / errors['n_points'], epoch 169 | ) 170 | 171 | for meter_threshold, deg_threshold in zip( 172 | [0.05, 0.15, 0.25, 0.5, 0.01, 0.02, 0.03, 0.05, 0.25, 0.5, 5], 173 | [2, 5, 10, 15, 1, 2, 3, 5, 2, 5, 10] 174 | ): 175 | score = torch.logical_and( 176 | t_errors <= meter_threshold, q_errors <= deg_threshold 177 | ).sum() / t_errors.shape[0] 178 | writer.add_scalar( 179 | f'{data_type} percentage localized within {meter_threshold}m, {deg_threshold}deg', 180 | score, epoch 181 | ) 182 | --------------------------------------------------------------------------------