├── .gitignore ├── LICENSE ├── README.md ├── augmentations.py ├── commons.py ├── cosface_loss.py ├── datasets ├── __init__.py ├── dataset_utils.py ├── eigenplaces_dataset.py ├── map_utils.py └── test_dataset.py ├── eigenplaces_model ├── __init__.py ├── eigenplaces_network.py └── layers.py ├── eval.py ├── hubconf.py ├── parser.py ├── requirements.txt ├── test.py ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .spyproject 2 | .idea 3 | __pycache__ 4 | logs 5 | cache 6 | jobs 7 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gabriele Berton, Gabriele Trivigno, Carlo Masone, Barbara Caputo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # EigenPlaces: Training Viewpoint Robust Models for Visual Place Recognition 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/eigenplaces-training-viewpoint-robust-models/visual-place-recognition-on-amstertime)](https://paperswithcode.com/sota/visual-place-recognition-on-amstertime?p=eigenplaces-training-viewpoint-robust-models) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/eigenplaces-training-viewpoint-robust-models/visual-place-recognition-on-eynsham)](https://paperswithcode.com/sota/visual-place-recognition-on-eynsham?p=eigenplaces-training-viewpoint-robust-models) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/eigenplaces-training-viewpoint-robust-models/visual-place-recognition-on-pittsburgh-30k)](https://paperswithcode.com/sota/visual-place-recognition-on-pittsburgh-30k?p=eigenplaces-training-viewpoint-robust-models) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/eigenplaces-training-viewpoint-robust-models/visual-place-recognition-on-san-francisco)](https://paperswithcode.com/sota/visual-place-recognition-on-san-francisco?p=eigenplaces-training-viewpoint-robust-models) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/eigenplaces-training-viewpoint-robust-models/visual-place-recognition-on-sf-xl-test-v1)](https://paperswithcode.com/sota/visual-place-recognition-on-sf-xl-test-v1?p=eigenplaces-training-viewpoint-robust-models) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/eigenplaces-training-viewpoint-robust-models/visual-place-recognition-on-tokyo247)](https://paperswithcode.com/sota/visual-place-recognition-on-tokyo247?p=eigenplaces-training-viewpoint-robust-models) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/eigenplaces-training-viewpoint-robust-models/visual-place-recognition-on-pittsburgh-250k)](https://paperswithcode.com/sota/visual-place-recognition-on-pittsburgh-250k?p=eigenplaces-training-viewpoint-robust-models) 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/eigenplaces-training-viewpoint-robust-models/visual-place-recognition-on-sf-xl-test-v2)](https://paperswithcode.com/sota/visual-place-recognition-on-sf-xl-test-v2?p=eigenplaces-training-viewpoint-robust-models) 12 | 13 | This is the official pyTorch implementation of the ICCV 2023 paper "EigenPlaces: Training Viewpoint Robust Models for Visual Place Recognition". 14 | The paper presents a new training method which aims at providing samples from multiple viewpoints to the model, to make it robust to camera viewpoint changes. It achieves SOTA on any dataset with large viewpoint shifts between query images and database. 15 | 16 | For the paper we also released a codebase to reproduce results with all other baselines (NetVLAD, SFRS, Conv-AP, CosPlace, MixVPR) in order to have a standardized and fair evaluation framework at [https://github.com/gmberton/VPR-methods-evaluation](https://github.com/gmberton/VPR-methods-evaluation) 17 | 18 | [[ICCV 2023 Open Access](https://openaccess.thecvf.com/content/ICCV2023/html/Berton_EigenPlaces_Training_Viewpoint_Robust_Models_for_Visual_Place_Recognition_ICCV_2023_paper.html)] [[ArXiv](https://arxiv.org/abs/2308.10832)] [[BibTex](https://github.com/gmberton/EigenPlaces#cite)] 19 | 20 |

21 | 22 | 23 | 24 | 25 |

26 | 27 | 28 | ## Train 29 | Training is performed on the SF-XL dataset, which you can download from [here](https://github.com/gmberton/CosPlace). Make sure to download the training panoramas, which EigenPlaces takes as input and automatically crops with the required orientation. 30 | After downloading the SF-XL dataset, simply run 31 | 32 | `$ python3 train.py --train_dataset_folder path/to/sf_xl/raw/train/panoramas --val_dataset_folder path/to/sf_xl/processed/val --test_dataset_folder path/to/sf_xl/processed/test` 33 | 34 | the script automatically splits SF-XL in CosPlace Groups, and saves the resulting object in the folder `cache`. 35 | By default training is performed with a ResNet-18 with descriptors dimensionality 512 and AMP, which uses less than 8GB of VRAM. 36 | 37 | To change the backbone or the output descriptors dimensionality simply run something like this 38 | 39 | `$ python3 train.py --backbone ResNet50 --fc_output_dim 128` 40 | 41 | Run `$ python3 train.py -h` to have a look at all the hyperparameters that you can change. You will find all hyperparameters mentioned in the paper. 42 | 43 | ## Test 44 | You can test one of our trained models as such (downloads the model from torch.hub) 45 | 46 | `$ python3 eval.py --backbone ResNet50 --fc_output_dim 2048 --resume_model torchhub` 47 | 48 | or a model trained by you like this 49 | 50 | `$ python3 eval.py --backbone ResNet50 --fc_output_dim 2048 --resume_model path/to/best_model.pth` 51 | 52 | ## Trained Models 53 | 54 | We have all our trained models on [PyTorch Hub](https://pytorch.org/docs/stable/hub.html), so that you can use them in any codebase without cloning this repository simply like this 55 | ``` 56 | import torch 57 | model = torch.hub.load("gmberton/eigenplaces", "get_trained_model", backbone="ResNet50", fc_output_dim=2048) 58 | ``` 59 | 60 | Available trained models are ResNet18 (with output dim 256 or 512), ResNet50 (output dim 128, 256, 512 or 2048), ResNet101 (output dim 128, 256, 512 or 2048) and VGG16 (output dim 512). 61 | 62 | 63 | ## Acknowledgements 64 | Parts of this repo are inspired by the following repositories: 65 | - [CosFace implementation in PyTorch](https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py) 66 | - [CNN Image Retrieval in PyTorch](https://github.com/filipradenovic/cnnimageretrieval-pytorch) (for the GeM layer) 67 | - [Visual Geo-localization benchmark](https://github.com/gmberton/deep-visual-geo-localization-benchmark) (for the evaluation / test code) 68 | - [CosPlace](https://github.com/gmberton/EigenPlaces) 69 | 70 | ## Cite 71 | Here is the bibtex to cite our paper 72 | ``` 73 | @inproceedings{Berton_2023_EigenPlaces, 74 | title={EigenPlaces: Training Viewpoint Robust Models for Visual Place Recognition}, 75 | author={Berton, Gabriele and Trivigno, Gabriele and Caputo, Barbara and Masone, Carlo}, 76 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 77 | year={2023}, 78 | month={October}, 79 | pages={11080-11090} 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /augmentations.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from typing import Tuple, Union 4 | import torchvision.transforms as T 5 | 6 | 7 | class DeviceAgnosticColorJitter(T.ColorJitter): 8 | def __init__(self, brightness: float = 0., contrast: float = 0., saturation: float = 0., hue: float = 0.): 9 | """This is the same as T.ColorJitter but it only accepts batches of images and works on GPU""" 10 | super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 11 | 12 | def forward(self, images: torch.Tensor) -> torch.Tensor: 13 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}" 14 | B, C, H, W = images.shape 15 | # Applies a different color jitter to each image 16 | color_jitter = super(DeviceAgnosticColorJitter, self).forward 17 | augmented_images = [color_jitter(img).unsqueeze(0) for img in images] 18 | augmented_images = torch.cat(augmented_images) 19 | assert augmented_images.shape == torch.Size([B, C, H, W]) 20 | return augmented_images 21 | 22 | 23 | class DeviceAgnosticRandomResizedCrop(T.RandomResizedCrop): 24 | def __init__(self, size: Union[int, Tuple[int, int]], scale: float): 25 | """This is the same as T.RandomResizedCrop but it only accepts batches of images and works on GPU""" 26 | super().__init__(size=size, scale=scale, antialias=True) 27 | 28 | def forward(self, images: torch.Tensor) -> torch.Tensor: 29 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}" 30 | B, C, H, W = images.shape 31 | # Applies a different color jitter to each image 32 | random_resized_crop = super(DeviceAgnosticRandomResizedCrop, self).forward 33 | augmented_images = [random_resized_crop(img).unsqueeze(0) for img in images] 34 | augmented_images = torch.cat(augmented_images) 35 | return augmented_images 36 | 37 | 38 | if __name__ == "__main__": 39 | """ 40 | You can run this script to visualize the transformations, and verify that 41 | the augmentations are applied individually on each image of the batch. 42 | """ 43 | from PIL import Image 44 | # Import skimage in here, so it is not necessary to install it unless you run this script 45 | from skimage import data 46 | 47 | # Initialize DeviceAgnosticRandomResizedCrop 48 | random_crop = DeviceAgnosticRandomResizedCrop(size=[256, 256], scale=[0.5, 1]) 49 | # Create a batch with 2 astronaut images 50 | pil_image = Image.fromarray(data.astronaut()) 51 | tensor_image = T.functional.to_tensor(pil_image).unsqueeze(0) 52 | images_batch = torch.cat([tensor_image, tensor_image]) 53 | # Apply augmentation (individually on each of the 2 images) 54 | augmented_batch = random_crop(images_batch) 55 | # Convert to PIL images 56 | augmented_image_0 = T.functional.to_pil_image(augmented_batch[0]) 57 | augmented_image_1 = T.functional.to_pil_image(augmented_batch[1]) 58 | # Visualize the original image, as well as the two augmented ones 59 | pil_image.show() 60 | augmented_image_0.show() 61 | augmented_image_1.show() 62 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import torch 5 | import random 6 | import logging 7 | import traceback 8 | import numpy as np 9 | 10 | 11 | class InfiniteDataLoader(torch.utils.data.DataLoader): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.dataset_iterator = super().__iter__() 15 | 16 | def __iter__(self): 17 | return self 18 | 19 | def __next__(self): 20 | try: 21 | batch = next(self.dataset_iterator) 22 | except StopIteration: 23 | self.dataset_iterator = super().__iter__() 24 | batch = next(self.dataset_iterator) 25 | return batch 26 | 27 | 28 | def make_deterministic(seed: int = 0): 29 | """Make results deterministic. If seed == -1, do not make deterministic. 30 | Running your script in a deterministic way might slow it down. 31 | Note that for some packages (eg: sklearn's PCA) this function is not enough. 32 | """ 33 | seed = int(seed) 34 | if seed == -1: 35 | return 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | 44 | def setup_logging(output_folder: str, exist_ok: bool = False, console: str = "debug", 45 | info_filename: str = "info.log", debug_filename: str = "debug.log"): 46 | """Set up logging files and console output. 47 | Creates one file for INFO logs and one for DEBUG logs. 48 | Args: 49 | output_folder (str): creates the folder where to save the files. 50 | exist_ok (boolean): if False throw a FileExistsError if output_folder already exists 51 | debug (str): 52 | if == "debug" prints on console debug messages and higher 53 | if == "info" prints on console info messages and higher 54 | if == None does not use console (useful when a logger has already been set) 55 | info_filename (str): the name of the info file. if None, don't create info file 56 | debug_filename (str): the name of the debug file. if None, don't create debug file 57 | """ 58 | if not exist_ok and os.path.exists(output_folder): 59 | raise FileExistsError(f"{output_folder} already exists!") 60 | os.makedirs(output_folder, exist_ok=True) 61 | import matplotlib 62 | logging.getLogger('matplotlib').disabled = True 63 | logging.getLogger('PIL').disabled = True 64 | logging.getLogger('staticmap').disabled = True 65 | logging.getLogger('requests').disabled = True 66 | logging.getLogger('PIL').setLevel(logging.WARNING) 67 | import requests 68 | logging.getLogger("requests").setLevel(logging.WARNING) 69 | import urllib3 70 | logging.getLogger("urllib3").setLevel(logging.WARNING) 71 | 72 | logging.getLogger('matplotlib.font_manager').disabled = True 73 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S") 74 | logger = logging.getLogger('') 75 | logger.setLevel(logging.DEBUG) 76 | 77 | if info_filename is not None: 78 | info_file_handler = logging.FileHandler(f'{output_folder}/{info_filename}') 79 | info_file_handler.setLevel(logging.INFO) 80 | info_file_handler.setFormatter(base_formatter) 81 | logger.addHandler(info_file_handler) 82 | 83 | if debug_filename is not None: 84 | debug_file_handler = logging.FileHandler(f'{output_folder}/{debug_filename}') 85 | debug_file_handler.setLevel(logging.DEBUG) 86 | debug_file_handler.setFormatter(base_formatter) 87 | logger.addHandler(debug_file_handler) 88 | 89 | if console is not None: 90 | console_handler = logging.StreamHandler() 91 | if console == "debug": 92 | console_handler.setLevel(logging.DEBUG) 93 | if console == "info": 94 | console_handler.setLevel(logging.INFO) 95 | console_handler.setFormatter(base_formatter) 96 | logger.addHandler(console_handler) 97 | 98 | def my_handler(type_, value, tb): 99 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb))) 100 | logging.info("Experiment finished (with some errors)") 101 | sys.excepthook = my_handler 102 | -------------------------------------------------------------------------------- /cosface_loss.py: -------------------------------------------------------------------------------- 1 | 2 | # Based on https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import Parameter 7 | 8 | 9 | def cosine_sim(x1: torch.Tensor, x2: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: 10 | ip = torch.mm(x1, x2.t()) 11 | w1 = torch.norm(x1, 2, dim) 12 | w2 = torch.norm(x2, 2, dim) 13 | return ip / torch.ger(w1, w2).clamp(min=eps) 14 | 15 | 16 | class MarginCosineProduct(nn.Module): 17 | """Implement of large margin cosine distance: 18 | Args: 19 | in_features: size of each input sample 20 | out_features: size of each output sample 21 | s: norm of input feature 22 | m: margin 23 | """ 24 | def __init__(self, in_features: int, out_features: int, s: float = 30.0, m: float = 0.40): 25 | super().__init__() 26 | self.in_features = in_features 27 | self.out_features = out_features 28 | self.s = s 29 | self.m = m 30 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 31 | nn.init.xavier_uniform_(self.weight) 32 | 33 | def forward(self, inputs: torch.Tensor, label: torch.Tensor) -> torch.Tensor: 34 | cosine = cosine_sim(inputs, self.weight) 35 | one_hot = torch.zeros_like(cosine) 36 | one_hot.scatter_(1, label.view(-1, 1), 1.0) 37 | output = self.s * (cosine - one_hot * self.m) 38 | return output 39 | 40 | def __repr__(self): 41 | return self.__class__.__name__ + '(' \ 42 | + 'in_features=' + str(self.in_features) \ 43 | + ', out_features=' + str(self.out_features) \ 44 | + ', s=' + str(self.s) \ 45 | + ', m=' + str(self.m) + ')' 46 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmberton/EigenPlaces/fef66475d45f9a65e7e10ba2c360c25396478f44/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import logging 4 | from glob import glob 5 | from PIL import ImageFile 6 | 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | 10 | def read_images_paths(dataset_folder, get_abs_path=False): 11 | """Find images within 'dataset_folder' and return their relative paths as a list. 12 | If there is a file 'dataset_folder'_images_paths.txt, read paths from such file. 13 | Otherwise, use glob(). Keeping the paths in the file speeds up computation, 14 | because using glob over large folders can be slow. 15 | 16 | Parameters 17 | ---------- 18 | dataset_folder : str, folder containing JPEG images 19 | get_abs_path : bool, if True return absolute paths, otherwise remove 20 | dataset_folder from each path 21 | 22 | Returns 23 | ------- 24 | images_paths : list[str], paths of JPEG images within dataset_folder 25 | """ 26 | 27 | if not os.path.exists(dataset_folder): 28 | raise FileNotFoundError(f"Folder {dataset_folder} does not exist") 29 | 30 | file_with_paths = dataset_folder + "_images_paths.txt" 31 | if os.path.exists(file_with_paths): 32 | logging.debug(f"Reading paths of images within {dataset_folder} from {file_with_paths}") 33 | with open(file_with_paths, "r") as file: 34 | images_paths = file.read().splitlines() 35 | images_paths = [dataset_folder + "/" + path for path in images_paths] 36 | # Sanity check that paths within the file exist 37 | if not os.path.exists(images_paths[0]): 38 | raise FileNotFoundError(f"Image with path {images_paths[0]} " 39 | f"does not exist within {dataset_folder}. It is likely " 40 | f"that the content of {file_with_paths} is wrong.") 41 | else: 42 | logging.debug(f"Searching images in {dataset_folder} with glob()") 43 | images_paths = sorted(glob(f"{dataset_folder}/**/*.jpg", recursive=True)) 44 | if len(images_paths) == 0: 45 | raise FileNotFoundError(f"Directory {dataset_folder} does not contain any JPEG images") 46 | 47 | if not get_abs_path: # Remove dataset_folder from the path 48 | images_paths = [p[len(dataset_folder) + 1:] for p in images_paths] 49 | 50 | return images_paths 51 | 52 | -------------------------------------------------------------------------------- /datasets/eigenplaces_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import utm 4 | import math 5 | import torch 6 | import random 7 | import imageio 8 | import logging 9 | import numpy as np 10 | from PIL import Image 11 | from PIL import ImageFile 12 | import torchvision.transforms as tfm 13 | from collections import defaultdict 14 | 15 | from datasets.map_utils import create_map 16 | import datasets.dataset_utils as dataset_utils 17 | 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True 19 | 20 | PANO_WIDTH = int(512*6.5) 21 | 22 | 23 | def get_angle(focal_point, obs_point): 24 | obs_e, obs_n = float(obs_point[0]), float(obs_point[1]) 25 | focal_e, focal_n = focal_point 26 | side1 = focal_e - obs_e 27 | side2 = focal_n - obs_n 28 | angle = - math.atan2(side1, side2) / math.pi * 90 * 2 29 | return angle 30 | 31 | 32 | def get_eigen_things(utm_coords): 33 | mu = utm_coords.mean(0) 34 | norm_data = utm_coords - mu 35 | eigenvectors, eigenvalues, v = np.linalg.svd(norm_data.T, full_matrices=False) 36 | return eigenvectors, eigenvalues, mu 37 | 38 | 39 | def rotate_2d_vector(vector, angle): 40 | assert vector.shape == (2,) 41 | theta = np.deg2rad(angle) 42 | rot_mat = np.array([[np.cos(theta), -np.sin(theta)], 43 | [np.sin(theta), np.cos(theta)]]) 44 | rotated_point = np.dot(rot_mat, vector) 45 | return rotated_point 46 | 47 | 48 | def get_focal_point(utm_coords, meters_from_center=20, angle=0): 49 | """Return the focal point from a set of utm coords""" 50 | B, D = utm_coords.shape 51 | assert D == 2 52 | eigenvectors, eigenvalues, mu = get_eigen_things(utm_coords) 53 | 54 | direction = rotate_2d_vector(eigenvectors[1], angle) 55 | focal_point = mu + direction * meters_from_center 56 | return focal_point 57 | 58 | 59 | class EigenPlacesDataset(torch.utils.data.Dataset): 60 | def __init__(self, dataset_folder, 61 | M=20, N=5, focal_dist=10, current_group=0, 62 | min_images_per_class=10, angle=0, visualize_classes=0): 63 | """ 64 | Parameters (please check our paper for a clearer explanation of the parameters). 65 | ---------- 66 | dataset_folder : str, the path of the folder with the train images. 67 | M : int, the length of the side of each cell in meters. 68 | N : int, distance (M-wise) between two classes of the same group. 69 | focal_dist : int, distance (in meters) between the center of the class and 70 | the focal point. The center of the class is computed as the 71 | mean of the positions of the images within the class. 72 | current_group : int, which one of the groups to consider. 73 | min_images_per_class : int, minimum number of image in a class. 74 | angle : int, the angle formed between the line of the first principal 75 | component, and the line that connects the center of gravity of the 76 | images to the focal point. 77 | visualize_classes : int, the number of classes for which to create 78 | visualizations. Visualizations of a class consists in its map and 79 | the images belonging to it. 80 | """ 81 | super().__init__() 82 | self.M = M 83 | self.N = N 84 | self.focal_dist = focal_dist 85 | self.current_group = current_group 86 | self.dataset_folder = dataset_folder 87 | 88 | filename = f"cache/sfxl_M{M}_N{N}_mipc{min_images_per_class}.torch" 89 | if not os.path.exists(filename): 90 | os.makedirs("cache", exist_ok=True) 91 | logging.info(f"Cached dataset {filename} does not exist, I'll create it now.") 92 | self.initialize(dataset_folder, M, N, min_images_per_class, filename) 93 | elif current_group == 0: 94 | logging.info(f"Using cached dataset {filename}") 95 | 96 | classes_per_group, self.images_per_class = torch.load(filename) 97 | if current_group >= len(classes_per_group): 98 | raise ValueError(f"With this configuration there are only {len(classes_per_group)} " + 99 | f"groups, therefore I can't create the {current_group}th group. " + 100 | "You should reduce the number of groups by setting for example " + 101 | f"'--groups_num {current_group}'") 102 | self.classes_ids = classes_per_group[current_group] 103 | 104 | new_classes_ids = [] 105 | self.focal_point_per_class = {} 106 | for class_id in self.classes_ids: 107 | paths = self.images_per_class[class_id] 108 | u_coords = np.array([p.split("@")[1:3] for p in paths]).astype(float) 109 | 110 | focal_point = get_focal_point(u_coords, focal_dist, angle=angle) 111 | new_classes_ids.append(class_id) 112 | self.focal_point_per_class[class_id] = focal_point 113 | 114 | self.classes_ids = new_classes_ids 115 | 116 | # This is only for logging, debugging and visualizations 117 | for class_num in range(visualize_classes): 118 | random_class_id = random.choice(self.classes_ids) 119 | paths = self.images_per_class[random_class_id] 120 | focal_point = self.focal_point_per_class[random_class_id] 121 | focal_point_lat_lon = np.array(utm.to_latlon(focal_point[0], focal_point[1], 10, 'S')) 122 | lats_lons = np.array([p.split("@")[5:7] for p in paths]).astype(float) 123 | lats_lons += (np.random.randn(*lats_lons.shape) / 500000) # Add a little noise to avoid overlapping 124 | 125 | min_e, min_n = random_class_id 126 | cell_utms = (min_e, min_n), (min_e, min_n + M), (min_e + M, min_n + M), (min_e + M, min_n) 127 | cell_corners = np.array([utm.to_latlon(*u, 10, 'S') for u in cell_utms]) 128 | 129 | output_folder = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) 130 | folder = f"{output_folder}/visualizations/group{current_group}_{class_num}_{random_class_id}" 131 | os.makedirs(folder) 132 | try: 133 | img_map = create_map([lats_lons, lats_lons.mean(0).reshape(1, 2), focal_point_lat_lon.reshape(1, 2), cell_corners], 134 | colors=["r", "b", "g", "orange"], 135 | legend_names=["images position", "center of mass", 136 | "focal point", f"cell corners ({M} meters)"], 137 | dot_sizes=[10, 100, 100, 100]) 138 | imageio.imsave(f"{folder}/@00_map.jpg", img_map) 139 | except RuntimeError: 140 | # Sometimes there are errors due to staticmap (komoot) servers 141 | logging.warn("There was some problem while downloading the map of the class for visualization. " 142 | "This will not influence training.") 143 | 144 | images_paths = self.images_per_class[random_class_id] 145 | for path in images_paths: 146 | crop = self.get_crop(self.dataset_folder + "/" + path, focal_point) 147 | crop = tfm.functional.to_pil_image(crop) 148 | crop.save(f"{folder}/{os.path.basename(path)}") 149 | 150 | @staticmethod 151 | def get_crop(pano_path, focal_point): 152 | obs_point = pano_path.split("@")[1:3] 153 | angle = - get_angle(focal_point, obs_point) % 360 154 | crop_offset = int((angle / 360 * PANO_WIDTH) % PANO_WIDTH) 155 | yaw = int(pano_path.split("@")[9]) 156 | north_yaw_in_degrees = (180-yaw) % 360 157 | yaw_offset = int((north_yaw_in_degrees / 360) * PANO_WIDTH) 158 | offset = (yaw_offset + crop_offset - 256) % PANO_WIDTH 159 | pano_pil = Image.open(pano_path) 160 | if offset + 512 <= PANO_WIDTH: 161 | pil_crop = pano_pil.crop((offset, 0, offset + 512, 512)) 162 | else: 163 | crop1 = pano_pil.crop((offset, 0, PANO_WIDTH, 512)) 164 | crop2 = pano_pil.crop((0, 0, 512 - (PANO_WIDTH - offset), 512)) 165 | pil_crop = Image.new('RGB', (512, 512)) 166 | pil_crop.paste(crop1, (0, 0)) 167 | pil_crop.paste(crop2, (crop1.size[0], 0)) 168 | crop = tfm.functional.to_tensor(pil_crop) 169 | 170 | return crop 171 | 172 | def __getitem__(self, class_num): 173 | # This function takes as input the class_num instead of the index of 174 | # the image. This way each class is equally represented during training. 175 | class_id = self.classes_ids[class_num] 176 | focal_point = self.focal_point_per_class[class_id] 177 | pano_path = self.dataset_folder + "/" + random.choice(self.images_per_class[class_id]) 178 | crop = self.get_crop(pano_path, focal_point) 179 | return crop, class_num, pano_path 180 | 181 | def get_images_num(self): 182 | """Return the number of images within this group.""" 183 | return sum([len(self.images_per_class[c]) for c in self.classes_ids]) 184 | 185 | def __len__(self): 186 | """Return the number of classes within this group.""" 187 | return len(self.classes_ids) 188 | 189 | @staticmethod 190 | def initialize(dataset_folder, M, N, min_images_per_class, filename): 191 | logging.debug(f"Searching training images in {dataset_folder}") 192 | 193 | images_paths = dataset_utils.read_images_paths(dataset_folder) 194 | logging.debug(f"Found {len(images_paths)} images") 195 | 196 | logging.debug("For each image, get its UTM east, UTM north from its path") 197 | images_metadatas = [p.split("@") for p in images_paths] 198 | # field 1 is UTM east, field 2 is UTM north 199 | utmeast_utmnorth = [(m[1], m[2]) for m in images_metadatas] 200 | utmeast_utmnorth = np.array(utmeast_utmnorth).astype(float) 201 | 202 | logging.debug("For each image, get class and group to which it belongs") 203 | class_id__group_id = [EigenPlacesDataset.get__class_id__group_id(*m, M, N) 204 | for m in utmeast_utmnorth] 205 | 206 | logging.debug("Group together images belonging to the same class") 207 | images_per_class = defaultdict(list) 208 | for image_path, (class_id, _) in zip(images_paths, class_id__group_id): 209 | images_per_class[class_id].append(image_path) 210 | 211 | # Images_per_class is a dict where the key is class_id, and the value 212 | # is a list with the paths of images within that class. 213 | images_per_class = {k: v for k, v in images_per_class.items() if len(v) >= min_images_per_class} 214 | 215 | logging.debug("Group together classes belonging to the same group") 216 | # Classes_per_group is a dict where the key is group_id, and the value 217 | # is a list with the class_ids belonging to that group. 218 | classes_per_group = defaultdict(set) 219 | for class_id, group_id in class_id__group_id: 220 | if class_id not in images_per_class: 221 | continue # Skip classes with too few images 222 | classes_per_group[group_id].add(class_id) 223 | 224 | # Convert classes_per_group to a list of lists. 225 | # Each sublist represents the classes within a group. 226 | classes_per_group = [list(c) for c in classes_per_group.values()] 227 | 228 | torch.save((classes_per_group, images_per_class), filename) 229 | 230 | @staticmethod 231 | def get__class_id__group_id(utm_east, utm_north, M, N): 232 | """Return class_id and group_id for a given point. 233 | The class_id is a triplet (tuple) of UTM_east, UTM_north 234 | (e.g. (396520, 4983800)). 235 | The group_id represents the group to which the class belongs 236 | (e.g. (0, 1)), and it is between (0, 0) and (N, N). 237 | """ 238 | rounded_utm_east = int(utm_east // M * M) # Rounded to nearest lower multiple of M 239 | rounded_utm_north = int(utm_north // M * M) 240 | 241 | class_id = (rounded_utm_east, rounded_utm_north) 242 | # group_id goes from (0, 0) to (N, N) 243 | group_id = (rounded_utm_east % (M * N) // M, 244 | rounded_utm_north % (M * N) // M) 245 | return class_id, group_id 246 | -------------------------------------------------------------------------------- /datasets/map_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import math 4 | import imageio 5 | import numpy as np 6 | import geopy.distance 7 | import matplotlib.pyplot as plt 8 | from staticmap import StaticMap, Polygon 9 | 10 | 11 | def _lon_to_x(lon, zoom): 12 | if not (-180 <= lon <= 180): lon = (lon + 180) % 360 - 180 13 | return ((lon + 180.) / 360) * pow(2, zoom) 14 | 15 | 16 | def _lat_to_y(lat, zoom): 17 | if not (-90 <= lat <= 90): lat = (lat + 90) % 180 - 90 18 | return (1 - math.log(math.tan(lat * math.pi / 180) + 1 / math.cos(lat * math.pi / 180)) / math.pi) / 2 * pow(2, zoom) 19 | 20 | 21 | def _download_map_image(min_lat=45.0, min_lon=7.6, max_lat=45.1, max_lon=7.7, size=2000): 22 | """"Download a map of the chosen area as a numpy image""" 23 | mean_lat = (min_lat + max_lat) / 2 24 | mean_lon = (min_lon + max_lon) / 2 25 | static_map = StaticMap(size, size) 26 | static_map.add_polygon( 27 | Polygon(((min_lon, min_lat), (min_lon, max_lat), (max_lon, max_lat), (max_lon, min_lat)), None, '#FFFFFF')) 28 | zoom = static_map._calculate_zoom() 29 | 30 | # print(((min_lat, min_lon), (max_lat, max_lon))) 31 | dist = geopy.distance.geodesic((min_lat, min_lon), (max_lat, max_lon)).m 32 | if dist < 50: 33 | zoom = 22 34 | else: 35 | zoom = 20 # static_map._calculate_zoom() 36 | static_map = StaticMap(size, size) 37 | image = static_map.render(zoom, [mean_lon, mean_lat]) 38 | # print(f"You can see the map on Google Maps at this link www.google.com/maps/place/@{mean_lat},{mean_lon},{zoom - 1}z") 39 | min_lat_px, min_lon_px, max_lat_px, max_lon_px = \ 40 | static_map._y_to_px(_lat_to_y(min_lat, zoom)), \ 41 | static_map._x_to_px(_lon_to_x(min_lon, zoom)), \ 42 | static_map._y_to_px(_lat_to_y(max_lat, zoom)), \ 43 | static_map._x_to_px(_lon_to_x(max_lon, zoom)) 44 | assert 0 <= max_lat_px < min_lat_px < size and 0 <= min_lon_px < max_lon_px < size 45 | return np.array(image)[max_lat_px:min_lat_px, min_lon_px:max_lon_px], static_map, zoom 46 | 47 | 48 | def get_edges(coordinates, enlarge=0): 49 | """ 50 | Send the edges of the coordinates, i.e. the most south, west, north and 51 | east coordinates. 52 | :param coordinates: A list of numpy.arrays of shape (Nx2) 53 | :param float enlarge: How much to increase the coordinates, to enlarge 54 | the area included between the points 55 | :return: a tuple with the four float 56 | """ 57 | min_lat, min_lon, max_lat, max_lon = (*np.concatenate(coordinates).min(0), *np.concatenate(coordinates).max(0)) 58 | diff_lat = (max_lat - min_lat) * enlarge 59 | diff_lon = (max_lon - min_lon) * enlarge 60 | inc_min_lat, inc_min_lon, inc_max_lat, inc_max_lon = \ 61 | min_lat - diff_lat, min_lon - diff_lon, max_lat + diff_lat, max_lon + diff_lon 62 | return inc_min_lat, inc_min_lon, inc_max_lat, inc_max_lon 63 | 64 | 65 | def create_map(coordinates, colors=None, dot_sizes=None, legend_names=None, map_intensity=0.6): 66 | 67 | dot_sizes = dot_sizes if dot_sizes is not None else [10] * len(coordinates) 68 | colors = colors if colors is not None else ["r"] * len(coordinates) 69 | assert len(coordinates) == len(dot_sizes) == len(colors), \ 70 | f"The number of coordinates must be equals to the number of colors and dot_sizes, but they're " \ 71 | f"{len(coordinates)}, {len(colors)}, {len(dot_sizes)}" 72 | 73 | # Add two dummy points to slightly enlarge the map 74 | min_lat, min_lon, max_lat, max_lon = get_edges(coordinates, enlarge=0.1) 75 | coordinates.append(np.array([[min_lat, min_lon], [max_lat, max_lon]])) 76 | # Download the map of the chosen area 77 | map_img, static_map, zoom = _download_map_image(min_lat, min_lon, max_lat, max_lon) 78 | 79 | scatters = [] 80 | fig = plt.figure(figsize=(map_img.shape[1] / 100, map_img.shape[0] / 100), dpi=1000) 81 | for i, coord in enumerate(coordinates): 82 | for i in range(len(coord)): # Scale latitudes because of earth's curvature 83 | coord[i, 0] = -static_map._y_to_px(_lat_to_y(coord[i, 0], zoom)) 84 | for coord, size, color in zip(coordinates, dot_sizes, colors): 85 | scatters.append(plt.scatter(coord[:, 1], coord[:, 0], s=size, color=color)) 86 | 87 | if legend_names != None: 88 | plt.legend(scatters, legend_names, scatterpoints=1, loc='best', 89 | ncol=1, framealpha=0, prop={"weight": "bold", "size": 20}) 90 | 91 | min_lat, min_lon, max_lat, max_lon = get_edges(coordinates) 92 | plt.ylim(min_lat, max_lat) 93 | plt.xlim(min_lon, max_lon) 94 | fig.subplots_adjust(bottom=0, top=1, left=0, right=1) 95 | fig.canvas.draw() 96 | plot_img = np.array(fig.canvas.renderer._renderer) 97 | plt.close() 98 | 99 | plot_img = cv2.resize(plot_img[:, :, :3], map_img.shape[:2][::-1], interpolation=cv2.INTER_LANCZOS4) 100 | map_img[(map_img.sum(2) < 444)] = 188 # brighten dark pixels 101 | map_img = (((map_img / 255) ** map_intensity) * 255).astype(np.uint8) # fade map 102 | mask = (plot_img.sum(2) == 255 * 3)[:, :, None] # mask of plot, to find white pixels 103 | final_map = map_img * mask + plot_img * (~mask) 104 | return final_map 105 | 106 | 107 | if __name__ == "__main__": 108 | # Create a map containing major cities of Italy, Germany and France 109 | coordinates = [ 110 | np.array([[41.8931, 12.4828], [45.4669, 9.1900], [40.8333, 14.2500]]), 111 | np.array([[52.5200, 13.4050], [48.7775, 9.1800], [48.1375, 11.5750]]), 112 | np.array([[48.8567, 2.3522], [43.2964, 5.3700], [45.7600, 4.8400]]) 113 | ] 114 | map_img = create_map( 115 | coordinates, 116 | colors=["green", "black", "blue"], 117 | dot_sizes=[1000, 1000, 1000], 118 | legend_names=[ 119 | "Main Italian Cities", 120 | "Main German Cities", 121 | "Main French Cities", 122 | ]) 123 | 124 | imageio.imsave("cities.png", map_img) 125 | 126 | -------------------------------------------------------------------------------- /datasets/test_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | from sklearn.neighbors import NearestNeighbors 8 | 9 | import datasets.dataset_utils as dataset_utils 10 | 11 | 12 | class TestDataset(data.Dataset): 13 | def __init__(self, dataset_folder, database_folder="database", 14 | queries_folder="queries", positive_dist_threshold=25): 15 | """Dataset with images from database and queries, used for validation and test. 16 | Parameters 17 | ---------- 18 | dataset_folder : str, should contain the path to the val or test set, 19 | which contains the folders {database_folder} and {queries_folder}. 20 | database_folder : str, name of folder with the database. 21 | queries_folder : str, name of folder with the queries. 22 | positive_dist_threshold : int, distance in meters for a prediction to 23 | be considered a positive. 24 | """ 25 | super().__init__() 26 | 27 | self.database_folder = os.path.join(dataset_folder, database_folder) 28 | self.queries_folder = os.path.join(dataset_folder, queries_folder) 29 | self.database_paths = dataset_utils.read_images_paths(self.database_folder, get_abs_path=True) 30 | self.queries_paths = dataset_utils.read_images_paths(self.queries_folder, get_abs_path=True) 31 | 32 | self.dataset_name = os.path.basename(dataset_folder) 33 | 34 | #### Read paths and UTM coordinates for all images. 35 | # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg 36 | self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float) 37 | self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(float) 38 | 39 | # Find positives_per_query, which are within positive_dist_threshold (default 25 meters) 40 | knn = NearestNeighbors(n_jobs=-1) 41 | knn.fit(self.database_utms) 42 | self.positives_per_query = knn.radius_neighbors( 43 | self.queries_utms, radius=positive_dist_threshold, return_distance=False 44 | ) 45 | 46 | self.images_paths = self.database_paths + self.queries_paths 47 | 48 | self.database_num = len(self.database_paths) 49 | self.queries_num = len(self.queries_paths) 50 | 51 | self.base_transform = transforms.Compose([ 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 54 | ]) 55 | 56 | def __getitem__(self, index): 57 | image_path = self.images_paths[index] 58 | pil_img = Image.open(image_path) 59 | normalized_img = self.base_transform(pil_img) 60 | return normalized_img, index 61 | 62 | def __len__(self): 63 | return len(self.images_paths) 64 | 65 | def __repr__(self): 66 | return f"< {self.dataset_name} - #q: {self.queries_num}; #db: {self.database_num} >" 67 | 68 | def get_positives(self): 69 | return self.positives_per_query 70 | -------------------------------------------------------------------------------- /eigenplaces_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmberton/EigenPlaces/fef66475d45f9a65e7e10ba2c360c25396478f44/eigenplaces_model/__init__.py -------------------------------------------------------------------------------- /eigenplaces_model/eigenplaces_network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import logging 4 | import torchvision 5 | from torch import nn 6 | from typing import Tuple 7 | 8 | from eigenplaces_model.layers import Flatten, L2Norm, GeM 9 | 10 | # The number of channels in the last convolutional layer, the one before average pooling 11 | CHANNELS_NUM_IN_LAST_CONV = { 12 | "ResNet18": 512, 13 | "ResNet50": 2048, 14 | "ResNet101": 2048, 15 | "ResNet152": 2048, 16 | "VGG16": 512, 17 | } 18 | 19 | 20 | class GeoLocalizationNet_(nn.Module): 21 | def __init__(self, backbone : str, fc_output_dim : int): 22 | """Return a model_ for GeoLocalization. 23 | 24 | Args: 25 | backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet. 26 | fc_output_dim (int): the output dimension of the last fc layer, equivalent to the descriptors dimension. 27 | """ 28 | super().__init__() 29 | assert backbone in CHANNELS_NUM_IN_LAST_CONV, f"backbone must be one of {list(CHANNELS_NUM_IN_LAST_CONV.keys())}" 30 | self.backbone, features_dim = _get_backbone(backbone) 31 | self.aggregation = nn.Sequential( 32 | L2Norm(), 33 | GeM(), 34 | Flatten(), 35 | nn.Linear(features_dim, fc_output_dim), 36 | L2Norm() 37 | ) 38 | 39 | def forward(self, x): 40 | x = self.backbone(x) 41 | x = self.aggregation(x) 42 | return x 43 | 44 | 45 | def _get_torchvision_model(backbone_name : str) -> torch.nn.Module: 46 | """This function takes the name of a backbone and returns the corresponding pretrained 47 | model from torchvision. Examples of backbone_name are 'VGG16' or 'ResNet18' 48 | """ 49 | return getattr(torchvision.models, backbone_name.lower())() 50 | 51 | 52 | def _get_backbone(backbone_name : str) -> Tuple[torch.nn.Module, int]: 53 | backbone = _get_torchvision_model(backbone_name) 54 | 55 | logging.info("Loading pretrained backbone's weights from CosPlace") 56 | cosplace = torch.hub.load("gmberton/cosplace", "get_trained_model", backbone=backbone_name, fc_output_dim=512) 57 | new_sd = {k1: v2 for (k1, v1), (k2, v2) in zip(backbone.state_dict().items(), cosplace.state_dict().items()) 58 | if v1.shape == v2.shape} 59 | backbone.load_state_dict(new_sd, strict=False) 60 | 61 | if backbone_name.startswith("ResNet"): 62 | for name, child in backbone.named_children(): 63 | if name == "layer3": # Freeze layers before conv_3 64 | break 65 | for params in child.parameters(): 66 | params.requires_grad = False 67 | logging.debug(f"Train only layer3 and layer4 of the {backbone_name}, freeze the previous ones") 68 | layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer 69 | 70 | elif backbone_name == "VGG16": 71 | layers = list(backbone.features.children())[:-2] # Remove avg pooling and FC layer 72 | for layer in layers[:-5]: 73 | for p in layer.parameters(): 74 | p.requires_grad = False 75 | logging.debug("Train last layers of the VGG-16, freeze the previous ones") 76 | 77 | backbone = torch.nn.Sequential(*layers) 78 | 79 | features_dim = CHANNELS_NUM_IN_LAST_CONV[backbone_name] 80 | 81 | return backbone, features_dim 82 | 83 | -------------------------------------------------------------------------------- /eigenplaces_model/layers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.parameter import Parameter 6 | 7 | 8 | def gem(x, p=torch.ones(1)*3, eps: float = 1e-6): 9 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 10 | 11 | 12 | class GeM(nn.Module): 13 | def __init__(self, p=3, eps=1e-6): 14 | super().__init__() 15 | self.p = Parameter(torch.ones(1)*p) 16 | self.eps = eps 17 | 18 | def forward(self, x): 19 | return gem(x, p=self.p, eps=self.eps) 20 | 21 | def __repr__(self): 22 | return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})" 23 | 24 | 25 | class Flatten(torch.nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | 29 | def forward(self, x): 30 | assert x.shape[2] == x.shape[3] == 1, f"{x.shape[2]} != {x.shape[3]} != 1" 31 | return x[:, :, 0, 0] 32 | 33 | 34 | class L2Norm(nn.Module): 35 | def __init__(self, dim=1): 36 | super().__init__() 37 | self.dim = dim 38 | 39 | def forward(self, x): 40 | return F.normalize(x, p=2.0, dim=self.dim) 41 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import torch 4 | import logging 5 | import multiprocessing 6 | from datetime import datetime 7 | 8 | import test 9 | import parser 10 | import commons 11 | from datasets.test_dataset import TestDataset 12 | from eigenplaces_model import eigenplaces_network 13 | 14 | torch.backends.cudnn.benchmark = True # Provides a speedup 15 | 16 | args = parser.parse_arguments() 17 | start_time = datetime.now() 18 | output_folder = f"logs/{args.save_dir}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}" 19 | commons.make_deterministic(args.seed) 20 | commons.setup_logging(output_folder, console="info") 21 | logging.info(" ".join(sys.argv)) 22 | logging.info(f"Arguments: {args}") 23 | logging.info(f"The outputs are being saved in {output_folder}") 24 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.") 25 | 26 | #### Model 27 | if args.resume_model == "torchhub": 28 | model = torch.hub.load("gmberton/eigenplaces", "get_trained_model", 29 | backbone=args.backbone, fc_output_dim=args.fc_output_dim) 30 | else: 31 | model = eigenplaces_network.GeoLocalizationNet_(args.backbone, args.fc_output_dim) 32 | 33 | if args.resume_model is not None: 34 | logging.info(f"Loading model_ from {args.resume_model}") 35 | model_state_dict = torch.load(args.resume_model) 36 | model.load_state_dict(model_state_dict) 37 | else: 38 | logging.info("WARNING: You didn't provide a path to resume the model_ (--resume_model parameter). " + 39 | "Evaluation will be computed using randomly initialized weights.") 40 | 41 | model = model.to(args.device) 42 | 43 | test_ds = TestDataset(args.test_dataset_folder, queries_folder="queries", 44 | positive_dist_threshold=args.positive_dist_threshold) 45 | 46 | recalls, recalls_str = test.test(args, test_ds, model) 47 | logging.info(f"{test_ds}: {recalls_str}") 48 | 49 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | 2 | dependencies = ['torch', 'torchvision'] 3 | 4 | import torch 5 | from eigenplaces_model import eigenplaces_network 6 | 7 | 8 | AVAILABLE_TRAINED_MODELS = { 9 | # backbone : list of available fc_output_dim, which is equivalent to descriptors dimensionality 10 | "VGG16": [ 512], 11 | "ResNet18": [ 256, 512], 12 | "ResNet50": [128, 256, 512, 1024, 2048], 13 | "ResNet101": [128, 256, 512, 1024, 2048], 14 | } 15 | 16 | 17 | def get_trained_model(backbone : str = "ResNet50", fc_output_dim : int = 2048) -> torch.nn.Module: 18 | """Return a model trained with EigenPlaces on San Francisco eXtra Large. 19 | 20 | Args: 21 | backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet. 22 | fc_output_dim (int): the output dimension of the last fc layer, equivalent to 23 | the descriptors dimension. Must be between 32 and 2048, depending on model's availability. 24 | 25 | Return: 26 | model (torch.nn.Module): a trained model. 27 | """ 28 | print(f"Returning EigenPlaces model with backbone: {backbone} with features dimension {fc_output_dim}") 29 | if backbone not in AVAILABLE_TRAINED_MODELS: 30 | raise ValueError(f"Parameter `backbone` is set to {backbone} but it must be one of {list(AVAILABLE_TRAINED_MODELS.keys())}") 31 | try: 32 | fc_output_dim = int(fc_output_dim) 33 | except: 34 | raise ValueError(f"Parameter `fc_output_dim` must be an integer, but it is set to {fc_output_dim}") 35 | if fc_output_dim not in AVAILABLE_TRAINED_MODELS[backbone]: 36 | raise ValueError(f"Parameter `fc_output_dim` is set to {fc_output_dim}, but for backbone {backbone} " 37 | f"it must be one of {list(AVAILABLE_TRAINED_MODELS[backbone])}") 38 | model = eigenplaces_network.GeoLocalizationNet_(backbone, fc_output_dim) 39 | model.load_state_dict( 40 | torch.hub.load_state_dict_from_url( 41 | f'https://github.com/gmberton/EigenPlaces/releases/download/v1.0/{backbone}_{fc_output_dim}_eigenplaces.pth', 42 | map_location=torch.device('cpu')) 43 | ) 44 | return model -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | 5 | def parse_arguments(): 6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | # CosPlace Groups parameters 8 | parser.add_argument("--M", type=int, default=15, help="_") 9 | parser.add_argument("--N", type=int, default=3, help="_") 10 | parser.add_argument("--focal_dist", type=int, default=10, help="_") # done GS 11 | parser.add_argument("--s", type=float, default=100, help="_") 12 | parser.add_argument("--m", type=float, default=0.4, help="_") 13 | parser.add_argument("--lambda_lat", type=float, default=1., help="_") 14 | parser.add_argument("--lambda_front", type=float, default=1., help="_") 15 | parser.add_argument("--groups_num", type=int, default=0, 16 | help="If set to 0 use N*N groups") 17 | 18 | parser.add_argument("--min_images_per_class", type=int, default=5, help="_") 19 | # Model parameters 20 | parser.add_argument("--backbone", type=str, default="ResNet18", 21 | choices=["VGG16", "ResNet18", "ResNet50", "ResNet101", "ResNet152"], help="_") 22 | parser.add_argument("--fc_output_dim", type=int, default=512, 23 | help="Output dimension of final fully connected layer") 24 | # Training parameters 25 | parser.add_argument("--batch_size", type=int, default=32, help="_") 26 | parser.add_argument("--epochs_num", type=int, default=40, help="_") 27 | parser.add_argument("--iterations_per_epoch", type=int, default=5000, help="_") 28 | parser.add_argument("--lr", type=float, default=0.00001, help="_") 29 | parser.add_argument("--classifiers_lr", type=float, default=0.01, help="_") 30 | # Data augmentation 31 | parser.add_argument("--brightness", type=float, default=0.7, help="_") 32 | parser.add_argument("--contrast", type=float, default=0.7, help="_") 33 | parser.add_argument("--hue", type=float, default=0.5, help="_") 34 | parser.add_argument("--saturation", type=float, default=0.7, help="_") 35 | parser.add_argument("--random_resized_crop", type=float, default=0.5, help="_") 36 | 37 | # Validation / test parameters 38 | parser.add_argument("--infer_batch_size", type=int, default=16, 39 | help="Batch size for inference (validating and testing)") 40 | parser.add_argument("--positive_dist_threshold", type=int, default=25, 41 | help="distance in meters for a prediction to be considered a positive") 42 | # Resume parameters 43 | parser.add_argument("--resume_train", type=str, default=None, 44 | help="path to checkpoint to resume, e.g. logs/.../last_checkpoint.pth") 45 | parser.add_argument("--resume_model", type=str, default=None, 46 | help="path to model_ to resume, e.g. logs/.../best_model.pth. " 47 | "Use \"torchhub\" if you want to use one of our pretrained models") 48 | # Other parameters 49 | parser.add_argument("--device", type=str, default="cuda", 50 | choices=["cuda", "cpu"], help="_") 51 | parser.add_argument("--seed", type=int, default=0, help="_") 52 | parser.add_argument("--num_workers", type=int, default=8, help="_") 53 | parser.add_argument("--visualize_classes", type=int, default=0, 54 | help="Save map visualizations for X classes in the save_dir") 55 | 56 | # Paths parameters 57 | parser.add_argument("--train_dataset_folder", type=str, default=None, 58 | help="path of the folder with training images") 59 | parser.add_argument("--val_dataset_folder", type=str, default=None, 60 | help="path of the folder with val images (split in database/queries)") 61 | parser.add_argument("--test_dataset_folder", type=str, default=None, 62 | help="path of the folder with test images (split in database/queries)") 63 | parser.add_argument("--save_dir", type=str, default="default", 64 | help="name of directory on which to save the logs, under logs/save_dir") 65 | 66 | args = parser.parse_args() 67 | if args.groups_num == 0: 68 | args.groups_num = args.N * args.N 69 | 70 | return args 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | geopy>=2.3.0 2 | imageio>=2.31.1 3 | matplotlib>=3.7.2 4 | numpy>=1.25.1 5 | Pillow==9.0.0 6 | Requests>=2.31.0 7 | scikit_learn>=1.3.0 8 | scikit_image 9 | staticmap>=0.5.5 10 | torch>=2.0.1 11 | torchmetrics>=1.0.1 12 | torchvision>=0.15.2 13 | tqdm>=4.65.0 14 | urllib3>=2.0.3 15 | utm>=0.7.0 16 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import faiss 3 | import torch 4 | import logging 5 | import numpy as np 6 | from tqdm import tqdm 7 | from typing import Tuple 8 | from argparse import Namespace 9 | from torch.utils.data.dataset import Subset 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | 13 | # Compute R@1, R@5, R@10, R@20 14 | RECALL_VALUES = [1, 5, 10, 20] 15 | 16 | 17 | def test(args: Namespace, eval_ds: Dataset, model: torch.nn.Module, batchify : bool = False) -> Tuple[np.ndarray, str]: 18 | """Compute descriptors of the given dataset and compute the recalls.""" 19 | 20 | model = model.eval() 21 | with torch.no_grad(): 22 | logging.debug("Extracting database descriptors for evaluation/testing") 23 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num))) 24 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers, 25 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda")) 26 | all_descriptors = np.empty((len(eval_ds), args.fc_output_dim), dtype="float32") 27 | for images, indices in tqdm(database_dataloader, ncols=100): 28 | descriptors = model(images.to(args.device)) 29 | descriptors = descriptors.cpu().numpy() 30 | all_descriptors[indices.numpy(), :] = descriptors 31 | 32 | logging.debug("Extracting queries descriptors for evaluation/testing") 33 | if batchify: 34 | queries_infer_batch_size = args.infer_batch_size 35 | else: 36 | queries_infer_batch_size = 1 37 | queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num))) 38 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers, 39 | batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda")) 40 | for images, indices in tqdm(queries_dataloader, ncols=100): 41 | descriptors = model(images.to(args.device)) 42 | descriptors = descriptors.cpu().numpy() 43 | all_descriptors[indices.numpy(), :] = descriptors 44 | 45 | queries_descriptors = all_descriptors[eval_ds.database_num:] 46 | database_descriptors = all_descriptors[:eval_ds.database_num] 47 | 48 | # Use a kNN to find predictions 49 | faiss_index = faiss.IndexFlatL2(args.fc_output_dim) 50 | faiss_index.add(database_descriptors) 51 | del database_descriptors, all_descriptors 52 | 53 | logging.debug("Calculating recalls") 54 | _, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES)) 55 | 56 | #### For each query, check if the predictions are correct 57 | positives_per_query = eval_ds.get_positives() 58 | recalls = np.zeros(len(RECALL_VALUES)) 59 | for query_index, preds in enumerate(predictions): 60 | for i, n in enumerate(RECALL_VALUES): 61 | if np.any(np.in1d(preds[:n], positives_per_query[query_index])): 62 | recalls[i:] += 1 63 | break 64 | # Divide by queries_num and multiply by 100, so the recalls are in percentages 65 | recalls = recalls / eval_ds.queries_num * 100 66 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)]) 67 | return recalls, recalls_str 68 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import torch 4 | import logging 5 | import torchmetrics 6 | from tqdm import tqdm 7 | import multiprocessing 8 | from datetime import datetime 9 | import torchvision.transforms as tfm 10 | 11 | import test 12 | import util 13 | import parser 14 | import commons 15 | import cosface_loss 16 | import augmentations 17 | from eigenplaces_model import eigenplaces_network 18 | from datasets.test_dataset import TestDataset 19 | from datasets.eigenplaces_dataset import EigenPlacesDataset 20 | 21 | torch.backends.cudnn.benchmark = True # Provides a speedup 22 | 23 | args = parser.parse_arguments() 24 | start_time = datetime.now() 25 | output_folder = f"logs/{args.save_dir}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}" 26 | commons.make_deterministic(args.seed) 27 | commons.setup_logging(output_folder, console="debug") 28 | logging.info(" ".join(sys.argv)) 29 | logging.info(f"Arguments: {args}") 30 | logging.info(f"The outputs are being saved in {output_folder}") 31 | 32 | #### Model 33 | model = eigenplaces_network.GeoLocalizationNet_(args.backbone, args.fc_output_dim) 34 | 35 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.") 36 | 37 | if args.resume_model is not None: 38 | logging.debug(f"Loading model from {args.resume_model}") 39 | model_state_dict = torch.load(args.resume_model) 40 | model.load_state_dict(model_state_dict) 41 | 42 | model = model.to(args.device).train() 43 | 44 | #### Optimizer 45 | criterion = torch.nn.CrossEntropyLoss() 46 | model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 47 | 48 | #### Datasets 49 | groups = [EigenPlacesDataset( 50 | args.train_dataset_folder, M=args.M, N=args.N, focal_dist=args.focal_dist, 51 | current_group=n//2, min_images_per_class=args.min_images_per_class, 52 | angle=[0, 90][n % 2], visualize_classes=args.visualize_classes) 53 | for n in range(args.groups_num * 2) 54 | ] 55 | # Each group has its own classifier, which depends on the number of classes in the group 56 | classifiers = [cosface_loss.MarginCosineProduct( 57 | args.fc_output_dim, len(group), s=args.s, m=args.m) for group in groups] 58 | classifiers_optimizers = [torch.optim.Adam(classifier.parameters(), lr=args.classifiers_lr) for classifier in classifiers] 59 | 60 | gpu_augmentation = tfm.Compose([ 61 | augmentations.DeviceAgnosticColorJitter(brightness=args.brightness, 62 | contrast=args.contrast, 63 | saturation=args.saturation, 64 | hue=args.hue), 65 | augmentations.DeviceAgnosticRandomResizedCrop([512, 512], 66 | scale=[1-args.random_resized_crop, 1]), 67 | tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 68 | ]) 69 | 70 | logging.info(f"Using {len(groups)} groups") 71 | logging.info(f"The {len(groups)} groups have respectively the following " 72 | f"number of classes {[len(g) for g in groups]}") 73 | logging.info(f"The {len(groups)} groups have respectively the following " 74 | f"number of images {[g.get_images_num() for g in groups]}") 75 | 76 | logging.info(f"There are {len(groups[0])} classes for the first group, " + 77 | f"each epoch has {args.iterations_per_epoch} iterations " + 78 | f"with batch_size {args.batch_size}, therefore the model sees each class (on average) " + 79 | f"{args.iterations_per_epoch * args.batch_size / len(groups[0]):.1f} times per epoch") 80 | 81 | val_ds = TestDataset(f"{args.val_dataset_folder}") 82 | logging.info(f"Validation set: {val_ds}") 83 | 84 | #### Resume 85 | if args.resume_train: 86 | model, model_optimizer, classifiers, classifiers_optimizers, \ 87 | best_val_recall1, start_epoch_num = \ 88 | util.resume_train(args, output_folder, model, model_optimizer, 89 | classifiers, classifiers_optimizers) 90 | 91 | model = model.to(args.device) 92 | epoch_num = start_epoch_num - 1 93 | logging.info(f"Resuming from epoch {start_epoch_num} with best R@1 {best_val_recall1:.1f} " + 94 | f"from checkpoint {args.resume_train}") 95 | else: 96 | best_val_recall1 = start_epoch_num = 0 97 | 98 | #### Train / evaluation loop 99 | logging.info("Start training ...") 100 | 101 | scaler = torch.cuda.amp.GradScaler() 102 | 103 | for epoch_num in range(start_epoch_num, args.epochs_num): 104 | 105 | #### Train 106 | epoch_start_time = datetime.now() 107 | 108 | def get_iterator(groups, classifiers, classifiers_optimizers, batch_size, g_num): 109 | assert len(groups) == len(classifiers) == len(classifiers_optimizers) 110 | classifiers[g_num] = classifiers[g_num].to(args.device) 111 | util.move_to_device(classifiers_optimizers[g_num], args.device) 112 | return commons.InfiniteDataLoader(groups[g_num], num_workers=args.num_workers, 113 | batch_size=batch_size, shuffle=True, 114 | pin_memory=(args.device == "cuda"), drop_last=True) 115 | 116 | # Select classifier and dataloader according to epoch 117 | current_dataset_num = (epoch_num % args.groups_num) * 2 118 | 119 | iterators = [] 120 | for i in range(2): 121 | iterators.append(get_iterator(groups, classifiers, classifiers_optimizers, 122 | args.batch_size, current_dataset_num + i)) 123 | lateral_loss = torchmetrics.MeanMetric() 124 | frontal_loss = torchmetrics.MeanMetric() 125 | 126 | model = model.train() 127 | for iteration in tqdm(range(args.iterations_per_epoch), ncols=100): 128 | model_optimizer.zero_grad() 129 | 130 | #### EigenPlace ITERATION #### 131 | for i in range(2): 132 | classifiers_optimizers[current_dataset_num + i].zero_grad() 133 | 134 | images, targets, _ = next(iterators[i]) 135 | images, targets = images.to(args.device), targets.to(args.device) 136 | with torch.cuda.amp.autocast(): 137 | images = gpu_augmentation(images) 138 | descriptors = model(images) 139 | output = classifiers[current_dataset_num + i](descriptors, targets) 140 | loss = criterion(output, targets) 141 | if i == 0: 142 | loss *= args.lambda_lat 143 | else: 144 | loss *= args.lambda_front 145 | del images, output 146 | scaler.scale(loss).backward() 147 | scaler.step(classifiers_optimizers[current_dataset_num + i]) 148 | if i == 0: 149 | lateral_loss.update(loss.detach().cpu()) 150 | else: 151 | frontal_loss.update(loss.detach().cpu()) 152 | del loss 153 | ####################### 154 | 155 | scaler.step(model_optimizer) 156 | scaler.update() 157 | 158 | for i in range(2): 159 | classifiers[current_dataset_num + i] = classifiers[current_dataset_num + i].cpu() 160 | util.move_to_device(classifiers_optimizers[current_dataset_num + i], "cpu") 161 | 162 | logging.debug(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]} - " 163 | f"group {current_dataset_num} lateral_loss = {lateral_loss.compute():.4f} - " 164 | f"group {current_dataset_num + 1} frontal_loss = {frontal_loss.compute():.4f}") 165 | 166 | #### Evaluation 167 | recalls, recalls_str = test.test(args, val_ds, model, batchify=True) 168 | logging.info(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, {val_ds}: {recalls_str}") 169 | is_best = recalls[0] > best_val_recall1 170 | best_val_recall1 = max(recalls[0], best_val_recall1) 171 | # Save checkpoint, which contains all training parameters 172 | util.save_checkpoint({ 173 | "epoch_num": epoch_num + 1, 174 | "model_state_dict": model.state_dict(), 175 | "optimizer_state_dict": model_optimizer.state_dict(), 176 | "classifiers_state_dict": [c.state_dict() for c in classifiers], 177 | "optimizers_state_dict": [c.state_dict() for c in classifiers_optimizers], 178 | "best_val_recall1": best_val_recall1 179 | }, is_best, output_folder) 180 | 181 | logging.info(f"Trained for {epoch_num+1:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}") 182 | 183 | #### Test best model_ on test set v1 184 | best_model_state_dict = torch.load(f"{output_folder}/best_model.pth") 185 | model.load_state_dict(best_model_state_dict) 186 | 187 | test_ds = TestDataset(f"{args.test_dataset_folder}") 188 | recalls, recalls_str = test.test(args, test_ds, model) 189 | logging.info(f"{test_ds}: {recalls_str}") 190 | 191 | logging.info("Experiment finished (without any errors)") 192 | 193 | 194 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import shutil 4 | import logging 5 | from typing import Type, List 6 | from argparse import Namespace 7 | from cosface_loss import MarginCosineProduct 8 | 9 | 10 | def move_to_device(optimizer: Type[torch.optim.Optimizer], device: str): 11 | for state in optimizer.state.values(): 12 | for k, v in state.items(): 13 | if torch.is_tensor(v): 14 | state[k] = v.to(device) 15 | 16 | 17 | def save_checkpoint(state: dict, is_best: bool, output_folder: str, 18 | ckpt_filename: str = "last_checkpoint.pth"): 19 | checkpoint_path = f"{output_folder}/{ckpt_filename}" 20 | torch.save(state, checkpoint_path, ) 21 | if is_best: 22 | torch.save(state["model_state_dict"], f"{output_folder}/best_model.pth") 23 | 24 | 25 | def resume_train(args: Namespace, output_folder: str, model: torch.nn.Module, 26 | model_optimizer: Type[torch.optim.Optimizer], 27 | classifiers: List[MarginCosineProduct], classifiers_optimizers: List[Type[torch.optim.Optimizer]]): 28 | 29 | """Load model_, optimizer, and other training parameters""" 30 | logging.info(f"Loading checkpoint: {args.resume_train}") 31 | checkpoint = torch.load(args.resume_train) 32 | start_epoch_num = checkpoint["epoch_num"] 33 | 34 | model_state_dict = checkpoint["model_state_dict"] 35 | model.load_state_dict(model_state_dict) 36 | 37 | model = model.to(args.device) 38 | model_optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 39 | 40 | # load classifiers and optimizers 41 | assert args.groups_num*2 == len(classifiers) == len(classifiers_optimizers) == \ 42 | len(checkpoint["classifiers_state_dict"]) == len(checkpoint["optimizers_state_dict"]), \ 43 | (f"{args.groups_num}, {len(classifiers)}, {len(classifiers_optimizers)}, " 44 | f"{len(checkpoint['classifiers_state_dict'])}, {len(checkpoint['optimizers_state_dict'])}") 45 | 46 | for c, sd in zip(classifiers, checkpoint["classifiers_state_dict"]): 47 | # Move classifiers to GPU before loading their optimizers 48 | c = c.to(args.device) 49 | c.load_state_dict(sd) 50 | for c, sd in zip(classifiers_optimizers, checkpoint["optimizers_state_dict"]): 51 | c.load_state_dict(sd) 52 | for c in classifiers: 53 | # Move classifiers back to CPU to save some GPU memory 54 | c = c.cpu() 55 | 56 | best_val_recall1 = checkpoint["best_val_recall1"] 57 | 58 | # Copy best model_ to current output_folder 59 | shutil.copy(args.resume_train.replace("last_checkpoint.pth", "best_model.pth"), output_folder) 60 | 61 | return model, model_optimizer, classifiers, classifiers_optimizers, best_val_recall1, start_epoch_num 62 | --------------------------------------------------------------------------------