├── .gitignore ├── Colab_examples.ipynb ├── README.md ├── datasets ├── __init__.py ├── test_dataset.py └── train_dataset.py ├── download_datasets.py ├── eval.py ├── main.py ├── parser.py ├── requirements.txt ├── utils.py └── visualizations.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .idea 3 | .spyproject 4 | __pycache__ 5 | .ipynb_checkpoints 6 | LOGS 7 | *logs* -------------------------------------------------------------------------------- /Colab_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | }, 15 | "accelerator": "GPU", 16 | "gpuClass": "standard" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "## Setup datasets on Colab - only once" 23 | ], 24 | "metadata": { 25 | "id": "uJWTe9GjRHQJ" 26 | } 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "source": [ 31 | "The following code must be executed only the first time that you open Colab; it is needed to unzip the datasets that we provided into your GDrive account so that you can access them in the future\n", 32 | "\n", 33 | "NB: BEFORE running this notebook, you must copy the datasets zip into your GDrive. You can use the link that we provided and simply click 'create a copy'. Make sure that you have enough space (roughly 8 GBs)" 34 | ], 35 | "metadata": { 36 | "id": "NhknBP6cRLTm" 37 | } 38 | }, 39 | { 40 | "cell_type": "code", 41 | "source": [ 42 | "from google.colab import drive\n", 43 | "drive.mount('/content/drive')" 44 | ], 45 | "metadata": { 46 | "id": "TwmjD6mRcPtc", 47 | "colab": { 48 | "base_uri": "https://localhost:8080/" 49 | }, 50 | "outputId": "fbff158e-ed41-4590-8555-f5e0cbf271c8" 51 | }, 52 | "execution_count": 5, 53 | "outputs": [ 54 | { 55 | "output_type": "stream", 56 | "name": "stdout", 57 | "text": [ 58 | "/content/drive/MyDrive/datasets\n" 59 | ] 60 | } 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "source": [ 66 | "# move into your mounted GDrive filesystem\n", 67 | "%cd /content/drive/MyDrive/datasets\n", 68 | "\n", 69 | "# unzip the datasets to /content, so you don't risk going over the GDrive storage limit\n", 70 | "# this can take a few minutes\n", 71 | "!unzip -q gsv_xs.zip -d /content\n", 72 | "!unzip -q tokyo_xs.zip -d /content\n", 73 | "!unzip -q sf_xs.zip -d /content" 74 | ], 75 | "metadata": { 76 | "id": "6Hc56Wxkpcp7" 77 | }, 78 | "execution_count": 6, 79 | "outputs": [] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "source": [ 84 | "# now remove the zips from your GDrive\n", 85 | "%rm gsv_xs.zip tokyo_xs.zip sf_xs.zip\n", 86 | "# and move the unzipped datasets into your GDrive; in this way the next time\n", 87 | "# you open Colab you can directly access them\n", 88 | "!mv /content/*_xs /content/drive/MyDrive/datasets" 89 | ], 90 | "metadata": { 91 | "id": "PCpsbwtuQmtT" 92 | }, 93 | "execution_count": null, 94 | "outputs": [] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "source": [ 99 | "# now download the code. It is better if you fork this repo, so that you can push your\n", 100 | "# modifications. Inside the repo folder you will also find logs of the experiments\n", 101 | "%cd /content\n", 102 | "!git clone https://github.com/gmberton/Simple_VPR_codebase\n", 103 | "%cd /content/Simple_VPR_codebase" 104 | ], 105 | "metadata": { 106 | "id": "1Eb674sQRo_E" 107 | }, 108 | "execution_count": null, 109 | "outputs": [] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "source": [ 114 | "## Setup environment - every time" 115 | ], 116 | "metadata": { 117 | "id": "718c2aALRZzH" 118 | } 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "source": [ 123 | "Every time you restart Colab's kernel, you have to re-install packages and download anything that was not saved into your GDrive" 124 | ], 125 | "metadata": { 126 | "id": "YpOkk8I0RiPU" 127 | } 128 | }, 129 | { 130 | "cell_type": "code", 131 | "source": [ 132 | "from google.colab import drive\n", 133 | "drive.mount('/content/drive')" 134 | ], 135 | "metadata": { 136 | "id": "u9eZYfxzSE8d" 137 | }, 138 | "execution_count": null, 139 | "outputs": [] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "source": [ 144 | "%cd /content/Simple_VPR_codebase" 145 | ], 146 | "metadata": { 147 | "colab": { 148 | "base_uri": "https://localhost:8080/" 149 | }, 150 | "id": "xJJYyGeemmGm", 151 | "outputId": "1fe51b7b-9114-437e-f815-967f357cf359" 152 | }, 153 | "execution_count": 20, 154 | "outputs": [ 155 | { 156 | "output_type": "stream", 157 | "name": "stdout", 158 | "text": [ 159 | "/content/Simple_VPR_codebase\n" 160 | ] 161 | } 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "source": [ 167 | "!pip install -r requirements.txt" 168 | ], 169 | "metadata": { 170 | "colab": { 171 | "base_uri": "https://localhost:8080/" 172 | }, 173 | "id": "IwpebDBcDLrO", 174 | "outputId": "e3e97731-c248-41c0-e552-7e095125dfdd" 175 | }, 176 | "execution_count": 3, 177 | "outputs": [ 178 | { 179 | "output_type": "stream", 180 | "name": "stdout", 181 | "text": [ 182 | "imageio==2.25.1\n", 183 | "imageio-ffmpeg==0.4.8\n", 184 | "imagesize==1.4.1\n", 185 | "scikit-image==0.19.3\n" 186 | ] 187 | } 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "source": [ 193 | "# this is an example of a basic experiments. You can choose to validate on tokyo_xs or sf_xs\n", 194 | "\n", 195 | "!python main.py --train_path /path/to/datasets/gsv_xs --val_path /path/to/datasets/tokyo_xs/test --test_path /path/to/datasets/tokyo_xs/test --num_workers 2" 196 | ], 197 | "metadata": { 198 | "colab": { 199 | "base_uri": "https://localhost:8080/" 200 | }, 201 | "id": "zwlYsjznDLJS", 202 | "outputId": "7a87da62-b537-42e0-eaba-ec4e56c4136e" 203 | }, 204 | "execution_count": 10, 205 | "outputs": [ 206 | { 207 | "output_type": "stream", 208 | "name": "stdout", 209 | "text": [ 210 | "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:554: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", 211 | " warnings.warn(_create_warning_msg(\n", 212 | "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n", 213 | "100% 44.7M/44.7M [00:00<00:00, 116MB/s]\n", 214 | "Using 16bit None Automatic Mixed Precision (AMP)\n", 215 | "GPU available: True (cuda), used: True\n", 216 | "TPU available: False, using: 0 TPU cores\n", 217 | "IPU available: False, using: 0 IPUs\n", 218 | "HPU available: False, using: 0 HPUs\n", 219 | "Missing logger folder: LOGS/lightning_logs\n", 220 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 221 | "2023-03-28 12:27:39.447299: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n", 222 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 223 | "2023-03-28 12:27:43.216432: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/lib/python3.9/dist-packages/cv2/../../lib64:/usr/lib64-nvidia\n", 224 | "2023-03-28 12:27:43.216907: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/lib/python3.9/dist-packages/cv2/../../lib64:/usr/lib64-nvidia\n", 225 | "2023-03-28 12:27:43.216942: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", 226 | "Validation: 0it [00:00, ?it/s]/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:554: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", 227 | " warnings.warn(_create_warning_msg(\n", 228 | "Validation DataLoader 0: 100% 205/205 [00:41<00:00, 4.93it/s]R@1: 15.6, R@5: 35.6, R@10: 47.6, R@20: 57.8\n", 229 | "Validation DataLoader 0: 100% 205/205 [00:42<00:00, 4.87it/s]\n", 230 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 231 | "┃\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", 232 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 233 | "│\u001b[36m \u001b[0m\u001b[36m R@1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 15.555555555555555 \u001b[0m\u001b[35m \u001b[0m│\n", 234 | "│\u001b[36m \u001b[0m\u001b[36m R@5 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 35.55555555555556 \u001b[0m\u001b[35m \u001b[0m│\n", 235 | "└───────────────────────────┴───────────────────────────┘\n", 236 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 237 | "\n", 238 | " | Name | Type | Params\n", 239 | "--------------------------------------------\n", 240 | "0 | model | ResNet | 11.4 M\n", 241 | "1 | loss_fn | ContrastiveLoss | 0 \n", 242 | "--------------------------------------------\n", 243 | "11.4 M Trainable params\n", 244 | "0 Non-trainable params\n", 245 | "11.4 M Total params\n", 246 | "22.878 Total estimated model params size (MB)\n", 247 | "Epoch 0: 76% 900/1182 [17:18<05:25, 1.15s/it, loss=0.984, v_num=0]/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/call.py:48: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n", 248 | " rank_zero_warn(\"Detected KeyboardInterrupt, attempting graceful shutdown...\")\n", 249 | "Traceback (most recent call last):\n", 250 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/call.py\", line 38, in _call_and_handle_interrupt\n", 251 | " return trainer_fn(*args, **kwargs)\n", 252 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py\", line 650, in _fit_impl\n", 253 | " self._run(model, ckpt_path=self.ckpt_path)\n", 254 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py\", line 1112, in _run\n", 255 | " results = self._run_stage()\n", 256 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py\", line 1191, in _run_stage\n", 257 | " self._run_train()\n", 258 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py\", line 1214, in _run_train\n", 259 | " self.fit_loop.run()\n", 260 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/loop.py\", line 199, in run\n", 261 | " self.advance(*args, **kwargs)\n", 262 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/fit_loop.py\", line 267, in advance\n", 263 | " self._outputs = self.epoch_loop.run(self._data_fetcher)\n", 264 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/loop.py\", line 199, in run\n", 265 | " self.advance(*args, **kwargs)\n", 266 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py\", line 187, in advance\n", 267 | " batch = next(data_fetcher)\n", 268 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/utilities/fetching.py\", line 184, in __next__\n", 269 | " return self.fetching_function()\n", 270 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/utilities/fetching.py\", line 265, in fetching_function\n", 271 | " self._fetch_next_batch(self.dataloader_iter)\n", 272 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/utilities/fetching.py\", line 280, in _fetch_next_batch\n", 273 | " batch = next(iterator)\n", 274 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/supporters.py\", line 569, in __next__\n", 275 | " return self.request_next_batch(self.loader_iters)\n", 276 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/supporters.py\", line 581, in request_next_batch\n", 277 | " return apply_to_collection(loader_iters, Iterator, next)\n", 278 | " File \"/usr/local/lib/python3.9/dist-packages/lightning_utilities/core/apply_func.py\", line 51, in apply_to_collection\n", 279 | " return function(data, *args, **kwargs)\n", 280 | " File \"/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py\", line 628, in __next__\n", 281 | " data = self._next_data()\n", 282 | " File \"/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py\", line 1316, in _next_data\n", 283 | " idx, data = self._get_data()\n", 284 | " File \"/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py\", line 1282, in _get_data\n", 285 | " success, data = self._try_get_data()\n", 286 | " File \"/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py\", line 1120, in _try_get_data\n", 287 | " data = self._data_queue.get(timeout=timeout)\n", 288 | " File \"/usr/lib/python3.9/multiprocessing/queues.py\", line 113, in get\n", 289 | " if not self._poll(timeout):\n", 290 | " File \"/usr/lib/python3.9/multiprocessing/connection.py\", line 257, in poll\n", 291 | " return self._poll(timeout)\n", 292 | " File \"/usr/lib/python3.9/multiprocessing/connection.py\", line 424, in _poll\n", 293 | " r = wait([self], timeout)\n", 294 | " File \"/usr/lib/python3.9/multiprocessing/connection.py\", line 931, in wait\n", 295 | " ready = selector.select(timeout)\n", 296 | " File \"/usr/lib/python3.9/selectors.py\", line 416, in select\n", 297 | " fd_event_list = self._selector.poll(timeout)\n", 298 | "KeyboardInterrupt\n", 299 | "\n", 300 | "During handling of the above exception, another exception occurred:\n", 301 | "\n", 302 | "Traceback (most recent call last):\n", 303 | " File \"/content/Simple_VPR_codebase/main.py\", line 140, in \n", 304 | " trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)\n", 305 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py\", line 608, in fit\n", 306 | " call._call_and_handle_interrupt(\n", 307 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/call.py\", line 54, in _call_and_handle_interrupt\n", 308 | " logger.finalize(\"failed\")\n", 309 | " File \"/usr/local/lib/python3.9/dist-packages/lightning_utilities/core/rank_zero.py\", line 27, in wrapped_fn\n", 310 | " return fn(*args, **kwargs)\n", 311 | " File \"/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loggers/tensorboard.py\", line 218, in finalize\n", 312 | " super().finalize(status)\n", 313 | " File \"/usr/local/lib/python3.9/dist-packages/lightning_utilities/core/rank_zero.py\", line 27, in wrapped_fn\n", 314 | " return fn(*args, **kwargs)\n", 315 | " File \"/usr/local/lib/python3.9/dist-packages/lightning_fabric/loggers/tensorboard.py\", line 277, in finalize\n", 316 | " self.experiment.close()\n", 317 | " File \"/usr/local/lib/python3.9/dist-packages/torch/utils/tensorboard/writer.py\", line 1207, in close\n", 318 | " writer.close()\n", 319 | " File \"/usr/local/lib/python3.9/dist-packages/torch/utils/tensorboard/writer.py\", line 156, in close\n", 320 | " self.event_writer.close()\n", 321 | " File \"/usr/local/lib/python3.9/dist-packages/tensorboard/summary/writer/event_file_writer.py\", line 130, in close\n", 322 | " self._async_writer.close()\n", 323 | " File \"/usr/local/lib/python3.9/dist-packages/tensorboard/summary/writer/event_file_writer.py\", line 185, in close\n", 324 | " self._worker.stop()\n", 325 | " File \"/usr/local/lib/python3.9/dist-packages/tensorboard/summary/writer/event_file_writer.py\", line 214, in stop\n", 326 | " self.join()\n", 327 | " File \"/usr/lib/python3.9/threading.py\", line 1060, in join\n", 328 | " self._wait_for_tstate_lock()\n", 329 | " File \"/usr/lib/python3.9/threading.py\", line 1080, in _wait_for_tstate_lock\n", 330 | " if lock.acquire(block, timeout):\n", 331 | "KeyboardInterrupt\n", 332 | "Epoch 0: 76%|███████▌ | 900/1182 [17:38<05:31, 1.18s/it, loss=0.984, v_num=0]\n", 333 | "^C\n" 334 | ] 335 | } 336 | ] 337 | } 338 | ] 339 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple_VPR_codebase 2 | 3 | This repository serves as a starting point to implement a VPR pipeline. It allows you to train a simple 4 | ResNet-18 on the GSV dataset. It relies on the [pytorch_metric_learning](https://kevinmusgrave.github.io/pytorch-metric-learning/) 5 | library. 6 | 7 | ## Download datasets 8 | NB: if you are using Colab, skip this section 9 | 10 | The following script: 11 | 12 | > python download_datasets.py 13 | 14 | allows you to download GSV_xs, SF_xs, tokyo_xs, which are reduced version of the GSVCities, SF-XL, Tokyo247 datasets respectively. 15 | 16 | ## Install dependencies 17 | NB: if you are using Colab, skip this section 18 | 19 | You can install the required packages by running 20 | > pip install -r requirements.txt 21 | 22 | 23 | ## Run an experiment 24 | You can choose to validate/test on sf_xs or tokyo_xs. 25 | 26 | 27 | >python main.py --train_path /path/to/datasets/gsv_xs --val_path /path/to/datasets/tokyo_xs/test --test_path /path/to/datasets/tokyo_xs/test --exp_name expname 28 | 29 | ## Resuming from checkpoint 30 | 31 | The code will save the best (according to validation score) and last models. If you your experiment dies and you want to resume from where you left off, you can simply run you experiment passing the argument `--checkpoint model_path`. You can find the model checkpoints under `logs/lightning_logs/exp_name/checkpoints` 32 | 33 | ## Logging 34 | 35 | The code will log everything under the directory `logs/lightning_logs/exp_name`. You will find the models under `checkpoints`. 36 | All the textual outputs generated by the code is saved in 2 files, namely `logs/lightning_logs/exp_name/debug.log` and `logs/lightning_logs/exp_name/info.log`, where typically info.log contains only relevant info whereas debug.log is a superset of it and contains additional (typically useless) prints. 37 | If you want to add any prints to the code, you can do so by using the functions `logging.debug` of `logging.info`. 38 | Finally, in this directory you will see some binary files generated by tensorboard, that you can use with the proper library. 39 | Once you install tensorboard via pip (check the documentation on how to do it), you can download to your local machine the `logs` directory and inspect the logs using `tensorboard --logdir logs/lightning_logs`. It will launch a webserver running on localhost:6006, that you can inspect using your browser 40 | 41 | ## Running evaluations 42 | 43 | Once you have trained your models, you can run an inference only step using the `eval.py` script, and passing the `--checkpoint` argument to specify the model checkpoint to load 44 | 45 | > python eval.py --checkpoint logs/lightning_logs/exp_name/c 46 | heckpoints/_epoch\(01\)_R@1\[30.4762\]_R@5\[49.2063\].ckpt --train_path data/gsv_xs --val_path data/tokyo_xs/test --test_path data/tokyo_xs/test --exp_name test_model 47 | 48 | ## Usage on Colab 49 | 50 | We provide you with the notebook `colab_example.ipynb`. 51 | It shows you how to attach your GDrive file system to Colab, unzip the datasets, install packages and run your first experiment. 52 | 53 | NB: BEFORE running this notebook, you must copy the datasets zip into your GDrive. You can use the [link](https://drive.google.com/drive/folders/1Ucy9JONT26EjDAjIJFhuL9qeLxgSZKmf?usp=sharing) that we provided and simply click 'create a copy'. Make sure that you have enough space (roughly 8 GBs) 54 | 55 | NB^2: you can ignore the dataset `robotcar_one_every_2m`. -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmberton/Simple_VPR_codebase/fce38f1570d75896d8f7cc4cbb9a2c3cacfa7f6b/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/test_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | from glob import glob 5 | from PIL import Image 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | from sklearn.neighbors import NearestNeighbors 9 | 10 | 11 | def open_image(path): 12 | return Image.open(path).convert("RGB") 13 | 14 | 15 | class TestDataset(data.Dataset): 16 | def __init__(self, dataset_folder, database_folder="database", 17 | queries_folder="queries", positive_dist_threshold=25): 18 | """Dataset with images from database and queries, used for validation and test. 19 | Parameters 20 | ---------- 21 | dataset_folder : str, should contain the path to the val or test set, 22 | which contains the folders {database_folder} and {queries_folder}. 23 | database_folder : str, name of folder with the database. 24 | queries_folder : str, name of folder with the queries. 25 | positive_dist_threshold : int, distance in meters for a prediction to 26 | be considered a positive. 27 | """ 28 | super().__init__() 29 | self.dataset_folder = dataset_folder 30 | self.database_folder = os.path.join(dataset_folder, database_folder) 31 | self.queries_folder = os.path.join(dataset_folder, queries_folder) 32 | self.dataset_name = os.path.basename(dataset_folder) 33 | 34 | if not os.path.exists(self.dataset_folder): 35 | raise FileNotFoundError(f"Folder {self.dataset_folder} does not exist") 36 | if not os.path.exists(self.database_folder): 37 | raise FileNotFoundError(f"Folder {self.database_folder} does not exist") 38 | if not os.path.exists(self.queries_folder): 39 | raise FileNotFoundError(f"Folder {self.queries_folder} does not exist") 40 | 41 | self.base_transform = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 44 | ]) 45 | 46 | #### Read paths and UTM coordinates for all images. 47 | self.database_paths = sorted(glob(os.path.join(self.database_folder, "**", "*.jpg"), recursive=True)) 48 | self.queries_paths = sorted(glob(os.path.join(self.queries_folder, "**", "*.jpg"), recursive=True)) 49 | if len(self.database_paths) == 0: 50 | raise FileNotFoundError(f"There are no images under {self.database_folder} , you should change this path") 51 | if len(self.queries_paths) == 0: 52 | raise FileNotFoundError(f"There are no images under {self.queries_paths} , you should change this path") 53 | # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg 54 | self.database_utms = np.array \ 55 | ([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float) 56 | self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype \ 57 | (float) 58 | 59 | # Find positives_per_query, which are within positive_dist_threshold (default 25 meters) 60 | knn = NearestNeighbors(n_jobs=-1) 61 | knn.fit(self.database_utms) 62 | self.positives_per_query = knn.radius_neighbors(self.queries_utms, 63 | radius=positive_dist_threshold, 64 | return_distance=False) 65 | 66 | self.images_paths = [p for p in self.database_paths] 67 | self.images_paths += [p for p in self.queries_paths] 68 | 69 | self.database_num = len(self.database_paths) 70 | self.queries_num = len(self.queries_paths) 71 | 72 | def __getitem__(self, index): 73 | image_path = self.images_paths[index] 74 | pil_img = open_image(image_path) 75 | normalized_img = self.base_transform(pil_img) 76 | return normalized_img, index 77 | 78 | def __len__(self): 79 | return len(self.images_paths) 80 | 81 | def __repr__(self): 82 | return f"< {self.dataset_name} - #q: {self.queries_num}; #db: {self.database_num} >" 83 | 84 | def get_positives(self): 85 | return self.positives_per_query 86 | 87 | -------------------------------------------------------------------------------- /datasets/train_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as tfm 8 | from collections import defaultdict 9 | 10 | default_transform = tfm.Compose([ 11 | tfm.ToTensor(), 12 | tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 13 | ]) 14 | 15 | 16 | class TrainDataset(Dataset): 17 | def __init__( 18 | self, 19 | dataset_folder, 20 | img_per_place=4, 21 | min_img_per_place=4, 22 | transform=default_transform, 23 | ): 24 | super().__init__() 25 | self.dataset_folder = dataset_folder 26 | self.images_paths = sorted(glob(f"{dataset_folder}/**/*.jpg", recursive=True)) 27 | if len(self.images_paths) == 0: 28 | raise FileNotFoundError(f"There are no images under {dataset_folder} , you should change this path") 29 | self.dict_place_paths = defaultdict(list) 30 | for image_path in self.images_paths: 31 | place_id = image_path.split("@")[-2] 32 | self.dict_place_paths[place_id].append(image_path) 33 | 34 | assert img_per_place <= min_img_per_place, \ 35 | f"img_per_place should be less than {min_img_per_place}" 36 | self.img_per_place = img_per_place 37 | self.transform = transform 38 | 39 | # keep only places depicted by at least min_img_per_place images 40 | for place_id in list(self.dict_place_paths.keys()): 41 | all_paths_from_place_id = self.dict_place_paths[place_id] 42 | if len(all_paths_from_place_id) < min_img_per_place: 43 | del self.dict_place_paths[place_id] 44 | self.places_ids = sorted(list(self.dict_place_paths.keys())) 45 | self.total_num_images = sum([len(paths) for paths in self.dict_place_paths.values()]) 46 | 47 | def __getitem__(self, index): 48 | place_id = self.places_ids[index] 49 | all_paths_from_place_id = self.dict_place_paths[place_id] 50 | chosen_paths = np.random.choice(all_paths_from_place_id, self.img_per_place) 51 | images = [Image.open(path).convert('RGB') for path in chosen_paths] 52 | images = [self.transform(img) for img in images] 53 | return torch.stack(images), torch.tensor(index).repeat(self.img_per_place) 54 | 55 | def __len__(self): 56 | """Denotes the total number of places (not images)""" 57 | return len(self.places_ids) 58 | -------------------------------------------------------------------------------- /download_datasets.py: -------------------------------------------------------------------------------- 1 | 2 | URLS = { 3 | "tokyo_xs": "https://drive.google.com/file/d/15QB3VNKj93027UAQWv7pzFQO1JDCdZj2/view?usp=share_link", 4 | "sf_xs": "https://drive.google.com/file/d/1tQqEyt3go3vMh4fj_LZrRcahoTbzzH-y/view?usp=share_link", 5 | "gsv_xs": "https://drive.google.com/file/d/1q7usSe9_5xV5zTfN-1In4DlmF5ReyU_A/view?usp=share_link" 6 | } 7 | 8 | import os 9 | import gdown 10 | import shutil 11 | 12 | os.makedirs("data", exist_ok=True) 13 | for dataset_name, url in URLS.items(): 14 | print(f"Downloading {dataset_name}") 15 | zip_filepath = f"data/{dataset_name}.zip" 16 | gdown.download(url, zip_filepath, fuzzy=True) 17 | shutil.unpack_archive(zip_filepath, extract_dir="data") 18 | os.remove(zip_filepath) 19 | 20 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | from torch.utils.data.dataloader import DataLoader 4 | from pytorch_lightning import loggers as pl_loggers 5 | from os.path import join 6 | 7 | import utils 8 | import parser 9 | from datasets.test_dataset import TestDataset 10 | from main import LightningModel 11 | 12 | def get_datasets_and_dataloaders(args): 13 | val_dataset = TestDataset(dataset_folder=args.val_path) 14 | test_dataset = TestDataset(dataset_folder=args.test_path) 15 | val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) 16 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) 17 | return val_dataset, test_dataset, val_loader, test_loader 18 | 19 | 20 | if __name__ == '__main__': 21 | args = parser.parse_arguments() 22 | utils.setup_logging(join('logs', args.exp_name), console='info') 23 | 24 | val_dataset, test_dataset, val_loader, test_loader = get_datasets_and_dataloaders(args) 25 | model = LightningModel(val_dataset, test_dataset, args.descriptors_dim, args.num_preds_to_save, args.save_only_wrong_preds) 26 | 27 | tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/", version=args.exp_name) 28 | # Instantiate a trainer 29 | trainer = pl.Trainer( 30 | accelerator='gpu', 31 | devices=[0], 32 | default_root_dir='./logs', # Tensorflow can be used to viz 33 | num_sanity_val_steps=0, # runs a validation step before stating training 34 | precision=16, # we use half precision to reduce memory usage 35 | max_epochs=args.max_epochs, 36 | check_val_every_n_epoch=1, # run validation every epoch 37 | logger=tb_logger, # log through tensorboard 38 | reload_dataloaders_every_n_epochs=1, # we reload the dataset to shuffle the order 39 | log_every_n_steps=20, 40 | ) 41 | trainer.test(model=model, dataloaders=test_loader, ckpt_path=args.checkpoint) 42 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import torchvision.models 5 | import pytorch_lightning as pl 6 | from torchvision import transforms as tfm 7 | from pytorch_metric_learning import losses 8 | from torch.utils.data.dataloader import DataLoader 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | from pytorch_lightning import loggers as pl_loggers 11 | import logging 12 | from os.path import join 13 | 14 | import utils 15 | import parser 16 | from datasets.test_dataset import TestDataset 17 | from datasets.train_dataset import TrainDataset 18 | 19 | 20 | class LightningModel(pl.LightningModule): 21 | def __init__(self, val_dataset, test_dataset, descriptors_dim=512, num_preds_to_save=0, save_only_wrong_preds=True): 22 | super().__init__() 23 | self.val_dataset = val_dataset 24 | self.test_dataset = test_dataset 25 | self.num_preds_to_save = num_preds_to_save 26 | self.save_only_wrong_preds = save_only_wrong_preds 27 | # Use a pretrained model 28 | self.model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT) 29 | # Change the output of the FC layer to the desired descriptors dimension 30 | self.model.fc = torch.nn.Linear(self.model.fc.in_features, descriptors_dim) 31 | # Set the loss function 32 | self.loss_fn = losses.ContrastiveLoss(pos_margin=0, neg_margin=1) 33 | 34 | def forward(self, images): 35 | descriptors = self.model(images) 36 | return descriptors 37 | 38 | def configure_optimizers(self): 39 | optimizers = torch.optim.SGD(self.parameters(), lr=0.001, weight_decay=0.001, momentum=0.9) 40 | return optimizers 41 | 42 | # The loss function call (this method will be called at each training iteration) 43 | def loss_function(self, descriptors, labels): 44 | loss = self.loss_fn(descriptors, labels) 45 | return loss 46 | 47 | # This is the training step that's executed at each iteration 48 | def training_step(self, batch, batch_idx): 49 | images, labels = batch 50 | num_places, num_images_per_place, C, H, W = images.shape 51 | images = images.view(num_places * num_images_per_place, C, H, W) 52 | labels = labels.view(num_places * num_images_per_place) 53 | 54 | # Feed forward the batch to the model 55 | descriptors = self(images) # Here we are calling the method forward that we defined above 56 | loss = self.loss_function(descriptors, labels) # Call the loss_function we defined above 57 | 58 | self.log('loss', loss.item(), logger=True) 59 | return {'loss': loss} 60 | 61 | # For validation and test, we iterate step by step over the validation set 62 | def inference_step(self, batch): 63 | images, _ = batch 64 | descriptors = self(images) 65 | return descriptors.cpu().numpy().astype(np.float32) 66 | 67 | def validation_step(self, batch, batch_idx): 68 | return self.inference_step(batch) 69 | 70 | def test_step(self, batch, batch_idx): 71 | return self.inference_step(batch) 72 | 73 | def validation_epoch_end(self, all_descriptors): 74 | return self.inference_epoch_end(all_descriptors, self.val_dataset, 'val') 75 | 76 | def test_epoch_end(self, all_descriptors): 77 | return self.inference_epoch_end(all_descriptors, self.test_dataset, 'test', self.num_preds_to_save) 78 | 79 | def inference_epoch_end(self, all_descriptors, inference_dataset, split, num_preds_to_save=0): 80 | """all_descriptors contains database then queries descriptors""" 81 | all_descriptors = np.concatenate(all_descriptors) 82 | queries_descriptors = all_descriptors[inference_dataset.database_num : ] 83 | database_descriptors = all_descriptors[ : inference_dataset.database_num] 84 | 85 | recalls, recalls_str = utils.compute_recalls( 86 | inference_dataset, queries_descriptors, database_descriptors, 87 | self.logger.log_dir, num_preds_to_save, self.save_only_wrong_preds 88 | ) 89 | # print(recalls_str) 90 | logging.info(f"Epoch[{self.current_epoch:02d}]): " + 91 | f"recalls: {recalls_str}") 92 | 93 | self.log(f'{split}/R@1', recalls[0], prog_bar=False, logger=True) 94 | self.log(f'{split}/R@5', recalls[1], prog_bar=False, logger=True) 95 | 96 | def get_datasets_and_dataloaders(args): 97 | train_transform = tfm.Compose([ 98 | tfm.RandAugment(num_ops=3), 99 | tfm.ToTensor(), 100 | tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 101 | ]) 102 | train_dataset = TrainDataset( 103 | dataset_folder=args.train_path, 104 | img_per_place=args.img_per_place, 105 | min_img_per_place=args.min_img_per_place, 106 | transform=train_transform 107 | ) 108 | val_dataset = TestDataset(dataset_folder=args.val_path) 109 | test_dataset = TestDataset(dataset_folder=args.test_path) 110 | train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) 111 | val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) 112 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) 113 | return train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader 114 | 115 | 116 | if __name__ == '__main__': 117 | args = parser.parse_arguments() 118 | utils.setup_logging(join('logs', 'lightning_logs', args.exp_name), console='info') 119 | 120 | train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader = get_datasets_and_dataloaders(args) 121 | model = LightningModel(val_dataset, test_dataset, args.descriptors_dim, args.num_preds_to_save, args.save_only_wrong_preds) 122 | 123 | # Model params saving using Pytorch Lightning. Save the best 3 models according to Recall@1 124 | checkpoint_cb = ModelCheckpoint( 125 | monitor='val/R@1', 126 | filename='_epoch({epoch:02d})_R@1[{val/R@1:.4f}]_R@5[{val/R@5:.4f}]', 127 | auto_insert_metric_name=False, 128 | save_weights_only=False, 129 | save_top_k=1, 130 | save_last=True, 131 | mode='max' 132 | ) 133 | 134 | tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/", version=args.exp_name) 135 | 136 | # Instantiate a trainer 137 | trainer = pl.Trainer( 138 | accelerator='gpu', 139 | devices=[0], 140 | default_root_dir='./logs', # Tensorflow can be used to viz 141 | num_sanity_val_steps=0, # runs a validation step before stating training 142 | precision=16, # we use half precision to reduce memory usage 143 | max_epochs=args.max_epochs, 144 | check_val_every_n_epoch=1, # run validation every epoch 145 | logger=tb_logger, # log through tensorboard 146 | callbacks=[checkpoint_cb], # we only run the checkpointing callback (you can add more) 147 | reload_dataloaders_every_n_epochs=1, # we reload the dataset to shuffle the order 148 | log_every_n_steps=20, 149 | ) 150 | trainer.validate(model=model, dataloaders=val_loader, ckpt_path=args.checkpoint) 151 | trainer.fit(model=model, ckpt_path=args.checkpoint, train_dataloaders=train_loader, val_dataloaders=val_loader) 152 | trainer.test(model=model, dataloaders=test_loader, ckpt_path='best') 153 | 154 | -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | 5 | def parse_arguments(): 6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | # experiment 8 | parser.add_argument("--exp_name", type=str, default="default", 9 | help="exp name") 10 | parser.add_argument("--checkpoint", type=str, default=None, 11 | help="checkpoint path") 12 | 13 | # Training parameters 14 | parser.add_argument("--batch_size", type=int, default=64, 15 | help="The number of places to use per iteration (one place is N images)") 16 | parser.add_argument("--img_per_place", type=int, default=4, 17 | help="The effective batch size is (batch_size * img_per_place)") 18 | parser.add_argument("--min_img_per_place", type=int, default=4, 19 | help="places with less than min_img_per_place are removed") 20 | parser.add_argument("--max_epochs", type=int, default=20, 21 | help="stop when training reaches max_epochs") 22 | parser.add_argument("--num_workers", type=int, default=8, 23 | help="number of processes to use for data loading / preprocessing") 24 | 25 | # Architecture parameters 26 | parser.add_argument("--descriptors_dim", type=int, default=512, 27 | help="dimensionality of the output descriptors") 28 | 29 | # Visualizations parameters 30 | parser.add_argument("--num_preds_to_save", type=int, default=0, 31 | help="At the end of training, save N preds for each query. " 32 | "Try with a small number like 3") 33 | parser.add_argument("--save_only_wrong_preds", action="store_true", 34 | help="When saving preds (if num_preds_to_save != 0) save only " 35 | "preds for difficult queries, i.e. with uncorrect first prediction") 36 | 37 | # Paths parameters 38 | parser.add_argument("--train_path", type=str, default="data/gsv_xs/train", 39 | help="path to train set") 40 | parser.add_argument("--val_path", type=str, default="data/sf_xs/val", 41 | help="path to val set (must contain database and queries)") 42 | parser.add_argument("--test_path", type=str, default="data/sf_xs/test", 43 | help="path to test set (must contain database and queries)") 44 | 45 | args = parser.parse_args() 46 | return args 47 | 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu116 2 | 3 | torch==1.13.1+cu116 4 | torchvision==0.14.1+cu116 5 | faiss-cpu==1.7.3 6 | pytorch-lightning==1.9.4 7 | pytorch-metric-learning==2.0.1 8 | opencv-python==4.7.0.72 9 | scikit-image==0.19.3 10 | googledrivedownloader==0.4 11 | gdown==4.7.1 12 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import faiss 3 | import logging 4 | import numpy as np 5 | from typing import Tuple 6 | from torch.utils.data import Dataset 7 | import os 8 | from os.path import join 9 | import sys 10 | import traceback 11 | 12 | import visualizations 13 | 14 | 15 | # Compute R@1, R@5, R@10, R@20 16 | RECALL_VALUES = [1, 5, 10, 20] 17 | 18 | def compute_recalls(eval_ds: Dataset, queries_descriptors : np.ndarray, database_descriptors : np.ndarray, 19 | output_folder : str = None, num_preds_to_save : int = 0, 20 | save_only_wrong_preds : bool = True) -> Tuple[np.ndarray, str]: 21 | """Compute the recalls given the queries and database descriptors. The dataset is needed to know the ground truth 22 | positives for each query.""" 23 | 24 | # Use a kNN to find predictions 25 | faiss_index = faiss.IndexFlatL2(queries_descriptors.shape[1]) 26 | faiss_index.add(database_descriptors) 27 | del database_descriptors 28 | 29 | logging.debug("Calculating recalls") 30 | _, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES)) 31 | 32 | #### For each query, check if the predictions are correct 33 | positives_per_query = eval_ds.get_positives() 34 | recalls = np.zeros(len(RECALL_VALUES)) 35 | for query_index, preds in enumerate(predictions): 36 | for i, n in enumerate(RECALL_VALUES): 37 | if np.any(np.in1d(preds[:n], positives_per_query[query_index])): 38 | recalls[i:] += 1 39 | break 40 | # Divide by queries_num and multiply by 100, so the recalls are in percentages 41 | recalls = recalls / eval_ds.queries_num * 100 42 | recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)]) 43 | 44 | # Save visualizations of predictions 45 | if num_preds_to_save != 0: 46 | # For each query save num_preds_to_save predictions 47 | visualizations.save_preds(predictions[:, :num_preds_to_save], eval_ds, output_folder, save_only_wrong_preds) 48 | 49 | return recalls, recalls_str 50 | 51 | 52 | def setup_logging(save_dir, console="debug", 53 | info_filename="info.log", debug_filename="debug.log"): 54 | """Set up logging files and console output. 55 | Creates one file for INFO logs and one for DEBUG logs. 56 | Args: 57 | save_dir (str): creates the folder where to save the files. 58 | debug (str): 59 | if == "debug" prints on console debug messages and higher 60 | if == "info" prints on console info messages and higher 61 | if == None does not use console (useful when a logger has already been set) 62 | info_filename (str): the name of the info file. if None, don't create info file 63 | debug_filename (str): the name of the debug file. if None, don't create debug file 64 | """ 65 | if os.path.exists(save_dir): 66 | raise FileExistsError(f"{save_dir} already exists!") 67 | os.makedirs(save_dir, exist_ok=True) 68 | # logging.Logger.manager.loggerDict.keys() to check which loggers are in use 69 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S") 70 | logger = logging.getLogger('') 71 | logger.setLevel(logging.DEBUG) 72 | 73 | if info_filename is not None: 74 | info_file_handler = logging.FileHandler(join(save_dir, info_filename)) 75 | info_file_handler.setLevel(logging.INFO) 76 | info_file_handler.setFormatter(base_formatter) 77 | logger.addHandler(info_file_handler) 78 | 79 | if debug_filename is not None: 80 | debug_file_handler = logging.FileHandler(join(save_dir, debug_filename)) 81 | debug_file_handler.setLevel(logging.DEBUG) 82 | debug_file_handler.setFormatter(base_formatter) 83 | logger.addHandler(debug_file_handler) 84 | 85 | if console is not None: 86 | console_handler = logging.StreamHandler() 87 | if console == "debug": 88 | console_handler.setLevel(logging.DEBUG) 89 | if console == "info": 90 | console_handler.setLevel(logging.INFO) 91 | console_handler.setFormatter(base_formatter) 92 | logger.addHandler(console_handler) 93 | 94 | def exception_handler(type_, value, tb): 95 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb))) 96 | sys.excepthook = exception_handler 97 | -------------------------------------------------------------------------------- /visualizations.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import numpy as np 5 | from tqdm import tqdm 6 | from skimage.transform import rescale 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | 10 | # Height and width of a single image 11 | H = 512 12 | W = 512 13 | TEXT_H = 175 14 | FONTSIZE = 80 15 | SPACE = 50 # Space between two images 16 | 17 | 18 | def write_labels_to_image(labels=["text1", "text2"]): 19 | """Creates an image with vertical text, spaced along rows.""" 20 | font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", FONTSIZE) 21 | img = Image.new('RGB', ((W * len(labels)) + 50 * (len(labels)-1), TEXT_H), (1, 1, 1)) 22 | d = ImageDraw.Draw(img) 23 | for i, text in enumerate(labels): 24 | _, _, w, h = d.textbbox((0,0), text, font=font) 25 | d.text(((W+SPACE)*i + W//2 - w//2, 1), text, fill=(0, 0, 0), font=font) 26 | return np.array(img) 27 | 28 | 29 | def draw(img, c=(0, 255, 0), thickness=20): 30 | """Draw a colored (usually red or green) box around an image.""" 31 | p = np.array([[0, 0], [0, img.shape[0]], [img.shape[1], img.shape[0]], [img.shape[1], 0]]) 32 | for i in range(3): 33 | cv2.line(img, (p[i, 0], p[i, 1]), (p[i+1, 0], p[i+1, 1]), c, thickness=thickness*2) 34 | return cv2.line(img, (p[3, 0], p[3, 1]), (p[0, 0], p[0, 1]), c, thickness=thickness*2) 35 | 36 | 37 | def build_prediction_image(images_paths, preds_correct=None): 38 | """Build a row of images, where the first is the query and the rest are predictions. 39 | For each image, if is_correct then draw a green/red box. 40 | """ 41 | assert len(images_paths) == len(preds_correct) 42 | labels = ["Query"] + [f"Pr{i} - {is_correct}" for i, is_correct in enumerate(preds_correct[1:])] 43 | num_images = len(images_paths) 44 | images = [np.array(Image.open(path)) for path in images_paths] 45 | for img, correct in zip(images, preds_correct): 46 | if correct is None: 47 | continue 48 | color = (0, 255, 0) if correct else (255, 0, 0) 49 | draw(img, color) 50 | concat_image = np.ones([H, (num_images*W)+((num_images-1)*SPACE), 3]) 51 | rescaleds = [rescale(i, [min(H/i.shape[0], W/i.shape[1]), min(H/i.shape[0], W/i.shape[1]), 1]) for i in images] 52 | for i, image in enumerate(rescaleds): 53 | pad_width = (W - image.shape[1] + 1) // 2 54 | pad_height = (H - image.shape[0] + 1) // 2 55 | image = np.pad(image, [[pad_height, pad_height], [pad_width, pad_width], [0, 0]], constant_values=1)[:H, :W] 56 | concat_image[: , i*(W+SPACE) : i*(W+SPACE)+W] = image 57 | labels_image = write_labels_to_image(labels) 58 | final_image = np.concatenate([labels_image, concat_image]) 59 | final_image = Image.fromarray((final_image*255).astype(np.uint8)) 60 | return final_image 61 | 62 | 63 | def save_file_with_paths(query_path, preds_paths, positives_paths, output_path): 64 | file_content = [] 65 | file_content.append("Query path:") 66 | file_content.append(query_path + "\n") 67 | file_content.append("Predictions paths:") 68 | file_content.append("\n".join(preds_paths) + "\n") 69 | file_content.append("Positives paths:") 70 | file_content.append("\n".join(positives_paths) + "\n") 71 | with open(output_path, "w") as file: 72 | _ = file.write("\n".join(file_content)) 73 | 74 | 75 | def save_preds(predictions, eval_ds, output_folder, save_only_wrong_preds=None): 76 | """For each query, save an image containing the query and its predictions, 77 | and a file with the paths of the query, its predictions and its positives. 78 | 79 | Parameters 80 | ---------- 81 | predictions : np.array of shape [num_queries x num_preds_to_viz], with the preds 82 | for each query 83 | eval_ds : TestDataset 84 | output_folder : str / Path with the path to save the predictions 85 | save_only_wrong_preds : bool, if True save only the wrongly predicted queries, 86 | i.e. the ones where the first pred is uncorrect (further than 25 m) 87 | """ 88 | positives_per_query = eval_ds.get_positives() 89 | os.makedirs(f"{output_folder}/preds", exist_ok=True) 90 | for query_index, preds in enumerate(tqdm(predictions, ncols=80, desc=f"Saving preds in {output_folder}")): 91 | query_path = eval_ds.queries_paths[query_index] 92 | list_of_images_paths = [query_path] 93 | # List of None (query), True (correct preds) or False (wrong preds) 94 | preds_correct = [None] 95 | for pred_index, pred in enumerate(preds): 96 | pred_path = eval_ds.database_paths[pred] 97 | list_of_images_paths.append(pred_path) 98 | is_correct = pred in positives_per_query[query_index] 99 | preds_correct.append(is_correct) 100 | 101 | if save_only_wrong_preds and preds_correct[1]: 102 | continue 103 | 104 | prediction_image = build_prediction_image(list_of_images_paths, preds_correct) 105 | pred_image_path = f"{output_folder}/preds/{query_index:03d}.jpg" 106 | prediction_image.save(pred_image_path) 107 | 108 | positives_paths = [eval_ds.database_paths[idx] for idx in positives_per_query[query_index]] 109 | save_file_with_paths( 110 | query_path=list_of_images_paths[0], 111 | preds_paths=list_of_images_paths[1:], 112 | positives_paths=positives_paths, 113 | output_path=f"{output_folder}/preds/{query_index:03d}.txt" 114 | ) 115 | 116 | --------------------------------------------------------------------------------