├── .flake8 ├── .gitattributes ├── .github └── workflows │ └── code-quality.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── DSC_0410.JPG ├── DSC_0411.JPG ├── architecture.svg ├── benchmark.png ├── benchmark_cpu.png ├── easy_hard.jpg ├── sacre_coeur1.jpg ├── sacre_coeur2.jpg └── teaser.svg ├── benchmark.py ├── demo.ipynb ├── lightglue ├── __init__.py ├── aliked.py ├── disk.py ├── dog_hardnet.py ├── lightglue.py ├── sift.py ├── superpoint.py ├── utils.py └── viz2d.py ├── pyproject.toml └── requirements.txt /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203 4 | exclude = .git,__pycache__,build,.venv/ 5 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation -------------------------------------------------------------------------------- /.github/workflows/code-quality.yml: -------------------------------------------------------------------------------- 1 | name: Format and Lint Checks 2 | on: 3 | push: 4 | branches: 5 | - main 6 | paths: 7 | - '*.py' 8 | pull_request: 9 | types: [ assigned, opened, synchronize, reopened ] 10 | jobs: 11 | check: 12 | name: Format and Lint Checks 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.10' 19 | cache: 'pip' 20 | - run: python -m pip install --upgrade pip 21 | - run: python -m pip install .[dev] 22 | - run: python -m flake8 . 23 | - run: python -m isort . --check-only --diff 24 | - run: python -m black . --check --diff 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /outputs/ 3 | /lightglue/weights/ 4 | *-checkpoint.ipynb 5 | *.pth 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | .idea/ 167 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 ETH Zurich 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

LightGlue ⚡️
Local Feature Matching at Light Speed

3 |

4 | Philipp Lindenberger 5 | · 6 | Paul-Edouard Sarlin 7 | · 8 | Marc Pollefeys 9 |

10 |

11 |

ICCV 2023

12 | Paper | 13 | Colab | 14 | Poster | 15 | Train your own! 16 |

17 | 18 |

19 |

20 | example 21 |
22 | LightGlue is a deep neural network that matches sparse local features across image pairs.
An adaptive mechanism makes it fast for easy pairs (top) and reduces the computational complexity for difficult ones (bottom).
23 |

24 | 25 | ## 26 | 27 | This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf). 28 | 29 | We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629), [DISK](https://arxiv.org/abs/2006.13566), [ALIKED](https://arxiv.org/abs/2304.03608) and [SIFT](https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf) local features. 30 | The training and evaluation code can be found in our library [glue-factory](https://github.com/cvg/glue-factory/). 31 | 32 | ## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb) 33 | 34 | Install this repo using pip: 35 | 36 | ```bash 37 | git clone https://github.com/cvg/LightGlue.git && cd LightGlue 38 | python -m pip install -e . 39 | ``` 40 | 41 | We provide a [demo notebook](demo.ipynb) which shows how to perform feature extraction and matching on an image pair. 42 | 43 | Here is a minimal script to match two images: 44 | 45 | ```python 46 | from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet 47 | from lightglue.utils import load_image, rbd 48 | 49 | # SuperPoint+LightGlue 50 | extractor = SuperPoint(max_num_keypoints=2048).eval().cuda() # load the extractor 51 | matcher = LightGlue(features='superpoint').eval().cuda() # load the matcher 52 | 53 | # or DISK+LightGlue, ALIKED+LightGlue or SIFT+LightGlue 54 | extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor 55 | matcher = LightGlue(features='disk').eval().cuda() # load the matcher 56 | 57 | # load each image as a torch.Tensor on GPU with shape (3,H,W), normalized in [0,1] 58 | image0 = load_image('path/to/image_0.jpg').cuda() 59 | image1 = load_image('path/to/image_1.jpg').cuda() 60 | 61 | # extract local features 62 | feats0 = extractor.extract(image0) # auto-resize the image, disable with resize=None 63 | feats1 = extractor.extract(image1) 64 | 65 | # match the features 66 | matches01 = matcher({'image0': feats0, 'image1': feats1}) 67 | feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension 68 | matches = matches01['matches'] # indices with shape (K,2) 69 | points0 = feats0['keypoints'][matches[..., 0]] # coordinates in image #0, shape (K,2) 70 | points1 = feats1['keypoints'][matches[..., 1]] # coordinates in image #1, shape (K,2) 71 | ``` 72 | 73 | We also provide a convenience method to match a pair of images: 74 | 75 | ```python 76 | from lightglue import match_pair 77 | feats0, feats1, matches01 = match_pair(extractor, matcher, image0, image1) 78 | ``` 79 | 80 | ## 81 | 82 |

83 | Logo 84 |
85 | LightGlue can adjust its depth (number of layers) and width (number of keypoints) per image pair, with a marginal impact on accuracy. 86 |

87 | 88 | ## Advanced configuration 89 | 90 |
91 | [Detail of all parameters - click to expand] 92 | 93 | - ```n_layers```: Number of stacked self+cross attention layers. Reduce this value for faster inference at the cost of accuracy (continuous red line in the plot above). Default: 9 (all layers). 94 | - ```flash```: Enable FlashAttention. Significantly increases the speed and reduces the memory consumption without any impact on accuracy. Default: True (LightGlue automatically detects if FlashAttention is available). 95 | - ```mp```: Enable mixed precision inference. Default: False (off) 96 | - ```depth_confidence```: Controls the early stopping. A lower values stops more often at earlier layers. Default: 0.95, disable with -1. 97 | - ```width_confidence```: Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1. 98 | - ```filter_threshold```: Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1 99 | 100 |
101 | 102 | The default values give a good trade-off between speed and accuracy. To maximize the accuracy, use all keypoints and disable the adaptive mechanisms: 103 | ```python 104 | extractor = SuperPoint(max_num_keypoints=None) 105 | matcher = LightGlue(features='superpoint', depth_confidence=-1, width_confidence=-1) 106 | ``` 107 | 108 | To increase the speed with a small drop of accuracy, decrease the number of keypoints and lower the adaptive thresholds: 109 | ```python 110 | extractor = SuperPoint(max_num_keypoints=1024) 111 | matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95) 112 | ``` 113 | 114 | The maximum speed is obtained with a combination of: 115 | - [FlashAttention](https://arxiv.org/abs/2205.14135): automatically used when ```torch >= 2.0``` or if [installed from source](https://github.com/HazyResearch/flash-attention#installation-and-features). 116 | - PyTorch compilation, available when ```torch >= 2.0```: 117 | ```python 118 | matcher = matcher.eval().cuda() 119 | matcher.compile(mode='reduce-overhead') 120 | ``` 121 | For inputs with fewer than 1536 keypoints (determined experimentally), this compiles LightGlue but disables point pruning (large overhead). For larger input sizes, it automatically falls backs to eager mode with point pruning. Adaptive depths is supported for any input size. 122 | 123 | ## Benchmark 124 | 125 | 126 |

127 | Logo 128 |
129 | Benchmark results on GPU (RTX 3080). With compilation and adaptivity, LightGlue runs at 150 FPS @ 1024 keypoints and 50 FPS @ 4096 keypoints per image. This is a 4-10x speedup over SuperGlue. 130 |

131 | 132 |

133 | Logo 134 |
135 | Benchmark results on CPU (Intel i7 10700K). LightGlue runs at 20 FPS @ 512 keypoints. 136 |

137 | 138 | Obtain the same plots for your setup using our [benchmark script](benchmark.py): 139 | ``` 140 | python benchmark.py [--device cuda] [--add_superglue] [--num_keypoints 512 1024 2048 4096] [--compile] 141 | ``` 142 | 143 |
144 | [Performance tip - click to expand] 145 | 146 | Note: **Point pruning** introduces an overhead that sometimes outweighs its benefits. 147 | Point pruning is thus enabled only when the there are more than N keypoints in an image, where N is hardware-dependent. 148 | We provide defaults optimized for current hardware (RTX 30xx GPUs). 149 | We suggest running the benchmark script and adjusting the thresholds for your hardware by updating `LightGlue.pruning_keypoint_thresholds['cuda']`. 150 | 151 |
152 | 153 | ## Training and evaluation 154 | 155 | With [Glue Factory](https://github.com/cvg/glue-factory), you can train LightGlue with your own local features, on your own dataset! 156 | You can also evaluate it and other baselines on standard benchmarks like HPatches and MegaDepth. 157 | 158 | ## Other links 159 | - [hloc - the visual localization toolbox](https://github.com/cvg/Hierarchical-Localization/): run LightGlue for Structure-from-Motion and visual localization. 160 | - [LightGlue-ONNX](https://github.com/fabio-sim/LightGlue-ONNX): export LightGlue to the Open Neural Network Exchange (ONNX) format with support for TensorRT and OpenVINO. 161 | - [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui): a web GUI to easily compare different matchers, including LightGlue. 162 | - [kornia](https://kornia.readthedocs.io) now exposes LightGlue via the interfaces [`LightGlue`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlue) and [`LightGlueMatcher`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlueMatcher). 163 | 164 | ## BibTeX citation 165 | If you use any ideas from the paper or code from this repo, please consider citing: 166 | 167 | ```txt 168 | @inproceedings{lindenberger2023lightglue, 169 | author = {Philipp Lindenberger and 170 | Paul-Edouard Sarlin and 171 | Marc Pollefeys}, 172 | title = {{LightGlue: Local Feature Matching at Light Speed}}, 173 | booktitle = {ICCV}, 174 | year = {2023} 175 | } 176 | ``` 177 | 178 | 179 | ## License 180 | The pre-trained weights of LightGlue and the code provided in this repository are released under the [Apache-2.0 license](./LICENSE). [DISK](https://github.com/cvlab-epfl/disk) follows this license as well but SuperPoint follows [a different, restrictive license](https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/LICENSE) (this includes its pre-trained weights and its [inference file](./lightglue/superpoint.py)). [ALIKED](https://github.com/Shiaoming/ALIKED) was published under a BSD-3-Clause license. 181 | -------------------------------------------------------------------------------- /assets/DSC_0410.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/LightGlue/edb2b838efb2ecfe3f88097c5fad9887d95aedad/assets/DSC_0410.JPG -------------------------------------------------------------------------------- /assets/DSC_0411.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/LightGlue/edb2b838efb2ecfe3f88097c5fad9887d95aedad/assets/DSC_0411.JPG -------------------------------------------------------------------------------- /assets/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/LightGlue/edb2b838efb2ecfe3f88097c5fad9887d95aedad/assets/benchmark.png -------------------------------------------------------------------------------- /assets/benchmark_cpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/LightGlue/edb2b838efb2ecfe3f88097c5fad9887d95aedad/assets/benchmark_cpu.png -------------------------------------------------------------------------------- /assets/easy_hard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/LightGlue/edb2b838efb2ecfe3f88097c5fad9887d95aedad/assets/easy_hard.jpg -------------------------------------------------------------------------------- /assets/sacre_coeur1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/LightGlue/edb2b838efb2ecfe3f88097c5fad9887d95aedad/assets/sacre_coeur1.jpg -------------------------------------------------------------------------------- /assets/sacre_coeur2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/LightGlue/edb2b838efb2ecfe3f88097c5fad9887d95aedad/assets/sacre_coeur2.jpg -------------------------------------------------------------------------------- /assets/teaser.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2023-06-25T11:23:59.261938 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.7.1, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 39 | 40 | 41 | 42 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 372 | 402 | 435 | 469 | 494 | 495 | 516 | 529 | 546 | 577 | 608 | 629 | 650 | 669 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 719 | 720 | 721 | 722 | 723 | 724 | 725 | 728 | 729 | 730 | 731 | 732 | 733 | 734 | 735 | 736 | 737 | 767 | 768 | 769 | 770 | 771 | 772 | 773 | 774 | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 | 785 | 786 | 787 | 788 | 789 | 790 | 791 | 792 | 793 | 794 | 795 | 796 | 797 | 798 | 799 | 800 | 801 | 802 | 803 | 804 | 805 | 806 | 807 | 808 | 809 | 810 | 811 | 821 | 822 | 823 | 824 | 825 | 826 | 827 | 828 | 829 | 830 | 831 | 859 | 866 | 887 | 897 | 913 | 935 | 952 | 963 | 1010 | 1021 | 1022 | 1023 | 1024 | 1025 | 1026 | 1027 | 1028 | 1029 | 1030 | 1031 | 1032 | 1033 | 1034 | 1035 | 1036 | 1037 | 1038 | 1039 | 1040 | 1041 | 1042 | 1043 | 1044 | 1045 | 1046 | 1047 | 1048 | 1049 | 1050 | 1051 | 1052 | 1055 | 1056 | 1057 | 1060 | 1061 | 1062 | 1065 | 1066 | 1067 | 1070 | 1071 | 1072 | 1073 | 1074 | 1075 | 1101 | 1126 | 1127 | 1128 | 1129 | 1130 | 1131 | 1132 | 1133 | 1134 | 1135 | 1136 | 1137 | 1138 | 1139 | 1140 | 1141 | 1142 | 1158 | 1171 | 1172 | 1173 | 1174 | 1175 | 1176 | 1177 | 1178 | 1179 | 1180 | 1181 | 1182 | 1183 | 1184 | 1193 | 1206 | 1217 | 1218 | 1219 | 1220 | 1221 | 1222 | 1223 | 1224 | 1225 | 1226 | 1227 | 1228 | 1229 | 1248 | 1249 | 1250 | 1251 | 1252 | 1253 | 1254 | 1255 | 1256 | 1257 | 1258 | 1259 | 1260 | 1261 | 1262 | 1263 | 1264 | 1265 | 1266 | 1279 | 1280 | 1281 | 1282 | 1283 | 1284 | 1285 | 1286 | 1287 | 1288 | 1289 | 1290 | 1291 | 1292 | 1293 | 1294 | 1295 | 1296 | 1297 | 1298 | 1299 | 1300 | 1301 | 1302 | 1303 | 1304 | 1305 | 1335 | 1336 | 1337 | 1338 | 1339 | 1340 | 1341 | 1342 | 1343 | 1344 | 1345 | 1366 | 1381 | 1388 | 1389 | 1390 | 1391 | 1392 | 1393 | 1394 | 1395 | 1396 | 1397 | 1398 | 1399 | 1400 | 1401 | 1402 | 1403 | 1404 | 1405 | 1406 | 1407 | 1408 | 1409 | 1410 | 1411 | 1412 | 1413 | 1414 | 1415 | 1416 | 1417 | 1418 | 1419 | 1432 | 1433 | 1434 | 1435 | 1436 | 1437 | 1438 | 1439 | 1440 | 1441 | 1442 | 1443 | 1444 | 1445 | 1446 | 1447 | 1448 | 1449 | 1450 | 1451 | 1452 | 1453 | 1454 | 1455 | 1456 | 1457 | 1458 | 1459 | 1464 | 1465 | 1476 | 1477 | 1478 | 1479 | 1480 | 1481 | 1482 | 1483 | 1484 | 1485 | 1491 | 1492 | 1493 | 1494 | 1495 | 1496 | 1497 | 1498 | 1499 | 1500 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | # Benchmark script for LightGlue on real images 2 | import argparse 3 | import time 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | import torch._dynamo 11 | 12 | from lightglue import LightGlue, SuperPoint 13 | from lightglue.utils import load_image 14 | 15 | torch.set_grad_enabled(False) 16 | 17 | 18 | def measure(matcher, data, device="cuda", r=100): 19 | timings = np.zeros((r, 1)) 20 | if device.type == "cuda": 21 | starter = torch.cuda.Event(enable_timing=True) 22 | ender = torch.cuda.Event(enable_timing=True) 23 | # warmup 24 | for _ in range(10): 25 | _ = matcher(data) 26 | # measurements 27 | with torch.no_grad(): 28 | for rep in range(r): 29 | if device.type == "cuda": 30 | starter.record() 31 | _ = matcher(data) 32 | ender.record() 33 | # sync gpu 34 | torch.cuda.synchronize() 35 | curr_time = starter.elapsed_time(ender) 36 | else: 37 | start = time.perf_counter() 38 | _ = matcher(data) 39 | curr_time = (time.perf_counter() - start) * 1e3 40 | timings[rep] = curr_time 41 | mean_syn = np.sum(timings) / r 42 | std_syn = np.std(timings) 43 | return {"mean": mean_syn, "std": std_syn} 44 | 45 | 46 | def print_as_table(d, title, cnames): 47 | print() 48 | header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames]) 49 | print(header) 50 | print("-" * len(header)) 51 | for k, l in d.items(): 52 | print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l])) 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser(description="Benchmark script for LightGlue") 57 | parser.add_argument( 58 | "--device", 59 | choices=["auto", "cuda", "cpu", "mps"], 60 | default="auto", 61 | help="device to benchmark on", 62 | ) 63 | parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs") 64 | parser.add_argument( 65 | "--no_flash", action="store_true", help="disable FlashAttention" 66 | ) 67 | parser.add_argument( 68 | "--no_prune_thresholds", 69 | action="store_true", 70 | help="disable pruning thresholds (i.e. always do pruning)", 71 | ) 72 | parser.add_argument( 73 | "--add_superglue", 74 | action="store_true", 75 | help="add SuperGlue to the benchmark (requires hloc)", 76 | ) 77 | parser.add_argument( 78 | "--measure", default="time", choices=["time", "log-time", "throughput"] 79 | ) 80 | parser.add_argument( 81 | "--repeat", "--r", type=int, default=100, help="repetitions of measurements" 82 | ) 83 | parser.add_argument( 84 | "--num_keypoints", 85 | nargs="+", 86 | type=int, 87 | default=[256, 512, 1024, 2048, 4096], 88 | help="number of keypoints (list separated by spaces)", 89 | ) 90 | parser.add_argument( 91 | "--matmul_precision", default="highest", choices=["highest", "high", "medium"] 92 | ) 93 | parser.add_argument( 94 | "--save", default=None, type=str, help="path where figure should be saved" 95 | ) 96 | args = parser.parse_intermixed_args() 97 | 98 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 99 | if args.device != "auto": 100 | device = torch.device(args.device) 101 | 102 | print("Running benchmark on device:", device) 103 | 104 | images = Path("assets") 105 | inputs = { 106 | "easy": ( 107 | load_image(images / "DSC_0411.JPG"), 108 | load_image(images / "DSC_0410.JPG"), 109 | ), 110 | "difficult": ( 111 | load_image(images / "sacre_coeur1.jpg"), 112 | load_image(images / "sacre_coeur2.jpg"), 113 | ), 114 | } 115 | 116 | configs = { 117 | "LightGlue-full": { 118 | "depth_confidence": -1, 119 | "width_confidence": -1, 120 | }, 121 | # 'LG-prune': { 122 | # 'width_confidence': -1, 123 | # }, 124 | # 'LG-depth': { 125 | # 'depth_confidence': -1, 126 | # }, 127 | "LightGlue-adaptive": {}, 128 | } 129 | 130 | if args.compile: 131 | configs = {**configs, **{k + "-compile": v for k, v in configs.items()}} 132 | 133 | sg_configs = { 134 | # 'SuperGlue': {}, 135 | "SuperGlue-fast": {"sinkhorn_iterations": 5} 136 | } 137 | 138 | torch.set_float32_matmul_precision(args.matmul_precision) 139 | 140 | results = {k: defaultdict(list) for k, v in inputs.items()} 141 | 142 | extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1) 143 | extractor = extractor.eval().to(device) 144 | figsize = (len(inputs) * 4.5, 4.5) 145 | fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize) 146 | axes = axes if len(inputs) > 1 else [axes] 147 | fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})") 148 | 149 | for title, ax in zip(inputs.keys(), axes): 150 | ax.set_xscale("log", base=2) 151 | bases = [2**x for x in range(7, 16)] 152 | ax.set_xticks(bases, bases) 153 | ax.grid(which="major") 154 | if args.measure == "log-time": 155 | ax.set_yscale("log") 156 | yticks = [10**x for x in range(6)] 157 | ax.set_yticks(yticks, yticks) 158 | mpos = [10**x * i for x in range(6) for i in range(2, 10)] 159 | mlabel = [ 160 | 10**x * i if i in [2, 5] else None 161 | for x in range(6) 162 | for i in range(2, 10) 163 | ] 164 | ax.set_yticks(mpos, mlabel, minor=True) 165 | ax.grid(which="minor", linewidth=0.2) 166 | ax.set_title(title) 167 | 168 | ax.set_xlabel("# keypoints") 169 | if args.measure == "throughput": 170 | ax.set_ylabel("Throughput [pairs/s]") 171 | else: 172 | ax.set_ylabel("Latency [ms]") 173 | 174 | for name, conf in configs.items(): 175 | print("Run benchmark for:", name) 176 | torch.cuda.empty_cache() 177 | matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf) 178 | if args.no_prune_thresholds: 179 | matcher.pruning_keypoint_thresholds = { 180 | k: -1 for k in matcher.pruning_keypoint_thresholds 181 | } 182 | matcher = matcher.eval().to(device) 183 | if name.endswith("compile"): 184 | import torch._dynamo 185 | 186 | torch._dynamo.reset() # avoid buffer overflow 187 | matcher.compile() 188 | for pair_name, ax in zip(inputs.keys(), axes): 189 | image0, image1 = [x.to(device) for x in inputs[pair_name]] 190 | runtimes = [] 191 | for num_kpts in args.num_keypoints: 192 | extractor.conf.max_num_keypoints = num_kpts 193 | feats0 = extractor.extract(image0) 194 | feats1 = extractor.extract(image1) 195 | runtime = measure( 196 | matcher, 197 | {"image0": feats0, "image1": feats1}, 198 | device=device, 199 | r=args.repeat, 200 | )["mean"] 201 | results[pair_name][name].append( 202 | 1000 / runtime if args.measure == "throughput" else runtime 203 | ) 204 | ax.plot( 205 | args.num_keypoints, results[pair_name][name], label=name, marker="o" 206 | ) 207 | del matcher, feats0, feats1 208 | 209 | if args.add_superglue: 210 | from hloc.matchers.superglue import SuperGlue 211 | 212 | for name, conf in sg_configs.items(): 213 | print("Run benchmark for:", name) 214 | matcher = SuperGlue(conf) 215 | matcher = matcher.eval().to(device) 216 | for pair_name, ax in zip(inputs.keys(), axes): 217 | image0, image1 = [x.to(device) for x in inputs[pair_name]] 218 | runtimes = [] 219 | for num_kpts in args.num_keypoints: 220 | extractor.conf.max_num_keypoints = num_kpts 221 | feats0 = extractor.extract(image0) 222 | feats1 = extractor.extract(image1) 223 | data = { 224 | "image0": image0[None], 225 | "image1": image1[None], 226 | **{k + "0": v for k, v in feats0.items()}, 227 | **{k + "1": v for k, v in feats1.items()}, 228 | } 229 | data["scores0"] = data["keypoint_scores0"] 230 | data["scores1"] = data["keypoint_scores1"] 231 | data["descriptors0"] = ( 232 | data["descriptors0"].transpose(-1, -2).contiguous() 233 | ) 234 | data["descriptors1"] = ( 235 | data["descriptors1"].transpose(-1, -2).contiguous() 236 | ) 237 | runtime = measure(matcher, data, device=device, r=args.repeat)[ 238 | "mean" 239 | ] 240 | results[pair_name][name].append( 241 | 1000 / runtime if args.measure == "throughput" else runtime 242 | ) 243 | ax.plot( 244 | args.num_keypoints, results[pair_name][name], label=name, marker="o" 245 | ) 246 | del matcher, data, image0, image1, feats0, feats1 247 | 248 | for name, runtimes in results.items(): 249 | print_as_table(runtimes, name, args.num_keypoints) 250 | 251 | axes[0].legend() 252 | fig.tight_layout() 253 | if args.save: 254 | plt.savefig(args.save, dpi=fig.dpi) 255 | plt.show() 256 | -------------------------------------------------------------------------------- /lightglue/__init__.py: -------------------------------------------------------------------------------- 1 | from .aliked import ALIKED # noqa 2 | from .disk import DISK # noqa 3 | from .dog_hardnet import DoGHardNet # noqa 4 | from .lightglue import LightGlue # noqa 5 | from .sift import SIFT # noqa 6 | from .superpoint import SuperPoint # noqa 7 | from .utils import match_pair # noqa 8 | -------------------------------------------------------------------------------- /lightglue/aliked.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | 3 | # Copyright (c) 2022, Zhao Xiaoming 4 | # All rights reserved. 5 | 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | 9 | # 1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | 12 | # 2. Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | 16 | # 3. Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | # Authors: 32 | # Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li 33 | # Code from https://github.com/Shiaoming/ALIKED 34 | 35 | from typing import Callable, Optional 36 | 37 | import torch 38 | import torch.nn.functional as F 39 | import torchvision 40 | from kornia.color import grayscale_to_rgb 41 | from torch import nn 42 | from torch.nn.modules.utils import _pair 43 | from torchvision.models import resnet 44 | 45 | from .utils import Extractor 46 | 47 | 48 | def get_patches( 49 | tensor: torch.Tensor, required_corners: torch.Tensor, ps: int 50 | ) -> torch.Tensor: 51 | c, h, w = tensor.shape 52 | corner = (required_corners - ps / 2 + 1).long() 53 | corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps) 54 | corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps) 55 | offset = torch.arange(0, ps) 56 | 57 | kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {} 58 | x, y = torch.meshgrid(offset, offset, **kw) 59 | patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2) 60 | patches = patches.to(corner) + corner[None, None] 61 | pts = patches.reshape(-1, 2) 62 | sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]] 63 | sampled = sampled.reshape(ps, ps, -1, c) 64 | assert sampled.shape[:3] == patches.shape[:3] 65 | return sampled.permute(2, 3, 0, 1) 66 | 67 | 68 | def simple_nms(scores: torch.Tensor, nms_radius: int): 69 | """Fast Non-maximum suppression to remove nearby points""" 70 | 71 | zeros = torch.zeros_like(scores) 72 | max_mask = scores == torch.nn.functional.max_pool2d( 73 | scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius 74 | ) 75 | 76 | for _ in range(2): 77 | supp_mask = ( 78 | torch.nn.functional.max_pool2d( 79 | max_mask.float(), 80 | kernel_size=nms_radius * 2 + 1, 81 | stride=1, 82 | padding=nms_radius, 83 | ) 84 | > 0 85 | ) 86 | supp_scores = torch.where(supp_mask, zeros, scores) 87 | new_max_mask = supp_scores == torch.nn.functional.max_pool2d( 88 | supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius 89 | ) 90 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 91 | return torch.where(max_mask, scores, zeros) 92 | 93 | 94 | class DKD(nn.Module): 95 | def __init__( 96 | self, 97 | radius: int = 2, 98 | top_k: int = 0, 99 | scores_th: float = 0.2, 100 | n_limit: int = 20000, 101 | ): 102 | """ 103 | Args: 104 | radius: soft detection radius, kernel size is (2 * radius + 1) 105 | top_k: top_k > 0: return top k keypoints 106 | scores_th: top_k <= 0 threshold mode: 107 | scores_th > 0: return keypoints with scores>scores_th 108 | else: return keypoints with scores > scores.mean() 109 | n_limit: max number of keypoint in threshold mode 110 | """ 111 | super().__init__() 112 | self.radius = radius 113 | self.top_k = top_k 114 | self.scores_th = scores_th 115 | self.n_limit = n_limit 116 | self.kernel_size = 2 * self.radius + 1 117 | self.temperature = 0.1 # tuned temperature 118 | self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius) 119 | # local xy grid 120 | x = torch.linspace(-self.radius, self.radius, self.kernel_size) 121 | # (kernel_size*kernel_size) x 2 : (w,h) 122 | kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {} 123 | self.hw_grid = ( 124 | torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]] 125 | ) 126 | 127 | def forward( 128 | self, 129 | scores_map: torch.Tensor, 130 | sub_pixel: bool = True, 131 | image_size: Optional[torch.Tensor] = None, 132 | ): 133 | """ 134 | :param scores_map: Bx1xHxW 135 | :param descriptor_map: BxCxHxW 136 | :param sub_pixel: whether to use sub-pixel keypoint detection 137 | :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1 138 | """ 139 | b, c, h, w = scores_map.shape 140 | scores_nograd = scores_map.detach() 141 | nms_scores = simple_nms(scores_nograd, self.radius) 142 | 143 | # remove border 144 | nms_scores[:, :, : self.radius, :] = 0 145 | nms_scores[:, :, :, : self.radius] = 0 146 | if image_size is not None: 147 | for i in range(scores_map.shape[0]): 148 | w, h = image_size[i].long() 149 | nms_scores[i, :, h.item() - self.radius :, :] = 0 150 | nms_scores[i, :, :, w.item() - self.radius :] = 0 151 | else: 152 | nms_scores[:, :, -self.radius :, :] = 0 153 | nms_scores[:, :, :, -self.radius :] = 0 154 | 155 | # detect keypoints without grad 156 | if self.top_k > 0: 157 | topk = torch.topk(nms_scores.view(b, -1), self.top_k) 158 | indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k 159 | else: 160 | if self.scores_th > 0: 161 | masks = nms_scores > self.scores_th 162 | if masks.sum() == 0: 163 | th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th 164 | masks = nms_scores > th.reshape(b, 1, 1, 1) 165 | else: 166 | th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th 167 | masks = nms_scores > th.reshape(b, 1, 1, 1) 168 | masks = masks.reshape(b, -1) 169 | 170 | indices_keypoints = [] # list, B x (any size) 171 | scores_view = scores_nograd.reshape(b, -1) 172 | for mask, scores in zip(masks, scores_view): 173 | indices = mask.nonzero()[:, 0] 174 | if len(indices) > self.n_limit: 175 | kpts_sc = scores[indices] 176 | sort_idx = kpts_sc.sort(descending=True)[1] 177 | sel_idx = sort_idx[: self.n_limit] 178 | indices = indices[sel_idx] 179 | indices_keypoints.append(indices) 180 | 181 | wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device) 182 | 183 | keypoints = [] 184 | scoredispersitys = [] 185 | kptscores = [] 186 | if sub_pixel: 187 | # detect soft keypoints with grad backpropagation 188 | patches = self.unfold(scores_map) # B x (kernel**2) x (H*W) 189 | self.hw_grid = self.hw_grid.to(scores_map) # to device 190 | for b_idx in range(b): 191 | patch = patches[b_idx].t() # (H*W) x (kernel**2) 192 | indices_kpt = indices_keypoints[ 193 | b_idx 194 | ] # one dimension vector, say its size is M 195 | patch_scores = patch[indices_kpt] # M x (kernel**2) 196 | keypoints_xy_nms = torch.stack( 197 | [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")], 198 | dim=1, 199 | ) # Mx2 200 | 201 | # max is detached to prevent undesired backprop loops in the graph 202 | max_v = patch_scores.max(dim=1).values.detach()[:, None] 203 | x_exp = ( 204 | (patch_scores - max_v) / self.temperature 205 | ).exp() # M * (kernel**2), in [0, 1] 206 | 207 | # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} } 208 | xy_residual = ( 209 | x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] 210 | ) # Soft-argmax, Mx2 211 | 212 | hw_grid_dist2 = ( 213 | torch.norm( 214 | (self.hw_grid[None, :, :] - xy_residual[:, None, :]) 215 | / self.radius, 216 | dim=-1, 217 | ) 218 | ** 2 219 | ) 220 | scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1) 221 | 222 | # compute result keypoints 223 | keypoints_xy = keypoints_xy_nms + xy_residual 224 | keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1) 225 | 226 | kptscore = torch.nn.functional.grid_sample( 227 | scores_map[b_idx].unsqueeze(0), 228 | keypoints_xy.view(1, 1, -1, 2), 229 | mode="bilinear", 230 | align_corners=True, 231 | )[ 232 | 0, 0, 0, : 233 | ] # CxN 234 | 235 | keypoints.append(keypoints_xy) 236 | scoredispersitys.append(scoredispersity) 237 | kptscores.append(kptscore) 238 | else: 239 | for b_idx in range(b): 240 | indices_kpt = indices_keypoints[ 241 | b_idx 242 | ] # one dimension vector, say its size is M 243 | # To avoid warning: UserWarning: __floordiv__ is deprecated 244 | keypoints_xy_nms = torch.stack( 245 | [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")], 246 | dim=1, 247 | ) # Mx2 248 | keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1) 249 | kptscore = torch.nn.functional.grid_sample( 250 | scores_map[b_idx].unsqueeze(0), 251 | keypoints_xy.view(1, 1, -1, 2), 252 | mode="bilinear", 253 | align_corners=True, 254 | )[ 255 | 0, 0, 0, : 256 | ] # CxN 257 | keypoints.append(keypoints_xy) 258 | scoredispersitys.append(kptscore) # for jit.script compatability 259 | kptscores.append(kptscore) 260 | 261 | return keypoints, scoredispersitys, kptscores 262 | 263 | 264 | class InputPadder(object): 265 | """Pads images such that dimensions are divisible by 8""" 266 | 267 | def __init__(self, h: int, w: int, divis_by: int = 8): 268 | self.ht = h 269 | self.wd = w 270 | pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by 271 | pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by 272 | self._pad = [ 273 | pad_wd // 2, 274 | pad_wd - pad_wd // 2, 275 | pad_ht // 2, 276 | pad_ht - pad_ht // 2, 277 | ] 278 | 279 | def pad(self, x: torch.Tensor): 280 | assert x.ndim == 4 281 | return F.pad(x, self._pad, mode="replicate") 282 | 283 | def unpad(self, x: torch.Tensor): 284 | assert x.ndim == 4 285 | ht = x.shape[-2] 286 | wd = x.shape[-1] 287 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 288 | return x[..., c[0] : c[1], c[2] : c[3]] 289 | 290 | 291 | class DeformableConv2d(nn.Module): 292 | def __init__( 293 | self, 294 | in_channels, 295 | out_channels, 296 | kernel_size=3, 297 | stride=1, 298 | padding=1, 299 | bias=False, 300 | mask=False, 301 | ): 302 | super(DeformableConv2d, self).__init__() 303 | 304 | self.padding = padding 305 | self.mask = mask 306 | 307 | self.channel_num = ( 308 | 3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size 309 | ) 310 | self.offset_conv = nn.Conv2d( 311 | in_channels, 312 | self.channel_num, 313 | kernel_size=kernel_size, 314 | stride=stride, 315 | padding=self.padding, 316 | bias=True, 317 | ) 318 | 319 | self.regular_conv = nn.Conv2d( 320 | in_channels=in_channels, 321 | out_channels=out_channels, 322 | kernel_size=kernel_size, 323 | stride=stride, 324 | padding=self.padding, 325 | bias=bias, 326 | ) 327 | 328 | def forward(self, x): 329 | h, w = x.shape[2:] 330 | max_offset = max(h, w) / 4.0 331 | 332 | out = self.offset_conv(x) 333 | if self.mask: 334 | o1, o2, mask = torch.chunk(out, 3, dim=1) 335 | offset = torch.cat((o1, o2), dim=1) 336 | mask = torch.sigmoid(mask) 337 | else: 338 | offset = out 339 | mask = None 340 | offset = offset.clamp(-max_offset, max_offset) 341 | x = torchvision.ops.deform_conv2d( 342 | input=x, 343 | offset=offset, 344 | weight=self.regular_conv.weight, 345 | bias=self.regular_conv.bias, 346 | padding=self.padding, 347 | mask=mask, 348 | ) 349 | return x 350 | 351 | 352 | def get_conv( 353 | inplanes, 354 | planes, 355 | kernel_size=3, 356 | stride=1, 357 | padding=1, 358 | bias=False, 359 | conv_type="conv", 360 | mask=False, 361 | ): 362 | if conv_type == "conv": 363 | conv = nn.Conv2d( 364 | inplanes, 365 | planes, 366 | kernel_size=kernel_size, 367 | stride=stride, 368 | padding=padding, 369 | bias=bias, 370 | ) 371 | elif conv_type == "dcn": 372 | conv = DeformableConv2d( 373 | inplanes, 374 | planes, 375 | kernel_size=kernel_size, 376 | stride=stride, 377 | padding=_pair(padding), 378 | bias=bias, 379 | mask=mask, 380 | ) 381 | else: 382 | raise TypeError 383 | return conv 384 | 385 | 386 | class ConvBlock(nn.Module): 387 | def __init__( 388 | self, 389 | in_channels, 390 | out_channels, 391 | gate: Optional[Callable[..., nn.Module]] = None, 392 | norm_layer: Optional[Callable[..., nn.Module]] = None, 393 | conv_type: str = "conv", 394 | mask: bool = False, 395 | ): 396 | super().__init__() 397 | if gate is None: 398 | self.gate = nn.ReLU(inplace=True) 399 | else: 400 | self.gate = gate 401 | if norm_layer is None: 402 | norm_layer = nn.BatchNorm2d 403 | self.conv1 = get_conv( 404 | in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask 405 | ) 406 | self.bn1 = norm_layer(out_channels) 407 | self.conv2 = get_conv( 408 | out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask 409 | ) 410 | self.bn2 = norm_layer(out_channels) 411 | 412 | def forward(self, x): 413 | x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W 414 | x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W 415 | return x 416 | 417 | 418 | # modified based on torchvision\models\resnet.py#27->BasicBlock 419 | class ResBlock(nn.Module): 420 | expansion: int = 1 421 | 422 | def __init__( 423 | self, 424 | inplanes: int, 425 | planes: int, 426 | stride: int = 1, 427 | downsample: Optional[nn.Module] = None, 428 | groups: int = 1, 429 | base_width: int = 64, 430 | dilation: int = 1, 431 | gate: Optional[Callable[..., nn.Module]] = None, 432 | norm_layer: Optional[Callable[..., nn.Module]] = None, 433 | conv_type: str = "conv", 434 | mask: bool = False, 435 | ) -> None: 436 | super(ResBlock, self).__init__() 437 | if gate is None: 438 | self.gate = nn.ReLU(inplace=True) 439 | else: 440 | self.gate = gate 441 | if norm_layer is None: 442 | norm_layer = nn.BatchNorm2d 443 | if groups != 1 or base_width != 64: 444 | raise ValueError("ResBlock only supports groups=1 and base_width=64") 445 | if dilation > 1: 446 | raise NotImplementedError("Dilation > 1 not supported in ResBlock") 447 | # Both self.conv1 and self.downsample layers 448 | # downsample the input when stride != 1 449 | self.conv1 = get_conv( 450 | inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask 451 | ) 452 | self.bn1 = norm_layer(planes) 453 | self.conv2 = get_conv( 454 | planes, planes, kernel_size=3, conv_type=conv_type, mask=mask 455 | ) 456 | self.bn2 = norm_layer(planes) 457 | self.downsample = downsample 458 | self.stride = stride 459 | 460 | def forward(self, x: torch.Tensor) -> torch.Tensor: 461 | identity = x 462 | 463 | out = self.conv1(x) 464 | out = self.bn1(out) 465 | out = self.gate(out) 466 | 467 | out = self.conv2(out) 468 | out = self.bn2(out) 469 | 470 | if self.downsample is not None: 471 | identity = self.downsample(x) 472 | 473 | out += identity 474 | out = self.gate(out) 475 | 476 | return out 477 | 478 | 479 | class SDDH(nn.Module): 480 | def __init__( 481 | self, 482 | dims: int, 483 | kernel_size: int = 3, 484 | n_pos: int = 8, 485 | gate=nn.ReLU(), 486 | conv2D=False, 487 | mask=False, 488 | ): 489 | super(SDDH, self).__init__() 490 | self.kernel_size = kernel_size 491 | self.n_pos = n_pos 492 | self.conv2D = conv2D 493 | self.mask = mask 494 | 495 | self.get_patches_func = get_patches 496 | 497 | # estimate offsets 498 | self.channel_num = 3 * n_pos if mask else 2 * n_pos 499 | self.offset_conv = nn.Sequential( 500 | nn.Conv2d( 501 | dims, 502 | self.channel_num, 503 | kernel_size=kernel_size, 504 | stride=1, 505 | padding=0, 506 | bias=True, 507 | ), 508 | gate, 509 | nn.Conv2d( 510 | self.channel_num, 511 | self.channel_num, 512 | kernel_size=1, 513 | stride=1, 514 | padding=0, 515 | bias=True, 516 | ), 517 | ) 518 | 519 | # sampled feature conv 520 | self.sf_conv = nn.Conv2d( 521 | dims, dims, kernel_size=1, stride=1, padding=0, bias=False 522 | ) 523 | 524 | # convM 525 | if not conv2D: 526 | # deformable desc weights 527 | agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims)) 528 | self.register_parameter("agg_weights", agg_weights) 529 | else: 530 | self.convM = nn.Conv2d( 531 | dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False 532 | ) 533 | 534 | def forward(self, x, keypoints): 535 | # x: [B,C,H,W] 536 | # keypoints: list, [[N_kpts,2], ...] (w,h) 537 | b, c, h, w = x.shape 538 | wh = torch.tensor([[w - 1, h - 1]], device=x.device) 539 | max_offset = max(h, w) / 4.0 540 | 541 | offsets = [] 542 | descriptors = [] 543 | # get offsets for each keypoint 544 | for ib in range(b): 545 | xi, kptsi = x[ib], keypoints[ib] 546 | kptsi_wh = (kptsi / 2 + 0.5) * wh 547 | N_kpts = len(kptsi) 548 | 549 | if self.kernel_size > 1: 550 | patch = self.get_patches_func( 551 | xi, kptsi_wh.long(), self.kernel_size 552 | ) # [N_kpts, C, K, K] 553 | else: 554 | kptsi_wh_long = kptsi_wh.long() 555 | patch = ( 556 | xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]] 557 | .permute(1, 0) 558 | .reshape(N_kpts, c, 1, 1) 559 | ) 560 | 561 | offset = self.offset_conv(patch).clamp( 562 | -max_offset, max_offset 563 | ) # [N_kpts, 2*n_pos, 1, 1] 564 | if self.mask: 565 | offset = ( 566 | offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1) 567 | ) # [N_kpts, n_pos, 3] 568 | offset = offset[:, :, :-1] # [N_kpts, n_pos, 2] 569 | mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos] 570 | else: 571 | offset = ( 572 | offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1) 573 | ) # [N_kpts, n_pos, 2] 574 | offsets.append(offset) # for visualization 575 | 576 | # get sample positions 577 | pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2] 578 | pos = 2.0 * pos / wh[None] - 1 579 | pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2) 580 | 581 | # sample features 582 | features = F.grid_sample( 583 | xi.unsqueeze(0), pos, mode="bilinear", align_corners=True 584 | ) # [1,C,(N_kpts*n_pos),1] 585 | features = features.reshape(c, N_kpts, self.n_pos, 1).permute( 586 | 1, 0, 2, 3 587 | ) # [N_kpts, C, n_pos, 1] 588 | if self.mask: 589 | features = torch.einsum("ncpo,np->ncpo", features, mask_weight) 590 | 591 | features = torch.selu_(self.sf_conv(features)).squeeze( 592 | -1 593 | ) # [N_kpts, C, n_pos] 594 | # convM 595 | if not self.conv2D: 596 | descs = torch.einsum( 597 | "ncp,pcd->nd", features, self.agg_weights 598 | ) # [N_kpts, C] 599 | else: 600 | features = features.reshape(N_kpts, -1)[ 601 | :, :, None, None 602 | ] # [N_kpts, C*n_pos, 1, 1] 603 | descs = self.convM(features).squeeze() # [N_kpts, C] 604 | 605 | # normalize 606 | descs = F.normalize(descs, p=2.0, dim=1) 607 | descriptors.append(descs) 608 | 609 | return descriptors, offsets 610 | 611 | 612 | class ALIKED(Extractor): 613 | default_conf = { 614 | "model_name": "aliked-n16", 615 | "max_num_keypoints": -1, 616 | "detection_threshold": 0.2, 617 | "nms_radius": 2, 618 | } 619 | 620 | checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth" 621 | 622 | n_limit_max = 20000 623 | 624 | # c1, c2, c3, c4, dim, K, M 625 | cfgs = { 626 | "aliked-t16": [8, 16, 32, 64, 64, 3, 16], 627 | "aliked-n16": [16, 32, 64, 128, 128, 3, 16], 628 | "aliked-n16rot": [16, 32, 64, 128, 128, 3, 16], 629 | "aliked-n32": [16, 32, 64, 128, 128, 3, 32], 630 | } 631 | preprocess_conf = { 632 | "resize": 1024, 633 | } 634 | 635 | required_data_keys = ["image"] 636 | 637 | def __init__(self, **conf): 638 | super().__init__(**conf) # Update with default configuration. 639 | conf = self.conf 640 | c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name] 641 | conv_types = ["conv", "conv", "dcn", "dcn"] 642 | conv2D = False 643 | mask = False 644 | 645 | # build model 646 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 647 | self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4) 648 | self.norm = nn.BatchNorm2d 649 | self.gate = nn.SELU(inplace=True) 650 | self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0]) 651 | self.block2 = self.get_resblock(c1, c2, conv_types[1], mask) 652 | self.block3 = self.get_resblock(c2, c3, conv_types[2], mask) 653 | self.block4 = self.get_resblock(c3, c4, conv_types[3], mask) 654 | 655 | self.conv1 = resnet.conv1x1(c1, dim // 4) 656 | self.conv2 = resnet.conv1x1(c2, dim // 4) 657 | self.conv3 = resnet.conv1x1(c3, dim // 4) 658 | self.conv4 = resnet.conv1x1(dim, dim // 4) 659 | self.upsample2 = nn.Upsample( 660 | scale_factor=2, mode="bilinear", align_corners=True 661 | ) 662 | self.upsample4 = nn.Upsample( 663 | scale_factor=4, mode="bilinear", align_corners=True 664 | ) 665 | self.upsample8 = nn.Upsample( 666 | scale_factor=8, mode="bilinear", align_corners=True 667 | ) 668 | self.upsample32 = nn.Upsample( 669 | scale_factor=32, mode="bilinear", align_corners=True 670 | ) 671 | self.score_head = nn.Sequential( 672 | resnet.conv1x1(dim, 8), 673 | self.gate, 674 | resnet.conv3x3(8, 4), 675 | self.gate, 676 | resnet.conv3x3(4, 4), 677 | self.gate, 678 | resnet.conv3x3(4, 1), 679 | ) 680 | self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask) 681 | self.dkd = DKD( 682 | radius=conf.nms_radius, 683 | top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints, 684 | scores_th=conf.detection_threshold, 685 | n_limit=conf.max_num_keypoints 686 | if conf.max_num_keypoints > 0 687 | else self.n_limit_max, 688 | ) 689 | 690 | state_dict = torch.hub.load_state_dict_from_url( 691 | self.checkpoint_url.format(conf.model_name), map_location="cpu" 692 | ) 693 | self.load_state_dict(state_dict, strict=True) 694 | 695 | def get_resblock(self, c_in, c_out, conv_type, mask): 696 | return ResBlock( 697 | c_in, 698 | c_out, 699 | 1, 700 | nn.Conv2d(c_in, c_out, 1), 701 | gate=self.gate, 702 | norm_layer=self.norm, 703 | conv_type=conv_type, 704 | mask=mask, 705 | ) 706 | 707 | def extract_dense_map(self, image): 708 | # Pads images such that dimensions are divisible by 709 | div_by = 2**5 710 | padder = InputPadder(image.shape[-2], image.shape[-1], div_by) 711 | image = padder.pad(image) 712 | 713 | # ================================== feature encoder 714 | x1 = self.block1(image) # B x c1 x H x W 715 | x2 = self.pool2(x1) 716 | x2 = self.block2(x2) # B x c2 x H/2 x W/2 717 | x3 = self.pool4(x2) 718 | x3 = self.block3(x3) # B x c3 x H/8 x W/8 719 | x4 = self.pool4(x3) 720 | x4 = self.block4(x4) # B x dim x H/32 x W/32 721 | # ================================== feature aggregation 722 | x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W 723 | x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2 724 | x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8 725 | x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32 726 | x2_up = self.upsample2(x2) # B x dim//4 x H x W 727 | x3_up = self.upsample8(x3) # B x dim//4 x H x W 728 | x4_up = self.upsample32(x4) # B x dim//4 x H x W 729 | x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1) 730 | # ================================== score head 731 | score_map = torch.sigmoid(self.score_head(x1234)) 732 | feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1) 733 | 734 | # Unpads images 735 | feature_map = padder.unpad(feature_map) 736 | score_map = padder.unpad(score_map) 737 | 738 | return feature_map, score_map 739 | 740 | def forward(self, data: dict) -> dict: 741 | image = data["image"] 742 | if image.shape[1] == 1: 743 | image = grayscale_to_rgb(image) 744 | feature_map, score_map = self.extract_dense_map(image) 745 | keypoints, kptscores, scoredispersitys = self.dkd( 746 | score_map, image_size=data.get("image_size") 747 | ) 748 | descriptors, offsets = self.desc_head(feature_map, keypoints) 749 | 750 | _, _, h, w = image.shape 751 | wh = torch.tensor([w - 1, h - 1], device=image.device) 752 | # no padding required 753 | # we can set detection_threshold=-1 and conf.max_num_keypoints > 0 754 | return { 755 | "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2 756 | "descriptors": torch.stack(descriptors), # B x N x D 757 | "keypoint_scores": torch.stack(kptscores), # B x N 758 | } 759 | -------------------------------------------------------------------------------- /lightglue/disk.py: -------------------------------------------------------------------------------- 1 | import kornia 2 | import torch 3 | 4 | from .utils import Extractor 5 | 6 | 7 | class DISK(Extractor): 8 | default_conf = { 9 | "weights": "depth", 10 | "max_num_keypoints": None, 11 | "desc_dim": 128, 12 | "nms_window_size": 5, 13 | "detection_threshold": 0.0, 14 | "pad_if_not_divisible": True, 15 | } 16 | 17 | preprocess_conf = { 18 | "resize": 1024, 19 | "grayscale": False, 20 | } 21 | 22 | required_data_keys = ["image"] 23 | 24 | def __init__(self, **conf) -> None: 25 | super().__init__(**conf) # Update with default configuration. 26 | self.model = kornia.feature.DISK.from_pretrained(self.conf.weights) 27 | 28 | def forward(self, data: dict) -> dict: 29 | """Compute keypoints, scores, descriptors for image""" 30 | for key in self.required_data_keys: 31 | assert key in data, f"Missing key {key} in data" 32 | image = data["image"] 33 | if image.shape[1] == 1: 34 | image = kornia.color.grayscale_to_rgb(image) 35 | features = self.model( 36 | image, 37 | n=self.conf.max_num_keypoints, 38 | window_size=self.conf.nms_window_size, 39 | score_threshold=self.conf.detection_threshold, 40 | pad_if_not_divisible=self.conf.pad_if_not_divisible, 41 | ) 42 | keypoints = [f.keypoints for f in features] 43 | scores = [f.detection_scores for f in features] 44 | descriptors = [f.descriptors for f in features] 45 | del features 46 | 47 | keypoints = torch.stack(keypoints, 0) 48 | scores = torch.stack(scores, 0) 49 | descriptors = torch.stack(descriptors, 0) 50 | 51 | return { 52 | "keypoints": keypoints.to(image).contiguous(), 53 | "keypoint_scores": scores.to(image).contiguous(), 54 | "descriptors": descriptors.to(image).contiguous(), 55 | } 56 | -------------------------------------------------------------------------------- /lightglue/dog_hardnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.color import rgb_to_grayscale 3 | from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori 4 | 5 | from .sift import SIFT 6 | 7 | 8 | class DoGHardNet(SIFT): 9 | required_data_keys = ["image"] 10 | 11 | def __init__(self, **conf): 12 | super().__init__(**conf) 13 | self.laf_desc = LAFDescriptor(HardNet(True)).eval() 14 | 15 | def forward(self, data: dict) -> dict: 16 | image = data["image"] 17 | if image.shape[1] == 3: 18 | image = rgb_to_grayscale(image) 19 | device = image.device 20 | self.laf_desc = self.laf_desc.to(device) 21 | self.laf_desc.descriptor = self.laf_desc.descriptor.eval() 22 | pred = [] 23 | if "image_size" in data.keys(): 24 | im_size = data.get("image_size").long() 25 | else: 26 | im_size = None 27 | for k in range(len(image)): 28 | img = image[k] 29 | if im_size is not None: 30 | w, h = data["image_size"][k] 31 | img = img[:, : h.to(torch.int32), : w.to(torch.int32)] 32 | p = self.extract_single_image(img) 33 | lafs = laf_from_center_scale_ori( 34 | p["keypoints"].reshape(1, -1, 2), 35 | 6.0 * p["scales"].reshape(1, -1, 1, 1), 36 | torch.rad2deg(p["oris"]).reshape(1, -1, 1), 37 | ).to(device) 38 | p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128) 39 | pred.append(p) 40 | pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} 41 | return pred 42 | -------------------------------------------------------------------------------- /lightglue/lightglue.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from types import SimpleNamespace 4 | from typing import Callable, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | try: 12 | from flash_attn.modules.mha import FlashCrossAttention 13 | except ModuleNotFoundError: 14 | FlashCrossAttention = None 15 | 16 | if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"): 17 | FLASH_AVAILABLE = True 18 | else: 19 | FLASH_AVAILABLE = False 20 | 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 25 | def normalize_keypoints( 26 | kpts: torch.Tensor, size: Optional[torch.Tensor] = None 27 | ) -> torch.Tensor: 28 | if size is None: 29 | size = 1 + kpts.max(-2).values - kpts.min(-2).values 30 | elif not isinstance(size, torch.Tensor): 31 | size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype) 32 | size = size.to(kpts) 33 | shift = size / 2 34 | scale = size.max(-1).values / 2 35 | kpts = (kpts - shift[..., None, :]) / scale[..., None, None] 36 | return kpts 37 | 38 | 39 | def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]: 40 | if length <= x.shape[-2]: 41 | return x, torch.ones_like(x[..., :1], dtype=torch.bool) 42 | pad = torch.ones( 43 | *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype 44 | ) 45 | y = torch.cat([x, pad], dim=-2) 46 | mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device) 47 | mask[..., : x.shape[-2], :] = True 48 | return y, mask 49 | 50 | 51 | def rotate_half(x: torch.Tensor) -> torch.Tensor: 52 | x = x.unflatten(-1, (-1, 2)) 53 | x1, x2 = x.unbind(dim=-1) 54 | return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) 55 | 56 | 57 | def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 58 | return (t * freqs[0]) + (rotate_half(t) * freqs[1]) 59 | 60 | 61 | class LearnableFourierPositionalEncoding(nn.Module): 62 | def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None: 63 | super().__init__() 64 | F_dim = F_dim if F_dim is not None else dim 65 | self.gamma = gamma 66 | self.Wr = nn.Linear(M, F_dim // 2, bias=False) 67 | nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2) 68 | 69 | def forward(self, x: torch.Tensor) -> torch.Tensor: 70 | """encode position vector""" 71 | projected = self.Wr(x) 72 | cosines, sines = torch.cos(projected), torch.sin(projected) 73 | emb = torch.stack([cosines, sines], 0).unsqueeze(-3) 74 | return emb.repeat_interleave(2, dim=-1) 75 | 76 | 77 | class TokenConfidence(nn.Module): 78 | def __init__(self, dim: int) -> None: 79 | super().__init__() 80 | self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid()) 81 | 82 | def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): 83 | """get confidence tokens""" 84 | return ( 85 | self.token(desc0.detach()).squeeze(-1), 86 | self.token(desc1.detach()).squeeze(-1), 87 | ) 88 | 89 | 90 | class Attention(nn.Module): 91 | def __init__(self, allow_flash: bool) -> None: 92 | super().__init__() 93 | if allow_flash and not FLASH_AVAILABLE: 94 | warnings.warn( 95 | "FlashAttention is not available. For optimal speed, " 96 | "consider installing torch >= 2.0 or flash-attn.", 97 | stacklevel=2, 98 | ) 99 | self.enable_flash = allow_flash and FLASH_AVAILABLE 100 | self.has_sdp = hasattr(F, "scaled_dot_product_attention") 101 | if allow_flash and FlashCrossAttention: 102 | self.flash_ = FlashCrossAttention() 103 | if self.has_sdp: 104 | torch.backends.cuda.enable_flash_sdp(allow_flash) 105 | 106 | def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor: 107 | if q.shape[-2] == 0 or k.shape[-2] == 0: 108 | return q.new_zeros((*q.shape[:-1], v.shape[-1])) 109 | if self.enable_flash and q.device.type == "cuda": 110 | # use torch 2.0 scaled_dot_product_attention with flash 111 | if self.has_sdp: 112 | args = [x.half().contiguous() for x in [q, k, v]] 113 | v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype) 114 | return v if mask is None else v.nan_to_num() 115 | else: 116 | assert mask is None 117 | q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]] 118 | m = self.flash_(q.half(), torch.stack([k, v], 2).half()) 119 | return m.transpose(-2, -3).to(q.dtype).clone() 120 | elif self.has_sdp: 121 | args = [x.contiguous() for x in [q, k, v]] 122 | v = F.scaled_dot_product_attention(*args, attn_mask=mask) 123 | return v if mask is None else v.nan_to_num() 124 | else: 125 | s = q.shape[-1] ** -0.5 126 | sim = torch.einsum("...id,...jd->...ij", q, k) * s 127 | if mask is not None: 128 | sim.masked_fill(~mask, -float("inf")) 129 | attn = F.softmax(sim, -1) 130 | return torch.einsum("...ij,...jd->...id", attn, v) 131 | 132 | 133 | class SelfBlock(nn.Module): 134 | def __init__( 135 | self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True 136 | ) -> None: 137 | super().__init__() 138 | self.embed_dim = embed_dim 139 | self.num_heads = num_heads 140 | assert self.embed_dim % num_heads == 0 141 | self.head_dim = self.embed_dim // num_heads 142 | self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) 143 | self.inner_attn = Attention(flash) 144 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 145 | self.ffn = nn.Sequential( 146 | nn.Linear(2 * embed_dim, 2 * embed_dim), 147 | nn.LayerNorm(2 * embed_dim, elementwise_affine=True), 148 | nn.GELU(), 149 | nn.Linear(2 * embed_dim, embed_dim), 150 | ) 151 | 152 | def forward( 153 | self, 154 | x: torch.Tensor, 155 | encoding: torch.Tensor, 156 | mask: Optional[torch.Tensor] = None, 157 | ) -> torch.Tensor: 158 | qkv = self.Wqkv(x) 159 | qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) 160 | q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] 161 | q = apply_cached_rotary_emb(encoding, q) 162 | k = apply_cached_rotary_emb(encoding, k) 163 | context = self.inner_attn(q, k, v, mask=mask) 164 | message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2)) 165 | return x + self.ffn(torch.cat([x, message], -1)) 166 | 167 | 168 | class CrossBlock(nn.Module): 169 | def __init__( 170 | self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True 171 | ) -> None: 172 | super().__init__() 173 | self.heads = num_heads 174 | dim_head = embed_dim // num_heads 175 | self.scale = dim_head**-0.5 176 | inner_dim = dim_head * num_heads 177 | self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias) 178 | self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias) 179 | self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias) 180 | self.ffn = nn.Sequential( 181 | nn.Linear(2 * embed_dim, 2 * embed_dim), 182 | nn.LayerNorm(2 * embed_dim, elementwise_affine=True), 183 | nn.GELU(), 184 | nn.Linear(2 * embed_dim, embed_dim), 185 | ) 186 | if flash and FLASH_AVAILABLE: 187 | self.flash = Attention(True) 188 | else: 189 | self.flash = None 190 | 191 | def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): 192 | return func(x0), func(x1) 193 | 194 | def forward( 195 | self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None 196 | ) -> List[torch.Tensor]: 197 | qk0, qk1 = self.map_(self.to_qk, x0, x1) 198 | v0, v1 = self.map_(self.to_v, x0, x1) 199 | qk0, qk1, v0, v1 = map( 200 | lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2), 201 | (qk0, qk1, v0, v1), 202 | ) 203 | if self.flash is not None and qk0.device.type == "cuda": 204 | m0 = self.flash(qk0, qk1, v1, mask) 205 | m1 = self.flash( 206 | qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None 207 | ) 208 | else: 209 | qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5 210 | sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1) 211 | if mask is not None: 212 | sim = sim.masked_fill(~mask, -float("inf")) 213 | attn01 = F.softmax(sim, dim=-1) 214 | attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) 215 | m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1) 216 | m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0) 217 | if mask is not None: 218 | m0, m1 = m0.nan_to_num(), m1.nan_to_num() 219 | m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1) 220 | m0, m1 = self.map_(self.to_out, m0, m1) 221 | x0 = x0 + self.ffn(torch.cat([x0, m0], -1)) 222 | x1 = x1 + self.ffn(torch.cat([x1, m1], -1)) 223 | return x0, x1 224 | 225 | 226 | class TransformerLayer(nn.Module): 227 | def __init__(self, *args, **kwargs): 228 | super().__init__() 229 | self.self_attn = SelfBlock(*args, **kwargs) 230 | self.cross_attn = CrossBlock(*args, **kwargs) 231 | 232 | def forward( 233 | self, 234 | desc0, 235 | desc1, 236 | encoding0, 237 | encoding1, 238 | mask0: Optional[torch.Tensor] = None, 239 | mask1: Optional[torch.Tensor] = None, 240 | ): 241 | if mask0 is not None and mask1 is not None: 242 | return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1) 243 | else: 244 | desc0 = self.self_attn(desc0, encoding0) 245 | desc1 = self.self_attn(desc1, encoding1) 246 | return self.cross_attn(desc0, desc1) 247 | 248 | # This part is compiled and allows padding inputs 249 | def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1): 250 | mask = mask0 & mask1.transpose(-1, -2) 251 | mask0 = mask0 & mask0.transpose(-1, -2) 252 | mask1 = mask1 & mask1.transpose(-1, -2) 253 | desc0 = self.self_attn(desc0, encoding0, mask0) 254 | desc1 = self.self_attn(desc1, encoding1, mask1) 255 | return self.cross_attn(desc0, desc1, mask) 256 | 257 | 258 | def sigmoid_log_double_softmax( 259 | sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor 260 | ) -> torch.Tensor: 261 | """create the log assignment matrix from logits and similarity""" 262 | b, m, n = sim.shape 263 | certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2) 264 | scores0 = F.log_softmax(sim, 2) 265 | scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) 266 | scores = sim.new_full((b, m + 1, n + 1), 0) 267 | scores[:, :m, :n] = scores0 + scores1 + certainties 268 | scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1)) 269 | scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1)) 270 | return scores 271 | 272 | 273 | class MatchAssignment(nn.Module): 274 | def __init__(self, dim: int) -> None: 275 | super().__init__() 276 | self.dim = dim 277 | self.matchability = nn.Linear(dim, 1, bias=True) 278 | self.final_proj = nn.Linear(dim, dim, bias=True) 279 | 280 | def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): 281 | """build assignment matrix from descriptors""" 282 | mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) 283 | _, _, d = mdesc0.shape 284 | mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25 285 | sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1) 286 | z0 = self.matchability(desc0) 287 | z1 = self.matchability(desc1) 288 | scores = sigmoid_log_double_softmax(sim, z0, z1) 289 | return scores, sim 290 | 291 | def get_matchability(self, desc: torch.Tensor): 292 | return torch.sigmoid(self.matchability(desc)).squeeze(-1) 293 | 294 | 295 | def filter_matches(scores: torch.Tensor, th: float): 296 | """obtain matches from a log assignment matrix [Bx M+1 x N+1]""" 297 | max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) 298 | m0, m1 = max0.indices, max1.indices 299 | indices0 = torch.arange(m0.shape[1], device=m0.device)[None] 300 | indices1 = torch.arange(m1.shape[1], device=m1.device)[None] 301 | mutual0 = indices0 == m1.gather(1, m0) 302 | mutual1 = indices1 == m0.gather(1, m1) 303 | max0_exp = max0.values.exp() 304 | zero = max0_exp.new_tensor(0) 305 | mscores0 = torch.where(mutual0, max0_exp, zero) 306 | mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero) 307 | valid0 = mutual0 & (mscores0 > th) 308 | valid1 = mutual1 & valid0.gather(1, m1) 309 | m0 = torch.where(valid0, m0, -1) 310 | m1 = torch.where(valid1, m1, -1) 311 | return m0, m1, mscores0, mscores1 312 | 313 | 314 | class LightGlue(nn.Module): 315 | default_conf = { 316 | "name": "lightglue", # just for interfacing 317 | "input_dim": 256, # input descriptor dimension (autoselected from weights) 318 | "descriptor_dim": 256, 319 | "add_scale_ori": False, 320 | "n_layers": 9, 321 | "num_heads": 4, 322 | "flash": True, # enable FlashAttention if available. 323 | "mp": False, # enable mixed precision 324 | "depth_confidence": 0.95, # early stopping, disable with -1 325 | "width_confidence": 0.99, # point pruning, disable with -1 326 | "filter_threshold": 0.1, # match threshold 327 | "weights": None, 328 | } 329 | 330 | # Point pruning involves an overhead (gather). 331 | # Therefore, we only activate it if there are enough keypoints. 332 | pruning_keypoint_thresholds = { 333 | "cpu": -1, 334 | "mps": -1, 335 | "cuda": 1024, 336 | "flash": 1536, 337 | } 338 | 339 | required_data_keys = ["image0", "image1"] 340 | 341 | version = "v0.1_arxiv" 342 | url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth" 343 | 344 | features = { 345 | "superpoint": { 346 | "weights": "superpoint_lightglue", 347 | "input_dim": 256, 348 | }, 349 | "disk": { 350 | "weights": "disk_lightglue", 351 | "input_dim": 128, 352 | }, 353 | "aliked": { 354 | "weights": "aliked_lightglue", 355 | "input_dim": 128, 356 | }, 357 | "sift": { 358 | "weights": "sift_lightglue", 359 | "input_dim": 128, 360 | "add_scale_ori": True, 361 | }, 362 | "doghardnet": { 363 | "weights": "doghardnet_lightglue", 364 | "input_dim": 128, 365 | "add_scale_ori": True, 366 | }, 367 | } 368 | 369 | def __init__(self, features="superpoint", **conf) -> None: 370 | super().__init__() 371 | self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) 372 | if features is not None: 373 | if features not in self.features: 374 | raise ValueError( 375 | f"Unsupported features: {features} not in " 376 | f"{{{','.join(self.features)}}}" 377 | ) 378 | for k, v in self.features[features].items(): 379 | setattr(conf, k, v) 380 | 381 | if conf.input_dim != conf.descriptor_dim: 382 | self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True) 383 | else: 384 | self.input_proj = nn.Identity() 385 | 386 | head_dim = conf.descriptor_dim // conf.num_heads 387 | self.posenc = LearnableFourierPositionalEncoding( 388 | 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim 389 | ) 390 | 391 | h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim 392 | 393 | self.transformers = nn.ModuleList( 394 | [TransformerLayer(d, h, conf.flash) for _ in range(n)] 395 | ) 396 | 397 | self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)]) 398 | self.token_confidence = nn.ModuleList( 399 | [TokenConfidence(d) for _ in range(n - 1)] 400 | ) 401 | self.register_buffer( 402 | "confidence_thresholds", 403 | torch.Tensor( 404 | [self.confidence_threshold(i) for i in range(self.conf.n_layers)] 405 | ), 406 | ) 407 | 408 | state_dict = None 409 | if features is not None: 410 | fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth" 411 | state_dict = torch.hub.load_state_dict_from_url( 412 | self.url.format(self.version, features), file_name=fname 413 | ) 414 | self.load_state_dict(state_dict, strict=False) 415 | elif conf.weights is not None: 416 | path = Path(__file__).parent 417 | path = path / "weights/{}.pth".format(self.conf.weights) 418 | state_dict = torch.load(str(path), map_location="cpu") 419 | 420 | if state_dict: 421 | # rename old state dict entries 422 | for i in range(self.conf.n_layers): 423 | pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" 424 | state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} 425 | pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" 426 | state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} 427 | self.load_state_dict(state_dict, strict=False) 428 | 429 | # static lengths LightGlue is compiled for (only used with torch.compile) 430 | self.static_lengths = None 431 | 432 | def compile( 433 | self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536] 434 | ): 435 | if self.conf.width_confidence != -1: 436 | warnings.warn( 437 | "Point pruning is partially disabled for compiled forward.", 438 | stacklevel=2, 439 | ) 440 | 441 | torch._inductor.cudagraph_mark_step_begin() 442 | for i in range(self.conf.n_layers): 443 | self.transformers[i].masked_forward = torch.compile( 444 | self.transformers[i].masked_forward, mode=mode, fullgraph=True 445 | ) 446 | 447 | self.static_lengths = static_lengths 448 | 449 | def forward(self, data: dict) -> dict: 450 | """ 451 | Match keypoints and descriptors between two images 452 | 453 | Input (dict): 454 | image0: dict 455 | keypoints: [B x M x 2] 456 | descriptors: [B x M x D] 457 | image: [B x C x H x W] or image_size: [B x 2] 458 | image1: dict 459 | keypoints: [B x N x 2] 460 | descriptors: [B x N x D] 461 | image: [B x C x H x W] or image_size: [B x 2] 462 | Output (dict): 463 | matches0: [B x M] 464 | matching_scores0: [B x M] 465 | matches1: [B x N] 466 | matching_scores1: [B x N] 467 | matches: List[[Si x 2]] 468 | scores: List[[Si]] 469 | stop: int 470 | prune0: [B x M] 471 | prune1: [B x N] 472 | """ 473 | with torch.autocast(enabled=self.conf.mp, device_type="cuda"): 474 | return self._forward(data) 475 | 476 | def _forward(self, data: dict) -> dict: 477 | for key in self.required_data_keys: 478 | assert key in data, f"Missing key {key} in data" 479 | data0, data1 = data["image0"], data["image1"] 480 | kpts0, kpts1 = data0["keypoints"], data1["keypoints"] 481 | b, m, _ = kpts0.shape 482 | b, n, _ = kpts1.shape 483 | device = kpts0.device 484 | size0, size1 = data0.get("image_size"), data1.get("image_size") 485 | kpts0 = normalize_keypoints(kpts0, size0).clone() 486 | kpts1 = normalize_keypoints(kpts1, size1).clone() 487 | 488 | if self.conf.add_scale_ori: 489 | kpts0 = torch.cat( 490 | [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1 491 | ) 492 | kpts1 = torch.cat( 493 | [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1 494 | ) 495 | desc0 = data0["descriptors"].detach().contiguous() 496 | desc1 = data1["descriptors"].detach().contiguous() 497 | 498 | assert desc0.shape[-1] == self.conf.input_dim 499 | assert desc1.shape[-1] == self.conf.input_dim 500 | 501 | if torch.is_autocast_enabled(): 502 | desc0 = desc0.half() 503 | desc1 = desc1.half() 504 | 505 | mask0, mask1 = None, None 506 | c = max(m, n) 507 | do_compile = self.static_lengths and c <= max(self.static_lengths) 508 | if do_compile: 509 | kn = min([k for k in self.static_lengths if k >= c]) 510 | desc0, mask0 = pad_to_length(desc0, kn) 511 | desc1, mask1 = pad_to_length(desc1, kn) 512 | kpts0, _ = pad_to_length(kpts0, kn) 513 | kpts1, _ = pad_to_length(kpts1, kn) 514 | desc0 = self.input_proj(desc0) 515 | desc1 = self.input_proj(desc1) 516 | # cache positional embeddings 517 | encoding0 = self.posenc(kpts0) 518 | encoding1 = self.posenc(kpts1) 519 | 520 | # GNN + final_proj + assignment 521 | do_early_stop = self.conf.depth_confidence > 0 522 | do_point_pruning = self.conf.width_confidence > 0 and not do_compile 523 | pruning_th = self.pruning_min_kpts(device) 524 | if do_point_pruning: 525 | ind0 = torch.arange(0, m, device=device)[None] 526 | ind1 = torch.arange(0, n, device=device)[None] 527 | # We store the index of the layer at which pruning is detected. 528 | prune0 = torch.ones_like(ind0) 529 | prune1 = torch.ones_like(ind1) 530 | token0, token1 = None, None 531 | for i in range(self.conf.n_layers): 532 | if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints 533 | break 534 | desc0, desc1 = self.transformers[i]( 535 | desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1 536 | ) 537 | if i == self.conf.n_layers - 1: 538 | continue # no early stopping or adaptive width at last layer 539 | 540 | if do_early_stop: 541 | token0, token1 = self.token_confidence[i](desc0, desc1) 542 | if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n): 543 | break 544 | if do_point_pruning and desc0.shape[-2] > pruning_th: 545 | scores0 = self.log_assignment[i].get_matchability(desc0) 546 | prunemask0 = self.get_pruning_mask(token0, scores0, i) 547 | keep0 = torch.where(prunemask0)[1] 548 | ind0 = ind0.index_select(1, keep0) 549 | desc0 = desc0.index_select(1, keep0) 550 | encoding0 = encoding0.index_select(-2, keep0) 551 | prune0[:, ind0] += 1 552 | if do_point_pruning and desc1.shape[-2] > pruning_th: 553 | scores1 = self.log_assignment[i].get_matchability(desc1) 554 | prunemask1 = self.get_pruning_mask(token1, scores1, i) 555 | keep1 = torch.where(prunemask1)[1] 556 | ind1 = ind1.index_select(1, keep1) 557 | desc1 = desc1.index_select(1, keep1) 558 | encoding1 = encoding1.index_select(-2, keep1) 559 | prune1[:, ind1] += 1 560 | 561 | if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints 562 | m0 = desc0.new_full((b, m), -1, dtype=torch.long) 563 | m1 = desc1.new_full((b, n), -1, dtype=torch.long) 564 | mscores0 = desc0.new_zeros((b, m)) 565 | mscores1 = desc1.new_zeros((b, n)) 566 | matches = desc0.new_empty((b, 0, 2), dtype=torch.long) 567 | mscores = desc0.new_empty((b, 0)) 568 | if not do_point_pruning: 569 | prune0 = torch.ones_like(mscores0) * self.conf.n_layers 570 | prune1 = torch.ones_like(mscores1) * self.conf.n_layers 571 | return { 572 | "matches0": m0, 573 | "matches1": m1, 574 | "matching_scores0": mscores0, 575 | "matching_scores1": mscores1, 576 | "stop": i + 1, 577 | "matches": matches, 578 | "scores": mscores, 579 | "prune0": prune0, 580 | "prune1": prune1, 581 | } 582 | 583 | desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding 584 | scores, _ = self.log_assignment[i](desc0, desc1) 585 | m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) 586 | matches, mscores = [], [] 587 | for k in range(b): 588 | valid = m0[k] > -1 589 | m_indices_0 = torch.where(valid)[0] 590 | m_indices_1 = m0[k][valid] 591 | if do_point_pruning: 592 | m_indices_0 = ind0[k, m_indices_0] 593 | m_indices_1 = ind1[k, m_indices_1] 594 | matches.append(torch.stack([m_indices_0, m_indices_1], -1)) 595 | mscores.append(mscores0[k][valid]) 596 | 597 | # TODO: Remove when hloc switches to the compact format. 598 | if do_point_pruning: 599 | m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype) 600 | m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype) 601 | m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0))) 602 | m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0))) 603 | mscores0_ = torch.zeros((b, m), device=mscores0.device) 604 | mscores1_ = torch.zeros((b, n), device=mscores1.device) 605 | mscores0_[:, ind0] = mscores0 606 | mscores1_[:, ind1] = mscores1 607 | m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_ 608 | else: 609 | prune0 = torch.ones_like(mscores0) * self.conf.n_layers 610 | prune1 = torch.ones_like(mscores1) * self.conf.n_layers 611 | 612 | return { 613 | "matches0": m0, 614 | "matches1": m1, 615 | "matching_scores0": mscores0, 616 | "matching_scores1": mscores1, 617 | "stop": i + 1, 618 | "matches": matches, 619 | "scores": mscores, 620 | "prune0": prune0, 621 | "prune1": prune1, 622 | } 623 | 624 | def confidence_threshold(self, layer_index: int) -> float: 625 | """scaled confidence threshold""" 626 | threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers) 627 | return np.clip(threshold, 0, 1) 628 | 629 | def get_pruning_mask( 630 | self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int 631 | ) -> torch.Tensor: 632 | """mask points which should be removed""" 633 | keep = scores > (1 - self.conf.width_confidence) 634 | if confidences is not None: # Low-confidence points are never pruned. 635 | keep |= confidences <= self.confidence_thresholds[layer_index] 636 | return keep 637 | 638 | def check_if_stop( 639 | self, 640 | confidences0: torch.Tensor, 641 | confidences1: torch.Tensor, 642 | layer_index: int, 643 | num_points: int, 644 | ) -> torch.Tensor: 645 | """evaluate stopping condition""" 646 | confidences = torch.cat([confidences0, confidences1], -1) 647 | threshold = self.confidence_thresholds[layer_index] 648 | ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points 649 | return ratio_confident > self.conf.depth_confidence 650 | 651 | def pruning_min_kpts(self, device: torch.device): 652 | if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda": 653 | return self.pruning_keypoint_thresholds["flash"] 654 | else: 655 | return self.pruning_keypoint_thresholds[device.type] 656 | -------------------------------------------------------------------------------- /lightglue/sift.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from kornia.color import rgb_to_grayscale 7 | from packaging import version 8 | 9 | try: 10 | import pycolmap 11 | except ImportError: 12 | pycolmap = None 13 | 14 | from .utils import Extractor 15 | 16 | 17 | def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None): 18 | h, w = image_shape 19 | ij = np.round(points - 0.5).astype(int).T[::-1] 20 | 21 | # Remove duplicate points (identical coordinates). 22 | # Pick highest scale or score 23 | s = scales if scores is None else scores 24 | buffer = np.zeros((h, w)) 25 | np.maximum.at(buffer, tuple(ij), s) 26 | keep = np.where(buffer[tuple(ij)] == s)[0] 27 | 28 | # Pick lowest angle (arbitrary). 29 | ij = ij[:, keep] 30 | buffer[:] = np.inf 31 | o_abs = np.abs(angles[keep]) 32 | np.minimum.at(buffer, tuple(ij), o_abs) 33 | mask = buffer[tuple(ij)] == o_abs 34 | ij = ij[:, mask] 35 | keep = keep[mask] 36 | 37 | if nms_radius > 0: 38 | # Apply NMS on the remaining points 39 | buffer[:] = 0 40 | buffer[tuple(ij)] = s[keep] # scores or scale 41 | 42 | local_max = torch.nn.functional.max_pool2d( 43 | torch.from_numpy(buffer).unsqueeze(0), 44 | kernel_size=nms_radius * 2 + 1, 45 | stride=1, 46 | padding=nms_radius, 47 | ).squeeze(0) 48 | is_local_max = buffer == local_max.numpy() 49 | keep = keep[is_local_max[tuple(ij)]] 50 | return keep 51 | 52 | 53 | def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor: 54 | x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps) 55 | x.clip_(min=eps).sqrt_() 56 | return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps) 57 | 58 | 59 | def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray: 60 | """ 61 | Detect keypoints using OpenCV Detector. 62 | Optionally, perform description. 63 | Args: 64 | features: OpenCV based keypoints detector and descriptor 65 | image: Grayscale image of uint8 data type 66 | Returns: 67 | keypoints: 1D array of detected cv2.KeyPoint 68 | scores: 1D array of responses 69 | descriptors: 1D array of descriptors 70 | """ 71 | detections, descriptors = features.detectAndCompute(image, None) 72 | points = np.array([k.pt for k in detections], dtype=np.float32) 73 | scores = np.array([k.response for k in detections], dtype=np.float32) 74 | scales = np.array([k.size for k in detections], dtype=np.float32) 75 | angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32)) 76 | return points, scores, scales, angles, descriptors 77 | 78 | 79 | class SIFT(Extractor): 80 | default_conf = { 81 | "rootsift": True, 82 | "nms_radius": 0, # None to disable filtering entirely. 83 | "max_num_keypoints": 4096, 84 | "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda} 85 | "detection_threshold": 0.0066667, # from COLMAP 86 | "edge_threshold": 10, 87 | "first_octave": -1, # only used by pycolmap, the default of COLMAP 88 | "num_octaves": 4, 89 | } 90 | 91 | preprocess_conf = { 92 | "resize": 1024, 93 | } 94 | 95 | required_data_keys = ["image"] 96 | 97 | def __init__(self, **conf): 98 | super().__init__(**conf) # Update with default configuration. 99 | backend = self.conf.backend 100 | if backend.startswith("pycolmap"): 101 | if pycolmap is None: 102 | raise ImportError( 103 | "Cannot find module pycolmap: install it with pip" 104 | "or use backend=opencv." 105 | ) 106 | options = { 107 | "peak_threshold": self.conf.detection_threshold, 108 | "edge_threshold": self.conf.edge_threshold, 109 | "first_octave": self.conf.first_octave, 110 | "num_octaves": self.conf.num_octaves, 111 | "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy. 112 | } 113 | device = ( 114 | "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "") 115 | ) 116 | if ( 117 | backend == "pycolmap_cpu" or not pycolmap.has_cuda 118 | ) and pycolmap.__version__ < "0.5.0": 119 | warnings.warn( 120 | "The pycolmap CPU SIFT is buggy in version < 0.5.0, " 121 | "consider upgrading pycolmap or use the CUDA version.", 122 | stacklevel=1, 123 | ) 124 | else: 125 | options["max_num_features"] = self.conf.max_num_keypoints 126 | self.sift = pycolmap.Sift(options=options, device=device) 127 | elif backend == "opencv": 128 | self.sift = cv2.SIFT_create( 129 | contrastThreshold=self.conf.detection_threshold, 130 | nfeatures=self.conf.max_num_keypoints, 131 | edgeThreshold=self.conf.edge_threshold, 132 | nOctaveLayers=self.conf.num_octaves, 133 | ) 134 | else: 135 | backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"} 136 | raise ValueError( 137 | f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}." 138 | ) 139 | 140 | def extract_single_image(self, image: torch.Tensor): 141 | image_np = image.cpu().numpy().squeeze(0) 142 | 143 | if self.conf.backend.startswith("pycolmap"): 144 | if version.parse(pycolmap.__version__) >= version.parse("0.5.0"): 145 | detections, descriptors = self.sift.extract(image_np) 146 | scores = None # Scores are not exposed by COLMAP anymore. 147 | else: 148 | detections, scores, descriptors = self.sift.extract(image_np) 149 | keypoints = detections[:, :2] # Keep only (x, y). 150 | scales, angles = detections[:, -2:].T 151 | if scores is not None and ( 152 | self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda 153 | ): 154 | # Set the scores as a combination of abs. response and scale. 155 | scores = np.abs(scores) * scales 156 | elif self.conf.backend == "opencv": 157 | # TODO: Check if opencv keypoints are already in corner convention 158 | keypoints, scores, scales, angles, descriptors = run_opencv_sift( 159 | self.sift, (image_np * 255.0).astype(np.uint8) 160 | ) 161 | pred = { 162 | "keypoints": keypoints, 163 | "scales": scales, 164 | "oris": angles, 165 | "descriptors": descriptors, 166 | } 167 | if scores is not None: 168 | pred["keypoint_scores"] = scores 169 | 170 | # sometimes pycolmap returns points outside the image. We remove them 171 | if self.conf.backend.startswith("pycolmap"): 172 | is_inside = ( 173 | pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]]) 174 | ).all(-1) 175 | pred = {k: v[is_inside] for k, v in pred.items()} 176 | 177 | if self.conf.nms_radius is not None: 178 | keep = filter_dog_point( 179 | pred["keypoints"], 180 | pred["scales"], 181 | pred["oris"], 182 | image_np.shape, 183 | self.conf.nms_radius, 184 | scores=pred.get("keypoint_scores"), 185 | ) 186 | pred = {k: v[keep] for k, v in pred.items()} 187 | 188 | pred = {k: torch.from_numpy(v) for k, v in pred.items()} 189 | if scores is not None: 190 | # Keep the k keypoints with highest score 191 | num_points = self.conf.max_num_keypoints 192 | if num_points is not None and len(pred["keypoints"]) > num_points: 193 | indices = torch.topk(pred["keypoint_scores"], num_points).indices 194 | pred = {k: v[indices] for k, v in pred.items()} 195 | 196 | return pred 197 | 198 | def forward(self, data: dict) -> dict: 199 | image = data["image"] 200 | if image.shape[1] == 3: 201 | image = rgb_to_grayscale(image) 202 | device = image.device 203 | image = image.cpu() 204 | pred = [] 205 | for k in range(len(image)): 206 | img = image[k] 207 | if "image_size" in data.keys(): 208 | # avoid extracting points in padded areas 209 | w, h = data["image_size"][k] 210 | img = img[:, :h, :w] 211 | p = self.extract_single_image(img) 212 | pred.append(p) 213 | pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} 214 | if self.conf.rootsift: 215 | pred["descriptors"] = sift_to_rootsift(pred["descriptors"]) 216 | return pred 217 | -------------------------------------------------------------------------------- /lightglue/superpoint.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | # Adapted by Remi Pautrat, Philipp Lindenberger 44 | 45 | import torch 46 | from kornia.color import rgb_to_grayscale 47 | from torch import nn 48 | 49 | from .utils import Extractor 50 | 51 | 52 | def simple_nms(scores, nms_radius: int): 53 | """Fast Non-maximum suppression to remove nearby points""" 54 | assert nms_radius >= 0 55 | 56 | def max_pool(x): 57 | return torch.nn.functional.max_pool2d( 58 | x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius 59 | ) 60 | 61 | zeros = torch.zeros_like(scores) 62 | max_mask = scores == max_pool(scores) 63 | for _ in range(2): 64 | supp_mask = max_pool(max_mask.float()) > 0 65 | supp_scores = torch.where(supp_mask, zeros, scores) 66 | new_max_mask = supp_scores == max_pool(supp_scores) 67 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 68 | return torch.where(max_mask, scores, zeros) 69 | 70 | 71 | def top_k_keypoints(keypoints, scores, k): 72 | if k >= len(keypoints): 73 | return keypoints, scores 74 | scores, indices = torch.topk(scores, k, dim=0, sorted=True) 75 | return keypoints[indices], scores 76 | 77 | 78 | def sample_descriptors(keypoints, descriptors, s: int = 8): 79 | """Interpolate descriptors at keypoint locations""" 80 | b, c, h, w = descriptors.shape 81 | keypoints = keypoints - s / 2 + 0.5 82 | keypoints /= torch.tensor( 83 | [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], 84 | ).to( 85 | keypoints 86 | )[None] 87 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1) 88 | args = {"align_corners": True} if torch.__version__ >= "1.3" else {} 89 | descriptors = torch.nn.functional.grid_sample( 90 | descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args 91 | ) 92 | descriptors = torch.nn.functional.normalize( 93 | descriptors.reshape(b, c, -1), p=2, dim=1 94 | ) 95 | return descriptors 96 | 97 | 98 | class SuperPoint(Extractor): 99 | """SuperPoint Convolutional Detector and Descriptor 100 | 101 | SuperPoint: Self-Supervised Interest Point Detection and 102 | Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew 103 | Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 104 | 105 | """ 106 | 107 | default_conf = { 108 | "descriptor_dim": 256, 109 | "nms_radius": 4, 110 | "max_num_keypoints": None, 111 | "detection_threshold": 0.0005, 112 | "remove_borders": 4, 113 | } 114 | 115 | preprocess_conf = { 116 | "resize": 1024, 117 | } 118 | 119 | required_data_keys = ["image"] 120 | 121 | def __init__(self, **conf): 122 | super().__init__(**conf) # Update with default configuration. 123 | self.relu = nn.ReLU(inplace=True) 124 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 125 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 126 | 127 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) 128 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 129 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 130 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 131 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 132 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 133 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 134 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 135 | 136 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 137 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) 138 | 139 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 140 | self.convDb = nn.Conv2d( 141 | c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0 142 | ) 143 | 144 | url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa 145 | self.load_state_dict(torch.hub.load_state_dict_from_url(url)) 146 | 147 | if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0: 148 | raise ValueError("max_num_keypoints must be positive or None") 149 | 150 | def forward(self, data: dict) -> dict: 151 | """Compute keypoints, scores, descriptors for image""" 152 | for key in self.required_data_keys: 153 | assert key in data, f"Missing key {key} in data" 154 | image = data["image"] 155 | if image.shape[1] == 3: 156 | image = rgb_to_grayscale(image) 157 | 158 | # Shared Encoder 159 | x = self.relu(self.conv1a(image)) 160 | x = self.relu(self.conv1b(x)) 161 | x = self.pool(x) 162 | x = self.relu(self.conv2a(x)) 163 | x = self.relu(self.conv2b(x)) 164 | x = self.pool(x) 165 | x = self.relu(self.conv3a(x)) 166 | x = self.relu(self.conv3b(x)) 167 | x = self.pool(x) 168 | x = self.relu(self.conv4a(x)) 169 | x = self.relu(self.conv4b(x)) 170 | 171 | # Compute the dense keypoint scores 172 | cPa = self.relu(self.convPa(x)) 173 | scores = self.convPb(cPa) 174 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1] 175 | b, _, h, w = scores.shape 176 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 177 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) 178 | scores = simple_nms(scores, self.conf.nms_radius) 179 | 180 | # Discard keypoints near the image borders 181 | if self.conf.remove_borders: 182 | pad = self.conf.remove_borders 183 | scores[:, :pad] = -1 184 | scores[:, :, :pad] = -1 185 | scores[:, -pad:] = -1 186 | scores[:, :, -pad:] = -1 187 | 188 | # Extract keypoints 189 | best_kp = torch.where(scores > self.conf.detection_threshold) 190 | scores = scores[best_kp] 191 | 192 | # Separate into batches 193 | keypoints = [ 194 | torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b) 195 | ] 196 | scores = [scores[best_kp[0] == i] for i in range(b)] 197 | 198 | # Keep the k keypoints with highest score 199 | if self.conf.max_num_keypoints is not None: 200 | keypoints, scores = list( 201 | zip( 202 | *[ 203 | top_k_keypoints(k, s, self.conf.max_num_keypoints) 204 | for k, s in zip(keypoints, scores) 205 | ] 206 | ) 207 | ) 208 | 209 | # Convert (h, w) to (x, y) 210 | keypoints = [torch.flip(k, [1]).float() for k in keypoints] 211 | 212 | # Compute the dense descriptors 213 | cDa = self.relu(self.convDa(x)) 214 | descriptors = self.convDb(cDa) 215 | descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) 216 | 217 | # Extract descriptors 218 | descriptors = [ 219 | sample_descriptors(k[None], d[None], 8)[0] 220 | for k, d in zip(keypoints, descriptors) 221 | ] 222 | 223 | return { 224 | "keypoints": torch.stack(keypoints, 0), 225 | "keypoint_scores": torch.stack(scores, 0), 226 | "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(), 227 | } 228 | -------------------------------------------------------------------------------- /lightglue/utils.py: -------------------------------------------------------------------------------- 1 | import collections.abc as collections 2 | from pathlib import Path 3 | from types import SimpleNamespace 4 | from typing import Callable, List, Optional, Tuple, Union 5 | 6 | import cv2 7 | import kornia 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class ImagePreprocessor: 13 | default_conf = { 14 | "resize": None, # target edge length, None for no resizing 15 | "side": "long", 16 | "interpolation": "bilinear", 17 | "align_corners": None, 18 | "antialias": True, 19 | } 20 | 21 | def __init__(self, **conf) -> None: 22 | super().__init__() 23 | self.conf = {**self.default_conf, **conf} 24 | self.conf = SimpleNamespace(**self.conf) 25 | 26 | def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 27 | """Resize and preprocess an image, return image and resize scale""" 28 | h, w = img.shape[-2:] 29 | if self.conf.resize is not None: 30 | img = kornia.geometry.transform.resize( 31 | img, 32 | self.conf.resize, 33 | side=self.conf.side, 34 | antialias=self.conf.antialias, 35 | align_corners=self.conf.align_corners, 36 | ) 37 | scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) 38 | return img, scale 39 | 40 | 41 | def map_tensor(input_, func: Callable): 42 | string_classes = (str, bytes) 43 | if isinstance(input_, string_classes): 44 | return input_ 45 | elif isinstance(input_, collections.Mapping): 46 | return {k: map_tensor(sample, func) for k, sample in input_.items()} 47 | elif isinstance(input_, collections.Sequence): 48 | return [map_tensor(sample, func) for sample in input_] 49 | elif isinstance(input_, torch.Tensor): 50 | return func(input_) 51 | else: 52 | return input_ 53 | 54 | 55 | def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True): 56 | """Move batch (dict) to device""" 57 | 58 | def _func(tensor): 59 | return tensor.to(device=device, non_blocking=non_blocking).detach() 60 | 61 | return map_tensor(batch, _func) 62 | 63 | 64 | def rbd(data: dict) -> dict: 65 | """Remove batch dimension from elements in data""" 66 | return { 67 | k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v 68 | for k, v in data.items() 69 | } 70 | 71 | 72 | def read_image(path: Path, grayscale: bool = False) -> np.ndarray: 73 | """Read an image from path as RGB or grayscale""" 74 | if not Path(path).exists(): 75 | raise FileNotFoundError(f"No image at path {path}.") 76 | mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR 77 | image = cv2.imread(str(path), mode) 78 | if image is None: 79 | raise IOError(f"Could not read image at {path}.") 80 | if not grayscale: 81 | image = image[..., ::-1] 82 | return image 83 | 84 | 85 | def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: 86 | """Normalize the image tensor and reorder the dimensions.""" 87 | if image.ndim == 3: 88 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW 89 | elif image.ndim == 2: 90 | image = image[None] # add channel axis 91 | else: 92 | raise ValueError(f"Not an image: {image.shape}") 93 | return torch.tensor(image / 255.0, dtype=torch.float) 94 | 95 | 96 | def resize_image( 97 | image: np.ndarray, 98 | size: Union[List[int], int], 99 | fn: str = "max", 100 | interp: Optional[str] = "area", 101 | ) -> np.ndarray: 102 | """Resize an image to a fixed size, or according to max or min edge.""" 103 | h, w = image.shape[:2] 104 | 105 | fn = {"max": max, "min": min}[fn] 106 | if isinstance(size, int): 107 | scale = size / fn(h, w) 108 | h_new, w_new = int(round(h * scale)), int(round(w * scale)) 109 | scale = (w_new / w, h_new / h) 110 | elif isinstance(size, (tuple, list)): 111 | h_new, w_new = size 112 | scale = (w_new / w, h_new / h) 113 | else: 114 | raise ValueError(f"Incorrect new size: {size}") 115 | mode = { 116 | "linear": cv2.INTER_LINEAR, 117 | "cubic": cv2.INTER_CUBIC, 118 | "nearest": cv2.INTER_NEAREST, 119 | "area": cv2.INTER_AREA, 120 | }[interp] 121 | return cv2.resize(image, (w_new, h_new), interpolation=mode), scale 122 | 123 | 124 | def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor: 125 | image = read_image(path) 126 | if resize is not None: 127 | image, _ = resize_image(image, resize, **kwargs) 128 | return numpy_image_to_torch(image) 129 | 130 | 131 | class Extractor(torch.nn.Module): 132 | def __init__(self, **conf): 133 | super().__init__() 134 | self.conf = SimpleNamespace(**{**self.default_conf, **conf}) 135 | 136 | @torch.no_grad() 137 | def extract(self, img: torch.Tensor, **conf) -> dict: 138 | """Perform extraction with online resizing""" 139 | if img.dim() == 3: 140 | img = img[None] # add batch dim 141 | assert img.dim() == 4 and img.shape[0] == 1 142 | shape = img.shape[-2:][::-1] 143 | img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) 144 | feats = self.forward({"image": img}) 145 | feats["image_size"] = torch.tensor(shape)[None].to(img).float() 146 | feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 147 | return feats 148 | 149 | 150 | def match_pair( 151 | extractor, 152 | matcher, 153 | image0: torch.Tensor, 154 | image1: torch.Tensor, 155 | device: str = "cpu", 156 | **preprocess, 157 | ): 158 | """Match a pair of images (image0, image1) with an extractor and matcher""" 159 | feats0 = extractor.extract(image0, **preprocess) 160 | feats1 = extractor.extract(image1, **preprocess) 161 | matches01 = matcher({"image0": feats0, "image1": feats1}) 162 | data = [feats0, feats1, matches01] 163 | # remove batch dim and move to target device 164 | feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data] 165 | return feats0, feats1, matches01 166 | -------------------------------------------------------------------------------- /lightglue/viz2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2D visualization primitives based on Matplotlib. 3 | 1) Plot images with `plot_images`. 4 | 2) Call `plot_keypoints` or `plot_matches` any number of times. 5 | 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. 6 | """ 7 | 8 | import matplotlib 9 | import matplotlib.patheffects as path_effects 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | 14 | 15 | def cm_RdGn(x): 16 | """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" 17 | x = np.clip(x, 0, 1)[..., None] * 2 18 | c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) 19 | return np.clip(c, 0, 1) 20 | 21 | 22 | def cm_BlRdGn(x_): 23 | """Custom colormap: blue (-1) -> red (0.0) -> green (1).""" 24 | x = np.clip(x_, 0, 1)[..., None] * 2 25 | c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]]) 26 | 27 | xn = -np.clip(x_, -1, 0)[..., None] * 2 28 | cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]]) 29 | out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1) 30 | return out 31 | 32 | 33 | def cm_prune(x_): 34 | """Custom colormap to visualize pruning""" 35 | if isinstance(x_, torch.Tensor): 36 | x_ = x_.cpu().numpy() 37 | max_i = max(x_) 38 | norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9) 39 | return cm_BlRdGn(norm_x) 40 | 41 | 42 | def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True): 43 | """Plot a set of images horizontally. 44 | Args: 45 | imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W). 46 | titles: a list of strings, as titles for each image. 47 | cmaps: colormaps for monochrome images. 48 | adaptive: whether the figure size should fit the image aspect ratios. 49 | """ 50 | # conversion to (H, W, 3) for torch.Tensor 51 | imgs = [ 52 | img.permute(1, 2, 0).cpu().numpy() 53 | if (isinstance(img, torch.Tensor) and img.dim() == 3) 54 | else img 55 | for img in imgs 56 | ] 57 | 58 | n = len(imgs) 59 | if not isinstance(cmaps, (list, tuple)): 60 | cmaps = [cmaps] * n 61 | 62 | if adaptive: 63 | ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H 64 | else: 65 | ratios = [4 / 3] * n 66 | figsize = [sum(ratios) * 4.5, 4.5] 67 | fig, ax = plt.subplots( 68 | 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} 69 | ) 70 | if n == 1: 71 | ax = [ax] 72 | for i in range(n): 73 | ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) 74 | ax[i].get_yaxis().set_ticks([]) 75 | ax[i].get_xaxis().set_ticks([]) 76 | ax[i].set_axis_off() 77 | for spine in ax[i].spines.values(): # remove frame 78 | spine.set_visible(False) 79 | if titles: 80 | ax[i].set_title(titles[i]) 81 | fig.tight_layout(pad=pad) 82 | 83 | 84 | def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0): 85 | """Plot keypoints for existing images. 86 | Args: 87 | kpts: list of ndarrays of size (N, 2). 88 | colors: string, or list of list of tuples (one for each keypoints). 89 | ps: size of the keypoints as float. 90 | """ 91 | if not isinstance(colors, list): 92 | colors = [colors] * len(kpts) 93 | if not isinstance(a, list): 94 | a = [a] * len(kpts) 95 | if axes is None: 96 | axes = plt.gcf().axes 97 | for ax, k, c, alpha in zip(axes, kpts, colors, a): 98 | if isinstance(k, torch.Tensor): 99 | k = k.cpu().numpy() 100 | ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha) 101 | 102 | 103 | def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None): 104 | """Plot matches for a pair of existing images. 105 | Args: 106 | kpts0, kpts1: corresponding keypoints of size (N, 2). 107 | color: color of each match, string or RGB tuple. Random if not given. 108 | lw: width of the lines. 109 | ps: size of the end points (no endpoint if ps=0) 110 | indices: indices of the images to draw the matches on. 111 | a: alpha opacity of the match lines. 112 | """ 113 | fig = plt.gcf() 114 | if axes is None: 115 | ax = fig.axes 116 | ax0, ax1 = ax[0], ax[1] 117 | else: 118 | ax0, ax1 = axes 119 | if isinstance(kpts0, torch.Tensor): 120 | kpts0 = kpts0.cpu().numpy() 121 | if isinstance(kpts1, torch.Tensor): 122 | kpts1 = kpts1.cpu().numpy() 123 | assert len(kpts0) == len(kpts1) 124 | if color is None: 125 | color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() 126 | elif len(color) > 0 and not isinstance(color[0], (tuple, list)): 127 | color = [color] * len(kpts0) 128 | 129 | if lw > 0: 130 | for i in range(len(kpts0)): 131 | line = matplotlib.patches.ConnectionPatch( 132 | xyA=(kpts0[i, 0], kpts0[i, 1]), 133 | xyB=(kpts1[i, 0], kpts1[i, 1]), 134 | coordsA=ax0.transData, 135 | coordsB=ax1.transData, 136 | axesA=ax0, 137 | axesB=ax1, 138 | zorder=1, 139 | color=color[i], 140 | linewidth=lw, 141 | clip_on=True, 142 | alpha=a, 143 | label=None if labels is None else labels[i], 144 | picker=5.0, 145 | ) 146 | line.set_annotation_clip(True) 147 | fig.add_artist(line) 148 | 149 | # freeze the axes to prevent the transform to change 150 | ax0.autoscale(enable=False) 151 | ax1.autoscale(enable=False) 152 | 153 | if ps > 0: 154 | ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) 155 | ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) 156 | 157 | 158 | def add_text( 159 | idx, 160 | text, 161 | pos=(0.01, 0.99), 162 | fs=15, 163 | color="w", 164 | lcolor="k", 165 | lwidth=2, 166 | ha="left", 167 | va="top", 168 | ): 169 | ax = plt.gcf().axes[idx] 170 | t = ax.text( 171 | *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes 172 | ) 173 | if lcolor is not None: 174 | t.set_path_effects( 175 | [ 176 | path_effects.Stroke(linewidth=lwidth, foreground=lcolor), 177 | path_effects.Normal(), 178 | ] 179 | ) 180 | 181 | 182 | def save_plot(path, **kw): 183 | """Save the current figure without any white margin.""" 184 | plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) 185 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lightglue" 3 | description = "LightGlue: Local Feature Matching at Light Speed" 4 | version = "0.0" 5 | authors = [ 6 | {name = "Philipp Lindenberger"}, 7 | {name = "Paul-Edouard Sarlin"}, 8 | ] 9 | readme = "README.md" 10 | requires-python = ">=3.6" 11 | license = {file = "LICENSE"} 12 | classifiers = [ 13 | "Programming Language :: Python :: 3", 14 | "License :: OSI Approved :: Apache Software License", 15 | "Operating System :: OS Independent", 16 | ] 17 | urls = {Repository = "https://github.com/cvg/LightGlue/"} 18 | dynamic = ["dependencies"] 19 | 20 | [project.optional-dependencies] 21 | dev = ["black==23.12.1", "flake8", "isort"] 22 | 23 | [tool.setuptools] 24 | packages = ["lightglue"] 25 | 26 | [tool.setuptools.dynamic] 27 | dependencies = {file = ["requirements.txt"]} 28 | 29 | [tool.isort] 30 | profile = "black" 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.1 2 | torchvision>=0.3 3 | numpy 4 | opencv-python 5 | matplotlib 6 | kornia>=0.6.11 --------------------------------------------------------------------------------