├── .gitignore
├── LICENSE
├── README.md
├── assets
├── results.png
└── slide-window.png
├── datasets_ws.py
├── environment.yml
├── model
├── __init__.py
├── aggregation.py
├── functional.py
├── network.py
├── non_local.py
├── normalization.py
└── sync_batchnorm
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── batchnorm_reimpl.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
├── parser.py
├── test.py
├── tools
├── __init__.py
├── commons.py
├── loss.py
├── map_builder.py
├── paper_utils.py
├── util.py
└── visual.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Autogenerated folders
2 | __pycache__
3 | logs
4 | test
5 | data
6 |
7 | # IDEs generated folders
8 | .spyproject
9 | venv/
10 | .idea/
11 | __MACOSX/
12 | **/.DS_Store
13 |
14 | # other
15 | pretrained
16 | *.pth
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 zafirshi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## PanoVPR: Towards Unified Perspective-to-Equirectangular Visual Place Recognition via Sliding Windows across the Panoramic View
4 |
5 |
6 |
7 |
8 | Ze Shi* ·
9 | Hao Shi* ·
10 | Kailun Yang ·
11 | Zhe Yin ·
12 | Yining Lin ·
13 | Kaiwei Wang
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 | ## Update
29 |
30 | - [2023-11] :gear: Code Release
31 | - [2023-08] :tada: PanoVPR is accepted to 26th IEEE International Conference on Intelligent Transportation Systems ([ITSC-2023](https://2023.ieee-itsc.org/)).
32 | - [2023-03] :construction: Init repo and release [arxiv](https://arxiv.org/pdf/2303.14095.pdf) version
33 |
34 | ## Introduction
35 |
36 | We propose a **V**isual **P**lace **R**ecognition framework for retrieving **Pano**ramic database images using perspective query images, dubbed **PanoVPR**. To achieve this, we adopt sliding window approach on panoramic database images to narrow the model's observation range of the large field of view panoramas. We achieve promising results in a derived dataset *Pitts250K-P2E* and a real-world scenario dataset *YQ360*.
37 |
38 | For more details, please check our [arXiv](https://arxiv.org/pdf/2303.14095.pdf) paper.
39 |
40 | ## Sliding window strategy
41 |
42 | 
43 |
44 | ## Qualitative results
45 |
46 | 
47 |
48 | ## Usage
49 |
50 | ### Dataset Preparation
51 |
52 | Before starting, you need to download the Pitts250K-P2E dataset and the YQ360 dataset [[OneDrive Link](https://zjueducn-my.sharepoint.com/:f:/g/personal/zafirshi_zju_edu_cn/Ei4N__otNrVAjxku0UnT-pQBdsOSF3PvAEi8Z9wGu7Aj0w?e=LvVwIp)][[BaiduYun Link](https://pan.baidu.com/s/1IBcpAwnwY5YlqfgfSqRz-w?pwd=Pano)].
53 |
54 | If the link is out of date, please email _office_makeit@163.com_ for the latest available link!
55 |
56 | Afterwards, specify the `--datasets_folder` parameter in the `parser.py` file.
57 |
58 |
59 | ### Setup
60 |
61 | You need to first create an environment from file `environment.yml` using [Conda](https://docs.conda.io/projects/miniconda/en/latest/miniconda-install.html), and then activate it.
62 |
63 | ```bash
64 | conda env create -f environment.yml --prefix /path/to/env
65 | conda activate PanoVPR
66 | ```
67 |
68 | ### Train
69 |
70 | If you want to train the network, you can change the training configuration and the dataset used
71 | by specifying parameters such as `--backbone`, `--split_nums`, and `--dataset_name` in the command line.
72 |
73 | Meanwhile, adjust other parameters according to the actual situation.
74 | By default, the output results are stored in the `./logs/{save_dir}` folder.
75 |
76 | **Please note that the `--title` parameter must be specified in the command line.**
77 |
78 | ```bash
79 | # Train on Pitts250K-P2E
80 | python train.py --title swinTx24 \
81 | --save_dir clean_branch_test \
82 | --backbone swin_tiny \
83 | --split_nums 24 \
84 | --dataset_name pitts250k \
85 | --cache_refresh_rate 125 \
86 | --neg_sample 100 \
87 | --queries_per_epoch 2000
88 | ```
89 |
90 | ### Inference
91 |
92 | For the inference process, you need to specify the absolute path where the `best_model.pth` is stored in the `--resume` parameter.
93 |
94 | ```bash
95 | # Val and Test On Pitts250K-P2E
96 | python test.py --title test_swinTx24 \
97 | --save_dir clean_branch_test \
98 | --backbone swin_tiny \
99 | --split_nums 24 \
100 | --dataset_name pitts250k \
101 | --cache_refresh_rate 125 \
102 | --neg_sample 100 \
103 | --queries_per_epoch 2000 \
104 | --resume
105 | ```
106 |
107 | ## Acknowledgments
108 |
109 | We thank the authors of the following repositories for their open source code:
110 |
111 | - [tranleanh/image-panorama-stitching](https://github.com/tranleanh/image-panorama-stitching)
112 | - [gmberton/datasets_vg](https://github.com/gmberton/datasets_vg)
113 | - [gmberton/deep-visual-geo-localization-benchmark](https://github.com/gmberton/deep-visual-geo-localization-benchmark)
114 |
115 |
116 | ## Cite Our Work
117 |
118 | Thanks for using our work. You can cite it as:
119 |
120 | ```bib
121 | @INPROCEEDINGS{shi2023panovpr,
122 | author={Shi, Ze and Shi, Hao and Yang, Kailun and Yin, Zhe and Lin, Yining and Wang, Kaiwei},
123 | booktitle={2023 IEEE 26th International Conference on Intelligent Transportation Systems (ITSC)},
124 | title={PanoVpr: Towards Unified Perspective-to-Equirectangular Visual Place Recognition via Sliding Windows Across the Panoramic View},
125 | year={2023},
126 | pages={1333-1340},
127 | doi={10.1109/ITSC57777.2023.10421857}
128 | }
129 | ```
130 |
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zafirshi/PanoVPR/7f576b7679691882fc8f0346930deb0aff6d1e38/assets/results.png
--------------------------------------------------------------------------------
/assets/slide-window.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zafirshi/PanoVPR/7f576b7679691882fc8f0346930deb0aff6d1e38/assets/slide-window.png
--------------------------------------------------------------------------------
/datasets_ws.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 |
4 | import PIL
5 | import torch
6 | import faiss
7 | import matplotlib.pyplot as plt
8 | import logging
9 | import numpy as np
10 | from glob import glob
11 | from tqdm import tqdm
12 | from os.path import join
13 | import torch.utils.data as data
14 | import torchvision.transforms as transforms
15 | from torch.utils.data.dataset import Subset
16 | from sklearn.neighbors import NearestNeighbors
17 | from torch.utils.data.dataloader import DataLoader
18 | from tools.visual import path_to_pil_img
19 |
20 | base_transform = transforms.Compose([
21 | transforms.ToTensor(),
22 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
23 | ])
24 |
25 |
26 | def shift_window_on_descriptor(short_vector, long_matrix, window_size, divisor_factor, sorted_indices_num):
27 | """
28 | Shift window in descriptor comparing between short_query_vector and long_database_matrix,
29 | and calculate similarity in window range
30 | :param divisor_factor: window_stride=window_size/n. n->divisor_factor
31 | :param short_vector: a short numpy vector generated by query image
32 | :param long_matrix: a long and multi-row matrix(2d numpy array) generated by sampled negative images
33 | :param window_size: feature_window_length set as shifting base length
34 | :param sorted_indices_num: num of index we select from sorted candidate descriptors
35 | :return:
36 | sorted_indices_num_in_matrix: database index
37 | window_loc_in_matrix: focus patch in selected database matrix
38 | """
39 | left, right = 0, window_size
40 | window_stride = int(window_size / divisor_factor)
41 |
42 | part_distance_list = []
43 | width_bound = long_matrix.shape[1]
44 | counter = 0
45 |
46 | while left < width_bound:
47 | if right <= width_bound:
48 | part_distance = np.linalg.norm(short_vector - long_matrix[:, left:right], ord=2, axis=1)
49 | else:
50 | cycle_part = np.concatenate((long_matrix[:,left:],long_matrix[:,:right-width_bound]),axis=1)
51 | part_distance = np.linalg.norm(short_vector - cycle_part, ord=2, axis=1)
52 | part_distance_list.append(part_distance)
53 |
54 | left += window_stride
55 | right += window_stride
56 | counter += 1
57 |
58 | maintain_table = np.array(part_distance_list).transpose()
59 |
60 | # Select the smallest 'sorted_indices_num' 'long_matrix_index' based on distance
61 | sorted_indices_num_in_matrix = np.argsort(np.min(maintain_table, axis=1))[:sorted_indices_num]
62 |
63 | window_loc_in_matrix = np.argmin(maintain_table, axis=1)[sorted_indices_num_in_matrix]
64 | return sorted_indices_num_in_matrix, window_loc_in_matrix
65 |
66 |
67 | def shift_window_on_img(img: torch.Tensor, win_num: int, win_stride: int, win_len: int) -> torch.Tensor:
68 | """
69 | shift window on picture
70 | :param img: tensor with long width, shape like be (3, 448, 3584)
71 | :param win_num: split window_num [should calculate carefully before training]
72 | :param win_stride: step window_shift
73 | :param win_len:
74 | :return:split_pano img -> a tensor shape like (N, 3, 448, 3584/N)
75 | """
76 | input_img_width = img.shape[-1]
77 | img_split_list = []
78 | for i in range(win_num):
79 | sw_left = i * win_stride
80 | sw_right = i * win_stride + win_len
81 | if sw_left <= input_img_width and sw_right <= input_img_width:
82 | img_split_list += [img[:, :, int(sw_left):int(sw_right)]]
83 | # when one-go directly, win_num should be 15(7+8) or other numbers
84 | # cycle calculate concat the rightest and the leftest part
85 | elif sw_left <= input_img_width < sw_right:
86 | img_split_list += [torch.concat([img[:, :, int(sw_left):],
87 | img[:, :, :int(sw_right - input_img_width)]], dim=-1)]
88 | else:
89 | break
90 | img = torch.stack(img_split_list, dim=0)
91 | return img
92 |
93 |
94 | class PCADataset(data.Dataset):
95 | def __init__(self, args, datasets_folder="dataset", dataset_folder="pitts30k/images/train"):
96 | dataset_folder_full_path = join(datasets_folder, dataset_folder)
97 | if not os.path.exists(dataset_folder_full_path):
98 | raise FileNotFoundError(f"Folder {dataset_folder_full_path} does not exist")
99 | self.images_paths = sorted(glob(join(dataset_folder_full_path, "**", "*.jpg"), recursive=True))
100 |
101 | def __getitem__(self, index):
102 | return base_transform(path_to_pil_img(self.images_paths[index]))
103 |
104 | def __len__(self):
105 | return len(self.images_paths)
106 |
107 |
108 | class BaseDataset(data.Dataset):
109 | """Dataset with images from database and queries, used for inference (testing and building cache).
110 | """
111 |
112 | def __init__(self, args, datasets_folder="datasets", dataset_name="pitts30k", split="train"):
113 | super().__init__()
114 | self.args = args
115 | self.dataset_name = dataset_name
116 | self.dataset_folder = join(datasets_folder, dataset_name, "images", split)
117 | if not os.path.exists(self.dataset_folder): raise FileNotFoundError(
118 | f"Folder {self.dataset_folder} does not exist")
119 |
120 | self.resize = args.resize
121 | self.query_resize = args.query_resize
122 | self.database_resize = args.database_resize
123 | self.test_method = args.test_method
124 |
125 | self.split_nums = args.split_nums
126 | # Non-overlapping sliding window length (in pixels), default to the width of the perspective query image.
127 | self.window_len = self.resize[1]
128 | self.window_stride = 8 * self.window_len / self.split_nums # Pixel step length of the sliding window.
129 | if self.window_stride < self.window_len:
130 | logging.debug('[Note]:slide window using overlapping way')
131 |
132 | # for display,locate the selected window_num
133 | self.pos_focus_patch = []
134 | self.neg_focus_patch = []
135 |
136 | #### Read paths and UTM coordinates for all images.
137 | database_folder = join(self.dataset_folder, "database_pano_clean")
138 | queries_folder = join(self.dataset_folder, "queries_split")
139 |
140 | if not os.path.exists(database_folder): raise FileNotFoundError(f"Folder {database_folder} does not exist")
141 | if not os.path.exists(queries_folder): raise FileNotFoundError(f"Folder {queries_folder} does not exist")
142 | self.database_paths = sorted(glob(join(database_folder, "**", "*.jpg"), recursive=True))
143 | self.queries_paths = sorted(glob(join(queries_folder, "**", "*.jpg"), recursive=True))
144 | # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg
145 | self.database_utms = np.array(
146 | [(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(np.float)
147 | self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(
148 | np.float)
149 |
150 | # Find soft_positives_per_query, which are within val_positive_dist_threshold (deafult 25 meters)
151 | knn = NearestNeighbors(n_jobs=-1)
152 | knn.fit(self.database_utms)
153 | self.soft_positives_per_query = knn.radius_neighbors(self.queries_utms,
154 | radius=args.val_positive_dist_threshold,
155 | return_distance=False)
156 |
157 | self.images_paths = list(self.database_paths) + list(self.queries_paths)
158 |
159 | self.database_num = len(self.database_paths)
160 | self.queries_num = len(self.queries_paths)
161 |
162 | def __getitem__(self, index):
163 | if index < self.database_num:
164 | # split pano_database images into several subs
165 | img = path_to_pil_img(self.images_paths[index])
166 | img = self.resize_database_p2e(img) # shape:(3,224,224*8)
167 |
168 | elif index >= self.database_num:
169 | # query image just resize
170 | img = path_to_pil_img(self.images_paths[index])
171 | img = base_transform(img)
172 | img = transforms.functional.resize(img, self.query_resize)
173 | # resize to adapt backbone input size
174 | img = transforms.functional.resize(img, self.resize)
175 |
176 | return img, index
177 |
178 | def resize_database_p2e(self, img: PIL.Image) -> torch.Tensor:
179 | img = base_transform(img)
180 | img = transforms.functional.resize(img, self.database_resize)
181 | img = transforms.functional.resize(img, (self.resize[0], 8 * self.resize[1])) # shape:(3,224,224*8)
182 | return img
183 |
184 | def __len__(self):
185 | return len(self.images_paths)
186 |
187 | def __repr__(self):
188 | return (
189 | f"< {self.__class__.__name__}, {self.dataset_name} - #database: {self.database_num}; #queries: {self.queries_num} >")
190 |
191 | def get_positives(self):
192 | return self.soft_positives_per_query
193 |
194 |
195 | class TripletsDataset(BaseDataset):
196 | """Dataset used for training, it is used to compute the triplets
197 | with TripletsDataset.compute_triplets() with various mining methods.
198 | If is_inference == True, uses methods of the parent class BaseDataset,
199 | this is used for example when computing the cache, because we compute features
200 | of each image, not triplets.
201 | """
202 |
203 | def __init__(self, args, datasets_folder="datasets", dataset_name="pitts30k", split="train", negs_num_per_query=10):
204 | super().__init__(args, datasets_folder, dataset_name, split)
205 | self.mining = args.mining
206 | self.neg_samples_num = args.neg_samples_num # Number of negatives to randomly sample
207 | self.negs_num_per_query = negs_num_per_query # Number of negatives per query in each batch
208 | if self.mining == "full": # "Full database mining" keeps a cache with last used negatives
209 | self.neg_cache = [np.empty((0,), dtype=np.int32) for _ in range(self.queries_num)]
210 | self.is_inference = False
211 |
212 | identity_transform = transforms.Lambda(lambda x: x)
213 | self.resized_transform = transforms.Compose([
214 | transforms.Resize(self.resize) if self.resize is not None else identity_transform,
215 | base_transform
216 | ])
217 |
218 | self.query_transform = transforms.Compose([
219 | transforms.ColorJitter(brightness=args.brightness) if args.brightness != None else identity_transform,
220 | transforms.ColorJitter(contrast=args.contrast) if args.contrast != None else identity_transform,
221 | transforms.ColorJitter(saturation=args.saturation) if args.saturation != None else identity_transform,
222 | transforms.ColorJitter(hue=args.hue) if args.hue != None else identity_transform,
223 | transforms.RandomPerspective(
224 | args.rand_perspective) if args.rand_perspective != None else identity_transform,
225 | transforms.RandomResizedCrop(size=self.resize, scale=(1 - args.random_resized_crop, 1)) \
226 | if args.random_resized_crop != None else identity_transform,
227 | transforms.RandomRotation(
228 | degrees=args.random_rotation) if args.random_rotation != None else identity_transform,
229 | self.resized_transform,
230 | ])
231 |
232 | # Find hard_positives_per_query, which are within train_positives_dist_threshold (10 meters)
233 | knn = NearestNeighbors(n_jobs=-1)
234 | knn.fit(self.database_utms)
235 | self.hard_positives_per_query = list(knn.radius_neighbors(self.queries_utms,
236 | radius=args.train_positives_dist_threshold,
237 | # 10 meters
238 | return_distance=False))
239 |
240 | #### Some queries might have no positive, we should remove those queries.
241 | queries_without_any_hard_positive = \
242 | np.where(np.array([len(p) for p in self.hard_positives_per_query], dtype=object) == 0)[0]
243 | if len(queries_without_any_hard_positive) != 0:
244 | logging.info(f"There are {len(queries_without_any_hard_positive)} queries without any positives " +
245 | "within the training set. They won't be considered as they're useless for training.")
246 | # Remove queries without positives
247 | self.hard_positives_per_query = np.delete(self.hard_positives_per_query, queries_without_any_hard_positive)
248 | self.queries_paths = np.delete(self.queries_paths, queries_without_any_hard_positive)
249 |
250 | # Recompute images_paths and queries_num because some queries might have been removed
251 | self.images_paths = list(self.database_paths) + list(self.queries_paths)
252 | self.queries_num = len(self.queries_paths)
253 |
254 |
255 | def __getitem__(self, index):
256 | if self.is_inference:
257 | return super().__getitem__(index)
258 | query_index, best_positive_index, neg_indexes = torch.split(self.triplets_global_indexes[index],
259 | (1, 1, self.negs_num_per_query))
260 |
261 | query = self.query_transform(path_to_pil_img(self.queries_paths[query_index]))
262 | database_indices_list = torch.concat([best_positive_index, neg_indexes]).tolist()
263 |
264 | pano_database_list = [self.resize_database_p2e(each_pano_database) for each_pano_database in
265 | [path_to_pil_img(self.database_paths[i]) for i in database_indices_list]]
266 |
267 | pano_database = torch.stack(pano_database_list, dim=0)
268 | return query, pano_database
269 |
270 | def __len__(self):
271 | if self.is_inference:
272 | return super().__len__()
273 | else:
274 | return len(self.triplets_global_indexes)
275 |
276 | def compute_triplets(self, args, model):
277 | self.is_inference = True
278 |
279 | if self.mining == "partial":
280 | self.compute_triplets_partial(args, model)
281 | elif self.mining == "full":
282 | raise NotImplementedError(f"{self.mining} has not been implemented yet, use partial instead")
283 | elif self.mining == "random":
284 | raise NotImplementedError(f"{self.mining} has not been implemented yet, use partial instead")
285 | else:
286 | raise KeyError(f"{self.mining} is not among the options (partial, full, random)")
287 |
288 | def compute_cache_database(self, args, model, subset_ds, database_cache_shape):
289 | """Compute the cache containing features of images, which is used to
290 | find best positive and hardest negatives."""
291 |
292 | subset_dl = DataLoader(dataset=subset_ds, num_workers=args.num_workers,
293 | batch_size=args.infer_batch_size, shuffle=False,
294 | pin_memory=(args.device == "cuda"))
295 | model = model.eval()
296 |
297 | # RAMEfficient2DMatrix can be replaced by np.zeros, but using
298 | # RAMEfficient2DMatrix is RAM efficient for full database mining.
299 | database_cache = RAMEfficient2DMatrix(database_cache_shape, dtype=np.float32)
300 |
301 | with torch.no_grad():
302 | for images, indexes in tqdm(subset_dl, ncols=100, desc='Compute Database Cache'):
303 | # images shape: 16(infer_bs),3,224,224*8 -> 16,16(split_num),3,224,224
304 |
305 | images = torch.stack([shift_window_on_img(each_img, win_num=self.split_nums,
306 | win_stride=self.window_stride, win_len=self.window_len)
307 | for each_img in images])
308 |
309 | B, N, C, resize_H, resize_W = images.shape
310 | images = images.view(B * N, C, resize_H, resize_W)
311 |
312 | images = images.to(args.device)
313 | features = model(images)
314 | features = torch.flatten(features.view(B, N, -1), start_dim=1)
315 |
316 | database_cache[indexes.numpy()] = features.cpu().numpy()
317 |
318 | return database_cache
319 |
320 | @staticmethod
321 | def compute_cache_query(args, model, subset_ds, query_cache_shape, database_num):
322 | """Compute the cache containing features of images, which is used to
323 | find best positive and hardest negatives."""
324 |
325 | subset_dl = DataLoader(dataset=subset_ds, num_workers=args.num_workers,
326 | batch_size=args.infer_batch_size, shuffle=False,
327 | pin_memory=(args.device == "cuda"))
328 | model = model.eval()
329 |
330 | query_cache = RAMEfficient2DMatrix(query_cache_shape, dtype=np.float32)
331 |
332 | with torch.no_grad():
333 | for images, indexes in tqdm(subset_dl, ncols=100, desc='Compute Query Cache'):
334 | images = images.to(args.device)
335 | features = model(images)
336 | # minus to begin from 0
337 | query_cache[indexes.numpy() - database_num] = features.cpu().numpy()
338 |
339 | return query_cache
340 |
341 | def get_query_features(self, query_index, query_cache):
342 | query_features = query_cache[query_index]
343 | if query_features is None:
344 | raise RuntimeError(f"For query {self.queries_paths[query_index]} " +
345 | f"with index {query_index} features have not been computed!\n" +
346 | "There might be some bug with caching")
347 | return query_features
348 |
349 | def get_best_positive_index(self, args, query_index, database_cache, query_features) -> int:
350 | positives_features = database_cache[self.hard_positives_per_query[query_index]]
351 |
352 | # Segmentally compare the similarity between descriptors,
353 | # calculate the average row-wise (each row represents a panoramic image)
354 | best_positive_in_matrix, window_loc = shift_window_on_descriptor(short_vector=query_features,
355 | long_matrix=positives_features,
356 | window_size=args.features_dim,
357 | divisor_factor=args.reduce_factor,
358 | sorted_indices_num=1)
359 |
360 | # Search the best positive (within 10 meters AND nearest in features space)
361 | best_positive_index = self.hard_positives_per_query[query_index][best_positive_in_matrix]
362 | self.pos_focus_patch.append(window_loc)
363 | return best_positive_index.item()
364 |
365 | def get_hardest_negatives_indexes(self, args, database_cache, query_features, neg_samples):
366 | neg_features = database_cache[neg_samples] # shape:(unknown, 8*feature_dim)
367 |
368 | neg_nums_in_matrix, window_loc = shift_window_on_descriptor(short_vector=query_features,
369 | long_matrix=neg_features,
370 | window_size=args.features_dim,
371 | divisor_factor=args.reduce_factor,
372 | sorted_indices_num=self.negs_num_per_query)
373 |
374 | neg_indexes = neg_samples[neg_nums_in_matrix]
375 | self.neg_focus_patch.append(window_loc)
376 | return neg_indexes.tolist()
377 |
378 | def compute_triplets_partial(self, args, model):
379 | triplets_global_indexes = []
380 | if self.mining == "partial":
381 | sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False)
382 | else:
383 | raise ValueError(f'sampled_queries_indexes is set wrong')
384 |
385 | sampled_database_indexes = np.random.choice(self.database_num, self.neg_samples_num, replace=False)
386 | # Take all the positives
387 | positives_indexes = [self.hard_positives_per_query[i] for i in sampled_queries_indexes]
388 | positives_indexes = [p for pos in positives_indexes for p in pos]
389 | # Merge them into database_indexes and remove duplicates
390 | database_indexes = list(sampled_database_indexes) + positives_indexes
391 | database_indexes = list(np.unique(database_indexes))
392 |
393 | subset_ds_database = Subset(self, database_indexes)
394 | subset_ds_query = Subset(self, list(sampled_queries_indexes + self.database_num))
395 |
396 | database_cache = self.compute_cache_database(args, model, subset_ds_database,
397 | database_cache_shape=(
398 | self.database_num, args.split_nums * args.features_dim))
399 |
400 | query_cache = self.compute_cache_query(args, model, subset_ds_query,
401 | query_cache_shape=(self.queries_num, args.features_dim),
402 | database_num=self.database_num)
403 |
404 | for query_index in tqdm(sampled_queries_indexes, ncols=100, desc='Neg Mining'):
405 |
406 | query_features = self.get_query_features(query_index, query_cache)
407 | best_positive_index = self.get_best_positive_index(args, query_index, database_cache, query_features)
408 |
409 | # Choose the hardest negatives within sampled_database_indexes, ensuring that there are no positives
410 | soft_positives = self.soft_positives_per_query[query_index]
411 | neg_indexes = np.setdiff1d(sampled_database_indexes, soft_positives, assume_unique=True)
412 |
413 | # Take all database images that are negatives and are within the sampled database images (aka database_indexes)
414 | neg_indexes = self.get_hardest_negatives_indexes(args, database_cache, query_features, neg_indexes)
415 | triplets_global_indexes.append((int(query_index), best_positive_index, *neg_indexes))
416 |
417 | self.triplets_global_indexes = torch.tensor(triplets_global_indexes)
418 |
419 |
420 | class RAMEfficient2DMatrix:
421 | """This class behaves similarly to a numpy.ndarray initialized
422 | with np.zeros(), but is implemented to save RAM when the rows
423 | within the 2D array are sparse. In this case it's needed because
424 | we don't always compute features for each image, just for few of
425 | them"""
426 |
427 | def __init__(self, shape, dtype=np.float32):
428 | self.shape = shape
429 | self.dtype = dtype
430 | self.matrix = [None] * shape[0]
431 |
432 | def __setitem__(self, indexes, vals):
433 | assert vals.shape[1] == self.shape[1], f"{vals.shape[1]} {self.shape[1]}"
434 | for i, val in zip(indexes, vals):
435 | self.matrix[i] = val.astype(self.dtype, copy=False)
436 |
437 | def __getitem__(self, index):
438 | if hasattr(index, "__len__"):
439 | return np.array([self.matrix[i] for i in index])
440 | else:
441 | return self.matrix[index]
442 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: PanoVPR2
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - blas=1.0=mkl
9 | - brotlipy=0.7.0=py37h27cfd23_1003
10 | - bzip2=1.0.8=h7b6447c_0
11 | - ca-certificates=2022.10.11=h06a4308_0
12 | - certifi=2022.9.24=py37h06a4308_0
13 | - cffi=1.15.1=py37h74dc2b5_0
14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
15 | - cryptography=38.0.1=py37h9ce1e76_0
16 | - cudatoolkit=11.3.1=h2bc3f7f_2
17 | - faiss-cpu=1.5.3=py37h6bb024c_0
18 | - ffmpeg=4.3=hf484d3e_0
19 | - freetype=2.12.1=h4a9f257_0
20 | - giflib=5.2.1=h7b6447c_0
21 | - gmp=6.2.1=h295c915_3
22 | - gnutls=3.6.15=he1e5248_0
23 | - idna=3.4=py37h06a4308_0
24 | - intel-openmp=2021.4.0=h06a4308_3561
25 | - jpeg=9e=h7f8727e_0
26 | - lame=3.100=h7b6447c_0
27 | - lcms2=2.12=h3be6417_0
28 | - ld_impl_linux-64=2.38=h1181459_1
29 | - lerc=3.0=h295c915_0
30 | - libdeflate=1.8=h7f8727e_5
31 | - libffi=3.3=he6710b0_2
32 | - libgcc-ng=11.2.0=h1234567_1
33 | - libgomp=11.2.0=h1234567_1
34 | - libiconv=1.16=h7f8727e_2
35 | - libidn2=2.3.2=h7f8727e_0
36 | - libpng=1.6.37=hbc83047_0
37 | - libstdcxx-ng=11.2.0=h1234567_1
38 | - libtasn1=4.16.0=h27cfd23_0
39 | - libtiff=4.4.0=hecacb30_2
40 | - libunistring=0.9.10=h27cfd23_0
41 | - libuv=1.40.0=h7b6447c_0
42 | - libwebp=1.2.4=h11a3e52_0
43 | - libwebp-base=1.2.4=h5eee18b_0
44 | - lz4-c=1.9.3=h295c915_1
45 | - mkl=2021.4.0=h06a4308_640
46 | - mkl-service=2.4.0=py37h7f8727e_0
47 | - mkl_fft=1.3.1=py37hd3c417c_0
48 | - mkl_random=1.2.2=py37h51133e4_0
49 | - ncurses=6.3=h5eee18b_3
50 | - nettle=3.7.3=hbbd107a_1
51 | - numpy=1.21.5=py37h6c91a56_3
52 | - numpy-base=1.21.5=py37ha15fc14_3
53 | - openh264=2.1.1=h4ff587b_0
54 | - openssl=1.1.1s=h7f8727e_0
55 | - pillow=9.2.0=py37hace64e9_1
56 | - pip=22.2.2=py37h06a4308_0
57 | - pycparser=2.21=pyhd3eb1b0_0
58 | - pyopenssl=22.0.0=pyhd3eb1b0_0
59 | - pysocks=1.7.1=py37_1
60 | - python=3.7.15=haa1d7c7_0
61 | - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0
62 | - pytorch-mutex=1.0=cuda
63 | - readline=8.2=h5eee18b_0
64 | - setuptools=65.5.0=py37h06a4308_0
65 | - six=1.16.0=pyhd3eb1b0_1
66 | - sqlite=3.40.0=h5082296_0
67 | - tk=8.6.12=h1ccaba5_0
68 | - torchaudio=0.11.0=py37_cu113
69 | - torchvision=0.12.0=py37_cu113
70 | - typing_extensions=4.3.0=py37h06a4308_0
71 | - urllib3=1.26.12=py37h06a4308_0
72 | - wheel=0.37.1=pyhd3eb1b0_0
73 | - xz=5.2.6=h5eee18b_0
74 | - zlib=1.2.13=h5eee18b_0
75 | - zstd=1.5.2=ha4553b6_0
76 | - pip:
77 | - absl-py==1.3.0
78 | - addict==2.4.0
79 | - cachetools==5.2.0
80 | - click==8.1.3
81 | - colorama==0.4.6
82 | - cycler==0.11.0
83 | - einops==0.6.0
84 | - filelock==3.8.0
85 | - fonttools==4.38.0
86 | - google-auth==2.14.1
87 | - google-auth-oauthlib==0.4.6
88 | - googledrivedownloader==0.4
89 | - grpcio==1.50.0
90 | - huggingface-hub==0.0.12
91 | - imageio==2.26.0
92 | - importlib-metadata==5.0.0
93 | - joblib==1.2.0
94 | - kiwisolver==1.4.4
95 | - markdown==3.4.1
96 | - markdown-it-py==2.2.0
97 | - markupsafe==2.1.1
98 | - matplotlib==3.5.3
99 | - mdurl==0.1.2
100 | - mmcls==0.25.0
101 | - mmcv-full==1.7.1
102 | - mmsegmentation==0.30.0
103 | - model-index==0.1.11
104 | - networkx==2.6.3
105 | - oauthlib==3.2.2
106 | - opencv-python==4.7.0.72
107 | - openmim==0.3.6
108 | - ordered-set==4.1.0
109 | - packaging==21.3
110 | - pandas==1.3.5
111 | - prettytable==3.6.0
112 | - protobuf==3.20.3
113 | - pyasn1==0.4.8
114 | - pyasn1-modules==0.2.8
115 | - pydantic==1.10.7
116 | - pygments==2.14.0
117 | - pyparsing==3.0.9
118 | - python-dateutil==2.8.2
119 | - pytz==2022.7.1
120 | - pywavelets==1.3.0
121 | - pyyaml==6.0
122 | - regex==2022.10.31
123 | - requests==2.26.0
124 | - requests-oauthlib==1.3.1
125 | - rich==13.3.1
126 | - rsa==4.9
127 | - sacremoses==0.0.53
128 | - scikit-image==0.17.2
129 | - scikit-learn==0.24.1
130 | - scipy==1.7.3
131 | - seaborn==0.12.2
132 | - staticmap==0.5.4
133 | - tabulate==0.9.0
134 | - tensorboard==2.11.0
135 | - tensorboard-data-server==0.6.1
136 | - tensorboard-plugin-wit==1.8.1
137 | - threadpoolctl==3.1.0
138 | - tifffile==2021.11.2
139 | - timm==0.6.12
140 | - tokenizers==0.10.3
141 | - torchscan==0.1.1
142 | - tqdm==4.48.2
143 | - transformers==4.8.2
144 | - wcwidth==0.2.6
145 | - werkzeug==2.2.2
146 | - yapf==0.32.0
147 | - zipp==3.10.0
148 | prefix: /home/shize/.conda/envs/PanoVPR
149 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zafirshi/PanoVPR/7f576b7679691882fc8f0346930deb0aff6d1e38/model/__init__.py
--------------------------------------------------------------------------------
/model/aggregation.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import torch
4 | import faiss
5 | import logging
6 | import numpy as np
7 | from tqdm import tqdm
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.nn.parameter import Parameter
11 | from torch.utils.data import DataLoader, SubsetRandomSampler
12 |
13 | import model.functional as LF
14 | import model.normalization as normalization
15 |
16 | class MAC(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | def forward(self, x):
20 | return LF.mac(x)
21 | def __repr__(self):
22 | return self.__class__.__name__ + '()'
23 |
24 | class SPoC(nn.Module):
25 | def __init__(self):
26 | super().__init__()
27 | def forward(self, x):
28 | return LF.spoc(x)
29 | def __repr__(self):
30 | return self.__class__.__name__ + '()'
31 |
32 | class GeM(nn.Module):
33 | def __init__(self, p=3, eps=1e-6, work_with_tokens=False):
34 | super().__init__()
35 | self.p = Parameter(torch.ones(1)*p)
36 | self.eps = eps
37 | self.work_with_tokens=work_with_tokens
38 | def forward(self, x):
39 | return LF.gem(x, p=self.p, eps=self.eps, work_with_tokens=self.work_with_tokens)
40 | def __repr__(self):
41 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
42 |
43 | class RMAC(nn.Module):
44 | def __init__(self, L=3, eps=1e-6):
45 | super().__init__()
46 | self.L = L
47 | self.eps = eps
48 | def forward(self, x):
49 | return LF.rmac(x, L=self.L, eps=self.eps)
50 | def __repr__(self):
51 | return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')'
52 |
53 |
54 | class Flatten(torch.nn.Module):
55 | def __init__(self): super().__init__()
56 | def forward(self, x): assert x.shape[2] == x.shape[3] == 1; return x[:,:,0,0]
57 |
58 | class RRM(nn.Module):
59 | """Residual Retrieval Module as described in the paper
60 | `Leveraging EfficientNet and Contrastive Learning for AccurateGlobal-scale
61 | Location Estimation `
62 | """
63 | def __init__(self, dim):
64 | super().__init__()
65 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
66 | self.flatten = Flatten()
67 | self.ln1 = nn.LayerNorm(normalized_shape=dim)
68 | self.fc1 = nn.Linear(in_features=dim, out_features=dim)
69 | self.relu = nn.ReLU()
70 | self.fc2 = nn.Linear(in_features=dim, out_features=dim)
71 | self.ln2 = nn.LayerNorm(normalized_shape=dim)
72 | self.l2 = normalization.L2Norm()
73 | def forward(self, x):
74 | x = self.avgpool(x)
75 | x = self.flatten(x)
76 | x = self.ln1(x)
77 | identity = x
78 | out = self.fc2(self.relu(self.fc1(x)))
79 | out += identity
80 | out = self.l2(self.ln2(out))
81 | return out
82 |
83 |
84 | # based on https://github.com/lyakaap/NetVLAD-pytorch/blob/master/netvlad.py
85 | class NetVLAD(nn.Module):
86 | """NetVLAD layer implementation"""
87 |
88 | def __init__(self, clusters_num=64, dim=128, normalize_input=True, work_with_tokens=False):
89 | """
90 | Args:
91 | clusters_num : int
92 | The number of clusters
93 | dim : int
94 | Dimension of descriptors
95 | alpha : float
96 | Parameter of initialization. Larger value is harder assignment.
97 | normalize_input : bool
98 | If true, descriptor-wise L2 normalization is applied to input.
99 | """
100 | super().__init__()
101 | self.clusters_num = clusters_num
102 | self.dim = dim
103 | self.alpha = 0
104 | self.normalize_input = normalize_input
105 | self.work_with_tokens = work_with_tokens
106 | if work_with_tokens:
107 | self.conv = nn.Conv1d(dim, clusters_num, kernel_size=1, bias=False)
108 | else:
109 | self.conv = nn.Conv2d(dim, clusters_num, kernel_size=(1, 1), bias=False)
110 | self.centroids = nn.Parameter(torch.rand(clusters_num, dim))
111 |
112 | def init_params(self, centroids, descriptors):
113 | centroids_assign = centroids / np.linalg.norm(centroids, axis=1, keepdims=True)
114 | dots = np.dot(centroids_assign, descriptors.T)
115 | dots.sort(0)
116 | dots = dots[::-1, :] # sort, descending
117 |
118 | self.alpha = (-np.log(0.01) / np.mean(dots[0,:] - dots[1,:])).item()
119 | self.centroids = nn.Parameter(torch.from_numpy(centroids))
120 | if self.work_with_tokens:
121 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha * centroids_assign).unsqueeze(2))
122 | else:
123 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha*centroids_assign).unsqueeze(2).unsqueeze(3))
124 | self.conv.bias = None
125 |
126 | def forward(self, x):
127 | if self.work_with_tokens:
128 | x = x.permute(0, 2, 1)
129 | N, D, _ = x.shape[:]
130 | else:
131 | N, D, H, W = x.shape[:]
132 | if self.normalize_input:
133 | x = F.normalize(x, p=2, dim=1) # Across descriptor dim
134 | x_flatten = x.view(N, D, -1) # shape: N, D, HW
135 | soft_assign = self.conv(x).view(N, self.clusters_num, -1)
136 | soft_assign = F.softmax(soft_assign, dim=1)
137 | vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device)
138 | for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage
139 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \
140 | self.centroids[D:D+1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
141 | residual = residual * soft_assign[:,D:D+1,:].unsqueeze(2)
142 | vlad[:,D:D+1,:] = residual.sum(dim=-1)
143 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
144 | vlad = vlad.view(N, -1) # Flatten
145 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
146 | return vlad
147 |
148 | def initialize_netvlad_layer(self, args, cluster_ds, backbone):
149 | if args.backbone.startswith('swin'):
150 | descriptors_num = 20000
151 | # for swin choose 40 swin_output(16,49,768)->(16,40,768) :sample 100 in res18_output(16,196,256)->(16,100,256)
152 | descs_num_per_image = 40
153 | else:
154 | descriptors_num = 50000
155 | descs_num_per_image = 100
156 | images_num = math.ceil(descriptors_num / descs_num_per_image)
157 | random_sampler_query = SubsetRandomSampler(np.random.choice(range(cluster_ds.database_num,
158 | cluster_ds.database_num+cluster_ds.queries_num),
159 | int(images_num), replace=False))
160 | random_sampler_database = SubsetRandomSampler(np.random.choice(range(cluster_ds.database_num),
161 | int(images_num), replace=False))
162 |
163 | random_dl_query = DataLoader(dataset=cluster_ds, num_workers=args.num_workers,
164 | batch_size=args.infer_batch_size, sampler=random_sampler_query)
165 |
166 | random_dl_database = DataLoader(dataset=cluster_ds, num_workers=args.num_workers,
167 | batch_size=args.infer_batch_size, sampler=random_sampler_database)
168 |
169 | with torch.no_grad():
170 | backbone = backbone.eval()
171 | logging.debug("Extracting features to initialize NetVLAD layer")
172 |
173 | descriptors_q = np.zeros(shape=(descriptors_num, args.features_dim), dtype=np.float32)
174 | for iteration, (inputs, idx) in enumerate(tqdm(random_dl_query, ncols=100, desc='Initializing NetVLAD [Query]')):
175 | inputs = inputs.to(args.device)
176 | outputs = backbone(inputs)
177 | # vit special case: set backbone forward output to a tensor
178 | if args.backbone.startswith('vit'):
179 | outputs = outputs.last_hidden_state
180 |
181 | if args.backbone.startswith('swin') or args.backbone.startswith('vit'):
182 | # swin don't need to reshape
183 | norm_outputs = F.normalize(outputs, p=2, dim=2)
184 | image_descriptors = norm_outputs
185 | else:
186 | norm_outputs = F.normalize(outputs, p=2, dim=1)
187 | image_descriptors = norm_outputs.view(norm_outputs.shape[0], args.features_dim, -1).permute(0, 2, 1)
188 | image_descriptors = image_descriptors.cpu().numpy()
189 | batchix = iteration * args.infer_batch_size * descs_num_per_image
190 | for ix in range(image_descriptors.shape[0]):
191 | sample = np.random.choice(image_descriptors.shape[1], descs_num_per_image, replace=False)
192 | startix = batchix + ix * descs_num_per_image
193 | descriptors_q[startix:startix + descs_num_per_image, :] = image_descriptors[ix, sample, :]
194 |
195 | descriptors_db = np.zeros(shape=(descriptors_num, args.features_dim), dtype=np.float32)
196 | for iteration, (inputs, idx) in enumerate(tqdm(random_dl_database, ncols=100, desc='Initializing NetVLAD [Database]')):
197 | rand_patch_idx = np.random.choice(args.split_nums, 1, replace=False)
198 | inputs_width = inputs.shape[3]
199 | stride = int(inputs_width / args.split_nums)
200 | left = int(stride * rand_patch_idx)
201 | right = int(stride * rand_patch_idx + args.resize[1])
202 | if left < right <= inputs_width:
203 | inputs = inputs[:, :, :, left:right]
204 | elif left < inputs_width < right:
205 | inputs = torch.cat((inputs[:, :, :, left:], inputs[:, :, :, :right - inputs_width]), dim=-1)
206 |
207 | inputs = inputs.to(args.device)
208 | outputs = backbone(inputs)
209 | # vit special case: set backbone forward output to a tensor
210 | if args.backbone.startswith('vit'):
211 | outputs = outputs.last_hidden_state
212 |
213 | if args.backbone.startswith('swin') or args.backbone.startswith('vit'):
214 | # swin don't need to reshape
215 | norm_outputs = F.normalize(outputs, p=2, dim=2)
216 | image_descriptors = norm_outputs
217 | else:
218 | norm_outputs = F.normalize(outputs, p=2, dim=1)
219 | image_descriptors = norm_outputs.view(norm_outputs.shape[0], args.features_dim, -1).permute(0, 2, 1)
220 | image_descriptors = image_descriptors.cpu().numpy()
221 | batchix = iteration * args.infer_batch_size * descs_num_per_image
222 | for ix in range(image_descriptors.shape[0]):
223 | sample = np.random.choice(image_descriptors.shape[1], descs_num_per_image, replace=False)
224 | startix = batchix + ix * descs_num_per_image
225 | descriptors_db[startix:startix + descs_num_per_image, :] = image_descriptors[ix, sample, :]
226 |
227 | descriptors = np.concatenate((descriptors_q,descriptors_db),axis=0)
228 | kmeans = faiss.Kmeans(args.features_dim, self.clusters_num, niter=100, verbose=False)
229 | kmeans.train(descriptors)
230 | logging.debug(f"NetVLAD centroids shape: {kmeans.centroids.shape}")
231 | self.init_params(kmeans.centroids, descriptors)
232 | self = self.to(args.device)
233 |
234 |
235 | class CRNModule(nn.Module):
236 | def __init__(self, dim):
237 | super().__init__()
238 | # Downsample pooling
239 | self.downsample_pool = nn.AvgPool2d(kernel_size=3, stride=(2, 2),
240 | padding=0, ceil_mode=True)
241 |
242 | # Multiscale Context Filters
243 | self.filter_3_3 = nn.Conv2d(in_channels=dim, out_channels=32,
244 | kernel_size=(3, 3), padding=1)
245 | self.filter_5_5 = nn.Conv2d(in_channels=dim, out_channels=32,
246 | kernel_size=(5, 5), padding=2)
247 | self.filter_7_7 = nn.Conv2d(in_channels=dim, out_channels=20,
248 | kernel_size=(7, 7), padding=3)
249 |
250 | # Accumulation weight
251 | self.acc_w = nn.Conv2d(in_channels=84, out_channels=1, kernel_size=(1, 1))
252 | # Upsampling
253 | self.upsample = F.interpolate
254 |
255 | self._initialize_weights()
256 |
257 | def _initialize_weights(self):
258 | # Initialize Context Filters
259 | torch.nn.init.xavier_normal_(self.filter_3_3.weight)
260 | torch.nn.init.constant_(self.filter_3_3.bias, 0.0)
261 | torch.nn.init.xavier_normal_(self.filter_5_5.weight)
262 | torch.nn.init.constant_(self.filter_5_5.bias, 0.0)
263 | torch.nn.init.xavier_normal_(self.filter_7_7.weight)
264 | torch.nn.init.constant_(self.filter_7_7.bias, 0.0)
265 |
266 | torch.nn.init.constant_(self.acc_w.weight, 1.0)
267 | torch.nn.init.constant_(self.acc_w.bias, 0.0)
268 | self.acc_w.weight.requires_grad = False
269 | self.acc_w.bias.requires_grad = False
270 |
271 | def forward(self, x):
272 | # Contextual Reweighting Network
273 | x_crn = self.downsample_pool(x)
274 |
275 | # Compute multiscale context filters g_n
276 | g_3 = self.filter_3_3(x_crn)
277 | g_5 = self.filter_5_5(x_crn)
278 | g_7 = self.filter_7_7(x_crn)
279 | g = torch.cat((g_3, g_5, g_7), dim=1)
280 | g = F.relu(g)
281 |
282 | w = F.relu(self.acc_w(g)) # Accumulation weight
283 | mask = self.upsample(w, scale_factor=2, mode='bilinear') # Reweighting Mask
284 |
285 | return mask
286 |
287 |
288 | class CRN(NetVLAD):
289 | def __init__(self, clusters_num=64, dim=128, normalize_input=True):
290 | super().__init__(clusters_num, dim, normalize_input)
291 | self.crn = CRNModule(dim)
292 |
293 | def forward(self, x):
294 | N, D, H, W = x.shape[:]
295 | if self.normalize_input:
296 | x = F.normalize(x, p=2, dim=1) # Across descriptor dim
297 |
298 | mask = self.crn(x)
299 |
300 | x_flatten = x.view(N, D, -1)
301 | soft_assign = self.conv(x).view(N, self.clusters_num, -1)
302 | soft_assign = F.softmax(soft_assign, dim=1)
303 |
304 | # Weight soft_assign using CRN's mask
305 | soft_assign = soft_assign * mask.view(N, 1, H * W)
306 |
307 | vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device)
308 | for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage
309 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \
310 | self.centroids[D:D + 1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
311 | residual = residual * soft_assign[:, D:D + 1, :].unsqueeze(2)
312 | vlad[:, D:D + 1, :] = residual.sum(dim=-1)
313 |
314 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
315 | vlad = vlad.view(N, -1) # Flatten
316 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
317 | return vlad
318 |
319 |
--------------------------------------------------------------------------------
/model/functional.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | def sare_ind(query, positive, negative):
7 | '''all 3 inputs are supposed to be shape 1xn_features'''
8 | dist_pos = ((query - positive)**2).sum(1)
9 | dist_neg = ((query - negative)**2).sum(1)
10 |
11 | dist = - torch.cat((dist_pos, dist_neg))
12 | dist = F.log_softmax(dist, 0)
13 |
14 | #loss = (- dist[:, 0]).mean() on a batch
15 | loss = -dist[0]
16 | return loss
17 |
18 | def sare_joint(query, positive, negatives):
19 | '''query and positive have to be 1xn_features; whereas negatives has to be
20 | shape n_negative x n_features. n_negative is usually 10'''
21 | # NOTE: the implementation is the same if batch_size=1 as all operations
22 | # are vectorial. If there were the additional n_batch dimension a different
23 | # handling of that situation would have to be implemented here.
24 | # This function is declared anyway for the sake of clarity as the 2 should
25 | # be called in different situations because, even though there would be
26 | # no Exceptions, there would actually be a conceptual error.
27 | return sare_ind(query, positive, negatives)
28 |
29 | def mac(x):
30 | return F.adaptive_max_pool2d(x, (1,1))
31 |
32 | def spoc(x):
33 | return F.adaptive_avg_pool2d(x, (1,1))
34 |
35 | def gem(x, p=3, eps=1e-6, work_with_tokens=False):
36 | if work_with_tokens:
37 | x = x.permute(0, 2, 1)
38 | # unseqeeze to maintain compatibility with Flatten
39 | return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1./p).unsqueeze(3)
40 | else:
41 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
42 |
43 | def rmac(x, L=3, eps=1e-6):
44 | ovr = 0.4 # desired overlap of neighboring regions
45 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
46 | W = x.size(3)
47 | H = x.size(2)
48 | w = min(W, H)
49 | # w2 = math.floor(w/2.0 - 1)
50 | b = (max(H, W)-w)/(steps-1)
51 | (tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension
52 | # region overplus per dimension
53 | Wd = 0;
54 | Hd = 0;
55 | if H < W:
56 | Wd = idx.item() + 1
57 | elif H > W:
58 | Hd = idx.item() + 1
59 | v = F.max_pool2d(x, (x.size(-2), x.size(-1)))
60 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v)
61 | for l in range(1, L+1):
62 | wl = math.floor(2*w/(l+1))
63 | wl2 = math.floor(wl/2 - 1)
64 | if l+Wd == 1:
65 | b = 0
66 | else:
67 | b = (W-wl)/(l+Wd-1)
68 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates
69 | if l+Hd == 1:
70 | b = 0
71 | else:
72 | b = (H-wl)/(l+Hd-1)
73 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates
74 | for i_ in cenH.tolist():
75 | for j_ in cenW.tolist():
76 | if wl == 0:
77 | continue
78 | R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:]
79 | R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()]
80 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1)))
81 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt)
82 | v += vt
83 | return v
84 |
85 |
--------------------------------------------------------------------------------
/model/network.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from collections import OrderedDict
4 |
5 | import timm
6 | import torch
7 | import logging
8 | import torchvision
9 | from torch import nn
10 | from os.path import join
11 | from timm.models.swin_transformer import SwinTransformer
12 | from timm.models.convnext import ConvNeXt
13 |
14 | from model.aggregation import Flatten
15 | from model.normalization import L2Norm
16 | import model.aggregation as aggregation
17 | from model.non_local import NonLocalBlock
18 | from tools import util
19 |
20 | from mmseg.models.decode_heads import FPNHead
21 | from mmseg.ops import Upsample, resize
22 |
23 |
24 | class GeoLocalizationNet(nn.Module):
25 | """The used networks are composed of a backbone and an aggregation layer.
26 | """
27 | def __init__(self, args):
28 | super().__init__()
29 | self.backbone = get_backbone(args)
30 |
31 | self.arch_name = args.backbone
32 | self.aggregation = get_aggregation(args)
33 | self.self_att = False
34 |
35 | if args.aggregation in ["gem", "spoc", "mac", "rmac"]:
36 | if args.l2 == "before_pool":
37 | self.aggregation = nn.Sequential(L2Norm(), self.aggregation, Flatten())
38 | elif args.l2 == "after_pool":
39 | self.aggregation = nn.Sequential(self.aggregation, L2Norm(), Flatten())
40 | elif args.l2 == "none":
41 | self.aggregation = nn.Sequential(self.aggregation, Flatten())
42 |
43 | if args.fc_output_dim != None:
44 | # Concatenate fully connected layer to the aggregation layer
45 | self.aggregation = nn.Sequential(self.aggregation,
46 | nn.Linear(args.features_dim, args.fc_output_dim),
47 | L2Norm())
48 | args.features_dim = args.fc_output_dim
49 | if args.non_local:
50 | non_local_list = [NonLocalBlock(channel_feat=get_output_channels_dim(self.backbone),
51 | channel_inner=args.channel_bottleneck)]* args.num_non_local
52 | self.non_local = nn.Sequential(*non_local_list)
53 | self.self_att = True
54 |
55 | def forward(self, x):
56 | x = self.backbone(x)
57 |
58 | if self.self_att:
59 | x = self.non_local(x)
60 |
61 | x = self.aggregation(x)
62 |
63 | return x
64 |
65 |
66 | class SwinBackbone(SwinTransformer):
67 |
68 | def __init__(self,
69 | backbone_name: str,
70 | depths=(2,2,6,2),
71 | img_size=224,
72 | window_size=7,
73 | embed_dim=96,
74 | num_heads=(3, 6, 12, 24),
75 | patch_size=4,
76 | patch_norm=True,
77 | norm_layer=nn.LayerNorm):
78 | super().__init__(img_size=img_size, patch_size=patch_size, patch_norm=patch_norm, window_size=window_size,
79 | depths=depths, norm_layer=norm_layer,embed_dim=embed_dim,num_heads=num_heads)
80 |
81 | self.depths = depths
82 | self.multi_feature = []
83 | self.out_layer_idx = self.get_multi_layer_out(depths)
84 |
85 | # load pretrained params
86 | original_model = timm.create_model(backbone_name, pretrained=True)
87 | self.load_state_dict(original_model.state_dict(), strict=False)
88 |
89 | self.query_patch_embed = util.PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3,
90 | embed_dim=self.embed_dim,
91 | norm_layer=norm_layer if patch_norm else None)
92 |
93 | self.dataset_patch_embed = util.PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=5,
94 | embed_dim=self.embed_dim,
95 | norm_layer=norm_layer if patch_norm else None)
96 |
97 | @staticmethod
98 | def get_multi_layer_out(depths):
99 | out_layer_idx = list(depths)
100 | cache = 0
101 | for idx, i in enumerate(depths):
102 | cache += i
103 | out_layer_idx[idx] = cache
104 | return out_layer_idx
105 |
106 | def load_state_dict(self, state_dict, strict=True):
107 |
108 | new_state_dict = OrderedDict()
109 | for k, v in state_dict.items():
110 | if not k.startswith('patch_embed'):
111 | new_state_dict[k] = v
112 |
113 | super().load_state_dict(new_state_dict, strict=strict)
114 |
115 | def forward_features(self, x):
116 | if x.shape[1] == 3:
117 | x = self.query_patch_embed(x)
118 | elif x.shape[1] == 5:
119 | x = self.dataset_patch_embed(x)
120 | else:
121 | raise ValueError(f'input tensor channel should be 3 or 5, but get {x.shape[1]} now')
122 |
123 | if self.absolute_pos_embed is not None:
124 | x = x + self.absolute_pos_embed
125 | x = self.pos_drop(x)
126 |
127 | multi_feature = []
128 | for layer_idx, layer in enumerate(self.layers):
129 | if layer_idx in range(self.num_layers):
130 | B, L, C = x.shape
131 | multi_feature.append(x.view(B, int(L**0.5), int(L**0.5), C).permute(0,3,1,2))
132 |
133 | x = layer(x)
134 | self.multi_feature = multi_feature
135 | assert len(self.multi_feature) == len(self.depths)
136 |
137 | x = self.norm(x) # B L C
138 | return x
139 |
140 |
141 | def forward(self, x):
142 | x = self.forward_features(x)
143 | return x
144 |
145 |
146 | class FPNHeadPooling(FPNHead):
147 | def __init__(self, feature_dim=None, **kwargs):
148 | """
149 | fuse multi-layer features
150 | :param feature_dim: (int) channel_nums we want to get after going through this module
151 | :param kwargs: some important args needed pass in
152 | - in_channels (int|Sequence[int]): Input channels.
153 | - channels (int): Channels after modules, before unify-channel which is a middle tmp channel
154 | - num_classes (int): Number of classes. default=1000 [dont use]
155 | - feature_strides (tuple[int]): The strides for input feature maps.
156 | stack_lateral. All strides suppose to be power of 2. The first
157 | one is of largest resolution.
158 | """
159 | super().__init__(**kwargs)
160 | self.conv_unify = nn.Conv2d(self.channels, feature_dim, kernel_size=1)
161 | pass
162 |
163 |
164 | def unify_channel(self,x):
165 | out = self.conv_unify(x)
166 | return out
167 |
168 |
169 | def forward(self, inputs):
170 | x = self._transform_inputs(inputs)
171 |
172 | output = self.scale_heads[0](x[0])
173 | for i in range(1, len(self.feature_strides)):
174 | # non inplace
175 | output = output + resize(
176 | self.scale_heads[i](x[i]),
177 | size=output.shape[2:],
178 | mode='bilinear',
179 | align_corners=self.align_corners)
180 |
181 | output = self.unify_channel(output) # B,C,H,W
182 | output = torch.flatten(output,start_dim=2).permute(0,2,1) # BCHW -> BCL -> BLC
183 | return output
184 |
185 |
186 | class ConvNeXtBackbone(ConvNeXt):
187 | def __init__(self,
188 | backbone_name:str,
189 | depths=(3,3,9,3),
190 | dims=(96,192,384,768)):
191 | super().__init__(depths=depths, dims=dims)
192 |
193 | original_model = timm.create_model(backbone_name, pretrained=True)
194 | self.load_state_dict(original_model.state_dict(), strict=False)
195 |
196 | def forward(self, x):
197 | # Drop the head
198 | x = self.forward_features(x)
199 | return x
200 |
201 |
202 | def get_aggregation(args):
203 | if args.aggregation == "gem":
204 | return aggregation.GeM(work_with_tokens=args.work_with_tokens)
205 | elif args.aggregation == "spoc":
206 | return aggregation.SPoC()
207 | elif args.aggregation == "mac":
208 | return aggregation.MAC()
209 | elif args.aggregation == "rmac":
210 | return aggregation.RMAC()
211 | elif args.aggregation == "netvlad":
212 | return aggregation.NetVLAD(clusters_num=args.netvlad_clusters, dim=args.features_dim,
213 | work_with_tokens=args.work_with_tokens)
214 | elif args.aggregation == 'crn':
215 | return aggregation.CRN(clusters_num=args.netvlad_clusters, dim=args.features_dim)
216 | elif args.aggregation == "rrm":
217 | return aggregation.RRM(args.features_dim)
218 | elif args.aggregation == 'none'\
219 | or args.aggregation == 'cls' \
220 | or args.aggregation == 'seqpool':
221 | return nn.Identity()
222 |
223 |
224 | def get_backbone(args):
225 | # The aggregation layer works differently based on the type of architecture
226 | args.work_with_tokens = args.backbone.startswith('swin')
227 |
228 | if args.backbone.startswith("swin"):
229 | if args.backbone.endswith("tiny"):
230 | # check input image size
231 | assert args.resize[0] == 224
232 | model_cfg = dict(depths=(2, 2, 6, 2),
233 | window_size=7, img_size=224,
234 | embed_dim=96, num_heads=(3,6,12,24))
235 | backbone = SwinBackbone('swin_tiny_patch4_window7_224', **model_cfg)
236 | args.features_dim = 96 * 8
237 | elif args.backbone.endswith("small"):
238 | assert args.resize[0] == 224
239 | model_cfg = dict(depths=(2, 2, 18, 2),
240 | window_size=7, img_size=224,
241 | embed_dim=96,num_heads=(3,6,12,24))
242 | backbone = SwinBackbone('swin_small_patch4_window7_224', **model_cfg)
243 | args.features_dim = 96 * 8
244 | elif args.backbone.endswith("base"):
245 | assert args.resize[0] == 384 or args.resize[0] == 224
246 | if args.resize[0] == 384:
247 | model_cfg = dict(depths=(2, 2, 18, 2),
248 | window_size=12, img_size=384,
249 | embed_dim=128, num_heads=(4, 8, 16, 32))
250 | backbone = SwinBackbone('swin_base_patch4_window12_384', **model_cfg)
251 | else:
252 | model_cfg = dict(depths=(2, 2, 18, 2),
253 | window_size=7, img_size=224,
254 | embed_dim=128, num_heads=(4, 8, 16, 32))
255 | backbone = SwinBackbone('swin_base_patch4_window7_224_in22k', **model_cfg)
256 | args.features_dim = 128 * 8
257 | else:
258 | raise NotImplementedError(f"The interface of {args.backbone} is not implemented")
259 | return backbone
260 |
261 | elif args.backbone.startswith("convnext"):
262 | if args.backbone.endswith("tiny"):
263 | # check input image size
264 | assert args.resize[0] == 224, f'Input size should be either 224 or 384, but get {args.resize[0]}'
265 | model_cfg = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
266 | backbone = ConvNeXtBackbone('convnext_tiny', **model_cfg)
267 | args.features_dim = 96 * 8
268 | elif args.backbone.endswith("small"):
269 | assert args.resize[0] == 224, f'Input size should be either 224 or 384, but get {args.resize[0]}'
270 | model_cfg = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
271 | backbone = ConvNeXtBackbone('convnext_small', **model_cfg)
272 | args.features_dim = 96 * 8
273 | elif args.backbone.endswith("base"):
274 | assert args.resize[0] == 384, f'Input size should be either 224 or 384, but get {args.resize[0]}'
275 | model_cfg = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
276 | backbone = ConvNeXtBackbone('convnext_base', **model_cfg)
277 | args.features_dim = 128 * 8
278 | else:
279 | raise NotImplementedError(f"The interface of {args.backbone} is not implemented")
280 | return backbone
281 |
282 |
283 | def get_output_channels_dim(model, type:str= 'feat'):
284 | """Return the number of channels in the output of a model."""
285 | if type == 'feat':
286 | return model(torch.ones([1, 3, 224, 224])).shape[1]
287 | elif type == 'token':
288 | return model(torch.ones([1, 3, 224, 224])).shape[2]
289 | else:
290 | raise Exception(f'type Err: which should be feat or token but get {type}')
291 |
292 |
--------------------------------------------------------------------------------
/model/non_local.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import einops
4 |
5 |
6 | class NonLocalBlock(nn.Module):
7 | def __init__(self, channel_feat, channel_inner, gamma=1):
8 | super().__init__()
9 | self.q_conv = nn.Conv2d(in_channels=channel_feat,
10 | out_channels=channel_inner,
11 | kernel_size=1)
12 | self.k_conv = nn.Conv2d(in_channels=channel_feat,
13 | out_channels=channel_inner,
14 | kernel_size=1)
15 | self.v_conv = nn.Conv2d(in_channels=channel_feat,
16 | out_channels=channel_inner,
17 | kernel_size=1)
18 | self.merge_conv = nn.Conv2d(in_channels=channel_inner,
19 | out_channels=channel_feat,
20 | kernel_size=1)
21 | self.gamma = gamma
22 |
23 | def forward(self, x):
24 | b, c, h, w = x.shape[:]
25 | q_tensor = self.q_conv(x)
26 | k_tensor = self.k_conv(x)
27 | v_tensor = self.v_conv(x)
28 |
29 | q_tensor = einops.rearrange(q_tensor, 'b c h w -> b c (h w)')
30 | k_tensor = einops.rearrange(k_tensor, 'b c h w -> b c (h w)')
31 | v_tensor = einops.rearrange(v_tensor, 'b c h w -> b c (h w)')
32 |
33 | qk_tensor = torch.einsum('b c i, b c j -> b i j', q_tensor, k_tensor) # where i = j = (h * w)
34 | attention = torch.softmax(qk_tensor, -1)
35 | out = torch.einsum('b n i, b c i -> b c n', attention, v_tensor)
36 | out = einops.rearrange(out, 'b c (h w) -> b c h w', h=h, w=w)
37 | out = self.merge_conv(out)
38 | out = self.gamma * out + x
39 | return out
40 |
--------------------------------------------------------------------------------
/model/normalization.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class L2Norm(nn.Module):
6 | def __init__(self, dim=1):
7 | super().__init__()
8 | self.dim = dim
9 | def forward(self, x):
10 | return F.normalize(x, p=2, dim=self.dim)
11 |
12 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import set_sbn_eps_mode
12 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
13 | from .batchnorm import patch_sync_batchnorm, convert_model
14 | from .replicate import DataParallelWithCallback, patch_replication_callback
15 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 | import contextlib
13 |
14 | import torch
15 | import torch.nn.functional as F
16 |
17 | from torch.nn.modules.batchnorm import _BatchNorm
18 |
19 | try:
20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
21 | except ImportError:
22 | ReduceAddCoalesced = Broadcast = None
23 |
24 | try:
25 | from jactorch.parallel.comm import SyncMaster
26 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
27 | except ImportError:
28 | from .comm import SyncMaster
29 | from .replicate import DataParallelWithCallback
30 |
31 | __all__ = [
32 | 'set_sbn_eps_mode',
33 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
34 | 'patch_sync_batchnorm', 'convert_model'
35 | ]
36 |
37 |
38 | SBN_EPS_MODE = 'clamp'
39 |
40 |
41 | def set_sbn_eps_mode(mode):
42 | global SBN_EPS_MODE
43 | assert mode in ('clamp', 'plus')
44 | SBN_EPS_MODE = mode
45 |
46 |
47 | def _sum_ft(tensor):
48 | """sum over the first and last dimention"""
49 | return tensor.sum(dim=0).sum(dim=-1)
50 |
51 |
52 | def _unsqueeze_ft(tensor):
53 | """add new dimensions at the front and the tail"""
54 | return tensor.unsqueeze(0).unsqueeze(-1)
55 |
56 |
57 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
58 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
59 |
60 |
61 | class _SynchronizedBatchNorm(_BatchNorm):
62 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
63 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
64 |
65 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
66 | track_running_stats=track_running_stats)
67 |
68 | if not self.track_running_stats:
69 | import warnings
70 | warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
71 |
72 | self._sync_master = SyncMaster(self._data_parallel_master)
73 |
74 | self._is_parallel = False
75 | self._parallel_id = None
76 | self._slave_pipe = None
77 |
78 | def forward(self, input):
79 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
80 | if not (self._is_parallel and self.training):
81 | return F.batch_norm(
82 | input, self.running_mean, self.running_var, self.weight, self.bias,
83 | self.training, self.momentum, self.eps)
84 |
85 | # Resize the input to (B, C, -1).
86 | input_shape = input.size()
87 | assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
88 | input = input.view(input.size(0), self.num_features, -1)
89 |
90 | # Compute the sum and square-sum.
91 | sum_size = input.size(0) * input.size(2)
92 | input_sum = _sum_ft(input)
93 | input_ssum = _sum_ft(input ** 2)
94 |
95 | # Reduce-and-broadcast the statistics.
96 | if self._parallel_id == 0:
97 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
98 | else:
99 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
100 |
101 | # Compute the output.
102 | if self.affine:
103 | # MJY:: Fuse the multiplication for speed.
104 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
105 | else:
106 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
107 |
108 | # Reshape it.
109 | return output.view(input_shape)
110 |
111 | def __data_parallel_replicate__(self, ctx, copy_id):
112 | self._is_parallel = True
113 | self._parallel_id = copy_id
114 |
115 | # parallel_id == 0 means master device.
116 | if self._parallel_id == 0:
117 | ctx.sync_master = self._sync_master
118 | else:
119 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
120 |
121 | def _data_parallel_master(self, intermediates):
122 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
123 |
124 | # Always using same "device order" makes the ReduceAdd operation faster.
125 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
126 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
127 |
128 | to_reduce = [i[1][:2] for i in intermediates]
129 | to_reduce = [j for i in to_reduce for j in i] # flatten
130 | target_gpus = [i[1].sum.get_device() for i in intermediates]
131 |
132 | sum_size = sum([i[1].sum_size for i in intermediates])
133 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
134 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
135 |
136 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
137 |
138 | outputs = []
139 | for i, rec in enumerate(intermediates):
140 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
141 |
142 | return outputs
143 |
144 | def _compute_mean_std(self, sum_, ssum, size):
145 | """Compute the mean and standard-deviation with sum and square-sum. This method
146 | also maintains the moving average on the master device."""
147 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
148 | mean = sum_ / size
149 | sumvar = ssum - sum_ * mean
150 | unbias_var = sumvar / (size - 1)
151 | bias_var = sumvar / size
152 |
153 | if hasattr(torch, 'no_grad'):
154 | with torch.no_grad():
155 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
156 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
157 | else:
158 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
159 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
160 |
161 | if SBN_EPS_MODE == 'clamp':
162 | return mean, bias_var.clamp(self.eps) ** -0.5
163 | elif SBN_EPS_MODE == 'plus':
164 | return mean, (bias_var + self.eps) ** -0.5
165 | else:
166 | raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
167 |
168 |
169 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
170 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
171 | mini-batch.
172 |
173 | .. math::
174 |
175 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
176 |
177 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
178 | standard-deviation are reduced across all devices during training.
179 |
180 | For example, when one uses `nn.DataParallel` to wrap the network during
181 | training, PyTorch's implementation normalize the tensor on each device using
182 | the statistics only on that device, which accelerated the computation and
183 | is also easy to implement, but the statistics might be inaccurate.
184 | Instead, in this synchronized version, the statistics will be computed
185 | over all training samples distributed on multiple devices.
186 |
187 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
188 | as the built-in PyTorch implementation.
189 |
190 | The mean and standard-deviation are calculated per-dimension over
191 | the mini-batches and gamma and beta are learnable parameter vectors
192 | of size C (where C is the input size).
193 |
194 | During training, this layer keeps a running estimate of its computed mean
195 | and variance. The running sum is kept with a default momentum of 0.1.
196 |
197 | During evaluation, this running mean/variance is used for normalization.
198 |
199 | Because the BatchNorm is done over the `C` dimension, computing statistics
200 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
201 |
202 | Args:
203 | num_features: num_features from an expected input of size
204 | `batch_size x num_features [x width]`
205 | eps: a value added to the denominator for numerical stability.
206 | Default: 1e-5
207 | momentum: the value used for the running_mean and running_var
208 | computation. Default: 0.1
209 | affine: a boolean value that when set to ``True``, gives the layer learnable
210 | affine parameters. Default: ``True``
211 |
212 | Shape::
213 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
214 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
215 |
216 | Examples:
217 | >>> # With Learnable Parameters
218 | >>> m = SynchronizedBatchNorm1d(100)
219 | >>> # Without Learnable Parameters
220 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
221 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
222 | >>> output = m(input)
223 | """
224 |
225 | def _check_input_dim(self, input):
226 | if input.dim() != 2 and input.dim() != 3:
227 | raise ValueError('expected 2D or 3D input (got {}D input)'
228 | .format(input.dim()))
229 |
230 |
231 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
232 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
233 | of 3d inputs
234 |
235 | .. math::
236 |
237 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
238 |
239 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
240 | standard-deviation are reduced across all devices during training.
241 |
242 | For example, when one uses `nn.DataParallel` to wrap the network during
243 | training, PyTorch's implementation normalize the tensor on each device using
244 | the statistics only on that device, which accelerated the computation and
245 | is also easy to implement, but the statistics might be inaccurate.
246 | Instead, in this synchronized version, the statistics will be computed
247 | over all training samples distributed on multiple devices.
248 |
249 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
250 | as the built-in PyTorch implementation.
251 |
252 | The mean and standard-deviation are calculated per-dimension over
253 | the mini-batches and gamma and beta are learnable parameter vectors
254 | of size C (where C is the input size).
255 |
256 | During training, this layer keeps a running estimate of its computed mean
257 | and variance. The running sum is kept with a default momentum of 0.1.
258 |
259 | During evaluation, this running mean/variance is used for normalization.
260 |
261 | Because the BatchNorm is done over the `C` dimension, computing statistics
262 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
263 |
264 | Args:
265 | num_features: num_features from an expected input of
266 | size batch_size x num_features x height x width
267 | eps: a value added to the denominator for numerical stability.
268 | Default: 1e-5
269 | momentum: the value used for the running_mean and running_var
270 | computation. Default: 0.1
271 | affine: a boolean value that when set to ``True``, gives the layer learnable
272 | affine parameters. Default: ``True``
273 |
274 | Shape::
275 | - Input: :math:`(N, C, H, W)`
276 | - Output: :math:`(N, C, H, W)` (same shape as input)
277 |
278 | Examples:
279 | >>> # With Learnable Parameters
280 | >>> m = SynchronizedBatchNorm2d(100)
281 | >>> # Without Learnable Parameters
282 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
283 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
284 | >>> output = m(input)
285 | """
286 |
287 | def _check_input_dim(self, input):
288 | if input.dim() != 4:
289 | raise ValueError('expected 4D input (got {}D input)'
290 | .format(input.dim()))
291 |
292 |
293 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
294 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
295 | of 4d inputs
296 |
297 | .. math::
298 |
299 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
300 |
301 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
302 | standard-deviation are reduced across all devices during training.
303 |
304 | For example, when one uses `nn.DataParallel` to wrap the network during
305 | training, PyTorch's implementation normalize the tensor on each device using
306 | the statistics only on that device, which accelerated the computation and
307 | is also easy to implement, but the statistics might be inaccurate.
308 | Instead, in this synchronized version, the statistics will be computed
309 | over all training samples distributed on multiple devices.
310 |
311 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
312 | as the built-in PyTorch implementation.
313 |
314 | The mean and standard-deviation are calculated per-dimension over
315 | the mini-batches and gamma and beta are learnable parameter vectors
316 | of size C (where C is the input size).
317 |
318 | During training, this layer keeps a running estimate of its computed mean
319 | and variance. The running sum is kept with a default momentum of 0.1.
320 |
321 | During evaluation, this running mean/variance is used for normalization.
322 |
323 | Because the BatchNorm is done over the `C` dimension, computing statistics
324 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
325 | or Spatio-temporal BatchNorm
326 |
327 | Args:
328 | num_features: num_features from an expected input of
329 | size batch_size x num_features x depth x height x width
330 | eps: a value added to the denominator for numerical stability.
331 | Default: 1e-5
332 | momentum: the value used for the running_mean and running_var
333 | computation. Default: 0.1
334 | affine: a boolean value that when set to ``True``, gives the layer learnable
335 | affine parameters. Default: ``True``
336 |
337 | Shape::
338 | - Input: :math:`(N, C, D, H, W)`
339 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
340 |
341 | Examples:
342 | >>> # With Learnable Parameters
343 | >>> m = SynchronizedBatchNorm3d(100)
344 | >>> # Without Learnable Parameters
345 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
346 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
347 | >>> output = m(input)
348 | """
349 |
350 | def _check_input_dim(self, input):
351 | if input.dim() != 5:
352 | raise ValueError('expected 5D input (got {}D input)'
353 | .format(input.dim()))
354 |
355 |
356 | @contextlib.contextmanager
357 | def patch_sync_batchnorm():
358 | import torch.nn as nn
359 |
360 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
361 |
362 | nn.BatchNorm1d = SynchronizedBatchNorm1d
363 | nn.BatchNorm2d = SynchronizedBatchNorm2d
364 | nn.BatchNorm3d = SynchronizedBatchNorm3d
365 |
366 | yield
367 |
368 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
369 |
370 |
371 | def convert_model(module):
372 | """Traverse the input module and its child recursively
373 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
374 | to SynchronizedBatchNorm*N*d
375 |
376 | Args:
377 | module: the input module needs to be convert to SyncBN model
378 |
379 | Examples:
380 | >>> import torch.nn as nn
381 | >>> import torchvision
382 | >>> # m is a standard pytorch model
383 | >>> m = torchvision.models.resnet18(True)
384 | >>> m = nn.DataParallel(m)
385 | >>> # after convert, m is using SyncBN
386 | >>> m = convert_model(m)
387 | """
388 | if isinstance(module, torch.nn.DataParallel):
389 | mod = module.module
390 | mod = convert_model(mod)
391 | mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
392 | return mod
393 |
394 | mod = module
395 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
396 | torch.nn.modules.batchnorm.BatchNorm2d,
397 | torch.nn.modules.batchnorm.BatchNorm3d],
398 | [SynchronizedBatchNorm1d,
399 | SynchronizedBatchNorm2d,
400 | SynchronizedBatchNorm3d]):
401 | if isinstance(module, pth_module):
402 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
403 | mod.running_mean = module.running_mean
404 | mod.running_var = module.running_var
405 | if module.affine:
406 | mod.weight.data = module.weight.data.clone().detach()
407 | mod.bias.data = module.bias.data.clone().detach()
408 |
409 | for name, child in module.named_children():
410 | mod.add_module(name, convert_model(child))
411 |
412 | return mod
413 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNorm2dReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/model/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
29 |
30 |
--------------------------------------------------------------------------------
/parser.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import argparse
5 |
6 |
7 | def parse_arguments():
8 | parser = argparse.ArgumentParser(description="Benchmarking Visual Geolocalization",
9 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10 | # Training parameters
11 | parser.add_argument("--train_batch_size", type=int, default=2,
12 | help="Number of triplets (query, pos, negs) in a batch. Each triplet consists of 12 images")
13 | parser.add_argument("--infer_batch_size", type=int, default=16,
14 | help="Batch size for inference (caching and testing)")
15 | parser.add_argument("--criterion", type=str, default='triplet',
16 | help='loss to be used')
17 | parser.add_argument("--margin", type=float, default=0.1,
18 | help="margin for the triplet loss")
19 | parser.add_argument("--epochs_num", type=int, default=60,
20 | help="number of epochs to train for")
21 | parser.add_argument("--patience", type=int, default=10, help="_")
22 | parser.add_argument("--lr", type=float, default=0.00001, help="_")
23 | parser.add_argument("--optim", type=str, default="adam", choices=["adam", "sgd"])
24 | parser.add_argument("--cache_refresh_rate", type=int, default=125,
25 | help="How often to refresh cache, in number of queries")
26 | # MARK: fine-tune should focus on
27 | parser.add_argument("--queries_per_epoch", type=int, default=500,
28 | help="How many queries to consider for one epoch. Must be multiple of cache_refresh_rate")
29 | parser.add_argument("--negs_num_per_query", type=int, default=10,
30 | help="How many negatives to consider per each query in the loss")
31 | parser.add_argument("--neg_samples_num", type=int, default=125,
32 | help="How many negatives to use to compute the hardest ones")
33 | parser.add_argument("--mining", type=str, default="partial", choices=["partial", "full", "random"])
34 | # Model parameters
35 | parser.add_argument("--backbone", type=str, default="swin_tiny",
36 | choices=["swin_tiny", "swin_small","convnext_tiny", "convnext_small"], help="_")
37 | parser.add_argument("--l2", type=str, default="before_pool", choices=["before_pool", "after_pool", "none"],
38 | help="When (and if) to apply the l2 norm with shallow aggregation layers")
39 | parser.add_argument("--aggregation", type=str, default="gem", choices=["gem", "spoc", "mac", "rmac"])
40 |
41 | parser.add_argument('--pca_dim', type=int, default=None, help="PCA dimension (number of principal components). If None, PCA is not used.")
42 | parser.add_argument('--num_non_local', type=int, default=1, help="Num of non local blocks")
43 | parser.add_argument("--non_local", action='store_true', help="_")
44 | parser.add_argument('--channel_bottleneck', type=int, default=128, help="Channel bottleneck for Non-Local blocks")
45 | parser.add_argument('--fc_output_dim', type=int, default=None,
46 | help="Output dimension of fully connected layer. If None, don't use a fully connected layer.")
47 | parser.add_argument("--clip", type=int, default=10, choices=[1,3,5,7,10], help='Gradient clip avoiding loss wave')
48 | # Shift window parameters (in image or vector)
49 | parser.add_argument('--reduce_factor', type=int, default=1, help='/n -> window_stride shorten factor compared to args.feature')
50 | parser.add_argument('--split_nums', type=int, default=16, help='choose how many parts to split pano_datasets image')
51 | # Initialization parameters
52 | parser.add_argument("--seed", type=int, default=0)
53 | parser.add_argument("--resume", type=str, default=None,
54 | help="Path to load checkpoint from, for resuming training or testing.")
55 | # Other parameters
56 | parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
57 | parser.add_argument("--num_workers", type=int, default=6, help="num_workers for all dataloaders")
58 | parser.add_argument('--resize', type=int, default=[224, 224], nargs=2, help="Resizing shape for images (HxW).")
59 | parser.add_argument('--query_resize', type=int, default=[448, 448], nargs=2, help="Resize for query")
60 | parser.add_argument('--database_resize', type=int, default=[448, 3584], nargs=2, help="Resize for database")
61 | parser.add_argument('--test_method', type=str, default="hard_resize",
62 | choices=["hard_resize", "single_query", "central_crop", "five_crops", "nearest_crop", "maj_voting"],
63 | help="This includes pre/post-processing methods and prediction refinement")
64 | parser.add_argument("--val_positive_dist_threshold", type=int, default=25, help="_")
65 | parser.add_argument("--train_positives_dist_threshold", type=int, default=10, help="_")
66 | parser.add_argument('--recall_values', type=int, default=[1, 5, 10, 20], nargs="+",
67 | help="Recalls to be computed, such as R@5.")
68 | # Data augmentation parameters
69 | parser.add_argument("--brightness", type=float, default=None, help="_")
70 | parser.add_argument("--contrast", type=float, default=None, help="_")
71 | parser.add_argument("--saturation", type=float, default=None, help="_")
72 | parser.add_argument("--hue", type=float, default=None, help="_")
73 | parser.add_argument("--rand_perspective", type=float, default=None, help="_")
74 | parser.add_argument("--horizontal_flip", action='store_true', help="_")
75 | parser.add_argument("--random_resized_crop", type=float, default=None, help="_")
76 | parser.add_argument("--random_rotation", type=float, default=None, help="_")
77 | # Paths parameters
78 | parser.add_argument("--datasets_folder", type=str, default='/home/shize/Datasets',
79 | help="Path with all datasets")
80 | parser.add_argument("--dataset_name", type=str, default="pitts250k", help="Relative path of the dataset")
81 | parser.add_argument("--pca_dataset_folder", type=str, default=None,
82 | help="Path with images to be used to compute PCA (ie: pitts30k/images/train")
83 | parser.add_argument("--save_dir", type=str, default="SwinT",
84 | help="Folder name of the current run (saved in ./logs})")
85 | parser.add_argument("--title", type=str, required=True, help="Abstract the experiment config")
86 | parser.add_argument("--train_visual_save_path", type=str, default='visualize/Insert_x32w/train_set/',
87 | help="Mining [train] results save path")
88 | parser.add_argument("--val_visual_save_path", type=str, default='visualize/Insert_x32w/val_set/',
89 | help="Inference [val] results save path")
90 | parser.add_argument("--test_visual_save_path", type=str, default='visualize/Insert_x32w/test_set/',
91 | help="Inference [test] results save path")
92 | args = parser.parse_args()
93 |
94 | if args.datasets_folder == None:
95 | try:
96 | args.datasets_folder = os.environ['DATASETS_FOLDER']
97 | except KeyError:
98 | raise Exception("You should set the parameter --datasets_folder or export " +
99 | "the DATASETS_FOLDER environment variable as such \n" +
100 | "export DATASETS_FOLDER=../datasets_vg/datasets")
101 |
102 | if args.queries_per_epoch % args.cache_refresh_rate != 0:
103 | raise ValueError("Ensure that queries_per_epoch is divisible by cache_refresh_rate, " +
104 | f"because {args.queries_per_epoch} is not divisible by {args.cache_refresh_rate}")
105 |
106 | if args.pca_dim != None and args.pca_dataset_folder == None:
107 | raise ValueError("Please specify --pca_dataset_folder when using pca")
108 |
109 | if args.split_nums < 8:
110 | raise ValueError("split_nums should be specified to 8/16/24/32")
111 |
112 | return args
113 |
114 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import Tuple
4 | import unittest
5 | import faiss
6 | import torch
7 | import logging
8 | import numpy as np
9 | from matplotlib import pyplot as plt
10 | from numpy import ndarray
11 | from tqdm import tqdm
12 | from torch.utils.data import DataLoader
13 | from torch.utils.data.dataset import Subset
14 |
15 | from tools import commons
16 | import datasets_ws
17 | import parser
18 | from model import network
19 | from collections import OrderedDict
20 | from os.path import join
21 | from datetime import datetime
22 |
23 | from datasets_ws import shift_window_on_descriptor # import shift_window_similar calculate func
24 | from model.sync_batchnorm import convert_model
25 | from tools.visual import display_inference
26 |
27 |
28 | def test(args, eval_ds, model, test_method="hard_resize", pca=None, show_inference_results=None, save_path=None):
29 | """Compute features of the given dataset and compute the recalls."""
30 |
31 | assert test_method in ["hard_resize"], f"test_method can't be {test_method}"
32 |
33 | model = model.eval()
34 | eval_ds.test_method = test_method
35 |
36 | with torch.no_grad():
37 | logging.debug("Extracting database features for evaluation/testing")
38 |
39 | database_features = np.empty((eval_ds.database_num, args.split_nums * args.features_dim), dtype="float32")
40 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
41 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
42 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
43 |
44 | # Database inputs shape : B, N, C, resize_H, resize_W
45 | for inputs, indices in tqdm(database_dataloader, ncols=100, desc='Extracting database features'):
46 | B, C, H, W = inputs.shape
47 | inputs = torch.stack([datasets_ws.shift_window_on_img(one_pano, eval_ds.split_nums, eval_ds.window_stride,
48 | eval_ds.window_len) for one_pano in inputs])
49 | inputs = inputs.view(B * eval_ds.split_nums, C, eval_ds.resize[0], eval_ds.resize[1])
50 |
51 | features = model(inputs.to(args.device))
52 | # B*split_nums, feature_dim -> # B, split_nums*feature_dim
53 | features = torch.flatten(features.view(B, eval_ds.split_nums, -1), start_dim=1)
54 | features = features.cpu().numpy()
55 | if pca != None:
56 | features = pca.transform(features)
57 |
58 | database_features[indices.numpy(), :] = features
59 |
60 | logging.debug("Extracting queries features for evaluation/testing")
61 |
62 | queries_infer_batch_size = args.infer_batch_size
63 |
64 | queries_features = np.empty((eval_ds.queries_num, args.features_dim), dtype="float32")
65 | queries_subset_ds = Subset(eval_ds,
66 | list(range(eval_ds.database_num, eval_ds.database_num + eval_ds.queries_num)))
67 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
68 | batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda"))
69 |
70 | # Query features shape: B, C, H, W
71 | for inputs, indices in tqdm(queries_dataloader, ncols=100, desc='Extracting queries features'):
72 | features = model(inputs.to(args.device))
73 | features = features.cpu().numpy()
74 |
75 | if pca != None:
76 | features = pca.transform(features)
77 |
78 | # NOTE!! minus database_num to begin from 0
79 | queries_features[indices.numpy() - eval_ds.database_num, :] = features
80 |
81 | # Sliding Window Matching Descriptor
82 | shift_window_start = time.time()
83 | predictions = []
84 | focus_patch_loc = []
85 | for one_query_feature in queries_features:
86 | predictions_per_query, focus_patch_loc_per_query = shift_window_on_descriptor(one_query_feature,
87 | database_features,
88 | args.features_dim,
89 | args.reduce_factor,
90 | max(args.recall_values))
91 | predictions.append(predictions_per_query)
92 | focus_patch_loc.append(focus_patch_loc_per_query) # show results interface
93 | shift_window_end = time.time()
94 | print(f'Searching all query in pano databases uses time:{shift_window_end-shift_window_start:.3f}s')
95 |
96 | # Visualization of Inference Results
97 | if show_inference_results:
98 | os.makedirs(save_path, exist_ok=True)
99 | display_inference(eval_ds, predictions, save_path, focus_patch_loc)
100 |
101 | #### For each query, check if the predictions are correct
102 | check_start = time.time()
103 | positives_per_query = eval_ds.get_positives()
104 | # args.recall_values by default is [1, 5, 10, 20]
105 | recalls = np.zeros(len(args.recall_values))
106 | for query_index, pred in enumerate(predictions):
107 | for i, n in enumerate(args.recall_values):
108 | if np.any(np.in1d(pred[:n], positives_per_query[query_index])):
109 | recalls[i:] += 1
110 | break
111 | # Divide by the number of queries*100, so the recalls are in percentages
112 | recalls = recalls / eval_ds.queries_num * 100
113 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(args.recall_values, recalls)])
114 | check_end = time.time()
115 | print(f'Checking whether the predict is right uses time:{check_end-check_start:.3f}s')
116 | return recalls, recalls_str
117 |
118 |
119 | def main():
120 | # Initial setup: parser
121 | args = parser.parse_arguments()
122 |
123 | # Set Logger
124 | start_time = datetime.now()
125 | args.save_dir = join("logs", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S') + '_' + args.title)
126 | commons.setup_logging(args.save_dir)
127 | logging.info(f"The outputs are being saved in {args.save_dir}")
128 |
129 | # Initialize model
130 | model = network.GeoLocalizationNet(args)
131 | model = model.to(args.device)
132 |
133 | # Muti-GPU Setting
134 | if torch.cuda.device_count() > 1:
135 | model = torch.nn.DataParallel(model)
136 |
137 | # val_ds
138 | val_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "val")
139 | logging.debug(f"Val set: {val_ds}")
140 |
141 | # test_ds
142 | test_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "test")
143 | logging.debug(f"Test set: {test_ds}")
144 |
145 | # load model params and run
146 | best_model_state_dict = torch.load(join(args.resume, "best_model.pth"))["model_state_dict"]
147 |
148 | if not torch.cuda.device_count() >= 2:
149 | best_model_state_dict = OrderedDict({k.replace('module.', ''): v for (k, v) in best_model_state_dict.items()})
150 |
151 | model.load_state_dict(best_model_state_dict)
152 | logging.info('Load pretrained model correctly!')
153 |
154 | recalls, recalls_str = test(args, eval_ds=val_ds, model=model)
155 | logging.info(f"Recalls on [Val-set]:{val_ds}: {recalls_str}")
156 |
157 | recalls, recalls_str = test(args, eval_ds=test_ds, model=model, show_inference_results=None)
158 | logging.info(f"Recalls on [Test-set]:{test_ds}: {recalls_str}")
159 |
160 |
161 | if __name__ == '__main__':
162 | main()
163 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zafirshi/PanoVPR/7f576b7679691882fc8f0346930deb0aff6d1e38/tools/__init__.py
--------------------------------------------------------------------------------
/tools/commons.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | This file contains some functions and classes which can be useful in very diverse projects.
4 | """
5 |
6 | import os
7 | import sys
8 | import torch
9 | import random
10 | import logging
11 | import traceback
12 | import numpy as np
13 | from os.path import join
14 |
15 |
16 | def make_deterministic(seed=0, speedup=None):
17 | """Make results deterministic. If seed == -1, do not make deterministic.
18 | Running the script in a deterministic way might slow it down.
19 | """
20 | if seed == -1:
21 | return
22 | random.seed(seed)
23 | np.random.seed(seed)
24 | torch.manual_seed(seed)
25 | torch.cuda.manual_seed(seed)
26 | torch.backends.cudnn.deterministic = True
27 | torch.backends.cudnn.benchmark = True if speedup else False
28 |
29 |
30 | def setup_logging(save_dir, console="debug",
31 | info_filename="info.log", debug_filename="debug.log"):
32 | """Set up logging files and console output.
33 | Creates one file for INFO logs and one for DEBUG logs.
34 | Args:
35 | save_dir (str): creates the folder where to save the files.
36 | debug (str):
37 | if == "debug" prints on console debug messages and higher
38 | if == "info" prints on console info messages and higher
39 | if == None does not use console (useful when a logger has already been set)
40 | info_filename (str): the name of the info file. if None, don't create info file
41 | debug_filename (str): the name of the debug file. if None, don't create debug file
42 | """
43 | if os.path.exists(save_dir):
44 | raise FileExistsError(f"{save_dir} already exists!")
45 | os.makedirs(save_dir, exist_ok=True)
46 | # logging.Logger.manager.loggerDict.keys() to check which loggers are in use
47 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S")
48 | logger = logging.getLogger('')
49 | logger.setLevel(logging.DEBUG)
50 |
51 | if info_filename != None:
52 | info_file_handler = logging.FileHandler(join(save_dir, info_filename))
53 | info_file_handler.setLevel(logging.INFO)
54 | info_file_handler.setFormatter(base_formatter)
55 | logger.addHandler(info_file_handler)
56 |
57 | if debug_filename != None:
58 | debug_file_handler = logging.FileHandler(join(save_dir, debug_filename))
59 | debug_file_handler.setLevel(logging.DEBUG)
60 | debug_file_handler.setFormatter(base_formatter)
61 | logger.addHandler(debug_file_handler)
62 |
63 | if console != None:
64 | console_handler = logging.StreamHandler()
65 | if console == "debug": console_handler.setLevel(logging.DEBUG)
66 | if console == "info": console_handler.setLevel(logging.INFO)
67 | console_handler.setFormatter(base_formatter)
68 | logger.addHandler(console_handler)
69 |
70 | def exception_handler(type_, value, tb):
71 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb)))
72 | sys.excepthook = exception_handler
73 |
74 |
--------------------------------------------------------------------------------
/tools/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def shift_window_triple_loss(args, query_feature, database_feature, loss_fn):
5 | """
6 | Calculate the shift window triple loss for a given query feature against a database feature.
7 |
8 | This function computes the similarity between a query feature and different windows of a database feature.
9 | It uses a shift window approach where the window moves across the database feature in steps defined by
10 | the reduce_factor in args. The loss is calculated using the provided loss function (loss_fn) which takes
11 | the query feature, a positive match from the database, and a negative match from the database.
12 |
13 | Parameters:
14 | - args (Namespace): A namespace or similar object containing configuration parameters, including
15 | 'reduce_factor' which determines the step size for window shifting.
16 | - query_feature (torch.Tensor): A tensor representing the query feature.
17 | - database_feature (torch.Tensor): A tensor representing the database feature against which the query
18 | is compared.
19 | - loss_fn (function): A loss function that takes three arguments (query_feature, positive_feature,
20 | negative_feature) and returns a scalar loss value.
21 |
22 | Returns:
23 | - loss_sum (float): The accumulated loss over all the shifted windows.
24 | - min_index_in_row (torch.Tensor): The indices of the minimum values in the maintain_table, which
25 | represent the most similar window of database_feature to the query_feature.
26 |
27 | The function operates by sliding a window across the database feature and computing a similarity
28 | measure between the query feature and each window using L2 norm. It identifies the most similar
29 | window (positive match) and uses other windows as negative matches to compute the loss. The process
30 | is repeated for each window, and the losses are summed up to get the total loss.
31 |
32 | The function assumes that the query_feature and database_feature have compatible shapes and that the
33 | loss_fn is properly defined to handle the inputs.
34 |
35 | Example:
36 | >>> loss, indices = shift_window_triple_loss(args, query_feature, database_feature, loss_fn)
37 | """
38 | window_size = query_feature.shape[0]
39 | step = int(window_size/args.reduce_factor)
40 | left, right = 0, window_size
41 | loss_sum = 0.
42 |
43 | split_similarity_list = []
44 | width_bound = database_feature.shape[1]
45 | while left < width_bound:
46 | if right <= width_bound:
47 | split_similarity = torch.norm(query_feature - database_feature[:, left:right], p=2, dim=1)
48 | else:
49 | cycle = torch.concat((database_feature[:,left:],database_feature[:,:right-width_bound]),dim=1)
50 | split_similarity = torch.norm(query_feature-cycle,p=2,dim=1)
51 | split_similarity_list.append(split_similarity)
52 |
53 | left += step
54 | right += step
55 |
56 | # Similarity matrix for each split window and get the most similar slice_index using L2 Norm
57 | maintain_table = torch.transpose(torch.stack(split_similarity_list), 0, 1) # shape:11, split_nums
58 | maintain_table_aggregation, min_index_in_row = torch.min(maintain_table, dim=1)
59 |
60 | # 0 -> positive and slice long vector according to slice_index above
61 | if min_index_in_row[0] * step + window_size <= width_bound:
62 | filtered_positive = database_feature[0,
63 | min_index_in_row[0] * step:min_index_in_row[0] * step + window_size] # shape:feature_dim
64 | else:
65 | filtered_positive = torch.concat((database_feature[0,min_index_in_row[0] * step:],
66 | database_feature[0,:min_index_in_row[0] * step + window_size - width_bound]),
67 | dim=-1)
68 |
69 |
70 | for i in range(1, 11):
71 | # 1 -> 10 refer to negs
72 | if min_index_in_row[i] * step + window_size <=width_bound:
73 | filtered_negative = database_feature[i,
74 | min_index_in_row[i] * step:min_index_in_row[i] * step + window_size] # shape:feature_dim
75 | else:
76 | filtered_negative = torch.concat((database_feature[i,min_index_in_row[i] * step:],
77 | database_feature[i,:min_index_in_row[i] * step + window_size - width_bound]),
78 | dim=-1)
79 |
80 | loss_sum += loss_fn(query_feature, filtered_positive, filtered_negative)
81 |
82 | return loss_sum, min_index_in_row
83 |
84 |
--------------------------------------------------------------------------------
/tools/map_builder.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import os
4 | import cv2
5 | import math
6 | import numpy as np
7 | from glob import glob
8 | from skimage import io
9 | from os.path import join
10 | import matplotlib.cm as cm
11 | import matplotlib.pyplot as plt
12 | from collections import defaultdict
13 | from staticmap import StaticMap, Polygon
14 | from matplotlib.colors import ListedColormap
15 |
16 |
17 | def _lon_to_x(lon, zoom):
18 | if not (-180 <= lon <= 180): lon = (lon + 180) % 360 - 180
19 | return ((lon + 180.) / 360) * pow(2, zoom)
20 |
21 |
22 | def _lat_to_y(lat, zoom):
23 | if not (-90 <= lat <= 90): lat = (lat + 90) % 180 - 90
24 | return (1 - math.log(math.tan(lat * math.pi / 180) + 1 / math.cos(lat * math.pi / 180)) / math.pi) / 2 * pow(2, zoom)
25 |
26 |
27 | def _download_map_image(min_lat=45.0, min_lon=7.6, max_lat=45.1, max_lon=7.7, size=2000):
28 | """"Download a map of the chosen area as a numpy image"""
29 | mean_lat = (min_lat + max_lat) / 2
30 | mean_lon = (min_lon + max_lon) / 2
31 | static_map = StaticMap(size, size, url_template='https://api.mapbox.com/styles/v1/mapbox/satellite-v9/tiles/{z}/{x}/{y}?access_token=')
32 | static_map.add_polygon(
33 | Polygon(((min_lon, min_lat), (min_lon, max_lat), (max_lon, max_lat), (max_lon, min_lat)), None, '#FFFFFF'))
34 | zoom = static_map._calculate_zoom()
35 | static_map = StaticMap(size, size)
36 | image = static_map.render(zoom, [mean_lon, mean_lat])
37 | print(
38 | f"You can see the map on Google Maps at this link www.google.com/maps/place/@{mean_lat},{mean_lon},{zoom - 1}z")
39 | min_lat_px, min_lon_px, max_lat_px, max_lon_px = \
40 | static_map._y_to_px(_lat_to_y(min_lat, zoom)), \
41 | static_map._x_to_px(_lon_to_x(min_lon, zoom)), \
42 | static_map._y_to_px(_lat_to_y(max_lat, zoom)), \
43 | static_map._x_to_px(_lon_to_x(max_lon, zoom))
44 | assert 0 <= max_lat_px < min_lat_px < size and 0 <= min_lon_px < max_lon_px < size
45 | return np.array(image)[max_lat_px:min_lat_px, min_lon_px:max_lon_px], static_map, zoom
46 |
47 |
48 | def get_edges(coordinates, enlarge=0):
49 | """
50 | Send the edges of the coordinates, i.e. the most south, west, north and
51 | east coordinates.
52 | :param coordinates: A list of numpy.arrays of shape (Nx2)
53 | :param float enlarge: How much to increase the coordinates, to enlarge
54 | the area included between the points
55 | :return: a tuple with the four float
56 | """
57 | min_lat, min_lon, max_lat, max_lon = (*np.concatenate(coordinates).min(0), *np.concatenate(coordinates).max(0))
58 | diff_lat = (max_lat - min_lat) * enlarge
59 | diff_lon = (max_lon - min_lon) * enlarge
60 | inc_min_lat, inc_min_lon, inc_max_lat, inc_max_lon = \
61 | min_lat - diff_lat, min_lon - diff_lon, max_lat + diff_lat, max_lon + diff_lon
62 | return inc_min_lat, inc_min_lon, inc_max_lat, inc_max_lon
63 |
64 |
65 | def _create_map(coordinates, colors=None, dot_sizes=None, legend_names=None, map_intensity=0.6):
66 | dot_sizes = dot_sizes if dot_sizes is not None else [10] * len(coordinates)
67 | colors = colors if colors is not None else ["r"] * len(coordinates)
68 | assert len(coordinates) == len(dot_sizes) == len(colors), \
69 | f"The number of coordinates must be equals to the number of colors and dot_sizes, but they're " \
70 | f"{len(coordinates)}, {len(colors)}, {len(dot_sizes)}"
71 |
72 | # Add two dummy points to slightly enlarge the map
73 | min_lat, min_lon, max_lat, max_lon = get_edges(coordinates, enlarge=0.1)
74 | coordinates.append(np.array([[min_lat, min_lon], [max_lat, max_lon]]))
75 | # Download the map of the chosen area
76 | map_img, static_map, zoom = _download_map_image(min_lat, min_lon, max_lat, max_lon)
77 |
78 | scatters = []
79 | fig = plt.figure(figsize=(map_img.shape[1] / 100, map_img.shape[0] / 100), dpi=1000)
80 | for i, coord in enumerate(coordinates):
81 | for i in range(len(coord)): # Scale latitudes because of earth's curvature
82 | coord[i, 0] = -static_map._y_to_px(_lat_to_y(coord[i, 0], zoom))
83 | for coord, size, color in zip(coordinates, dot_sizes, colors):
84 | scatters.append(plt.scatter(coord[:, 1], coord[:, 0], s=size, color=color))
85 |
86 | if legend_names != None:
87 | plt.legend(scatters, legend_names, scatterpoints=10000, loc='upper left',
88 | ncol=1, framealpha=0, prop={"weight": "bold", "size": 20})
89 |
90 | min_lat, min_lon, max_lat, max_lon = get_edges(coordinates)
91 | plt.ylim(min_lat, max_lat)
92 | plt.xlim(min_lon, max_lon)
93 | fig.subplots_adjust(bottom=0, top=1, left=0, right=1)
94 | fig.canvas.draw()
95 | plot_img = np.array(fig.canvas.renderer._renderer)
96 | plt.close()
97 |
98 | plot_img = cv2.resize(plot_img[:, :, :3], map_img.shape[:2][::-1], interpolation=cv2.INTER_LANCZOS4)
99 | map_img[(map_img.sum(2) < 444)] = 188 # brighten dark pixels
100 | map_img = (((map_img / 255) ** map_intensity) * 255).astype(np.uint8) # fade map
101 | mask = (plot_img.sum(2) == 255 * 3)[:, :, None] # mask of plot, to find white pixels
102 | final_map = map_img * mask + plot_img * (~mask)
103 | return final_map
104 |
105 |
106 | def _get_coordinates_from_dataset(dataset_folder, extension="jpg"):
107 | """
108 | Takes as input the path of a dataset, such as "datasets/st_lucia/images"
109 | and returns
110 | [("train/database", [[45, 8.1], [45.2, 8.2]]), ("train/queries", [[45, 8.1], [45.2, 8.2]])]
111 | """
112 | images_paths = glob(join(dataset_folder, "**", f"*.{extension}"), recursive=True)
113 | if len(images_paths) != 0:
114 | print(f"I found {len(images_paths)} images in {dataset_folder}")
115 | else:
116 | raise ValueError(f"I found no images in {dataset_folder} !")
117 |
118 | grouped_gps_coords = defaultdict(list)
119 |
120 | for image_path in images_paths:
121 | full_path = os.path.dirname(image_path)
122 | full_parent_path, parent_dir = os.path.split(full_path)
123 | parent_parent_dir = os.path.split(full_parent_path)[1]
124 |
125 | # folder_name is for example "train - database"
126 | folder_name = " - ".join([parent_parent_dir, parent_dir])
127 |
128 | gps_lat = image_path.split("@")[6]
129 | gps_lon = image_path.split("@")[7] if not image_path.split("@")[7].endswith(f'{extension}') \
130 | else image_path.split("@")[7][:-len(extension)-1]
131 | gps_coords = gps_lat, gps_lon
132 | grouped_gps_coords[folder_name].append(gps_coords)
133 |
134 | grouped_gps_coords = sorted([(k, np.array(v).astype(np.float64))
135 | for k, v in grouped_gps_coords.items()])
136 | return grouped_gps_coords
137 |
138 |
139 | def build_map_from_dataset(dataset_folder, dot_sizes=None):
140 | """dataset_folder is the path that contains the 'images' folder."""
141 | grouped_gps_coords = _get_coordinates_from_dataset(join(dataset_folder))
142 | SORTED_FOLDERS = ["train - database", "train - queries", "val - database", "val - queries",
143 | "test - database", "test - queries"]
144 | try:
145 | grouped_gps_coords = sorted(grouped_gps_coords, key=lambda x: SORTED_FOLDERS.index(x[0]))
146 | except ValueError:
147 | pass # This dataset has different folder names than the standard train-val-test database-queries.
148 | coordinates = []
149 | legend_names = []
150 | for folder_name, coords in grouped_gps_coords:
151 | legend_names.append(f"{folder_name} - {len(coords)}")
152 | coordinates.append(coords)
153 |
154 | colors = cm.rainbow(np.linspace(0, 1, len(legend_names)))
155 | colors = ListedColormap(colors)
156 | colors = colors.colors
157 | if len(legend_names) == 1:
158 | legend_names = None # If there's only one folder, don't show the legend
159 | colors = np.array([[1, 0, 0]])
160 |
161 | map_img = _create_map(coordinates, colors, dot_sizes, legend_names)
162 |
163 | print(f"Map image resolution: {map_img.shape}")
164 | dataset_name = os.path.basename(os.path.abspath(dataset_folder))
165 | io.imsave(join(dataset_folder, f"map_{dataset_name}_final.png"), map_img)
166 |
167 | def main():
168 | dataset_folder = '/home/zafirshi/Datasets/YQ360'
169 | build_map_from_dataset(dataset_folder)
170 |
171 | if __name__ == '__main__':
172 | main()
173 |
--------------------------------------------------------------------------------
/tools/paper_utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import seaborn as sns
3 | import matplotlib.pyplot as plt
4 | import matplotlib.cm as cm
5 | import numpy as np
6 |
7 |
8 | def results_comparison_in_graph(out_name: str):
9 | method = ['NetVLAD', 'Berton et.al', 'Orhan et.al', 'Swin-T', 'PanoVPR(Swin-T)',
10 | 'Swin-S', 'PanoVPR(Swin-S)', 'ConvNeXt-T', 'PanoVPR(ConvNeXt-T)',
11 | 'ConvNeXt-S', 'PanoVPR(ConvNeXt-S)']
12 | params = [7.23, 86.86, 136.62, 28.29, 28.29, 49.61, 49.61, 28.59, 28.59, 50.22, 50.22]
13 | r1 = [4.0, 8.0, 47.0, 10.1, 41.4, 12.4, 38.2, 9.7, 34.0, 14.2, 48.8]
14 |
15 | fig, ax = plt.subplots()
16 |
17 | for idx, (m,p,r) in enumerate(zip(method,params,r1)):
18 | if m == 'NetVLAD':
19 | ax.scatter(x=p, y=r, s=100, c='#1f77b4', label=m, alpha=0.5)
20 | elif m == 'Berton et.al':
21 | ax.scatter(x=p, y=r, s=100, c='#ff7f0e', label=m, alpha=0.5)
22 | elif m == 'Orhan et.al':
23 | ax.scatter(x=p, y=r, s=100, c='#2ca02c', label=m, alpha=0.5)
24 | elif m == 'Swin-T' or m == 'PanoVPR(Swin-T)':
25 | if m == 'PanoVPR(Swin-T)':
26 | ax.scatter(x=p, y=r, s=100, c='#d62728', label=m, marker='*', alpha=0.5)
27 | # add curve
28 | x_base, y_base = params[idx-1], r1[idx-1]
29 | x_end, y_end = p, r
30 | ax.annotate("", xy=(x_end, y_end),xytext=(x_base, y_base),size=4, va="center", ha="center",
31 | arrowprops=dict(color='#d62728',
32 | arrowstyle="-|>, head_length=1, head_width=0.4",
33 | linewidth=1,
34 | connectionstyle="arc3,rad=-0.3",
35 | linestyle='dashed',
36 | )
37 | )
38 | else:
39 | ax.scatter(x=p, y=r, s=100, c='#d62728', label=m, alpha=0.5)
40 | elif m == 'Swin-S' or m == 'PanoVPR(Swin-S)':
41 | if m == 'PanoVPR(Swin-S)':
42 | ax.scatter(x=p, y=r, s=100, c='#bcbd22', label=m, marker='*', alpha=0.5)
43 | # add curve
44 | x_base, y_base = params[idx-1], r1[idx-1]
45 | x_end, y_end = p, r
46 | ax.annotate("", xy=(x_end, y_end),xytext=(x_base, y_base),size=4, va="center", ha="center",
47 | arrowprops=dict(color='#bcbd22',
48 | arrowstyle="-|>, head_length=1, head_width=0.4",
49 | linewidth=1,
50 | connectionstyle="arc3,rad=-0.2",
51 | linestyle='dashed',
52 | )
53 | )
54 | else:
55 | ax.scatter(x=p, y=r, s=100, c='#bcbd22', label=m, alpha=0.5)
56 | elif m == 'ConvNeXt-T' or m =='PanoVPR(ConvNeXt-T)':
57 | if m == 'PanoVPR(ConvNeXt-T)':
58 | ax.scatter(x=p, y=r, s=100, c='#7f7f7f', label=m, marker='*', alpha=0.5)
59 | # add curve
60 | x_base, y_base = params[idx-1], r1[idx-1]
61 | x_end, y_end = p, r
62 | ax.annotate("", xy=(x_end, y_end),xytext=(x_base, y_base),size=4, va="center", ha="center",
63 | arrowprops=dict(color='#7f7f7f',
64 | arrowstyle="-|>, head_length=1, head_width=0.4",
65 | linewidth=1,
66 | connectionstyle="arc3,rad=0.2",
67 | linestyle='dashed',
68 | )
69 | )
70 | else:
71 | ax.scatter(x=p, y=r, s=100, c='#7f7f7f', label=m, alpha=0.5)
72 | elif m == 'ConvNeXt-S' or m == 'PanoVPR(ConvNeXt-S)':
73 | if m == 'PanoVPR(ConvNeXt-S)':
74 | ax.scatter(x=p, y=r, s=100, c='#e377c2', label=m, marker='*', alpha=0.5)
75 | # add curve
76 | x_base, y_base = params[idx-1], r1[idx-1]
77 | x_end, y_end = p, r
78 | ax.annotate("", xy=(x_end, y_end),xytext=(x_base, y_base),size=4, va="center", ha="center",
79 | arrowprops=dict(color='#e377c2',
80 | arrowstyle="-|>, head_length=1, head_width=0.4",
81 | linewidth=1,
82 | connectionstyle="arc3,rad=0.3",
83 | linestyle='dashed',
84 | )
85 | )
86 | else:
87 | ax.scatter(x=p, y=r, s=100, c='#e377c2', label=m, alpha=0.5)
88 |
89 |
90 | # add legend
91 | ax.legend(loc="center right", prop = {'size':8})
92 |
93 | plt.title('Performance comparison of different methods')
94 | plt.xlabel('Parameters(M)')
95 | plt.ylabel('R@1(%)')
96 | plt.grid(True)
97 |
98 | plt.savefig(out_name, dpi=300, bbox_inches='tight')
99 | plt.show()
100 |
101 |
102 | if __name__=="__main__":
103 | results_comparison_in_graph('test.png')
--------------------------------------------------------------------------------
/tools/util.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | import torch
4 | import shutil
5 | import logging
6 | import argparse
7 | import numpy as np
8 | from collections import OrderedDict
9 | from os.path import join
10 |
11 | from sklearn.decomposition import PCA
12 |
13 |
14 | from torch import nn as nn
15 |
16 | from timm.models.layers.helpers import to_2tuple
17 | from timm.models.layers.trace_utils import _assert
18 | from mmcv.cnn import get_model_complexity_info
19 |
20 | import datasets_ws
21 | import parser
22 | from model import network
23 |
24 |
25 | def get_flops(input_shape=(224, 224)):
26 | # Initial setup: parser
27 | args = parser.parse_arguments()
28 |
29 | # Initialize model
30 | model = network.GeoLocalizationNet(args)
31 | model = model.to(args.device)
32 |
33 | # input shape
34 | input_shape = (3, input_shape[0], input_shape[1])
35 |
36 | model.eval()
37 |
38 | flops, params = get_model_complexity_info(model, input_shape)
39 | split_line = '=' * 30
40 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
41 | split_line, input_shape, flops, params))
42 | print('!!!Please be cautious if you use the results in papers. '
43 | 'You may need to check if all ops are supported and verify that the '
44 | 'flops computation is correct.')
45 |
46 |
47 | def save_checkpoint(args, state, is_best, filename):
48 | model_path = join(args.save_dir, filename)
49 | torch.save(state, model_path)
50 | if is_best:
51 | shutil.copyfile(model_path, join(args.save_dir, "best_model.pth"))
52 |
53 |
54 | def resume_model(args, model):
55 | checkpoint = torch.load(args.resume, map_location=args.device)
56 | if 'model_state_dict' in checkpoint:
57 | state_dict = checkpoint['model_state_dict']
58 | else:
59 | # The pre-trained models that we provide in the README do not have 'state_dict' in the keys as
60 | # the checkpoint is directly the state dict
61 | state_dict = checkpoint
62 | # if the model contains the prefix "module" which is appendend by
63 | # DataParallel, remove it to avoid errors when loading dict
64 | if list(state_dict.keys())[0].startswith('module'):
65 | state_dict = OrderedDict({k.replace('module.', ''): v for (k, v) in state_dict.items()})
66 | model.load_state_dict(state_dict)
67 | return model
68 |
69 |
70 | def resume_train(args, model, optimizer=None, strict=False):
71 | """Load model, optimizer, and other training parameters"""
72 | logging.debug(f"Loading checkpoint: {args.resume}")
73 | checkpoint = torch.load(args.resume)
74 | start_epoch_num = checkpoint["epoch_num"]
75 | model.load_state_dict(checkpoint["model_state_dict"], strict=strict)
76 | if optimizer:
77 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
78 | best_r5 = checkpoint["best_r5"]
79 | not_improved_num = checkpoint["not_improved_num"]
80 | logging.debug(f"Loaded checkpoint: start_epoch_num = {start_epoch_num}, " \
81 | f"current_best_R@5 = {best_r5:.1f}")
82 | if args.resume.endswith("last_model.pth"): # Copy best model to current save_dir
83 | shutil.copy(args.resume.replace("last_model.pth", "best_model.pth"), args.save_dir)
84 | return model, optimizer, best_r5, start_epoch_num, not_improved_num
85 |
86 |
87 | def compute_pca(args, model, pca_dataset_folder, full_features_dim):
88 | model = model.eval()
89 | pca_ds = datasets_ws.PCADataset(args, args.datasets_folder, pca_dataset_folder)
90 | dl = torch.utils.data.DataLoader(pca_ds, args.infer_batch_size, shuffle=True)
91 | pca_features = np.empty([min(len(pca_ds), 2**14), full_features_dim])
92 | with torch.no_grad():
93 | for i, images in enumerate(dl):
94 | if i*args.infer_batch_size >= len(pca_features): break
95 | features = model(images).cpu().numpy()
96 | pca_features[i*args.infer_batch_size : (i*args.infer_batch_size)+len(features)] = features
97 | pca = PCA(args.pca_dim)
98 | pca.fit(pca_features)
99 | return pca
100 |
101 |
102 | class PatchEmbed(nn.Module):
103 | """ 2D Image to Patch Embedding
104 | """
105 | def __init__(
106 | self,
107 | img_size=224,
108 | patch_size=16,
109 | in_chans=3,
110 | embed_dim=768,
111 | norm_layer=None,
112 | flatten=True,
113 | bias=True,
114 | ):
115 | super().__init__()
116 | img_size = to_2tuple(img_size)
117 | patch_size = to_2tuple(patch_size)
118 | self.img_size = img_size
119 | self.patch_size = patch_size
120 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
121 | self.num_patches = self.grid_size[0] * self.grid_size[1]
122 | self.flatten = flatten
123 |
124 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
125 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
126 |
127 | def forward(self, x):
128 | B, C, H, W = x.shape
129 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
130 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
131 | x = self.proj(x)
132 | if self.flatten:
133 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
134 | x = self.norm(x)
135 | return x
136 |
--------------------------------------------------------------------------------
/tools/visual.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from matplotlib import pyplot as plt
4 | from tqdm import tqdm
5 |
6 |
7 | def path_to_pil_img(path):
8 | return Image.open(path).convert("RGB")
9 |
10 |
11 | def add_mask(train_set, img, img_path_idx, win_len=448):
12 | pos_mask = np.zeros(train_set.database_resize, dtype=np.float64) + (100 / 255)
13 | pos_w = img.shape[1] # 448, 448*8, 3
14 | mask_left = (pos_w / train_set.split_nums) * img_path_idx
15 | mask_right = mask_left + win_len
16 |
17 | if mask_left < mask_right <= pos_w:
18 | pos_mask[:, int(mask_left):int(mask_right)] = 0
19 | elif mask_left< pos_w < mask_right:
20 | pos_mask[:int(mask_right-pos_w), int(mask_left):] = 0
21 | else:
22 | raise Exception('Adding mask on img goes WRONG!')
23 | pos_mask_3d = np.stack((pos_mask, pos_mask, pos_mask), axis=2)
24 |
25 | normal_img = (img / 255.).astype(np.float64)
26 | img = np.clip(normal_img - pos_mask_3d, 0, 1)
27 | img = (img * 255).astype(np.uint8)
28 |
29 | return img
30 |
31 |
32 | def display_mining(train_set, triplets_global_indexes_array, save_path, pos_patch_loc, neg_patch_loc):
33 | full_images_path = np.array(train_set.images_paths)
34 | db_num = train_set.database_num
35 | for iter_idx, triple in enumerate(tqdm(triplets_global_indexes_array, ncols=100, desc='Show Mining Results')):
36 | if iter_idx % 50 == 0:
37 | query_idx, pos_idx, neg_idx = triple[0], triple[1], triple[2:]
38 | query_img_path = str(full_images_path[query_idx + db_num])
39 | pos_img_path = str(full_images_path[pos_idx])
40 | neg_img_path = full_images_path[neg_idx].tolist()
41 |
42 | query = np.array(path_to_pil_img(query_img_path).resize((448, 448)))
43 | pos = np.array(path_to_pil_img(pos_img_path).resize((3584, 448)))
44 | negs = list(map(lambda x: np.array(path_to_pil_img(x).resize((3584, 448))), neg_img_path))
45 | black = np.zeros((448 * 10, 448, 3), dtype=np.uint8)
46 |
47 | # mask
48 | pos = add_mask(train_set, pos, pos_patch_loc[iter_idx].item())
49 | for i, each_neg in enumerate(negs):
50 | negs[i] = add_mask(train_set, each_neg, neg_patch_loc[iter_idx][i].item())
51 |
52 | right = np.concatenate([pos] + negs, axis=0)
53 | left = np.concatenate((query, black), axis=0)
54 | one_in_all = np.concatenate((left, right), axis=1)
55 |
56 | plt.imsave(save_path + f'iter{iter_idx}.png', one_in_all)
57 |
58 | with open(save_path + 'record.txt', 'a+') as f:
59 | f.write(f'======> iter_idx:{iter_idx}/{train_set.queries_num} <======\n')
60 | f.write(f'Query_Path: {query_img_path}\n')
61 | f.write(f'Pos_Path: {pos_img_path}\n')
62 | for each_neg_path in neg_img_path:
63 | f.write(f'Neg_Path: {each_neg_path}\n')
64 |
65 |
66 | def display_inference(eval_ds, predictions, save_path, focus_patch_loc):
67 | full_images_path = np.array(eval_ds.images_paths)
68 | db_num = eval_ds.database_num
69 | for query_idx, each_query_pred in enumerate(tqdm(predictions, ncols=100, desc='Show Inference Results')):
70 | if query_idx % 50 == 0:
71 | query_img_path = str(full_images_path[query_idx + db_num])
72 | db_img_path = full_images_path[each_query_pred[:5]].tolist() # 便于展示,选top5
73 |
74 | query = np.array(path_to_pil_img(query_img_path).resize((448, 448)))
75 | db = list(map(lambda x: np.array(path_to_pil_img(x).resize((3584, 448))), db_img_path))
76 | black = np.zeros((448 * 4, 448, 3), dtype=np.uint8)
77 |
78 | # add mask
79 | for i, each_db in enumerate(db):
80 | db[i] = add_mask(eval_ds, each_db, focus_patch_loc[query_idx][i].item())
81 |
82 | right = np.concatenate(db, axis=0)
83 | left = np.concatenate((query, black), axis=0)
84 | one_in_all = np.concatenate((left, right), axis=1)
85 |
86 | # save img
87 | plt.imsave(save_path + f'query_idx{query_idx}.png', one_in_all)
88 |
89 | with open(save_path + 'inference_list.txt', 'a+') as f:
90 | f.write(f'======> query_idx:{query_idx}/{eval_ds.queries_num} <======\n')
91 | f.write(f'Query_Path: {query_img_path}\n')
92 | for each_db_path in db_img_path:
93 | f.write(f'Pred_db_Path: {each_db_path}\n')
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import math
4 | import torch
5 | import logging
6 | import numpy as np
7 | from matplotlib import pyplot as plt
8 | from torch.utils import data
9 | from tqdm import tqdm
10 | import torch.nn as nn
11 | import multiprocessing
12 | from os.path import join
13 | from datetime import datetime
14 | import torchvision.transforms as transforms
15 | from torch.utils.data.dataloader import DataLoader
16 | from torch.utils.tensorboard import SummaryWriter
17 | from torch.cuda.amp import autocast, GradScaler
18 |
19 | from tools import util
20 | import test
21 | import parser
22 | from tools import commons
23 | import datasets_ws
24 | from model import network
25 | from model.sync_batchnorm import convert_model
26 | from model.functional import sare_ind, sare_joint
27 | from tools.visual import display_mining
28 | from tools.loss import shift_window_triple_loss
29 |
30 |
31 | def train(args,
32 | model: nn.Module,
33 | train_set: datasets_ws.TripletsDataset,
34 | loss_fn: nn.TripletMarginLoss,
35 | optimizer,
36 | val_set: datasets_ws.BaseDataset,
37 | writer,
38 | start_epoch_num=0,
39 | best_r5=0,
40 | not_improved_num=0,
41 | show_mining_triplet_img=None,
42 | visual_mining_save_path=None,
43 | show_inference_results=None,
44 | visual_val_save_path=None
45 | ):
46 | """
47 | Trains a given model using triplet loss and performs validation.
48 |
49 | Args:
50 | args: A namespace or dictionary containing training parameters.
51 | model (nn.Module): The neural network model to train.
52 | train_set (datasets_ws.TripletsDataset): The training dataset containing triplets.
53 | loss_fn (nn.TripletMarginLoss): The loss function to optimize.
54 | optimizer: The optimization algorithm.
55 | val_set (datasets_ws.BaseDataset): The validation dataset.
56 | writer: A summary writer object for logging.
57 | start_epoch_num (int): The starting epoch number for training.
58 | best_r5 (float): The best recall@5 score obtained so far.
59 | not_improved_num (int): Counter for epochs without improvement.
60 | show_mining_triplet_img (callable, optional): Function to visualize mining results.
61 | visual_mining_save_path (str, optional): Path to save visual mining results.
62 | show_inference_results (callable, optional): Function to visualize inference results.
63 | visual_val_save_path (str, optional): Path to save validation visualization results.
64 |
65 | The training process includes mining hard triplets, calculating loss, and updating the model parameters.
66 | Validation is performed at the end of each epoch to monitor the recall metrics and early stopping is applied
67 | based on the recall@5 metric.
68 | """
69 |
70 | scaler = GradScaler()
71 |
72 | # Training loop
73 | for epoch_num in range(start_epoch_num, args.epochs_num):
74 | logging.info(f"Start training epoch: {epoch_num:02d}")
75 |
76 | epoch_start_time = datetime.now()
77 | epoch_losses = np.zeros((0, 1), dtype=np.float32)
78 |
79 | # How many loops should an epoch last
80 | loops_num = math.ceil(args.queries_per_epoch / args.cache_refresh_rate)
81 | for loop_num in range(loops_num):
82 | logging.debug(f"Cache: {loop_num} / {loops_num}")
83 |
84 | # Compute triplets to use in the triplet loss
85 | train_set.is_inference = True
86 | train_set.compute_triplets(args, model)
87 | train_set.is_inference = False
88 |
89 | # Visualizing mining results.
90 | if show_mining_triplet_img:
91 | os.makedirs(f'{visual_mining_save_path}' + f'mining_epoch{epoch_num}/', exist_ok=True)
92 | save_path = f'{visual_mining_save_path}' + f'mining_epoch{epoch_num}/loopNum{loop_num}_'
93 | triplets_global_indexes_array = train_set.triplets_global_indexes.numpy()
94 | # Obtaining the sub-window focused on by the model.
95 | pos_patch_loc, neg_patch_loc = train_setshift_window_on_img.pos_focus_patch, train_set.neg_focus_patch
96 | display_mining(train_set, triplets_global_indexes_array, save_path, pos_patch_loc, neg_patch_loc)
97 |
98 | triplets_dl = DataLoader(dataset=train_set, num_workers=args.num_workers,
99 | batch_size=args.train_batch_size,
100 | pin_memory=(args.device == "cuda"),
101 | drop_last=True)
102 |
103 | model = model.train()
104 |
105 | cache_losses = np.zeros((0, 1), dtype=np.float32)
106 | batch_nums = args.cache_refresh_rate / args.train_batch_size
107 |
108 | # images shape: (train_batch_size*12)*3*H*W
109 | for batch_idx, (query, pano_database) in enumerate(tqdm(triplets_dl, ncols=100, desc='Training')):
110 | # query shape:(B, 3, 224, 224) || pano_database shape:(B, 11, 3, 224, 224*8)
111 | pano_database_4d = torch.flatten(pano_database, end_dim=1) # B*11, 3, 224, 224*8
112 |
113 | pano_split_list = []
114 | for pano_full in pano_database_4d:
115 | pano_split = datasets_ws.shift_window_on_img(pano_full, train_set.split_nums,
116 | train_set.window_stride, train_set.window_len)
117 |
118 | pano_split_list.append(pano_split)
119 |
120 | pano_database_4d = torch.flatten(torch.stack(pano_split_list), end_dim=1)
121 |
122 | optimizer.zero_grad()
123 | with autocast():
124 | database_feature = model(pano_database_4d.to(args.device)).view(args.train_batch_size,
125 | 1 + args.negs_num_per_query, -1)
126 | query_feature = model(query.to(args.device))
127 |
128 | loss_triplet = 0
129 | loss = 0
130 | for idx in range(args.train_batch_size):
131 | loss, min_index_in_row = shift_window_triple_loss(args, query_feature[idx], database_feature[idx], loss_fn)
132 | loss_triplet += loss
133 | del database_feature, query_feature
134 |
135 | loss_triplet /= (args.train_batch_size * args.negs_num_per_query)
136 |
137 | scaler.scale(loss_triplet).backward()
138 |
139 | scaler.unscale_(optimizer)
140 | # clip_grad 1,3,5,7,10
141 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
142 |
143 | scaler.step(optimizer)
144 | scaler.update()
145 |
146 | # Keep track of all losses by appending them to epoch_losses
147 | batch_loss = loss_triplet.item()
148 |
149 | writer.add_scalar('Loss/Batch_Loss', batch_loss,
150 | epoch_num * loops_num * batch_nums + loop_num * batch_nums + batch_idx)
151 |
152 | cache_losses = np.append(cache_losses, batch_loss)
153 | del loss_triplet
154 |
155 | logging.debug(f"Epoch[{epoch_num:02d}]({loop_num}/{loops_num}): " +
156 | f"latest batch triplet loss = {batch_loss:.4f}, " +
157 | f"current cache triplet loss = {cache_losses.mean():.4f}")
158 | writer.add_scalar('Loss/Cache_Loss', cache_losses.mean().item(), epoch_num * loops_num + loop_num)
159 |
160 | # epoch_losses should update after calculating cache_loss
161 | epoch_losses = np.append(epoch_losses, cache_losses)
162 |
163 | logging.info(f"Finished epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, "
164 | f"average epoch triplet loss = {epoch_losses.mean():.4f}")
165 | writer.add_scalar('Loss/Epoch_Loss', epoch_losses.mean().item(), epoch_num)
166 |
167 | # Compute recalls on validation set
168 | recalls, recalls_str = test.test(args, val_set, model,
169 | show_inference_results=show_inference_results,
170 | save_path=visual_val_save_path)
171 | logging.info(f"Recalls on val set {val_set}: {recalls_str}")
172 |
173 | writer.add_scalar('Recall/@1', recalls[0], epoch_num)
174 | writer.add_scalar('Recall/@5', recalls[1], epoch_num)
175 | writer.add_scalar('Recall/@10', recalls[2], epoch_num)
176 | writer.add_scalar('Recall/@20', recalls[3], epoch_num)
177 |
178 | is_best = recalls[1] > best_r5
179 |
180 | # Save checkpoint, which contains all training parameters
181 | util.save_checkpoint(args, {"epoch_num": epoch_num, "model_state_dict": model.state_dict(),
182 | "optimizer_state_dict": optimizer.state_dict(), "recalls": recalls,
183 | "best_r5": best_r5,
184 | "not_improved_num": not_improved_num
185 | }, is_best, filename="last_model.pth")
186 |
187 | # If recall@5 did not improve for "many" epochs, stop training
188 | if is_best:
189 | logging.info(f"Improved: previous best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}")
190 | best_r5 = recalls[1]
191 | not_improved_num = 0
192 | else:
193 | not_improved_num += 1
194 | logging.info(
195 | f"Not improved: {not_improved_num} / {args.patience}: best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}")
196 | if not_improved_num >= args.patience:
197 | logging.info(f"Performance did not improve for {not_improved_num} epochs. Stop training.")
198 | break
199 |
200 | logging.info(f"Trained for {epoch_num + 1:02d} epochs, Best R@5: {best_r5:.1f}")
201 |
202 |
203 | def main():
204 | # Initial setup: parser, logging...
205 | args = parser.parse_arguments()
206 | start_time = datetime.now()
207 | args.save_dir = join("logs", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S') + '_' + args.title)
208 | commons.setup_logging(args.save_dir)
209 | commons.make_deterministic(args.seed, speedup=False) # speedup=False make results Reproducible
210 | logging.info(f"Arguments: {args}")
211 | logging.info(f"The outputs are being saved in {args.save_dir}")
212 | logging.info(f"Using {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs")
213 |
214 | # Creation of Datasets
215 | logging.debug(f"Loading dataset {args.dataset_name} from folder {args.datasets_folder}")
216 |
217 | triplets_ds = datasets_ws.TripletsDataset(args, args.datasets_folder, args.dataset_name, "train",
218 | args.negs_num_per_query)
219 | logging.info(f"Train query set: {triplets_ds}")
220 |
221 | val_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "val")
222 | logging.info(f"Val set: {val_ds}")
223 |
224 | test_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "test")
225 | logging.info(f"Test set: {test_ds}")
226 |
227 | #### Initialize model
228 | model = network.GeoLocalizationNet(args)
229 | model = model.to(args.device)
230 | model = torch.nn.DataParallel(model)
231 |
232 | #### Setup Optimizer and Loss
233 | if args.optim == "adam":
234 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
235 | elif args.optim == "sgd":
236 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.001)
237 |
238 | criterion_triplet = nn.TripletMarginLoss(margin=args.margin, p=2, reduction="sum")
239 |
240 | #### Resume model, optimizer, and other training parameters
241 | if args.resume:
242 | model, optimizer, best_r5, start_epoch_num, not_improved_num = util.resume_train(args, model, optimizer)
243 | logging.info(f"Resuming from epoch {start_epoch_num} with best recall@5 {best_r5:.1f}")
244 | else:
245 | best_r5 = start_epoch_num = not_improved_num = 0
246 |
247 | if torch.cuda.device_count() >= 2:
248 | # When using more than 1GPU, use sync_batchnorm for torch.nn.DataParallel
249 | model = convert_model(model)
250 | model = model.cuda()
251 |
252 | # Add tensorboard monitor
253 | writer = SummaryWriter(log_dir=args.save_dir)
254 |
255 | # Train model on train set and validate model on validation set every train epoch
256 | train(args,
257 | model,
258 | train_set=triplets_ds, loss_fn=criterion_triplet, optimizer=optimizer, val_set=val_ds,
259 | writer=writer, start_epoch_num=start_epoch_num, best_r5=best_r5, not_improved_num=not_improved_num,
260 | show_mining_triplet_img=False,
261 | visual_mining_save_path=args.train_visual_save_path,
262 | show_inference_results=False,
263 | visual_val_save_path=args.val_visual_save_path)
264 | logging.info(f"Trained total in {str(datetime.now() - start_time)[:-7]}")
265 |
266 | #### Test best model on test set
267 | best_model_state_dict = torch.load(join(args.save_dir, "best_model.pth"))["model_state_dict"]
268 | model.load_state_dict(best_model_state_dict)
269 |
270 | recalls, recalls_str = test.test(args, test_ds, model, test_method=args.test_method,
271 | show_inference_results=False,
272 | save_path=args.test_visual_save_path)
273 | logging.info(f"Recalls on {test_ds}: {recalls_str}")
274 | # show test results in one graph
275 | for i in range(len(recalls)):
276 | writer.add_scalar('Test', recalls[i], i + 1)
277 | writer.close()
278 |
279 |
280 | if __name__ == '__main__':
281 | main()
282 |
--------------------------------------------------------------------------------