├── .gitignore ├── LICENSE ├── README.md ├── assets ├── results.png └── slide-window.png ├── datasets_ws.py ├── environment.yml ├── model ├── __init__.py ├── aggregation.py ├── functional.py ├── network.py ├── non_local.py ├── normalization.py └── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── parser.py ├── test.py ├── tools ├── __init__.py ├── commons.py ├── loss.py ├── map_builder.py ├── paper_utils.py ├── util.py └── visual.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Autogenerated folders 2 | __pycache__ 3 | logs 4 | test 5 | data 6 | 7 | # IDEs generated folders 8 | .spyproject 9 | venv/ 10 | .idea/ 11 | __MACOSX/ 12 | **/.DS_Store 13 | 14 | # other 15 | pretrained 16 | *.pth -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 zafirshi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ## PanoVPR: Towards Unified Perspective-to-Equirectangular Visual Place Recognition via Sliding Windows across the Panoramic View 4 | 5 |
6 | 7 |

8 | Ze Shi* · 9 | Hao Shi* · 10 | Kailun Yang · 11 | Zhe Yin · 12 | Yining Lin · 13 | Kaiwei Wang 14 |

15 |

16 | 17 |

18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |

27 | 28 | ## Update 29 | 30 | - [2023-11] :gear: Code Release 31 | - [2023-08] :tada: PanoVPR is accepted to 26th IEEE International Conference on Intelligent Transportation Systems ([ITSC-2023](https://2023.ieee-itsc.org/)). 32 | - [2023-03] :construction: Init repo and release [arxiv](https://arxiv.org/pdf/2303.14095.pdf) version 33 | 34 | ## Introduction 35 | 36 | We propose a **V**isual **P**lace **R**ecognition framework for retrieving **Pano**ramic database images using perspective query images, dubbed **PanoVPR**. To achieve this, we adopt sliding window approach on panoramic database images to narrow the model's observation range of the large field of view panoramas. We achieve promising results in a derived dataset *Pitts250K-P2E* and a real-world scenario dataset *YQ360*. 37 | 38 | For more details, please check our [arXiv](https://arxiv.org/pdf/2303.14095.pdf) paper. 39 | 40 | ## Sliding window strategy 41 | 42 | ![Silding window](assets/slide-window.png) 43 | 44 | ## Qualitative results 45 | 46 | ![CMNeXt](assets/results.png) 47 | 48 | ## Usage 49 | 50 | ### Dataset Preparation 51 | 52 | Before starting, you need to download the Pitts250K-P2E dataset and the YQ360 dataset [[OneDrive Link](https://zjueducn-my.sharepoint.com/:f:/g/personal/zafirshi_zju_edu_cn/Ei4N__otNrVAjxku0UnT-pQBdsOSF3PvAEi8Z9wGu7Aj0w?e=LvVwIp)][[BaiduYun Link](https://pan.baidu.com/s/1IBcpAwnwY5YlqfgfSqRz-w?pwd=Pano)]. 53 | 54 | If the link is out of date, please email _office_makeit@163.com_ for the latest available link! 55 | 56 | Afterwards, specify the `--datasets_folder` parameter in the `parser.py` file. 57 | 58 | 59 | ### Setup 60 | 61 | You need to first create an environment from file `environment.yml` using [Conda](https://docs.conda.io/projects/miniconda/en/latest/miniconda-install.html), and then activate it. 62 | 63 | ```bash 64 | conda env create -f environment.yml --prefix /path/to/env 65 | conda activate PanoVPR 66 | ``` 67 | 68 | ### Train 69 | 70 | If you want to train the network, you can change the training configuration and the dataset used 71 | by specifying parameters such as `--backbone`, `--split_nums`, and `--dataset_name` in the command line. 72 | 73 | Meanwhile, adjust other parameters according to the actual situation. 74 | By default, the output results are stored in the `./logs/{save_dir}` folder. 75 | 76 | **Please note that the `--title` parameter must be specified in the command line.** 77 | 78 | ```bash 79 | # Train on Pitts250K-P2E 80 | python train.py --title swinTx24 \ 81 | --save_dir clean_branch_test \ 82 | --backbone swin_tiny \ 83 | --split_nums 24 \ 84 | --dataset_name pitts250k \ 85 | --cache_refresh_rate 125 \ 86 | --neg_sample 100 \ 87 | --queries_per_epoch 2000 88 | ``` 89 | 90 | ### Inference 91 | 92 | For the inference process, you need to specify the absolute path where the `best_model.pth` is stored in the `--resume` parameter. 93 | 94 | ```bash 95 | # Val and Test On Pitts250K-P2E 96 | python test.py --title test_swinTx24 \ 97 | --save_dir clean_branch_test \ 98 | --backbone swin_tiny \ 99 | --split_nums 24 \ 100 | --dataset_name pitts250k \ 101 | --cache_refresh_rate 125 \ 102 | --neg_sample 100 \ 103 | --queries_per_epoch 2000 \ 104 | --resume 105 | ``` 106 | 107 | ## Acknowledgments 108 | 109 | We thank the authors of the following repositories for their open source code: 110 | 111 | - [tranleanh/image-panorama-stitching](https://github.com/tranleanh/image-panorama-stitching) 112 | - [gmberton/datasets_vg](https://github.com/gmberton/datasets_vg) 113 | - [gmberton/deep-visual-geo-localization-benchmark](https://github.com/gmberton/deep-visual-geo-localization-benchmark) 114 | 115 | 116 | ## Cite Our Work 117 | 118 | Thanks for using our work. You can cite it as: 119 | 120 | ```bib 121 | @INPROCEEDINGS{shi2023panovpr, 122 | author={Shi, Ze and Shi, Hao and Yang, Kailun and Yin, Zhe and Lin, Yining and Wang, Kaiwei}, 123 | booktitle={2023 IEEE 26th International Conference on Intelligent Transportation Systems (ITSC)}, 124 | title={PanoVpr: Towards Unified Perspective-to-Equirectangular Visual Place Recognition via Sliding Windows Across the Panoramic View}, 125 | year={2023}, 126 | pages={1333-1340}, 127 | doi={10.1109/ITSC57777.2023.10421857} 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zafirshi/PanoVPR/7f576b7679691882fc8f0346930deb0aff6d1e38/assets/results.png -------------------------------------------------------------------------------- /assets/slide-window.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zafirshi/PanoVPR/7f576b7679691882fc8f0346930deb0aff6d1e38/assets/slide-window.png -------------------------------------------------------------------------------- /datasets_ws.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import PIL 5 | import torch 6 | import faiss 7 | import matplotlib.pyplot as plt 8 | import logging 9 | import numpy as np 10 | from glob import glob 11 | from tqdm import tqdm 12 | from os.path import join 13 | import torch.utils.data as data 14 | import torchvision.transforms as transforms 15 | from torch.utils.data.dataset import Subset 16 | from sklearn.neighbors import NearestNeighbors 17 | from torch.utils.data.dataloader import DataLoader 18 | from tools.visual import path_to_pil_img 19 | 20 | base_transform = transforms.Compose([ 21 | transforms.ToTensor(), 22 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 23 | ]) 24 | 25 | 26 | def shift_window_on_descriptor(short_vector, long_matrix, window_size, divisor_factor, sorted_indices_num): 27 | """ 28 | Shift window in descriptor comparing between short_query_vector and long_database_matrix, 29 | and calculate similarity in window range 30 | :param divisor_factor: window_stride=window_size/n. n->divisor_factor 31 | :param short_vector: a short numpy vector generated by query image 32 | :param long_matrix: a long and multi-row matrix(2d numpy array) generated by sampled negative images 33 | :param window_size: feature_window_length set as shifting base length 34 | :param sorted_indices_num: num of index we select from sorted candidate descriptors 35 | :return: 36 | sorted_indices_num_in_matrix: database index 37 | window_loc_in_matrix: focus patch in selected database matrix 38 | """ 39 | left, right = 0, window_size 40 | window_stride = int(window_size / divisor_factor) 41 | 42 | part_distance_list = [] 43 | width_bound = long_matrix.shape[1] 44 | counter = 0 45 | 46 | while left < width_bound: 47 | if right <= width_bound: 48 | part_distance = np.linalg.norm(short_vector - long_matrix[:, left:right], ord=2, axis=1) 49 | else: 50 | cycle_part = np.concatenate((long_matrix[:,left:],long_matrix[:,:right-width_bound]),axis=1) 51 | part_distance = np.linalg.norm(short_vector - cycle_part, ord=2, axis=1) 52 | part_distance_list.append(part_distance) 53 | 54 | left += window_stride 55 | right += window_stride 56 | counter += 1 57 | 58 | maintain_table = np.array(part_distance_list).transpose() 59 | 60 | # Select the smallest 'sorted_indices_num' 'long_matrix_index' based on distance 61 | sorted_indices_num_in_matrix = np.argsort(np.min(maintain_table, axis=1))[:sorted_indices_num] 62 | 63 | window_loc_in_matrix = np.argmin(maintain_table, axis=1)[sorted_indices_num_in_matrix] 64 | return sorted_indices_num_in_matrix, window_loc_in_matrix 65 | 66 | 67 | def shift_window_on_img(img: torch.Tensor, win_num: int, win_stride: int, win_len: int) -> torch.Tensor: 68 | """ 69 | shift window on picture 70 | :param img: tensor with long width, shape like be (3, 448, 3584) 71 | :param win_num: split window_num [should calculate carefully before training] 72 | :param win_stride: step window_shift 73 | :param win_len: 74 | :return:split_pano img -> a tensor shape like (N, 3, 448, 3584/N) 75 | """ 76 | input_img_width = img.shape[-1] 77 | img_split_list = [] 78 | for i in range(win_num): 79 | sw_left = i * win_stride 80 | sw_right = i * win_stride + win_len 81 | if sw_left <= input_img_width and sw_right <= input_img_width: 82 | img_split_list += [img[:, :, int(sw_left):int(sw_right)]] 83 | # when one-go directly, win_num should be 15(7+8) or other numbers 84 | # cycle calculate concat the rightest and the leftest part 85 | elif sw_left <= input_img_width < sw_right: 86 | img_split_list += [torch.concat([img[:, :, int(sw_left):], 87 | img[:, :, :int(sw_right - input_img_width)]], dim=-1)] 88 | else: 89 | break 90 | img = torch.stack(img_split_list, dim=0) 91 | return img 92 | 93 | 94 | class PCADataset(data.Dataset): 95 | def __init__(self, args, datasets_folder="dataset", dataset_folder="pitts30k/images/train"): 96 | dataset_folder_full_path = join(datasets_folder, dataset_folder) 97 | if not os.path.exists(dataset_folder_full_path): 98 | raise FileNotFoundError(f"Folder {dataset_folder_full_path} does not exist") 99 | self.images_paths = sorted(glob(join(dataset_folder_full_path, "**", "*.jpg"), recursive=True)) 100 | 101 | def __getitem__(self, index): 102 | return base_transform(path_to_pil_img(self.images_paths[index])) 103 | 104 | def __len__(self): 105 | return len(self.images_paths) 106 | 107 | 108 | class BaseDataset(data.Dataset): 109 | """Dataset with images from database and queries, used for inference (testing and building cache). 110 | """ 111 | 112 | def __init__(self, args, datasets_folder="datasets", dataset_name="pitts30k", split="train"): 113 | super().__init__() 114 | self.args = args 115 | self.dataset_name = dataset_name 116 | self.dataset_folder = join(datasets_folder, dataset_name, "images", split) 117 | if not os.path.exists(self.dataset_folder): raise FileNotFoundError( 118 | f"Folder {self.dataset_folder} does not exist") 119 | 120 | self.resize = args.resize 121 | self.query_resize = args.query_resize 122 | self.database_resize = args.database_resize 123 | self.test_method = args.test_method 124 | 125 | self.split_nums = args.split_nums 126 | # Non-overlapping sliding window length (in pixels), default to the width of the perspective query image. 127 | self.window_len = self.resize[1] 128 | self.window_stride = 8 * self.window_len / self.split_nums # Pixel step length of the sliding window. 129 | if self.window_stride < self.window_len: 130 | logging.debug('[Note]:slide window using overlapping way') 131 | 132 | # for display,locate the selected window_num 133 | self.pos_focus_patch = [] 134 | self.neg_focus_patch = [] 135 | 136 | #### Read paths and UTM coordinates for all images. 137 | database_folder = join(self.dataset_folder, "database_pano_clean") 138 | queries_folder = join(self.dataset_folder, "queries_split") 139 | 140 | if not os.path.exists(database_folder): raise FileNotFoundError(f"Folder {database_folder} does not exist") 141 | if not os.path.exists(queries_folder): raise FileNotFoundError(f"Folder {queries_folder} does not exist") 142 | self.database_paths = sorted(glob(join(database_folder, "**", "*.jpg"), recursive=True)) 143 | self.queries_paths = sorted(glob(join(queries_folder, "**", "*.jpg"), recursive=True)) 144 | # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg 145 | self.database_utms = np.array( 146 | [(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(np.float) 147 | self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype( 148 | np.float) 149 | 150 | # Find soft_positives_per_query, which are within val_positive_dist_threshold (deafult 25 meters) 151 | knn = NearestNeighbors(n_jobs=-1) 152 | knn.fit(self.database_utms) 153 | self.soft_positives_per_query = knn.radius_neighbors(self.queries_utms, 154 | radius=args.val_positive_dist_threshold, 155 | return_distance=False) 156 | 157 | self.images_paths = list(self.database_paths) + list(self.queries_paths) 158 | 159 | self.database_num = len(self.database_paths) 160 | self.queries_num = len(self.queries_paths) 161 | 162 | def __getitem__(self, index): 163 | if index < self.database_num: 164 | # split pano_database images into several subs 165 | img = path_to_pil_img(self.images_paths[index]) 166 | img = self.resize_database_p2e(img) # shape:(3,224,224*8) 167 | 168 | elif index >= self.database_num: 169 | # query image just resize 170 | img = path_to_pil_img(self.images_paths[index]) 171 | img = base_transform(img) 172 | img = transforms.functional.resize(img, self.query_resize) 173 | # resize to adapt backbone input size 174 | img = transforms.functional.resize(img, self.resize) 175 | 176 | return img, index 177 | 178 | def resize_database_p2e(self, img: PIL.Image) -> torch.Tensor: 179 | img = base_transform(img) 180 | img = transforms.functional.resize(img, self.database_resize) 181 | img = transforms.functional.resize(img, (self.resize[0], 8 * self.resize[1])) # shape:(3,224,224*8) 182 | return img 183 | 184 | def __len__(self): 185 | return len(self.images_paths) 186 | 187 | def __repr__(self): 188 | return ( 189 | f"< {self.__class__.__name__}, {self.dataset_name} - #database: {self.database_num}; #queries: {self.queries_num} >") 190 | 191 | def get_positives(self): 192 | return self.soft_positives_per_query 193 | 194 | 195 | class TripletsDataset(BaseDataset): 196 | """Dataset used for training, it is used to compute the triplets 197 | with TripletsDataset.compute_triplets() with various mining methods. 198 | If is_inference == True, uses methods of the parent class BaseDataset, 199 | this is used for example when computing the cache, because we compute features 200 | of each image, not triplets. 201 | """ 202 | 203 | def __init__(self, args, datasets_folder="datasets", dataset_name="pitts30k", split="train", negs_num_per_query=10): 204 | super().__init__(args, datasets_folder, dataset_name, split) 205 | self.mining = args.mining 206 | self.neg_samples_num = args.neg_samples_num # Number of negatives to randomly sample 207 | self.negs_num_per_query = negs_num_per_query # Number of negatives per query in each batch 208 | if self.mining == "full": # "Full database mining" keeps a cache with last used negatives 209 | self.neg_cache = [np.empty((0,), dtype=np.int32) for _ in range(self.queries_num)] 210 | self.is_inference = False 211 | 212 | identity_transform = transforms.Lambda(lambda x: x) 213 | self.resized_transform = transforms.Compose([ 214 | transforms.Resize(self.resize) if self.resize is not None else identity_transform, 215 | base_transform 216 | ]) 217 | 218 | self.query_transform = transforms.Compose([ 219 | transforms.ColorJitter(brightness=args.brightness) if args.brightness != None else identity_transform, 220 | transforms.ColorJitter(contrast=args.contrast) if args.contrast != None else identity_transform, 221 | transforms.ColorJitter(saturation=args.saturation) if args.saturation != None else identity_transform, 222 | transforms.ColorJitter(hue=args.hue) if args.hue != None else identity_transform, 223 | transforms.RandomPerspective( 224 | args.rand_perspective) if args.rand_perspective != None else identity_transform, 225 | transforms.RandomResizedCrop(size=self.resize, scale=(1 - args.random_resized_crop, 1)) \ 226 | if args.random_resized_crop != None else identity_transform, 227 | transforms.RandomRotation( 228 | degrees=args.random_rotation) if args.random_rotation != None else identity_transform, 229 | self.resized_transform, 230 | ]) 231 | 232 | # Find hard_positives_per_query, which are within train_positives_dist_threshold (10 meters) 233 | knn = NearestNeighbors(n_jobs=-1) 234 | knn.fit(self.database_utms) 235 | self.hard_positives_per_query = list(knn.radius_neighbors(self.queries_utms, 236 | radius=args.train_positives_dist_threshold, 237 | # 10 meters 238 | return_distance=False)) 239 | 240 | #### Some queries might have no positive, we should remove those queries. 241 | queries_without_any_hard_positive = \ 242 | np.where(np.array([len(p) for p in self.hard_positives_per_query], dtype=object) == 0)[0] 243 | if len(queries_without_any_hard_positive) != 0: 244 | logging.info(f"There are {len(queries_without_any_hard_positive)} queries without any positives " + 245 | "within the training set. They won't be considered as they're useless for training.") 246 | # Remove queries without positives 247 | self.hard_positives_per_query = np.delete(self.hard_positives_per_query, queries_without_any_hard_positive) 248 | self.queries_paths = np.delete(self.queries_paths, queries_without_any_hard_positive) 249 | 250 | # Recompute images_paths and queries_num because some queries might have been removed 251 | self.images_paths = list(self.database_paths) + list(self.queries_paths) 252 | self.queries_num = len(self.queries_paths) 253 | 254 | 255 | def __getitem__(self, index): 256 | if self.is_inference: 257 | return super().__getitem__(index) 258 | query_index, best_positive_index, neg_indexes = torch.split(self.triplets_global_indexes[index], 259 | (1, 1, self.negs_num_per_query)) 260 | 261 | query = self.query_transform(path_to_pil_img(self.queries_paths[query_index])) 262 | database_indices_list = torch.concat([best_positive_index, neg_indexes]).tolist() 263 | 264 | pano_database_list = [self.resize_database_p2e(each_pano_database) for each_pano_database in 265 | [path_to_pil_img(self.database_paths[i]) for i in database_indices_list]] 266 | 267 | pano_database = torch.stack(pano_database_list, dim=0) 268 | return query, pano_database 269 | 270 | def __len__(self): 271 | if self.is_inference: 272 | return super().__len__() 273 | else: 274 | return len(self.triplets_global_indexes) 275 | 276 | def compute_triplets(self, args, model): 277 | self.is_inference = True 278 | 279 | if self.mining == "partial": 280 | self.compute_triplets_partial(args, model) 281 | elif self.mining == "full": 282 | raise NotImplementedError(f"{self.mining} has not been implemented yet, use partial instead") 283 | elif self.mining == "random": 284 | raise NotImplementedError(f"{self.mining} has not been implemented yet, use partial instead") 285 | else: 286 | raise KeyError(f"{self.mining} is not among the options (partial, full, random)") 287 | 288 | def compute_cache_database(self, args, model, subset_ds, database_cache_shape): 289 | """Compute the cache containing features of images, which is used to 290 | find best positive and hardest negatives.""" 291 | 292 | subset_dl = DataLoader(dataset=subset_ds, num_workers=args.num_workers, 293 | batch_size=args.infer_batch_size, shuffle=False, 294 | pin_memory=(args.device == "cuda")) 295 | model = model.eval() 296 | 297 | # RAMEfficient2DMatrix can be replaced by np.zeros, but using 298 | # RAMEfficient2DMatrix is RAM efficient for full database mining. 299 | database_cache = RAMEfficient2DMatrix(database_cache_shape, dtype=np.float32) 300 | 301 | with torch.no_grad(): 302 | for images, indexes in tqdm(subset_dl, ncols=100, desc='Compute Database Cache'): 303 | # images shape: 16(infer_bs),3,224,224*8 -> 16,16(split_num),3,224,224 304 | 305 | images = torch.stack([shift_window_on_img(each_img, win_num=self.split_nums, 306 | win_stride=self.window_stride, win_len=self.window_len) 307 | for each_img in images]) 308 | 309 | B, N, C, resize_H, resize_W = images.shape 310 | images = images.view(B * N, C, resize_H, resize_W) 311 | 312 | images = images.to(args.device) 313 | features = model(images) 314 | features = torch.flatten(features.view(B, N, -1), start_dim=1) 315 | 316 | database_cache[indexes.numpy()] = features.cpu().numpy() 317 | 318 | return database_cache 319 | 320 | @staticmethod 321 | def compute_cache_query(args, model, subset_ds, query_cache_shape, database_num): 322 | """Compute the cache containing features of images, which is used to 323 | find best positive and hardest negatives.""" 324 | 325 | subset_dl = DataLoader(dataset=subset_ds, num_workers=args.num_workers, 326 | batch_size=args.infer_batch_size, shuffle=False, 327 | pin_memory=(args.device == "cuda")) 328 | model = model.eval() 329 | 330 | query_cache = RAMEfficient2DMatrix(query_cache_shape, dtype=np.float32) 331 | 332 | with torch.no_grad(): 333 | for images, indexes in tqdm(subset_dl, ncols=100, desc='Compute Query Cache'): 334 | images = images.to(args.device) 335 | features = model(images) 336 | # minus to begin from 0 337 | query_cache[indexes.numpy() - database_num] = features.cpu().numpy() 338 | 339 | return query_cache 340 | 341 | def get_query_features(self, query_index, query_cache): 342 | query_features = query_cache[query_index] 343 | if query_features is None: 344 | raise RuntimeError(f"For query {self.queries_paths[query_index]} " + 345 | f"with index {query_index} features have not been computed!\n" + 346 | "There might be some bug with caching") 347 | return query_features 348 | 349 | def get_best_positive_index(self, args, query_index, database_cache, query_features) -> int: 350 | positives_features = database_cache[self.hard_positives_per_query[query_index]] 351 | 352 | # Segmentally compare the similarity between descriptors, 353 | # calculate the average row-wise (each row represents a panoramic image) 354 | best_positive_in_matrix, window_loc = shift_window_on_descriptor(short_vector=query_features, 355 | long_matrix=positives_features, 356 | window_size=args.features_dim, 357 | divisor_factor=args.reduce_factor, 358 | sorted_indices_num=1) 359 | 360 | # Search the best positive (within 10 meters AND nearest in features space) 361 | best_positive_index = self.hard_positives_per_query[query_index][best_positive_in_matrix] 362 | self.pos_focus_patch.append(window_loc) 363 | return best_positive_index.item() 364 | 365 | def get_hardest_negatives_indexes(self, args, database_cache, query_features, neg_samples): 366 | neg_features = database_cache[neg_samples] # shape:(unknown, 8*feature_dim) 367 | 368 | neg_nums_in_matrix, window_loc = shift_window_on_descriptor(short_vector=query_features, 369 | long_matrix=neg_features, 370 | window_size=args.features_dim, 371 | divisor_factor=args.reduce_factor, 372 | sorted_indices_num=self.negs_num_per_query) 373 | 374 | neg_indexes = neg_samples[neg_nums_in_matrix] 375 | self.neg_focus_patch.append(window_loc) 376 | return neg_indexes.tolist() 377 | 378 | def compute_triplets_partial(self, args, model): 379 | triplets_global_indexes = [] 380 | if self.mining == "partial": 381 | sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False) 382 | else: 383 | raise ValueError(f'sampled_queries_indexes is set wrong') 384 | 385 | sampled_database_indexes = np.random.choice(self.database_num, self.neg_samples_num, replace=False) 386 | # Take all the positives 387 | positives_indexes = [self.hard_positives_per_query[i] for i in sampled_queries_indexes] 388 | positives_indexes = [p for pos in positives_indexes for p in pos] 389 | # Merge them into database_indexes and remove duplicates 390 | database_indexes = list(sampled_database_indexes) + positives_indexes 391 | database_indexes = list(np.unique(database_indexes)) 392 | 393 | subset_ds_database = Subset(self, database_indexes) 394 | subset_ds_query = Subset(self, list(sampled_queries_indexes + self.database_num)) 395 | 396 | database_cache = self.compute_cache_database(args, model, subset_ds_database, 397 | database_cache_shape=( 398 | self.database_num, args.split_nums * args.features_dim)) 399 | 400 | query_cache = self.compute_cache_query(args, model, subset_ds_query, 401 | query_cache_shape=(self.queries_num, args.features_dim), 402 | database_num=self.database_num) 403 | 404 | for query_index in tqdm(sampled_queries_indexes, ncols=100, desc='Neg Mining'): 405 | 406 | query_features = self.get_query_features(query_index, query_cache) 407 | best_positive_index = self.get_best_positive_index(args, query_index, database_cache, query_features) 408 | 409 | # Choose the hardest negatives within sampled_database_indexes, ensuring that there are no positives 410 | soft_positives = self.soft_positives_per_query[query_index] 411 | neg_indexes = np.setdiff1d(sampled_database_indexes, soft_positives, assume_unique=True) 412 | 413 | # Take all database images that are negatives and are within the sampled database images (aka database_indexes) 414 | neg_indexes = self.get_hardest_negatives_indexes(args, database_cache, query_features, neg_indexes) 415 | triplets_global_indexes.append((int(query_index), best_positive_index, *neg_indexes)) 416 | 417 | self.triplets_global_indexes = torch.tensor(triplets_global_indexes) 418 | 419 | 420 | class RAMEfficient2DMatrix: 421 | """This class behaves similarly to a numpy.ndarray initialized 422 | with np.zeros(), but is implemented to save RAM when the rows 423 | within the 2D array are sparse. In this case it's needed because 424 | we don't always compute features for each image, just for few of 425 | them""" 426 | 427 | def __init__(self, shape, dtype=np.float32): 428 | self.shape = shape 429 | self.dtype = dtype 430 | self.matrix = [None] * shape[0] 431 | 432 | def __setitem__(self, indexes, vals): 433 | assert vals.shape[1] == self.shape[1], f"{vals.shape[1]} {self.shape[1]}" 434 | for i, val in zip(indexes, vals): 435 | self.matrix[i] = val.astype(self.dtype, copy=False) 436 | 437 | def __getitem__(self, index): 438 | if hasattr(index, "__len__"): 439 | return np.array([self.matrix[i] for i in index]) 440 | else: 441 | return self.matrix[index] 442 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: PanoVPR2 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - brotlipy=0.7.0=py37h27cfd23_1003 10 | - bzip2=1.0.8=h7b6447c_0 11 | - ca-certificates=2022.10.11=h06a4308_0 12 | - certifi=2022.9.24=py37h06a4308_0 13 | - cffi=1.15.1=py37h74dc2b5_0 14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 15 | - cryptography=38.0.1=py37h9ce1e76_0 16 | - cudatoolkit=11.3.1=h2bc3f7f_2 17 | - faiss-cpu=1.5.3=py37h6bb024c_0 18 | - ffmpeg=4.3=hf484d3e_0 19 | - freetype=2.12.1=h4a9f257_0 20 | - giflib=5.2.1=h7b6447c_0 21 | - gmp=6.2.1=h295c915_3 22 | - gnutls=3.6.15=he1e5248_0 23 | - idna=3.4=py37h06a4308_0 24 | - intel-openmp=2021.4.0=h06a4308_3561 25 | - jpeg=9e=h7f8727e_0 26 | - lame=3.100=h7b6447c_0 27 | - lcms2=2.12=h3be6417_0 28 | - ld_impl_linux-64=2.38=h1181459_1 29 | - lerc=3.0=h295c915_0 30 | - libdeflate=1.8=h7f8727e_5 31 | - libffi=3.3=he6710b0_2 32 | - libgcc-ng=11.2.0=h1234567_1 33 | - libgomp=11.2.0=h1234567_1 34 | - libiconv=1.16=h7f8727e_2 35 | - libidn2=2.3.2=h7f8727e_0 36 | - libpng=1.6.37=hbc83047_0 37 | - libstdcxx-ng=11.2.0=h1234567_1 38 | - libtasn1=4.16.0=h27cfd23_0 39 | - libtiff=4.4.0=hecacb30_2 40 | - libunistring=0.9.10=h27cfd23_0 41 | - libuv=1.40.0=h7b6447c_0 42 | - libwebp=1.2.4=h11a3e52_0 43 | - libwebp-base=1.2.4=h5eee18b_0 44 | - lz4-c=1.9.3=h295c915_1 45 | - mkl=2021.4.0=h06a4308_640 46 | - mkl-service=2.4.0=py37h7f8727e_0 47 | - mkl_fft=1.3.1=py37hd3c417c_0 48 | - mkl_random=1.2.2=py37h51133e4_0 49 | - ncurses=6.3=h5eee18b_3 50 | - nettle=3.7.3=hbbd107a_1 51 | - numpy=1.21.5=py37h6c91a56_3 52 | - numpy-base=1.21.5=py37ha15fc14_3 53 | - openh264=2.1.1=h4ff587b_0 54 | - openssl=1.1.1s=h7f8727e_0 55 | - pillow=9.2.0=py37hace64e9_1 56 | - pip=22.2.2=py37h06a4308_0 57 | - pycparser=2.21=pyhd3eb1b0_0 58 | - pyopenssl=22.0.0=pyhd3eb1b0_0 59 | - pysocks=1.7.1=py37_1 60 | - python=3.7.15=haa1d7c7_0 61 | - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0 62 | - pytorch-mutex=1.0=cuda 63 | - readline=8.2=h5eee18b_0 64 | - setuptools=65.5.0=py37h06a4308_0 65 | - six=1.16.0=pyhd3eb1b0_1 66 | - sqlite=3.40.0=h5082296_0 67 | - tk=8.6.12=h1ccaba5_0 68 | - torchaudio=0.11.0=py37_cu113 69 | - torchvision=0.12.0=py37_cu113 70 | - typing_extensions=4.3.0=py37h06a4308_0 71 | - urllib3=1.26.12=py37h06a4308_0 72 | - wheel=0.37.1=pyhd3eb1b0_0 73 | - xz=5.2.6=h5eee18b_0 74 | - zlib=1.2.13=h5eee18b_0 75 | - zstd=1.5.2=ha4553b6_0 76 | - pip: 77 | - absl-py==1.3.0 78 | - addict==2.4.0 79 | - cachetools==5.2.0 80 | - click==8.1.3 81 | - colorama==0.4.6 82 | - cycler==0.11.0 83 | - einops==0.6.0 84 | - filelock==3.8.0 85 | - fonttools==4.38.0 86 | - google-auth==2.14.1 87 | - google-auth-oauthlib==0.4.6 88 | - googledrivedownloader==0.4 89 | - grpcio==1.50.0 90 | - huggingface-hub==0.0.12 91 | - imageio==2.26.0 92 | - importlib-metadata==5.0.0 93 | - joblib==1.2.0 94 | - kiwisolver==1.4.4 95 | - markdown==3.4.1 96 | - markdown-it-py==2.2.0 97 | - markupsafe==2.1.1 98 | - matplotlib==3.5.3 99 | - mdurl==0.1.2 100 | - mmcls==0.25.0 101 | - mmcv-full==1.7.1 102 | - mmsegmentation==0.30.0 103 | - model-index==0.1.11 104 | - networkx==2.6.3 105 | - oauthlib==3.2.2 106 | - opencv-python==4.7.0.72 107 | - openmim==0.3.6 108 | - ordered-set==4.1.0 109 | - packaging==21.3 110 | - pandas==1.3.5 111 | - prettytable==3.6.0 112 | - protobuf==3.20.3 113 | - pyasn1==0.4.8 114 | - pyasn1-modules==0.2.8 115 | - pydantic==1.10.7 116 | - pygments==2.14.0 117 | - pyparsing==3.0.9 118 | - python-dateutil==2.8.2 119 | - pytz==2022.7.1 120 | - pywavelets==1.3.0 121 | - pyyaml==6.0 122 | - regex==2022.10.31 123 | - requests==2.26.0 124 | - requests-oauthlib==1.3.1 125 | - rich==13.3.1 126 | - rsa==4.9 127 | - sacremoses==0.0.53 128 | - scikit-image==0.17.2 129 | - scikit-learn==0.24.1 130 | - scipy==1.7.3 131 | - seaborn==0.12.2 132 | - staticmap==0.5.4 133 | - tabulate==0.9.0 134 | - tensorboard==2.11.0 135 | - tensorboard-data-server==0.6.1 136 | - tensorboard-plugin-wit==1.8.1 137 | - threadpoolctl==3.1.0 138 | - tifffile==2021.11.2 139 | - timm==0.6.12 140 | - tokenizers==0.10.3 141 | - torchscan==0.1.1 142 | - tqdm==4.48.2 143 | - transformers==4.8.2 144 | - wcwidth==0.2.6 145 | - werkzeug==2.2.2 146 | - yapf==0.32.0 147 | - zipp==3.10.0 148 | prefix: /home/shize/.conda/envs/PanoVPR 149 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zafirshi/PanoVPR/7f576b7679691882fc8f0346930deb0aff6d1e38/model/__init__.py -------------------------------------------------------------------------------- /model/aggregation.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import faiss 5 | import logging 6 | import numpy as np 7 | from tqdm import tqdm 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.parameter import Parameter 11 | from torch.utils.data import DataLoader, SubsetRandomSampler 12 | 13 | import model.functional as LF 14 | import model.normalization as normalization 15 | 16 | class MAC(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | def forward(self, x): 20 | return LF.mac(x) 21 | def __repr__(self): 22 | return self.__class__.__name__ + '()' 23 | 24 | class SPoC(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | def forward(self, x): 28 | return LF.spoc(x) 29 | def __repr__(self): 30 | return self.__class__.__name__ + '()' 31 | 32 | class GeM(nn.Module): 33 | def __init__(self, p=3, eps=1e-6, work_with_tokens=False): 34 | super().__init__() 35 | self.p = Parameter(torch.ones(1)*p) 36 | self.eps = eps 37 | self.work_with_tokens=work_with_tokens 38 | def forward(self, x): 39 | return LF.gem(x, p=self.p, eps=self.eps, work_with_tokens=self.work_with_tokens) 40 | def __repr__(self): 41 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 42 | 43 | class RMAC(nn.Module): 44 | def __init__(self, L=3, eps=1e-6): 45 | super().__init__() 46 | self.L = L 47 | self.eps = eps 48 | def forward(self, x): 49 | return LF.rmac(x, L=self.L, eps=self.eps) 50 | def __repr__(self): 51 | return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')' 52 | 53 | 54 | class Flatten(torch.nn.Module): 55 | def __init__(self): super().__init__() 56 | def forward(self, x): assert x.shape[2] == x.shape[3] == 1; return x[:,:,0,0] 57 | 58 | class RRM(nn.Module): 59 | """Residual Retrieval Module as described in the paper 60 | `Leveraging EfficientNet and Contrastive Learning for AccurateGlobal-scale 61 | Location Estimation ` 62 | """ 63 | def __init__(self, dim): 64 | super().__init__() 65 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) 66 | self.flatten = Flatten() 67 | self.ln1 = nn.LayerNorm(normalized_shape=dim) 68 | self.fc1 = nn.Linear(in_features=dim, out_features=dim) 69 | self.relu = nn.ReLU() 70 | self.fc2 = nn.Linear(in_features=dim, out_features=dim) 71 | self.ln2 = nn.LayerNorm(normalized_shape=dim) 72 | self.l2 = normalization.L2Norm() 73 | def forward(self, x): 74 | x = self.avgpool(x) 75 | x = self.flatten(x) 76 | x = self.ln1(x) 77 | identity = x 78 | out = self.fc2(self.relu(self.fc1(x))) 79 | out += identity 80 | out = self.l2(self.ln2(out)) 81 | return out 82 | 83 | 84 | # based on https://github.com/lyakaap/NetVLAD-pytorch/blob/master/netvlad.py 85 | class NetVLAD(nn.Module): 86 | """NetVLAD layer implementation""" 87 | 88 | def __init__(self, clusters_num=64, dim=128, normalize_input=True, work_with_tokens=False): 89 | """ 90 | Args: 91 | clusters_num : int 92 | The number of clusters 93 | dim : int 94 | Dimension of descriptors 95 | alpha : float 96 | Parameter of initialization. Larger value is harder assignment. 97 | normalize_input : bool 98 | If true, descriptor-wise L2 normalization is applied to input. 99 | """ 100 | super().__init__() 101 | self.clusters_num = clusters_num 102 | self.dim = dim 103 | self.alpha = 0 104 | self.normalize_input = normalize_input 105 | self.work_with_tokens = work_with_tokens 106 | if work_with_tokens: 107 | self.conv = nn.Conv1d(dim, clusters_num, kernel_size=1, bias=False) 108 | else: 109 | self.conv = nn.Conv2d(dim, clusters_num, kernel_size=(1, 1), bias=False) 110 | self.centroids = nn.Parameter(torch.rand(clusters_num, dim)) 111 | 112 | def init_params(self, centroids, descriptors): 113 | centroids_assign = centroids / np.linalg.norm(centroids, axis=1, keepdims=True) 114 | dots = np.dot(centroids_assign, descriptors.T) 115 | dots.sort(0) 116 | dots = dots[::-1, :] # sort, descending 117 | 118 | self.alpha = (-np.log(0.01) / np.mean(dots[0,:] - dots[1,:])).item() 119 | self.centroids = nn.Parameter(torch.from_numpy(centroids)) 120 | if self.work_with_tokens: 121 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha * centroids_assign).unsqueeze(2)) 122 | else: 123 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha*centroids_assign).unsqueeze(2).unsqueeze(3)) 124 | self.conv.bias = None 125 | 126 | def forward(self, x): 127 | if self.work_with_tokens: 128 | x = x.permute(0, 2, 1) 129 | N, D, _ = x.shape[:] 130 | else: 131 | N, D, H, W = x.shape[:] 132 | if self.normalize_input: 133 | x = F.normalize(x, p=2, dim=1) # Across descriptor dim 134 | x_flatten = x.view(N, D, -1) # shape: N, D, HW 135 | soft_assign = self.conv(x).view(N, self.clusters_num, -1) 136 | soft_assign = F.softmax(soft_assign, dim=1) 137 | vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device) 138 | for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage 139 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \ 140 | self.centroids[D:D+1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 141 | residual = residual * soft_assign[:,D:D+1,:].unsqueeze(2) 142 | vlad[:,D:D+1,:] = residual.sum(dim=-1) 143 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 144 | vlad = vlad.view(N, -1) # Flatten 145 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 146 | return vlad 147 | 148 | def initialize_netvlad_layer(self, args, cluster_ds, backbone): 149 | if args.backbone.startswith('swin'): 150 | descriptors_num = 20000 151 | # for swin choose 40 swin_output(16,49,768)->(16,40,768) :sample 100 in res18_output(16,196,256)->(16,100,256) 152 | descs_num_per_image = 40 153 | else: 154 | descriptors_num = 50000 155 | descs_num_per_image = 100 156 | images_num = math.ceil(descriptors_num / descs_num_per_image) 157 | random_sampler_query = SubsetRandomSampler(np.random.choice(range(cluster_ds.database_num, 158 | cluster_ds.database_num+cluster_ds.queries_num), 159 | int(images_num), replace=False)) 160 | random_sampler_database = SubsetRandomSampler(np.random.choice(range(cluster_ds.database_num), 161 | int(images_num), replace=False)) 162 | 163 | random_dl_query = DataLoader(dataset=cluster_ds, num_workers=args.num_workers, 164 | batch_size=args.infer_batch_size, sampler=random_sampler_query) 165 | 166 | random_dl_database = DataLoader(dataset=cluster_ds, num_workers=args.num_workers, 167 | batch_size=args.infer_batch_size, sampler=random_sampler_database) 168 | 169 | with torch.no_grad(): 170 | backbone = backbone.eval() 171 | logging.debug("Extracting features to initialize NetVLAD layer") 172 | 173 | descriptors_q = np.zeros(shape=(descriptors_num, args.features_dim), dtype=np.float32) 174 | for iteration, (inputs, idx) in enumerate(tqdm(random_dl_query, ncols=100, desc='Initializing NetVLAD [Query]')): 175 | inputs = inputs.to(args.device) 176 | outputs = backbone(inputs) 177 | # vit special case: set backbone forward output to a tensor 178 | if args.backbone.startswith('vit'): 179 | outputs = outputs.last_hidden_state 180 | 181 | if args.backbone.startswith('swin') or args.backbone.startswith('vit'): 182 | # swin don't need to reshape 183 | norm_outputs = F.normalize(outputs, p=2, dim=2) 184 | image_descriptors = norm_outputs 185 | else: 186 | norm_outputs = F.normalize(outputs, p=2, dim=1) 187 | image_descriptors = norm_outputs.view(norm_outputs.shape[0], args.features_dim, -1).permute(0, 2, 1) 188 | image_descriptors = image_descriptors.cpu().numpy() 189 | batchix = iteration * args.infer_batch_size * descs_num_per_image 190 | for ix in range(image_descriptors.shape[0]): 191 | sample = np.random.choice(image_descriptors.shape[1], descs_num_per_image, replace=False) 192 | startix = batchix + ix * descs_num_per_image 193 | descriptors_q[startix:startix + descs_num_per_image, :] = image_descriptors[ix, sample, :] 194 | 195 | descriptors_db = np.zeros(shape=(descriptors_num, args.features_dim), dtype=np.float32) 196 | for iteration, (inputs, idx) in enumerate(tqdm(random_dl_database, ncols=100, desc='Initializing NetVLAD [Database]')): 197 | rand_patch_idx = np.random.choice(args.split_nums, 1, replace=False) 198 | inputs_width = inputs.shape[3] 199 | stride = int(inputs_width / args.split_nums) 200 | left = int(stride * rand_patch_idx) 201 | right = int(stride * rand_patch_idx + args.resize[1]) 202 | if left < right <= inputs_width: 203 | inputs = inputs[:, :, :, left:right] 204 | elif left < inputs_width < right: 205 | inputs = torch.cat((inputs[:, :, :, left:], inputs[:, :, :, :right - inputs_width]), dim=-1) 206 | 207 | inputs = inputs.to(args.device) 208 | outputs = backbone(inputs) 209 | # vit special case: set backbone forward output to a tensor 210 | if args.backbone.startswith('vit'): 211 | outputs = outputs.last_hidden_state 212 | 213 | if args.backbone.startswith('swin') or args.backbone.startswith('vit'): 214 | # swin don't need to reshape 215 | norm_outputs = F.normalize(outputs, p=2, dim=2) 216 | image_descriptors = norm_outputs 217 | else: 218 | norm_outputs = F.normalize(outputs, p=2, dim=1) 219 | image_descriptors = norm_outputs.view(norm_outputs.shape[0], args.features_dim, -1).permute(0, 2, 1) 220 | image_descriptors = image_descriptors.cpu().numpy() 221 | batchix = iteration * args.infer_batch_size * descs_num_per_image 222 | for ix in range(image_descriptors.shape[0]): 223 | sample = np.random.choice(image_descriptors.shape[1], descs_num_per_image, replace=False) 224 | startix = batchix + ix * descs_num_per_image 225 | descriptors_db[startix:startix + descs_num_per_image, :] = image_descriptors[ix, sample, :] 226 | 227 | descriptors = np.concatenate((descriptors_q,descriptors_db),axis=0) 228 | kmeans = faiss.Kmeans(args.features_dim, self.clusters_num, niter=100, verbose=False) 229 | kmeans.train(descriptors) 230 | logging.debug(f"NetVLAD centroids shape: {kmeans.centroids.shape}") 231 | self.init_params(kmeans.centroids, descriptors) 232 | self = self.to(args.device) 233 | 234 | 235 | class CRNModule(nn.Module): 236 | def __init__(self, dim): 237 | super().__init__() 238 | # Downsample pooling 239 | self.downsample_pool = nn.AvgPool2d(kernel_size=3, stride=(2, 2), 240 | padding=0, ceil_mode=True) 241 | 242 | # Multiscale Context Filters 243 | self.filter_3_3 = nn.Conv2d(in_channels=dim, out_channels=32, 244 | kernel_size=(3, 3), padding=1) 245 | self.filter_5_5 = nn.Conv2d(in_channels=dim, out_channels=32, 246 | kernel_size=(5, 5), padding=2) 247 | self.filter_7_7 = nn.Conv2d(in_channels=dim, out_channels=20, 248 | kernel_size=(7, 7), padding=3) 249 | 250 | # Accumulation weight 251 | self.acc_w = nn.Conv2d(in_channels=84, out_channels=1, kernel_size=(1, 1)) 252 | # Upsampling 253 | self.upsample = F.interpolate 254 | 255 | self._initialize_weights() 256 | 257 | def _initialize_weights(self): 258 | # Initialize Context Filters 259 | torch.nn.init.xavier_normal_(self.filter_3_3.weight) 260 | torch.nn.init.constant_(self.filter_3_3.bias, 0.0) 261 | torch.nn.init.xavier_normal_(self.filter_5_5.weight) 262 | torch.nn.init.constant_(self.filter_5_5.bias, 0.0) 263 | torch.nn.init.xavier_normal_(self.filter_7_7.weight) 264 | torch.nn.init.constant_(self.filter_7_7.bias, 0.0) 265 | 266 | torch.nn.init.constant_(self.acc_w.weight, 1.0) 267 | torch.nn.init.constant_(self.acc_w.bias, 0.0) 268 | self.acc_w.weight.requires_grad = False 269 | self.acc_w.bias.requires_grad = False 270 | 271 | def forward(self, x): 272 | # Contextual Reweighting Network 273 | x_crn = self.downsample_pool(x) 274 | 275 | # Compute multiscale context filters g_n 276 | g_3 = self.filter_3_3(x_crn) 277 | g_5 = self.filter_5_5(x_crn) 278 | g_7 = self.filter_7_7(x_crn) 279 | g = torch.cat((g_3, g_5, g_7), dim=1) 280 | g = F.relu(g) 281 | 282 | w = F.relu(self.acc_w(g)) # Accumulation weight 283 | mask = self.upsample(w, scale_factor=2, mode='bilinear') # Reweighting Mask 284 | 285 | return mask 286 | 287 | 288 | class CRN(NetVLAD): 289 | def __init__(self, clusters_num=64, dim=128, normalize_input=True): 290 | super().__init__(clusters_num, dim, normalize_input) 291 | self.crn = CRNModule(dim) 292 | 293 | def forward(self, x): 294 | N, D, H, W = x.shape[:] 295 | if self.normalize_input: 296 | x = F.normalize(x, p=2, dim=1) # Across descriptor dim 297 | 298 | mask = self.crn(x) 299 | 300 | x_flatten = x.view(N, D, -1) 301 | soft_assign = self.conv(x).view(N, self.clusters_num, -1) 302 | soft_assign = F.softmax(soft_assign, dim=1) 303 | 304 | # Weight soft_assign using CRN's mask 305 | soft_assign = soft_assign * mask.view(N, 1, H * W) 306 | 307 | vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device) 308 | for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage 309 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \ 310 | self.centroids[D:D + 1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 311 | residual = residual * soft_assign[:, D:D + 1, :].unsqueeze(2) 312 | vlad[:, D:D + 1, :] = residual.sum(dim=-1) 313 | 314 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 315 | vlad = vlad.view(N, -1) # Flatten 316 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 317 | return vlad 318 | 319 | -------------------------------------------------------------------------------- /model/functional.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | def sare_ind(query, positive, negative): 7 | '''all 3 inputs are supposed to be shape 1xn_features''' 8 | dist_pos = ((query - positive)**2).sum(1) 9 | dist_neg = ((query - negative)**2).sum(1) 10 | 11 | dist = - torch.cat((dist_pos, dist_neg)) 12 | dist = F.log_softmax(dist, 0) 13 | 14 | #loss = (- dist[:, 0]).mean() on a batch 15 | loss = -dist[0] 16 | return loss 17 | 18 | def sare_joint(query, positive, negatives): 19 | '''query and positive have to be 1xn_features; whereas negatives has to be 20 | shape n_negative x n_features. n_negative is usually 10''' 21 | # NOTE: the implementation is the same if batch_size=1 as all operations 22 | # are vectorial. If there were the additional n_batch dimension a different 23 | # handling of that situation would have to be implemented here. 24 | # This function is declared anyway for the sake of clarity as the 2 should 25 | # be called in different situations because, even though there would be 26 | # no Exceptions, there would actually be a conceptual error. 27 | return sare_ind(query, positive, negatives) 28 | 29 | def mac(x): 30 | return F.adaptive_max_pool2d(x, (1,1)) 31 | 32 | def spoc(x): 33 | return F.adaptive_avg_pool2d(x, (1,1)) 34 | 35 | def gem(x, p=3, eps=1e-6, work_with_tokens=False): 36 | if work_with_tokens: 37 | x = x.permute(0, 2, 1) 38 | # unseqeeze to maintain compatibility with Flatten 39 | return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1./p).unsqueeze(3) 40 | else: 41 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 42 | 43 | def rmac(x, L=3, eps=1e-6): 44 | ovr = 0.4 # desired overlap of neighboring regions 45 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 46 | W = x.size(3) 47 | H = x.size(2) 48 | w = min(W, H) 49 | # w2 = math.floor(w/2.0 - 1) 50 | b = (max(H, W)-w)/(steps-1) 51 | (tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension 52 | # region overplus per dimension 53 | Wd = 0; 54 | Hd = 0; 55 | if H < W: 56 | Wd = idx.item() + 1 57 | elif H > W: 58 | Hd = idx.item() + 1 59 | v = F.max_pool2d(x, (x.size(-2), x.size(-1))) 60 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v) 61 | for l in range(1, L+1): 62 | wl = math.floor(2*w/(l+1)) 63 | wl2 = math.floor(wl/2 - 1) 64 | if l+Wd == 1: 65 | b = 0 66 | else: 67 | b = (W-wl)/(l+Wd-1) 68 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates 69 | if l+Hd == 1: 70 | b = 0 71 | else: 72 | b = (H-wl)/(l+Hd-1) 73 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates 74 | for i_ in cenH.tolist(): 75 | for j_ in cenW.tolist(): 76 | if wl == 0: 77 | continue 78 | R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:] 79 | R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()] 80 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1))) 81 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt) 82 | v += vt 83 | return v 84 | 85 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from collections import OrderedDict 4 | 5 | import timm 6 | import torch 7 | import logging 8 | import torchvision 9 | from torch import nn 10 | from os.path import join 11 | from timm.models.swin_transformer import SwinTransformer 12 | from timm.models.convnext import ConvNeXt 13 | 14 | from model.aggregation import Flatten 15 | from model.normalization import L2Norm 16 | import model.aggregation as aggregation 17 | from model.non_local import NonLocalBlock 18 | from tools import util 19 | 20 | from mmseg.models.decode_heads import FPNHead 21 | from mmseg.ops import Upsample, resize 22 | 23 | 24 | class GeoLocalizationNet(nn.Module): 25 | """The used networks are composed of a backbone and an aggregation layer. 26 | """ 27 | def __init__(self, args): 28 | super().__init__() 29 | self.backbone = get_backbone(args) 30 | 31 | self.arch_name = args.backbone 32 | self.aggregation = get_aggregation(args) 33 | self.self_att = False 34 | 35 | if args.aggregation in ["gem", "spoc", "mac", "rmac"]: 36 | if args.l2 == "before_pool": 37 | self.aggregation = nn.Sequential(L2Norm(), self.aggregation, Flatten()) 38 | elif args.l2 == "after_pool": 39 | self.aggregation = nn.Sequential(self.aggregation, L2Norm(), Flatten()) 40 | elif args.l2 == "none": 41 | self.aggregation = nn.Sequential(self.aggregation, Flatten()) 42 | 43 | if args.fc_output_dim != None: 44 | # Concatenate fully connected layer to the aggregation layer 45 | self.aggregation = nn.Sequential(self.aggregation, 46 | nn.Linear(args.features_dim, args.fc_output_dim), 47 | L2Norm()) 48 | args.features_dim = args.fc_output_dim 49 | if args.non_local: 50 | non_local_list = [NonLocalBlock(channel_feat=get_output_channels_dim(self.backbone), 51 | channel_inner=args.channel_bottleneck)]* args.num_non_local 52 | self.non_local = nn.Sequential(*non_local_list) 53 | self.self_att = True 54 | 55 | def forward(self, x): 56 | x = self.backbone(x) 57 | 58 | if self.self_att: 59 | x = self.non_local(x) 60 | 61 | x = self.aggregation(x) 62 | 63 | return x 64 | 65 | 66 | class SwinBackbone(SwinTransformer): 67 | 68 | def __init__(self, 69 | backbone_name: str, 70 | depths=(2,2,6,2), 71 | img_size=224, 72 | window_size=7, 73 | embed_dim=96, 74 | num_heads=(3, 6, 12, 24), 75 | patch_size=4, 76 | patch_norm=True, 77 | norm_layer=nn.LayerNorm): 78 | super().__init__(img_size=img_size, patch_size=patch_size, patch_norm=patch_norm, window_size=window_size, 79 | depths=depths, norm_layer=norm_layer,embed_dim=embed_dim,num_heads=num_heads) 80 | 81 | self.depths = depths 82 | self.multi_feature = [] 83 | self.out_layer_idx = self.get_multi_layer_out(depths) 84 | 85 | # load pretrained params 86 | original_model = timm.create_model(backbone_name, pretrained=True) 87 | self.load_state_dict(original_model.state_dict(), strict=False) 88 | 89 | self.query_patch_embed = util.PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, 90 | embed_dim=self.embed_dim, 91 | norm_layer=norm_layer if patch_norm else None) 92 | 93 | self.dataset_patch_embed = util.PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=5, 94 | embed_dim=self.embed_dim, 95 | norm_layer=norm_layer if patch_norm else None) 96 | 97 | @staticmethod 98 | def get_multi_layer_out(depths): 99 | out_layer_idx = list(depths) 100 | cache = 0 101 | for idx, i in enumerate(depths): 102 | cache += i 103 | out_layer_idx[idx] = cache 104 | return out_layer_idx 105 | 106 | def load_state_dict(self, state_dict, strict=True): 107 | 108 | new_state_dict = OrderedDict() 109 | for k, v in state_dict.items(): 110 | if not k.startswith('patch_embed'): 111 | new_state_dict[k] = v 112 | 113 | super().load_state_dict(new_state_dict, strict=strict) 114 | 115 | def forward_features(self, x): 116 | if x.shape[1] == 3: 117 | x = self.query_patch_embed(x) 118 | elif x.shape[1] == 5: 119 | x = self.dataset_patch_embed(x) 120 | else: 121 | raise ValueError(f'input tensor channel should be 3 or 5, but get {x.shape[1]} now') 122 | 123 | if self.absolute_pos_embed is not None: 124 | x = x + self.absolute_pos_embed 125 | x = self.pos_drop(x) 126 | 127 | multi_feature = [] 128 | for layer_idx, layer in enumerate(self.layers): 129 | if layer_idx in range(self.num_layers): 130 | B, L, C = x.shape 131 | multi_feature.append(x.view(B, int(L**0.5), int(L**0.5), C).permute(0,3,1,2)) 132 | 133 | x = layer(x) 134 | self.multi_feature = multi_feature 135 | assert len(self.multi_feature) == len(self.depths) 136 | 137 | x = self.norm(x) # B L C 138 | return x 139 | 140 | 141 | def forward(self, x): 142 | x = self.forward_features(x) 143 | return x 144 | 145 | 146 | class FPNHeadPooling(FPNHead): 147 | def __init__(self, feature_dim=None, **kwargs): 148 | """ 149 | fuse multi-layer features 150 | :param feature_dim: (int) channel_nums we want to get after going through this module 151 | :param kwargs: some important args needed pass in 152 | - in_channels (int|Sequence[int]): Input channels. 153 | - channels (int): Channels after modules, before unify-channel which is a middle tmp channel 154 | - num_classes (int): Number of classes. default=1000 [dont use] 155 | - feature_strides (tuple[int]): The strides for input feature maps. 156 | stack_lateral. All strides suppose to be power of 2. The first 157 | one is of largest resolution. 158 | """ 159 | super().__init__(**kwargs) 160 | self.conv_unify = nn.Conv2d(self.channels, feature_dim, kernel_size=1) 161 | pass 162 | 163 | 164 | def unify_channel(self,x): 165 | out = self.conv_unify(x) 166 | return out 167 | 168 | 169 | def forward(self, inputs): 170 | x = self._transform_inputs(inputs) 171 | 172 | output = self.scale_heads[0](x[0]) 173 | for i in range(1, len(self.feature_strides)): 174 | # non inplace 175 | output = output + resize( 176 | self.scale_heads[i](x[i]), 177 | size=output.shape[2:], 178 | mode='bilinear', 179 | align_corners=self.align_corners) 180 | 181 | output = self.unify_channel(output) # B,C,H,W 182 | output = torch.flatten(output,start_dim=2).permute(0,2,1) # BCHW -> BCL -> BLC 183 | return output 184 | 185 | 186 | class ConvNeXtBackbone(ConvNeXt): 187 | def __init__(self, 188 | backbone_name:str, 189 | depths=(3,3,9,3), 190 | dims=(96,192,384,768)): 191 | super().__init__(depths=depths, dims=dims) 192 | 193 | original_model = timm.create_model(backbone_name, pretrained=True) 194 | self.load_state_dict(original_model.state_dict(), strict=False) 195 | 196 | def forward(self, x): 197 | # Drop the head 198 | x = self.forward_features(x) 199 | return x 200 | 201 | 202 | def get_aggregation(args): 203 | if args.aggregation == "gem": 204 | return aggregation.GeM(work_with_tokens=args.work_with_tokens) 205 | elif args.aggregation == "spoc": 206 | return aggregation.SPoC() 207 | elif args.aggregation == "mac": 208 | return aggregation.MAC() 209 | elif args.aggregation == "rmac": 210 | return aggregation.RMAC() 211 | elif args.aggregation == "netvlad": 212 | return aggregation.NetVLAD(clusters_num=args.netvlad_clusters, dim=args.features_dim, 213 | work_with_tokens=args.work_with_tokens) 214 | elif args.aggregation == 'crn': 215 | return aggregation.CRN(clusters_num=args.netvlad_clusters, dim=args.features_dim) 216 | elif args.aggregation == "rrm": 217 | return aggregation.RRM(args.features_dim) 218 | elif args.aggregation == 'none'\ 219 | or args.aggregation == 'cls' \ 220 | or args.aggregation == 'seqpool': 221 | return nn.Identity() 222 | 223 | 224 | def get_backbone(args): 225 | # The aggregation layer works differently based on the type of architecture 226 | args.work_with_tokens = args.backbone.startswith('swin') 227 | 228 | if args.backbone.startswith("swin"): 229 | if args.backbone.endswith("tiny"): 230 | # check input image size 231 | assert args.resize[0] == 224 232 | model_cfg = dict(depths=(2, 2, 6, 2), 233 | window_size=7, img_size=224, 234 | embed_dim=96, num_heads=(3,6,12,24)) 235 | backbone = SwinBackbone('swin_tiny_patch4_window7_224', **model_cfg) 236 | args.features_dim = 96 * 8 237 | elif args.backbone.endswith("small"): 238 | assert args.resize[0] == 224 239 | model_cfg = dict(depths=(2, 2, 18, 2), 240 | window_size=7, img_size=224, 241 | embed_dim=96,num_heads=(3,6,12,24)) 242 | backbone = SwinBackbone('swin_small_patch4_window7_224', **model_cfg) 243 | args.features_dim = 96 * 8 244 | elif args.backbone.endswith("base"): 245 | assert args.resize[0] == 384 or args.resize[0] == 224 246 | if args.resize[0] == 384: 247 | model_cfg = dict(depths=(2, 2, 18, 2), 248 | window_size=12, img_size=384, 249 | embed_dim=128, num_heads=(4, 8, 16, 32)) 250 | backbone = SwinBackbone('swin_base_patch4_window12_384', **model_cfg) 251 | else: 252 | model_cfg = dict(depths=(2, 2, 18, 2), 253 | window_size=7, img_size=224, 254 | embed_dim=128, num_heads=(4, 8, 16, 32)) 255 | backbone = SwinBackbone('swin_base_patch4_window7_224_in22k', **model_cfg) 256 | args.features_dim = 128 * 8 257 | else: 258 | raise NotImplementedError(f"The interface of {args.backbone} is not implemented") 259 | return backbone 260 | 261 | elif args.backbone.startswith("convnext"): 262 | if args.backbone.endswith("tiny"): 263 | # check input image size 264 | assert args.resize[0] == 224, f'Input size should be either 224 or 384, but get {args.resize[0]}' 265 | model_cfg = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768)) 266 | backbone = ConvNeXtBackbone('convnext_tiny', **model_cfg) 267 | args.features_dim = 96 * 8 268 | elif args.backbone.endswith("small"): 269 | assert args.resize[0] == 224, f'Input size should be either 224 or 384, but get {args.resize[0]}' 270 | model_cfg = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768]) 271 | backbone = ConvNeXtBackbone('convnext_small', **model_cfg) 272 | args.features_dim = 96 * 8 273 | elif args.backbone.endswith("base"): 274 | assert args.resize[0] == 384, f'Input size should be either 224 or 384, but get {args.resize[0]}' 275 | model_cfg = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]) 276 | backbone = ConvNeXtBackbone('convnext_base', **model_cfg) 277 | args.features_dim = 128 * 8 278 | else: 279 | raise NotImplementedError(f"The interface of {args.backbone} is not implemented") 280 | return backbone 281 | 282 | 283 | def get_output_channels_dim(model, type:str= 'feat'): 284 | """Return the number of channels in the output of a model.""" 285 | if type == 'feat': 286 | return model(torch.ones([1, 3, 224, 224])).shape[1] 287 | elif type == 'token': 288 | return model(torch.ones([1, 3, 224, 224])).shape[2] 289 | else: 290 | raise Exception(f'type Err: which should be feat or token but get {type}') 291 | 292 | -------------------------------------------------------------------------------- /model/non_local.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import einops 4 | 5 | 6 | class NonLocalBlock(nn.Module): 7 | def __init__(self, channel_feat, channel_inner, gamma=1): 8 | super().__init__() 9 | self.q_conv = nn.Conv2d(in_channels=channel_feat, 10 | out_channels=channel_inner, 11 | kernel_size=1) 12 | self.k_conv = nn.Conv2d(in_channels=channel_feat, 13 | out_channels=channel_inner, 14 | kernel_size=1) 15 | self.v_conv = nn.Conv2d(in_channels=channel_feat, 16 | out_channels=channel_inner, 17 | kernel_size=1) 18 | self.merge_conv = nn.Conv2d(in_channels=channel_inner, 19 | out_channels=channel_feat, 20 | kernel_size=1) 21 | self.gamma = gamma 22 | 23 | def forward(self, x): 24 | b, c, h, w = x.shape[:] 25 | q_tensor = self.q_conv(x) 26 | k_tensor = self.k_conv(x) 27 | v_tensor = self.v_conv(x) 28 | 29 | q_tensor = einops.rearrange(q_tensor, 'b c h w -> b c (h w)') 30 | k_tensor = einops.rearrange(k_tensor, 'b c h w -> b c (h w)') 31 | v_tensor = einops.rearrange(v_tensor, 'b c h w -> b c (h w)') 32 | 33 | qk_tensor = torch.einsum('b c i, b c j -> b i j', q_tensor, k_tensor) # where i = j = (h * w) 34 | attention = torch.softmax(qk_tensor, -1) 35 | out = torch.einsum('b n i, b c i -> b c n', attention, v_tensor) 36 | out = einops.rearrange(out, 'b c (h w) -> b c h w', h=h, w=w) 37 | out = self.merge_conv(out) 38 | out = self.gamma * out + x 39 | return out 40 | -------------------------------------------------------------------------------- /model/normalization.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class L2Norm(nn.Module): 6 | def __init__(self, dim=1): 7 | super().__init__() 8 | self.dim = dim 9 | def forward(self, x): 10 | return F.normalize(x, p=2, dim=self.dim) 11 | 12 | -------------------------------------------------------------------------------- /model/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import set_sbn_eps_mode 12 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 13 | from .batchnorm import patch_sync_batchnorm, convert_model 14 | from .replicate import DataParallelWithCallback, patch_replication_callback 15 | -------------------------------------------------------------------------------- /model/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | import contextlib 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from torch.nn.modules.batchnorm import _BatchNorm 18 | 19 | try: 20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 21 | except ImportError: 22 | ReduceAddCoalesced = Broadcast = None 23 | 24 | try: 25 | from jactorch.parallel.comm import SyncMaster 26 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback 27 | except ImportError: 28 | from .comm import SyncMaster 29 | from .replicate import DataParallelWithCallback 30 | 31 | __all__ = [ 32 | 'set_sbn_eps_mode', 33 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', 34 | 'patch_sync_batchnorm', 'convert_model' 35 | ] 36 | 37 | 38 | SBN_EPS_MODE = 'clamp' 39 | 40 | 41 | def set_sbn_eps_mode(mode): 42 | global SBN_EPS_MODE 43 | assert mode in ('clamp', 'plus') 44 | SBN_EPS_MODE = mode 45 | 46 | 47 | def _sum_ft(tensor): 48 | """sum over the first and last dimention""" 49 | return tensor.sum(dim=0).sum(dim=-1) 50 | 51 | 52 | def _unsqueeze_ft(tensor): 53 | """add new dimensions at the front and the tail""" 54 | return tensor.unsqueeze(0).unsqueeze(-1) 55 | 56 | 57 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 58 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 59 | 60 | 61 | class _SynchronizedBatchNorm(_BatchNorm): 62 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): 63 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' 64 | 65 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, 66 | track_running_stats=track_running_stats) 67 | 68 | if not self.track_running_stats: 69 | import warnings 70 | warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') 71 | 72 | self._sync_master = SyncMaster(self._data_parallel_master) 73 | 74 | self._is_parallel = False 75 | self._parallel_id = None 76 | self._slave_pipe = None 77 | 78 | def forward(self, input): 79 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 80 | if not (self._is_parallel and self.training): 81 | return F.batch_norm( 82 | input, self.running_mean, self.running_var, self.weight, self.bias, 83 | self.training, self.momentum, self.eps) 84 | 85 | # Resize the input to (B, C, -1). 86 | input_shape = input.size() 87 | assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) 88 | input = input.view(input.size(0), self.num_features, -1) 89 | 90 | # Compute the sum and square-sum. 91 | sum_size = input.size(0) * input.size(2) 92 | input_sum = _sum_ft(input) 93 | input_ssum = _sum_ft(input ** 2) 94 | 95 | # Reduce-and-broadcast the statistics. 96 | if self._parallel_id == 0: 97 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 98 | else: 99 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 100 | 101 | # Compute the output. 102 | if self.affine: 103 | # MJY:: Fuse the multiplication for speed. 104 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 105 | else: 106 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 107 | 108 | # Reshape it. 109 | return output.view(input_shape) 110 | 111 | def __data_parallel_replicate__(self, ctx, copy_id): 112 | self._is_parallel = True 113 | self._parallel_id = copy_id 114 | 115 | # parallel_id == 0 means master device. 116 | if self._parallel_id == 0: 117 | ctx.sync_master = self._sync_master 118 | else: 119 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 120 | 121 | def _data_parallel_master(self, intermediates): 122 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 123 | 124 | # Always using same "device order" makes the ReduceAdd operation faster. 125 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 126 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 127 | 128 | to_reduce = [i[1][:2] for i in intermediates] 129 | to_reduce = [j for i in to_reduce for j in i] # flatten 130 | target_gpus = [i[1].sum.get_device() for i in intermediates] 131 | 132 | sum_size = sum([i[1].sum_size for i in intermediates]) 133 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 134 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 135 | 136 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 137 | 138 | outputs = [] 139 | for i, rec in enumerate(intermediates): 140 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 141 | 142 | return outputs 143 | 144 | def _compute_mean_std(self, sum_, ssum, size): 145 | """Compute the mean and standard-deviation with sum and square-sum. This method 146 | also maintains the moving average on the master device.""" 147 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 148 | mean = sum_ / size 149 | sumvar = ssum - sum_ * mean 150 | unbias_var = sumvar / (size - 1) 151 | bias_var = sumvar / size 152 | 153 | if hasattr(torch, 'no_grad'): 154 | with torch.no_grad(): 155 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 156 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 157 | else: 158 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 159 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 160 | 161 | if SBN_EPS_MODE == 'clamp': 162 | return mean, bias_var.clamp(self.eps) ** -0.5 163 | elif SBN_EPS_MODE == 'plus': 164 | return mean, (bias_var + self.eps) ** -0.5 165 | else: 166 | raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) 167 | 168 | 169 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 170 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 171 | mini-batch. 172 | 173 | .. math:: 174 | 175 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 176 | 177 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 178 | standard-deviation are reduced across all devices during training. 179 | 180 | For example, when one uses `nn.DataParallel` to wrap the network during 181 | training, PyTorch's implementation normalize the tensor on each device using 182 | the statistics only on that device, which accelerated the computation and 183 | is also easy to implement, but the statistics might be inaccurate. 184 | Instead, in this synchronized version, the statistics will be computed 185 | over all training samples distributed on multiple devices. 186 | 187 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 188 | as the built-in PyTorch implementation. 189 | 190 | The mean and standard-deviation are calculated per-dimension over 191 | the mini-batches and gamma and beta are learnable parameter vectors 192 | of size C (where C is the input size). 193 | 194 | During training, this layer keeps a running estimate of its computed mean 195 | and variance. The running sum is kept with a default momentum of 0.1. 196 | 197 | During evaluation, this running mean/variance is used for normalization. 198 | 199 | Because the BatchNorm is done over the `C` dimension, computing statistics 200 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 201 | 202 | Args: 203 | num_features: num_features from an expected input of size 204 | `batch_size x num_features [x width]` 205 | eps: a value added to the denominator for numerical stability. 206 | Default: 1e-5 207 | momentum: the value used for the running_mean and running_var 208 | computation. Default: 0.1 209 | affine: a boolean value that when set to ``True``, gives the layer learnable 210 | affine parameters. Default: ``True`` 211 | 212 | Shape:: 213 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 214 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 215 | 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm1d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 2 and input.dim() != 3: 227 | raise ValueError('expected 2D or 3D input (got {}D input)' 228 | .format(input.dim())) 229 | 230 | 231 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 232 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 233 | of 3d inputs 234 | 235 | .. math:: 236 | 237 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 238 | 239 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 240 | standard-deviation are reduced across all devices during training. 241 | 242 | For example, when one uses `nn.DataParallel` to wrap the network during 243 | training, PyTorch's implementation normalize the tensor on each device using 244 | the statistics only on that device, which accelerated the computation and 245 | is also easy to implement, but the statistics might be inaccurate. 246 | Instead, in this synchronized version, the statistics will be computed 247 | over all training samples distributed on multiple devices. 248 | 249 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 250 | as the built-in PyTorch implementation. 251 | 252 | The mean and standard-deviation are calculated per-dimension over 253 | the mini-batches and gamma and beta are learnable parameter vectors 254 | of size C (where C is the input size). 255 | 256 | During training, this layer keeps a running estimate of its computed mean 257 | and variance. The running sum is kept with a default momentum of 0.1. 258 | 259 | During evaluation, this running mean/variance is used for normalization. 260 | 261 | Because the BatchNorm is done over the `C` dimension, computing statistics 262 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 263 | 264 | Args: 265 | num_features: num_features from an expected input of 266 | size batch_size x num_features x height x width 267 | eps: a value added to the denominator for numerical stability. 268 | Default: 1e-5 269 | momentum: the value used for the running_mean and running_var 270 | computation. Default: 0.1 271 | affine: a boolean value that when set to ``True``, gives the layer learnable 272 | affine parameters. Default: ``True`` 273 | 274 | Shape:: 275 | - Input: :math:`(N, C, H, W)` 276 | - Output: :math:`(N, C, H, W)` (same shape as input) 277 | 278 | Examples: 279 | >>> # With Learnable Parameters 280 | >>> m = SynchronizedBatchNorm2d(100) 281 | >>> # Without Learnable Parameters 282 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 283 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 284 | >>> output = m(input) 285 | """ 286 | 287 | def _check_input_dim(self, input): 288 | if input.dim() != 4: 289 | raise ValueError('expected 4D input (got {}D input)' 290 | .format(input.dim())) 291 | 292 | 293 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 294 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 295 | of 4d inputs 296 | 297 | .. math:: 298 | 299 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 300 | 301 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 302 | standard-deviation are reduced across all devices during training. 303 | 304 | For example, when one uses `nn.DataParallel` to wrap the network during 305 | training, PyTorch's implementation normalize the tensor on each device using 306 | the statistics only on that device, which accelerated the computation and 307 | is also easy to implement, but the statistics might be inaccurate. 308 | Instead, in this synchronized version, the statistics will be computed 309 | over all training samples distributed on multiple devices. 310 | 311 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 312 | as the built-in PyTorch implementation. 313 | 314 | The mean and standard-deviation are calculated per-dimension over 315 | the mini-batches and gamma and beta are learnable parameter vectors 316 | of size C (where C is the input size). 317 | 318 | During training, this layer keeps a running estimate of its computed mean 319 | and variance. The running sum is kept with a default momentum of 0.1. 320 | 321 | During evaluation, this running mean/variance is used for normalization. 322 | 323 | Because the BatchNorm is done over the `C` dimension, computing statistics 324 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 325 | or Spatio-temporal BatchNorm 326 | 327 | Args: 328 | num_features: num_features from an expected input of 329 | size batch_size x num_features x depth x height x width 330 | eps: a value added to the denominator for numerical stability. 331 | Default: 1e-5 332 | momentum: the value used for the running_mean and running_var 333 | computation. Default: 0.1 334 | affine: a boolean value that when set to ``True``, gives the layer learnable 335 | affine parameters. Default: ``True`` 336 | 337 | Shape:: 338 | - Input: :math:`(N, C, D, H, W)` 339 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 340 | 341 | Examples: 342 | >>> # With Learnable Parameters 343 | >>> m = SynchronizedBatchNorm3d(100) 344 | >>> # Without Learnable Parameters 345 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 346 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 347 | >>> output = m(input) 348 | """ 349 | 350 | def _check_input_dim(self, input): 351 | if input.dim() != 5: 352 | raise ValueError('expected 5D input (got {}D input)' 353 | .format(input.dim())) 354 | 355 | 356 | @contextlib.contextmanager 357 | def patch_sync_batchnorm(): 358 | import torch.nn as nn 359 | 360 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d 361 | 362 | nn.BatchNorm1d = SynchronizedBatchNorm1d 363 | nn.BatchNorm2d = SynchronizedBatchNorm2d 364 | nn.BatchNorm3d = SynchronizedBatchNorm3d 365 | 366 | yield 367 | 368 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup 369 | 370 | 371 | def convert_model(module): 372 | """Traverse the input module and its child recursively 373 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d 374 | to SynchronizedBatchNorm*N*d 375 | 376 | Args: 377 | module: the input module needs to be convert to SyncBN model 378 | 379 | Examples: 380 | >>> import torch.nn as nn 381 | >>> import torchvision 382 | >>> # m is a standard pytorch model 383 | >>> m = torchvision.models.resnet18(True) 384 | >>> m = nn.DataParallel(m) 385 | >>> # after convert, m is using SyncBN 386 | >>> m = convert_model(m) 387 | """ 388 | if isinstance(module, torch.nn.DataParallel): 389 | mod = module.module 390 | mod = convert_model(mod) 391 | mod = DataParallelWithCallback(mod, device_ids=module.device_ids) 392 | return mod 393 | 394 | mod = module 395 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, 396 | torch.nn.modules.batchnorm.BatchNorm2d, 397 | torch.nn.modules.batchnorm.BatchNorm3d], 398 | [SynchronizedBatchNorm1d, 399 | SynchronizedBatchNorm2d, 400 | SynchronizedBatchNorm3d]): 401 | if isinstance(module, pth_module): 402 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) 403 | mod.running_mean = module.running_mean 404 | mod.running_var = module.running_var 405 | if module.affine: 406 | mod.weight.data = module.weight.data.clone().detach() 407 | mod.bias.data = module.bias.data.clone().detach() 408 | 409 | for name, child in module.named_children(): 410 | mod.add_module(name, convert_model(child)) 411 | 412 | return mod 413 | -------------------------------------------------------------------------------- /model/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /model/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /model/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /model/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) 29 | 30 | -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import argparse 5 | 6 | 7 | def parse_arguments(): 8 | parser = argparse.ArgumentParser(description="Benchmarking Visual Geolocalization", 9 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | # Training parameters 11 | parser.add_argument("--train_batch_size", type=int, default=2, 12 | help="Number of triplets (query, pos, negs) in a batch. Each triplet consists of 12 images") 13 | parser.add_argument("--infer_batch_size", type=int, default=16, 14 | help="Batch size for inference (caching and testing)") 15 | parser.add_argument("--criterion", type=str, default='triplet', 16 | help='loss to be used') 17 | parser.add_argument("--margin", type=float, default=0.1, 18 | help="margin for the triplet loss") 19 | parser.add_argument("--epochs_num", type=int, default=60, 20 | help="number of epochs to train for") 21 | parser.add_argument("--patience", type=int, default=10, help="_") 22 | parser.add_argument("--lr", type=float, default=0.00001, help="_") 23 | parser.add_argument("--optim", type=str, default="adam", choices=["adam", "sgd"]) 24 | parser.add_argument("--cache_refresh_rate", type=int, default=125, 25 | help="How often to refresh cache, in number of queries") 26 | # MARK: fine-tune should focus on 27 | parser.add_argument("--queries_per_epoch", type=int, default=500, 28 | help="How many queries to consider for one epoch. Must be multiple of cache_refresh_rate") 29 | parser.add_argument("--negs_num_per_query", type=int, default=10, 30 | help="How many negatives to consider per each query in the loss") 31 | parser.add_argument("--neg_samples_num", type=int, default=125, 32 | help="How many negatives to use to compute the hardest ones") 33 | parser.add_argument("--mining", type=str, default="partial", choices=["partial", "full", "random"]) 34 | # Model parameters 35 | parser.add_argument("--backbone", type=str, default="swin_tiny", 36 | choices=["swin_tiny", "swin_small","convnext_tiny", "convnext_small"], help="_") 37 | parser.add_argument("--l2", type=str, default="before_pool", choices=["before_pool", "after_pool", "none"], 38 | help="When (and if) to apply the l2 norm with shallow aggregation layers") 39 | parser.add_argument("--aggregation", type=str, default="gem", choices=["gem", "spoc", "mac", "rmac"]) 40 | 41 | parser.add_argument('--pca_dim', type=int, default=None, help="PCA dimension (number of principal components). If None, PCA is not used.") 42 | parser.add_argument('--num_non_local', type=int, default=1, help="Num of non local blocks") 43 | parser.add_argument("--non_local", action='store_true', help="_") 44 | parser.add_argument('--channel_bottleneck', type=int, default=128, help="Channel bottleneck for Non-Local blocks") 45 | parser.add_argument('--fc_output_dim', type=int, default=None, 46 | help="Output dimension of fully connected layer. If None, don't use a fully connected layer.") 47 | parser.add_argument("--clip", type=int, default=10, choices=[1,3,5,7,10], help='Gradient clip avoiding loss wave') 48 | # Shift window parameters (in image or vector) 49 | parser.add_argument('--reduce_factor', type=int, default=1, help='/n -> window_stride shorten factor compared to args.feature') 50 | parser.add_argument('--split_nums', type=int, default=16, help='choose how many parts to split pano_datasets image') 51 | # Initialization parameters 52 | parser.add_argument("--seed", type=int, default=0) 53 | parser.add_argument("--resume", type=str, default=None, 54 | help="Path to load checkpoint from, for resuming training or testing.") 55 | # Other parameters 56 | parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) 57 | parser.add_argument("--num_workers", type=int, default=6, help="num_workers for all dataloaders") 58 | parser.add_argument('--resize', type=int, default=[224, 224], nargs=2, help="Resizing shape for images (HxW).") 59 | parser.add_argument('--query_resize', type=int, default=[448, 448], nargs=2, help="Resize for query") 60 | parser.add_argument('--database_resize', type=int, default=[448, 3584], nargs=2, help="Resize for database") 61 | parser.add_argument('--test_method', type=str, default="hard_resize", 62 | choices=["hard_resize", "single_query", "central_crop", "five_crops", "nearest_crop", "maj_voting"], 63 | help="This includes pre/post-processing methods and prediction refinement") 64 | parser.add_argument("--val_positive_dist_threshold", type=int, default=25, help="_") 65 | parser.add_argument("--train_positives_dist_threshold", type=int, default=10, help="_") 66 | parser.add_argument('--recall_values', type=int, default=[1, 5, 10, 20], nargs="+", 67 | help="Recalls to be computed, such as R@5.") 68 | # Data augmentation parameters 69 | parser.add_argument("--brightness", type=float, default=None, help="_") 70 | parser.add_argument("--contrast", type=float, default=None, help="_") 71 | parser.add_argument("--saturation", type=float, default=None, help="_") 72 | parser.add_argument("--hue", type=float, default=None, help="_") 73 | parser.add_argument("--rand_perspective", type=float, default=None, help="_") 74 | parser.add_argument("--horizontal_flip", action='store_true', help="_") 75 | parser.add_argument("--random_resized_crop", type=float, default=None, help="_") 76 | parser.add_argument("--random_rotation", type=float, default=None, help="_") 77 | # Paths parameters 78 | parser.add_argument("--datasets_folder", type=str, default='/home/shize/Datasets', 79 | help="Path with all datasets") 80 | parser.add_argument("--dataset_name", type=str, default="pitts250k", help="Relative path of the dataset") 81 | parser.add_argument("--pca_dataset_folder", type=str, default=None, 82 | help="Path with images to be used to compute PCA (ie: pitts30k/images/train") 83 | parser.add_argument("--save_dir", type=str, default="SwinT", 84 | help="Folder name of the current run (saved in ./logs})") 85 | parser.add_argument("--title", type=str, required=True, help="Abstract the experiment config") 86 | parser.add_argument("--train_visual_save_path", type=str, default='visualize/Insert_x32w/train_set/', 87 | help="Mining [train] results save path") 88 | parser.add_argument("--val_visual_save_path", type=str, default='visualize/Insert_x32w/val_set/', 89 | help="Inference [val] results save path") 90 | parser.add_argument("--test_visual_save_path", type=str, default='visualize/Insert_x32w/test_set/', 91 | help="Inference [test] results save path") 92 | args = parser.parse_args() 93 | 94 | if args.datasets_folder == None: 95 | try: 96 | args.datasets_folder = os.environ['DATASETS_FOLDER'] 97 | except KeyError: 98 | raise Exception("You should set the parameter --datasets_folder or export " + 99 | "the DATASETS_FOLDER environment variable as such \n" + 100 | "export DATASETS_FOLDER=../datasets_vg/datasets") 101 | 102 | if args.queries_per_epoch % args.cache_refresh_rate != 0: 103 | raise ValueError("Ensure that queries_per_epoch is divisible by cache_refresh_rate, " + 104 | f"because {args.queries_per_epoch} is not divisible by {args.cache_refresh_rate}") 105 | 106 | if args.pca_dim != None and args.pca_dataset_folder == None: 107 | raise ValueError("Please specify --pca_dataset_folder when using pca") 108 | 109 | if args.split_nums < 8: 110 | raise ValueError("split_nums should be specified to 8/16/24/32") 111 | 112 | return args 113 | 114 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Tuple 4 | import unittest 5 | import faiss 6 | import torch 7 | import logging 8 | import numpy as np 9 | from matplotlib import pyplot as plt 10 | from numpy import ndarray 11 | from tqdm import tqdm 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data.dataset import Subset 14 | 15 | from tools import commons 16 | import datasets_ws 17 | import parser 18 | from model import network 19 | from collections import OrderedDict 20 | from os.path import join 21 | from datetime import datetime 22 | 23 | from datasets_ws import shift_window_on_descriptor # import shift_window_similar calculate func 24 | from model.sync_batchnorm import convert_model 25 | from tools.visual import display_inference 26 | 27 | 28 | def test(args, eval_ds, model, test_method="hard_resize", pca=None, show_inference_results=None, save_path=None): 29 | """Compute features of the given dataset and compute the recalls.""" 30 | 31 | assert test_method in ["hard_resize"], f"test_method can't be {test_method}" 32 | 33 | model = model.eval() 34 | eval_ds.test_method = test_method 35 | 36 | with torch.no_grad(): 37 | logging.debug("Extracting database features for evaluation/testing") 38 | 39 | database_features = np.empty((eval_ds.database_num, args.split_nums * args.features_dim), dtype="float32") 40 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num))) 41 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers, 42 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda")) 43 | 44 | # Database inputs shape : B, N, C, resize_H, resize_W 45 | for inputs, indices in tqdm(database_dataloader, ncols=100, desc='Extracting database features'): 46 | B, C, H, W = inputs.shape 47 | inputs = torch.stack([datasets_ws.shift_window_on_img(one_pano, eval_ds.split_nums, eval_ds.window_stride, 48 | eval_ds.window_len) for one_pano in inputs]) 49 | inputs = inputs.view(B * eval_ds.split_nums, C, eval_ds.resize[0], eval_ds.resize[1]) 50 | 51 | features = model(inputs.to(args.device)) 52 | # B*split_nums, feature_dim -> # B, split_nums*feature_dim 53 | features = torch.flatten(features.view(B, eval_ds.split_nums, -1), start_dim=1) 54 | features = features.cpu().numpy() 55 | if pca != None: 56 | features = pca.transform(features) 57 | 58 | database_features[indices.numpy(), :] = features 59 | 60 | logging.debug("Extracting queries features for evaluation/testing") 61 | 62 | queries_infer_batch_size = args.infer_batch_size 63 | 64 | queries_features = np.empty((eval_ds.queries_num, args.features_dim), dtype="float32") 65 | queries_subset_ds = Subset(eval_ds, 66 | list(range(eval_ds.database_num, eval_ds.database_num + eval_ds.queries_num))) 67 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers, 68 | batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda")) 69 | 70 | # Query features shape: B, C, H, W 71 | for inputs, indices in tqdm(queries_dataloader, ncols=100, desc='Extracting queries features'): 72 | features = model(inputs.to(args.device)) 73 | features = features.cpu().numpy() 74 | 75 | if pca != None: 76 | features = pca.transform(features) 77 | 78 | # NOTE!! minus database_num to begin from 0 79 | queries_features[indices.numpy() - eval_ds.database_num, :] = features 80 | 81 | # Sliding Window Matching Descriptor 82 | shift_window_start = time.time() 83 | predictions = [] 84 | focus_patch_loc = [] 85 | for one_query_feature in queries_features: 86 | predictions_per_query, focus_patch_loc_per_query = shift_window_on_descriptor(one_query_feature, 87 | database_features, 88 | args.features_dim, 89 | args.reduce_factor, 90 | max(args.recall_values)) 91 | predictions.append(predictions_per_query) 92 | focus_patch_loc.append(focus_patch_loc_per_query) # show results interface 93 | shift_window_end = time.time() 94 | print(f'Searching all query in pano databases uses time:{shift_window_end-shift_window_start:.3f}s') 95 | 96 | # Visualization of Inference Results 97 | if show_inference_results: 98 | os.makedirs(save_path, exist_ok=True) 99 | display_inference(eval_ds, predictions, save_path, focus_patch_loc) 100 | 101 | #### For each query, check if the predictions are correct 102 | check_start = time.time() 103 | positives_per_query = eval_ds.get_positives() 104 | # args.recall_values by default is [1, 5, 10, 20] 105 | recalls = np.zeros(len(args.recall_values)) 106 | for query_index, pred in enumerate(predictions): 107 | for i, n in enumerate(args.recall_values): 108 | if np.any(np.in1d(pred[:n], positives_per_query[query_index])): 109 | recalls[i:] += 1 110 | break 111 | # Divide by the number of queries*100, so the recalls are in percentages 112 | recalls = recalls / eval_ds.queries_num * 100 113 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(args.recall_values, recalls)]) 114 | check_end = time.time() 115 | print(f'Checking whether the predict is right uses time:{check_end-check_start:.3f}s') 116 | return recalls, recalls_str 117 | 118 | 119 | def main(): 120 | # Initial setup: parser 121 | args = parser.parse_arguments() 122 | 123 | # Set Logger 124 | start_time = datetime.now() 125 | args.save_dir = join("logs", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S') + '_' + args.title) 126 | commons.setup_logging(args.save_dir) 127 | logging.info(f"The outputs are being saved in {args.save_dir}") 128 | 129 | # Initialize model 130 | model = network.GeoLocalizationNet(args) 131 | model = model.to(args.device) 132 | 133 | # Muti-GPU Setting 134 | if torch.cuda.device_count() > 1: 135 | model = torch.nn.DataParallel(model) 136 | 137 | # val_ds 138 | val_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "val") 139 | logging.debug(f"Val set: {val_ds}") 140 | 141 | # test_ds 142 | test_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "test") 143 | logging.debug(f"Test set: {test_ds}") 144 | 145 | # load model params and run 146 | best_model_state_dict = torch.load(join(args.resume, "best_model.pth"))["model_state_dict"] 147 | 148 | if not torch.cuda.device_count() >= 2: 149 | best_model_state_dict = OrderedDict({k.replace('module.', ''): v for (k, v) in best_model_state_dict.items()}) 150 | 151 | model.load_state_dict(best_model_state_dict) 152 | logging.info('Load pretrained model correctly!') 153 | 154 | recalls, recalls_str = test(args, eval_ds=val_ds, model=model) 155 | logging.info(f"Recalls on [Val-set]:{val_ds}: {recalls_str}") 156 | 157 | recalls, recalls_str = test(args, eval_ds=test_ds, model=model, show_inference_results=None) 158 | logging.info(f"Recalls on [Test-set]:{test_ds}: {recalls_str}") 159 | 160 | 161 | if __name__ == '__main__': 162 | main() 163 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zafirshi/PanoVPR/7f576b7679691882fc8f0346930deb0aff6d1e38/tools/__init__.py -------------------------------------------------------------------------------- /tools/commons.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | This file contains some functions and classes which can be useful in very diverse projects. 4 | """ 5 | 6 | import os 7 | import sys 8 | import torch 9 | import random 10 | import logging 11 | import traceback 12 | import numpy as np 13 | from os.path import join 14 | 15 | 16 | def make_deterministic(seed=0, speedup=None): 17 | """Make results deterministic. If seed == -1, do not make deterministic. 18 | Running the script in a deterministic way might slow it down. 19 | """ 20 | if seed == -1: 21 | return 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = True if speedup else False 28 | 29 | 30 | def setup_logging(save_dir, console="debug", 31 | info_filename="info.log", debug_filename="debug.log"): 32 | """Set up logging files and console output. 33 | Creates one file for INFO logs and one for DEBUG logs. 34 | Args: 35 | save_dir (str): creates the folder where to save the files. 36 | debug (str): 37 | if == "debug" prints on console debug messages and higher 38 | if == "info" prints on console info messages and higher 39 | if == None does not use console (useful when a logger has already been set) 40 | info_filename (str): the name of the info file. if None, don't create info file 41 | debug_filename (str): the name of the debug file. if None, don't create debug file 42 | """ 43 | if os.path.exists(save_dir): 44 | raise FileExistsError(f"{save_dir} already exists!") 45 | os.makedirs(save_dir, exist_ok=True) 46 | # logging.Logger.manager.loggerDict.keys() to check which loggers are in use 47 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S") 48 | logger = logging.getLogger('') 49 | logger.setLevel(logging.DEBUG) 50 | 51 | if info_filename != None: 52 | info_file_handler = logging.FileHandler(join(save_dir, info_filename)) 53 | info_file_handler.setLevel(logging.INFO) 54 | info_file_handler.setFormatter(base_formatter) 55 | logger.addHandler(info_file_handler) 56 | 57 | if debug_filename != None: 58 | debug_file_handler = logging.FileHandler(join(save_dir, debug_filename)) 59 | debug_file_handler.setLevel(logging.DEBUG) 60 | debug_file_handler.setFormatter(base_formatter) 61 | logger.addHandler(debug_file_handler) 62 | 63 | if console != None: 64 | console_handler = logging.StreamHandler() 65 | if console == "debug": console_handler.setLevel(logging.DEBUG) 66 | if console == "info": console_handler.setLevel(logging.INFO) 67 | console_handler.setFormatter(base_formatter) 68 | logger.addHandler(console_handler) 69 | 70 | def exception_handler(type_, value, tb): 71 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb))) 72 | sys.excepthook = exception_handler 73 | 74 | -------------------------------------------------------------------------------- /tools/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def shift_window_triple_loss(args, query_feature, database_feature, loss_fn): 5 | """ 6 | Calculate the shift window triple loss for a given query feature against a database feature. 7 | 8 | This function computes the similarity between a query feature and different windows of a database feature. 9 | It uses a shift window approach where the window moves across the database feature in steps defined by 10 | the reduce_factor in args. The loss is calculated using the provided loss function (loss_fn) which takes 11 | the query feature, a positive match from the database, and a negative match from the database. 12 | 13 | Parameters: 14 | - args (Namespace): A namespace or similar object containing configuration parameters, including 15 | 'reduce_factor' which determines the step size for window shifting. 16 | - query_feature (torch.Tensor): A tensor representing the query feature. 17 | - database_feature (torch.Tensor): A tensor representing the database feature against which the query 18 | is compared. 19 | - loss_fn (function): A loss function that takes three arguments (query_feature, positive_feature, 20 | negative_feature) and returns a scalar loss value. 21 | 22 | Returns: 23 | - loss_sum (float): The accumulated loss over all the shifted windows. 24 | - min_index_in_row (torch.Tensor): The indices of the minimum values in the maintain_table, which 25 | represent the most similar window of database_feature to the query_feature. 26 | 27 | The function operates by sliding a window across the database feature and computing a similarity 28 | measure between the query feature and each window using L2 norm. It identifies the most similar 29 | window (positive match) and uses other windows as negative matches to compute the loss. The process 30 | is repeated for each window, and the losses are summed up to get the total loss. 31 | 32 | The function assumes that the query_feature and database_feature have compatible shapes and that the 33 | loss_fn is properly defined to handle the inputs. 34 | 35 | Example: 36 | >>> loss, indices = shift_window_triple_loss(args, query_feature, database_feature, loss_fn) 37 | """ 38 | window_size = query_feature.shape[0] 39 | step = int(window_size/args.reduce_factor) 40 | left, right = 0, window_size 41 | loss_sum = 0. 42 | 43 | split_similarity_list = [] 44 | width_bound = database_feature.shape[1] 45 | while left < width_bound: 46 | if right <= width_bound: 47 | split_similarity = torch.norm(query_feature - database_feature[:, left:right], p=2, dim=1) 48 | else: 49 | cycle = torch.concat((database_feature[:,left:],database_feature[:,:right-width_bound]),dim=1) 50 | split_similarity = torch.norm(query_feature-cycle,p=2,dim=1) 51 | split_similarity_list.append(split_similarity) 52 | 53 | left += step 54 | right += step 55 | 56 | # Similarity matrix for each split window and get the most similar slice_index using L2 Norm 57 | maintain_table = torch.transpose(torch.stack(split_similarity_list), 0, 1) # shape:11, split_nums 58 | maintain_table_aggregation, min_index_in_row = torch.min(maintain_table, dim=1) 59 | 60 | # 0 -> positive and slice long vector according to slice_index above 61 | if min_index_in_row[0] * step + window_size <= width_bound: 62 | filtered_positive = database_feature[0, 63 | min_index_in_row[0] * step:min_index_in_row[0] * step + window_size] # shape:feature_dim 64 | else: 65 | filtered_positive = torch.concat((database_feature[0,min_index_in_row[0] * step:], 66 | database_feature[0,:min_index_in_row[0] * step + window_size - width_bound]), 67 | dim=-1) 68 | 69 | 70 | for i in range(1, 11): 71 | # 1 -> 10 refer to negs 72 | if min_index_in_row[i] * step + window_size <=width_bound: 73 | filtered_negative = database_feature[i, 74 | min_index_in_row[i] * step:min_index_in_row[i] * step + window_size] # shape:feature_dim 75 | else: 76 | filtered_negative = torch.concat((database_feature[i,min_index_in_row[i] * step:], 77 | database_feature[i,:min_index_in_row[i] * step + window_size - width_bound]), 78 | dim=-1) 79 | 80 | loss_sum += loss_fn(query_feature, filtered_positive, filtered_negative) 81 | 82 | return loss_sum, min_index_in_row 83 | 84 | -------------------------------------------------------------------------------- /tools/map_builder.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import os 4 | import cv2 5 | import math 6 | import numpy as np 7 | from glob import glob 8 | from skimage import io 9 | from os.path import join 10 | import matplotlib.cm as cm 11 | import matplotlib.pyplot as plt 12 | from collections import defaultdict 13 | from staticmap import StaticMap, Polygon 14 | from matplotlib.colors import ListedColormap 15 | 16 | 17 | def _lon_to_x(lon, zoom): 18 | if not (-180 <= lon <= 180): lon = (lon + 180) % 360 - 180 19 | return ((lon + 180.) / 360) * pow(2, zoom) 20 | 21 | 22 | def _lat_to_y(lat, zoom): 23 | if not (-90 <= lat <= 90): lat = (lat + 90) % 180 - 90 24 | return (1 - math.log(math.tan(lat * math.pi / 180) + 1 / math.cos(lat * math.pi / 180)) / math.pi) / 2 * pow(2, zoom) 25 | 26 | 27 | def _download_map_image(min_lat=45.0, min_lon=7.6, max_lat=45.1, max_lon=7.7, size=2000): 28 | """"Download a map of the chosen area as a numpy image""" 29 | mean_lat = (min_lat + max_lat) / 2 30 | mean_lon = (min_lon + max_lon) / 2 31 | static_map = StaticMap(size, size, url_template='https://api.mapbox.com/styles/v1/mapbox/satellite-v9/tiles/{z}/{x}/{y}?access_token=') 32 | static_map.add_polygon( 33 | Polygon(((min_lon, min_lat), (min_lon, max_lat), (max_lon, max_lat), (max_lon, min_lat)), None, '#FFFFFF')) 34 | zoom = static_map._calculate_zoom() 35 | static_map = StaticMap(size, size) 36 | image = static_map.render(zoom, [mean_lon, mean_lat]) 37 | print( 38 | f"You can see the map on Google Maps at this link www.google.com/maps/place/@{mean_lat},{mean_lon},{zoom - 1}z") 39 | min_lat_px, min_lon_px, max_lat_px, max_lon_px = \ 40 | static_map._y_to_px(_lat_to_y(min_lat, zoom)), \ 41 | static_map._x_to_px(_lon_to_x(min_lon, zoom)), \ 42 | static_map._y_to_px(_lat_to_y(max_lat, zoom)), \ 43 | static_map._x_to_px(_lon_to_x(max_lon, zoom)) 44 | assert 0 <= max_lat_px < min_lat_px < size and 0 <= min_lon_px < max_lon_px < size 45 | return np.array(image)[max_lat_px:min_lat_px, min_lon_px:max_lon_px], static_map, zoom 46 | 47 | 48 | def get_edges(coordinates, enlarge=0): 49 | """ 50 | Send the edges of the coordinates, i.e. the most south, west, north and 51 | east coordinates. 52 | :param coordinates: A list of numpy.arrays of shape (Nx2) 53 | :param float enlarge: How much to increase the coordinates, to enlarge 54 | the area included between the points 55 | :return: a tuple with the four float 56 | """ 57 | min_lat, min_lon, max_lat, max_lon = (*np.concatenate(coordinates).min(0), *np.concatenate(coordinates).max(0)) 58 | diff_lat = (max_lat - min_lat) * enlarge 59 | diff_lon = (max_lon - min_lon) * enlarge 60 | inc_min_lat, inc_min_lon, inc_max_lat, inc_max_lon = \ 61 | min_lat - diff_lat, min_lon - diff_lon, max_lat + diff_lat, max_lon + diff_lon 62 | return inc_min_lat, inc_min_lon, inc_max_lat, inc_max_lon 63 | 64 | 65 | def _create_map(coordinates, colors=None, dot_sizes=None, legend_names=None, map_intensity=0.6): 66 | dot_sizes = dot_sizes if dot_sizes is not None else [10] * len(coordinates) 67 | colors = colors if colors is not None else ["r"] * len(coordinates) 68 | assert len(coordinates) == len(dot_sizes) == len(colors), \ 69 | f"The number of coordinates must be equals to the number of colors and dot_sizes, but they're " \ 70 | f"{len(coordinates)}, {len(colors)}, {len(dot_sizes)}" 71 | 72 | # Add two dummy points to slightly enlarge the map 73 | min_lat, min_lon, max_lat, max_lon = get_edges(coordinates, enlarge=0.1) 74 | coordinates.append(np.array([[min_lat, min_lon], [max_lat, max_lon]])) 75 | # Download the map of the chosen area 76 | map_img, static_map, zoom = _download_map_image(min_lat, min_lon, max_lat, max_lon) 77 | 78 | scatters = [] 79 | fig = plt.figure(figsize=(map_img.shape[1] / 100, map_img.shape[0] / 100), dpi=1000) 80 | for i, coord in enumerate(coordinates): 81 | for i in range(len(coord)): # Scale latitudes because of earth's curvature 82 | coord[i, 0] = -static_map._y_to_px(_lat_to_y(coord[i, 0], zoom)) 83 | for coord, size, color in zip(coordinates, dot_sizes, colors): 84 | scatters.append(plt.scatter(coord[:, 1], coord[:, 0], s=size, color=color)) 85 | 86 | if legend_names != None: 87 | plt.legend(scatters, legend_names, scatterpoints=10000, loc='upper left', 88 | ncol=1, framealpha=0, prop={"weight": "bold", "size": 20}) 89 | 90 | min_lat, min_lon, max_lat, max_lon = get_edges(coordinates) 91 | plt.ylim(min_lat, max_lat) 92 | plt.xlim(min_lon, max_lon) 93 | fig.subplots_adjust(bottom=0, top=1, left=0, right=1) 94 | fig.canvas.draw() 95 | plot_img = np.array(fig.canvas.renderer._renderer) 96 | plt.close() 97 | 98 | plot_img = cv2.resize(plot_img[:, :, :3], map_img.shape[:2][::-1], interpolation=cv2.INTER_LANCZOS4) 99 | map_img[(map_img.sum(2) < 444)] = 188 # brighten dark pixels 100 | map_img = (((map_img / 255) ** map_intensity) * 255).astype(np.uint8) # fade map 101 | mask = (plot_img.sum(2) == 255 * 3)[:, :, None] # mask of plot, to find white pixels 102 | final_map = map_img * mask + plot_img * (~mask) 103 | return final_map 104 | 105 | 106 | def _get_coordinates_from_dataset(dataset_folder, extension="jpg"): 107 | """ 108 | Takes as input the path of a dataset, such as "datasets/st_lucia/images" 109 | and returns 110 | [("train/database", [[45, 8.1], [45.2, 8.2]]), ("train/queries", [[45, 8.1], [45.2, 8.2]])] 111 | """ 112 | images_paths = glob(join(dataset_folder, "**", f"*.{extension}"), recursive=True) 113 | if len(images_paths) != 0: 114 | print(f"I found {len(images_paths)} images in {dataset_folder}") 115 | else: 116 | raise ValueError(f"I found no images in {dataset_folder} !") 117 | 118 | grouped_gps_coords = defaultdict(list) 119 | 120 | for image_path in images_paths: 121 | full_path = os.path.dirname(image_path) 122 | full_parent_path, parent_dir = os.path.split(full_path) 123 | parent_parent_dir = os.path.split(full_parent_path)[1] 124 | 125 | # folder_name is for example "train - database" 126 | folder_name = " - ".join([parent_parent_dir, parent_dir]) 127 | 128 | gps_lat = image_path.split("@")[6] 129 | gps_lon = image_path.split("@")[7] if not image_path.split("@")[7].endswith(f'{extension}') \ 130 | else image_path.split("@")[7][:-len(extension)-1] 131 | gps_coords = gps_lat, gps_lon 132 | grouped_gps_coords[folder_name].append(gps_coords) 133 | 134 | grouped_gps_coords = sorted([(k, np.array(v).astype(np.float64)) 135 | for k, v in grouped_gps_coords.items()]) 136 | return grouped_gps_coords 137 | 138 | 139 | def build_map_from_dataset(dataset_folder, dot_sizes=None): 140 | """dataset_folder is the path that contains the 'images' folder.""" 141 | grouped_gps_coords = _get_coordinates_from_dataset(join(dataset_folder)) 142 | SORTED_FOLDERS = ["train - database", "train - queries", "val - database", "val - queries", 143 | "test - database", "test - queries"] 144 | try: 145 | grouped_gps_coords = sorted(grouped_gps_coords, key=lambda x: SORTED_FOLDERS.index(x[0])) 146 | except ValueError: 147 | pass # This dataset has different folder names than the standard train-val-test database-queries. 148 | coordinates = [] 149 | legend_names = [] 150 | for folder_name, coords in grouped_gps_coords: 151 | legend_names.append(f"{folder_name} - {len(coords)}") 152 | coordinates.append(coords) 153 | 154 | colors = cm.rainbow(np.linspace(0, 1, len(legend_names))) 155 | colors = ListedColormap(colors) 156 | colors = colors.colors 157 | if len(legend_names) == 1: 158 | legend_names = None # If there's only one folder, don't show the legend 159 | colors = np.array([[1, 0, 0]]) 160 | 161 | map_img = _create_map(coordinates, colors, dot_sizes, legend_names) 162 | 163 | print(f"Map image resolution: {map_img.shape}") 164 | dataset_name = os.path.basename(os.path.abspath(dataset_folder)) 165 | io.imsave(join(dataset_folder, f"map_{dataset_name}_final.png"), map_img) 166 | 167 | def main(): 168 | dataset_folder = '/home/zafirshi/Datasets/YQ360' 169 | build_map_from_dataset(dataset_folder) 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /tools/paper_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | import matplotlib.cm as cm 5 | import numpy as np 6 | 7 | 8 | def results_comparison_in_graph(out_name: str): 9 | method = ['NetVLAD', 'Berton et.al', 'Orhan et.al', 'Swin-T', 'PanoVPR(Swin-T)', 10 | 'Swin-S', 'PanoVPR(Swin-S)', 'ConvNeXt-T', 'PanoVPR(ConvNeXt-T)', 11 | 'ConvNeXt-S', 'PanoVPR(ConvNeXt-S)'] 12 | params = [7.23, 86.86, 136.62, 28.29, 28.29, 49.61, 49.61, 28.59, 28.59, 50.22, 50.22] 13 | r1 = [4.0, 8.0, 47.0, 10.1, 41.4, 12.4, 38.2, 9.7, 34.0, 14.2, 48.8] 14 | 15 | fig, ax = plt.subplots() 16 | 17 | for idx, (m,p,r) in enumerate(zip(method,params,r1)): 18 | if m == 'NetVLAD': 19 | ax.scatter(x=p, y=r, s=100, c='#1f77b4', label=m, alpha=0.5) 20 | elif m == 'Berton et.al': 21 | ax.scatter(x=p, y=r, s=100, c='#ff7f0e', label=m, alpha=0.5) 22 | elif m == 'Orhan et.al': 23 | ax.scatter(x=p, y=r, s=100, c='#2ca02c', label=m, alpha=0.5) 24 | elif m == 'Swin-T' or m == 'PanoVPR(Swin-T)': 25 | if m == 'PanoVPR(Swin-T)': 26 | ax.scatter(x=p, y=r, s=100, c='#d62728', label=m, marker='*', alpha=0.5) 27 | # add curve 28 | x_base, y_base = params[idx-1], r1[idx-1] 29 | x_end, y_end = p, r 30 | ax.annotate("", xy=(x_end, y_end),xytext=(x_base, y_base),size=4, va="center", ha="center", 31 | arrowprops=dict(color='#d62728', 32 | arrowstyle="-|>, head_length=1, head_width=0.4", 33 | linewidth=1, 34 | connectionstyle="arc3,rad=-0.3", 35 | linestyle='dashed', 36 | ) 37 | ) 38 | else: 39 | ax.scatter(x=p, y=r, s=100, c='#d62728', label=m, alpha=0.5) 40 | elif m == 'Swin-S' or m == 'PanoVPR(Swin-S)': 41 | if m == 'PanoVPR(Swin-S)': 42 | ax.scatter(x=p, y=r, s=100, c='#bcbd22', label=m, marker='*', alpha=0.5) 43 | # add curve 44 | x_base, y_base = params[idx-1], r1[idx-1] 45 | x_end, y_end = p, r 46 | ax.annotate("", xy=(x_end, y_end),xytext=(x_base, y_base),size=4, va="center", ha="center", 47 | arrowprops=dict(color='#bcbd22', 48 | arrowstyle="-|>, head_length=1, head_width=0.4", 49 | linewidth=1, 50 | connectionstyle="arc3,rad=-0.2", 51 | linestyle='dashed', 52 | ) 53 | ) 54 | else: 55 | ax.scatter(x=p, y=r, s=100, c='#bcbd22', label=m, alpha=0.5) 56 | elif m == 'ConvNeXt-T' or m =='PanoVPR(ConvNeXt-T)': 57 | if m == 'PanoVPR(ConvNeXt-T)': 58 | ax.scatter(x=p, y=r, s=100, c='#7f7f7f', label=m, marker='*', alpha=0.5) 59 | # add curve 60 | x_base, y_base = params[idx-1], r1[idx-1] 61 | x_end, y_end = p, r 62 | ax.annotate("", xy=(x_end, y_end),xytext=(x_base, y_base),size=4, va="center", ha="center", 63 | arrowprops=dict(color='#7f7f7f', 64 | arrowstyle="-|>, head_length=1, head_width=0.4", 65 | linewidth=1, 66 | connectionstyle="arc3,rad=0.2", 67 | linestyle='dashed', 68 | ) 69 | ) 70 | else: 71 | ax.scatter(x=p, y=r, s=100, c='#7f7f7f', label=m, alpha=0.5) 72 | elif m == 'ConvNeXt-S' or m == 'PanoVPR(ConvNeXt-S)': 73 | if m == 'PanoVPR(ConvNeXt-S)': 74 | ax.scatter(x=p, y=r, s=100, c='#e377c2', label=m, marker='*', alpha=0.5) 75 | # add curve 76 | x_base, y_base = params[idx-1], r1[idx-1] 77 | x_end, y_end = p, r 78 | ax.annotate("", xy=(x_end, y_end),xytext=(x_base, y_base),size=4, va="center", ha="center", 79 | arrowprops=dict(color='#e377c2', 80 | arrowstyle="-|>, head_length=1, head_width=0.4", 81 | linewidth=1, 82 | connectionstyle="arc3,rad=0.3", 83 | linestyle='dashed', 84 | ) 85 | ) 86 | else: 87 | ax.scatter(x=p, y=r, s=100, c='#e377c2', label=m, alpha=0.5) 88 | 89 | 90 | # add legend 91 | ax.legend(loc="center right", prop = {'size':8}) 92 | 93 | plt.title('Performance comparison of different methods') 94 | plt.xlabel('Parameters(M)') 95 | plt.ylabel('R@1(%)') 96 | plt.grid(True) 97 | 98 | plt.savefig(out_name, dpi=300, bbox_inches='tight') 99 | plt.show() 100 | 101 | 102 | if __name__=="__main__": 103 | results_comparison_in_graph('test.png') -------------------------------------------------------------------------------- /tools/util.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import torch 4 | import shutil 5 | import logging 6 | import argparse 7 | import numpy as np 8 | from collections import OrderedDict 9 | from os.path import join 10 | 11 | from sklearn.decomposition import PCA 12 | 13 | 14 | from torch import nn as nn 15 | 16 | from timm.models.layers.helpers import to_2tuple 17 | from timm.models.layers.trace_utils import _assert 18 | from mmcv.cnn import get_model_complexity_info 19 | 20 | import datasets_ws 21 | import parser 22 | from model import network 23 | 24 | 25 | def get_flops(input_shape=(224, 224)): 26 | # Initial setup: parser 27 | args = parser.parse_arguments() 28 | 29 | # Initialize model 30 | model = network.GeoLocalizationNet(args) 31 | model = model.to(args.device) 32 | 33 | # input shape 34 | input_shape = (3, input_shape[0], input_shape[1]) 35 | 36 | model.eval() 37 | 38 | flops, params = get_model_complexity_info(model, input_shape) 39 | split_line = '=' * 30 40 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( 41 | split_line, input_shape, flops, params)) 42 | print('!!!Please be cautious if you use the results in papers. ' 43 | 'You may need to check if all ops are supported and verify that the ' 44 | 'flops computation is correct.') 45 | 46 | 47 | def save_checkpoint(args, state, is_best, filename): 48 | model_path = join(args.save_dir, filename) 49 | torch.save(state, model_path) 50 | if is_best: 51 | shutil.copyfile(model_path, join(args.save_dir, "best_model.pth")) 52 | 53 | 54 | def resume_model(args, model): 55 | checkpoint = torch.load(args.resume, map_location=args.device) 56 | if 'model_state_dict' in checkpoint: 57 | state_dict = checkpoint['model_state_dict'] 58 | else: 59 | # The pre-trained models that we provide in the README do not have 'state_dict' in the keys as 60 | # the checkpoint is directly the state dict 61 | state_dict = checkpoint 62 | # if the model contains the prefix "module" which is appendend by 63 | # DataParallel, remove it to avoid errors when loading dict 64 | if list(state_dict.keys())[0].startswith('module'): 65 | state_dict = OrderedDict({k.replace('module.', ''): v for (k, v) in state_dict.items()}) 66 | model.load_state_dict(state_dict) 67 | return model 68 | 69 | 70 | def resume_train(args, model, optimizer=None, strict=False): 71 | """Load model, optimizer, and other training parameters""" 72 | logging.debug(f"Loading checkpoint: {args.resume}") 73 | checkpoint = torch.load(args.resume) 74 | start_epoch_num = checkpoint["epoch_num"] 75 | model.load_state_dict(checkpoint["model_state_dict"], strict=strict) 76 | if optimizer: 77 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 78 | best_r5 = checkpoint["best_r5"] 79 | not_improved_num = checkpoint["not_improved_num"] 80 | logging.debug(f"Loaded checkpoint: start_epoch_num = {start_epoch_num}, " \ 81 | f"current_best_R@5 = {best_r5:.1f}") 82 | if args.resume.endswith("last_model.pth"): # Copy best model to current save_dir 83 | shutil.copy(args.resume.replace("last_model.pth", "best_model.pth"), args.save_dir) 84 | return model, optimizer, best_r5, start_epoch_num, not_improved_num 85 | 86 | 87 | def compute_pca(args, model, pca_dataset_folder, full_features_dim): 88 | model = model.eval() 89 | pca_ds = datasets_ws.PCADataset(args, args.datasets_folder, pca_dataset_folder) 90 | dl = torch.utils.data.DataLoader(pca_ds, args.infer_batch_size, shuffle=True) 91 | pca_features = np.empty([min(len(pca_ds), 2**14), full_features_dim]) 92 | with torch.no_grad(): 93 | for i, images in enumerate(dl): 94 | if i*args.infer_batch_size >= len(pca_features): break 95 | features = model(images).cpu().numpy() 96 | pca_features[i*args.infer_batch_size : (i*args.infer_batch_size)+len(features)] = features 97 | pca = PCA(args.pca_dim) 98 | pca.fit(pca_features) 99 | return pca 100 | 101 | 102 | class PatchEmbed(nn.Module): 103 | """ 2D Image to Patch Embedding 104 | """ 105 | def __init__( 106 | self, 107 | img_size=224, 108 | patch_size=16, 109 | in_chans=3, 110 | embed_dim=768, 111 | norm_layer=None, 112 | flatten=True, 113 | bias=True, 114 | ): 115 | super().__init__() 116 | img_size = to_2tuple(img_size) 117 | patch_size = to_2tuple(patch_size) 118 | self.img_size = img_size 119 | self.patch_size = patch_size 120 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 121 | self.num_patches = self.grid_size[0] * self.grid_size[1] 122 | self.flatten = flatten 123 | 124 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 125 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 126 | 127 | def forward(self, x): 128 | B, C, H, W = x.shape 129 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 130 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 131 | x = self.proj(x) 132 | if self.flatten: 133 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 134 | x = self.norm(x) 135 | return x 136 | -------------------------------------------------------------------------------- /tools/visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from matplotlib import pyplot as plt 4 | from tqdm import tqdm 5 | 6 | 7 | def path_to_pil_img(path): 8 | return Image.open(path).convert("RGB") 9 | 10 | 11 | def add_mask(train_set, img, img_path_idx, win_len=448): 12 | pos_mask = np.zeros(train_set.database_resize, dtype=np.float64) + (100 / 255) 13 | pos_w = img.shape[1] # 448, 448*8, 3 14 | mask_left = (pos_w / train_set.split_nums) * img_path_idx 15 | mask_right = mask_left + win_len 16 | 17 | if mask_left < mask_right <= pos_w: 18 | pos_mask[:, int(mask_left):int(mask_right)] = 0 19 | elif mask_left< pos_w < mask_right: 20 | pos_mask[:int(mask_right-pos_w), int(mask_left):] = 0 21 | else: 22 | raise Exception('Adding mask on img goes WRONG!') 23 | pos_mask_3d = np.stack((pos_mask, pos_mask, pos_mask), axis=2) 24 | 25 | normal_img = (img / 255.).astype(np.float64) 26 | img = np.clip(normal_img - pos_mask_3d, 0, 1) 27 | img = (img * 255).astype(np.uint8) 28 | 29 | return img 30 | 31 | 32 | def display_mining(train_set, triplets_global_indexes_array, save_path, pos_patch_loc, neg_patch_loc): 33 | full_images_path = np.array(train_set.images_paths) 34 | db_num = train_set.database_num 35 | for iter_idx, triple in enumerate(tqdm(triplets_global_indexes_array, ncols=100, desc='Show Mining Results')): 36 | if iter_idx % 50 == 0: 37 | query_idx, pos_idx, neg_idx = triple[0], triple[1], triple[2:] 38 | query_img_path = str(full_images_path[query_idx + db_num]) 39 | pos_img_path = str(full_images_path[pos_idx]) 40 | neg_img_path = full_images_path[neg_idx].tolist() 41 | 42 | query = np.array(path_to_pil_img(query_img_path).resize((448, 448))) 43 | pos = np.array(path_to_pil_img(pos_img_path).resize((3584, 448))) 44 | negs = list(map(lambda x: np.array(path_to_pil_img(x).resize((3584, 448))), neg_img_path)) 45 | black = np.zeros((448 * 10, 448, 3), dtype=np.uint8) 46 | 47 | # mask 48 | pos = add_mask(train_set, pos, pos_patch_loc[iter_idx].item()) 49 | for i, each_neg in enumerate(negs): 50 | negs[i] = add_mask(train_set, each_neg, neg_patch_loc[iter_idx][i].item()) 51 | 52 | right = np.concatenate([pos] + negs, axis=0) 53 | left = np.concatenate((query, black), axis=0) 54 | one_in_all = np.concatenate((left, right), axis=1) 55 | 56 | plt.imsave(save_path + f'iter{iter_idx}.png', one_in_all) 57 | 58 | with open(save_path + 'record.txt', 'a+') as f: 59 | f.write(f'======> iter_idx:{iter_idx}/{train_set.queries_num} <======\n') 60 | f.write(f'Query_Path: {query_img_path}\n') 61 | f.write(f'Pos_Path: {pos_img_path}\n') 62 | for each_neg_path in neg_img_path: 63 | f.write(f'Neg_Path: {each_neg_path}\n') 64 | 65 | 66 | def display_inference(eval_ds, predictions, save_path, focus_patch_loc): 67 | full_images_path = np.array(eval_ds.images_paths) 68 | db_num = eval_ds.database_num 69 | for query_idx, each_query_pred in enumerate(tqdm(predictions, ncols=100, desc='Show Inference Results')): 70 | if query_idx % 50 == 0: 71 | query_img_path = str(full_images_path[query_idx + db_num]) 72 | db_img_path = full_images_path[each_query_pred[:5]].tolist() # 便于展示,选top5 73 | 74 | query = np.array(path_to_pil_img(query_img_path).resize((448, 448))) 75 | db = list(map(lambda x: np.array(path_to_pil_img(x).resize((3584, 448))), db_img_path)) 76 | black = np.zeros((448 * 4, 448, 3), dtype=np.uint8) 77 | 78 | # add mask 79 | for i, each_db in enumerate(db): 80 | db[i] = add_mask(eval_ds, each_db, focus_patch_loc[query_idx][i].item()) 81 | 82 | right = np.concatenate(db, axis=0) 83 | left = np.concatenate((query, black), axis=0) 84 | one_in_all = np.concatenate((left, right), axis=1) 85 | 86 | # save img 87 | plt.imsave(save_path + f'query_idx{query_idx}.png', one_in_all) 88 | 89 | with open(save_path + 'inference_list.txt', 'a+') as f: 90 | f.write(f'======> query_idx:{query_idx}/{eval_ds.queries_num} <======\n') 91 | f.write(f'Query_Path: {query_img_path}\n') 92 | for each_db_path in db_img_path: 93 | f.write(f'Pred_db_Path: {each_db_path}\n') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import math 4 | import torch 5 | import logging 6 | import numpy as np 7 | from matplotlib import pyplot as plt 8 | from torch.utils import data 9 | from tqdm import tqdm 10 | import torch.nn as nn 11 | import multiprocessing 12 | from os.path import join 13 | from datetime import datetime 14 | import torchvision.transforms as transforms 15 | from torch.utils.data.dataloader import DataLoader 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torch.cuda.amp import autocast, GradScaler 18 | 19 | from tools import util 20 | import test 21 | import parser 22 | from tools import commons 23 | import datasets_ws 24 | from model import network 25 | from model.sync_batchnorm import convert_model 26 | from model.functional import sare_ind, sare_joint 27 | from tools.visual import display_mining 28 | from tools.loss import shift_window_triple_loss 29 | 30 | 31 | def train(args, 32 | model: nn.Module, 33 | train_set: datasets_ws.TripletsDataset, 34 | loss_fn: nn.TripletMarginLoss, 35 | optimizer, 36 | val_set: datasets_ws.BaseDataset, 37 | writer, 38 | start_epoch_num=0, 39 | best_r5=0, 40 | not_improved_num=0, 41 | show_mining_triplet_img=None, 42 | visual_mining_save_path=None, 43 | show_inference_results=None, 44 | visual_val_save_path=None 45 | ): 46 | """ 47 | Trains a given model using triplet loss and performs validation. 48 | 49 | Args: 50 | args: A namespace or dictionary containing training parameters. 51 | model (nn.Module): The neural network model to train. 52 | train_set (datasets_ws.TripletsDataset): The training dataset containing triplets. 53 | loss_fn (nn.TripletMarginLoss): The loss function to optimize. 54 | optimizer: The optimization algorithm. 55 | val_set (datasets_ws.BaseDataset): The validation dataset. 56 | writer: A summary writer object for logging. 57 | start_epoch_num (int): The starting epoch number for training. 58 | best_r5 (float): The best recall@5 score obtained so far. 59 | not_improved_num (int): Counter for epochs without improvement. 60 | show_mining_triplet_img (callable, optional): Function to visualize mining results. 61 | visual_mining_save_path (str, optional): Path to save visual mining results. 62 | show_inference_results (callable, optional): Function to visualize inference results. 63 | visual_val_save_path (str, optional): Path to save validation visualization results. 64 | 65 | The training process includes mining hard triplets, calculating loss, and updating the model parameters. 66 | Validation is performed at the end of each epoch to monitor the recall metrics and early stopping is applied 67 | based on the recall@5 metric. 68 | """ 69 | 70 | scaler = GradScaler() 71 | 72 | # Training loop 73 | for epoch_num in range(start_epoch_num, args.epochs_num): 74 | logging.info(f"Start training epoch: {epoch_num:02d}") 75 | 76 | epoch_start_time = datetime.now() 77 | epoch_losses = np.zeros((0, 1), dtype=np.float32) 78 | 79 | # How many loops should an epoch last 80 | loops_num = math.ceil(args.queries_per_epoch / args.cache_refresh_rate) 81 | for loop_num in range(loops_num): 82 | logging.debug(f"Cache: {loop_num} / {loops_num}") 83 | 84 | # Compute triplets to use in the triplet loss 85 | train_set.is_inference = True 86 | train_set.compute_triplets(args, model) 87 | train_set.is_inference = False 88 | 89 | # Visualizing mining results. 90 | if show_mining_triplet_img: 91 | os.makedirs(f'{visual_mining_save_path}' + f'mining_epoch{epoch_num}/', exist_ok=True) 92 | save_path = f'{visual_mining_save_path}' + f'mining_epoch{epoch_num}/loopNum{loop_num}_' 93 | triplets_global_indexes_array = train_set.triplets_global_indexes.numpy() 94 | # Obtaining the sub-window focused on by the model. 95 | pos_patch_loc, neg_patch_loc = train_setshift_window_on_img.pos_focus_patch, train_set.neg_focus_patch 96 | display_mining(train_set, triplets_global_indexes_array, save_path, pos_patch_loc, neg_patch_loc) 97 | 98 | triplets_dl = DataLoader(dataset=train_set, num_workers=args.num_workers, 99 | batch_size=args.train_batch_size, 100 | pin_memory=(args.device == "cuda"), 101 | drop_last=True) 102 | 103 | model = model.train() 104 | 105 | cache_losses = np.zeros((0, 1), dtype=np.float32) 106 | batch_nums = args.cache_refresh_rate / args.train_batch_size 107 | 108 | # images shape: (train_batch_size*12)*3*H*W 109 | for batch_idx, (query, pano_database) in enumerate(tqdm(triplets_dl, ncols=100, desc='Training')): 110 | # query shape:(B, 3, 224, 224) || pano_database shape:(B, 11, 3, 224, 224*8) 111 | pano_database_4d = torch.flatten(pano_database, end_dim=1) # B*11, 3, 224, 224*8 112 | 113 | pano_split_list = [] 114 | for pano_full in pano_database_4d: 115 | pano_split = datasets_ws.shift_window_on_img(pano_full, train_set.split_nums, 116 | train_set.window_stride, train_set.window_len) 117 | 118 | pano_split_list.append(pano_split) 119 | 120 | pano_database_4d = torch.flatten(torch.stack(pano_split_list), end_dim=1) 121 | 122 | optimizer.zero_grad() 123 | with autocast(): 124 | database_feature = model(pano_database_4d.to(args.device)).view(args.train_batch_size, 125 | 1 + args.negs_num_per_query, -1) 126 | query_feature = model(query.to(args.device)) 127 | 128 | loss_triplet = 0 129 | loss = 0 130 | for idx in range(args.train_batch_size): 131 | loss, min_index_in_row = shift_window_triple_loss(args, query_feature[idx], database_feature[idx], loss_fn) 132 | loss_triplet += loss 133 | del database_feature, query_feature 134 | 135 | loss_triplet /= (args.train_batch_size * args.negs_num_per_query) 136 | 137 | scaler.scale(loss_triplet).backward() 138 | 139 | scaler.unscale_(optimizer) 140 | # clip_grad 1,3,5,7,10 141 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 142 | 143 | scaler.step(optimizer) 144 | scaler.update() 145 | 146 | # Keep track of all losses by appending them to epoch_losses 147 | batch_loss = loss_triplet.item() 148 | 149 | writer.add_scalar('Loss/Batch_Loss', batch_loss, 150 | epoch_num * loops_num * batch_nums + loop_num * batch_nums + batch_idx) 151 | 152 | cache_losses = np.append(cache_losses, batch_loss) 153 | del loss_triplet 154 | 155 | logging.debug(f"Epoch[{epoch_num:02d}]({loop_num}/{loops_num}): " + 156 | f"latest batch triplet loss = {batch_loss:.4f}, " + 157 | f"current cache triplet loss = {cache_losses.mean():.4f}") 158 | writer.add_scalar('Loss/Cache_Loss', cache_losses.mean().item(), epoch_num * loops_num + loop_num) 159 | 160 | # epoch_losses should update after calculating cache_loss 161 | epoch_losses = np.append(epoch_losses, cache_losses) 162 | 163 | logging.info(f"Finished epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, " 164 | f"average epoch triplet loss = {epoch_losses.mean():.4f}") 165 | writer.add_scalar('Loss/Epoch_Loss', epoch_losses.mean().item(), epoch_num) 166 | 167 | # Compute recalls on validation set 168 | recalls, recalls_str = test.test(args, val_set, model, 169 | show_inference_results=show_inference_results, 170 | save_path=visual_val_save_path) 171 | logging.info(f"Recalls on val set {val_set}: {recalls_str}") 172 | 173 | writer.add_scalar('Recall/@1', recalls[0], epoch_num) 174 | writer.add_scalar('Recall/@5', recalls[1], epoch_num) 175 | writer.add_scalar('Recall/@10', recalls[2], epoch_num) 176 | writer.add_scalar('Recall/@20', recalls[3], epoch_num) 177 | 178 | is_best = recalls[1] > best_r5 179 | 180 | # Save checkpoint, which contains all training parameters 181 | util.save_checkpoint(args, {"epoch_num": epoch_num, "model_state_dict": model.state_dict(), 182 | "optimizer_state_dict": optimizer.state_dict(), "recalls": recalls, 183 | "best_r5": best_r5, 184 | "not_improved_num": not_improved_num 185 | }, is_best, filename="last_model.pth") 186 | 187 | # If recall@5 did not improve for "many" epochs, stop training 188 | if is_best: 189 | logging.info(f"Improved: previous best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}") 190 | best_r5 = recalls[1] 191 | not_improved_num = 0 192 | else: 193 | not_improved_num += 1 194 | logging.info( 195 | f"Not improved: {not_improved_num} / {args.patience}: best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}") 196 | if not_improved_num >= args.patience: 197 | logging.info(f"Performance did not improve for {not_improved_num} epochs. Stop training.") 198 | break 199 | 200 | logging.info(f"Trained for {epoch_num + 1:02d} epochs, Best R@5: {best_r5:.1f}") 201 | 202 | 203 | def main(): 204 | # Initial setup: parser, logging... 205 | args = parser.parse_arguments() 206 | start_time = datetime.now() 207 | args.save_dir = join("logs", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S') + '_' + args.title) 208 | commons.setup_logging(args.save_dir) 209 | commons.make_deterministic(args.seed, speedup=False) # speedup=False make results Reproducible 210 | logging.info(f"Arguments: {args}") 211 | logging.info(f"The outputs are being saved in {args.save_dir}") 212 | logging.info(f"Using {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs") 213 | 214 | # Creation of Datasets 215 | logging.debug(f"Loading dataset {args.dataset_name} from folder {args.datasets_folder}") 216 | 217 | triplets_ds = datasets_ws.TripletsDataset(args, args.datasets_folder, args.dataset_name, "train", 218 | args.negs_num_per_query) 219 | logging.info(f"Train query set: {triplets_ds}") 220 | 221 | val_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "val") 222 | logging.info(f"Val set: {val_ds}") 223 | 224 | test_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "test") 225 | logging.info(f"Test set: {test_ds}") 226 | 227 | #### Initialize model 228 | model = network.GeoLocalizationNet(args) 229 | model = model.to(args.device) 230 | model = torch.nn.DataParallel(model) 231 | 232 | #### Setup Optimizer and Loss 233 | if args.optim == "adam": 234 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 235 | elif args.optim == "sgd": 236 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.001) 237 | 238 | criterion_triplet = nn.TripletMarginLoss(margin=args.margin, p=2, reduction="sum") 239 | 240 | #### Resume model, optimizer, and other training parameters 241 | if args.resume: 242 | model, optimizer, best_r5, start_epoch_num, not_improved_num = util.resume_train(args, model, optimizer) 243 | logging.info(f"Resuming from epoch {start_epoch_num} with best recall@5 {best_r5:.1f}") 244 | else: 245 | best_r5 = start_epoch_num = not_improved_num = 0 246 | 247 | if torch.cuda.device_count() >= 2: 248 | # When using more than 1GPU, use sync_batchnorm for torch.nn.DataParallel 249 | model = convert_model(model) 250 | model = model.cuda() 251 | 252 | # Add tensorboard monitor 253 | writer = SummaryWriter(log_dir=args.save_dir) 254 | 255 | # Train model on train set and validate model on validation set every train epoch 256 | train(args, 257 | model, 258 | train_set=triplets_ds, loss_fn=criterion_triplet, optimizer=optimizer, val_set=val_ds, 259 | writer=writer, start_epoch_num=start_epoch_num, best_r5=best_r5, not_improved_num=not_improved_num, 260 | show_mining_triplet_img=False, 261 | visual_mining_save_path=args.train_visual_save_path, 262 | show_inference_results=False, 263 | visual_val_save_path=args.val_visual_save_path) 264 | logging.info(f"Trained total in {str(datetime.now() - start_time)[:-7]}") 265 | 266 | #### Test best model on test set 267 | best_model_state_dict = torch.load(join(args.save_dir, "best_model.pth"))["model_state_dict"] 268 | model.load_state_dict(best_model_state_dict) 269 | 270 | recalls, recalls_str = test.test(args, test_ds, model, test_method=args.test_method, 271 | show_inference_results=False, 272 | save_path=args.test_visual_save_path) 273 | logging.info(f"Recalls on {test_ds}: {recalls_str}") 274 | # show test results in one graph 275 | for i in range(len(recalls)): 276 | writer.add_scalar('Test', recalls[i], i + 1) 277 | writer.close() 278 | 279 | 280 | if __name__ == '__main__': 281 | main() 282 | --------------------------------------------------------------------------------