├── .gitignore ├── LICENSE ├── README.md ├── dataloaders ├── CliqueMapillaryDataset.py ├── GSVCitiesDataloader.py ├── GSVCitiesDataset.py ├── MapillaryDataset.py ├── PittsburgDataset.py └── val │ ├── MapillaryDataset.py │ ├── NordlandDataset.py │ ├── PittsburghDataset.py │ └── SPEDDataset.py ├── datasets ├── Nordland │ ├── Nordland_dbImages.npy │ ├── Nordland_gt.npy │ └── Nordland_qImages.npy ├── Pittsburgh │ ├── pitts250k_test_dbImages.npy │ ├── pitts250k_test_gt.npy │ ├── pitts250k_test_qImages.npy │ ├── pitts30k_test_dbImages.npy │ ├── pitts30k_test_gt.npy │ ├── pitts30k_test_qImages.npy │ ├── pitts30k_val_dbImages.npy │ ├── pitts30k_val_gt.npy │ └── pitts30k_val_qImages.npy ├── SPED │ ├── SPED_dbImages.npy │ ├── SPED_gt.npy │ └── SPED_qImages.npy └── msls_val │ ├── msls_val_dbImages.npy │ ├── msls_val_pIdx.npy │ ├── msls_val_qIdx.npy │ └── msls_val_qImages.npy ├── environment.yml ├── eval.py ├── hubconf.py ├── main.py ├── models ├── __init__.py ├── aggregators │ ├── __init__.py │ ├── convap.py │ ├── cosplace.py │ ├── gem.py │ ├── mixvpr.py │ └── salad.py ├── backbones │ ├── __init__.py │ ├── dinov2.py │ └── resnet.py └── helper.py ├── utils ├── __init__.py ├── losses.py └── validation.py └── vpr_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Close, But Not There: Boosting Geographic Distance Sensitivity in Visual Place Recognition 2 | Sergio Izquierdo, Javier Civera 3 | 4 | Code and models for the ECCV 2024 paper "Close, But Not There: Boosting Geographic Distance Sensitivity in Visual Place Recognition" (CliqueMining) 5 | 6 | ## Summary 7 | 8 | In this repo, we include a novel mining pipeline, CliqueMining, that creates very difficult batches. It creates a graph of very similar images and samples cliques (representing places) to create challenging batches. This technique improves performance on many common datasets. 9 | 10 | For more details, check the [paper](https://arxiv.org/abs/2407.02422). 11 | 12 | ## Weights 13 | 14 | You can download the weights of the trained model [here](https://drive.google.com/file/d/1B06ysb-Wjb4KDcNrl-7pyj1mJve1jqdk/view?usp=sharing). To evaluate, follow the same steps as with [SALAD](https://github.com/serizba/salad). 15 | 16 | -------------------------------------------------------------------------------- /dataloaders/CliqueMapillaryDataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | from PIL import Image, ImageFile, UnidentifiedImageError 4 | ImageFile.LOAD_TRUNCATED_IMAGES = True 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as T 8 | import numpy as np 9 | import tqdm 10 | 11 | import concurrent.futures 12 | from scipy.spatial.distance import cdist, pdist, squareform 13 | import networkx 14 | 15 | default_transform = T.Compose([ 16 | T.ToTensor(), 17 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 18 | ]) 19 | 20 | # NOTE: Hard coded path to dataset folder 21 | BASE_PATH = '../data/msls/train_val/' 22 | 23 | if not Path(BASE_PATH).exists(): 24 | raise FileNotFoundError( 25 | 'BASE_PATH is hardcoded, please adjust to point to gsv_cities') 26 | 27 | def load_city_df(base_path): 28 | # Load cities 29 | city_df = {} 30 | for city in (Path(base_path)).iterdir(): 31 | 32 | # Database 33 | db = pd.read_csv(city / 'database' / 'postprocessed.csv') 34 | db = db.join( 35 | pd.read_csv(city / 'database' / 'raw.csv')[['pano', 'key']].set_index('key'), 36 | on='key' 37 | ) 38 | db.insert(0, 'query', False) 39 | 40 | # Query 41 | q = pd.read_csv(city / 'query' / 'postprocessed.csv') 42 | q = q.join( 43 | pd.read_csv(city / 'query' / 'raw.csv')[['pano', 'key']].set_index('key'), 44 | on='key' 45 | ) 46 | q.insert(0, 'query', True) 47 | 48 | df = pd.concat([db, q]) 49 | 50 | # Remove where pano is True 51 | df = df[df['pano'] == False] 52 | 53 | city_df[city.name] = df 54 | 55 | return city_df 56 | 57 | def compute_cluster_descriptors(city_df, model, descriptor_size=8192 + 256, batch_size=64): 58 | 59 | class MSLSDataset(torch.utils.data.Dataset): 60 | def __init__(self, rows, city_path): 61 | self.rows = rows 62 | self.city_path = city_path 63 | 64 | self.valid_transform = T.Compose([ 65 | T.Resize((322, 322), interpolation=T.InterpolationMode.BILINEAR), 66 | T.ToTensor(), 67 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 68 | ]) 69 | 70 | def __len__(self): 71 | return len(self.rows) 72 | 73 | def __getitem__(self, idx): 74 | row = self.rows.iloc[idx] 75 | path = Path(BASE_PATH) / self.city_path / ('query' if row['query'] else 'database') / 'images' / f'{row["key"]}.jpg' 76 | try: 77 | img = Image.open(path) 78 | except: 79 | print(f'Image {path} could not be loaded') 80 | img = Image.new('RGB', (322, 322)) 81 | img = self.valid_transform(img) 82 | return img, row['unique_cluster'] 83 | 84 | 85 | cluster_descriptors_dict = {} 86 | for city, df in tqdm.tqdm(city_df.items(), desc='Computing cluster descriptors'): 87 | 88 | # Create dataloader with one sample per cluster 89 | msls = MSLSDataset(df.groupby('unique_cluster').sample(1), city) 90 | dataloader = torch.utils.data.DataLoader( 91 | dataset=msls, 92 | batch_size=batch_size, 93 | num_workers=8, 94 | drop_last=False, 95 | pin_memory=True, 96 | shuffle=False 97 | ) 98 | 99 | cluster_descriptors = torch.zeros((df.unique_cluster.max() + 1, descriptor_size)).cuda() 100 | 101 | # Compute descriptors for each cluster 102 | with torch.no_grad(): 103 | for batch in dataloader: 104 | img, clusters = batch 105 | img = img.cuda() 106 | descriptors = model(img) 107 | cluster_descriptors[clusters] = descriptors 108 | 109 | cluster_descriptors_dict[city] = cluster_descriptors.cpu().numpy() 110 | 111 | return cluster_descriptors_dict 112 | 113 | 114 | def create_dataset_part( 115 | cluster_descriptors_dict, 116 | city_df, 117 | num_batches=100, 118 | batch_size=60, 119 | num_images_per_place=4, 120 | sampled_similar_places=15, 121 | same_place_threshold=20.0, 122 | ): 123 | 124 | import os 125 | import time 126 | np.random.seed((os.getpid() * int(time.time())) % 123456789) 127 | 128 | images = np.zeros((num_batches, batch_size, num_images_per_place), dtype=object) 129 | 130 | for i in tqdm.tqdm(range(num_batches)): 131 | 132 | cities_this_batch = [] 133 | 134 | batch_idx = 0 135 | while batch_idx < batch_size: 136 | 137 | cities_to_sample = [c for c in cluster_descriptors_dict.keys()] 138 | num_clusters = np.array([d.shape[0] for c, d in cluster_descriptors_dict.items()]) 139 | 140 | city = np.random.choice(cities_to_sample, p=num_clusters/num_clusters.sum()) 141 | 142 | # Don't sample already done in this batch 143 | while city in cities_this_batch: 144 | city = np.random.choice(cities_to_sample, p=num_clusters/num_clusters.sum()) 145 | cities_this_batch.append(city) 146 | 147 | 148 | df = city_df[city] 149 | descriptor = cluster_descriptors_dict[city] 150 | 151 | # Sample a random cluster 152 | place_id = np.random.choice(df.unique_cluster.unique()) 153 | 154 | # Compute similarity between the selected cluster and all the others 155 | distances = cdist(descriptor[place_id, None, :], descriptor)[0] 156 | # Normalize distances as probabilities (where min distance is max probability) 157 | distances[distances != 0] = distances.max() - distances[distances != 0] 158 | distances = distances / distances.sum() 159 | 160 | # Sample similar places 161 | other_places = np.random.choice(np.arange(df.unique_cluster.max() + 1), size=sampled_similar_places, p=distances, replace=False) 162 | other_places = np.concatenate([np.array([place_id]), other_places]) 163 | 164 | df = df[df['unique_cluster'].isin(other_places)] 165 | 166 | # Create adjacency matrix from UTM coordinates (two places are connected if they are closer than same_place_threshold) 167 | utms = squareform(pdist(df[['easting', 'northing']].values)) < same_place_threshold 168 | 169 | while batch_idx < batch_size: 170 | 171 | # Find a clique of at least num_images_per_place 172 | for c in networkx.find_cliques(networkx.Graph(utms)): 173 | if len(c) >= num_images_per_place: 174 | clique = np.random.choice(c, num_images_per_place, replace=False) 175 | break 176 | else: 177 | break 178 | 179 | neighbors = np.unique(np.where(utms[clique, :])[1]) 180 | 181 | # Append place to batch 182 | rows = df.iloc[list(clique)] 183 | images[i, batch_idx] = np.char.add(np.char.add(np.where(rows['query'].values, f'{city}/query/images/', f'{city}/database/images/').astype(' (0000013 and 0500013) 72 | # We suppose that there is no city with more than 73 | # 99999 images and there won't be more than 99 cities 74 | # TODO: rename the dataset and hardcode these prefixes 75 | prefix = i 76 | tmp_df['place_id'] = tmp_df['place_id'] + (prefix * 10**5) 77 | tmp_df = tmp_df.sample(frac=1) # shuffle the city dataframe 78 | 79 | df = pd.concat([df, tmp_df], ignore_index=True) 80 | 81 | # keep only places depicted by at least min_img_per_place images 82 | res = df[df.groupby('place_id')['place_id'].transform( 83 | 'size') >= self.min_img_per_place] 84 | return res.set_index('place_id') 85 | 86 | def __getitem__(self, index): 87 | place_id = self.places_ids[index] 88 | 89 | # get the place in form of a dataframe (each row corresponds to one image) 90 | place = self.dataframe.loc[place_id] 91 | 92 | # sample K images (rows) from this place 93 | # we can either sort and take the most recent k images 94 | # or randomly sample them 95 | if self.random_sample_from_each_place: 96 | place = place.sample(n=self.img_per_place) 97 | else: # always get the same most recent images 98 | place = place.sort_values( 99 | by=['year', 'month', 'lat'], ascending=False) 100 | place = place[: self.img_per_place] 101 | 102 | imgs = [] 103 | for i, row in place.iterrows(): 104 | img_name = self.get_img_name(row) 105 | img_path = self.base_path + 'Images/' + \ 106 | row['city_id'] + '/' + img_name 107 | img = self.image_loader(img_path) 108 | 109 | if self.transform is not None: 110 | img = self.transform(img) 111 | 112 | imgs.append(img) 113 | 114 | # NOTE: contrary to image classification where __getitem__ returns only one image 115 | # in GSVCities, we return a place, which is a Tesor of K images (K=self.img_per_place) 116 | # this will return a Tensor of shape [K, channels, height, width]. This needs to be taken into account 117 | # in the Dataloader (which will yield batches of shape [BS, K, channels, height, width]) 118 | return torch.stack(imgs), torch.tensor(place_id).repeat(self.img_per_place) 119 | 120 | def __len__(self): 121 | '''Denotes the total number of places (not images)''' 122 | return len(self.places_ids) 123 | 124 | @staticmethod 125 | def image_loader(path): 126 | try: 127 | return Image.open(path).convert('RGB') 128 | except UnidentifiedImageError: 129 | print(f'Image {path} could not be loaded') 130 | return Image.new('RGB', (224, 224)) 131 | 132 | @staticmethod 133 | def get_img_name(row): 134 | # given a row from the dataframe 135 | # return the corresponding image name 136 | 137 | city = row['city_id'] 138 | 139 | # now remove the two digit we added to the id 140 | # they are superficially added to make ids different 141 | # for different cities 142 | pl_id = row.name % 10**5 #row.name is the index of the row, not to be confused with image name 143 | pl_id = str(pl_id).zfill(7) 144 | 145 | panoid = row['panoid'] 146 | year = str(row['year']).zfill(4) 147 | month = str(row['month']).zfill(2) 148 | northdeg = str(row['northdeg']).zfill(3) 149 | lat, lon = str(row['lat']), str(row['lon']) 150 | name = city+'_'+pl_id+'_'+year+'_'+month+'_' + \ 151 | northdeg+'_'+lat+'_'+lon+'_'+panoid+'.jpg' 152 | return name 153 | -------------------------------------------------------------------------------- /dataloaders/MapillaryDataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | # NOTE: you need to download the mapillary_sls dataset from https://github.com/FrederikWarburg/mapillary_sls 7 | # make sure the path where the mapillary_sls validation dataset resides on your computer is correct. 8 | # the folder named train_val should reside in DATASET_ROOT path (that's the only folder you need from mapillary_sls) 9 | # I hardcoded the groundtruth for image to image evaluation, otherwise it would take ages to run the groundtruth script at each epoch. 10 | DATASET_ROOT = '../data/mapillary/' 11 | 12 | path_obj = Path(DATASET_ROOT) 13 | if not path_obj.exists(): 14 | raise Exception('Please make sure the path to mapillary_sls dataset is correct') 15 | 16 | if not path_obj.joinpath('train_val'): 17 | raise Exception(f'Please make sure the directory train_val from mapillary_sls dataset is situated in the directory {DATASET_ROOT}') 18 | 19 | class MSLS(Dataset): 20 | def __init__(self, input_transform = None): 21 | 22 | self.input_transform = input_transform 23 | 24 | # hard coded reference image names, this avoids the hassle of listing them at each epoch. 25 | self.dbImages = np.load('./datasets/msls_val/msls_val_dbImages.npy') 26 | 27 | # hard coded query image names. 28 | self.qImages = np.load('./datasets/msls_val/msls_val_qImages.npy') 29 | 30 | # hard coded index of query images 31 | self.qIdx = np.load('./datasets/msls_val/msls_val_qIdx.npy') 32 | 33 | # hard coded groundtruth (correspondence between each query and its matches) 34 | self.pIdx = np.load('./datasets/msls_val/msls_val_pIdx.npy', allow_pickle=True) 35 | 36 | # concatenate reference images then query images so that we can use only one dataloader 37 | self.images = np.concatenate((self.dbImages, self.qImages[self.qIdx])) 38 | 39 | # we need to keeo the number of references so that we can split references-queries 40 | # when calculating recall@K 41 | self.num_references = len(self.dbImages) 42 | 43 | def __getitem__(self, index): 44 | img = Image.open(DATASET_ROOT+self.images[index]) 45 | 46 | if self.input_transform: 47 | img = self.input_transform(img) 48 | 49 | return img, index 50 | 51 | def __len__(self): 52 | return len(self.images) -------------------------------------------------------------------------------- /dataloaders/PittsburgDataset.py: -------------------------------------------------------------------------------- 1 | from os.path import join, exists 2 | from collections import namedtuple 3 | from scipy.io import loadmat 4 | 5 | import torchvision.transforms as T 6 | import torch.utils.data as data 7 | 8 | 9 | from PIL import Image, UnidentifiedImageError 10 | from sklearn.neighbors import NearestNeighbors 11 | 12 | root_dir = '../data/Pittsburgh/' 13 | 14 | if not exists(root_dir): 15 | raise FileNotFoundError( 16 | 'root_dir is hardcoded, please adjust to point to Pittsburgh dataset') 17 | 18 | struct_dir = join(root_dir, 'datasets/') 19 | queries_dir = join(root_dir, 'queries_real') 20 | 21 | 22 | def input_transform(image_size=None): 23 | return T.Compose([ 24 | T.Resize(image_size),# interpolation=T.InterpolationMode.BICUBIC), 25 | T.ToTensor(), 26 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 27 | ]) 28 | 29 | 30 | 31 | def get_whole_val_set(input_transform): 32 | structFile = join(struct_dir, 'pitts30k_val.mat') 33 | return WholeDatasetFromStruct(structFile, input_transform=input_transform) 34 | 35 | 36 | def get_250k_val_set(input_transform): 37 | structFile = join(struct_dir, 'pitts250k_val.mat') 38 | return WholeDatasetFromStruct(structFile, input_transform=input_transform) 39 | 40 | 41 | def get_whole_test_set(input_transform): 42 | structFile = join(struct_dir, 'pitts30k_test.mat') 43 | return WholeDatasetFromStruct(structFile, input_transform=input_transform) 44 | 45 | 46 | def get_250k_test_set(input_transform): 47 | structFile = join(struct_dir, 'pitts250k_test.mat') 48 | return WholeDatasetFromStruct(structFile, input_transform=input_transform) 49 | 50 | def get_whole_training_set(onlyDB=False): 51 | structFile = join(struct_dir, 'pitts30k_train.mat') 52 | return WholeDatasetFromStruct(structFile, 53 | input_transform=input_transform(), 54 | onlyDB=onlyDB) 55 | 56 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 57 | 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 58 | 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr']) 59 | 60 | 61 | def parse_dbStruct(path): 62 | mat = loadmat(path) 63 | matStruct = mat['dbStruct'].item() 64 | 65 | if '250k' in path.split('/')[-1]: 66 | dataset = 'pitts250k' 67 | else: 68 | dataset = 'pitts30k' 69 | 70 | whichSet = matStruct[0].item() 71 | 72 | dbImage = [f[0].item() for f in matStruct[1]] 73 | utmDb = matStruct[2].T 74 | 75 | qImage = [f[0].item() for f in matStruct[3]] 76 | utmQ = matStruct[4].T 77 | 78 | numDb = matStruct[5].item() 79 | numQ = matStruct[6].item() 80 | 81 | posDistThr = matStruct[7].item() 82 | posDistSqThr = matStruct[8].item() 83 | nonTrivPosDistSqThr = matStruct[9].item() 84 | 85 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, 86 | utmQ, numDb, numQ, posDistThr, 87 | posDistSqThr, nonTrivPosDistSqThr) 88 | 89 | 90 | class WholeDatasetFromStruct(data.Dataset): 91 | def __init__(self, structFile, input_transform=None, onlyDB=False): 92 | super().__init__() 93 | 94 | self.input_transform = input_transform 95 | 96 | self.dbStruct = parse_dbStruct(structFile) 97 | self.images = [join(root_dir, dbIm) for dbIm in self.dbStruct.dbImage] 98 | if not onlyDB: 99 | self.images += [join(queries_dir, qIm) 100 | for qIm in self.dbStruct.qImage] 101 | 102 | self.whichSet = self.dbStruct.whichSet 103 | self.dataset = self.dbStruct.dataset 104 | 105 | self.positives = None 106 | self.distances = None 107 | 108 | def __getitem__(self, index): 109 | 110 | try: 111 | img = Image.open(self.images[index]) 112 | except UnidentifiedImageError: 113 | print(f'Image {self.images[index]} could not be loaded') 114 | img = Image.new('RGB', (224, 224)) 115 | 116 | if self.input_transform: 117 | img = self.input_transform(img) 118 | 119 | return img, index 120 | 121 | def __len__(self): 122 | return len(self.images) 123 | 124 | def getPositives(self): 125 | # positives for evaluation are those within trivial threshold range 126 | # fit NN to find them, search by radius 127 | if self.positives is None: 128 | knn = NearestNeighbors(n_jobs=-1) 129 | knn.fit(self.dbStruct.utmDb) 130 | 131 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, 132 | radius=self.dbStruct.posDistThr) 133 | 134 | return self.positives 135 | -------------------------------------------------------------------------------- /dataloaders/val/MapillaryDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | DATASET_ROOT = '../data/mapillary/' 7 | GT_ROOT = './datasets/' # BECAREFUL, this is the ground truth that comes with GSV-Cities 8 | 9 | class MSLS(Dataset): 10 | def __init__(self, input_transform = None): 11 | 12 | 13 | self.input_transform = input_transform 14 | 15 | self.dbImages = np.load(GT_ROOT+'msls_val/msls_val_dbImages.npy') 16 | self.qIdx = np.load(GT_ROOT+'msls_val/msls_val_qIdx.npy') 17 | self.qImages = np.load(GT_ROOT+'msls_val/msls_val_qImages.npy') 18 | self.ground_truth = np.load(GT_ROOT+'msls_val/msls_val_pIdx.npy', allow_pickle=True) 19 | 20 | # reference images then query images 21 | self.images = np.concatenate((self.dbImages, self.qImages[self.qIdx])) 22 | self.num_references = len(self.dbImages) 23 | self.num_queries = len(self.qImages[self.qIdx]) 24 | 25 | def __getitem__(self, index): 26 | img = Image.open(DATASET_ROOT + self.images[index]) 27 | 28 | if self.input_transform: 29 | img = self.input_transform(img) 30 | 31 | return img, index 32 | 33 | def __len__(self): 34 | return len(self.images) -------------------------------------------------------------------------------- /dataloaders/val/NordlandDataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from PIL import Image, ImageFile 4 | ImageFile.LOAD_TRUNCATED_IMAGES = True 5 | from torch.utils.data import Dataset 6 | 7 | # NOTE: you need to download the Nordland dataset from https://surfdrive.surf.nl/files/index.php/s/sbZRXzYe3l0v67W 8 | # this link is shared and maintained by the authors of VPR_Bench: https://github.com/MubarizZaffar/VPR-Bench 9 | # the folders named ref and query should reside in DATASET_ROOT path 10 | # I hardcoded the image names and ground truth for faster evaluation 11 | # performance is exactly the same as if you use VPR-Bench. 12 | 13 | DATASET_ROOT = '../data/Nordland/' 14 | GT_ROOT = './datasets/' # BECAREFUL, this is the ground truth that comes with GSV-Cities 15 | 16 | path_obj = Path(DATASET_ROOT) 17 | if not path_obj.exists(): 18 | raise Exception(f'Please make sure the path {DATASET_ROOT} to Nordland dataset is correct') 19 | 20 | if not path_obj.joinpath('ref') or not path_obj.joinpath('query'): 21 | raise Exception(f'Please make sure the directories query and ref are situated in the directory {DATASET_ROOT}') 22 | 23 | class NordlandDataset(Dataset): 24 | def __init__(self, input_transform = None): 25 | 26 | 27 | self.input_transform = input_transform 28 | 29 | # reference images names 30 | self.dbImages = np.load(GT_ROOT+'Nordland/Nordland_dbImages.npy') 31 | 32 | # query images names 33 | self.qImages = np.load(GT_ROOT+'Nordland/Nordland_qImages.npy') 34 | 35 | # ground truth 36 | self.ground_truth = np.load(GT_ROOT+'Nordland/Nordland_gt.npy', allow_pickle=True) 37 | 38 | # reference images then query images 39 | self.images = np.concatenate((self.dbImages, self.qImages)) 40 | 41 | self.num_references = len(self.dbImages) 42 | self.num_queries = len(self.qImages) 43 | 44 | 45 | def __getitem__(self, index): 46 | img = Image.open(DATASET_ROOT+self.images[index]) 47 | 48 | if self.input_transform: 49 | img = self.input_transform(img) 50 | 51 | return img, index 52 | 53 | def __len__(self): 54 | return len(self.images) -------------------------------------------------------------------------------- /dataloaders/val/PittsburghDataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | # NOTE: you need to download the Nordland dataset from https://surfdrive.surf.nl/files/index.php/s/sbZRXzYe3l0v67W 7 | # this link is shared and maintained by the authors of VPR_Bench: https://github.com/MubarizZaffar/VPR-Bench 8 | # the folders named ref and query should reside in DATASET_ROOT path 9 | # I hardcoded the image names and ground truth for faster evaluation 10 | # performance is exactly the same as if you use VPR-Bench. 11 | DATASET_ROOT = '../data/Pittsburgh/' 12 | GT_ROOT = './datasets/' # BECAREFUL, this is the ground truth that comes with GSV-Cities 13 | 14 | path_obj = Path(DATASET_ROOT) 15 | if not path_obj.exists(): 16 | raise Exception(f'Please make sure the path {DATASET_ROOT} to Nordland dataset is correct') 17 | 18 | if not path_obj.joinpath('ref') or not path_obj.joinpath('query'): 19 | raise Exception(f'Please make sure the directories query and ref are situated in the directory {DATASET_ROOT}') 20 | 21 | class PittsburghDataset(Dataset): 22 | def __init__(self, which_ds='pitts30k_test', input_transform = None): 23 | 24 | assert which_ds.lower() in ['pitts30k_val', 'pitts30k_test', 'pitts250k_test'] 25 | 26 | self.input_transform = input_transform 27 | 28 | # reference images names 29 | self.dbImages = np.load(GT_ROOT+f'Pittsburgh/{which_ds}_dbImages.npy') 30 | 31 | # query images names 32 | self.qImages = np.load(GT_ROOT+f'Pittsburgh/{which_ds}_qImages.npy') 33 | 34 | # ground truth 35 | self.ground_truth = np.load(GT_ROOT+f'Pittsburgh/{which_ds}_gt.npy', allow_pickle=True) 36 | 37 | # reference images then query images 38 | self.images = np.concatenate((self.dbImages, self.qImages)) 39 | 40 | self.num_references = len(self.dbImages) 41 | self.num_queries = len(self.qImages) 42 | 43 | 44 | def __getitem__(self, index): 45 | img = Image.open(DATASET_ROOT+self.images[index]) 46 | 47 | if self.input_transform: 48 | img = self.input_transform(img) 49 | 50 | return img, index 51 | 52 | def __len__(self): 53 | return len(self.images) -------------------------------------------------------------------------------- /dataloaders/val/SPEDDataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | # NOTE: you need to download the SPED dataset from https://surfdrive.surf.nl/files/index.php/s/sbZRXzYe3l0v67W 7 | # this link is shared and maintained by the authors of VPR_Bench: https://github.com/MubarizZaffar/VPR-Bench 8 | # the folders named ref and query should reside in DATASET_ROOT path 9 | # I hardcoded the image names and ground truth for faster evaluation 10 | # performance is exactly the same as if you use VPR-Bench. 11 | 12 | DATASET_ROOT = '../data/SPEDTEST/' 13 | GT_ROOT = './datasets/' # BECAREFUL, this is the ground truth that comes with GSV-Cities 14 | 15 | path_obj = Path(DATASET_ROOT) 16 | if not path_obj.exists(): 17 | raise Exception(f'Please make sure the path {DATASET_ROOT} to SPED dataset is correct') 18 | 19 | if not path_obj.joinpath('ref') or not path_obj.joinpath('query'): 20 | raise Exception(f'Please make sure the directories query and ref are situated in the directory {DATASET_ROOT}') 21 | 22 | class SPEDDataset(Dataset): 23 | def __init__(self, input_transform = None): 24 | 25 | 26 | self.input_transform = input_transform 27 | 28 | # reference images names 29 | self.dbImages = np.load(GT_ROOT+'SPED/SPED_dbImages.npy') 30 | 31 | # query images names 32 | self.qImages = np.load(GT_ROOT+'SPED/SPED_qImages.npy') 33 | 34 | # ground truth 35 | self.ground_truth = np.load(GT_ROOT+'SPED/SPED_gt.npy', allow_pickle=True) 36 | 37 | # reference images then query images 38 | self.images = np.concatenate((self.dbImages, self.qImages)) 39 | 40 | self.num_references = len(self.dbImages) 41 | self.num_queries = len(self.qImages) 42 | 43 | 44 | def __getitem__(self, index): 45 | img = Image.open(DATASET_ROOT+self.images[index]) 46 | 47 | if self.input_transform: 48 | img = self.input_transform(img) 49 | 50 | return img, index 51 | 52 | def __len__(self): 53 | return len(self.images) -------------------------------------------------------------------------------- /datasets/Nordland/Nordland_dbImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Nordland/Nordland_dbImages.npy -------------------------------------------------------------------------------- /datasets/Nordland/Nordland_gt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Nordland/Nordland_gt.npy -------------------------------------------------------------------------------- /datasets/Nordland/Nordland_qImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Nordland/Nordland_qImages.npy -------------------------------------------------------------------------------- /datasets/Pittsburgh/pitts250k_test_dbImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Pittsburgh/pitts250k_test_dbImages.npy -------------------------------------------------------------------------------- /datasets/Pittsburgh/pitts250k_test_gt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Pittsburgh/pitts250k_test_gt.npy -------------------------------------------------------------------------------- /datasets/Pittsburgh/pitts250k_test_qImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Pittsburgh/pitts250k_test_qImages.npy -------------------------------------------------------------------------------- /datasets/Pittsburgh/pitts30k_test_dbImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Pittsburgh/pitts30k_test_dbImages.npy -------------------------------------------------------------------------------- /datasets/Pittsburgh/pitts30k_test_gt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Pittsburgh/pitts30k_test_gt.npy -------------------------------------------------------------------------------- /datasets/Pittsburgh/pitts30k_test_qImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Pittsburgh/pitts30k_test_qImages.npy -------------------------------------------------------------------------------- /datasets/Pittsburgh/pitts30k_val_dbImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Pittsburgh/pitts30k_val_dbImages.npy -------------------------------------------------------------------------------- /datasets/Pittsburgh/pitts30k_val_gt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Pittsburgh/pitts30k_val_gt.npy -------------------------------------------------------------------------------- /datasets/Pittsburgh/pitts30k_val_qImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/Pittsburgh/pitts30k_val_qImages.npy -------------------------------------------------------------------------------- /datasets/SPED/SPED_dbImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/SPED/SPED_dbImages.npy -------------------------------------------------------------------------------- /datasets/SPED/SPED_gt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/SPED/SPED_gt.npy -------------------------------------------------------------------------------- /datasets/SPED/SPED_qImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/SPED/SPED_qImages.npy -------------------------------------------------------------------------------- /datasets/msls_val/msls_val_dbImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/msls_val/msls_val_dbImages.npy -------------------------------------------------------------------------------- /datasets/msls_val/msls_val_pIdx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/msls_val/msls_val_pIdx.npy -------------------------------------------------------------------------------- /datasets/msls_val/msls_val_qIdx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/msls_val/msls_val_qIdx.npy -------------------------------------------------------------------------------- /datasets/msls_val/msls_val_qImages.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/datasets/msls_val/msls_val_qImages.npy -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: salad 2 | channels: 3 | - defaults 4 | - xformers 5 | - pytorch 6 | - nvidia 7 | dependencies: 8 | - python=3.10 9 | - pytorch::pytorch==2.1.0 10 | - pytorch::pytorch-cuda=12.1 11 | - pytorch::torchvision==0.16.0 12 | - xformers 13 | - pip 14 | - pip: 15 | - faiss-gpu==1.7.2 16 | - pandas==2.1.3 17 | - prettytable==3.9.0 18 | - pytorch-lightning==2.1.2 19 | - pytorch-metric-learning==2.3.0 20 | - torchmetrics==1.2.0 21 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torchvision.transforms as T 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | 8 | from vpr_model import VPRModel 9 | from utils.validation import get_validation_recalls 10 | # Dataloader 11 | from dataloaders.val.NordlandDataset import NordlandDataset 12 | from dataloaders.val.MapillaryDataset import MSLS 13 | from dataloaders.val.PittsburghDataset import PittsburghDataset 14 | from dataloaders.val.SPEDDataset import SPEDDataset 15 | 16 | VAL_DATASETS = ['MSLS', 'pitts30k_test', 'pitts250k_test', 'Nordland', 'SPED'] 17 | 18 | 19 | def input_transform(image_size=None): 20 | MEAN=[0.485, 0.456, 0.406]; STD=[0.229, 0.224, 0.225] 21 | if image_size: 22 | return T.Compose([ 23 | T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR), 24 | T.ToTensor(), 25 | T.Normalize(mean=MEAN, std=STD) 26 | ]) 27 | else: 28 | return T.Compose([ 29 | T.ToTensor(), 30 | T.Normalize(mean=MEAN, std=STD) 31 | ]) 32 | 33 | def get_val_dataset(dataset_name, image_size=None): 34 | dataset_name = dataset_name.lower() 35 | transform = input_transform(image_size=image_size) 36 | 37 | if 'nordland' in dataset_name: 38 | ds = NordlandDataset(input_transform=transform) 39 | 40 | elif 'msls_test' in dataset_name: 41 | ds = MSLSTest(input_transform=transform) 42 | 43 | elif 'msls' in dataset_name: 44 | ds = MSLS(input_transform=transform) 45 | 46 | elif 'pitts' in dataset_name: 47 | ds = PittsburghDataset(which_ds=dataset_name, input_transform=transform) 48 | 49 | elif 'sped' in dataset_name: 50 | ds = SPEDDataset(input_transform=transform) 51 | else: 52 | raise ValueError 53 | 54 | num_references = ds.num_references 55 | num_queries = ds.num_queries 56 | ground_truth = ds.ground_truth 57 | return ds, num_references, num_queries, ground_truth 58 | 59 | def get_descriptors(model, dataloader, device): 60 | descriptors = [] 61 | with torch.no_grad(): 62 | with torch.autocast(device_type='cuda', dtype=torch.float16): 63 | for batch in tqdm(dataloader, 'Calculating descritptors...'): 64 | imgs, labels = batch 65 | output = model(imgs.to(device)).cpu() 66 | descriptors.append(output) 67 | 68 | return torch.cat(descriptors) 69 | 70 | def load_model(ckpt_path): 71 | model = VPRModel( 72 | backbone_arch='dinov2_vitb14', 73 | backbone_config={ 74 | 'num_trainable_blocks': 4, 75 | 'return_token': True, 76 | 'norm_layer': True, 77 | }, 78 | agg_arch='SALAD', 79 | agg_config={ 80 | 'num_channels': 768, 81 | 'num_clusters': 64, 82 | 'cluster_dim': 128, 83 | 'token_dim': 256, 84 | }, 85 | ) 86 | 87 | model.load_state_dict(torch.load(ckpt_path)['state_dict']) 88 | model = model.eval() 89 | model = model.to('cuda') 90 | print(f"Loaded model from {ckpt_path} Successfully!") 91 | return model 92 | 93 | def parse_args(): 94 | parser = argparse.ArgumentParser( 95 | description="Eval VPR model", 96 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 97 | ) 98 | # Model parameters 99 | parser.add_argument("--ckpt_path", type=str, required=True, default=None, help="Path to the checkpoint") 100 | 101 | # Datasets parameters 102 | parser.add_argument( 103 | '--val_datasets', 104 | nargs='+', 105 | default=VAL_DATASETS, 106 | help='Validation datasets to use', 107 | choices=VAL_DATASETS, 108 | ) 109 | parser.add_argument('--image_size', nargs='*', default=None, help='Image size (int, tuple or None)') 110 | parser.add_argument('--batch_size', type=int, default=512, help='Batch size') 111 | 112 | args = parser.parse_args() 113 | 114 | # Parse image size 115 | if args.image_size: 116 | if len(args.image_size) == 1: 117 | args.image_size = (args.image_size[0], args.image_size[0]) 118 | elif len(args.image_size) == 2: 119 | args.image_size = tuple(args.image_size) 120 | else: 121 | raise ValueError('Invalid image size, must be int, tuple or None') 122 | 123 | args.image_size = tuple(map(int, args.image_size)) 124 | 125 | return args 126 | 127 | 128 | if __name__ == '__main__': 129 | 130 | torch.backends.cudnn.benchmark = True 131 | 132 | args = parse_args() 133 | 134 | model = load_model(args.ckpt_path) 135 | 136 | for val_name in args.val_datasets: 137 | val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_name, args.image_size) 138 | val_loader = DataLoader(val_dataset, num_workers=16, batch_size=args.batch_size, shuffle=False, pin_memory=True) 139 | 140 | print(f'Evaluating on {val_name}') 141 | descriptors = get_descriptors(model, val_loader, 'cuda') 142 | 143 | print(f'Descriptor dimension {descriptors.shape[1]}') 144 | r_list = descriptors[ : num_references] 145 | q_list = descriptors[num_references : ] 146 | 147 | print('total_size', descriptors.shape[0], num_queries + num_references) 148 | 149 | preds = get_validation_recalls( 150 | r_list=r_list, 151 | q_list=q_list, 152 | k_values=[1, 5, 10, 15, 20, 25], 153 | gt=ground_truth, 154 | print_results=True, 155 | dataset_name=val_name, 156 | faiss_gpu=False, 157 | testing=False, 158 | ) 159 | 160 | del descriptors 161 | print('========> DONE!\n\n') 162 | 163 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch'] 2 | 3 | import torch 4 | from vpr_model import VPRModel 5 | from models.backbones.dinov2 import DINOV2_ARCHS 6 | 7 | 8 | def dinov2_salad( 9 | backbone : str = "dinov2_vitb14", 10 | pretrained=True, 11 | backbone_args=None, 12 | agg_args=None, 13 | ) -> torch.nn.Module: 14 | """Return a DINOv2 SALAD model. 15 | 16 | Args: 17 | backbone (str): DINOv2 encoder to use ('dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'). 18 | pretrained (bool): If True, returns a model pre-trained on GSV-Cities (only available for 'dinov2_vitb14'). 19 | backbone_args (dict): Arguments for the backbone (check models.backbones.dinov2). 20 | agg_args (dict): Arguments for the aggregation module (check models.aggregators.salad). 21 | Return: 22 | model (torch.nn.Module): the model. 23 | """ 24 | assert backbone in DINOV2_ARCHS.keys(), f"Parameter `backbone` is set to {backbone} but it must be one of {list(DINOV2_ARCHS.keys())}" 25 | assert not pretrained or backbone == "dinov2_vitb14", f"Parameter `pretrained` can only be set to True if backbone is 'dinov2_vitb14', but it is set to {backbone}" 26 | 27 | 28 | backbone_args = backbone_args or { 29 | 'num_trainable_blocks': 4, 30 | 'return_token': True, 31 | 'norm_layer': True, 32 | } 33 | agg_args = agg_args or { 34 | 'num_channels': DINOV2_ARCHS[backbone], 35 | 'num_clusters': 64, 36 | 'cluster_dim': 128, 37 | 'token_dim': 256, 38 | } 39 | model = VPRModel( 40 | backbone_arch=backbone, 41 | backbone_config=backbone_args, 42 | agg_arch='SALAD', 43 | agg_config=agg_args, 44 | ) 45 | model.load_state_dict( 46 | torch.hub.load_state_dict_from_url( 47 | f'https://github.com/serizba/salad/releases/download/v1.0.0/dino_salad.ckpt', 48 | map_location=torch.device('cpu') 49 | ) 50 | ) 51 | return model -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | from vpr_model import VPRModel 4 | from dataloaders.GSVCitiesDataloader import GSVCitiesDataModule 5 | 6 | if __name__ == '__main__': 7 | datamodule = GSVCitiesDataModule( 8 | batch_size=30, 9 | img_per_place=4, 10 | min_img_per_place=4, 11 | shuffle_all=False, # shuffle all images or keep shuffling in-city only 12 | random_sample_from_each_place=True, 13 | image_size=(224, 224), 14 | num_workers=10, 15 | show_data_stats=True, 16 | val_set_names=['pitts30k_val', 'pitts30k_test', 'msls_val'], # pitts30k_val, pitts30k_test, msls_val 17 | clique_mapillary_args={ 18 | 'same_place_threshold': 25.0, 19 | # We create more batches than required so 20 | # that we can shuffle the dataset after each epoch 21 | 'num_batches': 4000, 22 | 'num_processes': 10, 23 | } 24 | ) 25 | 26 | model = VPRModel( 27 | #---- Encoder 28 | backbone_arch='dinov2_vitb14', 29 | backbone_config={ 30 | 'num_trainable_blocks': 4, 31 | 'return_token': True, 32 | 'norm_layer': True, 33 | }, 34 | agg_arch='SALAD', 35 | agg_config={ 36 | 'num_channels': 768, 37 | 'num_clusters': 64, 38 | 'cluster_dim': 128, 39 | 'token_dim': 256, 40 | }, 41 | lr = 6e-5, 42 | optimizer='adamw', 43 | weight_decay=9.5e-9, # 0.001 for sgd and 0 for adam, 44 | momentum=0.9, 45 | lr_sched='linear', 46 | lr_sched_args = { 47 | 'start_factor': 1, 48 | 'end_factor': 0.2, 49 | 'total_iters': 4000, 50 | }, 51 | 52 | #----- Loss functions 53 | # example: ContrastiveLoss, TripletMarginLoss, MultiSimilarityLoss, 54 | # FastAPLoss, CircleLoss, SupConLoss, 55 | loss_name='MultiSimilarityLoss', 56 | miner_name='MultiSimilarityMiner', # example: TripletMarginMiner, MultiSimilarityMiner, PairMarginMiner 57 | miner_margin=0.1, 58 | faiss_gpu=False 59 | ) 60 | 61 | # model params saving using Pytorch Lightning 62 | # we save the best 3 models accoring to Recall@1 on pittsburg val 63 | checkpoint_cb = pl.callbacks.ModelCheckpoint( 64 | monitor='pitts30k_val/R1', 65 | filename=f'{model.encoder_arch}' + '_({epoch:02d})_R1[{pitts30k_val/R1:.4f}]_R5[{pitts30k_val/R5:.4f}]', 66 | auto_insert_metric_name=False, 67 | save_weights_only=True, 68 | save_top_k=3, 69 | save_last=True, 70 | mode='max' 71 | ) 72 | 73 | #------------------ 74 | # we instanciate a trainer 75 | trainer = pl.Trainer( 76 | accelerator='gpu', 77 | devices=1, 78 | default_root_dir=f'./logs/', # Tensorflow can be used to viz 79 | num_nodes=1, 80 | num_sanity_val_steps=0, # runs a validation step before stating training 81 | precision='16-mixed', # we use half precision to reduce memory usage 82 | max_epochs=4, 83 | check_val_every_n_epoch=1, # run validation every epoch 84 | callbacks=[checkpoint_cb],# we only run the checkpointing callback (you can add more) 85 | reload_dataloaders_every_n_epochs=1, # we reload the dataset to shuffle the order 86 | log_every_n_steps=20, 87 | ) 88 | 89 | # we call the trainer, we give it the model and the datamodule 90 | trainer.fit(model=model, datamodule=datamodule) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serizba/cliquemining/f7c4e8f31c49d60b8c82b11b703051d8973cbe72/models/__init__.py -------------------------------------------------------------------------------- /models/aggregators/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosplace import CosPlace 2 | from .convap import ConvAP 3 | from .gem import GeMPool 4 | from .mixvpr import MixVPR 5 | from .salad import SALAD 6 | -------------------------------------------------------------------------------- /models/aggregators/convap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | class ConvAP(nn.Module): 7 | """Implementation of ConvAP as of https://arxiv.org/pdf/2210.10239.pdf 8 | 9 | Args: 10 | in_channels (int): number of channels in the input of ConvAP 11 | out_channels (int, optional): number of channels that ConvAP outputs. Defaults to 512. 12 | s1 (int, optional): spatial height of the adaptive average pooling. Defaults to 2. 13 | s2 (int, optional): spatial width of the adaptive average pooling. Defaults to 2. 14 | """ 15 | def __init__(self, in_channels, out_channels=512, s1=2, s2=2): 16 | super(ConvAP, self).__init__() 17 | self.channel_pool = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True) 18 | self.AAP = nn.AdaptiveAvgPool2d((s1, s2)) 19 | 20 | def forward(self, x): 21 | x = self.channel_pool(x) 22 | x = self.AAP(x) 23 | x = F.normalize(x.flatten(1), p=2, dim=1) 24 | return x 25 | 26 | 27 | if __name__ == '__main__': 28 | x = torch.randn(4, 2048, 10, 10) 29 | m = ConvAP(2048, 512) 30 | r = m(x) 31 | print(r.shape) -------------------------------------------------------------------------------- /models/aggregators/cosplace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class GeM(nn.Module): 6 | """Implementation of GeM as in https://github.com/filipradenovic/cnnimageretrieval-pytorch 7 | """ 8 | def __init__(self, p=3, eps=1e-6): 9 | super().__init__() 10 | self.p = nn.Parameter(torch.ones(1)*p) 11 | self.eps = eps 12 | 13 | def forward(self, x): 14 | return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p) 15 | 16 | class CosPlace(nn.Module): 17 | """ 18 | CosPlace aggregation layer as implemented in https://github.com/gmberton/CosPlace/blob/main/model/network.py 19 | 20 | Args: 21 | in_dim: number of channels of the input 22 | out_dim: dimension of the output descriptor 23 | """ 24 | def __init__(self, in_dim, out_dim): 25 | super().__init__() 26 | self.gem = GeM() 27 | self.fc = nn.Linear(in_dim, out_dim) 28 | 29 | def forward(self, x): 30 | x = F.normalize(x, p=2, dim=1) 31 | x = self.gem(x) 32 | x = x.flatten(1) 33 | x = self.fc(x) 34 | x = F.normalize(x, p=2, dim=1) 35 | return x 36 | 37 | if __name__ == '__main__': 38 | x = torch.randn(4, 2048, 10, 10) 39 | m = CosPlace(2048, 512) 40 | r = m(x) 41 | print(r.shape) -------------------------------------------------------------------------------- /models/aggregators/gem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class GeMPool(nn.Module): 6 | """Implementation of GeM as in https://github.com/filipradenovic/cnnimageretrieval-pytorch 7 | we add flatten and norm so that we can use it as one aggregation layer. 8 | """ 9 | def __init__(self, p=3, eps=1e-6): 10 | super().__init__() 11 | self.p = nn.Parameter(torch.ones(1)*p) 12 | self.eps = eps 13 | 14 | def forward(self, x): 15 | x = F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p) 16 | x = x.flatten(1) 17 | return F.normalize(x, p=2, dim=1) -------------------------------------------------------------------------------- /models/aggregators/mixvpr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | import numpy as np 6 | 7 | 8 | class FeatureMixerLayer(nn.Module): 9 | def __init__(self, in_dim, mlp_ratio=1): 10 | super().__init__() 11 | self.mix = nn.Sequential( 12 | nn.LayerNorm(in_dim), 13 | nn.Linear(in_dim, int(in_dim * mlp_ratio)), 14 | nn.ReLU(), 15 | nn.Linear(int(in_dim * mlp_ratio), in_dim), 16 | ) 17 | 18 | for m in self.modules(): 19 | if isinstance(m, (nn.Linear)): 20 | nn.init.trunc_normal_(m.weight, std=0.02) 21 | if m.bias is not None: 22 | nn.init.zeros_(m.bias) 23 | 24 | def forward(self, x): 25 | return x + self.mix(x) 26 | 27 | 28 | class MixVPR(nn.Module): 29 | def __init__(self, 30 | in_channels=1024, 31 | in_h=20, 32 | in_w=20, 33 | out_channels=512, 34 | mix_depth=1, 35 | mlp_ratio=1, 36 | out_rows=4, 37 | ) -> None: 38 | super().__init__() 39 | 40 | self.in_h = in_h # height of input feature maps 41 | self.in_w = in_w # width of input feature maps 42 | self.in_channels = in_channels # depth of input feature maps 43 | 44 | self.out_channels = out_channels # depth wise projection dimension 45 | self.out_rows = out_rows # row wise projection dimesion 46 | 47 | self.mix_depth = mix_depth # L the number of stacked FeatureMixers 48 | self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block 49 | 50 | hw = in_h*in_w 51 | self.mix = nn.Sequential(*[ 52 | FeatureMixerLayer(in_dim=hw, mlp_ratio=mlp_ratio) 53 | for _ in range(self.mix_depth) 54 | ]) 55 | self.channel_proj = nn.Linear(in_channels, out_channels) 56 | self.row_proj = nn.Linear(hw, out_rows) 57 | 58 | def forward(self, x): 59 | x = x.flatten(2) 60 | x = self.mix(x) 61 | x = x.permute(0, 2, 1) 62 | x = self.channel_proj(x) 63 | x = x.permute(0, 2, 1) 64 | x = self.row_proj(x) 65 | x = F.normalize(x.flatten(1), p=2, dim=-1) 66 | return x 67 | 68 | 69 | # ------------------------------------------------------------------------------- 70 | 71 | def print_nb_params(m): 72 | model_parameters = filter(lambda p: p.requires_grad, m.parameters()) 73 | params = sum([np.prod(p.size()) for p in model_parameters]) 74 | print(f'Trainable parameters: {params/1e6:.3}M') 75 | 76 | 77 | def main(): 78 | x = torch.randn(1, 1024, 20, 20) 79 | agg = MixVPR( 80 | in_channels=1024, 81 | in_h=20, 82 | in_w=20, 83 | out_channels=1024, 84 | mix_depth=4, 85 | mlp_ratio=1, 86 | out_rows=4) 87 | 88 | print_nb_params(agg) 89 | output = agg(x) 90 | print(output.shape) 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /models/aggregators/salad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Code from SuperGlue (https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/models/superglue.py) 5 | def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor: 6 | """ Perform Sinkhorn Normalization in Log-space for stability""" 7 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 8 | for _ in range(iters): 9 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) 10 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) 11 | return Z + u.unsqueeze(2) + v.unsqueeze(1) 12 | 13 | # Code from SuperGlue (https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/models/superglue.py) 14 | def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor: 15 | """ Perform Differentiable Optimal Transport in Log-space for stability""" 16 | b, m, n = scores.shape 17 | one = scores.new_tensor(1) 18 | ms, ns, bs = (m*one).to(scores), (n*one).to(scores), ((n-m)*one).to(scores) 19 | 20 | bins = alpha.expand(b, 1, n) 21 | alpha = alpha.expand(b, 1, 1) 22 | 23 | couplings = torch.cat([scores, bins], 1) 24 | 25 | norm = - (ms + ns).log() 26 | log_mu = torch.cat([norm.expand(m), bs.log()[None] + norm]) 27 | log_nu = norm.expand(n) 28 | log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) 29 | 30 | Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) 31 | Z = Z - norm # multiply probabilities by M+N 32 | return Z 33 | 34 | 35 | class SALAD(nn.Module): 36 | """ 37 | This class represents the Sinkhorn Algorithm for Locally Aggregated Descriptors (SALAD) model. 38 | 39 | Attributes: 40 | num_channels (int): The number of channels of the inputs (d). 41 | num_clusters (int): The number of clusters in the model (m). 42 | cluster_dim (int): The number of channels of the clusters (l). 43 | token_dim (int): The dimension of the global scene token (g). 44 | dropout (float): The dropout rate. 45 | """ 46 | def __init__(self, 47 | num_channels=1536, 48 | num_clusters=64, 49 | cluster_dim=128, 50 | token_dim=256, 51 | dropout=0.3, 52 | ) -> None: 53 | super().__init__() 54 | 55 | self.num_channels = num_channels 56 | self.num_clusters= num_clusters 57 | self.cluster_dim = cluster_dim 58 | self.token_dim = token_dim 59 | 60 | if dropout > 0: 61 | dropout = nn.Dropout(dropout) 62 | else: 63 | dropout = nn.Identity() 64 | 65 | # MLP for global scene token g 66 | self.token_features = nn.Sequential( 67 | nn.Linear(self.num_channels, 512), 68 | nn.ReLU(), 69 | nn.Linear(512, self.token_dim) 70 | ) 71 | # MLP for local features f_i 72 | self.cluster_features = nn.Sequential( 73 | nn.Conv2d(self.num_channels, 512, 1), 74 | dropout, 75 | nn.ReLU(), 76 | nn.Conv2d(512, self.cluster_dim, 1) 77 | ) 78 | # MLP for score matrix S 79 | self.score = nn.Sequential( 80 | nn.Conv2d(self.num_channels, 512, 1), 81 | dropout, 82 | nn.ReLU(), 83 | nn.Conv2d(512, self.num_clusters, 1), 84 | ) 85 | # Dustbin parameter z 86 | self.dust_bin = nn.Parameter(torch.tensor(1.)) 87 | 88 | 89 | def forward(self, x): 90 | """ 91 | x (tuple): A tuple containing two elements, f and t. 92 | (torch.Tensor): The feature tensors (t_i) [B, C, H // 14, W // 14]. 93 | (torch.Tensor): The token tensor (t_{n+1}) [B, C]. 94 | 95 | Returns: 96 | f (torch.Tensor): The global descriptor [B, m*l + g] 97 | """ 98 | x, t = x # Extract features and token 99 | 100 | f = self.cluster_features(x).flatten(2) 101 | p = self.score(x).flatten(2) 102 | t = self.token_features(t) 103 | 104 | # Sinkhorn algorithm 105 | p = log_optimal_transport(p, self.dust_bin, 3) 106 | p = torch.exp(p) 107 | # Normalize to maintain mass 108 | p = p[:, :-1, :] 109 | 110 | 111 | p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1) 112 | f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1) 113 | 114 | f = torch.cat([ 115 | nn.functional.normalize(t, p=2, dim=-1), 116 | nn.functional.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1) 117 | ], dim=-1) 118 | 119 | return nn.functional.normalize(f, p=2, dim=-1) 120 | -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet 2 | from .dinov2 import DINOv2 3 | -------------------------------------------------------------------------------- /models/backbones/dinov2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | DINOV2_ARCHS = { 5 | 'dinov2_vits14': 384, 6 | 'dinov2_vitb14': 768, 7 | 'dinov2_vitl14': 1024, 8 | 'dinov2_vitg14': 1536, 9 | } 10 | 11 | class DINOv2(nn.Module): 12 | """ 13 | DINOv2 model 14 | 15 | Args: 16 | model_name (str): The name of the model architecture 17 | should be one of ('dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14') 18 | num_trainable_blocks (int): The number of last blocks in the model that are trainable. 19 | norm_layer (bool): If True, a normalization layer is applied in the forward pass. 20 | return_token (bool): If True, the forward pass returns both the feature map and the token. 21 | """ 22 | def __init__( 23 | self, 24 | model_name='dinov2_vitb14', 25 | num_trainable_blocks=2, 26 | norm_layer=False, 27 | return_token=False 28 | ): 29 | super().__init__() 30 | 31 | assert model_name in DINOV2_ARCHS.keys(), f'Unknown model name {model_name}' 32 | self.model = torch.hub.load('facebookresearch/dinov2', model_name) 33 | self.num_channels = DINOV2_ARCHS[model_name] 34 | self.num_trainable_blocks = num_trainable_blocks 35 | self.norm_layer = norm_layer 36 | self.return_token = return_token 37 | 38 | 39 | def forward(self, x): 40 | """ 41 | The forward method for the DINOv2 class 42 | 43 | Parameters: 44 | x (torch.Tensor): The input tensor [B, 3, H, W]. H and W should be divisible by 14. 45 | 46 | Returns: 47 | f (torch.Tensor): The feature map [B, C, H // 14, W // 14]. 48 | t (torch.Tensor): The token [B, C]. This is only returned if return_token is True. 49 | """ 50 | 51 | B, C, H, W = x.shape 52 | 53 | x = self.model.prepare_tokens_with_masks(x) 54 | 55 | # First blocks are frozen 56 | with torch.no_grad(): 57 | for blk in self.model.blocks[:-self.num_trainable_blocks]: 58 | x = blk(x) 59 | x = x.detach() 60 | 61 | # Last blocks are trained 62 | for blk in self.model.blocks[-self.num_trainable_blocks:]: 63 | x = blk(x) 64 | 65 | if self.norm_layer: 66 | x = self.model.norm(x) 67 | 68 | t = x[:, 0] 69 | f = x[:, 1:] 70 | 71 | # Reshape to (B, C, H, W) 72 | f = f.reshape((B, H // 14, W // 14, self.num_channels)).permute(0, 3, 1, 2) 73 | 74 | if self.return_token: 75 | return f, t 76 | return f 77 | -------------------------------------------------------------------------------- /models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import numpy as np 5 | 6 | class ResNet(nn.Module): 7 | def __init__(self, 8 | model_name='resnet50', 9 | pretrained=True, 10 | layers_to_freeze=2, 11 | layers_to_crop=[], 12 | ): 13 | """Class representing the resnet backbone used in the pipeline 14 | we consider resnet network as a list of 5 blocks (from 0 to 4), 15 | layer 0 is the first conv+bn and the other layers (1 to 4) are the rest of the residual blocks 16 | we don't take into account the global pooling and the last fc 17 | 18 | Args: 19 | model_name (str, optional): The architecture of the resnet backbone to instanciate. Defaults to 'resnet50'. 20 | pretrained (bool, optional): Whether pretrained or not. Defaults to True. 21 | layers_to_freeze (int, optional): The number of residual blocks to freeze (starting from 0) . Defaults to 2. 22 | layers_to_crop (list, optional): Which residual layers to crop, for example [3,4] will crop the third and fourth res blocks. Defaults to []. 23 | 24 | Raises: 25 | NotImplementedError: if the model_name corresponds to an unknown architecture. 26 | """ 27 | super().__init__() 28 | self.model_name = model_name.lower() 29 | self.layers_to_freeze = layers_to_freeze 30 | 31 | if pretrained: 32 | # the new naming of pretrained weights, you can change to V2 if desired. 33 | weights = 'IMAGENET1K_V1' 34 | else: 35 | weights = None 36 | 37 | if 'swsl' in model_name or 'ssl' in model_name: 38 | # These are the semi supervised and weakly semi supervised weights from Facebook 39 | self.model = torch.hub.load( 40 | 'facebookresearch/semi-supervised-ImageNet1K-models', model_name) 41 | else: 42 | if 'resnext50' in model_name: 43 | self.model = torchvision.models.resnext50_32x4d(weights=weights) 44 | elif 'resnet50' in model_name: 45 | self.model = torchvision.models.resnet50(weights=weights) 46 | elif '101' in model_name: 47 | self.model = torchvision.models.resnet101(weights=weights) 48 | elif '152' in model_name: 49 | self.model = torchvision.models.resnet152(weights=weights) 50 | elif '34' in model_name: 51 | self.model = torchvision.models.resnet34(weights=weights) 52 | elif '18' in model_name: 53 | # self.model = torchvision.models.resnet18(pretrained=False) 54 | self.model = torchvision.models.resnet18(weights=weights) 55 | elif 'wide_resnet50_2' in model_name: 56 | self.model = torchvision.models.wide_resnet50_2(weights=weights) 57 | else: 58 | raise NotImplementedError( 59 | 'Backbone architecture not recognized!') 60 | 61 | # freeze only if the model is pretrained 62 | if pretrained: 63 | if layers_to_freeze >= 0: 64 | self.model.conv1.requires_grad_(False) 65 | self.model.bn1.requires_grad_(False) 66 | if layers_to_freeze >= 1: 67 | self.model.layer1.requires_grad_(False) 68 | if layers_to_freeze >= 2: 69 | self.model.layer2.requires_grad_(False) 70 | if layers_to_freeze >= 3: 71 | self.model.layer3.requires_grad_(False) 72 | 73 | # remove the avgpool and most importantly the fc layer 74 | self.model.avgpool = None 75 | self.model.fc = None 76 | 77 | if 4 in layers_to_crop: 78 | self.model.layer4 = None 79 | if 3 in layers_to_crop: 80 | self.model.layer3 = None 81 | 82 | out_channels = 2048 83 | if '34' in model_name or '18' in model_name: 84 | out_channels = 512 85 | 86 | self.out_channels = out_channels // 2 if self.model.layer4 is None else out_channels 87 | self.out_channels = self.out_channels // 2 if self.model.layer3 is None else self.out_channels 88 | 89 | def forward(self, x): 90 | x = self.model.conv1(x) 91 | x = self.model.bn1(x) 92 | x = self.model.relu(x) 93 | x = self.model.maxpool(x) 94 | x = self.model.layer1(x) 95 | x = self.model.layer2(x) 96 | if self.model.layer3 is not None: 97 | x = self.model.layer3(x) 98 | if self.model.layer4 is not None: 99 | x = self.model.layer4(x) 100 | return x 101 | -------------------------------------------------------------------------------- /models/helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from models import aggregators 3 | from models import backbones 4 | 5 | 6 | def get_backbone( 7 | backbone_arch='resnet50', 8 | backbone_config={} 9 | ): 10 | """Helper function that returns the backbone given its name 11 | 12 | Args: 13 | backbone_arch (str, optional): . Defaults to 'resnet50'. 14 | backbone_config (dict, optional): this must contain all the arguments needed to instantiate the backbone class. Defaults to {}. 15 | 16 | Returns: 17 | nn.Module: the backbone as a nn.Model object 18 | """ 19 | if 'resnet' in backbone_arch.lower(): 20 | return backbones.ResNet(backbone_arch, **backbone_config) 21 | 22 | elif 'dinov2' in backbone_arch.lower(): 23 | return backbones.DINOv2(model_name=backbone_arch, **backbone_config) 24 | 25 | 26 | def get_aggregator(agg_arch='ConvAP', agg_config={}): 27 | """Helper function that returns the aggregation layer given its name. 28 | If you happen to make your own aggregator, you might need to add a call 29 | to this helper function. 30 | 31 | Args: 32 | agg_arch (str, optional): the name of the aggregator. Defaults to 'ConvAP'. 33 | agg_config (dict, optional): this must contain all the arguments needed to instantiate the aggregator class. Defaults to {}. 34 | 35 | Returns: 36 | nn.Module: the aggregation layer 37 | """ 38 | 39 | if 'cosplace' in agg_arch.lower(): 40 | assert 'in_dim' in agg_config 41 | assert 'out_dim' in agg_config 42 | return aggregators.CosPlace(**agg_config) 43 | 44 | elif 'gem' in agg_arch.lower(): 45 | if agg_config == {}: 46 | agg_config['p'] = 3 47 | else: 48 | assert 'p' in agg_config 49 | return aggregators.GeMPool(**agg_config) 50 | 51 | elif 'convap' in agg_arch.lower(): 52 | assert 'in_channels' in agg_config 53 | return aggregators.ConvAP(**agg_config) 54 | 55 | elif 'mixvpr' in agg_arch.lower(): 56 | assert 'in_channels' in agg_config 57 | assert 'out_channels' in agg_config 58 | assert 'in_h' in agg_config 59 | assert 'in_w' in agg_config 60 | assert 'mix_depth' in agg_config 61 | return aggregators.MixVPR(**agg_config) 62 | 63 | elif 'salad' in agg_arch.lower(): 64 | assert 'num_channels' in agg_config 65 | assert 'num_clusters' in agg_config 66 | assert 'cluster_dim' in agg_config 67 | assert 'token_dim' in agg_config 68 | return aggregators.SALAD(**agg_config) 69 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import get_miner, get_loss 2 | from .validation import get_validation_recalls 3 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | from pytorch_metric_learning import losses, miners 2 | from pytorch_metric_learning.distances import CosineSimilarity, DotProductSimilarity 3 | 4 | def get_loss(loss_name): 5 | if loss_name == 'SupConLoss': return losses.SupConLoss(temperature=0.07) 6 | if loss_name == 'CircleLoss': return losses.CircleLoss(m=0.4, gamma=80) #these are params for image retrieval 7 | if loss_name == 'MultiSimilarityLoss': return losses.MultiSimilarityLoss(alpha=1.0, beta=50, base=0.0, distance=DotProductSimilarity()) 8 | if loss_name == 'ContrastiveLoss': return losses.ContrastiveLoss(pos_margin=0, neg_margin=1) 9 | if loss_name == 'Lifted': return losses.GeneralizedLiftedStructureLoss(neg_margin=0, pos_margin=1, distance=DotProductSimilarity()) 10 | if loss_name == 'FastAPLoss': return losses.FastAPLoss(num_bins=30) 11 | if loss_name == 'NTXentLoss': return losses.NTXentLoss(temperature=0.07) #The MoCo paper uses 0.07, while SimCLR uses 0.5. 12 | if loss_name == 'TripletMarginLoss': return losses.TripletMarginLoss(margin=0.1, swap=False, smooth_loss=False, triplets_per_anchor='all') #or an int, for example 100 13 | if loss_name == 'CentroidTripletLoss': return losses.CentroidTripletLoss(margin=0.05, 14 | swap=False, 15 | smooth_loss=False, 16 | triplets_per_anchor="all",) 17 | raise NotImplementedError(f'Sorry, <{loss_name}> loss function is not implemented!') 18 | 19 | def get_miner(miner_name, margin=0.1): 20 | if miner_name == 'TripletMarginMiner' : return miners.TripletMarginMiner(margin=margin, type_of_triplets="semihard") # all, hard, semihard, easy 21 | if miner_name == 'MultiSimilarityMiner' : return miners.MultiSimilarityMiner(epsilon=margin, distance=CosineSimilarity()) 22 | if miner_name == 'PairMarginMiner' : return miners.PairMarginMiner(pos_margin=0.7, neg_margin=0.3, distance=DotProductSimilarity()) 23 | return None 24 | -------------------------------------------------------------------------------- /utils/validation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import faiss 3 | import faiss.contrib.torch_utils 4 | from prettytable import PrettyTable 5 | 6 | 7 | def get_validation_recalls(r_list, q_list, k_values, gt, print_results=True, faiss_gpu=False, dataset_name='dataset without name ?', testing=False): 8 | 9 | embed_size = r_list.shape[1] 10 | if faiss_gpu: 11 | res = faiss.StandardGpuResources() 12 | flat_config = faiss.GpuIndexFlatConfig() 13 | flat_config.useFloat16 = True 14 | flat_config.device = 0 15 | faiss_index = faiss.GpuIndexFlatL2(res, embed_size, flat_config) 16 | # build index 17 | else: 18 | faiss_index = faiss.IndexFlatL2(embed_size) 19 | 20 | # add references 21 | faiss_index.add(r_list) 22 | 23 | # search for queries in the index 24 | _, predictions = faiss_index.search(q_list, max(k_values)) 25 | 26 | if testing: 27 | return predictions 28 | 29 | # start calculating recall_at_k 30 | correct_at_k = np.zeros(len(k_values)) 31 | for q_idx, pred in enumerate(predictions): 32 | for i, n in enumerate(k_values): 33 | # if in top N then also in top NN, where NN > N 34 | if np.any(np.in1d(pred[:n], gt[q_idx])): 35 | correct_at_k[i:] += 1 36 | break 37 | 38 | correct_at_k = correct_at_k / len(predictions) 39 | d = {k:v for (k,v) in zip(k_values, correct_at_k)} 40 | 41 | if print_results: 42 | print() # print a new line 43 | table = PrettyTable() 44 | table.field_names = ['K']+[str(k) for k in k_values] 45 | table.add_row(['Recall@K']+ [f'{100*v:.2f}' for v in correct_at_k]) 46 | print(table.get_string(title=f"Performances on {dataset_name}")) 47 | 48 | return d 49 | -------------------------------------------------------------------------------- /vpr_model.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch.optim import lr_scheduler, optimizer 4 | 5 | import utils 6 | from models import helper 7 | 8 | 9 | class VPRModel(pl.LightningModule): 10 | """This is the main model for Visual Place Recognition 11 | we use Pytorch Lightning for modularity purposes. 12 | 13 | Args: 14 | pl (_type_): _description_ 15 | """ 16 | 17 | def __init__(self, 18 | #---- Backbone 19 | backbone_arch='resnet50', 20 | backbone_config={}, 21 | 22 | #---- Aggregator 23 | agg_arch='ConvAP', 24 | agg_config={}, 25 | 26 | #---- Train hyperparameters 27 | lr=0.03, 28 | optimizer='sgd', 29 | weight_decay=1e-3, 30 | momentum=0.9, 31 | lr_sched='linear', 32 | lr_sched_args = { 33 | 'start_factor': 1, 34 | 'end_factor': 0.2, 35 | 'total_iters': 4000, 36 | }, 37 | 38 | #----- Loss 39 | loss_name='MultiSimilarityLoss', 40 | miner_name='MultiSimilarityMiner', 41 | miner_margin=0.1, 42 | faiss_gpu=False 43 | ): 44 | super().__init__() 45 | 46 | # Backbone 47 | self.encoder_arch = backbone_arch 48 | self.backbone_config = backbone_config 49 | 50 | # Aggregator 51 | self.agg_arch = agg_arch 52 | self.agg_config = agg_config 53 | 54 | # Train hyperparameters 55 | self.lr = lr 56 | self.optimizer = optimizer 57 | self.weight_decay = weight_decay 58 | self.momentum = momentum 59 | self.lr_sched = lr_sched 60 | self.lr_sched_args = lr_sched_args 61 | 62 | # Loss 63 | self.loss_name = loss_name 64 | self.miner_name = miner_name 65 | self.miner_margin = miner_margin 66 | 67 | self.save_hyperparameters() # write hyperparams into a file 68 | 69 | self.loss_fn = utils.get_loss(loss_name) 70 | self.miner = utils.get_miner(miner_name, miner_margin) 71 | self.batch_acc = [] # we will keep track of the % of trivial pairs/triplets at the loss level 72 | 73 | self.faiss_gpu = faiss_gpu 74 | 75 | # ---------------------------------- 76 | # get the backbone and the aggregator 77 | self.backbone = helper.get_backbone(backbone_arch, backbone_config) 78 | self.aggregator = helper.get_aggregator(agg_arch, agg_config) 79 | 80 | # For validation in Lightning v2.0.0 81 | self.val_outputs = [] 82 | 83 | # the forward pass of the lightning model 84 | def forward(self, x): 85 | x = self.backbone(x) 86 | x = self.aggregator(x) 87 | return x 88 | 89 | # configure the optimizer 90 | def configure_optimizers(self): 91 | if self.optimizer.lower() == 'sgd': 92 | optimizer = torch.optim.SGD( 93 | self.parameters(), 94 | lr=self.lr, 95 | weight_decay=self.weight_decay, 96 | momentum=self.momentum 97 | ) 98 | elif self.optimizer.lower() == 'adamw': 99 | optimizer = torch.optim.AdamW( 100 | self.parameters(), 101 | lr=self.lr, 102 | weight_decay=self.weight_decay 103 | ) 104 | elif self.optimizer.lower() == 'adam': 105 | optimizer = torch.optim.AdamW( 106 | self.parameters(), 107 | lr=self.lr, 108 | weight_decay=self.weight_decay 109 | ) 110 | else: 111 | raise ValueError(f'Optimizer {self.optimizer} has not been added to "configure_optimizers()"') 112 | 113 | 114 | if self.lr_sched.lower() == 'multistep': 115 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_sched_args['milestones'], gamma=self.lr_sched_args['gamma']) 116 | elif self.lr_sched.lower() == 'cosine': 117 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, self.lr_sched_args['T_max']) 118 | elif self.lr_sched.lower() == 'linear': 119 | scheduler = lr_scheduler.LinearLR( 120 | optimizer, 121 | start_factor=self.lr_sched_args['start_factor'], 122 | end_factor=self.lr_sched_args['end_factor'], 123 | total_iters=self.lr_sched_args['total_iters'] 124 | ) 125 | 126 | return [optimizer], [scheduler] 127 | 128 | # configure the optizer step, takes into account the warmup stage 129 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): 130 | # warm up lr 131 | optimizer.step(closure=optimizer_closure) 132 | self.lr_schedulers().step() 133 | 134 | # The loss function call (this method will be called at each training iteration) 135 | def loss_function(self, descriptors, labels): 136 | # we mine the pairs/triplets if there is an online mining strategy 137 | if self.miner is not None: 138 | miner_outputs = self.miner(descriptors, labels) 139 | loss = self.loss_fn(descriptors, labels, miner_outputs) 140 | 141 | # calculate the % of trivial pairs/triplets 142 | # which do not contribute in the loss value 143 | nb_samples = descriptors.shape[0] 144 | nb_mined = len(set(miner_outputs[0].detach().cpu().numpy())) 145 | batch_acc = 1.0 - (nb_mined/nb_samples) 146 | 147 | else: # no online mining 148 | loss = self.loss_fn(descriptors, labels) 149 | batch_acc = 0.0 150 | if type(loss) == tuple: 151 | # somes losses do the online mining inside (they don't need a miner objet), 152 | # so they return the loss and the batch accuracy 153 | # for example, if you are developping a new loss function, you might be better 154 | # doing the online mining strategy inside the forward function of the loss class, 155 | # and return a tuple containing the loss value and the batch_accuracy (the % of valid pairs or triplets) 156 | loss, batch_acc = loss 157 | 158 | # keep accuracy of every batch and later reset it at epoch start 159 | self.batch_acc.append(batch_acc) 160 | # log it 161 | self.log('b_acc', sum(self.batch_acc) / 162 | len(self.batch_acc), prog_bar=True, logger=True) 163 | return loss 164 | 165 | # This is the training step that's executed at each iteration 166 | def training_step(self, batch, batch_idx): 167 | # places, labels, types = batch 168 | 169 | places_1, labels_1 = batch['GSVCities'] 170 | places_2, labels_2 = batch['MSLS'] 171 | 172 | BS, N, ch, h, w = places_1.shape 173 | 174 | # Labels 2 should be adjusted to be unique 175 | labels_2 += labels_1.max() + 1 176 | 177 | # Note that GSVCities yields places (each containing N images) 178 | # which means the dataloader will return a batch containing BS places 179 | images = torch.concat([places_1, places_2], dim=0).view((places_1.size(0) + places_2.size(0))*N, ch, h, w) 180 | labels = torch.concat([labels_1, labels_2], dim=0).view(-1) 181 | 182 | # Feed forward the batch to the model 183 | descriptors = self(images) # Here we are calling the method forward that we defined above 184 | 185 | if torch.isnan(descriptors).any(): 186 | raise ValueError('NaNs in descriptors') 187 | 188 | # Split loss 189 | loss_1 = self.loss_function(descriptors[:places_1.size(0)*N], labels[:places_1.size(0)*N]) 190 | loss_2 = self.loss_function(descriptors[places_2.size(0)*N:], labels[places_2.size(0)*N:]) 191 | loss = loss_1 + loss_2 192 | 193 | self.log('loss', loss.item(), logger=True, prog_bar=True) 194 | return {'loss': loss} 195 | 196 | def on_train_epoch_end(self): 197 | # we empty the batch_acc list for next epoch 198 | self.batch_acc = [] 199 | 200 | # For validation, we will also iterate step by step over the validation set 201 | # this is the way Pytorch Lghtning is made. All about modularity, folks. 202 | def validation_step(self, batch, batch_idx, dataloader_idx=None): 203 | places, _ = batch 204 | descriptors = self(places) 205 | self.val_outputs[dataloader_idx].append(descriptors.detach().cpu()) 206 | return descriptors.detach().cpu() 207 | 208 | def on_validation_epoch_start(self): 209 | # reset the outputs list 210 | self.val_outputs = [[] for _ in range(len(self.trainer.datamodule.val_datasets))] 211 | 212 | def on_validation_epoch_end(self): 213 | """this return descriptors in their order 214 | depending on how the validation dataset is implemented 215 | for this project (MSLS val, Pittburg val), it is always references then queries 216 | [R1, R2, ..., Rn, Q1, Q2, ...] 217 | """ 218 | val_step_outputs = self.val_outputs 219 | 220 | dm = self.trainer.datamodule 221 | # The following line is a hack: if we have only one validation set, then 222 | # we need to put the outputs in a list (Pytorch Lightning does not do it presently) 223 | if len(dm.val_datasets)==1: # we need to put the outputs in a list 224 | val_step_outputs = [val_step_outputs] 225 | 226 | for i, (val_set_name, val_dataset) in enumerate(zip(dm.val_set_names, dm.val_datasets)): 227 | feats = torch.concat(val_step_outputs[i], dim=0) 228 | 229 | if 'pitts' in val_set_name: 230 | # split to ref and queries 231 | num_references = val_dataset.dbStruct.numDb 232 | positives = val_dataset.getPositives() 233 | elif 'msls' in val_set_name: 234 | # split to ref and queries 235 | num_references = val_dataset.num_references 236 | positives = val_dataset.pIdx 237 | else: 238 | print(f'Please implement validation_epoch_end for {val_set_name}') 239 | raise NotImplemented 240 | 241 | r_list = feats[ : num_references] 242 | q_list = feats[num_references : ] 243 | pitts_dict = utils.get_validation_recalls( 244 | r_list=r_list, 245 | q_list=q_list, 246 | k_values=[1, 5, 10, 15, 20, 50, 100], 247 | gt=positives, 248 | print_results=True, 249 | dataset_name=val_set_name, 250 | faiss_gpu=self.faiss_gpu 251 | ) 252 | del r_list, q_list, feats, num_references, positives 253 | 254 | self.log(f'{val_set_name}/R1', pitts_dict[1], prog_bar=False, logger=True) 255 | self.log(f'{val_set_name}/R5', pitts_dict[5], prog_bar=False, logger=True) 256 | self.log(f'{val_set_name}/R10', pitts_dict[10], prog_bar=False, logger=True) 257 | print('\n\n') 258 | 259 | # reset the outputs list 260 | self.val_outputs = [] --------------------------------------------------------------------------------