├── .gitignore
├── LICENSE
├── README.md
├── augmentations.py
├── commons.py
├── cosface_loss.py
├── cosplace_model
├── __init__.py
├── cosplace_network.py
└── layers.py
├── datasets
├── __init__.py
├── dataset_utils.py
├── test_dataset.py
└── train_dataset.py
├── eval.py
├── hubconf.py
├── parser.py
├── requirements.txt
├── test.py
├── train.py
├── util.py
└── visualizations.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .spyproject
2 | __pycache__
3 | logs
4 | cache
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Gabriele Berton, Carlo Masone, Barbara Caputo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Rethinking Visual Geo-localization for Large-Scale Applications
3 |
4 | [](https://paperswithcode.com/sota/visual-place-recognition-on-pittsburgh-250k?p=rethinking-visual-geo-localization-for-large)[](https://paperswithcode.com/sota/visual-place-recognition-on-pittsburgh-30k?p=rethinking-visual-geo-localization-for-large)
5 | [](https://paperswithcode.com/sota/visual-place-recognition-on-tokyo247?p=rethinking-visual-geo-localization-for-large)
6 | [](https://paperswithcode.com/sota/visual-place-recognition-on-mapillary-val?p=rethinking-visual-geo-localization-for-large)
7 | [](https://paperswithcode.com/sota/visual-place-recognition-on-st-lucia?p=rethinking-visual-geo-localization-for-large)
8 | [](https://paperswithcode.com/sota/visual-place-recognition-on-sf-xl-test-v1?p=rethinking-visual-geo-localization-for-large)
9 | [](https://paperswithcode.com/sota/visual-place-recognition-on-sf-xl-test-v2?p=rethinking-visual-geo-localization-for-large)
10 |
11 | This is the official pyTorch implementation of the CVPR 2022 paper "Rethinking Visual Geo-localization for Large-Scale Applications".
12 | The paper presents a new dataset called San Francisco eXtra Large (SF-XL, go [_here_](https://forms.gle/wpyDzhDyoWLQygAT9) to download it), and a highly scalable training method (called CosPlace), which allows to reach SOTA results with compact descriptors.
13 |
14 |
15 | [[CVPR OpenAccess](https://openaccess.thecvf.com/content/CVPR2022/html/Berton_Rethinking_Visual_Geo-Localization_for_Large-Scale_Applications_CVPR_2022_paper.html)] [[ArXiv](https://arxiv.org/abs/2204.02287)] [[Video](https://www.youtube.com/watch?v=oDyL6oVNN3I)] [[BibTex](https://github.com/gmberton/CosPlace?tab=readme-ov-file#cite)]
16 |
17 |
18 |
19 | The images below represent respectively:
20 | 1) the map of San Francisco eXtra Large
21 | 2) a visualization of how CosPlace Groups (read datasets) are formed
22 | 3) results with CosPlace vs other methods on Pitts250k (CosPlace trained on SF-XL, others on Pitts30k)
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 | ## Train
32 | After downloading the SF-XL dataset, simply run
33 |
34 | `$ python3 train.py --train_set_folder path/to/sf_xl/raw/train/database --val_set_folder path/to/sf_xl/processed/val --test_set_folder path/to/sf_xl/processed/test`
35 |
36 | the script automatically splits SF-XL in CosPlace Groups, and saves the resulting object in the folder `cache`.
37 | By default training is performed with a ResNet-18 with descriptors dimensionality 512, which fits in less than 4GB of VRAM.
38 |
39 | To change the backbone or the output descriptors dimensionality simply run
40 |
41 | `$ python3 train.py --backbone ResNet50 --fc_output_dim 128`
42 |
43 | You can also speed up your training with Automatic Mixed Precision (note that all results/statistics from the paper did not use AMP)
44 |
45 | `$ python3 train.py --use_amp16`
46 |
47 | Run `$ python3 train.py -h` to have a look at all the hyperparameters that you can change. You will find all hyperparameters mentioned in the paper.
48 |
49 | #### Dataset size and lightweight version
50 |
51 | The SF-XL dataset is about 1 TB.
52 | For training only a subset of the images is used, and you can use this subset for training, which is only 360 GB.
53 | If this is still too heavy for you (e.g. if you're using Colab), but you would like to run CosPlace, we also created a small version of SF-XL, which is only 5 GB.
54 | Obviously, using the small version will lead to lower results, and it should be used only for debugging / exploration purposes.
55 | More information on the dataset and lightweight version are on the README that you can find on the dataset download page (go [_here_](https://forms.gle/wpyDzhDyoWLQygAT9) to find it).
56 |
57 | #### Reproducibility
58 | Results from the paper are fully reproducible, and we followed deep learning's best practices (average over multiple runs for the main results, validation/early stopping and hyperparameter search on the val set).
59 | If you are a researcher comparing your work against ours, please make sure to follow these best practices and avoid picking the best model on the test set.
60 |
61 |
62 | ## Test
63 | You can test a trained model as such
64 |
65 | `$ python3 eval.py --backbone ResNet50 --fc_output_dim 128 --resume_model path/to/best_model.pth`
66 |
67 | You can download plenty of trained models below.
68 |
69 |
70 | ### Visualize predictions
71 |
72 | Predictions can be easily visualized through the `num_preds_to_save` parameter. For example running this
73 |
74 | ```
75 | python3 eval.py --backbone ResNet50 --fc_output_dim 512 --resume_model path/to/best_model.pth \
76 | --num_preds_to_save=3 --exp_name=cosplace_on_stlucia
77 | ```
78 | will generate under the path `./logs/cosplace_on_stlucia/*/preds` images such as
79 |
80 |
81 |
82 |
83 |
84 | Given that saving predictions for each query might take long, you can also pass the parameter `--save_only_wrong_preds` which will save only predictions for wrongly predicted queries (i.e. where the first prediction is wrong), which should be the most interesting failure cases.
85 |
86 |
87 | ## Trained Models
88 |
89 | We now have all our trained models on [PyTorch Hub](https://pytorch.org/docs/stable/hub.html), so that you can use them in any codebase without cloning this repository simply like this
90 | ```
91 | import torch
92 | model = torch.hub.load("gmberton/cosplace", "get_trained_model", backbone="ResNet50", fc_output_dim=2048)
93 | ```
94 |
95 | As an alternative, you can download the trained models from the table below, which provides links to models with different backbones and dimensionality of descriptors, trained on SF-XL.
96 |
97 |
98 |
99 | Model |
100 | Dimension of Descriptors |
101 |
102 |
103 | 32 |
104 | 64 |
105 | 128 |
106 | 256 |
107 | 512 |
108 | 1024 |
109 | 2048 |
110 |
111 |
112 | ResNet-18 |
113 | link |
114 | link |
115 | link |
116 | link |
117 | link |
118 | - |
119 | - |
120 |
121 |
122 | ResNet-50 |
123 | link |
124 | link |
125 | link |
126 | link |
127 | link |
128 | link |
129 | link |
130 |
131 |
132 | ResNet-101 |
133 | link |
134 | link |
135 | link |
136 | link |
137 | link |
138 | link |
139 | link |
140 |
141 |
142 | ResNet-152 |
143 | link |
144 | link |
145 | link |
146 | link |
147 | link |
148 | link |
149 | link |
150 |
151 |
152 | VGG-16 |
153 | - |
154 | link |
155 | link |
156 | link |
157 | link |
158 | - |
159 | - |
160 |
161 |
162 |
163 | Or you can download all models at once at [this link](https://drive.google.com/drive/folders/1WzSLnv05FLm-XqP5DxR5nXaaixH23uvV?usp=sharing)
164 |
165 | ## Issues
166 | If you have any questions regarding our code or dataset, feel free to open an issue or send an email to berton.gabri@gmail.com
167 |
168 | ## Acknowledgements
169 | Parts of this repo are inspired by the following repositories:
170 | - [CosFace implementation in PyTorch](https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py)
171 | - [CNN Image Retrieval in PyTorch](https://github.com/filipradenovic/cnnimageretrieval-pytorch) (for the GeM layer)
172 | - [Visual Geo-localization benchmark](https://github.com/gmberton/deep-visual-geo-localization-benchmark) (for the evaluation / test code)
173 |
174 | ## Cite
175 | Here is the bibtex to cite our paper
176 | ```
177 | @InProceedings{Berton_CVPR_2022_CosPlace,
178 | author = {Berton, Gabriele and Masone, Carlo and Caputo, Barbara},
179 | title = {Rethinking Visual Geo-Localization for Large-Scale Applications},
180 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
181 | month = {June},
182 | year = {2022},
183 | pages = {4878-4888}
184 | }
185 | ```
186 |
--------------------------------------------------------------------------------
/augmentations.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from typing import Tuple, Union
4 | import torchvision.transforms as T
5 |
6 |
7 | class DeviceAgnosticColorJitter(T.ColorJitter):
8 | def __init__(self, brightness: float = 0., contrast: float = 0., saturation: float = 0., hue: float = 0.):
9 | """This is the same as T.ColorJitter but it only accepts batches of images and works on GPU"""
10 | super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
11 |
12 | def forward(self, images: torch.Tensor) -> torch.Tensor:
13 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}"
14 | B, C, H, W = images.shape
15 | # Applies a different color jitter to each image
16 | color_jitter = super(DeviceAgnosticColorJitter, self).forward
17 | augmented_images = [color_jitter(img).unsqueeze(0) for img in images]
18 | augmented_images = torch.cat(augmented_images)
19 | assert augmented_images.shape == torch.Size([B, C, H, W])
20 | return augmented_images
21 |
22 |
23 | class DeviceAgnosticRandomResizedCrop(T.RandomResizedCrop):
24 | def __init__(self, size: Union[int, Tuple[int, int]], scale: float):
25 | """This is the same as T.RandomResizedCrop but it only accepts batches of images and works on GPU"""
26 | super().__init__(size=size, scale=scale, antialias=True)
27 |
28 | def forward(self, images: torch.Tensor) -> torch.Tensor:
29 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}"
30 | B, C, H, W = images.shape
31 | # Applies a different RandomResizedCrop to each image
32 | random_resized_crop = super(DeviceAgnosticRandomResizedCrop, self).forward
33 | augmented_images = [random_resized_crop(img).unsqueeze(0) for img in images]
34 | augmented_images = torch.cat(augmented_images)
35 | return augmented_images
36 |
37 |
38 | if __name__ == "__main__":
39 | """
40 | You can run this script to visualize the transformations, and verify that
41 | the augmentations are applied individually on each image of the batch.
42 | """
43 | from PIL import Image
44 | # Import skimage in here, so it is not necessary to install it unless you run this script
45 | from skimage import data
46 |
47 | # Initialize DeviceAgnosticRandomResizedCrop
48 | random_crop = DeviceAgnosticRandomResizedCrop(size=[256, 256], scale=[0.5, 1])
49 | # Create a batch with 2 astronaut images
50 | pil_image = Image.fromarray(data.astronaut())
51 | tensor_image = T.functional.to_tensor(pil_image).unsqueeze(0)
52 | images_batch = torch.cat([tensor_image, tensor_image])
53 | # Apply augmentation (individually on each of the 2 images)
54 | augmented_batch = random_crop(images_batch)
55 | # Convert to PIL images
56 | augmented_image_0 = T.functional.to_pil_image(augmented_batch[0])
57 | augmented_image_1 = T.functional.to_pil_image(augmented_batch[1])
58 | # Visualize the original image, as well as the two augmented ones
59 | pil_image.show()
60 | augmented_image_0.show()
61 | augmented_image_1.show()
62 |
--------------------------------------------------------------------------------
/commons.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import torch
5 | import random
6 | import logging
7 | import traceback
8 | import numpy as np
9 |
10 |
11 | class InfiniteDataLoader(torch.utils.data.DataLoader):
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 | self.dataset_iterator = super().__iter__()
15 |
16 | def __iter__(self):
17 | return self
18 |
19 | def __next__(self):
20 | try:
21 | batch = next(self.dataset_iterator)
22 | except StopIteration:
23 | self.dataset_iterator = super().__iter__()
24 | batch = next(self.dataset_iterator)
25 | return batch
26 |
27 |
28 | def make_deterministic(seed: int = 0):
29 | """Make results deterministic. If seed == -1, do not make deterministic.
30 | Running your script in a deterministic way might slow it down.
31 | Note that for some packages (eg: sklearn's PCA) this function is not enough.
32 | """
33 | seed = int(seed)
34 | if seed == -1:
35 | return
36 | random.seed(seed)
37 | np.random.seed(seed)
38 | torch.manual_seed(seed)
39 | torch.cuda.manual_seed_all(seed)
40 | torch.backends.cudnn.deterministic = True
41 | torch.backends.cudnn.benchmark = False
42 |
43 |
44 | def setup_logging(output_folder: str, exist_ok: bool = False, console: str = "debug",
45 | info_filename: str = "info.log", debug_filename: str = "debug.log"):
46 | """Set up logging files and console output.
47 | Creates one file for INFO logs and one for DEBUG logs.
48 | Args:
49 | output_folder (str): creates the folder where to save the files.
50 | exist_ok (boolean): if False throw a FileExistsError if output_folder already exists
51 | debug (str):
52 | if == "debug" prints on console debug messages and higher
53 | if == "info" prints on console info messages and higher
54 | if == None does not use console (useful when a logger has already been set)
55 | info_filename (str): the name of the info file. if None, don't create info file
56 | debug_filename (str): the name of the debug file. if None, don't create debug file
57 | """
58 | if not exist_ok and os.path.exists(output_folder):
59 | raise FileExistsError(f"{output_folder} already exists!")
60 | os.makedirs(output_folder, exist_ok=True)
61 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S")
62 | logger = logging.getLogger('')
63 | logger.setLevel(logging.DEBUG)
64 |
65 | if info_filename is not None:
66 | info_file_handler = logging.FileHandler(f'{output_folder}/{info_filename}')
67 | info_file_handler.setLevel(logging.INFO)
68 | info_file_handler.setFormatter(base_formatter)
69 | logger.addHandler(info_file_handler)
70 |
71 | if debug_filename is not None:
72 | debug_file_handler = logging.FileHandler(f'{output_folder}/{debug_filename}')
73 | debug_file_handler.setLevel(logging.DEBUG)
74 | debug_file_handler.setFormatter(base_formatter)
75 | logger.addHandler(debug_file_handler)
76 |
77 | if console is not None:
78 | console_handler = logging.StreamHandler()
79 | if console == "debug":
80 | console_handler.setLevel(logging.DEBUG)
81 | if console == "info":
82 | console_handler.setLevel(logging.INFO)
83 | console_handler.setFormatter(base_formatter)
84 | logger.addHandler(console_handler)
85 |
86 | def my_handler(type_, value, tb):
87 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb)))
88 | logging.info("Experiment finished (with some errors)")
89 | sys.excepthook = my_handler
90 |
--------------------------------------------------------------------------------
/cosface_loss.py:
--------------------------------------------------------------------------------
1 |
2 | # Based on https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import Parameter
7 |
8 |
9 | def cosine_sim(x1: torch.Tensor, x2: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
10 | ip = torch.mm(x1, x2.t())
11 | w1 = torch.norm(x1, 2, dim)
12 | w2 = torch.norm(x2, 2, dim)
13 | return ip / torch.ger(w1, w2).clamp(min=eps)
14 |
15 |
16 | class MarginCosineProduct(nn.Module):
17 | """Implement of large margin cosine distance:
18 | Args:
19 | in_features: size of each input sample
20 | out_features: size of each output sample
21 | s: norm of input feature
22 | m: margin
23 | """
24 | def __init__(self, in_features: int, out_features: int, s: float = 30.0, m: float = 0.40):
25 | super().__init__()
26 | self.in_features = in_features
27 | self.out_features = out_features
28 | self.s = s
29 | self.m = m
30 | self.weight = Parameter(torch.Tensor(out_features, in_features))
31 | nn.init.xavier_uniform_(self.weight)
32 |
33 | def forward(self, inputs: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
34 | cosine = cosine_sim(inputs, self.weight)
35 | one_hot = torch.zeros_like(cosine)
36 | one_hot.scatter_(1, label.view(-1, 1), 1.0)
37 | output = self.s * (cosine - one_hot * self.m)
38 | return output
39 |
40 | def __repr__(self):
41 | return self.__class__.__name__ + '(' \
42 | + 'in_features=' + str(self.in_features) \
43 | + ', out_features=' + str(self.out_features) \
44 | + ', s=' + str(self.s) \
45 | + ', m=' + str(self.m) + ')'
46 |
--------------------------------------------------------------------------------
/cosplace_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gmberton/CosPlace/eb8ab8f956da3e0c90d0a1acc74b689506093b6b/cosplace_model/__init__.py
--------------------------------------------------------------------------------
/cosplace_model/cosplace_network.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import logging
4 | import torchvision
5 | from torch import nn
6 | from typing import Tuple
7 |
8 | from cosplace_model.layers import Flatten, L2Norm, GeM
9 |
10 | # The number of channels in the last convolutional layer, the one before average pooling
11 | CHANNELS_NUM_IN_LAST_CONV = {
12 | "ResNet18": 512,
13 | "ResNet50": 2048,
14 | "ResNet101": 2048,
15 | "ResNet152": 2048,
16 | "VGG16": 512,
17 | "EfficientNet_B0": 1280,
18 | "EfficientNet_B1": 1280,
19 | "EfficientNet_B2": 1408,
20 | "EfficientNet_B3": 1536,
21 | "EfficientNet_B4": 1792,
22 | "EfficientNet_B5": 2048,
23 | "EfficientNet_B6": 2304,
24 | "EfficientNet_B7": 2560,
25 | }
26 |
27 |
28 | class GeoLocalizationNet(nn.Module):
29 | def __init__(self, backbone : str, fc_output_dim : int, train_all_layers : bool = False):
30 | """Return a model for GeoLocalization.
31 |
32 | Args:
33 | backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet.
34 | fc_output_dim (int): the output dimension of the last fc layer, equivalent to the descriptors dimension.
35 | train_all_layers (bool): whether to freeze the first layers of the backbone during training or not.
36 | """
37 | super().__init__()
38 | assert backbone in CHANNELS_NUM_IN_LAST_CONV, f"backbone must be one of {list(CHANNELS_NUM_IN_LAST_CONV.keys())}"
39 | self.backbone, features_dim = get_backbone(backbone, train_all_layers)
40 | self.aggregation = nn.Sequential(
41 | L2Norm(),
42 | GeM(),
43 | Flatten(),
44 | nn.Linear(features_dim, fc_output_dim),
45 | L2Norm()
46 | )
47 |
48 | def forward(self, x):
49 | x = self.backbone(x)
50 | x = self.aggregation(x)
51 | return x
52 |
53 |
54 | def get_pretrained_torchvision_model(backbone_name : str) -> torch.nn.Module:
55 | """This function takes the name of a backbone and returns the corresponding pretrained
56 | model from torchvision. Examples of backbone_name are 'VGG16' or 'ResNet18'
57 | """
58 | try: # Newer versions of pytorch require to pass weights=weights_module.DEFAULT
59 | weights_module = getattr(__import__('torchvision.models', fromlist=[f"{backbone_name}_Weights"]), f"{backbone_name}_Weights")
60 | model = getattr(torchvision.models, backbone_name.lower())(weights=weights_module.DEFAULT)
61 | except (ImportError, AttributeError): # Older versions of pytorch require to pass pretrained=True
62 | model = getattr(torchvision.models, backbone_name.lower())(pretrained=True)
63 | return model
64 |
65 |
66 | def get_backbone(backbone_name : str, train_all_layers : bool) -> Tuple[torch.nn.Module, int]:
67 | backbone = get_pretrained_torchvision_model(backbone_name)
68 | if backbone_name.startswith("ResNet"):
69 | if train_all_layers:
70 | logging.debug(f"Train all layers of the {backbone_name}")
71 | else:
72 | for name, child in backbone.named_children():
73 | if name == "layer3": # Freeze layers before conv_3
74 | break
75 | for params in child.parameters():
76 | params.requires_grad = False
77 | logging.debug(f"Train only layer3 and layer4 of the {backbone_name}, freeze the previous ones")
78 |
79 | layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer
80 |
81 | elif backbone_name == "VGG16":
82 | layers = list(backbone.features.children())[:-2] # Remove avg pooling and FC layer
83 | if train_all_layers:
84 | logging.debug("Train all layers of the VGG-16")
85 | else:
86 | for layer in layers[:-5]:
87 | for p in layer.parameters():
88 | p.requires_grad = False
89 | logging.debug("Train last layers of the VGG-16, freeze the previous ones")
90 |
91 | elif backbone_name.startswith("EfficientNet"):
92 | if train_all_layers:
93 | logging.debug(f"Train all layers of the {backbone_name}")
94 | else:
95 | for name, child in backbone.features.named_children():
96 | if name == "5": # Freeze layers before block 5
97 | break
98 | for params in child.parameters():
99 | params.requires_grad = False
100 | logging.debug(f"Train only the last three blocks of the {backbone_name}, freeze the previous ones")
101 | layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer
102 |
103 | backbone = torch.nn.Sequential(*layers)
104 | features_dim = CHANNELS_NUM_IN_LAST_CONV[backbone_name]
105 |
106 | return backbone, features_dim
107 |
--------------------------------------------------------------------------------
/cosplace_model/layers.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn.parameter import Parameter
6 |
7 |
8 | def gem(x, p=torch.ones(1)*3, eps: float = 1e-6):
9 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
10 |
11 |
12 | class GeM(nn.Module):
13 | def __init__(self, p=3, eps=1e-6):
14 | super().__init__()
15 | self.p = Parameter(torch.ones(1)*p)
16 | self.eps = eps
17 |
18 | def forward(self, x):
19 | return gem(x, p=self.p, eps=self.eps)
20 |
21 | def __repr__(self):
22 | return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})"
23 |
24 |
25 | class Flatten(torch.nn.Module):
26 | def __init__(self):
27 | super().__init__()
28 |
29 | def forward(self, x):
30 | assert x.shape[2] == x.shape[3] == 1, f"{x.shape[2]} != {x.shape[3]} != 1"
31 | return x[:, :, 0, 0]
32 |
33 |
34 | class L2Norm(nn.Module):
35 | def __init__(self, dim=1):
36 | super().__init__()
37 | self.dim = dim
38 |
39 | def forward(self, x):
40 | return F.normalize(x, p=2.0, dim=self.dim)
41 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gmberton/CosPlace/eb8ab8f956da3e0c90d0a1acc74b689506093b6b/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import logging
4 | from glob import glob
5 | from PIL import ImageFile
6 |
7 | ImageFile.LOAD_TRUNCATED_IMAGES = True
8 |
9 |
10 | def read_images_paths(dataset_folder, get_abs_path=False):
11 | """Find images within 'dataset_folder' and return their relative paths as a list.
12 | If there is a file 'dataset_folder'_images_paths.txt, read paths from such file.
13 | Otherwise, use glob(). Keeping the paths in the file speeds up computation,
14 | because using glob over large folders can be slow.
15 |
16 | Parameters
17 | ----------
18 | dataset_folder : str, folder containing JPEG images
19 | get_abs_path : bool, if True return absolute paths, otherwise remove
20 | dataset_folder from each path
21 |
22 | Returns
23 | -------
24 | images_paths : list[str], paths of JPEG images within dataset_folder
25 | """
26 |
27 | if not os.path.exists(dataset_folder):
28 | raise FileNotFoundError(f"Folder {dataset_folder} does not exist")
29 |
30 | file_with_paths = dataset_folder + "_images_paths.txt"
31 | if os.path.exists(file_with_paths):
32 | logging.debug(f"Reading paths of images within {dataset_folder} from {file_with_paths}")
33 | with open(file_with_paths, "r") as file:
34 | images_paths = file.read().splitlines()
35 | images_paths = [os.path.join(dataset_folder, path) for path in images_paths]
36 | # Sanity check that paths within the file exist
37 | if not os.path.exists(images_paths[0]):
38 | raise FileNotFoundError(f"Image with path {images_paths[0]} "
39 | f"does not exist within {dataset_folder}. It is likely "
40 | f"that the content of {file_with_paths} is wrong.")
41 | else:
42 | logging.debug(f"Searching images in {dataset_folder} with glob()")
43 | images_paths = sorted(glob(f"{dataset_folder}/**/*.jpg", recursive=True))
44 | if len(images_paths) == 0:
45 | raise FileNotFoundError(f"Directory {dataset_folder} does not contain any JPEG images")
46 |
47 | if not get_abs_path: # Remove dataset_folder from the path
48 | images_paths = [p[len(dataset_folder) + 1:] for p in images_paths]
49 |
50 | return images_paths
51 |
52 |
--------------------------------------------------------------------------------
/datasets/test_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import torch.utils.data as data
6 | import torchvision.transforms as transforms
7 | from sklearn.neighbors import NearestNeighbors
8 |
9 | import datasets.dataset_utils as dataset_utils
10 |
11 |
12 | class TestDataset(data.Dataset):
13 | def __init__(self, dataset_folder, database_folder="database",
14 | queries_folder="queries", positive_dist_threshold=25,
15 | image_size=512, resize_test_imgs=False):
16 | self.database_folder = dataset_folder + "/" + database_folder
17 | self.queries_folder = dataset_folder + "/" + queries_folder
18 | self.database_paths = dataset_utils.read_images_paths(self.database_folder, get_abs_path=True)
19 | self.queries_paths = dataset_utils.read_images_paths(self.queries_folder, get_abs_path=True)
20 |
21 | self.dataset_name = os.path.basename(dataset_folder)
22 |
23 | #### Read paths and UTM coordinates for all images.
24 | # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg
25 | self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float)
26 | self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(float)
27 |
28 | # Find positives_per_query, which are within positive_dist_threshold (default 25 meters)
29 | knn = NearestNeighbors(n_jobs=-1)
30 | knn.fit(self.database_utms)
31 | self.positives_per_query = knn.radius_neighbors(
32 | self.queries_utms, radius=positive_dist_threshold, return_distance=False
33 | )
34 |
35 | self.images_paths = self.database_paths + self.queries_paths
36 |
37 | self.database_num = len(self.database_paths)
38 | self.queries_num = len(self.queries_paths)
39 |
40 | transforms_list = []
41 | if resize_test_imgs:
42 | # Resize to image_size along the shorter side while maintaining aspect ratio
43 | transforms_list += [transforms.Resize(image_size, antialias=True)]
44 | transforms_list += [
45 | transforms.ToTensor(),
46 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47 | ]
48 | self.base_transform = transforms.Compose(transforms_list)
49 |
50 | @staticmethod
51 | def open_image(path):
52 | return Image.open(path).convert("RGB")
53 |
54 | def __getitem__(self, index):
55 | image_path = self.images_paths[index]
56 | pil_img = TestDataset.open_image(image_path)
57 | normalized_img = self.base_transform(pil_img)
58 | return normalized_img, index
59 |
60 | def __len__(self):
61 | return len(self.images_paths)
62 |
63 | def __repr__(self):
64 | return f"< {self.dataset_name} - #q: {self.queries_num}; #db: {self.database_num} >"
65 |
66 | def get_positives(self):
67 | return self.positives_per_query
68 |
--------------------------------------------------------------------------------
/datasets/train_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import random
5 | import logging
6 | import numpy as np
7 | from PIL import Image
8 | from PIL import ImageFile
9 | import torchvision.transforms as T
10 | from collections import defaultdict
11 |
12 | import datasets.dataset_utils as dataset_utils
13 |
14 |
15 | ImageFile.LOAD_TRUNCATED_IMAGES = True
16 |
17 |
18 | class TrainDataset(torch.utils.data.Dataset):
19 | def __init__(self, args, dataset_folder, M=10, alpha=30, N=5, L=2,
20 | current_group=0, min_images_per_class=10):
21 | """
22 | Parameters (please check our paper for a clearer explanation of the parameters).
23 | ----------
24 | args : args for data augmentation
25 | dataset_folder : str, the path of the folder with the train images.
26 | M : int, the length of the side of each cell in meters.
27 | alpha : int, size of each class in degrees.
28 | N : int, distance (M-wise) between two classes of the same group.
29 | L : int, distance (alpha-wise) between two classes of the same group.
30 | current_group : int, which one of the groups to consider.
31 | min_images_per_class : int, minimum number of image in a class.
32 | """
33 | super().__init__()
34 | self.M = M
35 | self.alpha = alpha
36 | self.N = N
37 | self.L = L
38 | self.current_group = current_group
39 | self.dataset_folder = dataset_folder
40 | self.augmentation_device = args.augmentation_device
41 |
42 | # dataset_name should be either "processed", "small" or "raw", if you're using SF-XL
43 | dataset_name = os.path.basename(dataset_folder)
44 | filename = f"cache/{dataset_name}_M{M}_N{N}_alpha{alpha}_L{L}_mipc{min_images_per_class}.torch"
45 | if not os.path.exists(filename):
46 | os.makedirs("cache", exist_ok=True)
47 | logging.info(f"Cached dataset {filename} does not exist, I'll create it now.")
48 | self.initialize(dataset_folder, M, N, alpha, L, min_images_per_class, filename)
49 | elif current_group == 0:
50 | logging.info(f"Using cached dataset {filename}")
51 |
52 | classes_per_group, self.images_per_class = torch.load(filename)
53 | if current_group >= len(classes_per_group):
54 | raise ValueError(f"With this configuration there are only {len(classes_per_group)} " +
55 | f"groups, therefore I can't create the {current_group}th group. " +
56 | "You should reduce the number of groups by setting for example " +
57 | f"'--groups_num {current_group}'")
58 | self.classes_ids = classes_per_group[current_group]
59 |
60 | if self.augmentation_device == "cpu":
61 | self.transform = T.Compose([
62 | T.ColorJitter(brightness=args.brightness,
63 | contrast=args.contrast,
64 | saturation=args.saturation,
65 | hue=args.hue),
66 | T.RandomResizedCrop([args.image_size, args.image_size], scale=[1-args.random_resized_crop, 1], antialias=True),
67 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
68 | ])
69 |
70 | @staticmethod
71 | def open_image(path):
72 | return Image.open(path).convert("RGB")
73 |
74 | def __getitem__(self, class_num):
75 | # This function takes as input the class_num instead of the index of
76 | # the image. This way each class is equally represented during training.
77 |
78 | class_id = self.classes_ids[class_num]
79 | # Pick a random image among those in this class.
80 | image_path = os.path.join(self.dataset_folder, random.choice(self.images_per_class[class_id]))
81 |
82 | try:
83 | pil_image = TrainDataset.open_image(image_path)
84 | except Exception as e:
85 | logging.info(f"ERROR image {image_path} couldn't be opened, it might be corrupted.")
86 | raise e
87 |
88 | tensor_image = T.functional.to_tensor(pil_image)
89 | assert tensor_image.shape == torch.Size([3, 512, 512]), \
90 | f"Image {image_path} should have shape [3, 512, 512] but has {tensor_image.shape}."
91 |
92 | if self.augmentation_device == "cpu":
93 | tensor_image = self.transform(tensor_image)
94 |
95 | return tensor_image, class_num, image_path
96 |
97 | def get_images_num(self):
98 | """Return the number of images within this group."""
99 | return sum([len(self.images_per_class[c]) for c in self.classes_ids])
100 |
101 | def __len__(self):
102 | """Return the number of classes within this group."""
103 | return len(self.classes_ids)
104 |
105 | @staticmethod
106 | def initialize(dataset_folder, M, N, alpha, L, min_images_per_class, filename):
107 | logging.debug(f"Searching training images in {dataset_folder}")
108 |
109 | images_paths = dataset_utils.read_images_paths(dataset_folder)
110 | logging.debug(f"Found {len(images_paths)} images")
111 |
112 | logging.debug("For each image, get its UTM east, UTM north and heading from its path")
113 | images_metadatas = [p.split("@") for p in images_paths]
114 | # field 1 is UTM east, field 2 is UTM north, field 9 is heading
115 | utmeast_utmnorth_heading = [(m[1], m[2], m[9]) for m in images_metadatas]
116 | utmeast_utmnorth_heading = np.array(utmeast_utmnorth_heading).astype(np.float64)
117 |
118 | logging.debug("For each image, get class and group to which it belongs")
119 | class_id__group_id = [TrainDataset.get__class_id__group_id(*m, M, alpha, N, L)
120 | for m in utmeast_utmnorth_heading]
121 |
122 | logging.debug("Group together images belonging to the same class")
123 | images_per_class = defaultdict(list)
124 | for image_path, (class_id, _) in zip(images_paths, class_id__group_id):
125 | images_per_class[class_id].append(image_path)
126 |
127 | # Images_per_class is a dict where the key is class_id, and the value
128 | # is a list with the paths of images within that class.
129 | images_per_class = {k: v for k, v in images_per_class.items() if len(v) >= min_images_per_class}
130 |
131 | logging.debug("Group together classes belonging to the same group")
132 | # Classes_per_group is a dict where the key is group_id, and the value
133 | # is a list with the class_ids belonging to that group.
134 | classes_per_group = defaultdict(set)
135 | for class_id, group_id in class_id__group_id:
136 | if class_id not in images_per_class:
137 | continue # Skip classes with too few images
138 | classes_per_group[group_id].add(class_id)
139 |
140 | # Convert classes_per_group to a list of lists.
141 | # Each sublist represents the classes within a group.
142 | classes_per_group = [list(c) for c in classes_per_group.values()]
143 |
144 | torch.save((classes_per_group, images_per_class), filename)
145 |
146 | @staticmethod
147 | def get__class_id__group_id(utm_east, utm_north, heading, M, alpha, N, L):
148 | """Return class_id and group_id for a given point.
149 | The class_id is a triplet (tuple) of UTM_east, UTM_north and
150 | heading (e.g. (396520, 4983800,120)).
151 | The group_id represents the group to which the class belongs
152 | (e.g. (0, 1, 0)), and it is between (0, 0, 0) and (N, N, L).
153 | """
154 | rounded_utm_east = int(utm_east // M * M) # Rounded to nearest lower multiple of M
155 | rounded_utm_north = int(utm_north // M * M)
156 | rounded_heading = int(heading // alpha * alpha)
157 |
158 | class_id = (rounded_utm_east, rounded_utm_north, rounded_heading)
159 | # group_id goes from (0, 0, 0) to (N, N, L)
160 | group_id = (rounded_utm_east % (M * N) // M,
161 | rounded_utm_north % (M * N) // M,
162 | rounded_heading % (alpha * L) // alpha)
163 | return class_id, group_id
164 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | import torch
4 | import logging
5 | import multiprocessing
6 | from datetime import datetime
7 |
8 | import test
9 | import parser
10 | import commons
11 | from cosplace_model import cosplace_network
12 | from datasets.test_dataset import TestDataset
13 |
14 | torch.backends.cudnn.benchmark = True # Provides a speedup
15 |
16 | args = parser.parse_arguments(is_training=False)
17 | start_time = datetime.now()
18 | args.output_folder = f"logs/{args.save_dir}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}"
19 | commons.make_deterministic(args.seed)
20 | commons.setup_logging(args.output_folder, console="info")
21 | logging.info(" ".join(sys.argv))
22 | logging.info(f"Arguments: {args}")
23 | logging.info(f"The outputs are being saved in {args.output_folder}")
24 |
25 | #### Model
26 | model = cosplace_network.GeoLocalizationNet(args.backbone, args.fc_output_dim)
27 |
28 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.")
29 |
30 | if args.resume_model is not None:
31 | logging.info(f"Loading model from {args.resume_model}")
32 | model_state_dict = torch.load(args.resume_model)
33 | model.load_state_dict(model_state_dict)
34 | else:
35 | logging.info("WARNING: You didn't provide a path to resume the model (--resume_model parameter). " +
36 | "Evaluation will be computed using randomly initialized weights.")
37 |
38 | model = model.to(args.device)
39 |
40 | test_ds = TestDataset(args.test_set_folder, queries_folder="queries_v1",
41 | positive_dist_threshold=args.positive_dist_threshold)
42 |
43 | recalls, recalls_str = test.test(args, test_ds, model, args.num_preds_to_save)
44 | logging.info(f"{test_ds}: {recalls_str}")
45 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 |
2 | dependencies = ['torch', 'torchvision']
3 |
4 | import torch
5 | from cosplace_model import cosplace_network
6 |
7 |
8 | AVAILABLE_TRAINED_MODELS = {
9 | # backbone : list of available fc_output_dim, which is equivalent to descriptors dimensionality
10 | "VGG16": [ 64, 128, 256, 512],
11 | "ResNet18": [32, 64, 128, 256, 512],
12 | "ResNet50": [32, 64, 128, 256, 512, 1024, 2048],
13 | "ResNet101": [32, 64, 128, 256, 512, 1024, 2048],
14 | "ResNet152": [32, 64, 128, 256, 512, 1024, 2048],
15 | }
16 |
17 |
18 | def get_trained_model(backbone : str = "ResNet50", fc_output_dim : int = 2048) -> torch.nn.Module:
19 | """Return a model trained with CosPlace on San Francisco eXtra Large.
20 |
21 | Args:
22 | backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet.
23 | fc_output_dim (int): the output dimension of the last fc layer, equivalent to
24 | the descriptors dimension. Must be between 32 and 2048, depending on model's availability.
25 |
26 | Return:
27 | model (torch.nn.Module): a trained model.
28 | """
29 | print(f"Returning CosPlace model with backbone: {backbone} with features dimension {fc_output_dim}")
30 | if backbone not in AVAILABLE_TRAINED_MODELS:
31 | raise ValueError(f"Parameter `backbone` is set to {backbone} but it must be one of {list(AVAILABLE_TRAINED_MODELS.keys())}")
32 | try:
33 | fc_output_dim = int(fc_output_dim)
34 | except:
35 | raise ValueError(f"Parameter `fc_output_dim` must be an integer, but it is set to {fc_output_dim}")
36 | if fc_output_dim not in AVAILABLE_TRAINED_MODELS[backbone]:
37 | raise ValueError(f"Parameter `fc_output_dim` is set to {fc_output_dim}, but for backbone {backbone} "
38 | f"it must be one of {list(AVAILABLE_TRAINED_MODELS[backbone])}")
39 | model = cosplace_network.GeoLocalizationNet(backbone, fc_output_dim)
40 | model.load_state_dict(
41 | torch.hub.load_state_dict_from_url(
42 | f'https://github.com/gmberton/CosPlace/releases/download/v1.0/{backbone}_{fc_output_dim}_cosplace.pth',
43 | map_location=torch.device('cpu'))
44 | )
45 | return model
46 |
--------------------------------------------------------------------------------
/parser.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 |
4 |
5 | def parse_arguments(is_training: bool = True):
6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7 | # CosPlace Groups parameters
8 | parser.add_argument("--M", type=int, default=10, help="_")
9 | parser.add_argument("--alpha", type=int, default=30, help="_")
10 | parser.add_argument("--N", type=int, default=5, help="_")
11 | parser.add_argument("--L", type=int, default=2, help="_")
12 | parser.add_argument("--groups_num", type=int, default=8, help="_")
13 | parser.add_argument("--min_images_per_class", type=int, default=10, help="_")
14 | # Model parameters
15 | parser.add_argument("--backbone", type=str, default="ResNet18",
16 | choices=["VGG16",
17 | "ResNet18", "ResNet50", "ResNet101", "ResNet152",
18 | "EfficientNet_B0", "EfficientNet_B1", "EfficientNet_B2",
19 | "EfficientNet_B3", "EfficientNet_B4", "EfficientNet_B5",
20 | "EfficientNet_B6", "EfficientNet_B7"], help="_")
21 | parser.add_argument("--fc_output_dim", type=int, default=512,
22 | help="Output dimension of final fully connected layer")
23 | parser.add_argument("--train_all_layers", default=False, action="store_true",
24 | help="If true, train all layers of the backbone")
25 | # Training parameters
26 | parser.add_argument("--use_amp16", action="store_true",
27 | help="use Automatic Mixed Precision")
28 | parser.add_argument("--augmentation_device", type=str, default="cuda",
29 | choices=["cuda", "cpu"],
30 | help="on which device to run data augmentation")
31 | parser.add_argument("--batch_size", type=int, default=32, help="_")
32 | parser.add_argument("--epochs_num", type=int, default=50, help="_")
33 | parser.add_argument("--iterations_per_epoch", type=int, default=10000, help="_")
34 | parser.add_argument("--lr", type=float, default=0.00001, help="_")
35 | parser.add_argument("--classifiers_lr", type=float, default=0.01, help="_")
36 | parser.add_argument("--image_size", type=int, default=512,
37 | help="Width and height of training images (1:1 aspect ratio))")
38 | parser.add_argument("--resize_test_imgs", default=False, action="store_true",
39 | help="If the test images should be resized to image_size along"
40 | "the shorter side while maintaining aspect ratio")
41 | # Data augmentation
42 | parser.add_argument("--brightness", type=float, default=0.7, help="_")
43 | parser.add_argument("--contrast", type=float, default=0.7, help="_")
44 | parser.add_argument("--hue", type=float, default=0.5, help="_")
45 | parser.add_argument("--saturation", type=float, default=0.7, help="_")
46 | parser.add_argument("--random_resized_crop", type=float, default=0.5, help="_")
47 | # Validation / test parameters
48 | parser.add_argument("--infer_batch_size", type=int, default=16,
49 | help="Batch size for inference (validating and testing)")
50 | parser.add_argument("--positive_dist_threshold", type=int, default=25,
51 | help="distance in meters for a prediction to be considered a positive")
52 | # Resume parameters
53 | parser.add_argument("--resume_train", type=str, default=None,
54 | help="path to checkpoint to resume, e.g. logs/.../last_checkpoint.pth")
55 | parser.add_argument("--resume_model", type=str, default=None,
56 | help="path to model to resume, e.g. logs/.../best_model.pth")
57 | # Other parameters
58 | parser.add_argument("--device", type=str, default="cuda",
59 | choices=["cuda", "cpu"], help="_")
60 | parser.add_argument("--seed", type=int, default=0, help="_")
61 | parser.add_argument("--num_workers", type=int, default=8, help="_")
62 | parser.add_argument("--num_preds_to_save", type=int, default=0,
63 | help="At the end of training, save N preds for each query. "
64 | "Try with a small number like 3")
65 | parser.add_argument("--save_only_wrong_preds", action="store_true",
66 | help="When saving preds (if num_preds_to_save != 0) save only "
67 | "preds for difficult queries, i.e. with uncorrect first prediction")
68 | # Paths parameters
69 | if is_training: # train and val sets are needed only for training
70 | parser.add_argument("--train_set_folder", type=str, required=True,
71 | help="path of the folder with training images")
72 | parser.add_argument("--val_set_folder", type=str, required=True,
73 | help="path of the folder with val images (split in database/queries)")
74 | parser.add_argument("--test_set_folder", type=str, required=True,
75 | help="path of the folder with test images (split in database/queries)")
76 | parser.add_argument("--save_dir", type=str, default="default",
77 | help="name of directory on which to save the logs, under logs/save_dir")
78 |
79 | args = parser.parse_args()
80 |
81 | return args
82 |
83 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | faiss_cpu>=1.7.1
2 | numpy>=1.21.2
3 | Pillow>=9.0.1
4 | scikit_learn>=1.0.2
5 | torch>=1.8.2
6 | torchvision>=0.9.2
7 | tqdm>=4.62.3
8 | utm>=0.7.0
9 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 |
2 | import faiss
3 | import torch
4 | import logging
5 | import numpy as np
6 | from tqdm import tqdm
7 | from typing import Tuple
8 | from argparse import Namespace
9 | from torch.utils.data.dataset import Subset
10 | from torch.utils.data import DataLoader, Dataset
11 |
12 | import visualizations
13 |
14 |
15 | # Compute R@1, R@5, R@10, R@20
16 | RECALL_VALUES = [1, 5, 10, 20]
17 |
18 |
19 | def test(args: Namespace, eval_ds: Dataset, model: torch.nn.Module,
20 | num_preds_to_save: int = 0) -> Tuple[np.ndarray, str]:
21 | """Compute descriptors of the given dataset and compute the recalls."""
22 |
23 | model = model.eval()
24 | with torch.no_grad():
25 | logging.debug("Extracting database descriptors for evaluation/testing")
26 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
27 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
28 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
29 | all_descriptors = np.empty((len(eval_ds), args.fc_output_dim), dtype="float32")
30 | for images, indices in tqdm(database_dataloader, ncols=100):
31 | descriptors = model(images.to(args.device))
32 | descriptors = descriptors.cpu().numpy()
33 | all_descriptors[indices.numpy(), :] = descriptors
34 |
35 | logging.debug("Extracting queries descriptors for evaluation/testing using batch size 1")
36 | queries_infer_batch_size = 1
37 | queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))
38 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
39 | batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda"))
40 | for images, indices in tqdm(queries_dataloader, ncols=100):
41 | descriptors = model(images.to(args.device))
42 | descriptors = descriptors.cpu().numpy()
43 | all_descriptors[indices.numpy(), :] = descriptors
44 |
45 | queries_descriptors = all_descriptors[eval_ds.database_num:]
46 | database_descriptors = all_descriptors[:eval_ds.database_num]
47 |
48 | # Use a kNN to find predictions
49 | faiss_index = faiss.IndexFlatL2(args.fc_output_dim)
50 | faiss_index.add(database_descriptors)
51 | del database_descriptors, all_descriptors
52 |
53 | logging.debug("Calculating recalls")
54 | _, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES))
55 |
56 | #### For each query, check if the predictions are correct
57 | positives_per_query = eval_ds.get_positives()
58 | recalls = np.zeros(len(RECALL_VALUES))
59 | for query_index, preds in enumerate(predictions):
60 | for i, n in enumerate(RECALL_VALUES):
61 | if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
62 | recalls[i:] += 1
63 | break
64 |
65 | # Divide by queries_num and multiply by 100, so the recalls are in percentages
66 | recalls = recalls / eval_ds.queries_num * 100
67 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)])
68 |
69 | # Save visualizations of predictions
70 | if num_preds_to_save != 0:
71 | # For each query save num_preds_to_save predictions
72 | visualizations.save_preds(predictions[:, :num_preds_to_save], eval_ds, args.output_folder, args.save_only_wrong_preds)
73 |
74 | return recalls, recalls_str
75 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | import torch
4 | import logging
5 | import numpy as np
6 | from tqdm import tqdm
7 | import multiprocessing
8 | from datetime import datetime
9 | import torchvision.transforms as T
10 |
11 | import test
12 | import util
13 | import parser
14 | import commons
15 | import cosface_loss
16 | import augmentations
17 | from cosplace_model import cosplace_network
18 | from datasets.test_dataset import TestDataset
19 | from datasets.train_dataset import TrainDataset
20 |
21 | torch.backends.cudnn.benchmark = True # Provides a speedup
22 |
23 | args = parser.parse_arguments()
24 | start_time = datetime.now()
25 | args.output_folder = f"logs/{args.save_dir}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}"
26 | commons.make_deterministic(args.seed)
27 | commons.setup_logging(args.output_folder, console="debug")
28 | logging.info(" ".join(sys.argv))
29 | logging.info(f"Arguments: {args}")
30 | logging.info(f"The outputs are being saved in {args.output_folder}")
31 |
32 | #### Model
33 | model = cosplace_network.GeoLocalizationNet(args.backbone, args.fc_output_dim, args.train_all_layers)
34 |
35 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.")
36 |
37 | if args.resume_model is not None:
38 | logging.debug(f"Loading model from {args.resume_model}")
39 | model_state_dict = torch.load(args.resume_model)
40 | model.load_state_dict(model_state_dict)
41 |
42 | model = model.to(args.device).train()
43 |
44 | #### Optimizer
45 | criterion = torch.nn.CrossEntropyLoss()
46 | model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
47 |
48 | #### Datasets
49 | groups = [TrainDataset(args, args.train_set_folder, M=args.M, alpha=args.alpha, N=args.N, L=args.L,
50 | current_group=n, min_images_per_class=args.min_images_per_class) for n in range(args.groups_num)]
51 | # Each group has its own classifier, which depends on the number of classes in the group
52 | classifiers = [cosface_loss.MarginCosineProduct(args.fc_output_dim, len(group)) for group in groups]
53 | classifiers_optimizers = [torch.optim.Adam(classifier.parameters(), lr=args.classifiers_lr) for classifier in classifiers]
54 |
55 | logging.info(f"Using {len(groups)} groups")
56 | logging.info(f"The {len(groups)} groups have respectively the following number of classes {[len(g) for g in groups]}")
57 | logging.info(f"The {len(groups)} groups have respectively the following number of images {[g.get_images_num() for g in groups]}")
58 |
59 | val_ds = TestDataset(args.val_set_folder, positive_dist_threshold=args.positive_dist_threshold,
60 | image_size=args.image_size, resize_test_imgs=args.resize_test_imgs)
61 | test_ds = TestDataset(args.test_set_folder, queries_folder="queries_v1",
62 | positive_dist_threshold=args.positive_dist_threshold,
63 | image_size=args.image_size, resize_test_imgs=args.resize_test_imgs)
64 | logging.info(f"Validation set: {val_ds}")
65 | logging.info(f"Test set: {test_ds}")
66 |
67 | #### Resume
68 | if args.resume_train:
69 | model, model_optimizer, classifiers, classifiers_optimizers, best_val_recall1, start_epoch_num = \
70 | util.resume_train(args, args.output_folder, model, model_optimizer, classifiers, classifiers_optimizers)
71 | model = model.to(args.device)
72 | epoch_num = start_epoch_num - 1
73 | logging.info(f"Resuming from epoch {start_epoch_num} with best R@1 {best_val_recall1:.1f} from checkpoint {args.resume_train}")
74 | else:
75 | best_val_recall1 = start_epoch_num = 0
76 |
77 | #### Train / evaluation loop
78 | logging.info("Start training ...")
79 | logging.info(f"There are {len(groups[0])} classes for the first group, " +
80 | f"each epoch has {args.iterations_per_epoch} iterations " +
81 | f"with batch_size {args.batch_size}, therefore the model sees each class (on average) " +
82 | f"{args.iterations_per_epoch * args.batch_size / len(groups[0]):.1f} times per epoch")
83 |
84 |
85 | if args.augmentation_device == "cuda":
86 | gpu_augmentation = T.Compose([
87 | augmentations.DeviceAgnosticColorJitter(brightness=args.brightness,
88 | contrast=args.contrast,
89 | saturation=args.saturation,
90 | hue=args.hue),
91 | augmentations.DeviceAgnosticRandomResizedCrop([args.image_size, args.image_size],
92 | scale=[1-args.random_resized_crop, 1]),
93 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
94 | ])
95 |
96 | if args.use_amp16:
97 | scaler = torch.cuda.amp.GradScaler()
98 |
99 | for epoch_num in range(start_epoch_num, args.epochs_num):
100 |
101 | #### Train
102 | epoch_start_time = datetime.now()
103 | # Select classifier and dataloader according to epoch
104 | current_group_num = epoch_num % args.groups_num
105 | classifiers[current_group_num] = classifiers[current_group_num].to(args.device)
106 | util.move_to_device(classifiers_optimizers[current_group_num], args.device)
107 |
108 | dataloader = commons.InfiniteDataLoader(groups[current_group_num], num_workers=args.num_workers,
109 | batch_size=args.batch_size, shuffle=True,
110 | pin_memory=(args.device == "cuda"), drop_last=True)
111 |
112 | dataloader_iterator = iter(dataloader)
113 | model = model.train()
114 |
115 | epoch_losses = np.zeros((0, 1), dtype=np.float32)
116 | for iteration in tqdm(range(args.iterations_per_epoch), ncols=100):
117 | images, targets, _ = next(dataloader_iterator)
118 | images, targets = images.to(args.device), targets.to(args.device)
119 |
120 | if args.augmentation_device == "cuda":
121 | images = gpu_augmentation(images)
122 |
123 | model_optimizer.zero_grad()
124 | classifiers_optimizers[current_group_num].zero_grad()
125 |
126 | if not args.use_amp16:
127 | descriptors = model(images)
128 | output = classifiers[current_group_num](descriptors, targets)
129 | loss = criterion(output, targets)
130 | loss.backward()
131 | epoch_losses = np.append(epoch_losses, loss.item())
132 | del loss, output, images
133 | model_optimizer.step()
134 | classifiers_optimizers[current_group_num].step()
135 | else: # Use AMP 16
136 | with torch.cuda.amp.autocast():
137 | descriptors = model(images)
138 | output = classifiers[current_group_num](descriptors, targets)
139 | loss = criterion(output, targets)
140 | scaler.scale(loss).backward()
141 | epoch_losses = np.append(epoch_losses, loss.item())
142 | del loss, output, images
143 | scaler.step(model_optimizer)
144 | scaler.step(classifiers_optimizers[current_group_num])
145 | scaler.update()
146 |
147 | classifiers[current_group_num] = classifiers[current_group_num].cpu()
148 | util.move_to_device(classifiers_optimizers[current_group_num], "cpu")
149 |
150 | logging.debug(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, "
151 | f"loss = {epoch_losses.mean():.4f}")
152 |
153 | #### Evaluation
154 | recalls, recalls_str = test.test(args, val_ds, model)
155 | logging.info(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, {val_ds}: {recalls_str[:20]}")
156 | is_best = recalls[0] > best_val_recall1
157 | best_val_recall1 = max(recalls[0], best_val_recall1)
158 | # Save checkpoint, which contains all training parameters
159 | util.save_checkpoint({
160 | "epoch_num": epoch_num + 1,
161 | "model_state_dict": model.state_dict(),
162 | "optimizer_state_dict": model_optimizer.state_dict(),
163 | "classifiers_state_dict": [c.state_dict() for c in classifiers],
164 | "optimizers_state_dict": [c.state_dict() for c in classifiers_optimizers],
165 | "best_val_recall1": best_val_recall1
166 | }, is_best, args.output_folder)
167 |
168 |
169 | logging.info(f"Trained for {epoch_num+1:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}")
170 |
171 | #### Test best model on test set v1
172 | best_model_state_dict = torch.load(f"{args.output_folder}/best_model.pth")
173 | model.load_state_dict(best_model_state_dict)
174 |
175 | logging.info(f"Now testing on the test set: {test_ds}")
176 | recalls, recalls_str = test.test(args, test_ds, model, args.num_preds_to_save)
177 | logging.info(f"{test_ds}: {recalls_str}")
178 |
179 | logging.info("Experiment finished (without any errors)")
180 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import shutil
4 | import logging
5 | from typing import Type, List
6 | from argparse import Namespace
7 | from cosface_loss import MarginCosineProduct
8 |
9 |
10 | def move_to_device(optimizer: Type[torch.optim.Optimizer], device: str):
11 | for state in optimizer.state.values():
12 | for k, v in state.items():
13 | if torch.is_tensor(v):
14 | state[k] = v.to(device)
15 |
16 |
17 | def save_checkpoint(state: dict, is_best: bool, output_folder: str,
18 | ckpt_filename: str = "last_checkpoint.pth"):
19 | # TODO it would be better to move weights to cpu before saving
20 | checkpoint_path = f"{output_folder}/{ckpt_filename}"
21 | torch.save(state, checkpoint_path)
22 | if is_best:
23 | torch.save(state["model_state_dict"], f"{output_folder}/best_model.pth")
24 |
25 |
26 | def resume_train(args: Namespace, output_folder: str, model: torch.nn.Module,
27 | model_optimizer: Type[torch.optim.Optimizer], classifiers: List[MarginCosineProduct],
28 | classifiers_optimizers: List[Type[torch.optim.Optimizer]]):
29 | """Load model, optimizer, and other training parameters"""
30 | logging.info(f"Loading checkpoint: {args.resume_train}")
31 | checkpoint = torch.load(args.resume_train)
32 | start_epoch_num = checkpoint["epoch_num"]
33 |
34 | model_state_dict = checkpoint["model_state_dict"]
35 | model.load_state_dict(model_state_dict)
36 |
37 | model = model.to(args.device)
38 | model_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
39 |
40 | assert args.groups_num == len(classifiers) == len(classifiers_optimizers) == \
41 | len(checkpoint["classifiers_state_dict"]) == len(checkpoint["optimizers_state_dict"]), \
42 | (f"{args.groups_num}, {len(classifiers)}, {len(classifiers_optimizers)}, "
43 | f"{len(checkpoint['classifiers_state_dict'])}, {len(checkpoint['optimizers_state_dict'])}")
44 |
45 | for c, sd in zip(classifiers, checkpoint["classifiers_state_dict"]):
46 | # Move classifiers to GPU before loading their optimizers
47 | c = c.to(args.device)
48 | c.load_state_dict(sd)
49 | for c, sd in zip(classifiers_optimizers, checkpoint["optimizers_state_dict"]):
50 | c.load_state_dict(sd)
51 | for c in classifiers:
52 | # Move classifiers back to CPU to save some GPU memory
53 | c = c.cpu()
54 |
55 | best_val_recall1 = checkpoint["best_val_recall1"]
56 |
57 | # Copy best model to current output_folder
58 | shutil.copy(args.resume_train.replace("last_checkpoint.pth", "best_model.pth"), output_folder)
59 |
60 | return model, model_optimizer, classifiers, classifiers_optimizers, best_val_recall1, start_epoch_num
61 |
--------------------------------------------------------------------------------
/visualizations.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import cv2
4 | import numpy as np
5 | from tqdm import tqdm
6 | from skimage.transform import rescale
7 | from PIL import Image, ImageDraw, ImageFont
8 |
9 |
10 | # Height and width of a single image
11 | H = 512
12 | W = 512
13 | TEXT_H = 175
14 | FONTSIZE = 80
15 | SPACE = 50 # Space between two images
16 |
17 |
18 | def write_labels_to_image(labels=["text1", "text2"]):
19 | """Creates an image with vertical text, spaced along rows."""
20 | font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", FONTSIZE)
21 | img = Image.new('RGB', ((W * len(labels)) + 50 * (len(labels)-1), TEXT_H), (1, 1, 1))
22 | d = ImageDraw.Draw(img)
23 | for i, text in enumerate(labels):
24 | _, _, w, h = d.textbbox((0,0), text, font=font)
25 | d.text(((W+SPACE)*i + W//2 - w//2, 1), text, fill=(0, 0, 0), font=font)
26 | return np.array(img)
27 |
28 |
29 | def draw(img, c=(0, 255, 0), thickness=20):
30 | """Draw a colored (usually red or green) box around an image."""
31 | p = np.array([[0, 0], [0, img.shape[0]], [img.shape[1], img.shape[0]], [img.shape[1], 0]])
32 | for i in range(3):
33 | cv2.line(img, (p[i, 0], p[i, 1]), (p[i+1, 0], p[i+1, 1]), c, thickness=thickness*2)
34 | return cv2.line(img, (p[3, 0], p[3, 1]), (p[0, 0], p[0, 1]), c, thickness=thickness*2)
35 |
36 |
37 | def build_prediction_image(images_paths, preds_correct=None):
38 | """Build a row of images, where the first is the query and the rest are predictions.
39 | For each image, if is_correct then draw a green/red box.
40 | """
41 | assert len(images_paths) == len(preds_correct)
42 | labels = ["Query"] + [f"Pr{i} - {is_correct}" for i, is_correct in enumerate(preds_correct[1:])]
43 | num_images = len(images_paths)
44 | images = [np.array(Image.open(path)) for path in images_paths]
45 | for img, correct in zip(images, preds_correct):
46 | if correct is None:
47 | continue
48 | color = (0, 255, 0) if correct else (255, 0, 0)
49 | draw(img, color)
50 | concat_image = np.ones([H, (num_images*W)+((num_images-1)*SPACE), 3])
51 | rescaleds = [rescale(i, [min(H/i.shape[0], W/i.shape[1]), min(H/i.shape[0], W/i.shape[1]), 1]) for i in images]
52 | for i, image in enumerate(rescaleds):
53 | pad_width = (W - image.shape[1] + 1) // 2
54 | pad_height = (H - image.shape[0] + 1) // 2
55 | image = np.pad(image, [[pad_height, pad_height], [pad_width, pad_width], [0, 0]], constant_values=1)[:H, :W]
56 | concat_image[: , i*(W+SPACE) : i*(W+SPACE)+W] = image
57 | try:
58 | labels_image = write_labels_to_image(labels)
59 | final_image = np.concatenate([labels_image, concat_image])
60 | except OSError: # Handle error in case of missing PIL ImageFont
61 | final_image = concat_image
62 | final_image = Image.fromarray((final_image*255).astype(np.uint8))
63 | return final_image
64 |
65 |
66 | def save_file_with_paths(query_path, preds_paths, positives_paths, output_path):
67 | file_content = []
68 | file_content.append("Query path:")
69 | file_content.append(query_path + "\n")
70 | file_content.append("Predictions paths:")
71 | file_content.append("\n".join(preds_paths) + "\n")
72 | file_content.append("Positives paths:")
73 | file_content.append("\n".join(positives_paths) + "\n")
74 | with open(output_path, "w") as file:
75 | _ = file.write("\n".join(file_content))
76 |
77 |
78 | def save_preds(predictions, eval_ds, output_folder, save_only_wrong_preds=None):
79 | """For each query, save an image containing the query and its predictions,
80 | and a file with the paths of the query, its predictions and its positives.
81 |
82 | Parameters
83 | ----------
84 | predictions : np.array of shape [num_queries x num_preds_to_viz], with the preds
85 | for each query
86 | eval_ds : TestDataset
87 | output_folder : str / Path with the path to save the predictions
88 | save_only_wrong_preds : bool, if True save only the wrongly predicted queries,
89 | i.e. the ones where the first pred is uncorrect (further than 25 m)
90 | """
91 | positives_per_query = eval_ds.get_positives()
92 | os.makedirs(f"{output_folder}/preds", exist_ok=True)
93 | for query_index, preds in enumerate(tqdm(predictions, ncols=80, desc=f"Saving preds in {output_folder}")):
94 | query_path = eval_ds.queries_paths[query_index]
95 | list_of_images_paths = [query_path]
96 | # List of None (query), True (correct preds) or False (wrong preds)
97 | preds_correct = [None]
98 | for pred_index, pred in enumerate(preds):
99 | pred_path = eval_ds.database_paths[pred]
100 | list_of_images_paths.append(pred_path)
101 | is_correct = pred in positives_per_query[query_index]
102 | preds_correct.append(is_correct)
103 |
104 | if save_only_wrong_preds and preds_correct[1]:
105 | continue
106 |
107 | prediction_image = build_prediction_image(list_of_images_paths, preds_correct)
108 | pred_image_path = f"{output_folder}/preds/{query_index:03d}.jpg"
109 | prediction_image.save(pred_image_path)
110 |
111 | positives_paths = [eval_ds.database_paths[idx] for idx in positives_per_query[query_index]]
112 | save_file_with_paths(
113 | query_path=list_of_images_paths[0],
114 | preds_paths=list_of_images_paths[1:],
115 | positives_paths=positives_paths,
116 | output_path=f"{output_folder}/preds/{query_index:03d}.txt"
117 | )
118 |
119 |
120 |
--------------------------------------------------------------------------------