├── .gitignore ├── LICENSE ├── README.md ├── assets └── architecture.jpg ├── cache_dataset_datastruct.py ├── evaluation.py ├── jist ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── test_dataset.py │ └── train_dataset.py ├── evals │ ├── __init__.py │ ├── cp_test.py │ └── test.py ├── models │ ├── __init__.py │ └── network.py └── utils │ ├── __init__.py │ ├── augmentations.py │ ├── commons.py │ ├── cosface_loss.py │ ├── cp_utils.py │ ├── data.py │ ├── logging.py │ ├── parser.py │ └── utils.py ├── requirements.txt └── train_double_dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | # custom 2 | *.ipynb_checkpoints 3 | *.ipynb 4 | mytest.py 5 | saved_objects/ 6 | downloaded/ 7 | *secret* 8 | logs/ 9 | cache 10 | *corrupt* 11 | *fake* 12 | 2legion* 13 | out.txt 14 | vit_legion.job 15 | ign*.* 16 | test 17 | *.pkl 18 | testing*.py 19 | *.npy 20 | f_*launcher.py 21 | jobs/ 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | pip-wheel-metadata/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | *.py,cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | #Pipfile.lock 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # Mac and PyCharm files 154 | .DS_Store 155 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gabriele Trivigno, Gabriele Berton 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JIST: Joint Image and Sequence Training for Sequential Visual Place Recognition 2 | 3 |

4 | 5 |
Overview of the JIST framework. 6 |

7 | 8 | This is the official repository for the paper "[JIST: Joint Image and Sequence Training for Sequential Visual Place Recognition](https://ieeexplore.ieee.org/document/10339796)". It can be used to reproduce results from the paper while also experiment with different aggregations of sequential descriptor methods for Visual Place Recognition. 9 | 10 | ## Install Locally 11 | Create your local environment and then install the required packages using: 12 | ``` bash 13 | pip install -r pip_requirements.txt 14 | ``` 15 | 16 | ## Datasets 17 | The experiments in the paper use two main datasets Mapillary Street Level Sequence (MSLS) and Oxford RobotCar. 18 | For downloading them you can refer to the repo of our previous paper here. 19 | - for MSLS, download the dataset from the official website, and format it using the instructions at the link above 20 | - for Robotcar, at the repo link above we provide from GDrive an already pre-processed version of the dataset 21 | 22 | In this work we also use the SF-XL dataset, which in total is about 1 TB, although we only use the `processed` subset as done in CosPlace, which is around 400 GB. More info in [CosPlace](https://github.com/gmberton/CosPlace). 23 | You can request the download, specifying the processed version of the dataset, using this form [_here_](https://forms.gle/wpyDzhDyoWLQygAT9). 24 | 25 | ## Run Experiments 26 | Once the datasets are ready, we can proceed running the experiments with the architecture of choice. 27 | 28 | **NB**: to build MSLS sequences, some heavy pre-processing to build data structures is needed. The dataset class will automatically cache this, 29 | so to compute them only the first time. Therefore the first experiment that you ever launch will take 1-2 hours to build this structures which will 30 | be saved in a `cache` directory, and following experiments will then start quickly. Note that this procedure caches everything with relative paths, 31 | therefore if you want to run experiments on multiple machines you can simply copy the `cache` directory. 32 | Finally, note that this data structures must be computed for each sequence length, so potentially in `cache` you will have a file for each sequence_length 33 | that you want to experiment with. You can also precompute them with the following command: 34 | ```bash 35 | python cache_dataset_datastruct.py --msls_folder /path/to/msls --seq_len SL 36 | ``` 37 | 38 | To replicate our results you can train a model as follows: 39 | 40 | ```bash 41 | python train_double_dataset.py --exp_name exp_name \ 42 | --dataset_folder /path/to/sf_xl/processed \ 43 | --seq_dataset_path /path/to/msls \ 44 | --aggregation_type seqgem 45 | ``` 46 | 47 | ### Evaluate trained models 48 | It is possible to evaluate the trained models using: 49 | ``` bash 50 | python main_scripts/evaluation.py \ 51 | --resume --seq_dataset_path /path/to/dataset 52 | ``` 53 | 54 | ### Download trained models 55 | You can download our JIST model with a ResNet-18 backbone and SeqGem from [this Google Drive link](https://drive.google.com/file/d/1F6eVrR-0LseE-tfbT8Y8WT92_O11lt5o/view?usp=sharing). 56 | 57 | ## Acknowledgements 58 | 59 | Parts of this repo are inspired by the following repositories: 60 | - [SeqVLAD](https://github.com/vandal-vpr/vg-transformers) 61 | - [CosFace](https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py) 62 | - [CosPlace](https://github.com/gmberton/CosPlace) 63 | 64 | ## Cite 65 | Here is the bibtex to cite our paper 66 | 67 | ``` 68 | @ARTICLE{Berton_2023_Jist, 69 | author={Berton, Gabriele and Trivigno, Gabriele and Caputo, Barbara and Masone, Carlo}, 70 | journal={IEEE Robotics and Automation Letters}, 71 | title={JIST: Joint Image and Sequence Training for Sequential Visual Place Recognition}, 72 | year={2023}, 73 | volume={}, 74 | number={}, 75 | pages={1-8}, 76 | doi={10.1109/LRA.2023.3339058} 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /assets/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ga1i13o/JIST/02014f8934fdd0203907dad41a2ead1f7484298f/assets/architecture.jpg -------------------------------------------------------------------------------- /cache_dataset_datastruct.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from jist.datasets import TrainDataset 3 | 4 | 5 | def main(args): 6 | folder = args.msls_folder 7 | seq_len = args.seq_len 8 | 9 | print(f'Caching dataset with seq len {seq_len}') 10 | triplets_ds = TrainDataset(cities='', dataset_folder=folder, split='train', 11 | seq_len=seq_len, pos_thresh=10, neg_thresh=25) 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser(description="Sequence Visual Geolocalization", 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('--msls_folder', type=str, required=True) 18 | parser.add_argument('--seq_len', type=int, default=5) 19 | args = parser.parse_args() 20 | 21 | main(args) -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | import torch 4 | 5 | from jist.datasets import BaseDataset 6 | from jist import utils, evals 7 | from jist.models import JistModel 8 | 9 | 10 | def evaluation(args): 11 | start_time = datetime.now() 12 | args.output_folder = f"test/{args.exp_name}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}" 13 | utils.setup_logging(args.output_folder, console="info") 14 | logging.info(f"Arguments: {args}") 15 | logging.info(f"The outputs are being saved in {args.output_folder}") 16 | 17 | ### Definition of the model 18 | model = JistModel(args, agg_type=args.aggregation_type) 19 | 20 | if args.resume_model != None: 21 | logging.debug(f"Loading model from {args.resume_model}") 22 | model_state_dict = torch.load(args.resume_model) 23 | model.load_state_dict(model_state_dict) 24 | 25 | model = model.to(args.device) 26 | 27 | meta = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]} 28 | img_shape = (args.img_shape[0], args.img_shape[1]) 29 | transform = utils.configure_transform(image_dim=img_shape, meta=meta) 30 | 31 | eval_ds = BaseDataset(dataset_folder=args.seq_dataset_path, split='test', 32 | base_transform=transform, seq_len=args.seq_length, 33 | pos_thresh=args.val_posDistThr, reverse_frames=args.reverse) 34 | logging.info(f"Test set: {eval_ds}") 35 | 36 | logging.info(f"Backbone output channels are {model.features_dim}, features descriptor dim is {model.fc_output_dim}, " 37 | f"sequence descriptor dim is {model.aggregation_dim}") 38 | 39 | _, recalls_str = evals.test(args, eval_ds, model) 40 | logging.info(f"Recalls on test set: {recalls_str}") 41 | logging.info(f"Finished in {str(datetime.now() - start_time)[:-7]}") 42 | 43 | 44 | if __name__ == "__main__": 45 | args = utils.parse_arguments() 46 | evaluation(args) 47 | -------------------------------------------------------------------------------- /jist/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets 2 | from . import models 3 | from . import evals 4 | from . import utils 5 | -------------------------------------------------------------------------------- /jist/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['dataset', 'train_dataset', 'test_dataset'] 2 | 3 | 4 | from .dataset import BaseDataset, TrainDataset, collate_fn 5 | from .train_dataset import CosplaceTrainDataset 6 | from .test_dataset import TestDataset 7 | -------------------------------------------------------------------------------- /jist/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import faiss 4 | import logging 5 | import numpy as np 6 | from glob import glob 7 | from tqdm import tqdm 8 | from PIL import Image 9 | from os.path import join 10 | import torch.utils.data as data 11 | from torch.utils.data.dataset import Subset 12 | from sklearn.neighbors import NearestNeighbors 13 | from torch.utils.data.dataloader import DataLoader 14 | 15 | from jist.utils import RAMEfficient2DMatrix 16 | 17 | 18 | def collate_fn(batch): 19 | """Creates mini-batch tensors from the list of tuples (images, 20 | triplets_local_indexes, triplets_global_indexes). 21 | triplets_local_indexes are the indexes referring to each triplet within images. 22 | triplets_global_indexes are the global indexes of each image. 23 | Args: 24 | batch: list of tuple (images, triplets_local_indexes, triplets_global_indexes). 25 | considering each query to have 10 negatives (negs_num_per_query=10): 26 | - images: torch tensor of shape (12, 3, h, w). 27 | - triplets_local_indexes: torch tensor of shape (10, 3). 28 | - triplets_global_indexes: torch tensor of shape (12). 29 | Returns: 30 | images: torch tensor of shape (batch_size*12, 3, h, w). 31 | triplets_local_indexes: torch tensor of shape (batch_size*10, 3). 32 | triplets_global_indexes: torch tensor of shape (batch_size, 12). 33 | """ 34 | images = torch.cat([e[0] for e in batch]) 35 | triplets_local_indexes = torch.cat([e[1][None] for e in batch]) 36 | triplets_global_indexes = torch.cat([e[2][None] for e in batch]) 37 | for i, (local_indexes, global_indexes) in enumerate(zip(triplets_local_indexes, triplets_global_indexes)): 38 | local_indexes += len(global_indexes) * i # Increment local indexes by offset (len(global_indexes) is 12) 39 | return images, torch.cat(tuple(triplets_local_indexes)), triplets_global_indexes 40 | 41 | 42 | class BaseDataset(data.Dataset): 43 | def __init__(self, cities='', dataset_folder="datasets", split="train", base_transform=None, 44 | seq_len=3, pos_thresh=25, neg_thresh=25, reverse_frames=False): 45 | super().__init__() 46 | self.dataset_folder = join(dataset_folder, split) 47 | self.seq_len = seq_len 48 | 49 | # the algorithm to create sequences works with odd-numbered sequences only 50 | cut_last_frame = ((seq_len % 2) == 0) 51 | if cut_last_frame: 52 | self.seq_len += 1 53 | self.base_transform = base_transform 54 | 55 | if not os.path.exists(self.dataset_folder): raise FileNotFoundError( 56 | f"Folder {self.dataset_folder} does not exist") 57 | 58 | self.init_data(cities, split, pos_thresh, neg_thresh) 59 | if reverse_frames: 60 | self.db_paths = [",".join(path.split(',')[::-1]) for path in self.db_paths] 61 | self.images_paths = self.db_paths + self.q_paths 62 | self.database_num = len(self.db_paths) 63 | self.queries_num = len(self.qIdx) 64 | 65 | if cut_last_frame: 66 | self.__cut_last_frame() 67 | 68 | def init_data(self, cities, split, pos_thresh, neg_thresh): 69 | if ('msls' in self.dataset_folder) and (split == 'train'): 70 | os.makedirs('cache', exist_ok=True) 71 | cache_file = f'cache/msls_seq{self.seq_len}_{cities}.torch' 72 | if os.path.isfile(cache_file): 73 | logging.info(f'Loading cached data from {cache_file}...') 74 | cache_dict = torch.load(cache_file) 75 | self.__dict__.update(cache_dict) 76 | return 77 | else: 78 | logging.info('Data structures were not cached, building them now...') 79 | #### Read paths and UTM coordinates for all images. 80 | database_folder = join(self.dataset_folder, "database") 81 | queries_folder = join(self.dataset_folder, "queries") 82 | 83 | cities = cities 84 | self.db_paths, all_db_paths, db_idx_frame_to_seq = build_sequences(database_folder, 85 | seq_len=self.seq_len, cities=cities, 86 | desc='loading database...') 87 | self.q_paths, all_q_paths, q_idx_frame_to_seq = build_sequences(queries_folder, 88 | seq_len=self.seq_len, cities=cities, 89 | desc='loading queries...') 90 | 91 | q_unique_idxs = np.unique([idx for seq_frames_idx in q_idx_frame_to_seq for idx in seq_frames_idx]) 92 | db_unique_idxs = np.unique([idx for seq_frames_idx in db_idx_frame_to_seq for idx in seq_frames_idx]) 93 | 94 | self.database_utms = np.array( 95 | [(path.split("@")[1], path.split("@")[2]) for path in all_db_paths[db_unique_idxs]]).astype(np.float64) 96 | self.queries_utms = np.array( 97 | [(path.split("@")[1], path.split("@")[2]) for path in all_q_paths[q_unique_idxs]]).astype( 98 | np.float64) 99 | 100 | knn = NearestNeighbors(n_jobs=-1) 101 | knn.fit(self.database_utms) 102 | self.hard_positives_per_query = knn.radius_neighbors(self.queries_utms, 103 | radius=pos_thresh, 104 | return_distance=False) 105 | if split == 'train': 106 | # Find soft_positives_per_query, which are within val_positive_dist_threshold (deafult 25 meters) 107 | knn = NearestNeighbors(n_jobs=-1) 108 | knn.fit(self.database_utms) 109 | self.soft_positives_per_query = knn.radius_neighbors(self.queries_utms, 110 | radius=neg_thresh, 111 | return_distance=False) 112 | self.qIdx = [] 113 | self.pIdx = [] 114 | self.nonNegIdx = [] 115 | self.q_without_pos = 0 116 | for q in tqdm(range(len(q_idx_frame_to_seq)), ncols=100, desc='Finding positives and negatives...'): 117 | q_frame_idxs = q_idx_frame_to_seq[q] 118 | unique_q_frame_idxs = np.where(np.in1d(q_unique_idxs, q_frame_idxs)) 119 | 120 | p_uniq_frame_idxs = np.unique( 121 | [p for pos in self.hard_positives_per_query[unique_q_frame_idxs] for p in pos]) 122 | 123 | if len(p_uniq_frame_idxs) > 0: 124 | # p_seq_idx = np.where(np.in1d(db_unique_idxs, p_uniq_frame_idxs))[0] 125 | p_seq_idx = np.where(np.in1d(db_idx_frame_to_seq, db_unique_idxs[p_uniq_frame_idxs]) 126 | .reshape(db_idx_frame_to_seq.shape))[0] 127 | 128 | self.qIdx.append(q) 129 | self.pIdx.append(np.unique(p_seq_idx)) 130 | 131 | if split == 'train': 132 | nonNeg_uniq_frame_idxs = np.unique( 133 | [p for pos in self.soft_positives_per_query[unique_q_frame_idxs] for p in pos]) 134 | #nonNeg_seq_idx = np.where(np.in1d(db_unique_idxs, nonNeg_uniq_frame_idxs)) 135 | #self.nonNegIdx.append(nonNeg_seq_idx) 136 | nonNeg_seq_idx = np.where(np.in1d(db_idx_frame_to_seq, db_unique_idxs[nonNeg_uniq_frame_idxs]) 137 | .reshape(db_idx_frame_to_seq.shape))[0] 138 | self.nonNegIdx.append(np.unique(nonNeg_seq_idx)) 139 | else: 140 | self.q_without_pos += 1 141 | 142 | self.qIdx = np.array(self.qIdx) 143 | self.pIdx = np.array(self.pIdx, dtype=object) 144 | if split == 'train': 145 | save_dict = { 146 | 'db_paths': self.db_paths, 147 | 'q_paths': self.q_paths, 148 | 'database_utms': self.database_utms, 149 | 'queries_utms': self.queries_utms, 150 | 'hard_positives_per_query': self.hard_positives_per_query, 151 | 'soft_positives_per_query': self.soft_positives_per_query, 152 | 'qIdx': self.qIdx, 153 | 'pIdx': self.pIdx, 154 | 'nonNegIdx': self.nonNegIdx, 155 | 'q_without_pos': self.q_without_pos 156 | } 157 | torch.save(save_dict, cache_file) 158 | 159 | def __cut_last_frame(self): 160 | for i, seq in enumerate(self.images_paths): 161 | self.images_paths[i] = ','.join((seq.split(',')[:-1])) 162 | for i, seq in enumerate(self.db_paths): 163 | self.db_paths[i] = ','.join((seq.split(',')[:-1])) 164 | for i, seq in enumerate(self.q_paths): 165 | self.q_paths[i] = ','.join((seq.split(',')[:-1])) 166 | 167 | def __getitem__(self, index): 168 | old_index = index 169 | if index >= self.database_num: 170 | q_index = index - self.database_num 171 | index = self.qIdx[q_index] + self.database_num 172 | 173 | img = torch.stack([self.base_transform(Image.open(join(self.dataset_folder, im))) for im in self.images_paths[index].split(',')]) 174 | 175 | return img, index, old_index 176 | 177 | def __len__(self): 178 | return len(self.images_paths) 179 | 180 | def __repr__(self): 181 | return ( 182 | f"< {self.__class__.__name__}, ' #database: {self.database_num}; #queries: {self.queries_num} >") 183 | 184 | def get_positives(self): 185 | return self.pIdx 186 | 187 | 188 | def filter_by_cities(x, cities): 189 | for city in cities: 190 | if x.find(city) > 0: 191 | return True 192 | return False 193 | 194 | 195 | def build_sequences(folder, seq_len=3, cities='', desc='loading'): 196 | if cities != '': 197 | if not isinstance(cities, list): 198 | cities = [cities] 199 | base_path = os.path.dirname(folder) 200 | paths = [] 201 | all_paths = [] 202 | idx_frame_to_seq = [] 203 | seqs_folders = sorted(glob(join(folder, '*'), recursive=True)) 204 | for seq in tqdm(seqs_folders, ncols=100, desc=desc): 205 | start_index = len(all_paths) 206 | frame_nums = np.array(list(map(lambda x: int(x.split('@')[4]), sorted(glob(join(seq, '*')))))) 207 | # seq_paths = np.array(sorted(glob(join(seq, '*')))) 208 | # read the full paths, then keep only relative ones. allows caching only rel. paths 209 | full_seq_paths = sorted(glob(join(seq, '*'))) 210 | seq_paths = np.array([s_p.replace(f'{base_path}/', '') for s_p in full_seq_paths]) 211 | 212 | if cities != '': 213 | sample_path = seq_paths[0] 214 | if not filter_by_cities(sample_path, cities): 215 | continue 216 | 217 | sorted_idx_frames = np.argsort(frame_nums) 218 | all_paths += list(seq_paths[sorted_idx_frames]) 219 | for idx, frame_num in enumerate(frame_nums): 220 | if idx < (seq_len // 2) or idx >= (len(frame_nums) - seq_len // 2): continue 221 | # find surrounding frames in sequence 222 | seq_idx = np.arange(-seq_len // 2, seq_len // 2) + 1 + idx 223 | if (np.diff(frame_nums[sorted_idx_frames][seq_idx]) == 1).all(): 224 | paths.append(",".join(seq_paths[sorted_idx_frames][seq_idx])) 225 | idx_frame_to_seq.append(seq_idx + start_index) 226 | 227 | return paths, np.array(all_paths), np.array(idx_frame_to_seq) 228 | 229 | 230 | class TrainDataset(BaseDataset): 231 | def __init__(self, cities='', dataset_folder="datasets", split="train", base_transform=None, 232 | seq_len=3, pos_thresh=25, neg_thresh=25, infer_batch_size=8, 233 | num_workers=3, img_shape=(480, 640), 234 | cached_negatives=1000, cached_queries=1000, nNeg=10): 235 | super().__init__(dataset_folder=dataset_folder, split=split, cities=cities, base_transform=base_transform, 236 | seq_len=seq_len, pos_thresh=pos_thresh, neg_thresh=neg_thresh) 237 | self.cached_negatives = cached_negatives # Number of negatives to randomly sample 238 | self.num_workers = num_workers 239 | self.cached_queries = cached_queries 240 | self.device = torch.device('cuda') if torch.cuda.is_available() else "cpu" 241 | self.bs = infer_batch_size 242 | self.img_shape = img_shape 243 | self.nNeg = nNeg # Number of negatives per query in each batch 244 | self.is_inference = False 245 | self.query_transform = self.base_transform 246 | 247 | def __getitem__(self, index): 248 | if self.is_inference: 249 | # At inference time return the single image. This is used for caching or computing NetVLAD's clusters 250 | return super().__getitem__(index) 251 | query_index, best_positive_index, neg_indexes = torch.split(self.triplets_global_indexes[index], 252 | (1, 1, self.nNeg)) 253 | 254 | query = torch.stack( 255 | [self.base_transform(Image.open(join(self.dataset_folder, im))) for im in self.q_paths[query_index].split(',')]) 256 | 257 | positive = torch.stack( 258 | [self.base_transform(Image.open(join(self.dataset_folder, im))) for im in self.db_paths[best_positive_index].split(',')]) 259 | 260 | negatives = [torch.stack([self.base_transform(Image.open(join(self.dataset_folder, im)))for im in self.db_paths[idx].split(',')]) 261 | for idx in neg_indexes] 262 | 263 | images = torch.stack((query, positive, *negatives), 0) 264 | triplets_local_indexes = torch.empty((0, 3), dtype=torch.int) 265 | for neg_num in range(len(neg_indexes)): 266 | triplets_local_indexes = torch.cat( 267 | (triplets_local_indexes, torch.tensor([0, 1, 2 + neg_num]).reshape(1, 3))) 268 | return images, triplets_local_indexes, self.triplets_global_indexes[index] 269 | 270 | def __len__(self): 271 | if self.is_inference: 272 | # At inference time return the number of images. This is used for caching or computing NetVLAD's clusters 273 | return super().__len__() 274 | else: 275 | return len(self.triplets_global_indexes) 276 | 277 | def compute_triplets(self, model): 278 | self.is_inference = True 279 | self.compute_triplets_partial(model) 280 | self.is_inference = False 281 | 282 | def compute_cache(self, model, subset_ds, cache_shape): 283 | subset_dl = DataLoader(dataset=subset_ds, num_workers=self.num_workers, 284 | batch_size=self.bs, shuffle=False, 285 | pin_memory=(self.device == "cuda")) 286 | model = model.eval() 287 | cache = RAMEfficient2DMatrix(cache_shape, dtype=np.float32) 288 | with torch.no_grad(): 289 | with torch.cuda.amp.autocast(): 290 | for images, indexes, _ in tqdm(subset_dl, desc="cache feat extraction", ncols=100): 291 | images = images.view(-1, 3, self.img_shape[0], self.img_shape[1]) 292 | frames_features = model(images.to(self.device)) 293 | aggregated_features = model.aggregate(frames_features) 294 | cache[indexes.numpy()] = aggregated_features.cpu().numpy() 295 | return cache 296 | 297 | def get_best_positive_index(self, qidx, cache, query_features): 298 | positives_features = cache[self.pIdx[qidx]] 299 | faiss_index = faiss.IndexFlatL2(len(query_features)) 300 | faiss_index.add(positives_features) 301 | # Search the best positive (within 10 meters AND nearest in features space) 302 | _, best_positive_num = faiss_index.search(query_features.reshape(1, -1), 1) 303 | best_positive_index = self.pIdx[qidx][best_positive_num[0]] 304 | return best_positive_index 305 | 306 | def get_hardest_negatives_indexes(self, cache, query_features, neg_indexes): 307 | neg_features = cache[neg_indexes] 308 | 309 | faiss_index = faiss.IndexFlatL2(len(query_features)) 310 | faiss_index.add(neg_features) 311 | 312 | _, neg_nums = faiss_index.search(query_features.reshape(1, -1), self.nNeg) 313 | neg_nums = neg_nums.reshape(-1) 314 | neg_idxs = neg_indexes[neg_nums].astype(np.int32) 315 | 316 | return neg_idxs 317 | 318 | def compute_triplets_partial(self, model): 319 | self.triplets_global_indexes = [] 320 | # Take 1000 random queries 321 | sampled_queries_indexes = np.random.choice(self.queries_num, self.cached_queries, replace=False) 322 | # Sample 1000 random database images for the negatives 323 | sampled_database_indexes = np.random.choice(self.database_num, self.cached_negatives, replace=False) 324 | 325 | positives_indexes = np.unique([idx for db_idx in self.pIdx[sampled_queries_indexes] for idx in db_idx]) 326 | database_indexes = list(sampled_database_indexes) + list(positives_indexes) 327 | subset_ds = Subset(self, database_indexes + list(sampled_queries_indexes + self.database_num)) 328 | cache = self.compute_cache(model, subset_ds, cache_shape=[len(self), model.aggregation_dim]) 329 | 330 | for q in tqdm(sampled_queries_indexes, desc="computing hard negatives", ncols=100): 331 | qidx = self.qIdx[q] + self.database_num 332 | query_features = cache[qidx] 333 | 334 | best_positive_index = self.get_best_positive_index(q, cache, query_features) 335 | if isinstance(best_positive_index, np.ndarray): 336 | best_positive_index = best_positive_index[0] 337 | # Choose the hardest negatives within sampled_database_indexes, ensuring that there are no positives 338 | soft_positives = self.nonNegIdx[q] 339 | neg_indexes = np.setdiff1d(sampled_database_indexes, soft_positives, assume_unique=True) 340 | # Take all database images that are negatives and are within the sampled database images 341 | neg_indexes = self.get_hardest_negatives_indexes(cache, query_features, neg_indexes) 342 | self.triplets_global_indexes.append((self.qIdx[q], best_positive_index, *neg_indexes)) 343 | 344 | # self.triplets_global_indexes is a tensor of shape [1000, 12] 345 | self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes) 346 | 347 | -------------------------------------------------------------------------------- /jist/datasets/test_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | from glob import glob 5 | from PIL import Image 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | from sklearn.neighbors import NearestNeighbors 9 | 10 | 11 | def open_image(path): 12 | return Image.open(path).convert("RGB") 13 | 14 | 15 | class TestDataset(data.Dataset): 16 | def __init__(self, dataset_folder, database_folder="database", 17 | queries_folder="queries", positive_dist_threshold=25): 18 | """Dataset with images from database and queries, used for validation and test. 19 | Parameters 20 | ---------- 21 | dataset_folder : str, should contain the path to the val or test set, 22 | which contains the folders {database_folder} and {queries_folder}. 23 | database_folder : str, name of folder with the database. 24 | queries_folder : str, name of folder with the queries. 25 | positive_dist_threshold : int, distance in meters for a prediction to 26 | be considered a positive. 27 | """ 28 | super().__init__() 29 | self.dataset_folder = dataset_folder 30 | self.database_folder = os.path.join(dataset_folder, database_folder) 31 | self.queries_folder = os.path.join(dataset_folder, queries_folder) 32 | self.dataset_name = os.path.basename(dataset_folder) 33 | 34 | if not os.path.exists(self.dataset_folder): 35 | raise FileNotFoundError(f"Folder {self.dataset_folder} does not exist") 36 | if not os.path.exists(self.database_folder): 37 | raise FileNotFoundError(f"Folder {self.database_folder} does not exist") 38 | if not os.path.exists(self.queries_folder): 39 | raise FileNotFoundError(f"Folder {self.queries_folder} does not exist") 40 | 41 | self.base_transform = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 44 | ]) 45 | 46 | #### Read paths and UTM coordinates for all images. 47 | self.database_paths = sorted(glob(os.path.join(self.database_folder, "**", "*.jpg"), recursive=True)) 48 | self.queries_paths = sorted(glob(os.path.join(self.queries_folder, "**", "*.jpg"), recursive=True)) 49 | 50 | # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg 51 | self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(np.float64) 52 | self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(np.float64) 53 | 54 | # Find positives_per_query, which are within positive_dist_threshold (default 25 meters) 55 | knn = NearestNeighbors(n_jobs=-1) 56 | knn.fit(self.database_utms) 57 | self.positives_per_query = knn.radius_neighbors(self.queries_utms, 58 | radius=positive_dist_threshold, 59 | return_distance=False) 60 | 61 | self.images_paths = [p for p in self.database_paths] 62 | self.images_paths += [p for p in self.queries_paths] 63 | 64 | self.database_num = len(self.database_paths) 65 | self.queries_num = len(self.queries_paths) 66 | 67 | def __getitem__(self, index): 68 | image_path = self.images_paths[index] 69 | pil_img = open_image(image_path) 70 | normalized_img = self.base_transform(pil_img) 71 | return normalized_img, index 72 | 73 | def __len__(self): 74 | return len(self.images_paths) 75 | 76 | def __repr__(self): 77 | return (f"< {self.dataset_name} - #q: {self.queries_num}; #db: {self.database_num} >") 78 | 79 | def get_positives(self): 80 | return self.positives_per_query 81 | 82 | -------------------------------------------------------------------------------- /jist/datasets/train_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import random 5 | import logging 6 | import numpy as np 7 | from PIL import Image 8 | from PIL import ImageFile 9 | import torchvision.transforms as T 10 | from collections import defaultdict 11 | 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | 15 | def open_image(path): 16 | return Image.open(path).convert("RGB") 17 | 18 | 19 | class CosplaceTrainDataset(torch.utils.data.Dataset): 20 | def __init__(self, args, dataset_folder, M=10, alpha=30, N=5, L=2, 21 | current_group=0, min_images_per_class=10): 22 | """ 23 | Parameters (please check our paper for a clearer explanation of the parameters). 24 | ---------- 25 | args : args for data augmentation 26 | dataset_folder : str, the path of the folder with the train images. 27 | M : int, the length of the side of each cell in meters. 28 | alpha : int, size of each class in degrees. 29 | N : int, distance (M-wise) between two classes of the same group. 30 | L : int, distance (alpha-wise) between two classes of the same group. 31 | current_group : int, which one of the groups to consider. 32 | min_images_per_class : int, minimum number of image in a class. 33 | """ 34 | super().__init__() 35 | self.M = M 36 | self.alpha = alpha 37 | self.N = N 38 | self.L = L 39 | self.current_group = current_group 40 | self.dataset_folder = dataset_folder 41 | self.augmentation_device = args.augmentation_device 42 | 43 | # dataset_name should be "processed" (if you're using SF-XL) 44 | dataset_name = os.path.basename(args.dataset_folder.strip('/')) 45 | filename = f"cache/{dataset_name}_M{M}_N{N}_mipc{min_images_per_class}.torch" 46 | if not os.path.exists(filename): 47 | os.makedirs("cache", exist_ok=True) 48 | logging.info(f"Cached dataset {filename} does not exist, I'll create it now.") 49 | self.initialize(dataset_folder, M, N, alpha, L, min_images_per_class, filename) 50 | elif current_group == 0: 51 | logging.info(f"Using cached dataset {filename}") 52 | 53 | self.base_path = dataset_folder[:dataset_folder.find(dataset_name)] 54 | classes_per_group, self.images_per_class = torch.load(filename) 55 | if current_group >= len(classes_per_group): 56 | raise ValueError(f"With this configuration there are only {len(classes_per_group)} " + 57 | f"groups, therefore I can't create the {current_group}th group. " + 58 | "You should reduce the number of groups in --groups_num") 59 | 60 | self.classes_ids = classes_per_group[current_group] 61 | if self.augmentation_device == "cpu": 62 | self.transform = T.Compose([ 63 | T.ColorJitter(brightness=args.brightness, 64 | contrast=args.contrast, 65 | saturation=args.saturation, 66 | hue=args.hue), 67 | T.RandomResizedCrop([512, 512], scale=[1-args.random_resized_crop, 1]), 68 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 69 | ]) 70 | 71 | 72 | def __getitem__(self, class_num): 73 | # This function takes as input the class_num instead of the index of 74 | # the image. This way each class is equally represented during training. 75 | 76 | class_id = self.classes_ids[class_num] 77 | # Pick a random image among those in this class. 78 | image_path = random.choice(self.images_per_class[class_id]) 79 | image_path = os.path.join(self.base_path, image_path) 80 | try: 81 | pil_image = open_image(image_path) 82 | except Exception as e: 83 | logging.info(f"ERROR image {image_path} couldn't be opened, it might be corrupted.") 84 | raise e 85 | 86 | tensor_image = T.functional.to_tensor(pil_image) 87 | assert tensor_image.shape == torch.Size([3, 512, 512]), \ 88 | f"Image {image_path} should have shape [3, 512, 512] but has {tensor_image.shape}." 89 | 90 | if self.augmentation_device == "cpu": 91 | tensor_image = self.transform(tensor_image) 92 | 93 | return tensor_image, class_num, image_path 94 | 95 | 96 | def get_images_num(self): 97 | """Return the number of images within this group.""" 98 | return sum([len(self.images_per_class[c]) for c in self.classes_ids]) 99 | 100 | 101 | def __len__(self): 102 | """Return the number of classes within this group.""" 103 | return len(self.classes_ids) 104 | 105 | 106 | @staticmethod 107 | def initialize(dataset_folder, M, N, alpha, L, min_images_per_class, filename): 108 | logging.debug(f"Searching training images in {dataset_folder}") 109 | 110 | all_paths = os.path.join(dataset_folder, '../../all_images_paths.txt') 111 | with open(all_paths, "r") as file: 112 | lines = file.readlines() 113 | images_paths = [l.replace("\n", "") for l in lines if l.startswith("processed/train")] 114 | logging.debug(f"Found {len(images_paths)} images") 115 | 116 | logging.debug("For each image, get its UTM east, UTM north and heading from its path") 117 | images_metadatas = [p.split("@") for p in images_paths] 118 | # field 1 is UTM east, field 2 is UTM north, field 9 is heading 119 | utmeast_utmnorth_heading = [(m[1], m[2], m[9]) for m in images_metadatas] 120 | utmeast_utmnorth_heading = np.array(utmeast_utmnorth_heading).astype(np.float64) 121 | 122 | logging.debug("For each image, get class and group to which it belongs") 123 | class_id__group_id = [CosplaceTrainDataset.get__class_id__group_id(*m, M, alpha, N, L) 124 | for m in utmeast_utmnorth_heading] 125 | 126 | logging.debug("Group together images belonging to the same class") 127 | images_per_class = defaultdict(list) 128 | for image_path, (class_id, _) in zip(images_paths, class_id__group_id): 129 | images_per_class[class_id].append(image_path) 130 | 131 | # Images_per_class is a dict where the key is class_id, and the value 132 | # is a list with the paths of images within that class. 133 | images_per_class = {k: v for k, v in images_per_class.items() if len(v) >= min_images_per_class} 134 | 135 | logging.debug("Group together classes belonging to the same group") 136 | # Classes_per_group is a dict where the key is group_id, and the value 137 | # is a list with the class_ids belonging to that group. 138 | classes_per_group = defaultdict(set) 139 | for class_id, group_id in class_id__group_id: 140 | if class_id not in images_per_class: 141 | continue # Skip classes with too few images 142 | classes_per_group[group_id].add(class_id) 143 | 144 | # Convert classes_per_group to a list of lists. 145 | # Each sublist represents the classes within a group. 146 | classes_per_group = [list(c) for c in classes_per_group.values()] 147 | 148 | torch.save((classes_per_group, images_per_class), filename) 149 | 150 | 151 | @staticmethod 152 | def get__class_id__group_id(utm_east, utm_north, heading, M, alpha, N, L): 153 | """Return class_id and group_id for a given point. 154 | The class_id is a triplet (tuple) of UTM_east, UTM_north and 155 | heading (e.g. (396520, 4983800,120)). 156 | The group_id represents the group to which the class belongs 157 | (e.g. (0, 1, 0)), and it is between (0, 0, 0) and (N, N, L). 158 | """ 159 | rounded_utm_east = int(utm_east // M * M) # Rounded to nearest lower multiple of M 160 | rounded_utm_north = int(utm_north // M * M) 161 | rounded_heading = int(heading // alpha * alpha) 162 | 163 | class_id = (rounded_utm_east, rounded_utm_north, rounded_heading) 164 | # group_id goes from (0, 0, 0) to (N, N, L) 165 | group_id = (rounded_utm_east % (M * N) // M, 166 | rounded_utm_north % (M * N) // M, 167 | rounded_heading % (alpha * L) // alpha) 168 | return class_id, group_id 169 | 170 | -------------------------------------------------------------------------------- /jist/evals/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['test', 'cp_test'] 2 | 3 | 4 | from .test import test 5 | from .cp_test import cosplace_test 6 | 7 | -------------------------------------------------------------------------------- /jist/evals/cp_test.py: -------------------------------------------------------------------------------- 1 | 2 | import faiss 3 | import torch 4 | import logging 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.dataset import Subset 9 | 10 | 11 | # Compute R@1, R@5, R@10, R@20 12 | RECALL_VALUES = [1, 5, 10, 20] 13 | 14 | 15 | def cosplace_test(args, eval_ds, model): 16 | """Compute descriptors of the given dataset and compute the recalls.""" 17 | 18 | model = model.eval() 19 | with torch.no_grad(): 20 | logging.debug("Extracting database descriptors for evaluation/testing") 21 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num))) 22 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers, 23 | batch_size=args.infer_batch_size, pin_memory=(args.device=="cuda")) 24 | all_descriptors = np.empty((len(eval_ds), args.fc_output_dim), dtype="float32") 25 | for images, indices in tqdm(database_dataloader, ncols=100): 26 | descriptors = model(images.to(args.device)) 27 | descriptors = descriptors.cpu().numpy() 28 | all_descriptors[indices.numpy(), :] = descriptors 29 | 30 | logging.debug("Extracting queries descriptors for evaluation/testing") 31 | queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num))) 32 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers, 33 | batch_size=args.infer_batch_size, pin_memory=(args.device=="cuda")) 34 | for images, indices in tqdm(queries_dataloader, ncols=100): 35 | descriptors = model(images.to(args.device)) 36 | descriptors = descriptors.cpu().numpy() 37 | all_descriptors[indices.numpy(), :] = descriptors 38 | 39 | queries_descriptors = all_descriptors[eval_ds.database_num:] 40 | database_descriptors = all_descriptors[:eval_ds.database_num] 41 | 42 | # Use a kNN to find predictionss 43 | faiss_index = faiss.IndexFlatL2(args.fc_output_dim) 44 | faiss_index.add(database_descriptors) 45 | del database_descriptors, all_descriptors 46 | 47 | logging.debug("Calculating recalls") 48 | _, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES)) 49 | 50 | #### For each query, check if the predictions are correct 51 | positives_per_query = eval_ds.get_positives() 52 | recalls = np.zeros(len(RECALL_VALUES)) 53 | for query_index, preds in enumerate(predictions): 54 | for i, n in enumerate(RECALL_VALUES): 55 | if np.any(np.in1d(preds[:n], positives_per_query[query_index])): 56 | recalls[i:] += 1 57 | break 58 | # Divide by queries_num and multiply by 100, so the recalls are in percentages 59 | recalls = recalls / eval_ds.queries_num * 100 60 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)]) 61 | return recalls, recalls_str 62 | 63 | -------------------------------------------------------------------------------- /jist/evals/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | from tqdm import tqdm 5 | import faiss 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data.dataset import Subset 8 | 9 | 10 | def test(args, eval_ds, model): 11 | model = model.eval() 12 | 13 | query_num = eval_ds.queries_num 14 | gallery_num = eval_ds.database_num 15 | all_features = np.empty((query_num + gallery_num, model.aggregation_dim), dtype=np.float32) 16 | 17 | with torch.no_grad(): 18 | logging.debug("Extracting gallery features for evaluation/testing") 19 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num))) 20 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=4, 21 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda")) 22 | 23 | for images, indices, _ in tqdm(database_dataloader, ncols=100): 24 | images = images.contiguous().view(-1, 3, args.img_shape[0], args.img_shape[1]) 25 | frames_features = model(images.to(args.device)) 26 | aggregated_features = model.aggregate(frames_features) 27 | all_features[indices.numpy(), :] = aggregated_features.cpu().numpy() 28 | 29 | logging.debug("Extracting queries features for evaluation/testing") 30 | queries_subset_ds = Subset(eval_ds, 31 | list(range(eval_ds.database_num, eval_ds.database_num + eval_ds.queries_num))) 32 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=4, 33 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda")) 34 | 35 | for images, _, indices in tqdm(queries_dataloader, ncols=100): 36 | images = images.contiguous().view(-1, 3, args.img_shape[0], args.img_shape[1]) 37 | frames_features = model(images.to(args.device)) 38 | aggregated_features = model.aggregate(frames_features) 39 | all_features[indices.numpy(), :] = aggregated_features.cpu().numpy() 40 | 41 | torch.cuda.empty_cache() 42 | queries_features = all_features[eval_ds.database_num:] 43 | gallery_features = all_features[:eval_ds.database_num] 44 | 45 | faiss_index = faiss.IndexFlatL2(model.aggregation_dim) 46 | faiss_index.add(gallery_features) 47 | 48 | logging.debug("Calculating recalls") 49 | _, predictions = faiss_index.search(queries_features, 10) 50 | 51 | # For each query, check if the predictions are correct 52 | positives_per_query = eval_ds.pIdx 53 | recall_values = [1, 5, 10] # recall@1, recall@5, recall@10 54 | recalls = np.zeros(len(recall_values)) 55 | for query_index, pred in enumerate(predictions): 56 | for i, n in enumerate(recall_values): 57 | if np.any(np.in1d(pred[:n], positives_per_query[query_index])): 58 | recalls[i:] += 1 59 | break 60 | # Divide by the number of queries*100, so the recalls are in percentages 61 | recalls = recalls / len(eval_ds.qIdx) * 100 62 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(recall_values, recalls)]) 63 | return recalls, recalls_str 64 | -------------------------------------------------------------------------------- /jist/models/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['network'] 2 | 3 | 4 | from jist.models.network import JistModel 5 | -------------------------------------------------------------------------------- /jist/models/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import einops 5 | 6 | 7 | def seq_gem(x, p=torch.ones(1)*3, eps: float = 1e-6): 8 | B, D, SL = x.shape 9 | return F.avg_pool1d(x.clamp(min=eps).pow(p), SL).pow(1./p) 10 | 11 | 12 | class SeqGeM(nn.Module): 13 | def __init__(self, p=3, eps=1e-6): 14 | super().__init__() 15 | self.p = torch.nn.Parameter(torch.ones(1)*p) 16 | self.eps = eps 17 | def forward(self, x): 18 | B, SL, D = x.shape 19 | x = einops.rearrange(x, "b sl d -> b d sl") 20 | x = seq_gem(x, p=self.p, eps=self.eps) 21 | assert x.shape == torch.Size([B, D, 1]), f"{x.shape}" 22 | return x[:, :, 0] 23 | def __repr__(self): 24 | return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})" 25 | 26 | 27 | class JistModel(nn.Module): 28 | def __init__(self, args, agg_type="concat"): 29 | super().__init__() 30 | self.model = torch.hub.load("gmberton/cosplace", "get_trained_model", 31 | backbone=args.backbone, fc_output_dim=args.fc_output_dim) 32 | for name, param in self.model.named_parameters(): 33 | if name.startswith("backbone.7"): # Train only last residual block 34 | break 35 | param.requires_grad = False 36 | assert name.startswith("backbone.7"), "are you using a resnet? this only work with resnets" 37 | 38 | self.features_dim = self.model.aggregation[3].in_features 39 | self.fc_output_dim = self.model.aggregation[3].out_features 40 | self.seq_length = args.seq_length 41 | if agg_type == "concat": 42 | self.aggregation_dim = self.fc_output_dim * args.seq_length 43 | if agg_type == "mean": 44 | self.aggregation_dim = self.fc_output_dim 45 | if agg_type == "max": 46 | self.aggregation_dim = self.fc_output_dim 47 | if agg_type == "conv1d": 48 | self.conv1d = torch.nn.Conv1d(self.fc_output_dim, self.fc_output_dim, self.seq_length) 49 | self.aggregation_dim = self.fc_output_dim 50 | if agg_type in ["simplefc", "meanfc"]: 51 | self.aggregation_dim = self.fc_output_dim 52 | self.final_fc = torch.nn.Linear(self.fc_output_dim * args.seq_length, self.fc_output_dim, bias=False) 53 | if agg_type == "meanfc": 54 | # Initialize as a mean pooling over the frames 55 | weights = torch.zeros_like(self.final_fc.weight) 56 | for i in range(self.fc_output_dim): 57 | for j in range(args.seq_length): 58 | weights[i, j * self.fc_output_dim + i] = 1 / args.seq_length 59 | self.final_fc.weight = torch.nn.Parameter(weights) 60 | if agg_type == "seqgem": 61 | self.aggregation_dim = self.fc_output_dim 62 | self.seq_gem = SeqGeM() 63 | 64 | self.agg_type = agg_type 65 | 66 | def forward(self, x): 67 | return self.model(x) 68 | 69 | def aggregate(self, frames_features): 70 | if self.agg_type == "concat": 71 | concat_features = einops.rearrange(frames_features, "(b sl) d -> b (sl d)", sl=self.seq_length) 72 | return concat_features 73 | if self.agg_type == "mean": 74 | aggregated_features = einops.rearrange(frames_features, "(b sl) d -> b sl d", sl=self.seq_length) 75 | return aggregated_features.mean(1) 76 | if self.agg_type == "max": 77 | aggregated_features = einops.rearrange(frames_features, "(b sl) d -> b sl d", sl=self.seq_length) 78 | return aggregated_features.max(1)[0] 79 | if self.agg_type == "conv1d": 80 | aggregated_features = einops.rearrange(frames_features, "(b sl) d -> b sl d", sl=self.seq_length) 81 | aggregated_features = einops.rearrange(aggregated_features, "b sl d -> b d sl", sl=self.seq_length) 82 | features = self.conv1d(aggregated_features) 83 | if len(features.shape) > 2 and features.shape[2] == 1: 84 | features = features[:, :, 0] 85 | return features 86 | if self.agg_type in ["simplefc", "meanfc"]: 87 | concat_features = einops.rearrange(frames_features, "(b sl) d -> b (sl d)", sl=self.seq_length) 88 | return self.final_fc(concat_features) 89 | if self.agg_type == "seqgem": 90 | aggregated_features = einops.rearrange(frames_features, "(b sl) d -> b sl d", sl=self.seq_length) 91 | return self.seq_gem(aggregated_features) 92 | -------------------------------------------------------------------------------- /jist/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['parser', 'logging', 'utils', 'data', 'commons', 'augmentations'] 2 | 3 | 4 | from .parser import parse_arguments 5 | from .logging import setup_logging 6 | from .utils import save_checkpoint, resume_train, load_pretrained_backbone, configure_transform 7 | from .data import RAMEfficient2DMatrix 8 | from .commons import InfiniteDataLoader, make_deterministic, delete_model_gradients 9 | from .cp_utils import move_to_device 10 | from .augmentations import DeviceAgnosticColorJitter, DeviceAgnosticRandomResizedCrop 11 | from .cosface_loss import MarginCosineProduct 12 | -------------------------------------------------------------------------------- /jist/utils/augmentations.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torchvision.transforms as T 4 | 5 | 6 | class DeviceAgnosticColorJitter(T.ColorJitter): 7 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 8 | """This is the same as T.ColorJitter but it only accepts batches of images and works on GPU""" 9 | super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 10 | def forward(self, images): 11 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}" 12 | B, C, H, W = images.shape 13 | # Applies a different color jitter to each image 14 | color_jitter = super(DeviceAgnosticColorJitter, self).forward 15 | augmented_images = [color_jitter(img).unsqueeze(0) for img in images] 16 | augmented_images = torch.cat(augmented_images) 17 | assert augmented_images.shape == torch.Size([B, C, H, W]) 18 | return augmented_images 19 | 20 | class DeviceAgnosticRandomResizedCrop(T.RandomResizedCrop): 21 | def __init__(self, size, scale): 22 | """This is the same as T.RandomResizedCrop but it only accepts batches of images and works on GPU""" 23 | super().__init__(size=size, scale=scale) 24 | def forward(self, images): 25 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}" 26 | B, C, H, W = images.shape 27 | # Applies a different color jitter to each image 28 | random_resized_crop = super(DeviceAgnosticRandomResizedCrop, self).forward 29 | augmented_images = [random_resized_crop(img).unsqueeze(0) for img in images] 30 | augmented_images = torch.cat(augmented_images) 31 | return augmented_images 32 | 33 | 34 | if __name__ == "__main__": 35 | """ 36 | You can run this script to visualize the transformations, and verify that 37 | the augmentations are applied individually on each image of the batch. 38 | """ 39 | from PIL import Image 40 | # Import skimage in here, so it is not necessary to install it unless you run this script 41 | from skimage import data 42 | 43 | # Initialize DeviceAgnosticRandomResizedCrop 44 | random_crop = DeviceAgnosticRandomResizedCrop(size=[256, 256], scale=[0.5, 1]) 45 | # Create a batch with 2 astronaut images 46 | pil_image = Image.fromarray(data.astronaut()) 47 | tensor_image = T.functional.to_tensor(pil_image).unsqueeze(0) 48 | images_batch = torch.cat([tensor_image, tensor_image]) 49 | # Apply augmentation (individually on each of the 2 images) 50 | augmented_batch = random_crop(images_batch) 51 | # Convert to PIL images 52 | augmented_image_0 = T.functional.to_pil_image(augmented_batch[0]) 53 | augmented_image_1 = T.functional.to_pil_image(augmented_batch[1]) 54 | # Visualize the original image, as well as the two augmented ones 55 | pil_image.show() 56 | augmented_image_0.show() 57 | augmented_image_1.show() 58 | 59 | -------------------------------------------------------------------------------- /jist/utils/commons.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import random 4 | import numpy as np 5 | 6 | 7 | class InfiniteDataLoader(torch.utils.data.DataLoader): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.dataset_iterator = super().__iter__() 11 | 12 | def __iter__(self): 13 | return self 14 | 15 | def __next__(self): 16 | try: 17 | batch = next(self.dataset_iterator) 18 | except StopIteration: 19 | self.dataset_iterator = super().__iter__() 20 | batch = next(self.dataset_iterator) 21 | return batch 22 | 23 | 24 | def delete_model_gradients(model): 25 | """Set gradients to None to free some GPU memory. Useful before inference. 26 | Note that using optimizer.zero_grad() sets them to 0, which keeps using GPU space. 27 | """ 28 | for param in model.parameters(): 29 | param.grad = None 30 | 31 | 32 | def make_deterministic(seed=0): 33 | """Make results deterministic. If seed == -1, do not make deterministic. 34 | Running your script in a deterministic way might slow it down. 35 | Note that for some packages (eg: sklearn's PCA) this function is not enough. 36 | """ 37 | seed = int(seed) 38 | if seed == -1: 39 | return 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | -------------------------------------------------------------------------------- /jist/utils/cosface_loss.py: -------------------------------------------------------------------------------- 1 | 2 | # Based on https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import Parameter 7 | 8 | def cosine_sim(x1, x2, dim=1, eps=1e-8): 9 | ip = torch.mm(x1, x2.t()) 10 | w1 = torch.norm(x1, 2, dim) 11 | w2 = torch.norm(x2, 2, dim) 12 | return ip / torch.ger(w1,w2).clamp(min=eps) 13 | 14 | class MarginCosineProduct(nn.Module): 15 | """Implement of large margin cosine distance: 16 | Args: 17 | in_features: size of each input sample 18 | out_features: size of each output sample 19 | s: norm of input feature 20 | m: margin 21 | """ 22 | def __init__(self, in_features, out_features, s=30.0, m=0.40): 23 | super().__init__() 24 | self.in_features = in_features 25 | self.out_features = out_features 26 | self.s = s 27 | self.m = m 28 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 29 | nn.init.xavier_uniform_(self.weight) 30 | def forward(self, input, label): 31 | cosine = cosine_sim(input, self.weight) 32 | one_hot = torch.zeros_like(cosine) 33 | one_hot.scatter_(1, label.view(-1, 1), 1.0) 34 | output = self.s * (cosine - one_hot * self.m) 35 | return output 36 | def __repr__(self): 37 | return self.__class__.__name__ + '(' \ 38 | + 'in_features=' + str(self.in_features) \ 39 | + ', out_features=' + str(self.out_features) \ 40 | + ', s=' + str(self.s) \ 41 | + ', m=' + str(self.m) + ')' 42 | 43 | -------------------------------------------------------------------------------- /jist/utils/cp_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import shutil 4 | import logging 5 | 6 | 7 | def move_to_device(optimizer, device): 8 | for state in optimizer.state.values(): 9 | for k, v in state.items(): 10 | if torch.is_tensor(v): 11 | state[k] = v.to(device) 12 | 13 | 14 | def save_checkpoint(state, is_best, output_folder, ckpt_filename="last_checkpoint.pth"): 15 | # TODO it would be better to move weights to cpu before saving 16 | checkpoint_path = f"{output_folder}/{ckpt_filename}" 17 | torch.save(state, checkpoint_path) 18 | if is_best: 19 | torch.save(state["model_state_dict"], f"{output_folder}/best_model.pth") 20 | 21 | 22 | def resume_train(args, output_folder, model, model_optimizer, classifiers, classifiers_optimizers): 23 | """Load model, optimizer, and other training parameters""" 24 | logging.info(f"Loading checkpoint: {args.resume_train}") 25 | checkpoint = torch.load(args.resume_train) 26 | start_epoch_num = checkpoint["epoch_num"] 27 | 28 | model_state_dict = checkpoint["model_state_dict"] 29 | model.load_state_dict(model_state_dict) 30 | 31 | model = model.to(args.device) 32 | model_optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 33 | 34 | assert args.groups_num == len(classifiers) == len(classifiers_optimizers) == len(checkpoint["classifiers_state_dict"]) == len(checkpoint["optimizers_state_dict"]), \ 35 | f"{args.groups_num} , {len(classifiers)} , {len(classifiers_optimizers)} , {len(checkpoint['classifiers_state_dict'])} , {len(checkpoint['optimizers_state_dict'])}" 36 | 37 | for c, sd in zip(classifiers, checkpoint["classifiers_state_dict"]): 38 | # Move classifiers to GPU before loading their optimizers 39 | c = c.to(args.device) 40 | c.load_state_dict(sd) 41 | for c, sd in zip(classifiers_optimizers, checkpoint["optimizers_state_dict"]): 42 | c.load_state_dict(sd) 43 | for c in classifiers: 44 | # Move classifiers back to CPU to save some GPU memory 45 | c = c.cpu() 46 | 47 | sequence_best_r5 = checkpoint["best_seq_val_recall5"] 48 | 49 | # Copy best model to current output_folder 50 | shutil.copy(args.resume_train.replace("last_checkpoint.pth", "best_model.pth"), output_folder) 51 | 52 | return model, model_optimizer, classifiers, classifiers_optimizers, sequence_best_r5, start_epoch_num 53 | -------------------------------------------------------------------------------- /jist/utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RAMEfficient2DMatrix: 5 | """This class behaves similarly to a numpy.ndarray initialized 6 | with np.zeros(), but is implemented to save RAM when the rows 7 | within the 2D array are sparse. In this case it's needed because 8 | we don't always compute features for each image, just for few of 9 | them""" 10 | 11 | def __init__(self, shape, dtype=np.float32): 12 | self.shape = shape 13 | self.dtype = dtype 14 | self.matrix = [None] * shape[0] 15 | 16 | def __setitem__(self, indexes, vals): 17 | assert vals.shape[1] == self.shape[1], f"{vals.shape[1]} {self.shape[1]}" 18 | for i, val in zip(indexes, vals): 19 | self.matrix[i] = val.astype(self.dtype, copy=False) 20 | 21 | def __getitem__(self, index): 22 | if hasattr(index, "__len__"): 23 | return np.array([self.matrix[i] for i in index]) 24 | else: 25 | return self.matrix[index] 26 | -------------------------------------------------------------------------------- /jist/utils/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import sys 4 | import traceback 5 | 6 | 7 | def setup_logging(output_folder, console="debug", 8 | info_filename="info.log", debug_filename="debug.log"): 9 | """Set up logging files and console output. 10 | Creates one file for INFO logs and one for DEBUG logs. 11 | Args: 12 | output_folder (str): creates the folder where to save the files. 13 | debug (str): 14 | if == "debug" prints on console debug messages and higher 15 | if == "info" prints on console info messages and higher 16 | if == None does not use console (useful when a logger has already been set) 17 | info_filename (str): the name of the info file. if None, don't create info file 18 | debug_filename (str): the name of the debug file. if None, don't create debug file 19 | """ 20 | os.makedirs(output_folder, exist_ok=True) 21 | # logging.Logger.manager.loggerDict.keys() to check which loggers are in use 22 | import warnings 23 | warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional") 24 | logging.getLogger('matplotlib.font_manager').disabled = True 25 | logging.getLogger('shapely').disabled = True 26 | logging.getLogger('shapely.geometry').disabled = True 27 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S") 28 | logger = logging.getLogger('') 29 | logger.setLevel(logging.DEBUG) 30 | logging.getLogger('PIL').setLevel(logging.INFO) # turn off logging tag for some images 31 | 32 | if info_filename != None: 33 | info_file_handler = logging.FileHandler(f'{output_folder}/{info_filename}') 34 | info_file_handler.setLevel(logging.INFO) 35 | info_file_handler.setFormatter(base_formatter) 36 | logger.addHandler(info_file_handler) 37 | 38 | if debug_filename != None: 39 | debug_file_handler = logging.FileHandler(f'{output_folder}/{debug_filename}') 40 | debug_file_handler.setLevel(logging.DEBUG) 41 | debug_file_handler.setFormatter(base_formatter) 42 | logger.addHandler(debug_file_handler) 43 | 44 | if console != None: 45 | console_handler = logging.StreamHandler() 46 | if console == "debug": console_handler.setLevel(logging.DEBUG) 47 | if console == "info": console_handler.setLevel(logging.INFO) 48 | console_handler.setFormatter(base_formatter) 49 | logger.addHandler(console_handler) 50 | 51 | def exception_handler(type_, value, tb): 52 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb))) 53 | sys.excepthook = exception_handler -------------------------------------------------------------------------------- /jist/utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | 5 | 6 | def _get_device_count(): 7 | if torch.cuda.is_available(): 8 | return torch.cuda.device_count() 9 | return -1 10 | 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser(description="Sequence Visual Geolocalization", 14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | # reproducibility 16 | parser.add_argument('--seed', type=int, default=0) 17 | 18 | # dataset 19 | parser.add_argument("--city", type=str, default='', help='subset of cities from train set') 20 | parser.add_argument("--seq_length", type=int, default=5, 21 | help="Number of images in each sequence") 22 | parser.add_argument("--reverse", action='store_true', default=False, help='reverse DB sequences frames') 23 | parser.add_argument("--cut_last_frame", action='store_true', default=False, help='cut last sequence frame') 24 | parser.add_argument("--val_posDistThr", type=int, default=25, help="_") 25 | parser.add_argument("--train_posDistThr", type=int, default=10, help="_") 26 | parser.add_argument("--negDistThr", type=int, default=25, help="_") 27 | parser.add_argument('--img_shape', type=int, default=[480, 640], nargs=2, 28 | help="Resizing shape for images (HxW).") 29 | 30 | # about triplets and mining 31 | parser.add_argument("--nNeg", type=int, default=5, 32 | help="How many negatives to consider per each query in the loss") 33 | parser.add_argument("--cached_negatives", type=int, default=3000, 34 | help="How many negatives to use to compute the hardest ones") 35 | parser.add_argument("--cached_queries", type=int, default=1000, 36 | help="How many queries to keep cached") 37 | parser.add_argument("--queries_per_epoch", type=int, default=5000, 38 | help="How many queries to consider for one epoch. Must be multiple of cached_queries") 39 | 40 | # models 41 | parser.add_argument("--resume", type=str, default=None, 42 | help="Path to load checkpoint from, for resuming training or testing.") 43 | parser.add_argument("--pretrain_model", type=str, default=None, 44 | help="Path to load pretrained model from.") 45 | 46 | # training pars 47 | parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) 48 | parser.add_argument("--num_workers", type=int, default=8, 49 | help="num_workers for all dataloaders") 50 | 51 | parser.add_argument("--num_sub_epochs", type=int, default=10, 52 | help="How many times to recompute cache per epoch.") 53 | parser.add_argument("--train_batch_size", type=int, default=4, 54 | help="Number of triplets: (query + pos + negs) * seq_length.") 55 | parser.add_argument("--infer_batch_size", type=int, default=8, 56 | help="Batch size for inference (caching and testing)") 57 | 58 | parser.add_argument("--epochs_num", type=int, default=5, 59 | help="number of epochs to train for") 60 | 61 | parser.add_argument("--margin", type=float, default=0.1, 62 | help="margin for the triplet loss") 63 | parser.add_argument("--lambda_triplet", type=float, default=10000, help="weight to triplet loss") 64 | parser.add_argument("--lambda_im2im", type=float, default=100, help="weight to cosplace loss") 65 | 66 | # PATHS 67 | parser.add_argument("--seq_dataset_path", type=str, help="Path of the seq2seq dataset") 68 | parser.add_argument("--dataset_folder", type=str, # should end in 'sf_xl/processed" 69 | help="path of the SF-XL processed folder with train/val/test sets") 70 | parser.add_argument("--exp_name", type=str, default="default", 71 | help="Folder name of the current run (saved in ./runs/)") 72 | 73 | # CosPlace Groups parameters 74 | parser.add_argument("--M", type=int, default=10, help="_") 75 | parser.add_argument("--alpha", type=int, default=30, help="_") 76 | parser.add_argument("--N", type=int, default=5, help="_") 77 | parser.add_argument("--L", type=int, default=2, help="_") 78 | parser.add_argument("--groups_num", type=int, default=8, help="_") 79 | parser.add_argument("--min_images_per_class", type=int, default=10, help="_") 80 | # Model parameters 81 | parser.add_argument("--backbone", type=str, default="ResNet18", 82 | choices=["ResNet18", "ResNet50", "ResNet101", "VGG16"], help="_") 83 | parser.add_argument("--aggregation_type", type=str, default="seqgem", 84 | choices=["concat", "mean", "max", "simplefc", "conv1d", "meanfc", "seqgem"], help="_") 85 | parser.add_argument("--fc_output_dim", type=int, default=512, 86 | help="Output dimension of final fully connected layer") 87 | parser.add_argument("--augmentation_device", type=str, default="cuda", 88 | choices=["cuda", "cpu"], 89 | help="on which device to run data augmentation") 90 | parser.add_argument("--cp_batch_size", type=int, default=64, help="_") 91 | parser.add_argument("--lr", type=float, default=0.00001, help="_") 92 | parser.add_argument("--classifiers_lr", type=float, default=0.01, help="_") 93 | # Data augmentation 94 | parser.add_argument("--brightness", type=float, default=0.7, help="_") 95 | parser.add_argument("--contrast", type=float, default=0.7, help="_") 96 | parser.add_argument("--hue", type=float, default=0.5, help="_") 97 | parser.add_argument("--saturation", type=float, default=0.7, help="_") 98 | parser.add_argument("--random_resized_crop", type=float, default=0.5, help="_") 99 | # Resume parameters 100 | parser.add_argument("--resume_train", type=str, default=None, 101 | help="path to checkpoint to resume, e.g. logs/.../last_checkpoint.pth") 102 | parser.add_argument("--resume_model", type=str, default=None, 103 | help="path to model to resume, e.g. logs/.../best_model.pth") 104 | 105 | args = parser.parse_args() 106 | 107 | if args.dataset_folder is not None: 108 | args.train_set_folder = os.path.join(args.dataset_folder, "train") 109 | if not os.path.exists(args.train_set_folder): 110 | raise FileNotFoundError(f"Folder {args.train_set_folder} does not exist") 111 | 112 | args.val_set_folder = os.path.join(args.dataset_folder, "val") 113 | if not os.path.exists(args.val_set_folder): 114 | raise FileNotFoundError(f"Folder {args.val_set_folder} does not exist") 115 | 116 | args.test_set_folder = os.path.join(args.dataset_folder, "test") 117 | if not os.path.exists(args.test_set_folder): 118 | raise FileNotFoundError(f"Folder {args.test_set_folder} does not exist") 119 | 120 | if args.queries_per_epoch % args.cached_queries != 0: 121 | raise ValueError("Please ensure that queries_per_epoch is divisible by cache_refresh_rate, " + 122 | f"because {args.queries_per_epoch} is not divisible by {args.cached_queries}") 123 | 124 | return args 125 | 126 | -------------------------------------------------------------------------------- /jist/utils/utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from torchvision import transforms 3 | import torch 4 | import logging 5 | 6 | 7 | def save_checkpoint(args, state, is_best, filename): 8 | model_path = f"{args.output_folder}/{filename}" 9 | torch.save(state, model_path) 10 | if is_best: 11 | shutil.copyfile(model_path, f"{args.output_folder}/best_model.pth") 12 | 13 | 14 | def resume_train(args, model, optimizer=None, strict=False): 15 | """Load model, optimizer, and other training parameters""" 16 | logging.debug(f"Loading checkpoint: {args.resume}") 17 | checkpoint = torch.load(args.resume) 18 | start_epoch_num = checkpoint["epoch_num"] 19 | model.load_state_dict(checkpoint["model_state_dict"], strict=strict) 20 | if optimizer: 21 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 22 | best_r5 = checkpoint["best_r5"] 23 | logging.debug(f"Loaded checkpoint: start_epoch_num = {start_epoch_num}, " \ 24 | f"current_best_R@5 = {best_r5:.1f}") 25 | if args.resume.endswith("last_model.pth"): # Copy best model to current output_folder 26 | shutil.copy(args.resume.replace("last_model.pth", "best_model.pth"), args.output_folder) 27 | return model, optimizer, best_r5, start_epoch_num 28 | 29 | 30 | def load_pretrained_backbone(args, model): 31 | """Load a pretrained backbone""" 32 | logging.debug(f"Loading checkpoint: {args.pretrain_model}") 33 | checkpoint = torch.load(args.pretrain_model) 34 | model.load_state_dict(checkpoint["model_state_dict"], strict=False) 35 | return model 36 | 37 | 38 | def configure_transform(image_dim, meta): 39 | normalize = transforms.Normalize(mean=meta['mean'], std=meta['std']) 40 | transform = transforms.Compose([ 41 | transforms.Resize(image_dim), 42 | transforms.ToTensor(), 43 | normalize, 44 | ]) 45 | 46 | return transform 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | 3 | torch==1.13.1+cu117 4 | torchvision==0.14.1+cu117 5 | faiss-cpu==1.7.3 6 | scikit-learn==1.2.0 7 | scipy==1.10.0 8 | transformers==4.25.1 9 | timm==0.6.12 10 | torchmetrics==0.11.1 11 | einops 12 | tqdm -------------------------------------------------------------------------------- /train_double_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import torch 4 | import logging 5 | import torchmetrics 6 | from torch import nn 7 | import numpy as np 8 | from tqdm import tqdm 9 | import multiprocessing 10 | from datetime import datetime 11 | import torchvision.transforms as T 12 | from torch.utils.data.dataloader import DataLoader 13 | from os.path import join 14 | torch.backends.cudnn.benchmark= True # Provides a speedup 15 | 16 | # ours 17 | from jist.datasets import (BaseDataset, TrainDataset, collate_fn, 18 | CosplaceTrainDataset, TestDataset) 19 | from jist import utils 20 | from jist.models import JistModel 21 | from jist import evals 22 | from jist.utils import (parse_arguments, setup_logging, MarginCosineProduct, 23 | configure_transform, delete_model_gradients, 24 | InfiniteDataLoader, make_deterministic, move_to_device) 25 | 26 | args = parse_arguments() 27 | start_time = datetime.now() 28 | output_folder = f"logs/{args.exp_name}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}" 29 | setup_logging(output_folder) 30 | make_deterministic(args.seed) 31 | logging.info(" ".join(sys.argv)) 32 | logging.info(f"Arguments: {args}") 33 | logging.info(f"The outputs are being saved in {output_folder}") 34 | 35 | # Model 36 | model = JistModel(args, agg_type=args.aggregation_type) 37 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.") 38 | 39 | if args.resume_model != None: 40 | logging.debug(f"Loading model from {args.resume_model}") 41 | model_state_dict = torch.load(args.resume_model) 42 | model.load_state_dict(model_state_dict) 43 | 44 | model = model.to(args.device).train() 45 | 46 | #### Optimizer 47 | criterion = torch.nn.CrossEntropyLoss() 48 | model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 49 | criterion_triplet = nn.TripletMarginLoss(margin=args.margin, p=2, reduction="sum") 50 | 51 | #|||||||||||||||||||||||| Datasets 52 | #### Datasets cosplace 53 | groups = [CosplaceTrainDataset(args, args.train_set_folder, M=args.M, alpha=args.alpha, N=args.N, L=args.L, 54 | current_group=n, min_images_per_class=args.min_images_per_class) for n in range(args.groups_num)] 55 | # Each group has its own classifier, which depends on the number of classes in the group 56 | classifiers = [MarginCosineProduct(args.fc_output_dim, len(group)) for group in groups] 57 | classifiers_optimizers = [torch.optim.Adam(classifier.parameters(), lr=args.classifiers_lr) for classifier in classifiers] 58 | 59 | logging.info(f"Using {len(groups)} groups") 60 | logging.info(f"The {len(groups)} groups have respectively the following number of classes {[len(g) for g in groups]}") 61 | logging.info(f"The {len(groups)} groups have respectively the following number of images {[g.get_images_num() for g in groups]}") 62 | 63 | cp_val_ds = TestDataset(args.val_set_folder, positive_dist_threshold=args.val_posDistThr) 64 | logging.info(f"Validation set: {cp_val_ds}") 65 | 66 | #### Datasets sequence 67 | # get transform 68 | meta = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]} 69 | img_shape = (args.img_shape[0], args.img_shape[1]) 70 | transform = configure_transform(image_dim=img_shape, meta=meta) 71 | 72 | logging.info("Loading train set...") 73 | triplets_ds = TrainDataset(cities=args.city, dataset_folder=args.seq_dataset_path, split='train', 74 | base_transform=transform, seq_len=args.seq_length, 75 | pos_thresh=args.train_posDistThr, neg_thresh=args.negDistThr, infer_batch_size=args.infer_batch_size, 76 | num_workers=args.num_workers, img_shape=args.img_shape, 77 | cached_negatives=args.cached_negatives, 78 | cached_queries=args.cached_queries, nNeg=args.nNeg) 79 | 80 | logging.info(f"Train set: {triplets_ds}") 81 | logging.info("Loading val set...") 82 | val_ds = BaseDataset(dataset_folder=args.seq_dataset_path, split='val', 83 | base_transform=transform, seq_len=args.seq_length, 84 | pos_thresh=args.val_posDistThr) 85 | logging.info(f"Val set: {val_ds}") 86 | 87 | logging.info("Loading test set...") 88 | test_ds = BaseDataset(dataset_folder=args.seq_dataset_path, split='test', 89 | base_transform=transform, seq_len=args.seq_length, 90 | pos_thresh=args.val_posDistThr) 91 | logging.info(f"Test set: {test_ds}") 92 | 93 | 94 | #### Resume 95 | if args.resume_train: 96 | model, model_optimizer, classifiers, classifiers_optimizers, sequence_best_r5, start_epoch_num = \ 97 | utils.cp_utils.resume_train(args, output_folder, model, model_optimizer, classifiers, classifiers_optimizers) 98 | 99 | epoch_num = start_epoch_num - 1 100 | best_val_recall1 = 0 101 | logging.info(f"Resuming from epoch {start_epoch_num} with best seq R@5 {sequence_best_r5:.1f} from checkpoint {args.resume_train}") 102 | else: 103 | best_val_recall1 = start_epoch_num = sequence_best_r5 = 0 104 | 105 | #### Train / evaluation loop 106 | iterations_per_epoch = args.cached_queries // args.train_batch_size * args.num_sub_epochs 107 | logging.info("Start training ...") 108 | logging.info(f"There are {len(groups[0])} classes for the first group, " + 109 | f"each epoch has {iterations_per_epoch} iterations " + 110 | f"with batch_size {args.cp_batch_size}, therefore the model sees each class (on average) " + 111 | f"{iterations_per_epoch * args.cp_batch_size / len(groups[0]):.1f} times per epoch") 112 | logging.info(f"Backbone output channels are {model.features_dim}, features descriptor dim is {model.fc_output_dim}, " 113 | f"sequence descriptor dim is {model.aggregation_dim}") 114 | 115 | gpu_augmentation = T.Compose([ 116 | utils.augmentations.DeviceAgnosticColorJitter(brightness=args.brightness, contrast=args.contrast, 117 | saturation=args.saturation, hue=args.hue), 118 | utils.augmentations.DeviceAgnosticRandomResizedCrop([512, 512], 119 | scale=[1-args.random_resized_crop, 1]), 120 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 121 | ]) 122 | 123 | scaler = torch.cuda.amp.GradScaler() 124 | for epoch_num in range(start_epoch_num, args.epochs_num): 125 | 126 | epoch_start_time = datetime.now() 127 | # Select classifier and dataloader according to epoch 128 | current_group_num = epoch_num % args.groups_num 129 | classifiers[current_group_num] = classifiers[current_group_num].to(args.device) 130 | move_to_device(classifiers_optimizers[current_group_num], args.device) 131 | dataloader = InfiniteDataLoader(groups[current_group_num], num_workers=args.num_workers, drop_last=True, 132 | batch_size=args.cp_batch_size, shuffle=True, pin_memory=(args.device == "cuda")) 133 | dataloader_iterator = iter(dataloader) 134 | model = model.train() 135 | 136 | sequence_mean_loss = torchmetrics.MeanMetric() 137 | cosplace_mean_loss = torchmetrics.MeanMetric() 138 | 139 | seq_epoch_losses = np.zeros((0, 1), dtype=np.float32) 140 | 141 | for num_sub_epoch in range(args.num_sub_epochs): 142 | logging.debug(f"Cache: {num_sub_epoch + 1} / {args.num_sub_epochs}") 143 | 144 | # creates triplets on the smaller cache set 145 | triplets_ds.compute_triplets(model) 146 | triplets_dl = DataLoader(dataset=triplets_ds, num_workers=args.num_workers, 147 | batch_size=args.train_batch_size, 148 | collate_fn=collate_fn, 149 | pin_memory=(args.device == "cuda"), 150 | drop_last=True) 151 | 152 | model = model.train() 153 | tqdm_bar = tqdm(triplets_dl, ncols=100) 154 | for images, _, _ in tqdm_bar: 155 | model_optimizer.zero_grad() 156 | classifiers_optimizers[current_group_num].zero_grad() 157 | 158 | if args.lambda_triplet != 0: 159 | #### ITERATION ON SEQUENCES 160 | # images shape: (bsz, seq_len*(nNeg + 2), 3, H, W) 161 | # triplets_local_indexes shape: (bsz, nNeg+2) -> contains -1 for query, 1 for pos, 0 for neg 162 | # reshape images to only have 4-d 163 | images = images.view(-1, 3, *img_shape) 164 | # features : (bsz*(nNeg+2), model_output_size) 165 | with torch.cuda.amp.autocast(): 166 | features = model(images.to(args.device)) 167 | features = model.aggregate(features) 168 | 169 | # Compute loss by passing the triplets one by one 170 | sequence_loss = 0 171 | features = features.view(args.train_batch_size, -1, model.aggregation_dim) 172 | for b in range(args.train_batch_size): 173 | query = features[b:b + 1, 0] # size (1, output_dim) 174 | pos = features[b:b + 1, 1] # size (1, output_dim) 175 | negatives = features[b, 2:] # size (nNeg, output_dim) 176 | # negatives has 10 images , pos and query 1 but 177 | # the loss yields same result as calling it 10 times 178 | sequence_loss += criterion_triplet(query, pos, negatives) 179 | del images, features 180 | sequence_loss /= (args.train_batch_size * args.nNeg) 181 | sequence_loss *= args.lambda_triplet 182 | scaler.scale(sequence_loss).backward() 183 | sequence_mean_loss.update(sequence_loss.item()) 184 | del sequence_loss 185 | else: 186 | sequence_mean_loss.update(-1) 187 | 188 | if args.lambda_im2im != 0: 189 | #### ITERATION ON COSPLACE 190 | images, targets, _ = next(dataloader_iterator) 191 | images, targets = images.to(args.device), targets.to(args.device) 192 | 193 | if args.augmentation_device == "cuda": 194 | images = gpu_augmentation(images) 195 | 196 | with torch.cuda.amp.autocast(): 197 | descriptors = model(images) 198 | output = classifiers[current_group_num](descriptors, targets) 199 | cosplace_loss = criterion(output, targets) 200 | del output, images, descriptors, targets 201 | cosplace_loss *= args.lambda_im2im 202 | scaler.scale(cosplace_loss).backward() 203 | cosplace_mean_loss.update(cosplace_loss.item()) 204 | del cosplace_loss 205 | else: 206 | cosplace_mean_loss.update(-1) 207 | 208 | scaler.step(model_optimizer) 209 | if args.lambda_im2im != 0: 210 | scaler.step(classifiers_optimizers[current_group_num]) 211 | scaler.update() 212 | tqdm_bar.set_description(f"seq_loss: {sequence_mean_loss.compute():.4f} - cos_loss: {cosplace_mean_loss.compute():.2f}") 213 | 214 | logging.debug(f"Epoch[{epoch_num:02d}]({num_sub_epoch + 1}/{args.num_sub_epochs}): " + 215 | f"epoch sequence loss = {sequence_mean_loss.compute():.4f} - " + 216 | f"epoch cosplace loss = {cosplace_mean_loss.compute():.4f}") 217 | 218 | classifiers[current_group_num] = classifiers[current_group_num].cpu() 219 | move_to_device(classifiers_optimizers[current_group_num], "cpu") 220 | delete_model_gradients(model) 221 | 222 | #### Evaluation CosPlace 223 | cosplace_recalls, cosplace_recalls_str = evals.cosplace_test(args, cp_val_ds, model) 224 | logging.info(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, cosPlace {cp_val_ds}: {cosplace_recalls_str[:20]}") 225 | cosplace_is_best = cosplace_recalls[0] > best_val_recall1 226 | cosplace_best_val_recall1 = max(cosplace_recalls[0], best_val_recall1) 227 | 228 | #### Evaluation Sequence 229 | sequence_recalls, sequence_recalls_str = evals.test(args, val_ds, model) 230 | logging.info(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, sequence {val_ds}: {sequence_recalls_str}") 231 | sequence_is_best = sequence_recalls[1] > sequence_best_r5 232 | 233 | if sequence_is_best: 234 | logging.info(f"Improved: previous best R@5 = {sequence_best_r5:.1f}, current R@5 = {sequence_recalls[1]:.1f}") 235 | sequence_best_r5 = sequence_recalls[1] 236 | else: 237 | logging.info(f"Not improved: best R@5 = {sequence_best_r5:.1f}, current R@5 = {sequence_recalls[1]:.1f}") 238 | 239 | utils.cp_utils.save_checkpoint({ 240 | "epoch_num": epoch_num + 1, 241 | "model_state_dict": model.state_dict(), 242 | "optimizer_state_dict": model_optimizer.state_dict(), 243 | "classifiers_state_dict": [c.state_dict() for c in classifiers], 244 | "optimizers_state_dict": [c.state_dict() for c in classifiers_optimizers], 245 | "best_seq_val_recall5": sequence_best_r5 246 | }, sequence_is_best, output_folder) 247 | recalls, recalls_str = evals.test(args, test_ds, model) 248 | logging.info(f"Recalls on test set: {recalls_str}") 249 | 250 | logging.info(f"Trained for {epoch_num+1:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}") 251 | #### Test best model on test set 252 | best_model_state_dict = torch.load(join(output_folder, "best_model.pth")) 253 | model.load_state_dict(best_model_state_dict) 254 | 255 | recalls, recalls_str = evals.test(args, test_ds, model) 256 | logging.info(f"Recalls on test set: {recalls_str}") 257 | --------------------------------------------------------------------------------