├── .gitignore ├── LICENSE ├── README.md ├── augmentations.py ├── commons.py ├── cosface_loss.py ├── cosplace_model ├── __init__.py ├── cosplace_network.py └── layers.py ├── datasets ├── __init__.py ├── dataset_utils.py ├── test_dataset.py └── train_dataset.py ├── eval.py ├── hubconf.py ├── parser.py ├── requirements.txt ├── test.py ├── train.py ├── util.py └── visualizations.py /.gitignore: -------------------------------------------------------------------------------- 1 | .spyproject 2 | __pycache__ 3 | logs 4 | cache 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Gabriele Berton, Carlo Masone, Barbara Caputo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Rethinking Visual Geo-localization for Large-Scale Applications 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-visual-geo-localization-for-large/visual-place-recognition-on-pittsburgh-250k)](https://paperswithcode.com/sota/visual-place-recognition-on-pittsburgh-250k?p=rethinking-visual-geo-localization-for-large)[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-visual-geo-localization-for-large/visual-place-recognition-on-pittsburgh-30k)](https://paperswithcode.com/sota/visual-place-recognition-on-pittsburgh-30k?p=rethinking-visual-geo-localization-for-large) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-visual-geo-localization-for-large/visual-place-recognition-on-tokyo247)](https://paperswithcode.com/sota/visual-place-recognition-on-tokyo247?p=rethinking-visual-geo-localization-for-large) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-visual-geo-localization-for-large/visual-place-recognition-on-mapillary-val)](https://paperswithcode.com/sota/visual-place-recognition-on-mapillary-val?p=rethinking-visual-geo-localization-for-large) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-visual-geo-localization-for-large/visual-place-recognition-on-st-lucia)](https://paperswithcode.com/sota/visual-place-recognition-on-st-lucia?p=rethinking-visual-geo-localization-for-large) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-visual-geo-localization-for-large/visual-place-recognition-on-sf-xl-test-v1)](https://paperswithcode.com/sota/visual-place-recognition-on-sf-xl-test-v1?p=rethinking-visual-geo-localization-for-large) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-visual-geo-localization-for-large/visual-place-recognition-on-sf-xl-test-v2)](https://paperswithcode.com/sota/visual-place-recognition-on-sf-xl-test-v2?p=rethinking-visual-geo-localization-for-large) 10 | 11 | This is the official pyTorch implementation of the CVPR 2022 paper "Rethinking Visual Geo-localization for Large-Scale Applications". 12 | The paper presents a new dataset called San Francisco eXtra Large (SF-XL, go [_here_](https://forms.gle/wpyDzhDyoWLQygAT9) to download it), and a highly scalable training method (called CosPlace), which allows to reach SOTA results with compact descriptors. 13 | 14 | 15 | [[CVPR OpenAccess](https://openaccess.thecvf.com/content/CVPR2022/html/Berton_Rethinking_Visual_Geo-Localization_for_Large-Scale_Applications_CVPR_2022_paper.html)] [[ArXiv](https://arxiv.org/abs/2204.02287)] [[Video](https://www.youtube.com/watch?v=oDyL6oVNN3I)] [[BibTex](https://github.com/gmberton/CosPlace?tab=readme-ov-file#cite)] 16 | 17 | 18 | 19 | The images below represent respectively: 20 | 1) the map of San Francisco eXtra Large 21 | 2) a visualization of how CosPlace Groups (read datasets) are formed 22 | 3) results with CosPlace vs other methods on Pitts250k (CosPlace trained on SF-XL, others on Pitts30k) 23 |

24 | 25 | 26 | 27 |

28 | 29 | 30 | 31 | ## Train 32 | After downloading the SF-XL dataset, simply run 33 | 34 | `$ python3 train.py --train_set_folder path/to/sf_xl/raw/train/database --val_set_folder path/to/sf_xl/processed/val --test_set_folder path/to/sf_xl/processed/test` 35 | 36 | the script automatically splits SF-XL in CosPlace Groups, and saves the resulting object in the folder `cache`. 37 | By default training is performed with a ResNet-18 with descriptors dimensionality 512, which fits in less than 4GB of VRAM. 38 | 39 | To change the backbone or the output descriptors dimensionality simply run 40 | 41 | `$ python3 train.py --backbone ResNet50 --fc_output_dim 128` 42 | 43 | You can also speed up your training with Automatic Mixed Precision (note that all results/statistics from the paper did not use AMP) 44 | 45 | `$ python3 train.py --use_amp16` 46 | 47 | Run `$ python3 train.py -h` to have a look at all the hyperparameters that you can change. You will find all hyperparameters mentioned in the paper. 48 | 49 | #### Dataset size and lightweight version 50 | 51 | The SF-XL dataset is about 1 TB. 52 | For training only a subset of the images is used, and you can use this subset for training, which is only 360 GB. 53 | If this is still too heavy for you (e.g. if you're using Colab), but you would like to run CosPlace, we also created a small version of SF-XL, which is only 5 GB. 54 | Obviously, using the small version will lead to lower results, and it should be used only for debugging / exploration purposes. 55 | More information on the dataset and lightweight version are on the README that you can find on the dataset download page (go [_here_](https://forms.gle/wpyDzhDyoWLQygAT9) to find it). 56 | 57 | #### Reproducibility 58 | Results from the paper are fully reproducible, and we followed deep learning's best practices (average over multiple runs for the main results, validation/early stopping and hyperparameter search on the val set). 59 | If you are a researcher comparing your work against ours, please make sure to follow these best practices and avoid picking the best model on the test set. 60 | 61 | 62 | ## Test 63 | You can test a trained model as such 64 | 65 | `$ python3 eval.py --backbone ResNet50 --fc_output_dim 128 --resume_model path/to/best_model.pth` 66 | 67 | You can download plenty of trained models below. 68 | 69 | 70 | ### Visualize predictions 71 | 72 | Predictions can be easily visualized through the `num_preds_to_save` parameter. For example running this 73 | 74 | ``` 75 | python3 eval.py --backbone ResNet50 --fc_output_dim 512 --resume_model path/to/best_model.pth \ 76 | --num_preds_to_save=3 --exp_name=cosplace_on_stlucia 77 | ``` 78 | will generate under the path `./logs/cosplace_on_stlucia/*/preds` images such as 79 | 80 |

81 | 82 |

83 | 84 | Given that saving predictions for each query might take long, you can also pass the parameter `--save_only_wrong_preds` which will save only predictions for wrongly predicted queries (i.e. where the first prediction is wrong), which should be the most interesting failure cases. 85 | 86 | 87 | ## Trained Models 88 | 89 | We now have all our trained models on [PyTorch Hub](https://pytorch.org/docs/stable/hub.html), so that you can use them in any codebase without cloning this repository simply like this 90 | ``` 91 | import torch 92 | model = torch.hub.load("gmberton/cosplace", "get_trained_model", backbone="ResNet50", fc_output_dim=2048) 93 | ``` 94 | 95 | As an alternative, you can download the trained models from the table below, which provides links to models with different backbones and dimensionality of descriptors, trained on SF-XL. 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 |
ModelDimension of Descriptors
326412825651210242048
ResNet-18linklinklinklinklink--
ResNet-50linklinklinklinklinklinklink
ResNet-101linklinklinklinklinklinklink
ResNet-152linklinklinklinklinklinklink
VGG-16-linklinklinklink--
162 | 163 | Or you can download all models at once at [this link](https://drive.google.com/drive/folders/1WzSLnv05FLm-XqP5DxR5nXaaixH23uvV?usp=sharing) 164 | 165 | ## Issues 166 | If you have any questions regarding our code or dataset, feel free to open an issue or send an email to berton.gabri@gmail.com 167 | 168 | ## Acknowledgements 169 | Parts of this repo are inspired by the following repositories: 170 | - [CosFace implementation in PyTorch](https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py) 171 | - [CNN Image Retrieval in PyTorch](https://github.com/filipradenovic/cnnimageretrieval-pytorch) (for the GeM layer) 172 | - [Visual Geo-localization benchmark](https://github.com/gmberton/deep-visual-geo-localization-benchmark) (for the evaluation / test code) 173 | 174 | ## Cite 175 | Here is the bibtex to cite our paper 176 | ``` 177 | @InProceedings{Berton_CVPR_2022_CosPlace, 178 | author = {Berton, Gabriele and Masone, Carlo and Caputo, Barbara}, 179 | title = {Rethinking Visual Geo-Localization for Large-Scale Applications}, 180 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 181 | month = {June}, 182 | year = {2022}, 183 | pages = {4878-4888} 184 | } 185 | ``` 186 | -------------------------------------------------------------------------------- /augmentations.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from typing import Tuple, Union 4 | import torchvision.transforms as T 5 | 6 | 7 | class DeviceAgnosticColorJitter(T.ColorJitter): 8 | def __init__(self, brightness: float = 0., contrast: float = 0., saturation: float = 0., hue: float = 0.): 9 | """This is the same as T.ColorJitter but it only accepts batches of images and works on GPU""" 10 | super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 11 | 12 | def forward(self, images: torch.Tensor) -> torch.Tensor: 13 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}" 14 | B, C, H, W = images.shape 15 | # Applies a different color jitter to each image 16 | color_jitter = super(DeviceAgnosticColorJitter, self).forward 17 | augmented_images = [color_jitter(img).unsqueeze(0) for img in images] 18 | augmented_images = torch.cat(augmented_images) 19 | assert augmented_images.shape == torch.Size([B, C, H, W]) 20 | return augmented_images 21 | 22 | 23 | class DeviceAgnosticRandomResizedCrop(T.RandomResizedCrop): 24 | def __init__(self, size: Union[int, Tuple[int, int]], scale: float): 25 | """This is the same as T.RandomResizedCrop but it only accepts batches of images and works on GPU""" 26 | super().__init__(size=size, scale=scale, antialias=True) 27 | 28 | def forward(self, images: torch.Tensor) -> torch.Tensor: 29 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}" 30 | B, C, H, W = images.shape 31 | # Applies a different RandomResizedCrop to each image 32 | random_resized_crop = super(DeviceAgnosticRandomResizedCrop, self).forward 33 | augmented_images = [random_resized_crop(img).unsqueeze(0) for img in images] 34 | augmented_images = torch.cat(augmented_images) 35 | return augmented_images 36 | 37 | 38 | if __name__ == "__main__": 39 | """ 40 | You can run this script to visualize the transformations, and verify that 41 | the augmentations are applied individually on each image of the batch. 42 | """ 43 | from PIL import Image 44 | # Import skimage in here, so it is not necessary to install it unless you run this script 45 | from skimage import data 46 | 47 | # Initialize DeviceAgnosticRandomResizedCrop 48 | random_crop = DeviceAgnosticRandomResizedCrop(size=[256, 256], scale=[0.5, 1]) 49 | # Create a batch with 2 astronaut images 50 | pil_image = Image.fromarray(data.astronaut()) 51 | tensor_image = T.functional.to_tensor(pil_image).unsqueeze(0) 52 | images_batch = torch.cat([tensor_image, tensor_image]) 53 | # Apply augmentation (individually on each of the 2 images) 54 | augmented_batch = random_crop(images_batch) 55 | # Convert to PIL images 56 | augmented_image_0 = T.functional.to_pil_image(augmented_batch[0]) 57 | augmented_image_1 = T.functional.to_pil_image(augmented_batch[1]) 58 | # Visualize the original image, as well as the two augmented ones 59 | pil_image.show() 60 | augmented_image_0.show() 61 | augmented_image_1.show() 62 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import torch 5 | import random 6 | import logging 7 | import traceback 8 | import numpy as np 9 | 10 | 11 | class InfiniteDataLoader(torch.utils.data.DataLoader): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.dataset_iterator = super().__iter__() 15 | 16 | def __iter__(self): 17 | return self 18 | 19 | def __next__(self): 20 | try: 21 | batch = next(self.dataset_iterator) 22 | except StopIteration: 23 | self.dataset_iterator = super().__iter__() 24 | batch = next(self.dataset_iterator) 25 | return batch 26 | 27 | 28 | def make_deterministic(seed: int = 0): 29 | """Make results deterministic. If seed == -1, do not make deterministic. 30 | Running your script in a deterministic way might slow it down. 31 | Note that for some packages (eg: sklearn's PCA) this function is not enough. 32 | """ 33 | seed = int(seed) 34 | if seed == -1: 35 | return 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | 44 | def setup_logging(output_folder: str, exist_ok: bool = False, console: str = "debug", 45 | info_filename: str = "info.log", debug_filename: str = "debug.log"): 46 | """Set up logging files and console output. 47 | Creates one file for INFO logs and one for DEBUG logs. 48 | Args: 49 | output_folder (str): creates the folder where to save the files. 50 | exist_ok (boolean): if False throw a FileExistsError if output_folder already exists 51 | debug (str): 52 | if == "debug" prints on console debug messages and higher 53 | if == "info" prints on console info messages and higher 54 | if == None does not use console (useful when a logger has already been set) 55 | info_filename (str): the name of the info file. if None, don't create info file 56 | debug_filename (str): the name of the debug file. if None, don't create debug file 57 | """ 58 | if not exist_ok and os.path.exists(output_folder): 59 | raise FileExistsError(f"{output_folder} already exists!") 60 | os.makedirs(output_folder, exist_ok=True) 61 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S") 62 | logger = logging.getLogger('') 63 | logger.setLevel(logging.DEBUG) 64 | 65 | if info_filename is not None: 66 | info_file_handler = logging.FileHandler(f'{output_folder}/{info_filename}') 67 | info_file_handler.setLevel(logging.INFO) 68 | info_file_handler.setFormatter(base_formatter) 69 | logger.addHandler(info_file_handler) 70 | 71 | if debug_filename is not None: 72 | debug_file_handler = logging.FileHandler(f'{output_folder}/{debug_filename}') 73 | debug_file_handler.setLevel(logging.DEBUG) 74 | debug_file_handler.setFormatter(base_formatter) 75 | logger.addHandler(debug_file_handler) 76 | 77 | if console is not None: 78 | console_handler = logging.StreamHandler() 79 | if console == "debug": 80 | console_handler.setLevel(logging.DEBUG) 81 | if console == "info": 82 | console_handler.setLevel(logging.INFO) 83 | console_handler.setFormatter(base_formatter) 84 | logger.addHandler(console_handler) 85 | 86 | def my_handler(type_, value, tb): 87 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb))) 88 | logging.info("Experiment finished (with some errors)") 89 | sys.excepthook = my_handler 90 | -------------------------------------------------------------------------------- /cosface_loss.py: -------------------------------------------------------------------------------- 1 | 2 | # Based on https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import Parameter 7 | 8 | 9 | def cosine_sim(x1: torch.Tensor, x2: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: 10 | ip = torch.mm(x1, x2.t()) 11 | w1 = torch.norm(x1, 2, dim) 12 | w2 = torch.norm(x2, 2, dim) 13 | return ip / torch.ger(w1, w2).clamp(min=eps) 14 | 15 | 16 | class MarginCosineProduct(nn.Module): 17 | """Implement of large margin cosine distance: 18 | Args: 19 | in_features: size of each input sample 20 | out_features: size of each output sample 21 | s: norm of input feature 22 | m: margin 23 | """ 24 | def __init__(self, in_features: int, out_features: int, s: float = 30.0, m: float = 0.40): 25 | super().__init__() 26 | self.in_features = in_features 27 | self.out_features = out_features 28 | self.s = s 29 | self.m = m 30 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 31 | nn.init.xavier_uniform_(self.weight) 32 | 33 | def forward(self, inputs: torch.Tensor, label: torch.Tensor) -> torch.Tensor: 34 | cosine = cosine_sim(inputs, self.weight) 35 | one_hot = torch.zeros_like(cosine) 36 | one_hot.scatter_(1, label.view(-1, 1), 1.0) 37 | output = self.s * (cosine - one_hot * self.m) 38 | return output 39 | 40 | def __repr__(self): 41 | return self.__class__.__name__ + '(' \ 42 | + 'in_features=' + str(self.in_features) \ 43 | + ', out_features=' + str(self.out_features) \ 44 | + ', s=' + str(self.s) \ 45 | + ', m=' + str(self.m) + ')' 46 | -------------------------------------------------------------------------------- /cosplace_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmberton/CosPlace/eb8ab8f956da3e0c90d0a1acc74b689506093b6b/cosplace_model/__init__.py -------------------------------------------------------------------------------- /cosplace_model/cosplace_network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import logging 4 | import torchvision 5 | from torch import nn 6 | from typing import Tuple 7 | 8 | from cosplace_model.layers import Flatten, L2Norm, GeM 9 | 10 | # The number of channels in the last convolutional layer, the one before average pooling 11 | CHANNELS_NUM_IN_LAST_CONV = { 12 | "ResNet18": 512, 13 | "ResNet50": 2048, 14 | "ResNet101": 2048, 15 | "ResNet152": 2048, 16 | "VGG16": 512, 17 | "EfficientNet_B0": 1280, 18 | "EfficientNet_B1": 1280, 19 | "EfficientNet_B2": 1408, 20 | "EfficientNet_B3": 1536, 21 | "EfficientNet_B4": 1792, 22 | "EfficientNet_B5": 2048, 23 | "EfficientNet_B6": 2304, 24 | "EfficientNet_B7": 2560, 25 | } 26 | 27 | 28 | class GeoLocalizationNet(nn.Module): 29 | def __init__(self, backbone : str, fc_output_dim : int, train_all_layers : bool = False): 30 | """Return a model for GeoLocalization. 31 | 32 | Args: 33 | backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet. 34 | fc_output_dim (int): the output dimension of the last fc layer, equivalent to the descriptors dimension. 35 | train_all_layers (bool): whether to freeze the first layers of the backbone during training or not. 36 | """ 37 | super().__init__() 38 | assert backbone in CHANNELS_NUM_IN_LAST_CONV, f"backbone must be one of {list(CHANNELS_NUM_IN_LAST_CONV.keys())}" 39 | self.backbone, features_dim = get_backbone(backbone, train_all_layers) 40 | self.aggregation = nn.Sequential( 41 | L2Norm(), 42 | GeM(), 43 | Flatten(), 44 | nn.Linear(features_dim, fc_output_dim), 45 | L2Norm() 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.backbone(x) 50 | x = self.aggregation(x) 51 | return x 52 | 53 | 54 | def get_pretrained_torchvision_model(backbone_name : str) -> torch.nn.Module: 55 | """This function takes the name of a backbone and returns the corresponding pretrained 56 | model from torchvision. Examples of backbone_name are 'VGG16' or 'ResNet18' 57 | """ 58 | try: # Newer versions of pytorch require to pass weights=weights_module.DEFAULT 59 | weights_module = getattr(__import__('torchvision.models', fromlist=[f"{backbone_name}_Weights"]), f"{backbone_name}_Weights") 60 | model = getattr(torchvision.models, backbone_name.lower())(weights=weights_module.DEFAULT) 61 | except (ImportError, AttributeError): # Older versions of pytorch require to pass pretrained=True 62 | model = getattr(torchvision.models, backbone_name.lower())(pretrained=True) 63 | return model 64 | 65 | 66 | def get_backbone(backbone_name : str, train_all_layers : bool) -> Tuple[torch.nn.Module, int]: 67 | backbone = get_pretrained_torchvision_model(backbone_name) 68 | if backbone_name.startswith("ResNet"): 69 | if train_all_layers: 70 | logging.debug(f"Train all layers of the {backbone_name}") 71 | else: 72 | for name, child in backbone.named_children(): 73 | if name == "layer3": # Freeze layers before conv_3 74 | break 75 | for params in child.parameters(): 76 | params.requires_grad = False 77 | logging.debug(f"Train only layer3 and layer4 of the {backbone_name}, freeze the previous ones") 78 | 79 | layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer 80 | 81 | elif backbone_name == "VGG16": 82 | layers = list(backbone.features.children())[:-2] # Remove avg pooling and FC layer 83 | if train_all_layers: 84 | logging.debug("Train all layers of the VGG-16") 85 | else: 86 | for layer in layers[:-5]: 87 | for p in layer.parameters(): 88 | p.requires_grad = False 89 | logging.debug("Train last layers of the VGG-16, freeze the previous ones") 90 | 91 | elif backbone_name.startswith("EfficientNet"): 92 | if train_all_layers: 93 | logging.debug(f"Train all layers of the {backbone_name}") 94 | else: 95 | for name, child in backbone.features.named_children(): 96 | if name == "5": # Freeze layers before block 5 97 | break 98 | for params in child.parameters(): 99 | params.requires_grad = False 100 | logging.debug(f"Train only the last three blocks of the {backbone_name}, freeze the previous ones") 101 | layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer 102 | 103 | backbone = torch.nn.Sequential(*layers) 104 | features_dim = CHANNELS_NUM_IN_LAST_CONV[backbone_name] 105 | 106 | return backbone, features_dim 107 | -------------------------------------------------------------------------------- /cosplace_model/layers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.parameter import Parameter 6 | 7 | 8 | def gem(x, p=torch.ones(1)*3, eps: float = 1e-6): 9 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 10 | 11 | 12 | class GeM(nn.Module): 13 | def __init__(self, p=3, eps=1e-6): 14 | super().__init__() 15 | self.p = Parameter(torch.ones(1)*p) 16 | self.eps = eps 17 | 18 | def forward(self, x): 19 | return gem(x, p=self.p, eps=self.eps) 20 | 21 | def __repr__(self): 22 | return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})" 23 | 24 | 25 | class Flatten(torch.nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | 29 | def forward(self, x): 30 | assert x.shape[2] == x.shape[3] == 1, f"{x.shape[2]} != {x.shape[3]} != 1" 31 | return x[:, :, 0, 0] 32 | 33 | 34 | class L2Norm(nn.Module): 35 | def __init__(self, dim=1): 36 | super().__init__() 37 | self.dim = dim 38 | 39 | def forward(self, x): 40 | return F.normalize(x, p=2.0, dim=self.dim) 41 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmberton/CosPlace/eb8ab8f956da3e0c90d0a1acc74b689506093b6b/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import logging 4 | from glob import glob 5 | from PIL import ImageFile 6 | 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | 10 | def read_images_paths(dataset_folder, get_abs_path=False): 11 | """Find images within 'dataset_folder' and return their relative paths as a list. 12 | If there is a file 'dataset_folder'_images_paths.txt, read paths from such file. 13 | Otherwise, use glob(). Keeping the paths in the file speeds up computation, 14 | because using glob over large folders can be slow. 15 | 16 | Parameters 17 | ---------- 18 | dataset_folder : str, folder containing JPEG images 19 | get_abs_path : bool, if True return absolute paths, otherwise remove 20 | dataset_folder from each path 21 | 22 | Returns 23 | ------- 24 | images_paths : list[str], paths of JPEG images within dataset_folder 25 | """ 26 | 27 | if not os.path.exists(dataset_folder): 28 | raise FileNotFoundError(f"Folder {dataset_folder} does not exist") 29 | 30 | file_with_paths = dataset_folder + "_images_paths.txt" 31 | if os.path.exists(file_with_paths): 32 | logging.debug(f"Reading paths of images within {dataset_folder} from {file_with_paths}") 33 | with open(file_with_paths, "r") as file: 34 | images_paths = file.read().splitlines() 35 | images_paths = [os.path.join(dataset_folder, path) for path in images_paths] 36 | # Sanity check that paths within the file exist 37 | if not os.path.exists(images_paths[0]): 38 | raise FileNotFoundError(f"Image with path {images_paths[0]} " 39 | f"does not exist within {dataset_folder}. It is likely " 40 | f"that the content of {file_with_paths} is wrong.") 41 | else: 42 | logging.debug(f"Searching images in {dataset_folder} with glob()") 43 | images_paths = sorted(glob(f"{dataset_folder}/**/*.jpg", recursive=True)) 44 | if len(images_paths) == 0: 45 | raise FileNotFoundError(f"Directory {dataset_folder} does not contain any JPEG images") 46 | 47 | if not get_abs_path: # Remove dataset_folder from the path 48 | images_paths = [p[len(dataset_folder) + 1:] for p in images_paths] 49 | 50 | return images_paths 51 | 52 | -------------------------------------------------------------------------------- /datasets/test_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | from sklearn.neighbors import NearestNeighbors 8 | 9 | import datasets.dataset_utils as dataset_utils 10 | 11 | 12 | class TestDataset(data.Dataset): 13 | def __init__(self, dataset_folder, database_folder="database", 14 | queries_folder="queries", positive_dist_threshold=25, 15 | image_size=512, resize_test_imgs=False): 16 | self.database_folder = dataset_folder + "/" + database_folder 17 | self.queries_folder = dataset_folder + "/" + queries_folder 18 | self.database_paths = dataset_utils.read_images_paths(self.database_folder, get_abs_path=True) 19 | self.queries_paths = dataset_utils.read_images_paths(self.queries_folder, get_abs_path=True) 20 | 21 | self.dataset_name = os.path.basename(dataset_folder) 22 | 23 | #### Read paths and UTM coordinates for all images. 24 | # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg 25 | self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float) 26 | self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(float) 27 | 28 | # Find positives_per_query, which are within positive_dist_threshold (default 25 meters) 29 | knn = NearestNeighbors(n_jobs=-1) 30 | knn.fit(self.database_utms) 31 | self.positives_per_query = knn.radius_neighbors( 32 | self.queries_utms, radius=positive_dist_threshold, return_distance=False 33 | ) 34 | 35 | self.images_paths = self.database_paths + self.queries_paths 36 | 37 | self.database_num = len(self.database_paths) 38 | self.queries_num = len(self.queries_paths) 39 | 40 | transforms_list = [] 41 | if resize_test_imgs: 42 | # Resize to image_size along the shorter side while maintaining aspect ratio 43 | transforms_list += [transforms.Resize(image_size, antialias=True)] 44 | transforms_list += [ 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 47 | ] 48 | self.base_transform = transforms.Compose(transforms_list) 49 | 50 | @staticmethod 51 | def open_image(path): 52 | return Image.open(path).convert("RGB") 53 | 54 | def __getitem__(self, index): 55 | image_path = self.images_paths[index] 56 | pil_img = TestDataset.open_image(image_path) 57 | normalized_img = self.base_transform(pil_img) 58 | return normalized_img, index 59 | 60 | def __len__(self): 61 | return len(self.images_paths) 62 | 63 | def __repr__(self): 64 | return f"< {self.dataset_name} - #q: {self.queries_num}; #db: {self.database_num} >" 65 | 66 | def get_positives(self): 67 | return self.positives_per_query 68 | -------------------------------------------------------------------------------- /datasets/train_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import random 5 | import logging 6 | import numpy as np 7 | from PIL import Image 8 | from PIL import ImageFile 9 | import torchvision.transforms as T 10 | from collections import defaultdict 11 | 12 | import datasets.dataset_utils as dataset_utils 13 | 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | class TrainDataset(torch.utils.data.Dataset): 19 | def __init__(self, args, dataset_folder, M=10, alpha=30, N=5, L=2, 20 | current_group=0, min_images_per_class=10): 21 | """ 22 | Parameters (please check our paper for a clearer explanation of the parameters). 23 | ---------- 24 | args : args for data augmentation 25 | dataset_folder : str, the path of the folder with the train images. 26 | M : int, the length of the side of each cell in meters. 27 | alpha : int, size of each class in degrees. 28 | N : int, distance (M-wise) between two classes of the same group. 29 | L : int, distance (alpha-wise) between two classes of the same group. 30 | current_group : int, which one of the groups to consider. 31 | min_images_per_class : int, minimum number of image in a class. 32 | """ 33 | super().__init__() 34 | self.M = M 35 | self.alpha = alpha 36 | self.N = N 37 | self.L = L 38 | self.current_group = current_group 39 | self.dataset_folder = dataset_folder 40 | self.augmentation_device = args.augmentation_device 41 | 42 | # dataset_name should be either "processed", "small" or "raw", if you're using SF-XL 43 | dataset_name = os.path.basename(dataset_folder) 44 | filename = f"cache/{dataset_name}_M{M}_N{N}_alpha{alpha}_L{L}_mipc{min_images_per_class}.torch" 45 | if not os.path.exists(filename): 46 | os.makedirs("cache", exist_ok=True) 47 | logging.info(f"Cached dataset {filename} does not exist, I'll create it now.") 48 | self.initialize(dataset_folder, M, N, alpha, L, min_images_per_class, filename) 49 | elif current_group == 0: 50 | logging.info(f"Using cached dataset {filename}") 51 | 52 | classes_per_group, self.images_per_class = torch.load(filename) 53 | if current_group >= len(classes_per_group): 54 | raise ValueError(f"With this configuration there are only {len(classes_per_group)} " + 55 | f"groups, therefore I can't create the {current_group}th group. " + 56 | "You should reduce the number of groups by setting for example " + 57 | f"'--groups_num {current_group}'") 58 | self.classes_ids = classes_per_group[current_group] 59 | 60 | if self.augmentation_device == "cpu": 61 | self.transform = T.Compose([ 62 | T.ColorJitter(brightness=args.brightness, 63 | contrast=args.contrast, 64 | saturation=args.saturation, 65 | hue=args.hue), 66 | T.RandomResizedCrop([args.image_size, args.image_size], scale=[1-args.random_resized_crop, 1], antialias=True), 67 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 68 | ]) 69 | 70 | @staticmethod 71 | def open_image(path): 72 | return Image.open(path).convert("RGB") 73 | 74 | def __getitem__(self, class_num): 75 | # This function takes as input the class_num instead of the index of 76 | # the image. This way each class is equally represented during training. 77 | 78 | class_id = self.classes_ids[class_num] 79 | # Pick a random image among those in this class. 80 | image_path = os.path.join(self.dataset_folder, random.choice(self.images_per_class[class_id])) 81 | 82 | try: 83 | pil_image = TrainDataset.open_image(image_path) 84 | except Exception as e: 85 | logging.info(f"ERROR image {image_path} couldn't be opened, it might be corrupted.") 86 | raise e 87 | 88 | tensor_image = T.functional.to_tensor(pil_image) 89 | assert tensor_image.shape == torch.Size([3, 512, 512]), \ 90 | f"Image {image_path} should have shape [3, 512, 512] but has {tensor_image.shape}." 91 | 92 | if self.augmentation_device == "cpu": 93 | tensor_image = self.transform(tensor_image) 94 | 95 | return tensor_image, class_num, image_path 96 | 97 | def get_images_num(self): 98 | """Return the number of images within this group.""" 99 | return sum([len(self.images_per_class[c]) for c in self.classes_ids]) 100 | 101 | def __len__(self): 102 | """Return the number of classes within this group.""" 103 | return len(self.classes_ids) 104 | 105 | @staticmethod 106 | def initialize(dataset_folder, M, N, alpha, L, min_images_per_class, filename): 107 | logging.debug(f"Searching training images in {dataset_folder}") 108 | 109 | images_paths = dataset_utils.read_images_paths(dataset_folder) 110 | logging.debug(f"Found {len(images_paths)} images") 111 | 112 | logging.debug("For each image, get its UTM east, UTM north and heading from its path") 113 | images_metadatas = [p.split("@") for p in images_paths] 114 | # field 1 is UTM east, field 2 is UTM north, field 9 is heading 115 | utmeast_utmnorth_heading = [(m[1], m[2], m[9]) for m in images_metadatas] 116 | utmeast_utmnorth_heading = np.array(utmeast_utmnorth_heading).astype(np.float64) 117 | 118 | logging.debug("For each image, get class and group to which it belongs") 119 | class_id__group_id = [TrainDataset.get__class_id__group_id(*m, M, alpha, N, L) 120 | for m in utmeast_utmnorth_heading] 121 | 122 | logging.debug("Group together images belonging to the same class") 123 | images_per_class = defaultdict(list) 124 | for image_path, (class_id, _) in zip(images_paths, class_id__group_id): 125 | images_per_class[class_id].append(image_path) 126 | 127 | # Images_per_class is a dict where the key is class_id, and the value 128 | # is a list with the paths of images within that class. 129 | images_per_class = {k: v for k, v in images_per_class.items() if len(v) >= min_images_per_class} 130 | 131 | logging.debug("Group together classes belonging to the same group") 132 | # Classes_per_group is a dict where the key is group_id, and the value 133 | # is a list with the class_ids belonging to that group. 134 | classes_per_group = defaultdict(set) 135 | for class_id, group_id in class_id__group_id: 136 | if class_id not in images_per_class: 137 | continue # Skip classes with too few images 138 | classes_per_group[group_id].add(class_id) 139 | 140 | # Convert classes_per_group to a list of lists. 141 | # Each sublist represents the classes within a group. 142 | classes_per_group = [list(c) for c in classes_per_group.values()] 143 | 144 | torch.save((classes_per_group, images_per_class), filename) 145 | 146 | @staticmethod 147 | def get__class_id__group_id(utm_east, utm_north, heading, M, alpha, N, L): 148 | """Return class_id and group_id for a given point. 149 | The class_id is a triplet (tuple) of UTM_east, UTM_north and 150 | heading (e.g. (396520, 4983800,120)). 151 | The group_id represents the group to which the class belongs 152 | (e.g. (0, 1, 0)), and it is between (0, 0, 0) and (N, N, L). 153 | """ 154 | rounded_utm_east = int(utm_east // M * M) # Rounded to nearest lower multiple of M 155 | rounded_utm_north = int(utm_north // M * M) 156 | rounded_heading = int(heading // alpha * alpha) 157 | 158 | class_id = (rounded_utm_east, rounded_utm_north, rounded_heading) 159 | # group_id goes from (0, 0, 0) to (N, N, L) 160 | group_id = (rounded_utm_east % (M * N) // M, 161 | rounded_utm_north % (M * N) // M, 162 | rounded_heading % (alpha * L) // alpha) 163 | return class_id, group_id 164 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import torch 4 | import logging 5 | import multiprocessing 6 | from datetime import datetime 7 | 8 | import test 9 | import parser 10 | import commons 11 | from cosplace_model import cosplace_network 12 | from datasets.test_dataset import TestDataset 13 | 14 | torch.backends.cudnn.benchmark = True # Provides a speedup 15 | 16 | args = parser.parse_arguments(is_training=False) 17 | start_time = datetime.now() 18 | args.output_folder = f"logs/{args.save_dir}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}" 19 | commons.make_deterministic(args.seed) 20 | commons.setup_logging(args.output_folder, console="info") 21 | logging.info(" ".join(sys.argv)) 22 | logging.info(f"Arguments: {args}") 23 | logging.info(f"The outputs are being saved in {args.output_folder}") 24 | 25 | #### Model 26 | model = cosplace_network.GeoLocalizationNet(args.backbone, args.fc_output_dim) 27 | 28 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.") 29 | 30 | if args.resume_model is not None: 31 | logging.info(f"Loading model from {args.resume_model}") 32 | model_state_dict = torch.load(args.resume_model) 33 | model.load_state_dict(model_state_dict) 34 | else: 35 | logging.info("WARNING: You didn't provide a path to resume the model (--resume_model parameter). " + 36 | "Evaluation will be computed using randomly initialized weights.") 37 | 38 | model = model.to(args.device) 39 | 40 | test_ds = TestDataset(args.test_set_folder, queries_folder="queries_v1", 41 | positive_dist_threshold=args.positive_dist_threshold) 42 | 43 | recalls, recalls_str = test.test(args, test_ds, model, args.num_preds_to_save) 44 | logging.info(f"{test_ds}: {recalls_str}") 45 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | 2 | dependencies = ['torch', 'torchvision'] 3 | 4 | import torch 5 | from cosplace_model import cosplace_network 6 | 7 | 8 | AVAILABLE_TRAINED_MODELS = { 9 | # backbone : list of available fc_output_dim, which is equivalent to descriptors dimensionality 10 | "VGG16": [ 64, 128, 256, 512], 11 | "ResNet18": [32, 64, 128, 256, 512], 12 | "ResNet50": [32, 64, 128, 256, 512, 1024, 2048], 13 | "ResNet101": [32, 64, 128, 256, 512, 1024, 2048], 14 | "ResNet152": [32, 64, 128, 256, 512, 1024, 2048], 15 | } 16 | 17 | 18 | def get_trained_model(backbone : str = "ResNet50", fc_output_dim : int = 2048) -> torch.nn.Module: 19 | """Return a model trained with CosPlace on San Francisco eXtra Large. 20 | 21 | Args: 22 | backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet. 23 | fc_output_dim (int): the output dimension of the last fc layer, equivalent to 24 | the descriptors dimension. Must be between 32 and 2048, depending on model's availability. 25 | 26 | Return: 27 | model (torch.nn.Module): a trained model. 28 | """ 29 | print(f"Returning CosPlace model with backbone: {backbone} with features dimension {fc_output_dim}") 30 | if backbone not in AVAILABLE_TRAINED_MODELS: 31 | raise ValueError(f"Parameter `backbone` is set to {backbone} but it must be one of {list(AVAILABLE_TRAINED_MODELS.keys())}") 32 | try: 33 | fc_output_dim = int(fc_output_dim) 34 | except: 35 | raise ValueError(f"Parameter `fc_output_dim` must be an integer, but it is set to {fc_output_dim}") 36 | if fc_output_dim not in AVAILABLE_TRAINED_MODELS[backbone]: 37 | raise ValueError(f"Parameter `fc_output_dim` is set to {fc_output_dim}, but for backbone {backbone} " 38 | f"it must be one of {list(AVAILABLE_TRAINED_MODELS[backbone])}") 39 | model = cosplace_network.GeoLocalizationNet(backbone, fc_output_dim) 40 | model.load_state_dict( 41 | torch.hub.load_state_dict_from_url( 42 | f'https://github.com/gmberton/CosPlace/releases/download/v1.0/{backbone}_{fc_output_dim}_cosplace.pth', 43 | map_location=torch.device('cpu')) 44 | ) 45 | return model 46 | -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | 5 | def parse_arguments(is_training: bool = True): 6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | # CosPlace Groups parameters 8 | parser.add_argument("--M", type=int, default=10, help="_") 9 | parser.add_argument("--alpha", type=int, default=30, help="_") 10 | parser.add_argument("--N", type=int, default=5, help="_") 11 | parser.add_argument("--L", type=int, default=2, help="_") 12 | parser.add_argument("--groups_num", type=int, default=8, help="_") 13 | parser.add_argument("--min_images_per_class", type=int, default=10, help="_") 14 | # Model parameters 15 | parser.add_argument("--backbone", type=str, default="ResNet18", 16 | choices=["VGG16", 17 | "ResNet18", "ResNet50", "ResNet101", "ResNet152", 18 | "EfficientNet_B0", "EfficientNet_B1", "EfficientNet_B2", 19 | "EfficientNet_B3", "EfficientNet_B4", "EfficientNet_B5", 20 | "EfficientNet_B6", "EfficientNet_B7"], help="_") 21 | parser.add_argument("--fc_output_dim", type=int, default=512, 22 | help="Output dimension of final fully connected layer") 23 | parser.add_argument("--train_all_layers", default=False, action="store_true", 24 | help="If true, train all layers of the backbone") 25 | # Training parameters 26 | parser.add_argument("--use_amp16", action="store_true", 27 | help="use Automatic Mixed Precision") 28 | parser.add_argument("--augmentation_device", type=str, default="cuda", 29 | choices=["cuda", "cpu"], 30 | help="on which device to run data augmentation") 31 | parser.add_argument("--batch_size", type=int, default=32, help="_") 32 | parser.add_argument("--epochs_num", type=int, default=50, help="_") 33 | parser.add_argument("--iterations_per_epoch", type=int, default=10000, help="_") 34 | parser.add_argument("--lr", type=float, default=0.00001, help="_") 35 | parser.add_argument("--classifiers_lr", type=float, default=0.01, help="_") 36 | parser.add_argument("--image_size", type=int, default=512, 37 | help="Width and height of training images (1:1 aspect ratio))") 38 | parser.add_argument("--resize_test_imgs", default=False, action="store_true", 39 | help="If the test images should be resized to image_size along" 40 | "the shorter side while maintaining aspect ratio") 41 | # Data augmentation 42 | parser.add_argument("--brightness", type=float, default=0.7, help="_") 43 | parser.add_argument("--contrast", type=float, default=0.7, help="_") 44 | parser.add_argument("--hue", type=float, default=0.5, help="_") 45 | parser.add_argument("--saturation", type=float, default=0.7, help="_") 46 | parser.add_argument("--random_resized_crop", type=float, default=0.5, help="_") 47 | # Validation / test parameters 48 | parser.add_argument("--infer_batch_size", type=int, default=16, 49 | help="Batch size for inference (validating and testing)") 50 | parser.add_argument("--positive_dist_threshold", type=int, default=25, 51 | help="distance in meters for a prediction to be considered a positive") 52 | # Resume parameters 53 | parser.add_argument("--resume_train", type=str, default=None, 54 | help="path to checkpoint to resume, e.g. logs/.../last_checkpoint.pth") 55 | parser.add_argument("--resume_model", type=str, default=None, 56 | help="path to model to resume, e.g. logs/.../best_model.pth") 57 | # Other parameters 58 | parser.add_argument("--device", type=str, default="cuda", 59 | choices=["cuda", "cpu"], help="_") 60 | parser.add_argument("--seed", type=int, default=0, help="_") 61 | parser.add_argument("--num_workers", type=int, default=8, help="_") 62 | parser.add_argument("--num_preds_to_save", type=int, default=0, 63 | help="At the end of training, save N preds for each query. " 64 | "Try with a small number like 3") 65 | parser.add_argument("--save_only_wrong_preds", action="store_true", 66 | help="When saving preds (if num_preds_to_save != 0) save only " 67 | "preds for difficult queries, i.e. with uncorrect first prediction") 68 | # Paths parameters 69 | if is_training: # train and val sets are needed only for training 70 | parser.add_argument("--train_set_folder", type=str, required=True, 71 | help="path of the folder with training images") 72 | parser.add_argument("--val_set_folder", type=str, required=True, 73 | help="path of the folder with val images (split in database/queries)") 74 | parser.add_argument("--test_set_folder", type=str, required=True, 75 | help="path of the folder with test images (split in database/queries)") 76 | parser.add_argument("--save_dir", type=str, default="default", 77 | help="name of directory on which to save the logs, under logs/save_dir") 78 | 79 | args = parser.parse_args() 80 | 81 | return args 82 | 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faiss_cpu>=1.7.1 2 | numpy>=1.21.2 3 | Pillow>=9.0.1 4 | scikit_learn>=1.0.2 5 | torch>=1.8.2 6 | torchvision>=0.9.2 7 | tqdm>=4.62.3 8 | utm>=0.7.0 9 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import faiss 3 | import torch 4 | import logging 5 | import numpy as np 6 | from tqdm import tqdm 7 | from typing import Tuple 8 | from argparse import Namespace 9 | from torch.utils.data.dataset import Subset 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | import visualizations 13 | 14 | 15 | # Compute R@1, R@5, R@10, R@20 16 | RECALL_VALUES = [1, 5, 10, 20] 17 | 18 | 19 | def test(args: Namespace, eval_ds: Dataset, model: torch.nn.Module, 20 | num_preds_to_save: int = 0) -> Tuple[np.ndarray, str]: 21 | """Compute descriptors of the given dataset and compute the recalls.""" 22 | 23 | model = model.eval() 24 | with torch.no_grad(): 25 | logging.debug("Extracting database descriptors for evaluation/testing") 26 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num))) 27 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers, 28 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda")) 29 | all_descriptors = np.empty((len(eval_ds), args.fc_output_dim), dtype="float32") 30 | for images, indices in tqdm(database_dataloader, ncols=100): 31 | descriptors = model(images.to(args.device)) 32 | descriptors = descriptors.cpu().numpy() 33 | all_descriptors[indices.numpy(), :] = descriptors 34 | 35 | logging.debug("Extracting queries descriptors for evaluation/testing using batch size 1") 36 | queries_infer_batch_size = 1 37 | queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num))) 38 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers, 39 | batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda")) 40 | for images, indices in tqdm(queries_dataloader, ncols=100): 41 | descriptors = model(images.to(args.device)) 42 | descriptors = descriptors.cpu().numpy() 43 | all_descriptors[indices.numpy(), :] = descriptors 44 | 45 | queries_descriptors = all_descriptors[eval_ds.database_num:] 46 | database_descriptors = all_descriptors[:eval_ds.database_num] 47 | 48 | # Use a kNN to find predictions 49 | faiss_index = faiss.IndexFlatL2(args.fc_output_dim) 50 | faiss_index.add(database_descriptors) 51 | del database_descriptors, all_descriptors 52 | 53 | logging.debug("Calculating recalls") 54 | _, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES)) 55 | 56 | #### For each query, check if the predictions are correct 57 | positives_per_query = eval_ds.get_positives() 58 | recalls = np.zeros(len(RECALL_VALUES)) 59 | for query_index, preds in enumerate(predictions): 60 | for i, n in enumerate(RECALL_VALUES): 61 | if np.any(np.in1d(preds[:n], positives_per_query[query_index])): 62 | recalls[i:] += 1 63 | break 64 | 65 | # Divide by queries_num and multiply by 100, so the recalls are in percentages 66 | recalls = recalls / eval_ds.queries_num * 100 67 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)]) 68 | 69 | # Save visualizations of predictions 70 | if num_preds_to_save != 0: 71 | # For each query save num_preds_to_save predictions 72 | visualizations.save_preds(predictions[:, :num_preds_to_save], eval_ds, args.output_folder, args.save_only_wrong_preds) 73 | 74 | return recalls, recalls_str 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import torch 4 | import logging 5 | import numpy as np 6 | from tqdm import tqdm 7 | import multiprocessing 8 | from datetime import datetime 9 | import torchvision.transforms as T 10 | 11 | import test 12 | import util 13 | import parser 14 | import commons 15 | import cosface_loss 16 | import augmentations 17 | from cosplace_model import cosplace_network 18 | from datasets.test_dataset import TestDataset 19 | from datasets.train_dataset import TrainDataset 20 | 21 | torch.backends.cudnn.benchmark = True # Provides a speedup 22 | 23 | args = parser.parse_arguments() 24 | start_time = datetime.now() 25 | args.output_folder = f"logs/{args.save_dir}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}" 26 | commons.make_deterministic(args.seed) 27 | commons.setup_logging(args.output_folder, console="debug") 28 | logging.info(" ".join(sys.argv)) 29 | logging.info(f"Arguments: {args}") 30 | logging.info(f"The outputs are being saved in {args.output_folder}") 31 | 32 | #### Model 33 | model = cosplace_network.GeoLocalizationNet(args.backbone, args.fc_output_dim, args.train_all_layers) 34 | 35 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.") 36 | 37 | if args.resume_model is not None: 38 | logging.debug(f"Loading model from {args.resume_model}") 39 | model_state_dict = torch.load(args.resume_model) 40 | model.load_state_dict(model_state_dict) 41 | 42 | model = model.to(args.device).train() 43 | 44 | #### Optimizer 45 | criterion = torch.nn.CrossEntropyLoss() 46 | model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 47 | 48 | #### Datasets 49 | groups = [TrainDataset(args, args.train_set_folder, M=args.M, alpha=args.alpha, N=args.N, L=args.L, 50 | current_group=n, min_images_per_class=args.min_images_per_class) for n in range(args.groups_num)] 51 | # Each group has its own classifier, which depends on the number of classes in the group 52 | classifiers = [cosface_loss.MarginCosineProduct(args.fc_output_dim, len(group)) for group in groups] 53 | classifiers_optimizers = [torch.optim.Adam(classifier.parameters(), lr=args.classifiers_lr) for classifier in classifiers] 54 | 55 | logging.info(f"Using {len(groups)} groups") 56 | logging.info(f"The {len(groups)} groups have respectively the following number of classes {[len(g) for g in groups]}") 57 | logging.info(f"The {len(groups)} groups have respectively the following number of images {[g.get_images_num() for g in groups]}") 58 | 59 | val_ds = TestDataset(args.val_set_folder, positive_dist_threshold=args.positive_dist_threshold, 60 | image_size=args.image_size, resize_test_imgs=args.resize_test_imgs) 61 | test_ds = TestDataset(args.test_set_folder, queries_folder="queries_v1", 62 | positive_dist_threshold=args.positive_dist_threshold, 63 | image_size=args.image_size, resize_test_imgs=args.resize_test_imgs) 64 | logging.info(f"Validation set: {val_ds}") 65 | logging.info(f"Test set: {test_ds}") 66 | 67 | #### Resume 68 | if args.resume_train: 69 | model, model_optimizer, classifiers, classifiers_optimizers, best_val_recall1, start_epoch_num = \ 70 | util.resume_train(args, args.output_folder, model, model_optimizer, classifiers, classifiers_optimizers) 71 | model = model.to(args.device) 72 | epoch_num = start_epoch_num - 1 73 | logging.info(f"Resuming from epoch {start_epoch_num} with best R@1 {best_val_recall1:.1f} from checkpoint {args.resume_train}") 74 | else: 75 | best_val_recall1 = start_epoch_num = 0 76 | 77 | #### Train / evaluation loop 78 | logging.info("Start training ...") 79 | logging.info(f"There are {len(groups[0])} classes for the first group, " + 80 | f"each epoch has {args.iterations_per_epoch} iterations " + 81 | f"with batch_size {args.batch_size}, therefore the model sees each class (on average) " + 82 | f"{args.iterations_per_epoch * args.batch_size / len(groups[0]):.1f} times per epoch") 83 | 84 | 85 | if args.augmentation_device == "cuda": 86 | gpu_augmentation = T.Compose([ 87 | augmentations.DeviceAgnosticColorJitter(brightness=args.brightness, 88 | contrast=args.contrast, 89 | saturation=args.saturation, 90 | hue=args.hue), 91 | augmentations.DeviceAgnosticRandomResizedCrop([args.image_size, args.image_size], 92 | scale=[1-args.random_resized_crop, 1]), 93 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 94 | ]) 95 | 96 | if args.use_amp16: 97 | scaler = torch.cuda.amp.GradScaler() 98 | 99 | for epoch_num in range(start_epoch_num, args.epochs_num): 100 | 101 | #### Train 102 | epoch_start_time = datetime.now() 103 | # Select classifier and dataloader according to epoch 104 | current_group_num = epoch_num % args.groups_num 105 | classifiers[current_group_num] = classifiers[current_group_num].to(args.device) 106 | util.move_to_device(classifiers_optimizers[current_group_num], args.device) 107 | 108 | dataloader = commons.InfiniteDataLoader(groups[current_group_num], num_workers=args.num_workers, 109 | batch_size=args.batch_size, shuffle=True, 110 | pin_memory=(args.device == "cuda"), drop_last=True) 111 | 112 | dataloader_iterator = iter(dataloader) 113 | model = model.train() 114 | 115 | epoch_losses = np.zeros((0, 1), dtype=np.float32) 116 | for iteration in tqdm(range(args.iterations_per_epoch), ncols=100): 117 | images, targets, _ = next(dataloader_iterator) 118 | images, targets = images.to(args.device), targets.to(args.device) 119 | 120 | if args.augmentation_device == "cuda": 121 | images = gpu_augmentation(images) 122 | 123 | model_optimizer.zero_grad() 124 | classifiers_optimizers[current_group_num].zero_grad() 125 | 126 | if not args.use_amp16: 127 | descriptors = model(images) 128 | output = classifiers[current_group_num](descriptors, targets) 129 | loss = criterion(output, targets) 130 | loss.backward() 131 | epoch_losses = np.append(epoch_losses, loss.item()) 132 | del loss, output, images 133 | model_optimizer.step() 134 | classifiers_optimizers[current_group_num].step() 135 | else: # Use AMP 16 136 | with torch.cuda.amp.autocast(): 137 | descriptors = model(images) 138 | output = classifiers[current_group_num](descriptors, targets) 139 | loss = criterion(output, targets) 140 | scaler.scale(loss).backward() 141 | epoch_losses = np.append(epoch_losses, loss.item()) 142 | del loss, output, images 143 | scaler.step(model_optimizer) 144 | scaler.step(classifiers_optimizers[current_group_num]) 145 | scaler.update() 146 | 147 | classifiers[current_group_num] = classifiers[current_group_num].cpu() 148 | util.move_to_device(classifiers_optimizers[current_group_num], "cpu") 149 | 150 | logging.debug(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, " 151 | f"loss = {epoch_losses.mean():.4f}") 152 | 153 | #### Evaluation 154 | recalls, recalls_str = test.test(args, val_ds, model) 155 | logging.info(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, {val_ds}: {recalls_str[:20]}") 156 | is_best = recalls[0] > best_val_recall1 157 | best_val_recall1 = max(recalls[0], best_val_recall1) 158 | # Save checkpoint, which contains all training parameters 159 | util.save_checkpoint({ 160 | "epoch_num": epoch_num + 1, 161 | "model_state_dict": model.state_dict(), 162 | "optimizer_state_dict": model_optimizer.state_dict(), 163 | "classifiers_state_dict": [c.state_dict() for c in classifiers], 164 | "optimizers_state_dict": [c.state_dict() for c in classifiers_optimizers], 165 | "best_val_recall1": best_val_recall1 166 | }, is_best, args.output_folder) 167 | 168 | 169 | logging.info(f"Trained for {epoch_num+1:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}") 170 | 171 | #### Test best model on test set v1 172 | best_model_state_dict = torch.load(f"{args.output_folder}/best_model.pth") 173 | model.load_state_dict(best_model_state_dict) 174 | 175 | logging.info(f"Now testing on the test set: {test_ds}") 176 | recalls, recalls_str = test.test(args, test_ds, model, args.num_preds_to_save) 177 | logging.info(f"{test_ds}: {recalls_str}") 178 | 179 | logging.info("Experiment finished (without any errors)") 180 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import shutil 4 | import logging 5 | from typing import Type, List 6 | from argparse import Namespace 7 | from cosface_loss import MarginCosineProduct 8 | 9 | 10 | def move_to_device(optimizer: Type[torch.optim.Optimizer], device: str): 11 | for state in optimizer.state.values(): 12 | for k, v in state.items(): 13 | if torch.is_tensor(v): 14 | state[k] = v.to(device) 15 | 16 | 17 | def save_checkpoint(state: dict, is_best: bool, output_folder: str, 18 | ckpt_filename: str = "last_checkpoint.pth"): 19 | # TODO it would be better to move weights to cpu before saving 20 | checkpoint_path = f"{output_folder}/{ckpt_filename}" 21 | torch.save(state, checkpoint_path) 22 | if is_best: 23 | torch.save(state["model_state_dict"], f"{output_folder}/best_model.pth") 24 | 25 | 26 | def resume_train(args: Namespace, output_folder: str, model: torch.nn.Module, 27 | model_optimizer: Type[torch.optim.Optimizer], classifiers: List[MarginCosineProduct], 28 | classifiers_optimizers: List[Type[torch.optim.Optimizer]]): 29 | """Load model, optimizer, and other training parameters""" 30 | logging.info(f"Loading checkpoint: {args.resume_train}") 31 | checkpoint = torch.load(args.resume_train) 32 | start_epoch_num = checkpoint["epoch_num"] 33 | 34 | model_state_dict = checkpoint["model_state_dict"] 35 | model.load_state_dict(model_state_dict) 36 | 37 | model = model.to(args.device) 38 | model_optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 39 | 40 | assert args.groups_num == len(classifiers) == len(classifiers_optimizers) == \ 41 | len(checkpoint["classifiers_state_dict"]) == len(checkpoint["optimizers_state_dict"]), \ 42 | (f"{args.groups_num}, {len(classifiers)}, {len(classifiers_optimizers)}, " 43 | f"{len(checkpoint['classifiers_state_dict'])}, {len(checkpoint['optimizers_state_dict'])}") 44 | 45 | for c, sd in zip(classifiers, checkpoint["classifiers_state_dict"]): 46 | # Move classifiers to GPU before loading their optimizers 47 | c = c.to(args.device) 48 | c.load_state_dict(sd) 49 | for c, sd in zip(classifiers_optimizers, checkpoint["optimizers_state_dict"]): 50 | c.load_state_dict(sd) 51 | for c in classifiers: 52 | # Move classifiers back to CPU to save some GPU memory 53 | c = c.cpu() 54 | 55 | best_val_recall1 = checkpoint["best_val_recall1"] 56 | 57 | # Copy best model to current output_folder 58 | shutil.copy(args.resume_train.replace("last_checkpoint.pth", "best_model.pth"), output_folder) 59 | 60 | return model, model_optimizer, classifiers, classifiers_optimizers, best_val_recall1, start_epoch_num 61 | -------------------------------------------------------------------------------- /visualizations.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import numpy as np 5 | from tqdm import tqdm 6 | from skimage.transform import rescale 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | 10 | # Height and width of a single image 11 | H = 512 12 | W = 512 13 | TEXT_H = 175 14 | FONTSIZE = 80 15 | SPACE = 50 # Space between two images 16 | 17 | 18 | def write_labels_to_image(labels=["text1", "text2"]): 19 | """Creates an image with vertical text, spaced along rows.""" 20 | font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", FONTSIZE) 21 | img = Image.new('RGB', ((W * len(labels)) + 50 * (len(labels)-1), TEXT_H), (1, 1, 1)) 22 | d = ImageDraw.Draw(img) 23 | for i, text in enumerate(labels): 24 | _, _, w, h = d.textbbox((0,0), text, font=font) 25 | d.text(((W+SPACE)*i + W//2 - w//2, 1), text, fill=(0, 0, 0), font=font) 26 | return np.array(img) 27 | 28 | 29 | def draw(img, c=(0, 255, 0), thickness=20): 30 | """Draw a colored (usually red or green) box around an image.""" 31 | p = np.array([[0, 0], [0, img.shape[0]], [img.shape[1], img.shape[0]], [img.shape[1], 0]]) 32 | for i in range(3): 33 | cv2.line(img, (p[i, 0], p[i, 1]), (p[i+1, 0], p[i+1, 1]), c, thickness=thickness*2) 34 | return cv2.line(img, (p[3, 0], p[3, 1]), (p[0, 0], p[0, 1]), c, thickness=thickness*2) 35 | 36 | 37 | def build_prediction_image(images_paths, preds_correct=None): 38 | """Build a row of images, where the first is the query and the rest are predictions. 39 | For each image, if is_correct then draw a green/red box. 40 | """ 41 | assert len(images_paths) == len(preds_correct) 42 | labels = ["Query"] + [f"Pr{i} - {is_correct}" for i, is_correct in enumerate(preds_correct[1:])] 43 | num_images = len(images_paths) 44 | images = [np.array(Image.open(path)) for path in images_paths] 45 | for img, correct in zip(images, preds_correct): 46 | if correct is None: 47 | continue 48 | color = (0, 255, 0) if correct else (255, 0, 0) 49 | draw(img, color) 50 | concat_image = np.ones([H, (num_images*W)+((num_images-1)*SPACE), 3]) 51 | rescaleds = [rescale(i, [min(H/i.shape[0], W/i.shape[1]), min(H/i.shape[0], W/i.shape[1]), 1]) for i in images] 52 | for i, image in enumerate(rescaleds): 53 | pad_width = (W - image.shape[1] + 1) // 2 54 | pad_height = (H - image.shape[0] + 1) // 2 55 | image = np.pad(image, [[pad_height, pad_height], [pad_width, pad_width], [0, 0]], constant_values=1)[:H, :W] 56 | concat_image[: , i*(W+SPACE) : i*(W+SPACE)+W] = image 57 | try: 58 | labels_image = write_labels_to_image(labels) 59 | final_image = np.concatenate([labels_image, concat_image]) 60 | except OSError: # Handle error in case of missing PIL ImageFont 61 | final_image = concat_image 62 | final_image = Image.fromarray((final_image*255).astype(np.uint8)) 63 | return final_image 64 | 65 | 66 | def save_file_with_paths(query_path, preds_paths, positives_paths, output_path): 67 | file_content = [] 68 | file_content.append("Query path:") 69 | file_content.append(query_path + "\n") 70 | file_content.append("Predictions paths:") 71 | file_content.append("\n".join(preds_paths) + "\n") 72 | file_content.append("Positives paths:") 73 | file_content.append("\n".join(positives_paths) + "\n") 74 | with open(output_path, "w") as file: 75 | _ = file.write("\n".join(file_content)) 76 | 77 | 78 | def save_preds(predictions, eval_ds, output_folder, save_only_wrong_preds=None): 79 | """For each query, save an image containing the query and its predictions, 80 | and a file with the paths of the query, its predictions and its positives. 81 | 82 | Parameters 83 | ---------- 84 | predictions : np.array of shape [num_queries x num_preds_to_viz], with the preds 85 | for each query 86 | eval_ds : TestDataset 87 | output_folder : str / Path with the path to save the predictions 88 | save_only_wrong_preds : bool, if True save only the wrongly predicted queries, 89 | i.e. the ones where the first pred is uncorrect (further than 25 m) 90 | """ 91 | positives_per_query = eval_ds.get_positives() 92 | os.makedirs(f"{output_folder}/preds", exist_ok=True) 93 | for query_index, preds in enumerate(tqdm(predictions, ncols=80, desc=f"Saving preds in {output_folder}")): 94 | query_path = eval_ds.queries_paths[query_index] 95 | list_of_images_paths = [query_path] 96 | # List of None (query), True (correct preds) or False (wrong preds) 97 | preds_correct = [None] 98 | for pred_index, pred in enumerate(preds): 99 | pred_path = eval_ds.database_paths[pred] 100 | list_of_images_paths.append(pred_path) 101 | is_correct = pred in positives_per_query[query_index] 102 | preds_correct.append(is_correct) 103 | 104 | if save_only_wrong_preds and preds_correct[1]: 105 | continue 106 | 107 | prediction_image = build_prediction_image(list_of_images_paths, preds_correct) 108 | pred_image_path = f"{output_folder}/preds/{query_index:03d}.jpg" 109 | prediction_image.save(pred_image_path) 110 | 111 | positives_paths = [eval_ds.database_paths[idx] for idx in positives_per_query[query_index]] 112 | save_file_with_paths( 113 | query_path=list_of_images_paths[0], 114 | preds_paths=list_of_images_paths[1:], 115 | positives_paths=positives_paths, 116 | output_path=f"{output_folder}/preds/{query_index:03d}.txt" 117 | ) 118 | 119 | 120 | --------------------------------------------------------------------------------