├── .gitignore
├── LICENSE
├── README.md
├── augmentations.py
├── commons.py
├── cosface_loss.py
├── datasets
├── __init__.py
├── dataset_utils.py
├── eigenplaces_dataset.py
├── map_utils.py
└── test_dataset.py
├── eigenplaces_model
├── __init__.py
├── eigenplaces_network.py
└── layers.py
├── eval.py
├── hubconf.py
├── parser.py
├── requirements.txt
├── test.py
├── train.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .spyproject
2 | .idea
3 | __pycache__
4 | logs
5 | cache
6 | jobs
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Gabriele Berton, Gabriele Trivigno, Carlo Masone, Barbara Caputo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # EigenPlaces: Training Viewpoint Robust Models for Visual Place Recognition
3 |
4 | [](https://paperswithcode.com/sota/visual-place-recognition-on-amstertime?p=eigenplaces-training-viewpoint-robust-models)
5 | [](https://paperswithcode.com/sota/visual-place-recognition-on-eynsham?p=eigenplaces-training-viewpoint-robust-models)
6 | [](https://paperswithcode.com/sota/visual-place-recognition-on-pittsburgh-30k?p=eigenplaces-training-viewpoint-robust-models)
7 | [](https://paperswithcode.com/sota/visual-place-recognition-on-san-francisco?p=eigenplaces-training-viewpoint-robust-models)
8 | [](https://paperswithcode.com/sota/visual-place-recognition-on-sf-xl-test-v1?p=eigenplaces-training-viewpoint-robust-models)
9 | [](https://paperswithcode.com/sota/visual-place-recognition-on-tokyo247?p=eigenplaces-training-viewpoint-robust-models)
10 | [](https://paperswithcode.com/sota/visual-place-recognition-on-pittsburgh-250k?p=eigenplaces-training-viewpoint-robust-models)
11 | [](https://paperswithcode.com/sota/visual-place-recognition-on-sf-xl-test-v2?p=eigenplaces-training-viewpoint-robust-models)
12 |
13 | This is the official pyTorch implementation of the ICCV 2023 paper "EigenPlaces: Training Viewpoint Robust Models for Visual Place Recognition".
14 | The paper presents a new training method which aims at providing samples from multiple viewpoints to the model, to make it robust to camera viewpoint changes. It achieves SOTA on any dataset with large viewpoint shifts between query images and database.
15 |
16 | For the paper we also released a codebase to reproduce results with all other baselines (NetVLAD, SFRS, Conv-AP, CosPlace, MixVPR) in order to have a standardized and fair evaluation framework at [https://github.com/gmberton/VPR-methods-evaluation](https://github.com/gmberton/VPR-methods-evaluation)
17 |
18 | [[ICCV 2023 Open Access](https://openaccess.thecvf.com/content/ICCV2023/html/Berton_EigenPlaces_Training_Viewpoint_Robust_Models_for_Visual_Place_Recognition_ICCV_2023_paper.html)] [[ArXiv](https://arxiv.org/abs/2308.10832)] [[BibTex](https://github.com/gmberton/EigenPlaces#cite)]
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 | ## Train
29 | Training is performed on the SF-XL dataset, which you can download from [here](https://github.com/gmberton/CosPlace). Make sure to download the training panoramas, which EigenPlaces takes as input and automatically crops with the required orientation.
30 | After downloading the SF-XL dataset, simply run
31 |
32 | `$ python3 train.py --train_dataset_folder path/to/sf_xl/raw/train/panoramas --val_dataset_folder path/to/sf_xl/processed/val --test_dataset_folder path/to/sf_xl/processed/test`
33 |
34 | the script automatically splits SF-XL in CosPlace Groups, and saves the resulting object in the folder `cache`.
35 | By default training is performed with a ResNet-18 with descriptors dimensionality 512 and AMP, which uses less than 8GB of VRAM.
36 |
37 | To change the backbone or the output descriptors dimensionality simply run something like this
38 |
39 | `$ python3 train.py --backbone ResNet50 --fc_output_dim 128`
40 |
41 | Run `$ python3 train.py -h` to have a look at all the hyperparameters that you can change. You will find all hyperparameters mentioned in the paper.
42 |
43 | ## Test
44 | You can test one of our trained models as such (downloads the model from torch.hub)
45 |
46 | `$ python3 eval.py --backbone ResNet50 --fc_output_dim 2048 --resume_model torchhub`
47 |
48 | or a model trained by you like this
49 |
50 | `$ python3 eval.py --backbone ResNet50 --fc_output_dim 2048 --resume_model path/to/best_model.pth`
51 |
52 | ## Trained Models
53 |
54 | We have all our trained models on [PyTorch Hub](https://pytorch.org/docs/stable/hub.html), so that you can use them in any codebase without cloning this repository simply like this
55 | ```
56 | import torch
57 | model = torch.hub.load("gmberton/eigenplaces", "get_trained_model", backbone="ResNet50", fc_output_dim=2048)
58 | ```
59 |
60 | Available trained models are ResNet18 (with output dim 256 or 512), ResNet50 (output dim 128, 256, 512 or 2048), ResNet101 (output dim 128, 256, 512 or 2048) and VGG16 (output dim 512).
61 |
62 |
63 | ## Acknowledgements
64 | Parts of this repo are inspired by the following repositories:
65 | - [CosFace implementation in PyTorch](https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py)
66 | - [CNN Image Retrieval in PyTorch](https://github.com/filipradenovic/cnnimageretrieval-pytorch) (for the GeM layer)
67 | - [Visual Geo-localization benchmark](https://github.com/gmberton/deep-visual-geo-localization-benchmark) (for the evaluation / test code)
68 | - [CosPlace](https://github.com/gmberton/EigenPlaces)
69 |
70 | ## Cite
71 | Here is the bibtex to cite our paper
72 | ```
73 | @inproceedings{Berton_2023_EigenPlaces,
74 | title={EigenPlaces: Training Viewpoint Robust Models for Visual Place Recognition},
75 | author={Berton, Gabriele and Trivigno, Gabriele and Caputo, Barbara and Masone, Carlo},
76 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
77 | year={2023},
78 | month={October},
79 | pages={11080-11090}
80 | }
81 | ```
82 |
--------------------------------------------------------------------------------
/augmentations.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from typing import Tuple, Union
4 | import torchvision.transforms as T
5 |
6 |
7 | class DeviceAgnosticColorJitter(T.ColorJitter):
8 | def __init__(self, brightness: float = 0., contrast: float = 0., saturation: float = 0., hue: float = 0.):
9 | """This is the same as T.ColorJitter but it only accepts batches of images and works on GPU"""
10 | super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
11 |
12 | def forward(self, images: torch.Tensor) -> torch.Tensor:
13 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}"
14 | B, C, H, W = images.shape
15 | # Applies a different color jitter to each image
16 | color_jitter = super(DeviceAgnosticColorJitter, self).forward
17 | augmented_images = [color_jitter(img).unsqueeze(0) for img in images]
18 | augmented_images = torch.cat(augmented_images)
19 | assert augmented_images.shape == torch.Size([B, C, H, W])
20 | return augmented_images
21 |
22 |
23 | class DeviceAgnosticRandomResizedCrop(T.RandomResizedCrop):
24 | def __init__(self, size: Union[int, Tuple[int, int]], scale: float):
25 | """This is the same as T.RandomResizedCrop but it only accepts batches of images and works on GPU"""
26 | super().__init__(size=size, scale=scale, antialias=True)
27 |
28 | def forward(self, images: torch.Tensor) -> torch.Tensor:
29 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}"
30 | B, C, H, W = images.shape
31 | # Applies a different color jitter to each image
32 | random_resized_crop = super(DeviceAgnosticRandomResizedCrop, self).forward
33 | augmented_images = [random_resized_crop(img).unsqueeze(0) for img in images]
34 | augmented_images = torch.cat(augmented_images)
35 | return augmented_images
36 |
37 |
38 | if __name__ == "__main__":
39 | """
40 | You can run this script to visualize the transformations, and verify that
41 | the augmentations are applied individually on each image of the batch.
42 | """
43 | from PIL import Image
44 | # Import skimage in here, so it is not necessary to install it unless you run this script
45 | from skimage import data
46 |
47 | # Initialize DeviceAgnosticRandomResizedCrop
48 | random_crop = DeviceAgnosticRandomResizedCrop(size=[256, 256], scale=[0.5, 1])
49 | # Create a batch with 2 astronaut images
50 | pil_image = Image.fromarray(data.astronaut())
51 | tensor_image = T.functional.to_tensor(pil_image).unsqueeze(0)
52 | images_batch = torch.cat([tensor_image, tensor_image])
53 | # Apply augmentation (individually on each of the 2 images)
54 | augmented_batch = random_crop(images_batch)
55 | # Convert to PIL images
56 | augmented_image_0 = T.functional.to_pil_image(augmented_batch[0])
57 | augmented_image_1 = T.functional.to_pil_image(augmented_batch[1])
58 | # Visualize the original image, as well as the two augmented ones
59 | pil_image.show()
60 | augmented_image_0.show()
61 | augmented_image_1.show()
62 |
--------------------------------------------------------------------------------
/commons.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import torch
5 | import random
6 | import logging
7 | import traceback
8 | import numpy as np
9 |
10 |
11 | class InfiniteDataLoader(torch.utils.data.DataLoader):
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 | self.dataset_iterator = super().__iter__()
15 |
16 | def __iter__(self):
17 | return self
18 |
19 | def __next__(self):
20 | try:
21 | batch = next(self.dataset_iterator)
22 | except StopIteration:
23 | self.dataset_iterator = super().__iter__()
24 | batch = next(self.dataset_iterator)
25 | return batch
26 |
27 |
28 | def make_deterministic(seed: int = 0):
29 | """Make results deterministic. If seed == -1, do not make deterministic.
30 | Running your script in a deterministic way might slow it down.
31 | Note that for some packages (eg: sklearn's PCA) this function is not enough.
32 | """
33 | seed = int(seed)
34 | if seed == -1:
35 | return
36 | random.seed(seed)
37 | np.random.seed(seed)
38 | torch.manual_seed(seed)
39 | torch.cuda.manual_seed_all(seed)
40 | torch.backends.cudnn.deterministic = True
41 | torch.backends.cudnn.benchmark = False
42 |
43 |
44 | def setup_logging(output_folder: str, exist_ok: bool = False, console: str = "debug",
45 | info_filename: str = "info.log", debug_filename: str = "debug.log"):
46 | """Set up logging files and console output.
47 | Creates one file for INFO logs and one for DEBUG logs.
48 | Args:
49 | output_folder (str): creates the folder where to save the files.
50 | exist_ok (boolean): if False throw a FileExistsError if output_folder already exists
51 | debug (str):
52 | if == "debug" prints on console debug messages and higher
53 | if == "info" prints on console info messages and higher
54 | if == None does not use console (useful when a logger has already been set)
55 | info_filename (str): the name of the info file. if None, don't create info file
56 | debug_filename (str): the name of the debug file. if None, don't create debug file
57 | """
58 | if not exist_ok and os.path.exists(output_folder):
59 | raise FileExistsError(f"{output_folder} already exists!")
60 | os.makedirs(output_folder, exist_ok=True)
61 | import matplotlib
62 | logging.getLogger('matplotlib').disabled = True
63 | logging.getLogger('PIL').disabled = True
64 | logging.getLogger('staticmap').disabled = True
65 | logging.getLogger('requests').disabled = True
66 | logging.getLogger('PIL').setLevel(logging.WARNING)
67 | import requests
68 | logging.getLogger("requests").setLevel(logging.WARNING)
69 | import urllib3
70 | logging.getLogger("urllib3").setLevel(logging.WARNING)
71 |
72 | logging.getLogger('matplotlib.font_manager').disabled = True
73 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S")
74 | logger = logging.getLogger('')
75 | logger.setLevel(logging.DEBUG)
76 |
77 | if info_filename is not None:
78 | info_file_handler = logging.FileHandler(f'{output_folder}/{info_filename}')
79 | info_file_handler.setLevel(logging.INFO)
80 | info_file_handler.setFormatter(base_formatter)
81 | logger.addHandler(info_file_handler)
82 |
83 | if debug_filename is not None:
84 | debug_file_handler = logging.FileHandler(f'{output_folder}/{debug_filename}')
85 | debug_file_handler.setLevel(logging.DEBUG)
86 | debug_file_handler.setFormatter(base_formatter)
87 | logger.addHandler(debug_file_handler)
88 |
89 | if console is not None:
90 | console_handler = logging.StreamHandler()
91 | if console == "debug":
92 | console_handler.setLevel(logging.DEBUG)
93 | if console == "info":
94 | console_handler.setLevel(logging.INFO)
95 | console_handler.setFormatter(base_formatter)
96 | logger.addHandler(console_handler)
97 |
98 | def my_handler(type_, value, tb):
99 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb)))
100 | logging.info("Experiment finished (with some errors)")
101 | sys.excepthook = my_handler
102 |
--------------------------------------------------------------------------------
/cosface_loss.py:
--------------------------------------------------------------------------------
1 |
2 | # Based on https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import Parameter
7 |
8 |
9 | def cosine_sim(x1: torch.Tensor, x2: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
10 | ip = torch.mm(x1, x2.t())
11 | w1 = torch.norm(x1, 2, dim)
12 | w2 = torch.norm(x2, 2, dim)
13 | return ip / torch.ger(w1, w2).clamp(min=eps)
14 |
15 |
16 | class MarginCosineProduct(nn.Module):
17 | """Implement of large margin cosine distance:
18 | Args:
19 | in_features: size of each input sample
20 | out_features: size of each output sample
21 | s: norm of input feature
22 | m: margin
23 | """
24 | def __init__(self, in_features: int, out_features: int, s: float = 30.0, m: float = 0.40):
25 | super().__init__()
26 | self.in_features = in_features
27 | self.out_features = out_features
28 | self.s = s
29 | self.m = m
30 | self.weight = Parameter(torch.Tensor(out_features, in_features))
31 | nn.init.xavier_uniform_(self.weight)
32 |
33 | def forward(self, inputs: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
34 | cosine = cosine_sim(inputs, self.weight)
35 | one_hot = torch.zeros_like(cosine)
36 | one_hot.scatter_(1, label.view(-1, 1), 1.0)
37 | output = self.s * (cosine - one_hot * self.m)
38 | return output
39 |
40 | def __repr__(self):
41 | return self.__class__.__name__ + '(' \
42 | + 'in_features=' + str(self.in_features) \
43 | + ', out_features=' + str(self.out_features) \
44 | + ', s=' + str(self.s) \
45 | + ', m=' + str(self.m) + ')'
46 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gmberton/EigenPlaces/fef66475d45f9a65e7e10ba2c360c25396478f44/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import logging
4 | from glob import glob
5 | from PIL import ImageFile
6 |
7 | ImageFile.LOAD_TRUNCATED_IMAGES = True
8 |
9 |
10 | def read_images_paths(dataset_folder, get_abs_path=False):
11 | """Find images within 'dataset_folder' and return their relative paths as a list.
12 | If there is a file 'dataset_folder'_images_paths.txt, read paths from such file.
13 | Otherwise, use glob(). Keeping the paths in the file speeds up computation,
14 | because using glob over large folders can be slow.
15 |
16 | Parameters
17 | ----------
18 | dataset_folder : str, folder containing JPEG images
19 | get_abs_path : bool, if True return absolute paths, otherwise remove
20 | dataset_folder from each path
21 |
22 | Returns
23 | -------
24 | images_paths : list[str], paths of JPEG images within dataset_folder
25 | """
26 |
27 | if not os.path.exists(dataset_folder):
28 | raise FileNotFoundError(f"Folder {dataset_folder} does not exist")
29 |
30 | file_with_paths = dataset_folder + "_images_paths.txt"
31 | if os.path.exists(file_with_paths):
32 | logging.debug(f"Reading paths of images within {dataset_folder} from {file_with_paths}")
33 | with open(file_with_paths, "r") as file:
34 | images_paths = file.read().splitlines()
35 | images_paths = [dataset_folder + "/" + path for path in images_paths]
36 | # Sanity check that paths within the file exist
37 | if not os.path.exists(images_paths[0]):
38 | raise FileNotFoundError(f"Image with path {images_paths[0]} "
39 | f"does not exist within {dataset_folder}. It is likely "
40 | f"that the content of {file_with_paths} is wrong.")
41 | else:
42 | logging.debug(f"Searching images in {dataset_folder} with glob()")
43 | images_paths = sorted(glob(f"{dataset_folder}/**/*.jpg", recursive=True))
44 | if len(images_paths) == 0:
45 | raise FileNotFoundError(f"Directory {dataset_folder} does not contain any JPEG images")
46 |
47 | if not get_abs_path: # Remove dataset_folder from the path
48 | images_paths = [p[len(dataset_folder) + 1:] for p in images_paths]
49 |
50 | return images_paths
51 |
52 |
--------------------------------------------------------------------------------
/datasets/eigenplaces_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import utm
4 | import math
5 | import torch
6 | import random
7 | import imageio
8 | import logging
9 | import numpy as np
10 | from PIL import Image
11 | from PIL import ImageFile
12 | import torchvision.transforms as tfm
13 | from collections import defaultdict
14 |
15 | from datasets.map_utils import create_map
16 | import datasets.dataset_utils as dataset_utils
17 |
18 | ImageFile.LOAD_TRUNCATED_IMAGES = True
19 |
20 | PANO_WIDTH = int(512*6.5)
21 |
22 |
23 | def get_angle(focal_point, obs_point):
24 | obs_e, obs_n = float(obs_point[0]), float(obs_point[1])
25 | focal_e, focal_n = focal_point
26 | side1 = focal_e - obs_e
27 | side2 = focal_n - obs_n
28 | angle = - math.atan2(side1, side2) / math.pi * 90 * 2
29 | return angle
30 |
31 |
32 | def get_eigen_things(utm_coords):
33 | mu = utm_coords.mean(0)
34 | norm_data = utm_coords - mu
35 | eigenvectors, eigenvalues, v = np.linalg.svd(norm_data.T, full_matrices=False)
36 | return eigenvectors, eigenvalues, mu
37 |
38 |
39 | def rotate_2d_vector(vector, angle):
40 | assert vector.shape == (2,)
41 | theta = np.deg2rad(angle)
42 | rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
43 | [np.sin(theta), np.cos(theta)]])
44 | rotated_point = np.dot(rot_mat, vector)
45 | return rotated_point
46 |
47 |
48 | def get_focal_point(utm_coords, meters_from_center=20, angle=0):
49 | """Return the focal point from a set of utm coords"""
50 | B, D = utm_coords.shape
51 | assert D == 2
52 | eigenvectors, eigenvalues, mu = get_eigen_things(utm_coords)
53 |
54 | direction = rotate_2d_vector(eigenvectors[1], angle)
55 | focal_point = mu + direction * meters_from_center
56 | return focal_point
57 |
58 |
59 | class EigenPlacesDataset(torch.utils.data.Dataset):
60 | def __init__(self, dataset_folder,
61 | M=20, N=5, focal_dist=10, current_group=0,
62 | min_images_per_class=10, angle=0, visualize_classes=0):
63 | """
64 | Parameters (please check our paper for a clearer explanation of the parameters).
65 | ----------
66 | dataset_folder : str, the path of the folder with the train images.
67 | M : int, the length of the side of each cell in meters.
68 | N : int, distance (M-wise) between two classes of the same group.
69 | focal_dist : int, distance (in meters) between the center of the class and
70 | the focal point. The center of the class is computed as the
71 | mean of the positions of the images within the class.
72 | current_group : int, which one of the groups to consider.
73 | min_images_per_class : int, minimum number of image in a class.
74 | angle : int, the angle formed between the line of the first principal
75 | component, and the line that connects the center of gravity of the
76 | images to the focal point.
77 | visualize_classes : int, the number of classes for which to create
78 | visualizations. Visualizations of a class consists in its map and
79 | the images belonging to it.
80 | """
81 | super().__init__()
82 | self.M = M
83 | self.N = N
84 | self.focal_dist = focal_dist
85 | self.current_group = current_group
86 | self.dataset_folder = dataset_folder
87 |
88 | filename = f"cache/sfxl_M{M}_N{N}_mipc{min_images_per_class}.torch"
89 | if not os.path.exists(filename):
90 | os.makedirs("cache", exist_ok=True)
91 | logging.info(f"Cached dataset {filename} does not exist, I'll create it now.")
92 | self.initialize(dataset_folder, M, N, min_images_per_class, filename)
93 | elif current_group == 0:
94 | logging.info(f"Using cached dataset {filename}")
95 |
96 | classes_per_group, self.images_per_class = torch.load(filename)
97 | if current_group >= len(classes_per_group):
98 | raise ValueError(f"With this configuration there are only {len(classes_per_group)} " +
99 | f"groups, therefore I can't create the {current_group}th group. " +
100 | "You should reduce the number of groups by setting for example " +
101 | f"'--groups_num {current_group}'")
102 | self.classes_ids = classes_per_group[current_group]
103 |
104 | new_classes_ids = []
105 | self.focal_point_per_class = {}
106 | for class_id in self.classes_ids:
107 | paths = self.images_per_class[class_id]
108 | u_coords = np.array([p.split("@")[1:3] for p in paths]).astype(float)
109 |
110 | focal_point = get_focal_point(u_coords, focal_dist, angle=angle)
111 | new_classes_ids.append(class_id)
112 | self.focal_point_per_class[class_id] = focal_point
113 |
114 | self.classes_ids = new_classes_ids
115 |
116 | # This is only for logging, debugging and visualizations
117 | for class_num in range(visualize_classes):
118 | random_class_id = random.choice(self.classes_ids)
119 | paths = self.images_per_class[random_class_id]
120 | focal_point = self.focal_point_per_class[random_class_id]
121 | focal_point_lat_lon = np.array(utm.to_latlon(focal_point[0], focal_point[1], 10, 'S'))
122 | lats_lons = np.array([p.split("@")[5:7] for p in paths]).astype(float)
123 | lats_lons += (np.random.randn(*lats_lons.shape) / 500000) # Add a little noise to avoid overlapping
124 |
125 | min_e, min_n = random_class_id
126 | cell_utms = (min_e, min_n), (min_e, min_n + M), (min_e + M, min_n + M), (min_e + M, min_n)
127 | cell_corners = np.array([utm.to_latlon(*u, 10, 'S') for u in cell_utms])
128 |
129 | output_folder = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
130 | folder = f"{output_folder}/visualizations/group{current_group}_{class_num}_{random_class_id}"
131 | os.makedirs(folder)
132 | try:
133 | img_map = create_map([lats_lons, lats_lons.mean(0).reshape(1, 2), focal_point_lat_lon.reshape(1, 2), cell_corners],
134 | colors=["r", "b", "g", "orange"],
135 | legend_names=["images position", "center of mass",
136 | "focal point", f"cell corners ({M} meters)"],
137 | dot_sizes=[10, 100, 100, 100])
138 | imageio.imsave(f"{folder}/@00_map.jpg", img_map)
139 | except RuntimeError:
140 | # Sometimes there are errors due to staticmap (komoot) servers
141 | logging.warn("There was some problem while downloading the map of the class for visualization. "
142 | "This will not influence training.")
143 |
144 | images_paths = self.images_per_class[random_class_id]
145 | for path in images_paths:
146 | crop = self.get_crop(self.dataset_folder + "/" + path, focal_point)
147 | crop = tfm.functional.to_pil_image(crop)
148 | crop.save(f"{folder}/{os.path.basename(path)}")
149 |
150 | @staticmethod
151 | def get_crop(pano_path, focal_point):
152 | obs_point = pano_path.split("@")[1:3]
153 | angle = - get_angle(focal_point, obs_point) % 360
154 | crop_offset = int((angle / 360 * PANO_WIDTH) % PANO_WIDTH)
155 | yaw = int(pano_path.split("@")[9])
156 | north_yaw_in_degrees = (180-yaw) % 360
157 | yaw_offset = int((north_yaw_in_degrees / 360) * PANO_WIDTH)
158 | offset = (yaw_offset + crop_offset - 256) % PANO_WIDTH
159 | pano_pil = Image.open(pano_path)
160 | if offset + 512 <= PANO_WIDTH:
161 | pil_crop = pano_pil.crop((offset, 0, offset + 512, 512))
162 | else:
163 | crop1 = pano_pil.crop((offset, 0, PANO_WIDTH, 512))
164 | crop2 = pano_pil.crop((0, 0, 512 - (PANO_WIDTH - offset), 512))
165 | pil_crop = Image.new('RGB', (512, 512))
166 | pil_crop.paste(crop1, (0, 0))
167 | pil_crop.paste(crop2, (crop1.size[0], 0))
168 | crop = tfm.functional.to_tensor(pil_crop)
169 |
170 | return crop
171 |
172 | def __getitem__(self, class_num):
173 | # This function takes as input the class_num instead of the index of
174 | # the image. This way each class is equally represented during training.
175 | class_id = self.classes_ids[class_num]
176 | focal_point = self.focal_point_per_class[class_id]
177 | pano_path = self.dataset_folder + "/" + random.choice(self.images_per_class[class_id])
178 | crop = self.get_crop(pano_path, focal_point)
179 | return crop, class_num, pano_path
180 |
181 | def get_images_num(self):
182 | """Return the number of images within this group."""
183 | return sum([len(self.images_per_class[c]) for c in self.classes_ids])
184 |
185 | def __len__(self):
186 | """Return the number of classes within this group."""
187 | return len(self.classes_ids)
188 |
189 | @staticmethod
190 | def initialize(dataset_folder, M, N, min_images_per_class, filename):
191 | logging.debug(f"Searching training images in {dataset_folder}")
192 |
193 | images_paths = dataset_utils.read_images_paths(dataset_folder)
194 | logging.debug(f"Found {len(images_paths)} images")
195 |
196 | logging.debug("For each image, get its UTM east, UTM north from its path")
197 | images_metadatas = [p.split("@") for p in images_paths]
198 | # field 1 is UTM east, field 2 is UTM north
199 | utmeast_utmnorth = [(m[1], m[2]) for m in images_metadatas]
200 | utmeast_utmnorth = np.array(utmeast_utmnorth).astype(float)
201 |
202 | logging.debug("For each image, get class and group to which it belongs")
203 | class_id__group_id = [EigenPlacesDataset.get__class_id__group_id(*m, M, N)
204 | for m in utmeast_utmnorth]
205 |
206 | logging.debug("Group together images belonging to the same class")
207 | images_per_class = defaultdict(list)
208 | for image_path, (class_id, _) in zip(images_paths, class_id__group_id):
209 | images_per_class[class_id].append(image_path)
210 |
211 | # Images_per_class is a dict where the key is class_id, and the value
212 | # is a list with the paths of images within that class.
213 | images_per_class = {k: v for k, v in images_per_class.items() if len(v) >= min_images_per_class}
214 |
215 | logging.debug("Group together classes belonging to the same group")
216 | # Classes_per_group is a dict where the key is group_id, and the value
217 | # is a list with the class_ids belonging to that group.
218 | classes_per_group = defaultdict(set)
219 | for class_id, group_id in class_id__group_id:
220 | if class_id not in images_per_class:
221 | continue # Skip classes with too few images
222 | classes_per_group[group_id].add(class_id)
223 |
224 | # Convert classes_per_group to a list of lists.
225 | # Each sublist represents the classes within a group.
226 | classes_per_group = [list(c) for c in classes_per_group.values()]
227 |
228 | torch.save((classes_per_group, images_per_class), filename)
229 |
230 | @staticmethod
231 | def get__class_id__group_id(utm_east, utm_north, M, N):
232 | """Return class_id and group_id for a given point.
233 | The class_id is a triplet (tuple) of UTM_east, UTM_north
234 | (e.g. (396520, 4983800)).
235 | The group_id represents the group to which the class belongs
236 | (e.g. (0, 1)), and it is between (0, 0) and (N, N).
237 | """
238 | rounded_utm_east = int(utm_east // M * M) # Rounded to nearest lower multiple of M
239 | rounded_utm_north = int(utm_north // M * M)
240 |
241 | class_id = (rounded_utm_east, rounded_utm_north)
242 | # group_id goes from (0, 0) to (N, N)
243 | group_id = (rounded_utm_east % (M * N) // M,
244 | rounded_utm_north % (M * N) // M)
245 | return class_id, group_id
246 |
--------------------------------------------------------------------------------
/datasets/map_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import math
4 | import imageio
5 | import numpy as np
6 | import geopy.distance
7 | import matplotlib.pyplot as plt
8 | from staticmap import StaticMap, Polygon
9 |
10 |
11 | def _lon_to_x(lon, zoom):
12 | if not (-180 <= lon <= 180): lon = (lon + 180) % 360 - 180
13 | return ((lon + 180.) / 360) * pow(2, zoom)
14 |
15 |
16 | def _lat_to_y(lat, zoom):
17 | if not (-90 <= lat <= 90): lat = (lat + 90) % 180 - 90
18 | return (1 - math.log(math.tan(lat * math.pi / 180) + 1 / math.cos(lat * math.pi / 180)) / math.pi) / 2 * pow(2, zoom)
19 |
20 |
21 | def _download_map_image(min_lat=45.0, min_lon=7.6, max_lat=45.1, max_lon=7.7, size=2000):
22 | """"Download a map of the chosen area as a numpy image"""
23 | mean_lat = (min_lat + max_lat) / 2
24 | mean_lon = (min_lon + max_lon) / 2
25 | static_map = StaticMap(size, size)
26 | static_map.add_polygon(
27 | Polygon(((min_lon, min_lat), (min_lon, max_lat), (max_lon, max_lat), (max_lon, min_lat)), None, '#FFFFFF'))
28 | zoom = static_map._calculate_zoom()
29 |
30 | # print(((min_lat, min_lon), (max_lat, max_lon)))
31 | dist = geopy.distance.geodesic((min_lat, min_lon), (max_lat, max_lon)).m
32 | if dist < 50:
33 | zoom = 22
34 | else:
35 | zoom = 20 # static_map._calculate_zoom()
36 | static_map = StaticMap(size, size)
37 | image = static_map.render(zoom, [mean_lon, mean_lat])
38 | # print(f"You can see the map on Google Maps at this link www.google.com/maps/place/@{mean_lat},{mean_lon},{zoom - 1}z")
39 | min_lat_px, min_lon_px, max_lat_px, max_lon_px = \
40 | static_map._y_to_px(_lat_to_y(min_lat, zoom)), \
41 | static_map._x_to_px(_lon_to_x(min_lon, zoom)), \
42 | static_map._y_to_px(_lat_to_y(max_lat, zoom)), \
43 | static_map._x_to_px(_lon_to_x(max_lon, zoom))
44 | assert 0 <= max_lat_px < min_lat_px < size and 0 <= min_lon_px < max_lon_px < size
45 | return np.array(image)[max_lat_px:min_lat_px, min_lon_px:max_lon_px], static_map, zoom
46 |
47 |
48 | def get_edges(coordinates, enlarge=0):
49 | """
50 | Send the edges of the coordinates, i.e. the most south, west, north and
51 | east coordinates.
52 | :param coordinates: A list of numpy.arrays of shape (Nx2)
53 | :param float enlarge: How much to increase the coordinates, to enlarge
54 | the area included between the points
55 | :return: a tuple with the four float
56 | """
57 | min_lat, min_lon, max_lat, max_lon = (*np.concatenate(coordinates).min(0), *np.concatenate(coordinates).max(0))
58 | diff_lat = (max_lat - min_lat) * enlarge
59 | diff_lon = (max_lon - min_lon) * enlarge
60 | inc_min_lat, inc_min_lon, inc_max_lat, inc_max_lon = \
61 | min_lat - diff_lat, min_lon - diff_lon, max_lat + diff_lat, max_lon + diff_lon
62 | return inc_min_lat, inc_min_lon, inc_max_lat, inc_max_lon
63 |
64 |
65 | def create_map(coordinates, colors=None, dot_sizes=None, legend_names=None, map_intensity=0.6):
66 |
67 | dot_sizes = dot_sizes if dot_sizes is not None else [10] * len(coordinates)
68 | colors = colors if colors is not None else ["r"] * len(coordinates)
69 | assert len(coordinates) == len(dot_sizes) == len(colors), \
70 | f"The number of coordinates must be equals to the number of colors and dot_sizes, but they're " \
71 | f"{len(coordinates)}, {len(colors)}, {len(dot_sizes)}"
72 |
73 | # Add two dummy points to slightly enlarge the map
74 | min_lat, min_lon, max_lat, max_lon = get_edges(coordinates, enlarge=0.1)
75 | coordinates.append(np.array([[min_lat, min_lon], [max_lat, max_lon]]))
76 | # Download the map of the chosen area
77 | map_img, static_map, zoom = _download_map_image(min_lat, min_lon, max_lat, max_lon)
78 |
79 | scatters = []
80 | fig = plt.figure(figsize=(map_img.shape[1] / 100, map_img.shape[0] / 100), dpi=1000)
81 | for i, coord in enumerate(coordinates):
82 | for i in range(len(coord)): # Scale latitudes because of earth's curvature
83 | coord[i, 0] = -static_map._y_to_px(_lat_to_y(coord[i, 0], zoom))
84 | for coord, size, color in zip(coordinates, dot_sizes, colors):
85 | scatters.append(plt.scatter(coord[:, 1], coord[:, 0], s=size, color=color))
86 |
87 | if legend_names != None:
88 | plt.legend(scatters, legend_names, scatterpoints=1, loc='best',
89 | ncol=1, framealpha=0, prop={"weight": "bold", "size": 20})
90 |
91 | min_lat, min_lon, max_lat, max_lon = get_edges(coordinates)
92 | plt.ylim(min_lat, max_lat)
93 | plt.xlim(min_lon, max_lon)
94 | fig.subplots_adjust(bottom=0, top=1, left=0, right=1)
95 | fig.canvas.draw()
96 | plot_img = np.array(fig.canvas.renderer._renderer)
97 | plt.close()
98 |
99 | plot_img = cv2.resize(plot_img[:, :, :3], map_img.shape[:2][::-1], interpolation=cv2.INTER_LANCZOS4)
100 | map_img[(map_img.sum(2) < 444)] = 188 # brighten dark pixels
101 | map_img = (((map_img / 255) ** map_intensity) * 255).astype(np.uint8) # fade map
102 | mask = (plot_img.sum(2) == 255 * 3)[:, :, None] # mask of plot, to find white pixels
103 | final_map = map_img * mask + plot_img * (~mask)
104 | return final_map
105 |
106 |
107 | if __name__ == "__main__":
108 | # Create a map containing major cities of Italy, Germany and France
109 | coordinates = [
110 | np.array([[41.8931, 12.4828], [45.4669, 9.1900], [40.8333, 14.2500]]),
111 | np.array([[52.5200, 13.4050], [48.7775, 9.1800], [48.1375, 11.5750]]),
112 | np.array([[48.8567, 2.3522], [43.2964, 5.3700], [45.7600, 4.8400]])
113 | ]
114 | map_img = create_map(
115 | coordinates,
116 | colors=["green", "black", "blue"],
117 | dot_sizes=[1000, 1000, 1000],
118 | legend_names=[
119 | "Main Italian Cities",
120 | "Main German Cities",
121 | "Main French Cities",
122 | ])
123 |
124 | imageio.imsave("cities.png", map_img)
125 |
126 |
--------------------------------------------------------------------------------
/datasets/test_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import torch.utils.data as data
6 | import torchvision.transforms as transforms
7 | from sklearn.neighbors import NearestNeighbors
8 |
9 | import datasets.dataset_utils as dataset_utils
10 |
11 |
12 | class TestDataset(data.Dataset):
13 | def __init__(self, dataset_folder, database_folder="database",
14 | queries_folder="queries", positive_dist_threshold=25):
15 | """Dataset with images from database and queries, used for validation and test.
16 | Parameters
17 | ----------
18 | dataset_folder : str, should contain the path to the val or test set,
19 | which contains the folders {database_folder} and {queries_folder}.
20 | database_folder : str, name of folder with the database.
21 | queries_folder : str, name of folder with the queries.
22 | positive_dist_threshold : int, distance in meters for a prediction to
23 | be considered a positive.
24 | """
25 | super().__init__()
26 |
27 | self.database_folder = os.path.join(dataset_folder, database_folder)
28 | self.queries_folder = os.path.join(dataset_folder, queries_folder)
29 | self.database_paths = dataset_utils.read_images_paths(self.database_folder, get_abs_path=True)
30 | self.queries_paths = dataset_utils.read_images_paths(self.queries_folder, get_abs_path=True)
31 |
32 | self.dataset_name = os.path.basename(dataset_folder)
33 |
34 | #### Read paths and UTM coordinates for all images.
35 | # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg
36 | self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float)
37 | self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(float)
38 |
39 | # Find positives_per_query, which are within positive_dist_threshold (default 25 meters)
40 | knn = NearestNeighbors(n_jobs=-1)
41 | knn.fit(self.database_utms)
42 | self.positives_per_query = knn.radius_neighbors(
43 | self.queries_utms, radius=positive_dist_threshold, return_distance=False
44 | )
45 |
46 | self.images_paths = self.database_paths + self.queries_paths
47 |
48 | self.database_num = len(self.database_paths)
49 | self.queries_num = len(self.queries_paths)
50 |
51 | self.base_transform = transforms.Compose([
52 | transforms.ToTensor(),
53 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
54 | ])
55 |
56 | def __getitem__(self, index):
57 | image_path = self.images_paths[index]
58 | pil_img = Image.open(image_path)
59 | normalized_img = self.base_transform(pil_img)
60 | return normalized_img, index
61 |
62 | def __len__(self):
63 | return len(self.images_paths)
64 |
65 | def __repr__(self):
66 | return f"< {self.dataset_name} - #q: {self.queries_num}; #db: {self.database_num} >"
67 |
68 | def get_positives(self):
69 | return self.positives_per_query
70 |
--------------------------------------------------------------------------------
/eigenplaces_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gmberton/EigenPlaces/fef66475d45f9a65e7e10ba2c360c25396478f44/eigenplaces_model/__init__.py
--------------------------------------------------------------------------------
/eigenplaces_model/eigenplaces_network.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import logging
4 | import torchvision
5 | from torch import nn
6 | from typing import Tuple
7 |
8 | from eigenplaces_model.layers import Flatten, L2Norm, GeM
9 |
10 | # The number of channels in the last convolutional layer, the one before average pooling
11 | CHANNELS_NUM_IN_LAST_CONV = {
12 | "ResNet18": 512,
13 | "ResNet50": 2048,
14 | "ResNet101": 2048,
15 | "ResNet152": 2048,
16 | "VGG16": 512,
17 | }
18 |
19 |
20 | class GeoLocalizationNet_(nn.Module):
21 | def __init__(self, backbone : str, fc_output_dim : int):
22 | """Return a model_ for GeoLocalization.
23 |
24 | Args:
25 | backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet.
26 | fc_output_dim (int): the output dimension of the last fc layer, equivalent to the descriptors dimension.
27 | """
28 | super().__init__()
29 | assert backbone in CHANNELS_NUM_IN_LAST_CONV, f"backbone must be one of {list(CHANNELS_NUM_IN_LAST_CONV.keys())}"
30 | self.backbone, features_dim = _get_backbone(backbone)
31 | self.aggregation = nn.Sequential(
32 | L2Norm(),
33 | GeM(),
34 | Flatten(),
35 | nn.Linear(features_dim, fc_output_dim),
36 | L2Norm()
37 | )
38 |
39 | def forward(self, x):
40 | x = self.backbone(x)
41 | x = self.aggregation(x)
42 | return x
43 |
44 |
45 | def _get_torchvision_model(backbone_name : str) -> torch.nn.Module:
46 | """This function takes the name of a backbone and returns the corresponding pretrained
47 | model from torchvision. Examples of backbone_name are 'VGG16' or 'ResNet18'
48 | """
49 | return getattr(torchvision.models, backbone_name.lower())()
50 |
51 |
52 | def _get_backbone(backbone_name : str) -> Tuple[torch.nn.Module, int]:
53 | backbone = _get_torchvision_model(backbone_name)
54 |
55 | logging.info("Loading pretrained backbone's weights from CosPlace")
56 | cosplace = torch.hub.load("gmberton/cosplace", "get_trained_model", backbone=backbone_name, fc_output_dim=512)
57 | new_sd = {k1: v2 for (k1, v1), (k2, v2) in zip(backbone.state_dict().items(), cosplace.state_dict().items())
58 | if v1.shape == v2.shape}
59 | backbone.load_state_dict(new_sd, strict=False)
60 |
61 | if backbone_name.startswith("ResNet"):
62 | for name, child in backbone.named_children():
63 | if name == "layer3": # Freeze layers before conv_3
64 | break
65 | for params in child.parameters():
66 | params.requires_grad = False
67 | logging.debug(f"Train only layer3 and layer4 of the {backbone_name}, freeze the previous ones")
68 | layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer
69 |
70 | elif backbone_name == "VGG16":
71 | layers = list(backbone.features.children())[:-2] # Remove avg pooling and FC layer
72 | for layer in layers[:-5]:
73 | for p in layer.parameters():
74 | p.requires_grad = False
75 | logging.debug("Train last layers of the VGG-16, freeze the previous ones")
76 |
77 | backbone = torch.nn.Sequential(*layers)
78 |
79 | features_dim = CHANNELS_NUM_IN_LAST_CONV[backbone_name]
80 |
81 | return backbone, features_dim
82 |
83 |
--------------------------------------------------------------------------------
/eigenplaces_model/layers.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn.parameter import Parameter
6 |
7 |
8 | def gem(x, p=torch.ones(1)*3, eps: float = 1e-6):
9 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
10 |
11 |
12 | class GeM(nn.Module):
13 | def __init__(self, p=3, eps=1e-6):
14 | super().__init__()
15 | self.p = Parameter(torch.ones(1)*p)
16 | self.eps = eps
17 |
18 | def forward(self, x):
19 | return gem(x, p=self.p, eps=self.eps)
20 |
21 | def __repr__(self):
22 | return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})"
23 |
24 |
25 | class Flatten(torch.nn.Module):
26 | def __init__(self):
27 | super().__init__()
28 |
29 | def forward(self, x):
30 | assert x.shape[2] == x.shape[3] == 1, f"{x.shape[2]} != {x.shape[3]} != 1"
31 | return x[:, :, 0, 0]
32 |
33 |
34 | class L2Norm(nn.Module):
35 | def __init__(self, dim=1):
36 | super().__init__()
37 | self.dim = dim
38 |
39 | def forward(self, x):
40 | return F.normalize(x, p=2.0, dim=self.dim)
41 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | import torch
4 | import logging
5 | import multiprocessing
6 | from datetime import datetime
7 |
8 | import test
9 | import parser
10 | import commons
11 | from datasets.test_dataset import TestDataset
12 | from eigenplaces_model import eigenplaces_network
13 |
14 | torch.backends.cudnn.benchmark = True # Provides a speedup
15 |
16 | args = parser.parse_arguments()
17 | start_time = datetime.now()
18 | output_folder = f"logs/{args.save_dir}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}"
19 | commons.make_deterministic(args.seed)
20 | commons.setup_logging(output_folder, console="info")
21 | logging.info(" ".join(sys.argv))
22 | logging.info(f"Arguments: {args}")
23 | logging.info(f"The outputs are being saved in {output_folder}")
24 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.")
25 |
26 | #### Model
27 | if args.resume_model == "torchhub":
28 | model = torch.hub.load("gmberton/eigenplaces", "get_trained_model",
29 | backbone=args.backbone, fc_output_dim=args.fc_output_dim)
30 | else:
31 | model = eigenplaces_network.GeoLocalizationNet_(args.backbone, args.fc_output_dim)
32 |
33 | if args.resume_model is not None:
34 | logging.info(f"Loading model_ from {args.resume_model}")
35 | model_state_dict = torch.load(args.resume_model)
36 | model.load_state_dict(model_state_dict)
37 | else:
38 | logging.info("WARNING: You didn't provide a path to resume the model_ (--resume_model parameter). " +
39 | "Evaluation will be computed using randomly initialized weights.")
40 |
41 | model = model.to(args.device)
42 |
43 | test_ds = TestDataset(args.test_dataset_folder, queries_folder="queries",
44 | positive_dist_threshold=args.positive_dist_threshold)
45 |
46 | recalls, recalls_str = test.test(args, test_ds, model)
47 | logging.info(f"{test_ds}: {recalls_str}")
48 |
49 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 |
2 | dependencies = ['torch', 'torchvision']
3 |
4 | import torch
5 | from eigenplaces_model import eigenplaces_network
6 |
7 |
8 | AVAILABLE_TRAINED_MODELS = {
9 | # backbone : list of available fc_output_dim, which is equivalent to descriptors dimensionality
10 | "VGG16": [ 512],
11 | "ResNet18": [ 256, 512],
12 | "ResNet50": [128, 256, 512, 1024, 2048],
13 | "ResNet101": [128, 256, 512, 1024, 2048],
14 | }
15 |
16 |
17 | def get_trained_model(backbone : str = "ResNet50", fc_output_dim : int = 2048) -> torch.nn.Module:
18 | """Return a model trained with EigenPlaces on San Francisco eXtra Large.
19 |
20 | Args:
21 | backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet.
22 | fc_output_dim (int): the output dimension of the last fc layer, equivalent to
23 | the descriptors dimension. Must be between 32 and 2048, depending on model's availability.
24 |
25 | Return:
26 | model (torch.nn.Module): a trained model.
27 | """
28 | print(f"Returning EigenPlaces model with backbone: {backbone} with features dimension {fc_output_dim}")
29 | if backbone not in AVAILABLE_TRAINED_MODELS:
30 | raise ValueError(f"Parameter `backbone` is set to {backbone} but it must be one of {list(AVAILABLE_TRAINED_MODELS.keys())}")
31 | try:
32 | fc_output_dim = int(fc_output_dim)
33 | except:
34 | raise ValueError(f"Parameter `fc_output_dim` must be an integer, but it is set to {fc_output_dim}")
35 | if fc_output_dim not in AVAILABLE_TRAINED_MODELS[backbone]:
36 | raise ValueError(f"Parameter `fc_output_dim` is set to {fc_output_dim}, but for backbone {backbone} "
37 | f"it must be one of {list(AVAILABLE_TRAINED_MODELS[backbone])}")
38 | model = eigenplaces_network.GeoLocalizationNet_(backbone, fc_output_dim)
39 | model.load_state_dict(
40 | torch.hub.load_state_dict_from_url(
41 | f'https://github.com/gmberton/EigenPlaces/releases/download/v1.0/{backbone}_{fc_output_dim}_eigenplaces.pth',
42 | map_location=torch.device('cpu'))
43 | )
44 | return model
--------------------------------------------------------------------------------
/parser.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 |
4 |
5 | def parse_arguments():
6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7 | # CosPlace Groups parameters
8 | parser.add_argument("--M", type=int, default=15, help="_")
9 | parser.add_argument("--N", type=int, default=3, help="_")
10 | parser.add_argument("--focal_dist", type=int, default=10, help="_") # done GS
11 | parser.add_argument("--s", type=float, default=100, help="_")
12 | parser.add_argument("--m", type=float, default=0.4, help="_")
13 | parser.add_argument("--lambda_lat", type=float, default=1., help="_")
14 | parser.add_argument("--lambda_front", type=float, default=1., help="_")
15 | parser.add_argument("--groups_num", type=int, default=0,
16 | help="If set to 0 use N*N groups")
17 |
18 | parser.add_argument("--min_images_per_class", type=int, default=5, help="_")
19 | # Model parameters
20 | parser.add_argument("--backbone", type=str, default="ResNet18",
21 | choices=["VGG16", "ResNet18", "ResNet50", "ResNet101", "ResNet152"], help="_")
22 | parser.add_argument("--fc_output_dim", type=int, default=512,
23 | help="Output dimension of final fully connected layer")
24 | # Training parameters
25 | parser.add_argument("--batch_size", type=int, default=32, help="_")
26 | parser.add_argument("--epochs_num", type=int, default=40, help="_")
27 | parser.add_argument("--iterations_per_epoch", type=int, default=5000, help="_")
28 | parser.add_argument("--lr", type=float, default=0.00001, help="_")
29 | parser.add_argument("--classifiers_lr", type=float, default=0.01, help="_")
30 | # Data augmentation
31 | parser.add_argument("--brightness", type=float, default=0.7, help="_")
32 | parser.add_argument("--contrast", type=float, default=0.7, help="_")
33 | parser.add_argument("--hue", type=float, default=0.5, help="_")
34 | parser.add_argument("--saturation", type=float, default=0.7, help="_")
35 | parser.add_argument("--random_resized_crop", type=float, default=0.5, help="_")
36 |
37 | # Validation / test parameters
38 | parser.add_argument("--infer_batch_size", type=int, default=16,
39 | help="Batch size for inference (validating and testing)")
40 | parser.add_argument("--positive_dist_threshold", type=int, default=25,
41 | help="distance in meters for a prediction to be considered a positive")
42 | # Resume parameters
43 | parser.add_argument("--resume_train", type=str, default=None,
44 | help="path to checkpoint to resume, e.g. logs/.../last_checkpoint.pth")
45 | parser.add_argument("--resume_model", type=str, default=None,
46 | help="path to model_ to resume, e.g. logs/.../best_model.pth. "
47 | "Use \"torchhub\" if you want to use one of our pretrained models")
48 | # Other parameters
49 | parser.add_argument("--device", type=str, default="cuda",
50 | choices=["cuda", "cpu"], help="_")
51 | parser.add_argument("--seed", type=int, default=0, help="_")
52 | parser.add_argument("--num_workers", type=int, default=8, help="_")
53 | parser.add_argument("--visualize_classes", type=int, default=0,
54 | help="Save map visualizations for X classes in the save_dir")
55 |
56 | # Paths parameters
57 | parser.add_argument("--train_dataset_folder", type=str, default=None,
58 | help="path of the folder with training images")
59 | parser.add_argument("--val_dataset_folder", type=str, default=None,
60 | help="path of the folder with val images (split in database/queries)")
61 | parser.add_argument("--test_dataset_folder", type=str, default=None,
62 | help="path of the folder with test images (split in database/queries)")
63 | parser.add_argument("--save_dir", type=str, default="default",
64 | help="name of directory on which to save the logs, under logs/save_dir")
65 |
66 | args = parser.parse_args()
67 | if args.groups_num == 0:
68 | args.groups_num = args.N * args.N
69 |
70 | return args
71 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | geopy>=2.3.0
2 | imageio>=2.31.1
3 | matplotlib>=3.7.2
4 | numpy>=1.25.1
5 | Pillow==9.0.0
6 | Requests>=2.31.0
7 | scikit_learn>=1.3.0
8 | scikit_image
9 | staticmap>=0.5.5
10 | torch>=2.0.1
11 | torchmetrics>=1.0.1
12 | torchvision>=0.15.2
13 | tqdm>=4.65.0
14 | urllib3>=2.0.3
15 | utm>=0.7.0
16 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 |
2 | import faiss
3 | import torch
4 | import logging
5 | import numpy as np
6 | from tqdm import tqdm
7 | from typing import Tuple
8 | from argparse import Namespace
9 | from torch.utils.data.dataset import Subset
10 | from torch.utils.data import DataLoader, Dataset
11 |
12 |
13 | # Compute R@1, R@5, R@10, R@20
14 | RECALL_VALUES = [1, 5, 10, 20]
15 |
16 |
17 | def test(args: Namespace, eval_ds: Dataset, model: torch.nn.Module, batchify : bool = False) -> Tuple[np.ndarray, str]:
18 | """Compute descriptors of the given dataset and compute the recalls."""
19 |
20 | model = model.eval()
21 | with torch.no_grad():
22 | logging.debug("Extracting database descriptors for evaluation/testing")
23 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
24 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
25 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
26 | all_descriptors = np.empty((len(eval_ds), args.fc_output_dim), dtype="float32")
27 | for images, indices in tqdm(database_dataloader, ncols=100):
28 | descriptors = model(images.to(args.device))
29 | descriptors = descriptors.cpu().numpy()
30 | all_descriptors[indices.numpy(), :] = descriptors
31 |
32 | logging.debug("Extracting queries descriptors for evaluation/testing")
33 | if batchify:
34 | queries_infer_batch_size = args.infer_batch_size
35 | else:
36 | queries_infer_batch_size = 1
37 | queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))
38 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
39 | batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda"))
40 | for images, indices in tqdm(queries_dataloader, ncols=100):
41 | descriptors = model(images.to(args.device))
42 | descriptors = descriptors.cpu().numpy()
43 | all_descriptors[indices.numpy(), :] = descriptors
44 |
45 | queries_descriptors = all_descriptors[eval_ds.database_num:]
46 | database_descriptors = all_descriptors[:eval_ds.database_num]
47 |
48 | # Use a kNN to find predictions
49 | faiss_index = faiss.IndexFlatL2(args.fc_output_dim)
50 | faiss_index.add(database_descriptors)
51 | del database_descriptors, all_descriptors
52 |
53 | logging.debug("Calculating recalls")
54 | _, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES))
55 |
56 | #### For each query, check if the predictions are correct
57 | positives_per_query = eval_ds.get_positives()
58 | recalls = np.zeros(len(RECALL_VALUES))
59 | for query_index, preds in enumerate(predictions):
60 | for i, n in enumerate(RECALL_VALUES):
61 | if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
62 | recalls[i:] += 1
63 | break
64 | # Divide by queries_num and multiply by 100, so the recalls are in percentages
65 | recalls = recalls / eval_ds.queries_num * 100
66 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)])
67 | return recalls, recalls_str
68 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | import torch
4 | import logging
5 | import torchmetrics
6 | from tqdm import tqdm
7 | import multiprocessing
8 | from datetime import datetime
9 | import torchvision.transforms as tfm
10 |
11 | import test
12 | import util
13 | import parser
14 | import commons
15 | import cosface_loss
16 | import augmentations
17 | from eigenplaces_model import eigenplaces_network
18 | from datasets.test_dataset import TestDataset
19 | from datasets.eigenplaces_dataset import EigenPlacesDataset
20 |
21 | torch.backends.cudnn.benchmark = True # Provides a speedup
22 |
23 | args = parser.parse_arguments()
24 | start_time = datetime.now()
25 | output_folder = f"logs/{args.save_dir}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}"
26 | commons.make_deterministic(args.seed)
27 | commons.setup_logging(output_folder, console="debug")
28 | logging.info(" ".join(sys.argv))
29 | logging.info(f"Arguments: {args}")
30 | logging.info(f"The outputs are being saved in {output_folder}")
31 |
32 | #### Model
33 | model = eigenplaces_network.GeoLocalizationNet_(args.backbone, args.fc_output_dim)
34 |
35 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.")
36 |
37 | if args.resume_model is not None:
38 | logging.debug(f"Loading model from {args.resume_model}")
39 | model_state_dict = torch.load(args.resume_model)
40 | model.load_state_dict(model_state_dict)
41 |
42 | model = model.to(args.device).train()
43 |
44 | #### Optimizer
45 | criterion = torch.nn.CrossEntropyLoss()
46 | model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
47 |
48 | #### Datasets
49 | groups = [EigenPlacesDataset(
50 | args.train_dataset_folder, M=args.M, N=args.N, focal_dist=args.focal_dist,
51 | current_group=n//2, min_images_per_class=args.min_images_per_class,
52 | angle=[0, 90][n % 2], visualize_classes=args.visualize_classes)
53 | for n in range(args.groups_num * 2)
54 | ]
55 | # Each group has its own classifier, which depends on the number of classes in the group
56 | classifiers = [cosface_loss.MarginCosineProduct(
57 | args.fc_output_dim, len(group), s=args.s, m=args.m) for group in groups]
58 | classifiers_optimizers = [torch.optim.Adam(classifier.parameters(), lr=args.classifiers_lr) for classifier in classifiers]
59 |
60 | gpu_augmentation = tfm.Compose([
61 | augmentations.DeviceAgnosticColorJitter(brightness=args.brightness,
62 | contrast=args.contrast,
63 | saturation=args.saturation,
64 | hue=args.hue),
65 | augmentations.DeviceAgnosticRandomResizedCrop([512, 512],
66 | scale=[1-args.random_resized_crop, 1]),
67 | tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
68 | ])
69 |
70 | logging.info(f"Using {len(groups)} groups")
71 | logging.info(f"The {len(groups)} groups have respectively the following "
72 | f"number of classes {[len(g) for g in groups]}")
73 | logging.info(f"The {len(groups)} groups have respectively the following "
74 | f"number of images {[g.get_images_num() for g in groups]}")
75 |
76 | logging.info(f"There are {len(groups[0])} classes for the first group, " +
77 | f"each epoch has {args.iterations_per_epoch} iterations " +
78 | f"with batch_size {args.batch_size}, therefore the model sees each class (on average) " +
79 | f"{args.iterations_per_epoch * args.batch_size / len(groups[0]):.1f} times per epoch")
80 |
81 | val_ds = TestDataset(f"{args.val_dataset_folder}")
82 | logging.info(f"Validation set: {val_ds}")
83 |
84 | #### Resume
85 | if args.resume_train:
86 | model, model_optimizer, classifiers, classifiers_optimizers, \
87 | best_val_recall1, start_epoch_num = \
88 | util.resume_train(args, output_folder, model, model_optimizer,
89 | classifiers, classifiers_optimizers)
90 |
91 | model = model.to(args.device)
92 | epoch_num = start_epoch_num - 1
93 | logging.info(f"Resuming from epoch {start_epoch_num} with best R@1 {best_val_recall1:.1f} " +
94 | f"from checkpoint {args.resume_train}")
95 | else:
96 | best_val_recall1 = start_epoch_num = 0
97 |
98 | #### Train / evaluation loop
99 | logging.info("Start training ...")
100 |
101 | scaler = torch.cuda.amp.GradScaler()
102 |
103 | for epoch_num in range(start_epoch_num, args.epochs_num):
104 |
105 | #### Train
106 | epoch_start_time = datetime.now()
107 |
108 | def get_iterator(groups, classifiers, classifiers_optimizers, batch_size, g_num):
109 | assert len(groups) == len(classifiers) == len(classifiers_optimizers)
110 | classifiers[g_num] = classifiers[g_num].to(args.device)
111 | util.move_to_device(classifiers_optimizers[g_num], args.device)
112 | return commons.InfiniteDataLoader(groups[g_num], num_workers=args.num_workers,
113 | batch_size=batch_size, shuffle=True,
114 | pin_memory=(args.device == "cuda"), drop_last=True)
115 |
116 | # Select classifier and dataloader according to epoch
117 | current_dataset_num = (epoch_num % args.groups_num) * 2
118 |
119 | iterators = []
120 | for i in range(2):
121 | iterators.append(get_iterator(groups, classifiers, classifiers_optimizers,
122 | args.batch_size, current_dataset_num + i))
123 | lateral_loss = torchmetrics.MeanMetric()
124 | frontal_loss = torchmetrics.MeanMetric()
125 |
126 | model = model.train()
127 | for iteration in tqdm(range(args.iterations_per_epoch), ncols=100):
128 | model_optimizer.zero_grad()
129 |
130 | #### EigenPlace ITERATION ####
131 | for i in range(2):
132 | classifiers_optimizers[current_dataset_num + i].zero_grad()
133 |
134 | images, targets, _ = next(iterators[i])
135 | images, targets = images.to(args.device), targets.to(args.device)
136 | with torch.cuda.amp.autocast():
137 | images = gpu_augmentation(images)
138 | descriptors = model(images)
139 | output = classifiers[current_dataset_num + i](descriptors, targets)
140 | loss = criterion(output, targets)
141 | if i == 0:
142 | loss *= args.lambda_lat
143 | else:
144 | loss *= args.lambda_front
145 | del images, output
146 | scaler.scale(loss).backward()
147 | scaler.step(classifiers_optimizers[current_dataset_num + i])
148 | if i == 0:
149 | lateral_loss.update(loss.detach().cpu())
150 | else:
151 | frontal_loss.update(loss.detach().cpu())
152 | del loss
153 | #######################
154 |
155 | scaler.step(model_optimizer)
156 | scaler.update()
157 |
158 | for i in range(2):
159 | classifiers[current_dataset_num + i] = classifiers[current_dataset_num + i].cpu()
160 | util.move_to_device(classifiers_optimizers[current_dataset_num + i], "cpu")
161 |
162 | logging.debug(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]} - "
163 | f"group {current_dataset_num} lateral_loss = {lateral_loss.compute():.4f} - "
164 | f"group {current_dataset_num + 1} frontal_loss = {frontal_loss.compute():.4f}")
165 |
166 | #### Evaluation
167 | recalls, recalls_str = test.test(args, val_ds, model, batchify=True)
168 | logging.info(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, {val_ds}: {recalls_str}")
169 | is_best = recalls[0] > best_val_recall1
170 | best_val_recall1 = max(recalls[0], best_val_recall1)
171 | # Save checkpoint, which contains all training parameters
172 | util.save_checkpoint({
173 | "epoch_num": epoch_num + 1,
174 | "model_state_dict": model.state_dict(),
175 | "optimizer_state_dict": model_optimizer.state_dict(),
176 | "classifiers_state_dict": [c.state_dict() for c in classifiers],
177 | "optimizers_state_dict": [c.state_dict() for c in classifiers_optimizers],
178 | "best_val_recall1": best_val_recall1
179 | }, is_best, output_folder)
180 |
181 | logging.info(f"Trained for {epoch_num+1:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}")
182 |
183 | #### Test best model_ on test set v1
184 | best_model_state_dict = torch.load(f"{output_folder}/best_model.pth")
185 | model.load_state_dict(best_model_state_dict)
186 |
187 | test_ds = TestDataset(f"{args.test_dataset_folder}")
188 | recalls, recalls_str = test.test(args, test_ds, model)
189 | logging.info(f"{test_ds}: {recalls_str}")
190 |
191 | logging.info("Experiment finished (without any errors)")
192 |
193 |
194 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import shutil
4 | import logging
5 | from typing import Type, List
6 | from argparse import Namespace
7 | from cosface_loss import MarginCosineProduct
8 |
9 |
10 | def move_to_device(optimizer: Type[torch.optim.Optimizer], device: str):
11 | for state in optimizer.state.values():
12 | for k, v in state.items():
13 | if torch.is_tensor(v):
14 | state[k] = v.to(device)
15 |
16 |
17 | def save_checkpoint(state: dict, is_best: bool, output_folder: str,
18 | ckpt_filename: str = "last_checkpoint.pth"):
19 | checkpoint_path = f"{output_folder}/{ckpt_filename}"
20 | torch.save(state, checkpoint_path, )
21 | if is_best:
22 | torch.save(state["model_state_dict"], f"{output_folder}/best_model.pth")
23 |
24 |
25 | def resume_train(args: Namespace, output_folder: str, model: torch.nn.Module,
26 | model_optimizer: Type[torch.optim.Optimizer],
27 | classifiers: List[MarginCosineProduct], classifiers_optimizers: List[Type[torch.optim.Optimizer]]):
28 |
29 | """Load model_, optimizer, and other training parameters"""
30 | logging.info(f"Loading checkpoint: {args.resume_train}")
31 | checkpoint = torch.load(args.resume_train)
32 | start_epoch_num = checkpoint["epoch_num"]
33 |
34 | model_state_dict = checkpoint["model_state_dict"]
35 | model.load_state_dict(model_state_dict)
36 |
37 | model = model.to(args.device)
38 | model_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
39 |
40 | # load classifiers and optimizers
41 | assert args.groups_num*2 == len(classifiers) == len(classifiers_optimizers) == \
42 | len(checkpoint["classifiers_state_dict"]) == len(checkpoint["optimizers_state_dict"]), \
43 | (f"{args.groups_num}, {len(classifiers)}, {len(classifiers_optimizers)}, "
44 | f"{len(checkpoint['classifiers_state_dict'])}, {len(checkpoint['optimizers_state_dict'])}")
45 |
46 | for c, sd in zip(classifiers, checkpoint["classifiers_state_dict"]):
47 | # Move classifiers to GPU before loading their optimizers
48 | c = c.to(args.device)
49 | c.load_state_dict(sd)
50 | for c, sd in zip(classifiers_optimizers, checkpoint["optimizers_state_dict"]):
51 | c.load_state_dict(sd)
52 | for c in classifiers:
53 | # Move classifiers back to CPU to save some GPU memory
54 | c = c.cpu()
55 |
56 | best_val_recall1 = checkpoint["best_val_recall1"]
57 |
58 | # Copy best model_ to current output_folder
59 | shutil.copy(args.resume_train.replace("last_checkpoint.pth", "best_model.pth"), output_folder)
60 |
61 | return model, model_optimizer, classifiers, classifiers_optimizers, best_val_recall1, start_epoch_num
62 |
--------------------------------------------------------------------------------