├── .gitignore
├── LICENSE
├── README.md
├── assets
└── architecture.jpg
├── cache_dataset_datastruct.py
├── evaluation.py
├── jist
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── dataset.py
│ ├── test_dataset.py
│ └── train_dataset.py
├── evals
│ ├── __init__.py
│ ├── cp_test.py
│ └── test.py
├── models
│ ├── __init__.py
│ └── network.py
└── utils
│ ├── __init__.py
│ ├── augmentations.py
│ ├── commons.py
│ ├── cosface_loss.py
│ ├── cp_utils.py
│ ├── data.py
│ ├── logging.py
│ ├── parser.py
│ └── utils.py
├── requirements.txt
└── train_double_dataset.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # custom
2 | *.ipynb_checkpoints
3 | *.ipynb
4 | mytest.py
5 | saved_objects/
6 | downloaded/
7 | *secret*
8 | logs/
9 | cache
10 | *corrupt*
11 | *fake*
12 | 2legion*
13 | out.txt
14 | vit_legion.job
15 | ign*.*
16 | test
17 | *.pkl
18 | testing*.py
19 | *.npy
20 | f_*launcher.py
21 | jobs/
22 |
23 | # Byte-compiled / optimized / DLL files
24 | __pycache__/
25 | *.py[cod]
26 | *$py.class
27 |
28 | # C extensions
29 | *.so
30 |
31 | # Distribution / packaging
32 | .Python
33 | build/
34 | develop-eggs/
35 | dist/
36 | downloads/
37 | eggs/
38 | .eggs/
39 | lib/
40 | lib64/
41 | parts/
42 | sdist/
43 | var/
44 | wheels/
45 | pip-wheel-metadata/
46 | share/python-wheels/
47 | *.egg-info/
48 | .installed.cfg
49 | *.egg
50 | MANIFEST
51 |
52 | # PyInstaller
53 | # Usually these files are written by a python script from a template
54 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
55 | *.manifest
56 | *.spec
57 |
58 | # Installer logs
59 | pip-log.txt
60 | pip-delete-this-directory.txt
61 |
62 | # Unit test / coverage reports
63 | htmlcov/
64 | .tox/
65 | .nox/
66 | .coverage
67 | .coverage.*
68 | .cache
69 | nosetests.xml
70 | coverage.xml
71 | *.cover
72 | *.py,cover
73 | .hypothesis/
74 | .pytest_cache/
75 |
76 | # Translations
77 | *.mo
78 | *.pot
79 |
80 | # Django stuff:
81 | *.log
82 | local_settings.py
83 | db.sqlite3
84 | db.sqlite3-journal
85 |
86 | # Flask stuff:
87 | instance/
88 | .webassets-cache
89 |
90 | # Scrapy stuff:
91 | .scrapy
92 |
93 | # Sphinx documentation
94 | docs/_build/
95 |
96 | # PyBuilder
97 | target/
98 |
99 | # Jupyter Notebook
100 | .ipynb_checkpoints
101 |
102 | # IPython
103 | profile_default/
104 | ipython_config.py
105 |
106 | # pyenv
107 | .python-version
108 |
109 | # pipenv
110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
113 | # install all needed dependencies.
114 | #Pipfile.lock
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # Mac and PyCharm files
154 | .DS_Store
155 | .idea
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Gabriele Trivigno, Gabriele Berton
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # JIST: Joint Image and Sequence Training for Sequential Visual Place Recognition
2 |
3 |
4 |
5 |
Overview of the JIST framework.
6 |
7 |
8 | This is the official repository for the paper "[JIST: Joint Image and Sequence Training for Sequential Visual Place Recognition](https://ieeexplore.ieee.org/document/10339796)". It can be used to reproduce results from the paper while also experiment with different aggregations of sequential descriptor methods for Visual Place Recognition.
9 |
10 | ## Install Locally
11 | Create your local environment and then install the required packages using:
12 | ``` bash
13 | pip install -r pip_requirements.txt
14 | ```
15 |
16 | ## Datasets
17 | The experiments in the paper use two main datasets Mapillary Street Level Sequence (MSLS) and Oxford RobotCar.
18 | For downloading them you can refer to the repo of our previous paper here.
19 | - for MSLS, download the dataset from the official website, and format it using the instructions at the link above
20 | - for Robotcar, at the repo link above we provide from GDrive an already pre-processed version of the dataset
21 |
22 | In this work we also use the SF-XL dataset, which in total is about 1 TB, although we only use the `processed` subset as done in CosPlace, which is around 400 GB. More info in [CosPlace](https://github.com/gmberton/CosPlace).
23 | You can request the download, specifying the processed version of the dataset, using this form [_here_](https://forms.gle/wpyDzhDyoWLQygAT9).
24 |
25 | ## Run Experiments
26 | Once the datasets are ready, we can proceed running the experiments with the architecture of choice.
27 |
28 | **NB**: to build MSLS sequences, some heavy pre-processing to build data structures is needed. The dataset class will automatically cache this,
29 | so to compute them only the first time. Therefore the first experiment that you ever launch will take 1-2 hours to build this structures which will
30 | be saved in a `cache` directory, and following experiments will then start quickly. Note that this procedure caches everything with relative paths,
31 | therefore if you want to run experiments on multiple machines you can simply copy the `cache` directory.
32 | Finally, note that this data structures must be computed for each sequence length, so potentially in `cache` you will have a file for each sequence_length
33 | that you want to experiment with. You can also precompute them with the following command:
34 | ```bash
35 | python cache_dataset_datastruct.py --msls_folder /path/to/msls --seq_len SL
36 | ```
37 |
38 | To replicate our results you can train a model as follows:
39 |
40 | ```bash
41 | python train_double_dataset.py --exp_name exp_name \
42 | --dataset_folder /path/to/sf_xl/processed \
43 | --seq_dataset_path /path/to/msls \
44 | --aggregation_type seqgem
45 | ```
46 |
47 | ### Evaluate trained models
48 | It is possible to evaluate the trained models using:
49 | ``` bash
50 | python main_scripts/evaluation.py \
51 | --resume --seq_dataset_path /path/to/dataset
52 | ```
53 |
54 | ### Download trained models
55 | You can download our JIST model with a ResNet-18 backbone and SeqGem from [this Google Drive link](https://drive.google.com/file/d/1F6eVrR-0LseE-tfbT8Y8WT92_O11lt5o/view?usp=sharing).
56 |
57 | ## Acknowledgements
58 |
59 | Parts of this repo are inspired by the following repositories:
60 | - [SeqVLAD](https://github.com/vandal-vpr/vg-transformers)
61 | - [CosFace](https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py)
62 | - [CosPlace](https://github.com/gmberton/CosPlace)
63 |
64 | ## Cite
65 | Here is the bibtex to cite our paper
66 |
67 | ```
68 | @ARTICLE{Berton_2023_Jist,
69 | author={Berton, Gabriele and Trivigno, Gabriele and Caputo, Barbara and Masone, Carlo},
70 | journal={IEEE Robotics and Automation Letters},
71 | title={JIST: Joint Image and Sequence Training for Sequential Visual Place Recognition},
72 | year={2023},
73 | volume={},
74 | number={},
75 | pages={1-8},
76 | doi={10.1109/LRA.2023.3339058}
77 | }
78 | ```
79 |
--------------------------------------------------------------------------------
/assets/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ga1i13o/JIST/02014f8934fdd0203907dad41a2ead1f7484298f/assets/architecture.jpg
--------------------------------------------------------------------------------
/cache_dataset_datastruct.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from jist.datasets import TrainDataset
3 |
4 |
5 | def main(args):
6 | folder = args.msls_folder
7 | seq_len = args.seq_len
8 |
9 | print(f'Caching dataset with seq len {seq_len}')
10 | triplets_ds = TrainDataset(cities='', dataset_folder=folder, split='train',
11 | seq_len=seq_len, pos_thresh=10, neg_thresh=25)
12 |
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser(description="Sequence Visual Geolocalization",
16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17 | parser.add_argument('--msls_folder', type=str, required=True)
18 | parser.add_argument('--seq_len', type=int, default=5)
19 | args = parser.parse_args()
20 |
21 | main(args)
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from datetime import datetime
3 | import torch
4 |
5 | from jist.datasets import BaseDataset
6 | from jist import utils, evals
7 | from jist.models import JistModel
8 |
9 |
10 | def evaluation(args):
11 | start_time = datetime.now()
12 | args.output_folder = f"test/{args.exp_name}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}"
13 | utils.setup_logging(args.output_folder, console="info")
14 | logging.info(f"Arguments: {args}")
15 | logging.info(f"The outputs are being saved in {args.output_folder}")
16 |
17 | ### Definition of the model
18 | model = JistModel(args, agg_type=args.aggregation_type)
19 |
20 | if args.resume_model != None:
21 | logging.debug(f"Loading model from {args.resume_model}")
22 | model_state_dict = torch.load(args.resume_model)
23 | model.load_state_dict(model_state_dict)
24 |
25 | model = model.to(args.device)
26 |
27 | meta = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
28 | img_shape = (args.img_shape[0], args.img_shape[1])
29 | transform = utils.configure_transform(image_dim=img_shape, meta=meta)
30 |
31 | eval_ds = BaseDataset(dataset_folder=args.seq_dataset_path, split='test',
32 | base_transform=transform, seq_len=args.seq_length,
33 | pos_thresh=args.val_posDistThr, reverse_frames=args.reverse)
34 | logging.info(f"Test set: {eval_ds}")
35 |
36 | logging.info(f"Backbone output channels are {model.features_dim}, features descriptor dim is {model.fc_output_dim}, "
37 | f"sequence descriptor dim is {model.aggregation_dim}")
38 |
39 | _, recalls_str = evals.test(args, eval_ds, model)
40 | logging.info(f"Recalls on test set: {recalls_str}")
41 | logging.info(f"Finished in {str(datetime.now() - start_time)[:-7]}")
42 |
43 |
44 | if __name__ == "__main__":
45 | args = utils.parse_arguments()
46 | evaluation(args)
47 |
--------------------------------------------------------------------------------
/jist/__init__.py:
--------------------------------------------------------------------------------
1 | from . import datasets
2 | from . import models
3 | from . import evals
4 | from . import utils
5 |
--------------------------------------------------------------------------------
/jist/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['dataset', 'train_dataset', 'test_dataset']
2 |
3 |
4 | from .dataset import BaseDataset, TrainDataset, collate_fn
5 | from .train_dataset import CosplaceTrainDataset
6 | from .test_dataset import TestDataset
7 |
--------------------------------------------------------------------------------
/jist/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import faiss
4 | import logging
5 | import numpy as np
6 | from glob import glob
7 | from tqdm import tqdm
8 | from PIL import Image
9 | from os.path import join
10 | import torch.utils.data as data
11 | from torch.utils.data.dataset import Subset
12 | from sklearn.neighbors import NearestNeighbors
13 | from torch.utils.data.dataloader import DataLoader
14 |
15 | from jist.utils import RAMEfficient2DMatrix
16 |
17 |
18 | def collate_fn(batch):
19 | """Creates mini-batch tensors from the list of tuples (images,
20 | triplets_local_indexes, triplets_global_indexes).
21 | triplets_local_indexes are the indexes referring to each triplet within images.
22 | triplets_global_indexes are the global indexes of each image.
23 | Args:
24 | batch: list of tuple (images, triplets_local_indexes, triplets_global_indexes).
25 | considering each query to have 10 negatives (negs_num_per_query=10):
26 | - images: torch tensor of shape (12, 3, h, w).
27 | - triplets_local_indexes: torch tensor of shape (10, 3).
28 | - triplets_global_indexes: torch tensor of shape (12).
29 | Returns:
30 | images: torch tensor of shape (batch_size*12, 3, h, w).
31 | triplets_local_indexes: torch tensor of shape (batch_size*10, 3).
32 | triplets_global_indexes: torch tensor of shape (batch_size, 12).
33 | """
34 | images = torch.cat([e[0] for e in batch])
35 | triplets_local_indexes = torch.cat([e[1][None] for e in batch])
36 | triplets_global_indexes = torch.cat([e[2][None] for e in batch])
37 | for i, (local_indexes, global_indexes) in enumerate(zip(triplets_local_indexes, triplets_global_indexes)):
38 | local_indexes += len(global_indexes) * i # Increment local indexes by offset (len(global_indexes) is 12)
39 | return images, torch.cat(tuple(triplets_local_indexes)), triplets_global_indexes
40 |
41 |
42 | class BaseDataset(data.Dataset):
43 | def __init__(self, cities='', dataset_folder="datasets", split="train", base_transform=None,
44 | seq_len=3, pos_thresh=25, neg_thresh=25, reverse_frames=False):
45 | super().__init__()
46 | self.dataset_folder = join(dataset_folder, split)
47 | self.seq_len = seq_len
48 |
49 | # the algorithm to create sequences works with odd-numbered sequences only
50 | cut_last_frame = ((seq_len % 2) == 0)
51 | if cut_last_frame:
52 | self.seq_len += 1
53 | self.base_transform = base_transform
54 |
55 | if not os.path.exists(self.dataset_folder): raise FileNotFoundError(
56 | f"Folder {self.dataset_folder} does not exist")
57 |
58 | self.init_data(cities, split, pos_thresh, neg_thresh)
59 | if reverse_frames:
60 | self.db_paths = [",".join(path.split(',')[::-1]) for path in self.db_paths]
61 | self.images_paths = self.db_paths + self.q_paths
62 | self.database_num = len(self.db_paths)
63 | self.queries_num = len(self.qIdx)
64 |
65 | if cut_last_frame:
66 | self.__cut_last_frame()
67 |
68 | def init_data(self, cities, split, pos_thresh, neg_thresh):
69 | if ('msls' in self.dataset_folder) and (split == 'train'):
70 | os.makedirs('cache', exist_ok=True)
71 | cache_file = f'cache/msls_seq{self.seq_len}_{cities}.torch'
72 | if os.path.isfile(cache_file):
73 | logging.info(f'Loading cached data from {cache_file}...')
74 | cache_dict = torch.load(cache_file)
75 | self.__dict__.update(cache_dict)
76 | return
77 | else:
78 | logging.info('Data structures were not cached, building them now...')
79 | #### Read paths and UTM coordinates for all images.
80 | database_folder = join(self.dataset_folder, "database")
81 | queries_folder = join(self.dataset_folder, "queries")
82 |
83 | cities = cities
84 | self.db_paths, all_db_paths, db_idx_frame_to_seq = build_sequences(database_folder,
85 | seq_len=self.seq_len, cities=cities,
86 | desc='loading database...')
87 | self.q_paths, all_q_paths, q_idx_frame_to_seq = build_sequences(queries_folder,
88 | seq_len=self.seq_len, cities=cities,
89 | desc='loading queries...')
90 |
91 | q_unique_idxs = np.unique([idx for seq_frames_idx in q_idx_frame_to_seq for idx in seq_frames_idx])
92 | db_unique_idxs = np.unique([idx for seq_frames_idx in db_idx_frame_to_seq for idx in seq_frames_idx])
93 |
94 | self.database_utms = np.array(
95 | [(path.split("@")[1], path.split("@")[2]) for path in all_db_paths[db_unique_idxs]]).astype(np.float64)
96 | self.queries_utms = np.array(
97 | [(path.split("@")[1], path.split("@")[2]) for path in all_q_paths[q_unique_idxs]]).astype(
98 | np.float64)
99 |
100 | knn = NearestNeighbors(n_jobs=-1)
101 | knn.fit(self.database_utms)
102 | self.hard_positives_per_query = knn.radius_neighbors(self.queries_utms,
103 | radius=pos_thresh,
104 | return_distance=False)
105 | if split == 'train':
106 | # Find soft_positives_per_query, which are within val_positive_dist_threshold (deafult 25 meters)
107 | knn = NearestNeighbors(n_jobs=-1)
108 | knn.fit(self.database_utms)
109 | self.soft_positives_per_query = knn.radius_neighbors(self.queries_utms,
110 | radius=neg_thresh,
111 | return_distance=False)
112 | self.qIdx = []
113 | self.pIdx = []
114 | self.nonNegIdx = []
115 | self.q_without_pos = 0
116 | for q in tqdm(range(len(q_idx_frame_to_seq)), ncols=100, desc='Finding positives and negatives...'):
117 | q_frame_idxs = q_idx_frame_to_seq[q]
118 | unique_q_frame_idxs = np.where(np.in1d(q_unique_idxs, q_frame_idxs))
119 |
120 | p_uniq_frame_idxs = np.unique(
121 | [p for pos in self.hard_positives_per_query[unique_q_frame_idxs] for p in pos])
122 |
123 | if len(p_uniq_frame_idxs) > 0:
124 | # p_seq_idx = np.where(np.in1d(db_unique_idxs, p_uniq_frame_idxs))[0]
125 | p_seq_idx = np.where(np.in1d(db_idx_frame_to_seq, db_unique_idxs[p_uniq_frame_idxs])
126 | .reshape(db_idx_frame_to_seq.shape))[0]
127 |
128 | self.qIdx.append(q)
129 | self.pIdx.append(np.unique(p_seq_idx))
130 |
131 | if split == 'train':
132 | nonNeg_uniq_frame_idxs = np.unique(
133 | [p for pos in self.soft_positives_per_query[unique_q_frame_idxs] for p in pos])
134 | #nonNeg_seq_idx = np.where(np.in1d(db_unique_idxs, nonNeg_uniq_frame_idxs))
135 | #self.nonNegIdx.append(nonNeg_seq_idx)
136 | nonNeg_seq_idx = np.where(np.in1d(db_idx_frame_to_seq, db_unique_idxs[nonNeg_uniq_frame_idxs])
137 | .reshape(db_idx_frame_to_seq.shape))[0]
138 | self.nonNegIdx.append(np.unique(nonNeg_seq_idx))
139 | else:
140 | self.q_without_pos += 1
141 |
142 | self.qIdx = np.array(self.qIdx)
143 | self.pIdx = np.array(self.pIdx, dtype=object)
144 | if split == 'train':
145 | save_dict = {
146 | 'db_paths': self.db_paths,
147 | 'q_paths': self.q_paths,
148 | 'database_utms': self.database_utms,
149 | 'queries_utms': self.queries_utms,
150 | 'hard_positives_per_query': self.hard_positives_per_query,
151 | 'soft_positives_per_query': self.soft_positives_per_query,
152 | 'qIdx': self.qIdx,
153 | 'pIdx': self.pIdx,
154 | 'nonNegIdx': self.nonNegIdx,
155 | 'q_without_pos': self.q_without_pos
156 | }
157 | torch.save(save_dict, cache_file)
158 |
159 | def __cut_last_frame(self):
160 | for i, seq in enumerate(self.images_paths):
161 | self.images_paths[i] = ','.join((seq.split(',')[:-1]))
162 | for i, seq in enumerate(self.db_paths):
163 | self.db_paths[i] = ','.join((seq.split(',')[:-1]))
164 | for i, seq in enumerate(self.q_paths):
165 | self.q_paths[i] = ','.join((seq.split(',')[:-1]))
166 |
167 | def __getitem__(self, index):
168 | old_index = index
169 | if index >= self.database_num:
170 | q_index = index - self.database_num
171 | index = self.qIdx[q_index] + self.database_num
172 |
173 | img = torch.stack([self.base_transform(Image.open(join(self.dataset_folder, im))) for im in self.images_paths[index].split(',')])
174 |
175 | return img, index, old_index
176 |
177 | def __len__(self):
178 | return len(self.images_paths)
179 |
180 | def __repr__(self):
181 | return (
182 | f"< {self.__class__.__name__}, ' #database: {self.database_num}; #queries: {self.queries_num} >")
183 |
184 | def get_positives(self):
185 | return self.pIdx
186 |
187 |
188 | def filter_by_cities(x, cities):
189 | for city in cities:
190 | if x.find(city) > 0:
191 | return True
192 | return False
193 |
194 |
195 | def build_sequences(folder, seq_len=3, cities='', desc='loading'):
196 | if cities != '':
197 | if not isinstance(cities, list):
198 | cities = [cities]
199 | base_path = os.path.dirname(folder)
200 | paths = []
201 | all_paths = []
202 | idx_frame_to_seq = []
203 | seqs_folders = sorted(glob(join(folder, '*'), recursive=True))
204 | for seq in tqdm(seqs_folders, ncols=100, desc=desc):
205 | start_index = len(all_paths)
206 | frame_nums = np.array(list(map(lambda x: int(x.split('@')[4]), sorted(glob(join(seq, '*'))))))
207 | # seq_paths = np.array(sorted(glob(join(seq, '*'))))
208 | # read the full paths, then keep only relative ones. allows caching only rel. paths
209 | full_seq_paths = sorted(glob(join(seq, '*')))
210 | seq_paths = np.array([s_p.replace(f'{base_path}/', '') for s_p in full_seq_paths])
211 |
212 | if cities != '':
213 | sample_path = seq_paths[0]
214 | if not filter_by_cities(sample_path, cities):
215 | continue
216 |
217 | sorted_idx_frames = np.argsort(frame_nums)
218 | all_paths += list(seq_paths[sorted_idx_frames])
219 | for idx, frame_num in enumerate(frame_nums):
220 | if idx < (seq_len // 2) or idx >= (len(frame_nums) - seq_len // 2): continue
221 | # find surrounding frames in sequence
222 | seq_idx = np.arange(-seq_len // 2, seq_len // 2) + 1 + idx
223 | if (np.diff(frame_nums[sorted_idx_frames][seq_idx]) == 1).all():
224 | paths.append(",".join(seq_paths[sorted_idx_frames][seq_idx]))
225 | idx_frame_to_seq.append(seq_idx + start_index)
226 |
227 | return paths, np.array(all_paths), np.array(idx_frame_to_seq)
228 |
229 |
230 | class TrainDataset(BaseDataset):
231 | def __init__(self, cities='', dataset_folder="datasets", split="train", base_transform=None,
232 | seq_len=3, pos_thresh=25, neg_thresh=25, infer_batch_size=8,
233 | num_workers=3, img_shape=(480, 640),
234 | cached_negatives=1000, cached_queries=1000, nNeg=10):
235 | super().__init__(dataset_folder=dataset_folder, split=split, cities=cities, base_transform=base_transform,
236 | seq_len=seq_len, pos_thresh=pos_thresh, neg_thresh=neg_thresh)
237 | self.cached_negatives = cached_negatives # Number of negatives to randomly sample
238 | self.num_workers = num_workers
239 | self.cached_queries = cached_queries
240 | self.device = torch.device('cuda') if torch.cuda.is_available() else "cpu"
241 | self.bs = infer_batch_size
242 | self.img_shape = img_shape
243 | self.nNeg = nNeg # Number of negatives per query in each batch
244 | self.is_inference = False
245 | self.query_transform = self.base_transform
246 |
247 | def __getitem__(self, index):
248 | if self.is_inference:
249 | # At inference time return the single image. This is used for caching or computing NetVLAD's clusters
250 | return super().__getitem__(index)
251 | query_index, best_positive_index, neg_indexes = torch.split(self.triplets_global_indexes[index],
252 | (1, 1, self.nNeg))
253 |
254 | query = torch.stack(
255 | [self.base_transform(Image.open(join(self.dataset_folder, im))) for im in self.q_paths[query_index].split(',')])
256 |
257 | positive = torch.stack(
258 | [self.base_transform(Image.open(join(self.dataset_folder, im))) for im in self.db_paths[best_positive_index].split(',')])
259 |
260 | negatives = [torch.stack([self.base_transform(Image.open(join(self.dataset_folder, im)))for im in self.db_paths[idx].split(',')])
261 | for idx in neg_indexes]
262 |
263 | images = torch.stack((query, positive, *negatives), 0)
264 | triplets_local_indexes = torch.empty((0, 3), dtype=torch.int)
265 | for neg_num in range(len(neg_indexes)):
266 | triplets_local_indexes = torch.cat(
267 | (triplets_local_indexes, torch.tensor([0, 1, 2 + neg_num]).reshape(1, 3)))
268 | return images, triplets_local_indexes, self.triplets_global_indexes[index]
269 |
270 | def __len__(self):
271 | if self.is_inference:
272 | # At inference time return the number of images. This is used for caching or computing NetVLAD's clusters
273 | return super().__len__()
274 | else:
275 | return len(self.triplets_global_indexes)
276 |
277 | def compute_triplets(self, model):
278 | self.is_inference = True
279 | self.compute_triplets_partial(model)
280 | self.is_inference = False
281 |
282 | def compute_cache(self, model, subset_ds, cache_shape):
283 | subset_dl = DataLoader(dataset=subset_ds, num_workers=self.num_workers,
284 | batch_size=self.bs, shuffle=False,
285 | pin_memory=(self.device == "cuda"))
286 | model = model.eval()
287 | cache = RAMEfficient2DMatrix(cache_shape, dtype=np.float32)
288 | with torch.no_grad():
289 | with torch.cuda.amp.autocast():
290 | for images, indexes, _ in tqdm(subset_dl, desc="cache feat extraction", ncols=100):
291 | images = images.view(-1, 3, self.img_shape[0], self.img_shape[1])
292 | frames_features = model(images.to(self.device))
293 | aggregated_features = model.aggregate(frames_features)
294 | cache[indexes.numpy()] = aggregated_features.cpu().numpy()
295 | return cache
296 |
297 | def get_best_positive_index(self, qidx, cache, query_features):
298 | positives_features = cache[self.pIdx[qidx]]
299 | faiss_index = faiss.IndexFlatL2(len(query_features))
300 | faiss_index.add(positives_features)
301 | # Search the best positive (within 10 meters AND nearest in features space)
302 | _, best_positive_num = faiss_index.search(query_features.reshape(1, -1), 1)
303 | best_positive_index = self.pIdx[qidx][best_positive_num[0]]
304 | return best_positive_index
305 |
306 | def get_hardest_negatives_indexes(self, cache, query_features, neg_indexes):
307 | neg_features = cache[neg_indexes]
308 |
309 | faiss_index = faiss.IndexFlatL2(len(query_features))
310 | faiss_index.add(neg_features)
311 |
312 | _, neg_nums = faiss_index.search(query_features.reshape(1, -1), self.nNeg)
313 | neg_nums = neg_nums.reshape(-1)
314 | neg_idxs = neg_indexes[neg_nums].astype(np.int32)
315 |
316 | return neg_idxs
317 |
318 | def compute_triplets_partial(self, model):
319 | self.triplets_global_indexes = []
320 | # Take 1000 random queries
321 | sampled_queries_indexes = np.random.choice(self.queries_num, self.cached_queries, replace=False)
322 | # Sample 1000 random database images for the negatives
323 | sampled_database_indexes = np.random.choice(self.database_num, self.cached_negatives, replace=False)
324 |
325 | positives_indexes = np.unique([idx for db_idx in self.pIdx[sampled_queries_indexes] for idx in db_idx])
326 | database_indexes = list(sampled_database_indexes) + list(positives_indexes)
327 | subset_ds = Subset(self, database_indexes + list(sampled_queries_indexes + self.database_num))
328 | cache = self.compute_cache(model, subset_ds, cache_shape=[len(self), model.aggregation_dim])
329 |
330 | for q in tqdm(sampled_queries_indexes, desc="computing hard negatives", ncols=100):
331 | qidx = self.qIdx[q] + self.database_num
332 | query_features = cache[qidx]
333 |
334 | best_positive_index = self.get_best_positive_index(q, cache, query_features)
335 | if isinstance(best_positive_index, np.ndarray):
336 | best_positive_index = best_positive_index[0]
337 | # Choose the hardest negatives within sampled_database_indexes, ensuring that there are no positives
338 | soft_positives = self.nonNegIdx[q]
339 | neg_indexes = np.setdiff1d(sampled_database_indexes, soft_positives, assume_unique=True)
340 | # Take all database images that are negatives and are within the sampled database images
341 | neg_indexes = self.get_hardest_negatives_indexes(cache, query_features, neg_indexes)
342 | self.triplets_global_indexes.append((self.qIdx[q], best_positive_index, *neg_indexes))
343 |
344 | # self.triplets_global_indexes is a tensor of shape [1000, 12]
345 | self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes)
346 |
347 |
--------------------------------------------------------------------------------
/jist/datasets/test_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import numpy as np
4 | from glob import glob
5 | from PIL import Image
6 | import torch.utils.data as data
7 | import torchvision.transforms as transforms
8 | from sklearn.neighbors import NearestNeighbors
9 |
10 |
11 | def open_image(path):
12 | return Image.open(path).convert("RGB")
13 |
14 |
15 | class TestDataset(data.Dataset):
16 | def __init__(self, dataset_folder, database_folder="database",
17 | queries_folder="queries", positive_dist_threshold=25):
18 | """Dataset with images from database and queries, used for validation and test.
19 | Parameters
20 | ----------
21 | dataset_folder : str, should contain the path to the val or test set,
22 | which contains the folders {database_folder} and {queries_folder}.
23 | database_folder : str, name of folder with the database.
24 | queries_folder : str, name of folder with the queries.
25 | positive_dist_threshold : int, distance in meters for a prediction to
26 | be considered a positive.
27 | """
28 | super().__init__()
29 | self.dataset_folder = dataset_folder
30 | self.database_folder = os.path.join(dataset_folder, database_folder)
31 | self.queries_folder = os.path.join(dataset_folder, queries_folder)
32 | self.dataset_name = os.path.basename(dataset_folder)
33 |
34 | if not os.path.exists(self.dataset_folder):
35 | raise FileNotFoundError(f"Folder {self.dataset_folder} does not exist")
36 | if not os.path.exists(self.database_folder):
37 | raise FileNotFoundError(f"Folder {self.database_folder} does not exist")
38 | if not os.path.exists(self.queries_folder):
39 | raise FileNotFoundError(f"Folder {self.queries_folder} does not exist")
40 |
41 | self.base_transform = transforms.Compose([
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
44 | ])
45 |
46 | #### Read paths and UTM coordinates for all images.
47 | self.database_paths = sorted(glob(os.path.join(self.database_folder, "**", "*.jpg"), recursive=True))
48 | self.queries_paths = sorted(glob(os.path.join(self.queries_folder, "**", "*.jpg"), recursive=True))
49 |
50 | # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg
51 | self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(np.float64)
52 | self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(np.float64)
53 |
54 | # Find positives_per_query, which are within positive_dist_threshold (default 25 meters)
55 | knn = NearestNeighbors(n_jobs=-1)
56 | knn.fit(self.database_utms)
57 | self.positives_per_query = knn.radius_neighbors(self.queries_utms,
58 | radius=positive_dist_threshold,
59 | return_distance=False)
60 |
61 | self.images_paths = [p for p in self.database_paths]
62 | self.images_paths += [p for p in self.queries_paths]
63 |
64 | self.database_num = len(self.database_paths)
65 | self.queries_num = len(self.queries_paths)
66 |
67 | def __getitem__(self, index):
68 | image_path = self.images_paths[index]
69 | pil_img = open_image(image_path)
70 | normalized_img = self.base_transform(pil_img)
71 | return normalized_img, index
72 |
73 | def __len__(self):
74 | return len(self.images_paths)
75 |
76 | def __repr__(self):
77 | return (f"< {self.dataset_name} - #q: {self.queries_num}; #db: {self.database_num} >")
78 |
79 | def get_positives(self):
80 | return self.positives_per_query
81 |
82 |
--------------------------------------------------------------------------------
/jist/datasets/train_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import random
5 | import logging
6 | import numpy as np
7 | from PIL import Image
8 | from PIL import ImageFile
9 | import torchvision.transforms as T
10 | from collections import defaultdict
11 |
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 |
14 |
15 | def open_image(path):
16 | return Image.open(path).convert("RGB")
17 |
18 |
19 | class CosplaceTrainDataset(torch.utils.data.Dataset):
20 | def __init__(self, args, dataset_folder, M=10, alpha=30, N=5, L=2,
21 | current_group=0, min_images_per_class=10):
22 | """
23 | Parameters (please check our paper for a clearer explanation of the parameters).
24 | ----------
25 | args : args for data augmentation
26 | dataset_folder : str, the path of the folder with the train images.
27 | M : int, the length of the side of each cell in meters.
28 | alpha : int, size of each class in degrees.
29 | N : int, distance (M-wise) between two classes of the same group.
30 | L : int, distance (alpha-wise) between two classes of the same group.
31 | current_group : int, which one of the groups to consider.
32 | min_images_per_class : int, minimum number of image in a class.
33 | """
34 | super().__init__()
35 | self.M = M
36 | self.alpha = alpha
37 | self.N = N
38 | self.L = L
39 | self.current_group = current_group
40 | self.dataset_folder = dataset_folder
41 | self.augmentation_device = args.augmentation_device
42 |
43 | # dataset_name should be "processed" (if you're using SF-XL)
44 | dataset_name = os.path.basename(args.dataset_folder.strip('/'))
45 | filename = f"cache/{dataset_name}_M{M}_N{N}_mipc{min_images_per_class}.torch"
46 | if not os.path.exists(filename):
47 | os.makedirs("cache", exist_ok=True)
48 | logging.info(f"Cached dataset {filename} does not exist, I'll create it now.")
49 | self.initialize(dataset_folder, M, N, alpha, L, min_images_per_class, filename)
50 | elif current_group == 0:
51 | logging.info(f"Using cached dataset {filename}")
52 |
53 | self.base_path = dataset_folder[:dataset_folder.find(dataset_name)]
54 | classes_per_group, self.images_per_class = torch.load(filename)
55 | if current_group >= len(classes_per_group):
56 | raise ValueError(f"With this configuration there are only {len(classes_per_group)} " +
57 | f"groups, therefore I can't create the {current_group}th group. " +
58 | "You should reduce the number of groups in --groups_num")
59 |
60 | self.classes_ids = classes_per_group[current_group]
61 | if self.augmentation_device == "cpu":
62 | self.transform = T.Compose([
63 | T.ColorJitter(brightness=args.brightness,
64 | contrast=args.contrast,
65 | saturation=args.saturation,
66 | hue=args.hue),
67 | T.RandomResizedCrop([512, 512], scale=[1-args.random_resized_crop, 1]),
68 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
69 | ])
70 |
71 |
72 | def __getitem__(self, class_num):
73 | # This function takes as input the class_num instead of the index of
74 | # the image. This way each class is equally represented during training.
75 |
76 | class_id = self.classes_ids[class_num]
77 | # Pick a random image among those in this class.
78 | image_path = random.choice(self.images_per_class[class_id])
79 | image_path = os.path.join(self.base_path, image_path)
80 | try:
81 | pil_image = open_image(image_path)
82 | except Exception as e:
83 | logging.info(f"ERROR image {image_path} couldn't be opened, it might be corrupted.")
84 | raise e
85 |
86 | tensor_image = T.functional.to_tensor(pil_image)
87 | assert tensor_image.shape == torch.Size([3, 512, 512]), \
88 | f"Image {image_path} should have shape [3, 512, 512] but has {tensor_image.shape}."
89 |
90 | if self.augmentation_device == "cpu":
91 | tensor_image = self.transform(tensor_image)
92 |
93 | return tensor_image, class_num, image_path
94 |
95 |
96 | def get_images_num(self):
97 | """Return the number of images within this group."""
98 | return sum([len(self.images_per_class[c]) for c in self.classes_ids])
99 |
100 |
101 | def __len__(self):
102 | """Return the number of classes within this group."""
103 | return len(self.classes_ids)
104 |
105 |
106 | @staticmethod
107 | def initialize(dataset_folder, M, N, alpha, L, min_images_per_class, filename):
108 | logging.debug(f"Searching training images in {dataset_folder}")
109 |
110 | all_paths = os.path.join(dataset_folder, '../../all_images_paths.txt')
111 | with open(all_paths, "r") as file:
112 | lines = file.readlines()
113 | images_paths = [l.replace("\n", "") for l in lines if l.startswith("processed/train")]
114 | logging.debug(f"Found {len(images_paths)} images")
115 |
116 | logging.debug("For each image, get its UTM east, UTM north and heading from its path")
117 | images_metadatas = [p.split("@") for p in images_paths]
118 | # field 1 is UTM east, field 2 is UTM north, field 9 is heading
119 | utmeast_utmnorth_heading = [(m[1], m[2], m[9]) for m in images_metadatas]
120 | utmeast_utmnorth_heading = np.array(utmeast_utmnorth_heading).astype(np.float64)
121 |
122 | logging.debug("For each image, get class and group to which it belongs")
123 | class_id__group_id = [CosplaceTrainDataset.get__class_id__group_id(*m, M, alpha, N, L)
124 | for m in utmeast_utmnorth_heading]
125 |
126 | logging.debug("Group together images belonging to the same class")
127 | images_per_class = defaultdict(list)
128 | for image_path, (class_id, _) in zip(images_paths, class_id__group_id):
129 | images_per_class[class_id].append(image_path)
130 |
131 | # Images_per_class is a dict where the key is class_id, and the value
132 | # is a list with the paths of images within that class.
133 | images_per_class = {k: v for k, v in images_per_class.items() if len(v) >= min_images_per_class}
134 |
135 | logging.debug("Group together classes belonging to the same group")
136 | # Classes_per_group is a dict where the key is group_id, and the value
137 | # is a list with the class_ids belonging to that group.
138 | classes_per_group = defaultdict(set)
139 | for class_id, group_id in class_id__group_id:
140 | if class_id not in images_per_class:
141 | continue # Skip classes with too few images
142 | classes_per_group[group_id].add(class_id)
143 |
144 | # Convert classes_per_group to a list of lists.
145 | # Each sublist represents the classes within a group.
146 | classes_per_group = [list(c) for c in classes_per_group.values()]
147 |
148 | torch.save((classes_per_group, images_per_class), filename)
149 |
150 |
151 | @staticmethod
152 | def get__class_id__group_id(utm_east, utm_north, heading, M, alpha, N, L):
153 | """Return class_id and group_id for a given point.
154 | The class_id is a triplet (tuple) of UTM_east, UTM_north and
155 | heading (e.g. (396520, 4983800,120)).
156 | The group_id represents the group to which the class belongs
157 | (e.g. (0, 1, 0)), and it is between (0, 0, 0) and (N, N, L).
158 | """
159 | rounded_utm_east = int(utm_east // M * M) # Rounded to nearest lower multiple of M
160 | rounded_utm_north = int(utm_north // M * M)
161 | rounded_heading = int(heading // alpha * alpha)
162 |
163 | class_id = (rounded_utm_east, rounded_utm_north, rounded_heading)
164 | # group_id goes from (0, 0, 0) to (N, N, L)
165 | group_id = (rounded_utm_east % (M * N) // M,
166 | rounded_utm_north % (M * N) // M,
167 | rounded_heading % (alpha * L) // alpha)
168 | return class_id, group_id
169 |
170 |
--------------------------------------------------------------------------------
/jist/evals/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['test', 'cp_test']
2 |
3 |
4 | from .test import test
5 | from .cp_test import cosplace_test
6 |
7 |
--------------------------------------------------------------------------------
/jist/evals/cp_test.py:
--------------------------------------------------------------------------------
1 |
2 | import faiss
3 | import torch
4 | import logging
5 | import numpy as np
6 | from tqdm import tqdm
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data.dataset import Subset
9 |
10 |
11 | # Compute R@1, R@5, R@10, R@20
12 | RECALL_VALUES = [1, 5, 10, 20]
13 |
14 |
15 | def cosplace_test(args, eval_ds, model):
16 | """Compute descriptors of the given dataset and compute the recalls."""
17 |
18 | model = model.eval()
19 | with torch.no_grad():
20 | logging.debug("Extracting database descriptors for evaluation/testing")
21 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
22 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
23 | batch_size=args.infer_batch_size, pin_memory=(args.device=="cuda"))
24 | all_descriptors = np.empty((len(eval_ds), args.fc_output_dim), dtype="float32")
25 | for images, indices in tqdm(database_dataloader, ncols=100):
26 | descriptors = model(images.to(args.device))
27 | descriptors = descriptors.cpu().numpy()
28 | all_descriptors[indices.numpy(), :] = descriptors
29 |
30 | logging.debug("Extracting queries descriptors for evaluation/testing")
31 | queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))
32 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
33 | batch_size=args.infer_batch_size, pin_memory=(args.device=="cuda"))
34 | for images, indices in tqdm(queries_dataloader, ncols=100):
35 | descriptors = model(images.to(args.device))
36 | descriptors = descriptors.cpu().numpy()
37 | all_descriptors[indices.numpy(), :] = descriptors
38 |
39 | queries_descriptors = all_descriptors[eval_ds.database_num:]
40 | database_descriptors = all_descriptors[:eval_ds.database_num]
41 |
42 | # Use a kNN to find predictionss
43 | faiss_index = faiss.IndexFlatL2(args.fc_output_dim)
44 | faiss_index.add(database_descriptors)
45 | del database_descriptors, all_descriptors
46 |
47 | logging.debug("Calculating recalls")
48 | _, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES))
49 |
50 | #### For each query, check if the predictions are correct
51 | positives_per_query = eval_ds.get_positives()
52 | recalls = np.zeros(len(RECALL_VALUES))
53 | for query_index, preds in enumerate(predictions):
54 | for i, n in enumerate(RECALL_VALUES):
55 | if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
56 | recalls[i:] += 1
57 | break
58 | # Divide by queries_num and multiply by 100, so the recalls are in percentages
59 | recalls = recalls / eval_ds.queries_num * 100
60 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)])
61 | return recalls, recalls_str
62 |
63 |
--------------------------------------------------------------------------------
/jist/evals/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import numpy as np
4 | from tqdm import tqdm
5 | import faiss
6 | from torch.utils.data import DataLoader
7 | from torch.utils.data.dataset import Subset
8 |
9 |
10 | def test(args, eval_ds, model):
11 | model = model.eval()
12 |
13 | query_num = eval_ds.queries_num
14 | gallery_num = eval_ds.database_num
15 | all_features = np.empty((query_num + gallery_num, model.aggregation_dim), dtype=np.float32)
16 |
17 | with torch.no_grad():
18 | logging.debug("Extracting gallery features for evaluation/testing")
19 | database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
20 | database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=4,
21 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
22 |
23 | for images, indices, _ in tqdm(database_dataloader, ncols=100):
24 | images = images.contiguous().view(-1, 3, args.img_shape[0], args.img_shape[1])
25 | frames_features = model(images.to(args.device))
26 | aggregated_features = model.aggregate(frames_features)
27 | all_features[indices.numpy(), :] = aggregated_features.cpu().numpy()
28 |
29 | logging.debug("Extracting queries features for evaluation/testing")
30 | queries_subset_ds = Subset(eval_ds,
31 | list(range(eval_ds.database_num, eval_ds.database_num + eval_ds.queries_num)))
32 | queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=4,
33 | batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
34 |
35 | for images, _, indices in tqdm(queries_dataloader, ncols=100):
36 | images = images.contiguous().view(-1, 3, args.img_shape[0], args.img_shape[1])
37 | frames_features = model(images.to(args.device))
38 | aggregated_features = model.aggregate(frames_features)
39 | all_features[indices.numpy(), :] = aggregated_features.cpu().numpy()
40 |
41 | torch.cuda.empty_cache()
42 | queries_features = all_features[eval_ds.database_num:]
43 | gallery_features = all_features[:eval_ds.database_num]
44 |
45 | faiss_index = faiss.IndexFlatL2(model.aggregation_dim)
46 | faiss_index.add(gallery_features)
47 |
48 | logging.debug("Calculating recalls")
49 | _, predictions = faiss_index.search(queries_features, 10)
50 |
51 | # For each query, check if the predictions are correct
52 | positives_per_query = eval_ds.pIdx
53 | recall_values = [1, 5, 10] # recall@1, recall@5, recall@10
54 | recalls = np.zeros(len(recall_values))
55 | for query_index, pred in enumerate(predictions):
56 | for i, n in enumerate(recall_values):
57 | if np.any(np.in1d(pred[:n], positives_per_query[query_index])):
58 | recalls[i:] += 1
59 | break
60 | # Divide by the number of queries*100, so the recalls are in percentages
61 | recalls = recalls / len(eval_ds.qIdx) * 100
62 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(recall_values, recalls)])
63 | return recalls, recalls_str
64 |
--------------------------------------------------------------------------------
/jist/models/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['network']
2 |
3 |
4 | from jist.models.network import JistModel
5 |
--------------------------------------------------------------------------------
/jist/models/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import einops
5 |
6 |
7 | def seq_gem(x, p=torch.ones(1)*3, eps: float = 1e-6):
8 | B, D, SL = x.shape
9 | return F.avg_pool1d(x.clamp(min=eps).pow(p), SL).pow(1./p)
10 |
11 |
12 | class SeqGeM(nn.Module):
13 | def __init__(self, p=3, eps=1e-6):
14 | super().__init__()
15 | self.p = torch.nn.Parameter(torch.ones(1)*p)
16 | self.eps = eps
17 | def forward(self, x):
18 | B, SL, D = x.shape
19 | x = einops.rearrange(x, "b sl d -> b d sl")
20 | x = seq_gem(x, p=self.p, eps=self.eps)
21 | assert x.shape == torch.Size([B, D, 1]), f"{x.shape}"
22 | return x[:, :, 0]
23 | def __repr__(self):
24 | return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})"
25 |
26 |
27 | class JistModel(nn.Module):
28 | def __init__(self, args, agg_type="concat"):
29 | super().__init__()
30 | self.model = torch.hub.load("gmberton/cosplace", "get_trained_model",
31 | backbone=args.backbone, fc_output_dim=args.fc_output_dim)
32 | for name, param in self.model.named_parameters():
33 | if name.startswith("backbone.7"): # Train only last residual block
34 | break
35 | param.requires_grad = False
36 | assert name.startswith("backbone.7"), "are you using a resnet? this only work with resnets"
37 |
38 | self.features_dim = self.model.aggregation[3].in_features
39 | self.fc_output_dim = self.model.aggregation[3].out_features
40 | self.seq_length = args.seq_length
41 | if agg_type == "concat":
42 | self.aggregation_dim = self.fc_output_dim * args.seq_length
43 | if agg_type == "mean":
44 | self.aggregation_dim = self.fc_output_dim
45 | if agg_type == "max":
46 | self.aggregation_dim = self.fc_output_dim
47 | if agg_type == "conv1d":
48 | self.conv1d = torch.nn.Conv1d(self.fc_output_dim, self.fc_output_dim, self.seq_length)
49 | self.aggregation_dim = self.fc_output_dim
50 | if agg_type in ["simplefc", "meanfc"]:
51 | self.aggregation_dim = self.fc_output_dim
52 | self.final_fc = torch.nn.Linear(self.fc_output_dim * args.seq_length, self.fc_output_dim, bias=False)
53 | if agg_type == "meanfc":
54 | # Initialize as a mean pooling over the frames
55 | weights = torch.zeros_like(self.final_fc.weight)
56 | for i in range(self.fc_output_dim):
57 | for j in range(args.seq_length):
58 | weights[i, j * self.fc_output_dim + i] = 1 / args.seq_length
59 | self.final_fc.weight = torch.nn.Parameter(weights)
60 | if agg_type == "seqgem":
61 | self.aggregation_dim = self.fc_output_dim
62 | self.seq_gem = SeqGeM()
63 |
64 | self.agg_type = agg_type
65 |
66 | def forward(self, x):
67 | return self.model(x)
68 |
69 | def aggregate(self, frames_features):
70 | if self.agg_type == "concat":
71 | concat_features = einops.rearrange(frames_features, "(b sl) d -> b (sl d)", sl=self.seq_length)
72 | return concat_features
73 | if self.agg_type == "mean":
74 | aggregated_features = einops.rearrange(frames_features, "(b sl) d -> b sl d", sl=self.seq_length)
75 | return aggregated_features.mean(1)
76 | if self.agg_type == "max":
77 | aggregated_features = einops.rearrange(frames_features, "(b sl) d -> b sl d", sl=self.seq_length)
78 | return aggregated_features.max(1)[0]
79 | if self.agg_type == "conv1d":
80 | aggregated_features = einops.rearrange(frames_features, "(b sl) d -> b sl d", sl=self.seq_length)
81 | aggregated_features = einops.rearrange(aggregated_features, "b sl d -> b d sl", sl=self.seq_length)
82 | features = self.conv1d(aggregated_features)
83 | if len(features.shape) > 2 and features.shape[2] == 1:
84 | features = features[:, :, 0]
85 | return features
86 | if self.agg_type in ["simplefc", "meanfc"]:
87 | concat_features = einops.rearrange(frames_features, "(b sl) d -> b (sl d)", sl=self.seq_length)
88 | return self.final_fc(concat_features)
89 | if self.agg_type == "seqgem":
90 | aggregated_features = einops.rearrange(frames_features, "(b sl) d -> b sl d", sl=self.seq_length)
91 | return self.seq_gem(aggregated_features)
92 |
--------------------------------------------------------------------------------
/jist/utils/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['parser', 'logging', 'utils', 'data', 'commons', 'augmentations']
2 |
3 |
4 | from .parser import parse_arguments
5 | from .logging import setup_logging
6 | from .utils import save_checkpoint, resume_train, load_pretrained_backbone, configure_transform
7 | from .data import RAMEfficient2DMatrix
8 | from .commons import InfiniteDataLoader, make_deterministic, delete_model_gradients
9 | from .cp_utils import move_to_device
10 | from .augmentations import DeviceAgnosticColorJitter, DeviceAgnosticRandomResizedCrop
11 | from .cosface_loss import MarginCosineProduct
12 |
--------------------------------------------------------------------------------
/jist/utils/augmentations.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torchvision.transforms as T
4 |
5 |
6 | class DeviceAgnosticColorJitter(T.ColorJitter):
7 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
8 | """This is the same as T.ColorJitter but it only accepts batches of images and works on GPU"""
9 | super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
10 | def forward(self, images):
11 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}"
12 | B, C, H, W = images.shape
13 | # Applies a different color jitter to each image
14 | color_jitter = super(DeviceAgnosticColorJitter, self).forward
15 | augmented_images = [color_jitter(img).unsqueeze(0) for img in images]
16 | augmented_images = torch.cat(augmented_images)
17 | assert augmented_images.shape == torch.Size([B, C, H, W])
18 | return augmented_images
19 |
20 | class DeviceAgnosticRandomResizedCrop(T.RandomResizedCrop):
21 | def __init__(self, size, scale):
22 | """This is the same as T.RandomResizedCrop but it only accepts batches of images and works on GPU"""
23 | super().__init__(size=size, scale=scale)
24 | def forward(self, images):
25 | assert len(images.shape) == 4, f"images should be a batch of images, but it has shape {images.shape}"
26 | B, C, H, W = images.shape
27 | # Applies a different color jitter to each image
28 | random_resized_crop = super(DeviceAgnosticRandomResizedCrop, self).forward
29 | augmented_images = [random_resized_crop(img).unsqueeze(0) for img in images]
30 | augmented_images = torch.cat(augmented_images)
31 | return augmented_images
32 |
33 |
34 | if __name__ == "__main__":
35 | """
36 | You can run this script to visualize the transformations, and verify that
37 | the augmentations are applied individually on each image of the batch.
38 | """
39 | from PIL import Image
40 | # Import skimage in here, so it is not necessary to install it unless you run this script
41 | from skimage import data
42 |
43 | # Initialize DeviceAgnosticRandomResizedCrop
44 | random_crop = DeviceAgnosticRandomResizedCrop(size=[256, 256], scale=[0.5, 1])
45 | # Create a batch with 2 astronaut images
46 | pil_image = Image.fromarray(data.astronaut())
47 | tensor_image = T.functional.to_tensor(pil_image).unsqueeze(0)
48 | images_batch = torch.cat([tensor_image, tensor_image])
49 | # Apply augmentation (individually on each of the 2 images)
50 | augmented_batch = random_crop(images_batch)
51 | # Convert to PIL images
52 | augmented_image_0 = T.functional.to_pil_image(augmented_batch[0])
53 | augmented_image_1 = T.functional.to_pil_image(augmented_batch[1])
54 | # Visualize the original image, as well as the two augmented ones
55 | pil_image.show()
56 | augmented_image_0.show()
57 | augmented_image_1.show()
58 |
59 |
--------------------------------------------------------------------------------
/jist/utils/commons.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import random
4 | import numpy as np
5 |
6 |
7 | class InfiniteDataLoader(torch.utils.data.DataLoader):
8 | def __init__(self, *args, **kwargs):
9 | super().__init__(*args, **kwargs)
10 | self.dataset_iterator = super().__iter__()
11 |
12 | def __iter__(self):
13 | return self
14 |
15 | def __next__(self):
16 | try:
17 | batch = next(self.dataset_iterator)
18 | except StopIteration:
19 | self.dataset_iterator = super().__iter__()
20 | batch = next(self.dataset_iterator)
21 | return batch
22 |
23 |
24 | def delete_model_gradients(model):
25 | """Set gradients to None to free some GPU memory. Useful before inference.
26 | Note that using optimizer.zero_grad() sets them to 0, which keeps using GPU space.
27 | """
28 | for param in model.parameters():
29 | param.grad = None
30 |
31 |
32 | def make_deterministic(seed=0):
33 | """Make results deterministic. If seed == -1, do not make deterministic.
34 | Running your script in a deterministic way might slow it down.
35 | Note that for some packages (eg: sklearn's PCA) this function is not enough.
36 | """
37 | seed = int(seed)
38 | if seed == -1:
39 | return
40 | random.seed(seed)
41 | np.random.seed(seed)
42 | torch.manual_seed(seed)
43 | torch.cuda.manual_seed_all(seed)
44 | torch.backends.cudnn.deterministic = True
45 | torch.backends.cudnn.benchmark = False
46 |
--------------------------------------------------------------------------------
/jist/utils/cosface_loss.py:
--------------------------------------------------------------------------------
1 |
2 | # Based on https://github.com/MuggleWang/CosFace_pytorch/blob/master/layer.py
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import Parameter
7 |
8 | def cosine_sim(x1, x2, dim=1, eps=1e-8):
9 | ip = torch.mm(x1, x2.t())
10 | w1 = torch.norm(x1, 2, dim)
11 | w2 = torch.norm(x2, 2, dim)
12 | return ip / torch.ger(w1,w2).clamp(min=eps)
13 |
14 | class MarginCosineProduct(nn.Module):
15 | """Implement of large margin cosine distance:
16 | Args:
17 | in_features: size of each input sample
18 | out_features: size of each output sample
19 | s: norm of input feature
20 | m: margin
21 | """
22 | def __init__(self, in_features, out_features, s=30.0, m=0.40):
23 | super().__init__()
24 | self.in_features = in_features
25 | self.out_features = out_features
26 | self.s = s
27 | self.m = m
28 | self.weight = Parameter(torch.Tensor(out_features, in_features))
29 | nn.init.xavier_uniform_(self.weight)
30 | def forward(self, input, label):
31 | cosine = cosine_sim(input, self.weight)
32 | one_hot = torch.zeros_like(cosine)
33 | one_hot.scatter_(1, label.view(-1, 1), 1.0)
34 | output = self.s * (cosine - one_hot * self.m)
35 | return output
36 | def __repr__(self):
37 | return self.__class__.__name__ + '(' \
38 | + 'in_features=' + str(self.in_features) \
39 | + ', out_features=' + str(self.out_features) \
40 | + ', s=' + str(self.s) \
41 | + ', m=' + str(self.m) + ')'
42 |
43 |
--------------------------------------------------------------------------------
/jist/utils/cp_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import shutil
4 | import logging
5 |
6 |
7 | def move_to_device(optimizer, device):
8 | for state in optimizer.state.values():
9 | for k, v in state.items():
10 | if torch.is_tensor(v):
11 | state[k] = v.to(device)
12 |
13 |
14 | def save_checkpoint(state, is_best, output_folder, ckpt_filename="last_checkpoint.pth"):
15 | # TODO it would be better to move weights to cpu before saving
16 | checkpoint_path = f"{output_folder}/{ckpt_filename}"
17 | torch.save(state, checkpoint_path)
18 | if is_best:
19 | torch.save(state["model_state_dict"], f"{output_folder}/best_model.pth")
20 |
21 |
22 | def resume_train(args, output_folder, model, model_optimizer, classifiers, classifiers_optimizers):
23 | """Load model, optimizer, and other training parameters"""
24 | logging.info(f"Loading checkpoint: {args.resume_train}")
25 | checkpoint = torch.load(args.resume_train)
26 | start_epoch_num = checkpoint["epoch_num"]
27 |
28 | model_state_dict = checkpoint["model_state_dict"]
29 | model.load_state_dict(model_state_dict)
30 |
31 | model = model.to(args.device)
32 | model_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
33 |
34 | assert args.groups_num == len(classifiers) == len(classifiers_optimizers) == len(checkpoint["classifiers_state_dict"]) == len(checkpoint["optimizers_state_dict"]), \
35 | f"{args.groups_num} , {len(classifiers)} , {len(classifiers_optimizers)} , {len(checkpoint['classifiers_state_dict'])} , {len(checkpoint['optimizers_state_dict'])}"
36 |
37 | for c, sd in zip(classifiers, checkpoint["classifiers_state_dict"]):
38 | # Move classifiers to GPU before loading their optimizers
39 | c = c.to(args.device)
40 | c.load_state_dict(sd)
41 | for c, sd in zip(classifiers_optimizers, checkpoint["optimizers_state_dict"]):
42 | c.load_state_dict(sd)
43 | for c in classifiers:
44 | # Move classifiers back to CPU to save some GPU memory
45 | c = c.cpu()
46 |
47 | sequence_best_r5 = checkpoint["best_seq_val_recall5"]
48 |
49 | # Copy best model to current output_folder
50 | shutil.copy(args.resume_train.replace("last_checkpoint.pth", "best_model.pth"), output_folder)
51 |
52 | return model, model_optimizer, classifiers, classifiers_optimizers, sequence_best_r5, start_epoch_num
53 |
--------------------------------------------------------------------------------
/jist/utils/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class RAMEfficient2DMatrix:
5 | """This class behaves similarly to a numpy.ndarray initialized
6 | with np.zeros(), but is implemented to save RAM when the rows
7 | within the 2D array are sparse. In this case it's needed because
8 | we don't always compute features for each image, just for few of
9 | them"""
10 |
11 | def __init__(self, shape, dtype=np.float32):
12 | self.shape = shape
13 | self.dtype = dtype
14 | self.matrix = [None] * shape[0]
15 |
16 | def __setitem__(self, indexes, vals):
17 | assert vals.shape[1] == self.shape[1], f"{vals.shape[1]} {self.shape[1]}"
18 | for i, val in zip(indexes, vals):
19 | self.matrix[i] = val.astype(self.dtype, copy=False)
20 |
21 | def __getitem__(self, index):
22 | if hasattr(index, "__len__"):
23 | return np.array([self.matrix[i] for i in index])
24 | else:
25 | return self.matrix[index]
26 |
--------------------------------------------------------------------------------
/jist/utils/logging.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import sys
4 | import traceback
5 |
6 |
7 | def setup_logging(output_folder, console="debug",
8 | info_filename="info.log", debug_filename="debug.log"):
9 | """Set up logging files and console output.
10 | Creates one file for INFO logs and one for DEBUG logs.
11 | Args:
12 | output_folder (str): creates the folder where to save the files.
13 | debug (str):
14 | if == "debug" prints on console debug messages and higher
15 | if == "info" prints on console info messages and higher
16 | if == None does not use console (useful when a logger has already been set)
17 | info_filename (str): the name of the info file. if None, don't create info file
18 | debug_filename (str): the name of the debug file. if None, don't create debug file
19 | """
20 | os.makedirs(output_folder, exist_ok=True)
21 | # logging.Logger.manager.loggerDict.keys() to check which loggers are in use
22 | import warnings
23 | warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
24 | logging.getLogger('matplotlib.font_manager').disabled = True
25 | logging.getLogger('shapely').disabled = True
26 | logging.getLogger('shapely.geometry').disabled = True
27 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S")
28 | logger = logging.getLogger('')
29 | logger.setLevel(logging.DEBUG)
30 | logging.getLogger('PIL').setLevel(logging.INFO) # turn off logging tag for some images
31 |
32 | if info_filename != None:
33 | info_file_handler = logging.FileHandler(f'{output_folder}/{info_filename}')
34 | info_file_handler.setLevel(logging.INFO)
35 | info_file_handler.setFormatter(base_formatter)
36 | logger.addHandler(info_file_handler)
37 |
38 | if debug_filename != None:
39 | debug_file_handler = logging.FileHandler(f'{output_folder}/{debug_filename}')
40 | debug_file_handler.setLevel(logging.DEBUG)
41 | debug_file_handler.setFormatter(base_formatter)
42 | logger.addHandler(debug_file_handler)
43 |
44 | if console != None:
45 | console_handler = logging.StreamHandler()
46 | if console == "debug": console_handler.setLevel(logging.DEBUG)
47 | if console == "info": console_handler.setLevel(logging.INFO)
48 | console_handler.setFormatter(base_formatter)
49 | logger.addHandler(console_handler)
50 |
51 | def exception_handler(type_, value, tb):
52 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb)))
53 | sys.excepthook = exception_handler
--------------------------------------------------------------------------------
/jist/utils/parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 |
5 |
6 | def _get_device_count():
7 | if torch.cuda.is_available():
8 | return torch.cuda.device_count()
9 | return -1
10 |
11 |
12 | def parse_arguments():
13 | parser = argparse.ArgumentParser(description="Sequence Visual Geolocalization",
14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
15 | # reproducibility
16 | parser.add_argument('--seed', type=int, default=0)
17 |
18 | # dataset
19 | parser.add_argument("--city", type=str, default='', help='subset of cities from train set')
20 | parser.add_argument("--seq_length", type=int, default=5,
21 | help="Number of images in each sequence")
22 | parser.add_argument("--reverse", action='store_true', default=False, help='reverse DB sequences frames')
23 | parser.add_argument("--cut_last_frame", action='store_true', default=False, help='cut last sequence frame')
24 | parser.add_argument("--val_posDistThr", type=int, default=25, help="_")
25 | parser.add_argument("--train_posDistThr", type=int, default=10, help="_")
26 | parser.add_argument("--negDistThr", type=int, default=25, help="_")
27 | parser.add_argument('--img_shape', type=int, default=[480, 640], nargs=2,
28 | help="Resizing shape for images (HxW).")
29 |
30 | # about triplets and mining
31 | parser.add_argument("--nNeg", type=int, default=5,
32 | help="How many negatives to consider per each query in the loss")
33 | parser.add_argument("--cached_negatives", type=int, default=3000,
34 | help="How many negatives to use to compute the hardest ones")
35 | parser.add_argument("--cached_queries", type=int, default=1000,
36 | help="How many queries to keep cached")
37 | parser.add_argument("--queries_per_epoch", type=int, default=5000,
38 | help="How many queries to consider for one epoch. Must be multiple of cached_queries")
39 |
40 | # models
41 | parser.add_argument("--resume", type=str, default=None,
42 | help="Path to load checkpoint from, for resuming training or testing.")
43 | parser.add_argument("--pretrain_model", type=str, default=None,
44 | help="Path to load pretrained model from.")
45 |
46 | # training pars
47 | parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
48 | parser.add_argument("--num_workers", type=int, default=8,
49 | help="num_workers for all dataloaders")
50 |
51 | parser.add_argument("--num_sub_epochs", type=int, default=10,
52 | help="How many times to recompute cache per epoch.")
53 | parser.add_argument("--train_batch_size", type=int, default=4,
54 | help="Number of triplets: (query + pos + negs) * seq_length.")
55 | parser.add_argument("--infer_batch_size", type=int, default=8,
56 | help="Batch size for inference (caching and testing)")
57 |
58 | parser.add_argument("--epochs_num", type=int, default=5,
59 | help="number of epochs to train for")
60 |
61 | parser.add_argument("--margin", type=float, default=0.1,
62 | help="margin for the triplet loss")
63 | parser.add_argument("--lambda_triplet", type=float, default=10000, help="weight to triplet loss")
64 | parser.add_argument("--lambda_im2im", type=float, default=100, help="weight to cosplace loss")
65 |
66 | # PATHS
67 | parser.add_argument("--seq_dataset_path", type=str, help="Path of the seq2seq dataset")
68 | parser.add_argument("--dataset_folder", type=str, # should end in 'sf_xl/processed"
69 | help="path of the SF-XL processed folder with train/val/test sets")
70 | parser.add_argument("--exp_name", type=str, default="default",
71 | help="Folder name of the current run (saved in ./runs/)")
72 |
73 | # CosPlace Groups parameters
74 | parser.add_argument("--M", type=int, default=10, help="_")
75 | parser.add_argument("--alpha", type=int, default=30, help="_")
76 | parser.add_argument("--N", type=int, default=5, help="_")
77 | parser.add_argument("--L", type=int, default=2, help="_")
78 | parser.add_argument("--groups_num", type=int, default=8, help="_")
79 | parser.add_argument("--min_images_per_class", type=int, default=10, help="_")
80 | # Model parameters
81 | parser.add_argument("--backbone", type=str, default="ResNet18",
82 | choices=["ResNet18", "ResNet50", "ResNet101", "VGG16"], help="_")
83 | parser.add_argument("--aggregation_type", type=str, default="seqgem",
84 | choices=["concat", "mean", "max", "simplefc", "conv1d", "meanfc", "seqgem"], help="_")
85 | parser.add_argument("--fc_output_dim", type=int, default=512,
86 | help="Output dimension of final fully connected layer")
87 | parser.add_argument("--augmentation_device", type=str, default="cuda",
88 | choices=["cuda", "cpu"],
89 | help="on which device to run data augmentation")
90 | parser.add_argument("--cp_batch_size", type=int, default=64, help="_")
91 | parser.add_argument("--lr", type=float, default=0.00001, help="_")
92 | parser.add_argument("--classifiers_lr", type=float, default=0.01, help="_")
93 | # Data augmentation
94 | parser.add_argument("--brightness", type=float, default=0.7, help="_")
95 | parser.add_argument("--contrast", type=float, default=0.7, help="_")
96 | parser.add_argument("--hue", type=float, default=0.5, help="_")
97 | parser.add_argument("--saturation", type=float, default=0.7, help="_")
98 | parser.add_argument("--random_resized_crop", type=float, default=0.5, help="_")
99 | # Resume parameters
100 | parser.add_argument("--resume_train", type=str, default=None,
101 | help="path to checkpoint to resume, e.g. logs/.../last_checkpoint.pth")
102 | parser.add_argument("--resume_model", type=str, default=None,
103 | help="path to model to resume, e.g. logs/.../best_model.pth")
104 |
105 | args = parser.parse_args()
106 |
107 | if args.dataset_folder is not None:
108 | args.train_set_folder = os.path.join(args.dataset_folder, "train")
109 | if not os.path.exists(args.train_set_folder):
110 | raise FileNotFoundError(f"Folder {args.train_set_folder} does not exist")
111 |
112 | args.val_set_folder = os.path.join(args.dataset_folder, "val")
113 | if not os.path.exists(args.val_set_folder):
114 | raise FileNotFoundError(f"Folder {args.val_set_folder} does not exist")
115 |
116 | args.test_set_folder = os.path.join(args.dataset_folder, "test")
117 | if not os.path.exists(args.test_set_folder):
118 | raise FileNotFoundError(f"Folder {args.test_set_folder} does not exist")
119 |
120 | if args.queries_per_epoch % args.cached_queries != 0:
121 | raise ValueError("Please ensure that queries_per_epoch is divisible by cache_refresh_rate, " +
122 | f"because {args.queries_per_epoch} is not divisible by {args.cached_queries}")
123 |
124 | return args
125 |
126 |
--------------------------------------------------------------------------------
/jist/utils/utils.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | from torchvision import transforms
3 | import torch
4 | import logging
5 |
6 |
7 | def save_checkpoint(args, state, is_best, filename):
8 | model_path = f"{args.output_folder}/{filename}"
9 | torch.save(state, model_path)
10 | if is_best:
11 | shutil.copyfile(model_path, f"{args.output_folder}/best_model.pth")
12 |
13 |
14 | def resume_train(args, model, optimizer=None, strict=False):
15 | """Load model, optimizer, and other training parameters"""
16 | logging.debug(f"Loading checkpoint: {args.resume}")
17 | checkpoint = torch.load(args.resume)
18 | start_epoch_num = checkpoint["epoch_num"]
19 | model.load_state_dict(checkpoint["model_state_dict"], strict=strict)
20 | if optimizer:
21 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
22 | best_r5 = checkpoint["best_r5"]
23 | logging.debug(f"Loaded checkpoint: start_epoch_num = {start_epoch_num}, " \
24 | f"current_best_R@5 = {best_r5:.1f}")
25 | if args.resume.endswith("last_model.pth"): # Copy best model to current output_folder
26 | shutil.copy(args.resume.replace("last_model.pth", "best_model.pth"), args.output_folder)
27 | return model, optimizer, best_r5, start_epoch_num
28 |
29 |
30 | def load_pretrained_backbone(args, model):
31 | """Load a pretrained backbone"""
32 | logging.debug(f"Loading checkpoint: {args.pretrain_model}")
33 | checkpoint = torch.load(args.pretrain_model)
34 | model.load_state_dict(checkpoint["model_state_dict"], strict=False)
35 | return model
36 |
37 |
38 | def configure_transform(image_dim, meta):
39 | normalize = transforms.Normalize(mean=meta['mean'], std=meta['std'])
40 | transform = transforms.Compose([
41 | transforms.Resize(image_dim),
42 | transforms.ToTensor(),
43 | normalize,
44 | ])
45 |
46 | return transform
47 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu117
2 |
3 | torch==1.13.1+cu117
4 | torchvision==0.14.1+cu117
5 | faiss-cpu==1.7.3
6 | scikit-learn==1.2.0
7 | scipy==1.10.0
8 | transformers==4.25.1
9 | timm==0.6.12
10 | torchmetrics==0.11.1
11 | einops
12 | tqdm
--------------------------------------------------------------------------------
/train_double_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | import torch
4 | import logging
5 | import torchmetrics
6 | from torch import nn
7 | import numpy as np
8 | from tqdm import tqdm
9 | import multiprocessing
10 | from datetime import datetime
11 | import torchvision.transforms as T
12 | from torch.utils.data.dataloader import DataLoader
13 | from os.path import join
14 | torch.backends.cudnn.benchmark= True # Provides a speedup
15 |
16 | # ours
17 | from jist.datasets import (BaseDataset, TrainDataset, collate_fn,
18 | CosplaceTrainDataset, TestDataset)
19 | from jist import utils
20 | from jist.models import JistModel
21 | from jist import evals
22 | from jist.utils import (parse_arguments, setup_logging, MarginCosineProduct,
23 | configure_transform, delete_model_gradients,
24 | InfiniteDataLoader, make_deterministic, move_to_device)
25 |
26 | args = parse_arguments()
27 | start_time = datetime.now()
28 | output_folder = f"logs/{args.exp_name}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}"
29 | setup_logging(output_folder)
30 | make_deterministic(args.seed)
31 | logging.info(" ".join(sys.argv))
32 | logging.info(f"Arguments: {args}")
33 | logging.info(f"The outputs are being saved in {output_folder}")
34 |
35 | # Model
36 | model = JistModel(args, agg_type=args.aggregation_type)
37 | logging.info(f"There are {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs.")
38 |
39 | if args.resume_model != None:
40 | logging.debug(f"Loading model from {args.resume_model}")
41 | model_state_dict = torch.load(args.resume_model)
42 | model.load_state_dict(model_state_dict)
43 |
44 | model = model.to(args.device).train()
45 |
46 | #### Optimizer
47 | criterion = torch.nn.CrossEntropyLoss()
48 | model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
49 | criterion_triplet = nn.TripletMarginLoss(margin=args.margin, p=2, reduction="sum")
50 |
51 | #|||||||||||||||||||||||| Datasets
52 | #### Datasets cosplace
53 | groups = [CosplaceTrainDataset(args, args.train_set_folder, M=args.M, alpha=args.alpha, N=args.N, L=args.L,
54 | current_group=n, min_images_per_class=args.min_images_per_class) for n in range(args.groups_num)]
55 | # Each group has its own classifier, which depends on the number of classes in the group
56 | classifiers = [MarginCosineProduct(args.fc_output_dim, len(group)) for group in groups]
57 | classifiers_optimizers = [torch.optim.Adam(classifier.parameters(), lr=args.classifiers_lr) for classifier in classifiers]
58 |
59 | logging.info(f"Using {len(groups)} groups")
60 | logging.info(f"The {len(groups)} groups have respectively the following number of classes {[len(g) for g in groups]}")
61 | logging.info(f"The {len(groups)} groups have respectively the following number of images {[g.get_images_num() for g in groups]}")
62 |
63 | cp_val_ds = TestDataset(args.val_set_folder, positive_dist_threshold=args.val_posDistThr)
64 | logging.info(f"Validation set: {cp_val_ds}")
65 |
66 | #### Datasets sequence
67 | # get transform
68 | meta = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
69 | img_shape = (args.img_shape[0], args.img_shape[1])
70 | transform = configure_transform(image_dim=img_shape, meta=meta)
71 |
72 | logging.info("Loading train set...")
73 | triplets_ds = TrainDataset(cities=args.city, dataset_folder=args.seq_dataset_path, split='train',
74 | base_transform=transform, seq_len=args.seq_length,
75 | pos_thresh=args.train_posDistThr, neg_thresh=args.negDistThr, infer_batch_size=args.infer_batch_size,
76 | num_workers=args.num_workers, img_shape=args.img_shape,
77 | cached_negatives=args.cached_negatives,
78 | cached_queries=args.cached_queries, nNeg=args.nNeg)
79 |
80 | logging.info(f"Train set: {triplets_ds}")
81 | logging.info("Loading val set...")
82 | val_ds = BaseDataset(dataset_folder=args.seq_dataset_path, split='val',
83 | base_transform=transform, seq_len=args.seq_length,
84 | pos_thresh=args.val_posDistThr)
85 | logging.info(f"Val set: {val_ds}")
86 |
87 | logging.info("Loading test set...")
88 | test_ds = BaseDataset(dataset_folder=args.seq_dataset_path, split='test',
89 | base_transform=transform, seq_len=args.seq_length,
90 | pos_thresh=args.val_posDistThr)
91 | logging.info(f"Test set: {test_ds}")
92 |
93 |
94 | #### Resume
95 | if args.resume_train:
96 | model, model_optimizer, classifiers, classifiers_optimizers, sequence_best_r5, start_epoch_num = \
97 | utils.cp_utils.resume_train(args, output_folder, model, model_optimizer, classifiers, classifiers_optimizers)
98 |
99 | epoch_num = start_epoch_num - 1
100 | best_val_recall1 = 0
101 | logging.info(f"Resuming from epoch {start_epoch_num} with best seq R@5 {sequence_best_r5:.1f} from checkpoint {args.resume_train}")
102 | else:
103 | best_val_recall1 = start_epoch_num = sequence_best_r5 = 0
104 |
105 | #### Train / evaluation loop
106 | iterations_per_epoch = args.cached_queries // args.train_batch_size * args.num_sub_epochs
107 | logging.info("Start training ...")
108 | logging.info(f"There are {len(groups[0])} classes for the first group, " +
109 | f"each epoch has {iterations_per_epoch} iterations " +
110 | f"with batch_size {args.cp_batch_size}, therefore the model sees each class (on average) " +
111 | f"{iterations_per_epoch * args.cp_batch_size / len(groups[0]):.1f} times per epoch")
112 | logging.info(f"Backbone output channels are {model.features_dim}, features descriptor dim is {model.fc_output_dim}, "
113 | f"sequence descriptor dim is {model.aggregation_dim}")
114 |
115 | gpu_augmentation = T.Compose([
116 | utils.augmentations.DeviceAgnosticColorJitter(brightness=args.brightness, contrast=args.contrast,
117 | saturation=args.saturation, hue=args.hue),
118 | utils.augmentations.DeviceAgnosticRandomResizedCrop([512, 512],
119 | scale=[1-args.random_resized_crop, 1]),
120 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
121 | ])
122 |
123 | scaler = torch.cuda.amp.GradScaler()
124 | for epoch_num in range(start_epoch_num, args.epochs_num):
125 |
126 | epoch_start_time = datetime.now()
127 | # Select classifier and dataloader according to epoch
128 | current_group_num = epoch_num % args.groups_num
129 | classifiers[current_group_num] = classifiers[current_group_num].to(args.device)
130 | move_to_device(classifiers_optimizers[current_group_num], args.device)
131 | dataloader = InfiniteDataLoader(groups[current_group_num], num_workers=args.num_workers, drop_last=True,
132 | batch_size=args.cp_batch_size, shuffle=True, pin_memory=(args.device == "cuda"))
133 | dataloader_iterator = iter(dataloader)
134 | model = model.train()
135 |
136 | sequence_mean_loss = torchmetrics.MeanMetric()
137 | cosplace_mean_loss = torchmetrics.MeanMetric()
138 |
139 | seq_epoch_losses = np.zeros((0, 1), dtype=np.float32)
140 |
141 | for num_sub_epoch in range(args.num_sub_epochs):
142 | logging.debug(f"Cache: {num_sub_epoch + 1} / {args.num_sub_epochs}")
143 |
144 | # creates triplets on the smaller cache set
145 | triplets_ds.compute_triplets(model)
146 | triplets_dl = DataLoader(dataset=triplets_ds, num_workers=args.num_workers,
147 | batch_size=args.train_batch_size,
148 | collate_fn=collate_fn,
149 | pin_memory=(args.device == "cuda"),
150 | drop_last=True)
151 |
152 | model = model.train()
153 | tqdm_bar = tqdm(triplets_dl, ncols=100)
154 | for images, _, _ in tqdm_bar:
155 | model_optimizer.zero_grad()
156 | classifiers_optimizers[current_group_num].zero_grad()
157 |
158 | if args.lambda_triplet != 0:
159 | #### ITERATION ON SEQUENCES
160 | # images shape: (bsz, seq_len*(nNeg + 2), 3, H, W)
161 | # triplets_local_indexes shape: (bsz, nNeg+2) -> contains -1 for query, 1 for pos, 0 for neg
162 | # reshape images to only have 4-d
163 | images = images.view(-1, 3, *img_shape)
164 | # features : (bsz*(nNeg+2), model_output_size)
165 | with torch.cuda.amp.autocast():
166 | features = model(images.to(args.device))
167 | features = model.aggregate(features)
168 |
169 | # Compute loss by passing the triplets one by one
170 | sequence_loss = 0
171 | features = features.view(args.train_batch_size, -1, model.aggregation_dim)
172 | for b in range(args.train_batch_size):
173 | query = features[b:b + 1, 0] # size (1, output_dim)
174 | pos = features[b:b + 1, 1] # size (1, output_dim)
175 | negatives = features[b, 2:] # size (nNeg, output_dim)
176 | # negatives has 10 images , pos and query 1 but
177 | # the loss yields same result as calling it 10 times
178 | sequence_loss += criterion_triplet(query, pos, negatives)
179 | del images, features
180 | sequence_loss /= (args.train_batch_size * args.nNeg)
181 | sequence_loss *= args.lambda_triplet
182 | scaler.scale(sequence_loss).backward()
183 | sequence_mean_loss.update(sequence_loss.item())
184 | del sequence_loss
185 | else:
186 | sequence_mean_loss.update(-1)
187 |
188 | if args.lambda_im2im != 0:
189 | #### ITERATION ON COSPLACE
190 | images, targets, _ = next(dataloader_iterator)
191 | images, targets = images.to(args.device), targets.to(args.device)
192 |
193 | if args.augmentation_device == "cuda":
194 | images = gpu_augmentation(images)
195 |
196 | with torch.cuda.amp.autocast():
197 | descriptors = model(images)
198 | output = classifiers[current_group_num](descriptors, targets)
199 | cosplace_loss = criterion(output, targets)
200 | del output, images, descriptors, targets
201 | cosplace_loss *= args.lambda_im2im
202 | scaler.scale(cosplace_loss).backward()
203 | cosplace_mean_loss.update(cosplace_loss.item())
204 | del cosplace_loss
205 | else:
206 | cosplace_mean_loss.update(-1)
207 |
208 | scaler.step(model_optimizer)
209 | if args.lambda_im2im != 0:
210 | scaler.step(classifiers_optimizers[current_group_num])
211 | scaler.update()
212 | tqdm_bar.set_description(f"seq_loss: {sequence_mean_loss.compute():.4f} - cos_loss: {cosplace_mean_loss.compute():.2f}")
213 |
214 | logging.debug(f"Epoch[{epoch_num:02d}]({num_sub_epoch + 1}/{args.num_sub_epochs}): " +
215 | f"epoch sequence loss = {sequence_mean_loss.compute():.4f} - " +
216 | f"epoch cosplace loss = {cosplace_mean_loss.compute():.4f}")
217 |
218 | classifiers[current_group_num] = classifiers[current_group_num].cpu()
219 | move_to_device(classifiers_optimizers[current_group_num], "cpu")
220 | delete_model_gradients(model)
221 |
222 | #### Evaluation CosPlace
223 | cosplace_recalls, cosplace_recalls_str = evals.cosplace_test(args, cp_val_ds, model)
224 | logging.info(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, cosPlace {cp_val_ds}: {cosplace_recalls_str[:20]}")
225 | cosplace_is_best = cosplace_recalls[0] > best_val_recall1
226 | cosplace_best_val_recall1 = max(cosplace_recalls[0], best_val_recall1)
227 |
228 | #### Evaluation Sequence
229 | sequence_recalls, sequence_recalls_str = evals.test(args, val_ds, model)
230 | logging.info(f"Epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, sequence {val_ds}: {sequence_recalls_str}")
231 | sequence_is_best = sequence_recalls[1] > sequence_best_r5
232 |
233 | if sequence_is_best:
234 | logging.info(f"Improved: previous best R@5 = {sequence_best_r5:.1f}, current R@5 = {sequence_recalls[1]:.1f}")
235 | sequence_best_r5 = sequence_recalls[1]
236 | else:
237 | logging.info(f"Not improved: best R@5 = {sequence_best_r5:.1f}, current R@5 = {sequence_recalls[1]:.1f}")
238 |
239 | utils.cp_utils.save_checkpoint({
240 | "epoch_num": epoch_num + 1,
241 | "model_state_dict": model.state_dict(),
242 | "optimizer_state_dict": model_optimizer.state_dict(),
243 | "classifiers_state_dict": [c.state_dict() for c in classifiers],
244 | "optimizers_state_dict": [c.state_dict() for c in classifiers_optimizers],
245 | "best_seq_val_recall5": sequence_best_r5
246 | }, sequence_is_best, output_folder)
247 | recalls, recalls_str = evals.test(args, test_ds, model)
248 | logging.info(f"Recalls on test set: {recalls_str}")
249 |
250 | logging.info(f"Trained for {epoch_num+1:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}")
251 | #### Test best model on test set
252 | best_model_state_dict = torch.load(join(output_folder, "best_model.pth"))
253 | model.load_state_dict(best_model_state_dict)
254 |
255 | recalls, recalls_str = evals.test(args, test_ds, model)
256 | logging.info(f"Recalls on test set: {recalls_str}")
257 |
--------------------------------------------------------------------------------