├── .flake8
├── .gitattributes
├── .github
└── workflows
│ └── code-quality.yml
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── assets
├── DSC_0410.JPG
├── DSC_0411.JPG
├── architecture.svg
├── benchmark.png
├── benchmark_cpu.png
├── easy_hard.jpg
├── sacre_coeur1.jpg
├── sacre_coeur2.jpg
├── teaser.svg
├── wf_base.png
├── wf_motionbrush.png
└── wf_video.png
├── benchmark.py
├── demo.ipynb
├── lightglue
├── __init__.py
├── aliked.py
├── disk.py
├── lightglue.py
├── sift.py
├── superpoint.py
├── utils.py
└── viz2d.py
├── nodes.py
├── pyproject.toml
├── requirements.txt
├── tools
└── draw.html
├── workflow.json
├── workflow_lightgluemotionbrush.json
└── workflow_video.json
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | extend-ignore = E203
4 | exclude = .git,__pycache__,build,.venv/
5 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-documentation
--------------------------------------------------------------------------------
/.github/workflows/code-quality.yml:
--------------------------------------------------------------------------------
1 | name: Format and Lint Checks
2 | on:
3 | push:
4 | branches:
5 | - main
6 | paths:
7 | - '*.py'
8 | pull_request:
9 | types: [ assigned, opened, synchronize, reopened ]
10 | jobs:
11 | check:
12 | name: Format and Lint Checks
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v3
16 | - uses: actions/setup-python@v4
17 | with:
18 | python-version: '3.10'
19 | cache: 'pip'
20 | - run: python -m pip install --upgrade pip
21 | - run: python -m pip install .[dev]
22 | - run: python -m flake8 .
23 | - run: python -m isort . --check-only --diff
24 | - run: python -m black . --check --diff
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /data/
2 | /outputs/
3 | /lightglue/weights/
4 | *-checkpoint.ipynb
5 | *.pth
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | cover/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | .pybuilder/
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | # For a library or package, you might want to ignore these files since the code is
93 | # intended to run in multiple environments; otherwise, check them in:
94 | # .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/#use-with-ide
116 | .pdm.toml
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
161 | # PyCharm
162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164 | # and can be added to the global gitignore or merged into this file. For a more nuclear
165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166 | .idea/
167 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 ETH Zurich
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # This is an ComfyUI implementation of LightGlue to generate `MotionBrush` for DragNUWA
2 |
3 | [LightGlue](https://github.com/cvg/LightGlue)
4 |
5 | ## Install
6 |
7 | 1. Clone this repo into custom_nodes directory of ComfyUI location
8 |
9 | 2. Run pip install -r requirements.txt
10 |
11 | ## Examples
12 |
13 | video workflow
14 |
15 |
16 |
17 | https://github.com/chaojie/ComfyUI-LightGlue/blob/main/workflow_video.json
18 |
19 | LightGlueMotionBrush & DragNUWA
20 |
21 |
22 |
23 | https://github.com/chaojie/ComfyUI-LightGlue/blob/main/workflow_lightgluemotionbrush.json
24 |
25 | ## Tools
26 |
27 | [Motion Brush Visualization Tool](https://chaojie.github.io/ComfyUI-LightGlue/tools/draw.html)
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .nodes import NODE_CLASS_MAPPINGS
2 |
3 | __all__ = ['NODE_CLASS_MAPPINGS']
4 |
5 |
--------------------------------------------------------------------------------
/assets/DSC_0410.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/DSC_0410.JPG
--------------------------------------------------------------------------------
/assets/DSC_0411.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/DSC_0411.JPG
--------------------------------------------------------------------------------
/assets/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/benchmark.png
--------------------------------------------------------------------------------
/assets/benchmark_cpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/benchmark_cpu.png
--------------------------------------------------------------------------------
/assets/easy_hard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/easy_hard.jpg
--------------------------------------------------------------------------------
/assets/sacre_coeur1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/sacre_coeur1.jpg
--------------------------------------------------------------------------------
/assets/sacre_coeur2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/sacre_coeur2.jpg
--------------------------------------------------------------------------------
/assets/wf_base.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/wf_base.png
--------------------------------------------------------------------------------
/assets/wf_motionbrush.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/wf_motionbrush.png
--------------------------------------------------------------------------------
/assets/wf_video.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-LightGlue/5ad80dedfda366dfb9830ae6b358fbcf014f0b94/assets/wf_video.png
--------------------------------------------------------------------------------
/benchmark.py:
--------------------------------------------------------------------------------
1 | # Benchmark script for LightGlue on real images
2 | import argparse
3 | import time
4 | from collections import defaultdict
5 | from pathlib import Path
6 |
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import torch
10 | import torch._dynamo
11 |
12 | from lightglue import LightGlue, SuperPoint
13 | from lightglue.utils import load_image
14 |
15 | torch.set_grad_enabled(False)
16 |
17 |
18 | def measure(matcher, data, device="cuda", r=100):
19 | timings = np.zeros((r, 1))
20 | if device.type == "cuda":
21 | starter = torch.cuda.Event(enable_timing=True)
22 | ender = torch.cuda.Event(enable_timing=True)
23 | # warmup
24 | for _ in range(10):
25 | _ = matcher(data)
26 | # measurements
27 | with torch.no_grad():
28 | for rep in range(r):
29 | if device.type == "cuda":
30 | starter.record()
31 | _ = matcher(data)
32 | ender.record()
33 | # sync gpu
34 | torch.cuda.synchronize()
35 | curr_time = starter.elapsed_time(ender)
36 | else:
37 | start = time.perf_counter()
38 | _ = matcher(data)
39 | curr_time = (time.perf_counter() - start) * 1e3
40 | timings[rep] = curr_time
41 | mean_syn = np.sum(timings) / r
42 | std_syn = np.std(timings)
43 | return {"mean": mean_syn, "std": std_syn}
44 |
45 |
46 | def print_as_table(d, title, cnames):
47 | print()
48 | header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
49 | print(header)
50 | print("-" * len(header))
51 | for k, l in d.items():
52 | print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
57 | parser.add_argument(
58 | "--device",
59 | choices=["auto", "cuda", "cpu", "mps"],
60 | default="auto",
61 | help="device to benchmark on",
62 | )
63 | parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
64 | parser.add_argument(
65 | "--no_flash", action="store_true", help="disable FlashAttention"
66 | )
67 | parser.add_argument(
68 | "--no_prune_thresholds",
69 | action="store_true",
70 | help="disable pruning thresholds (i.e. always do pruning)",
71 | )
72 | parser.add_argument(
73 | "--add_superglue",
74 | action="store_true",
75 | help="add SuperGlue to the benchmark (requires hloc)",
76 | )
77 | parser.add_argument(
78 | "--measure", default="time", choices=["time", "log-time", "throughput"]
79 | )
80 | parser.add_argument(
81 | "--repeat", "--r", type=int, default=100, help="repetitions of measurements"
82 | )
83 | parser.add_argument(
84 | "--num_keypoints",
85 | nargs="+",
86 | type=int,
87 | default=[256, 512, 1024, 2048, 4096],
88 | help="number of keypoints (list separated by spaces)",
89 | )
90 | parser.add_argument(
91 | "--matmul_precision", default="highest", choices=["highest", "high", "medium"]
92 | )
93 | parser.add_argument(
94 | "--save", default=None, type=str, help="path where figure should be saved"
95 | )
96 | args = parser.parse_intermixed_args()
97 |
98 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99 | if args.device != "auto":
100 | device = torch.device(args.device)
101 |
102 | print("Running benchmark on device:", device)
103 |
104 | images = Path("assets")
105 | inputs = {
106 | "easy": (
107 | load_image(images / "DSC_0411.JPG"),
108 | load_image(images / "DSC_0410.JPG"),
109 | ),
110 | "difficult": (
111 | load_image(images / "sacre_coeur1.jpg"),
112 | load_image(images / "sacre_coeur2.jpg"),
113 | ),
114 | }
115 |
116 | configs = {
117 | "LightGlue-full": {
118 | "depth_confidence": -1,
119 | "width_confidence": -1,
120 | },
121 | # 'LG-prune': {
122 | # 'width_confidence': -1,
123 | # },
124 | # 'LG-depth': {
125 | # 'depth_confidence': -1,
126 | # },
127 | "LightGlue-adaptive": {},
128 | }
129 |
130 | if args.compile:
131 | configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
132 |
133 | sg_configs = {
134 | # 'SuperGlue': {},
135 | "SuperGlue-fast": {"sinkhorn_iterations": 5}
136 | }
137 |
138 | torch.set_float32_matmul_precision(args.matmul_precision)
139 |
140 | results = {k: defaultdict(list) for k, v in inputs.items()}
141 |
142 | extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
143 | extractor = extractor.eval().to(device)
144 | figsize = (len(inputs) * 4.5, 4.5)
145 | fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
146 | axes = axes if len(inputs) > 1 else [axes]
147 | fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
148 |
149 | for title, ax in zip(inputs.keys(), axes):
150 | ax.set_xscale("log", base=2)
151 | bases = [2**x for x in range(7, 16)]
152 | ax.set_xticks(bases, bases)
153 | ax.grid(which="major")
154 | if args.measure == "log-time":
155 | ax.set_yscale("log")
156 | yticks = [10**x for x in range(6)]
157 | ax.set_yticks(yticks, yticks)
158 | mpos = [10**x * i for x in range(6) for i in range(2, 10)]
159 | mlabel = [
160 | 10**x * i if i in [2, 5] else None
161 | for x in range(6)
162 | for i in range(2, 10)
163 | ]
164 | ax.set_yticks(mpos, mlabel, minor=True)
165 | ax.grid(which="minor", linewidth=0.2)
166 | ax.set_title(title)
167 |
168 | ax.set_xlabel("# keypoints")
169 | if args.measure == "throughput":
170 | ax.set_ylabel("Throughput [pairs/s]")
171 | else:
172 | ax.set_ylabel("Latency [ms]")
173 |
174 | for name, conf in configs.items():
175 | print("Run benchmark for:", name)
176 | torch.cuda.empty_cache()
177 | matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
178 | if args.no_prune_thresholds:
179 | matcher.pruning_keypoint_thresholds = {
180 | k: -1 for k in matcher.pruning_keypoint_thresholds
181 | }
182 | matcher = matcher.eval().to(device)
183 | if name.endswith("compile"):
184 | import torch._dynamo
185 |
186 | torch._dynamo.reset() # avoid buffer overflow
187 | matcher.compile()
188 | for pair_name, ax in zip(inputs.keys(), axes):
189 | image0, image1 = [x.to(device) for x in inputs[pair_name]]
190 | runtimes = []
191 | for num_kpts in args.num_keypoints:
192 | extractor.conf.max_num_keypoints = num_kpts
193 | feats0 = extractor.extract(image0)
194 | feats1 = extractor.extract(image1)
195 | runtime = measure(
196 | matcher,
197 | {"image0": feats0, "image1": feats1},
198 | device=device,
199 | r=args.repeat,
200 | )["mean"]
201 | results[pair_name][name].append(
202 | 1000 / runtime if args.measure == "throughput" else runtime
203 | )
204 | ax.plot(
205 | args.num_keypoints, results[pair_name][name], label=name, marker="o"
206 | )
207 | del matcher, feats0, feats1
208 |
209 | if args.add_superglue:
210 | from hloc.matchers.superglue import SuperGlue
211 |
212 | for name, conf in sg_configs.items():
213 | print("Run benchmark for:", name)
214 | matcher = SuperGlue(conf)
215 | matcher = matcher.eval().to(device)
216 | for pair_name, ax in zip(inputs.keys(), axes):
217 | image0, image1 = [x.to(device) for x in inputs[pair_name]]
218 | runtimes = []
219 | for num_kpts in args.num_keypoints:
220 | extractor.conf.max_num_keypoints = num_kpts
221 | feats0 = extractor.extract(image0)
222 | feats1 = extractor.extract(image1)
223 | data = {
224 | "image0": image0[None],
225 | "image1": image1[None],
226 | **{k + "0": v for k, v in feats0.items()},
227 | **{k + "1": v for k, v in feats1.items()},
228 | }
229 | data["scores0"] = data["keypoint_scores0"]
230 | data["scores1"] = data["keypoint_scores1"]
231 | data["descriptors0"] = (
232 | data["descriptors0"].transpose(-1, -2).contiguous()
233 | )
234 | data["descriptors1"] = (
235 | data["descriptors1"].transpose(-1, -2).contiguous()
236 | )
237 | runtime = measure(matcher, data, device=device, r=args.repeat)[
238 | "mean"
239 | ]
240 | results[pair_name][name].append(
241 | 1000 / runtime if args.measure == "throughput" else runtime
242 | )
243 | ax.plot(
244 | args.num_keypoints, results[pair_name][name], label=name, marker="o"
245 | )
246 | del matcher, data, image0, image1, feats0, feats1
247 |
248 | for name, runtimes in results.items():
249 | print_as_table(runtimes, name, args.num_keypoints)
250 |
251 | axes[0].legend()
252 | fig.tight_layout()
253 | if args.save:
254 | plt.savefig(args.save, dpi=fig.dpi)
255 | plt.show()
256 |
--------------------------------------------------------------------------------
/lightglue/__init__.py:
--------------------------------------------------------------------------------
1 | from .aliked import ALIKED # noqa
2 | from .disk import DISK # noqa
3 | from .lightglue import LightGlue # noqa
4 | from .sift import SIFT # noqa
5 | from .superpoint import SuperPoint # noqa
6 | from .utils import match_pair # noqa
7 |
--------------------------------------------------------------------------------
/lightglue/aliked.py:
--------------------------------------------------------------------------------
1 | # BSD 3-Clause License
2 |
3 | # Copyright (c) 2022, Zhao Xiaoming
4 | # All rights reserved.
5 |
6 | # Redistribution and use in source and binary forms, with or without
7 | # modification, are permitted provided that the following conditions are met:
8 |
9 | # 1. Redistributions of source code must retain the above copyright notice, this
10 | # list of conditions and the following disclaimer.
11 |
12 | # 2. Redistributions in binary form must reproduce the above copyright notice,
13 | # this list of conditions and the following disclaimer in the documentation
14 | # and/or other materials provided with the distribution.
15 |
16 | # 3. Neither the name of the copyright holder nor the names of its
17 | # contributors may be used to endorse or promote products derived from
18 | # this software without specific prior written permission.
19 |
20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
31 | # Authors:
32 | # Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
33 | # Code from https://github.com/Shiaoming/ALIKED
34 |
35 | from typing import Callable, Optional
36 |
37 | import torch
38 | import torch.nn.functional as F
39 | import torchvision
40 | from kornia.color import grayscale_to_rgb
41 | from torch import nn
42 | from torch.nn.modules.utils import _pair
43 | from torchvision.models import resnet
44 |
45 | from .utils import Extractor
46 |
47 |
48 | def get_patches(
49 | tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
50 | ) -> torch.Tensor:
51 | c, h, w = tensor.shape
52 | corner = (required_corners - ps / 2 + 1).long()
53 | corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
54 | corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
55 | offset = torch.arange(0, ps)
56 |
57 | kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
58 | x, y = torch.meshgrid(offset, offset, **kw)
59 | patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
60 | patches = patches.to(corner) + corner[None, None]
61 | pts = patches.reshape(-1, 2)
62 | sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
63 | sampled = sampled.reshape(ps, ps, -1, c)
64 | assert sampled.shape[:3] == patches.shape[:3]
65 | return sampled.permute(2, 3, 0, 1)
66 |
67 |
68 | def simple_nms(scores: torch.Tensor, nms_radius: int):
69 | """Fast Non-maximum suppression to remove nearby points"""
70 |
71 | zeros = torch.zeros_like(scores)
72 | max_mask = scores == torch.nn.functional.max_pool2d(
73 | scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
74 | )
75 |
76 | for _ in range(2):
77 | supp_mask = (
78 | torch.nn.functional.max_pool2d(
79 | max_mask.float(),
80 | kernel_size=nms_radius * 2 + 1,
81 | stride=1,
82 | padding=nms_radius,
83 | )
84 | > 0
85 | )
86 | supp_scores = torch.where(supp_mask, zeros, scores)
87 | new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
88 | supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
89 | )
90 | max_mask = max_mask | (new_max_mask & (~supp_mask))
91 | return torch.where(max_mask, scores, zeros)
92 |
93 |
94 | class DKD(nn.Module):
95 | def __init__(
96 | self,
97 | radius: int = 2,
98 | top_k: int = 0,
99 | scores_th: float = 0.2,
100 | n_limit: int = 20000,
101 | ):
102 | """
103 | Args:
104 | radius: soft detection radius, kernel size is (2 * radius + 1)
105 | top_k: top_k > 0: return top k keypoints
106 | scores_th: top_k <= 0 threshold mode:
107 | scores_th > 0: return keypoints with scores>scores_th
108 | else: return keypoints with scores > scores.mean()
109 | n_limit: max number of keypoint in threshold mode
110 | """
111 | super().__init__()
112 | self.radius = radius
113 | self.top_k = top_k
114 | self.scores_th = scores_th
115 | self.n_limit = n_limit
116 | self.kernel_size = 2 * self.radius + 1
117 | self.temperature = 0.1 # tuned temperature
118 | self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
119 | # local xy grid
120 | x = torch.linspace(-self.radius, self.radius, self.kernel_size)
121 | # (kernel_size*kernel_size) x 2 : (w,h)
122 | kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
123 | self.hw_grid = (
124 | torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
125 | )
126 |
127 | def forward(
128 | self,
129 | scores_map: torch.Tensor,
130 | sub_pixel: bool = True,
131 | image_size: Optional[torch.Tensor] = None,
132 | ):
133 | """
134 | :param scores_map: Bx1xHxW
135 | :param descriptor_map: BxCxHxW
136 | :param sub_pixel: whether to use sub-pixel keypoint detection
137 | :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
138 | """
139 | b, c, h, w = scores_map.shape
140 | scores_nograd = scores_map.detach()
141 | nms_scores = simple_nms(scores_nograd, self.radius)
142 |
143 | # remove border
144 | nms_scores[:, :, : self.radius, :] = 0
145 | nms_scores[:, :, :, : self.radius] = 0
146 | if image_size is not None:
147 | for i in range(scores_map.shape[0]):
148 | w, h = image_size[i].long()
149 | nms_scores[i, :, h.item() - self.radius :, :] = 0
150 | nms_scores[i, :, :, w.item() - self.radius :] = 0
151 | else:
152 | nms_scores[:, :, -self.radius :, :] = 0
153 | nms_scores[:, :, :, -self.radius :] = 0
154 |
155 | # detect keypoints without grad
156 | if self.top_k > 0:
157 | topk = torch.topk(nms_scores.view(b, -1), self.top_k)
158 | indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
159 | else:
160 | if self.scores_th > 0:
161 | masks = nms_scores > self.scores_th
162 | if masks.sum() == 0:
163 | th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
164 | masks = nms_scores > th.reshape(b, 1, 1, 1)
165 | else:
166 | th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
167 | masks = nms_scores > th.reshape(b, 1, 1, 1)
168 | masks = masks.reshape(b, -1)
169 |
170 | indices_keypoints = [] # list, B x (any size)
171 | scores_view = scores_nograd.reshape(b, -1)
172 | for mask, scores in zip(masks, scores_view):
173 | indices = mask.nonzero()[:, 0]
174 | if len(indices) > self.n_limit:
175 | kpts_sc = scores[indices]
176 | sort_idx = kpts_sc.sort(descending=True)[1]
177 | sel_idx = sort_idx[: self.n_limit]
178 | indices = indices[sel_idx]
179 | indices_keypoints.append(indices)
180 |
181 | wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
182 |
183 | keypoints = []
184 | scoredispersitys = []
185 | kptscores = []
186 | if sub_pixel:
187 | # detect soft keypoints with grad backpropagation
188 | patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
189 | self.hw_grid = self.hw_grid.to(scores_map) # to device
190 | for b_idx in range(b):
191 | patch = patches[b_idx].t() # (H*W) x (kernel**2)
192 | indices_kpt = indices_keypoints[
193 | b_idx
194 | ] # one dimension vector, say its size is M
195 | patch_scores = patch[indices_kpt] # M x (kernel**2)
196 | keypoints_xy_nms = torch.stack(
197 | [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
198 | dim=1,
199 | ) # Mx2
200 |
201 | # max is detached to prevent undesired backprop loops in the graph
202 | max_v = patch_scores.max(dim=1).values.detach()[:, None]
203 | x_exp = (
204 | (patch_scores - max_v) / self.temperature
205 | ).exp() # M * (kernel**2), in [0, 1]
206 |
207 | # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
208 | xy_residual = (
209 | x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
210 | ) # Soft-argmax, Mx2
211 |
212 | hw_grid_dist2 = (
213 | torch.norm(
214 | (self.hw_grid[None, :, :] - xy_residual[:, None, :])
215 | / self.radius,
216 | dim=-1,
217 | )
218 | ** 2
219 | )
220 | scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
221 |
222 | # compute result keypoints
223 | keypoints_xy = keypoints_xy_nms + xy_residual
224 | keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
225 |
226 | kptscore = torch.nn.functional.grid_sample(
227 | scores_map[b_idx].unsqueeze(0),
228 | keypoints_xy.view(1, 1, -1, 2),
229 | mode="bilinear",
230 | align_corners=True,
231 | )[
232 | 0, 0, 0, :
233 | ] # CxN
234 |
235 | keypoints.append(keypoints_xy)
236 | scoredispersitys.append(scoredispersity)
237 | kptscores.append(kptscore)
238 | else:
239 | for b_idx in range(b):
240 | indices_kpt = indices_keypoints[
241 | b_idx
242 | ] # one dimension vector, say its size is M
243 | # To avoid warning: UserWarning: __floordiv__ is deprecated
244 | keypoints_xy_nms = torch.stack(
245 | [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
246 | dim=1,
247 | ) # Mx2
248 | keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
249 | kptscore = torch.nn.functional.grid_sample(
250 | scores_map[b_idx].unsqueeze(0),
251 | keypoints_xy.view(1, 1, -1, 2),
252 | mode="bilinear",
253 | align_corners=True,
254 | )[
255 | 0, 0, 0, :
256 | ] # CxN
257 | keypoints.append(keypoints_xy)
258 | scoredispersitys.append(kptscore) # for jit.script compatability
259 | kptscores.append(kptscore)
260 |
261 | return keypoints, scoredispersitys, kptscores
262 |
263 |
264 | class InputPadder(object):
265 | """Pads images such that dimensions are divisible by 8"""
266 |
267 | def __init__(self, h: int, w: int, divis_by: int = 8):
268 | self.ht = h
269 | self.wd = w
270 | pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
271 | pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
272 | self._pad = [
273 | pad_wd // 2,
274 | pad_wd - pad_wd // 2,
275 | pad_ht // 2,
276 | pad_ht - pad_ht // 2,
277 | ]
278 |
279 | def pad(self, x: torch.Tensor):
280 | assert x.ndim == 4
281 | return F.pad(x, self._pad, mode="replicate")
282 |
283 | def unpad(self, x: torch.Tensor):
284 | assert x.ndim == 4
285 | ht = x.shape[-2]
286 | wd = x.shape[-1]
287 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
288 | return x[..., c[0] : c[1], c[2] : c[3]]
289 |
290 |
291 | class DeformableConv2d(nn.Module):
292 | def __init__(
293 | self,
294 | in_channels,
295 | out_channels,
296 | kernel_size=3,
297 | stride=1,
298 | padding=1,
299 | bias=False,
300 | mask=False,
301 | ):
302 | super(DeformableConv2d, self).__init__()
303 |
304 | self.padding = padding
305 | self.mask = mask
306 |
307 | self.channel_num = (
308 | 3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
309 | )
310 | self.offset_conv = nn.Conv2d(
311 | in_channels,
312 | self.channel_num,
313 | kernel_size=kernel_size,
314 | stride=stride,
315 | padding=self.padding,
316 | bias=True,
317 | )
318 |
319 | self.regular_conv = nn.Conv2d(
320 | in_channels=in_channels,
321 | out_channels=out_channels,
322 | kernel_size=kernel_size,
323 | stride=stride,
324 | padding=self.padding,
325 | bias=bias,
326 | )
327 |
328 | def forward(self, x):
329 | h, w = x.shape[2:]
330 | max_offset = max(h, w) / 4.0
331 |
332 | out = self.offset_conv(x)
333 | if self.mask:
334 | o1, o2, mask = torch.chunk(out, 3, dim=1)
335 | offset = torch.cat((o1, o2), dim=1)
336 | mask = torch.sigmoid(mask)
337 | else:
338 | offset = out
339 | mask = None
340 | offset = offset.clamp(-max_offset, max_offset)
341 | x = torchvision.ops.deform_conv2d(
342 | input=x,
343 | offset=offset,
344 | weight=self.regular_conv.weight,
345 | bias=self.regular_conv.bias,
346 | padding=self.padding,
347 | mask=mask,
348 | )
349 | return x
350 |
351 |
352 | def get_conv(
353 | inplanes,
354 | planes,
355 | kernel_size=3,
356 | stride=1,
357 | padding=1,
358 | bias=False,
359 | conv_type="conv",
360 | mask=False,
361 | ):
362 | if conv_type == "conv":
363 | conv = nn.Conv2d(
364 | inplanes,
365 | planes,
366 | kernel_size=kernel_size,
367 | stride=stride,
368 | padding=padding,
369 | bias=bias,
370 | )
371 | elif conv_type == "dcn":
372 | conv = DeformableConv2d(
373 | inplanes,
374 | planes,
375 | kernel_size=kernel_size,
376 | stride=stride,
377 | padding=_pair(padding),
378 | bias=bias,
379 | mask=mask,
380 | )
381 | else:
382 | raise TypeError
383 | return conv
384 |
385 |
386 | class ConvBlock(nn.Module):
387 | def __init__(
388 | self,
389 | in_channels,
390 | out_channels,
391 | gate: Optional[Callable[..., nn.Module]] = None,
392 | norm_layer: Optional[Callable[..., nn.Module]] = None,
393 | conv_type: str = "conv",
394 | mask: bool = False,
395 | ):
396 | super().__init__()
397 | if gate is None:
398 | self.gate = nn.ReLU(inplace=True)
399 | else:
400 | self.gate = gate
401 | if norm_layer is None:
402 | norm_layer = nn.BatchNorm2d
403 | self.conv1 = get_conv(
404 | in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
405 | )
406 | self.bn1 = norm_layer(out_channels)
407 | self.conv2 = get_conv(
408 | out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
409 | )
410 | self.bn2 = norm_layer(out_channels)
411 |
412 | def forward(self, x):
413 | x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
414 | x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
415 | return x
416 |
417 |
418 | # modified based on torchvision\models\resnet.py#27->BasicBlock
419 | class ResBlock(nn.Module):
420 | expansion: int = 1
421 |
422 | def __init__(
423 | self,
424 | inplanes: int,
425 | planes: int,
426 | stride: int = 1,
427 | downsample: Optional[nn.Module] = None,
428 | groups: int = 1,
429 | base_width: int = 64,
430 | dilation: int = 1,
431 | gate: Optional[Callable[..., nn.Module]] = None,
432 | norm_layer: Optional[Callable[..., nn.Module]] = None,
433 | conv_type: str = "conv",
434 | mask: bool = False,
435 | ) -> None:
436 | super(ResBlock, self).__init__()
437 | if gate is None:
438 | self.gate = nn.ReLU(inplace=True)
439 | else:
440 | self.gate = gate
441 | if norm_layer is None:
442 | norm_layer = nn.BatchNorm2d
443 | if groups != 1 or base_width != 64:
444 | raise ValueError("ResBlock only supports groups=1 and base_width=64")
445 | if dilation > 1:
446 | raise NotImplementedError("Dilation > 1 not supported in ResBlock")
447 | # Both self.conv1 and self.downsample layers
448 | # downsample the input when stride != 1
449 | self.conv1 = get_conv(
450 | inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
451 | )
452 | self.bn1 = norm_layer(planes)
453 | self.conv2 = get_conv(
454 | planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
455 | )
456 | self.bn2 = norm_layer(planes)
457 | self.downsample = downsample
458 | self.stride = stride
459 |
460 | def forward(self, x: torch.Tensor) -> torch.Tensor:
461 | identity = x
462 |
463 | out = self.conv1(x)
464 | out = self.bn1(out)
465 | out = self.gate(out)
466 |
467 | out = self.conv2(out)
468 | out = self.bn2(out)
469 |
470 | if self.downsample is not None:
471 | identity = self.downsample(x)
472 |
473 | out += identity
474 | out = self.gate(out)
475 |
476 | return out
477 |
478 |
479 | class SDDH(nn.Module):
480 | def __init__(
481 | self,
482 | dims: int,
483 | kernel_size: int = 3,
484 | n_pos: int = 8,
485 | gate=nn.ReLU(),
486 | conv2D=False,
487 | mask=False,
488 | ):
489 | super(SDDH, self).__init__()
490 | self.kernel_size = kernel_size
491 | self.n_pos = n_pos
492 | self.conv2D = conv2D
493 | self.mask = mask
494 |
495 | self.get_patches_func = get_patches
496 |
497 | # estimate offsets
498 | self.channel_num = 3 * n_pos if mask else 2 * n_pos
499 | self.offset_conv = nn.Sequential(
500 | nn.Conv2d(
501 | dims,
502 | self.channel_num,
503 | kernel_size=kernel_size,
504 | stride=1,
505 | padding=0,
506 | bias=True,
507 | ),
508 | gate,
509 | nn.Conv2d(
510 | self.channel_num,
511 | self.channel_num,
512 | kernel_size=1,
513 | stride=1,
514 | padding=0,
515 | bias=True,
516 | ),
517 | )
518 |
519 | # sampled feature conv
520 | self.sf_conv = nn.Conv2d(
521 | dims, dims, kernel_size=1, stride=1, padding=0, bias=False
522 | )
523 |
524 | # convM
525 | if not conv2D:
526 | # deformable desc weights
527 | agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
528 | self.register_parameter("agg_weights", agg_weights)
529 | else:
530 | self.convM = nn.Conv2d(
531 | dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
532 | )
533 |
534 | def forward(self, x, keypoints):
535 | # x: [B,C,H,W]
536 | # keypoints: list, [[N_kpts,2], ...] (w,h)
537 | b, c, h, w = x.shape
538 | wh = torch.tensor([[w - 1, h - 1]], device=x.device)
539 | max_offset = max(h, w) / 4.0
540 |
541 | offsets = []
542 | descriptors = []
543 | # get offsets for each keypoint
544 | for ib in range(b):
545 | xi, kptsi = x[ib], keypoints[ib]
546 | kptsi_wh = (kptsi / 2 + 0.5) * wh
547 | N_kpts = len(kptsi)
548 |
549 | if self.kernel_size > 1:
550 | patch = self.get_patches_func(
551 | xi, kptsi_wh.long(), self.kernel_size
552 | ) # [N_kpts, C, K, K]
553 | else:
554 | kptsi_wh_long = kptsi_wh.long()
555 | patch = (
556 | xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
557 | .permute(1, 0)
558 | .reshape(N_kpts, c, 1, 1)
559 | )
560 |
561 | offset = self.offset_conv(patch).clamp(
562 | -max_offset, max_offset
563 | ) # [N_kpts, 2*n_pos, 1, 1]
564 | if self.mask:
565 | offset = (
566 | offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
567 | ) # [N_kpts, n_pos, 3]
568 | offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
569 | mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
570 | else:
571 | offset = (
572 | offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
573 | ) # [N_kpts, n_pos, 2]
574 | offsets.append(offset) # for visualization
575 |
576 | # get sample positions
577 | pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
578 | pos = 2.0 * pos / wh[None] - 1
579 | pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
580 |
581 | # sample features
582 | features = F.grid_sample(
583 | xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
584 | ) # [1,C,(N_kpts*n_pos),1]
585 | features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
586 | 1, 0, 2, 3
587 | ) # [N_kpts, C, n_pos, 1]
588 | if self.mask:
589 | features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
590 |
591 | features = torch.selu_(self.sf_conv(features)).squeeze(
592 | -1
593 | ) # [N_kpts, C, n_pos]
594 | # convM
595 | if not self.conv2D:
596 | descs = torch.einsum(
597 | "ncp,pcd->nd", features, self.agg_weights
598 | ) # [N_kpts, C]
599 | else:
600 | features = features.reshape(N_kpts, -1)[
601 | :, :, None, None
602 | ] # [N_kpts, C*n_pos, 1, 1]
603 | descs = self.convM(features).squeeze() # [N_kpts, C]
604 |
605 | # normalize
606 | descs = F.normalize(descs, p=2.0, dim=1)
607 | descriptors.append(descs)
608 |
609 | return descriptors, offsets
610 |
611 |
612 | class ALIKED(Extractor):
613 | default_conf = {
614 | "model_name": "aliked-n16",
615 | "max_num_keypoints": -1,
616 | "detection_threshold": 0.2,
617 | "nms_radius": 2,
618 | }
619 |
620 | checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
621 |
622 | n_limit_max = 20000
623 |
624 | # c1, c2, c3, c4, dim, K, M
625 | cfgs = {
626 | "aliked-t16": [8, 16, 32, 64, 64, 3, 16],
627 | "aliked-n16": [16, 32, 64, 128, 128, 3, 16],
628 | "aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
629 | "aliked-n32": [16, 32, 64, 128, 128, 3, 32],
630 | }
631 | preprocess_conf = {
632 | "resize": 1024,
633 | }
634 |
635 | required_data_keys = ["image"]
636 |
637 | def __init__(self, **conf):
638 | super().__init__(**conf) # Update with default configuration.
639 | conf = self.conf
640 | c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
641 | conv_types = ["conv", "conv", "dcn", "dcn"]
642 | conv2D = False
643 | mask = False
644 |
645 | # build model
646 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
647 | self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
648 | self.norm = nn.BatchNorm2d
649 | self.gate = nn.SELU(inplace=True)
650 | self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
651 | self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
652 | self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
653 | self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
654 |
655 | self.conv1 = resnet.conv1x1(c1, dim // 4)
656 | self.conv2 = resnet.conv1x1(c2, dim // 4)
657 | self.conv3 = resnet.conv1x1(c3, dim // 4)
658 | self.conv4 = resnet.conv1x1(dim, dim // 4)
659 | self.upsample2 = nn.Upsample(
660 | scale_factor=2, mode="bilinear", align_corners=True
661 | )
662 | self.upsample4 = nn.Upsample(
663 | scale_factor=4, mode="bilinear", align_corners=True
664 | )
665 | self.upsample8 = nn.Upsample(
666 | scale_factor=8, mode="bilinear", align_corners=True
667 | )
668 | self.upsample32 = nn.Upsample(
669 | scale_factor=32, mode="bilinear", align_corners=True
670 | )
671 | self.score_head = nn.Sequential(
672 | resnet.conv1x1(dim, 8),
673 | self.gate,
674 | resnet.conv3x3(8, 4),
675 | self.gate,
676 | resnet.conv3x3(4, 4),
677 | self.gate,
678 | resnet.conv3x3(4, 1),
679 | )
680 | self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
681 | self.dkd = DKD(
682 | radius=conf.nms_radius,
683 | top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
684 | scores_th=conf.detection_threshold,
685 | n_limit=conf.max_num_keypoints
686 | if conf.max_num_keypoints > 0
687 | else self.n_limit_max,
688 | )
689 |
690 | state_dict = torch.hub.load_state_dict_from_url(
691 | self.checkpoint_url.format(conf.model_name), map_location="cpu"
692 | )
693 | self.load_state_dict(state_dict, strict=True)
694 |
695 | def get_resblock(self, c_in, c_out, conv_type, mask):
696 | return ResBlock(
697 | c_in,
698 | c_out,
699 | 1,
700 | nn.Conv2d(c_in, c_out, 1),
701 | gate=self.gate,
702 | norm_layer=self.norm,
703 | conv_type=conv_type,
704 | mask=mask,
705 | )
706 |
707 | def extract_dense_map(self, image):
708 | # Pads images such that dimensions are divisible by
709 | div_by = 2**5
710 | padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
711 | image = padder.pad(image)
712 |
713 | # ================================== feature encoder
714 | x1 = self.block1(image) # B x c1 x H x W
715 | x2 = self.pool2(x1)
716 | x2 = self.block2(x2) # B x c2 x H/2 x W/2
717 | x3 = self.pool4(x2)
718 | x3 = self.block3(x3) # B x c3 x H/8 x W/8
719 | x4 = self.pool4(x3)
720 | x4 = self.block4(x4) # B x dim x H/32 x W/32
721 | # ================================== feature aggregation
722 | x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
723 | x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
724 | x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
725 | x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
726 | x2_up = self.upsample2(x2) # B x dim//4 x H x W
727 | x3_up = self.upsample8(x3) # B x dim//4 x H x W
728 | x4_up = self.upsample32(x4) # B x dim//4 x H x W
729 | x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
730 | # ================================== score head
731 | score_map = torch.sigmoid(self.score_head(x1234))
732 | feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
733 |
734 | # Unpads images
735 | feature_map = padder.unpad(feature_map)
736 | score_map = padder.unpad(score_map)
737 |
738 | return feature_map, score_map
739 |
740 | def forward(self, data: dict) -> dict:
741 | image = data["image"]
742 | if image.shape[1] == 1:
743 | image = grayscale_to_rgb(image)
744 | feature_map, score_map = self.extract_dense_map(image)
745 | keypoints, kptscores, scoredispersitys = self.dkd(
746 | score_map, image_size=data.get("image_size")
747 | )
748 | descriptors, offsets = self.desc_head(feature_map, keypoints)
749 |
750 | _, _, h, w = image.shape
751 | wh = torch.tensor([w - 1, h - 1], device=image.device)
752 | # no padding required
753 | # we can set detection_threshold=-1 and conf.max_num_keypoints > 0
754 | return {
755 | "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
756 | "descriptors": torch.stack(descriptors), # B x N x D
757 | "keypoint_scores": torch.stack(kptscores), # B x N
758 | }
759 |
--------------------------------------------------------------------------------
/lightglue/disk.py:
--------------------------------------------------------------------------------
1 | import kornia
2 | import torch
3 |
4 | from .utils import Extractor
5 |
6 |
7 | class DISK(Extractor):
8 | default_conf = {
9 | "weights": "depth",
10 | "max_num_keypoints": None,
11 | "desc_dim": 128,
12 | "nms_window_size": 5,
13 | "detection_threshold": 0.0,
14 | "pad_if_not_divisible": True,
15 | }
16 |
17 | preprocess_conf = {
18 | "resize": 1024,
19 | "grayscale": False,
20 | }
21 |
22 | required_data_keys = ["image"]
23 |
24 | def __init__(self, **conf) -> None:
25 | super().__init__(**conf) # Update with default configuration.
26 | self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
27 |
28 | def forward(self, data: dict) -> dict:
29 | """Compute keypoints, scores, descriptors for image"""
30 | for key in self.required_data_keys:
31 | assert key in data, f"Missing key {key} in data"
32 | image = data["image"]
33 | if image.shape[1] == 1:
34 | image = kornia.color.grayscale_to_rgb(image)
35 | features = self.model(
36 | image,
37 | n=self.conf.max_num_keypoints,
38 | window_size=self.conf.nms_window_size,
39 | score_threshold=self.conf.detection_threshold,
40 | pad_if_not_divisible=self.conf.pad_if_not_divisible,
41 | )
42 | keypoints = [f.keypoints for f in features]
43 | scores = [f.detection_scores for f in features]
44 | descriptors = [f.descriptors for f in features]
45 | del features
46 |
47 | keypoints = torch.stack(keypoints, 0)
48 | scores = torch.stack(scores, 0)
49 | descriptors = torch.stack(descriptors, 0)
50 |
51 | return {
52 | "keypoints": keypoints.to(image).contiguous(),
53 | "keypoint_scores": scores.to(image).contiguous(),
54 | "descriptors": descriptors.to(image).contiguous(),
55 | }
56 |
--------------------------------------------------------------------------------
/lightglue/lightglue.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 | from types import SimpleNamespace
4 | from typing import Callable, List, Optional, Tuple
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import nn
10 |
11 | try:
12 | from flash_attn.modules.mha import FlashCrossAttention
13 | except ModuleNotFoundError:
14 | FlashCrossAttention = None
15 |
16 | if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
17 | FLASH_AVAILABLE = True
18 | else:
19 | FLASH_AVAILABLE = False
20 |
21 | torch.backends.cudnn.deterministic = True
22 |
23 |
24 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
25 | def normalize_keypoints(
26 | kpts: torch.Tensor, size: Optional[torch.Tensor] = None
27 | ) -> torch.Tensor:
28 | if size is None:
29 | size = 1 + kpts.max(-2).values - kpts.min(-2).values
30 | elif not isinstance(size, torch.Tensor):
31 | size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
32 | size = size.to(kpts)
33 | shift = size / 2
34 | scale = size.max(-1).values / 2
35 | kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
36 | return kpts
37 |
38 |
39 | def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
40 | if length <= x.shape[-2]:
41 | return x, torch.ones_like(x[..., :1], dtype=torch.bool)
42 | pad = torch.ones(
43 | *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
44 | )
45 | y = torch.cat([x, pad], dim=-2)
46 | mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
47 | mask[..., : x.shape[-2], :] = True
48 | return y, mask
49 |
50 |
51 | def rotate_half(x: torch.Tensor) -> torch.Tensor:
52 | x = x.unflatten(-1, (-1, 2))
53 | x1, x2 = x.unbind(dim=-1)
54 | return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
55 |
56 |
57 | def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
58 | return (t * freqs[0]) + (rotate_half(t) * freqs[1])
59 |
60 |
61 | class LearnableFourierPositionalEncoding(nn.Module):
62 | def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
63 | super().__init__()
64 | F_dim = F_dim if F_dim is not None else dim
65 | self.gamma = gamma
66 | self.Wr = nn.Linear(M, F_dim // 2, bias=False)
67 | nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
68 |
69 | def forward(self, x: torch.Tensor) -> torch.Tensor:
70 | """encode position vector"""
71 | projected = self.Wr(x)
72 | cosines, sines = torch.cos(projected), torch.sin(projected)
73 | emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
74 | return emb.repeat_interleave(2, dim=-1)
75 |
76 |
77 | class TokenConfidence(nn.Module):
78 | def __init__(self, dim: int) -> None:
79 | super().__init__()
80 | self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
81 |
82 | def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
83 | """get confidence tokens"""
84 | return (
85 | self.token(desc0.detach()).squeeze(-1),
86 | self.token(desc1.detach()).squeeze(-1),
87 | )
88 |
89 |
90 | class Attention(nn.Module):
91 | def __init__(self, allow_flash: bool) -> None:
92 | super().__init__()
93 | if allow_flash and not FLASH_AVAILABLE:
94 | warnings.warn(
95 | "FlashAttention is not available. For optimal speed, "
96 | "consider installing torch >= 2.0 or flash-attn.",
97 | stacklevel=2,
98 | )
99 | self.enable_flash = allow_flash and FLASH_AVAILABLE
100 | self.has_sdp = hasattr(F, "scaled_dot_product_attention")
101 | if allow_flash and FlashCrossAttention:
102 | self.flash_ = FlashCrossAttention()
103 | if self.has_sdp:
104 | torch.backends.cuda.enable_flash_sdp(allow_flash)
105 |
106 | def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
107 | if self.enable_flash and q.device.type == "cuda":
108 | # use torch 2.0 scaled_dot_product_attention with flash
109 | if self.has_sdp:
110 | args = [x.half().contiguous() for x in [q, k, v]]
111 | v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
112 | return v if mask is None else v.nan_to_num()
113 | else:
114 | assert mask is None
115 | q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
116 | m = self.flash_(q.half(), torch.stack([k, v], 2).half())
117 | return m.transpose(-2, -3).to(q.dtype).clone()
118 | elif self.has_sdp:
119 | args = [x.contiguous() for x in [q, k, v]]
120 | v = F.scaled_dot_product_attention(*args, attn_mask=mask)
121 | return v if mask is None else v.nan_to_num()
122 | else:
123 | s = q.shape[-1] ** -0.5
124 | sim = torch.einsum("...id,...jd->...ij", q, k) * s
125 | if mask is not None:
126 | sim.masked_fill(~mask, -float("inf"))
127 | attn = F.softmax(sim, -1)
128 | return torch.einsum("...ij,...jd->...id", attn, v)
129 |
130 |
131 | class SelfBlock(nn.Module):
132 | def __init__(
133 | self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
134 | ) -> None:
135 | super().__init__()
136 | self.embed_dim = embed_dim
137 | self.num_heads = num_heads
138 | assert self.embed_dim % num_heads == 0
139 | self.head_dim = self.embed_dim // num_heads
140 | self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
141 | self.inner_attn = Attention(flash)
142 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
143 | self.ffn = nn.Sequential(
144 | nn.Linear(2 * embed_dim, 2 * embed_dim),
145 | nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
146 | nn.GELU(),
147 | nn.Linear(2 * embed_dim, embed_dim),
148 | )
149 |
150 | def forward(
151 | self,
152 | x: torch.Tensor,
153 | encoding: torch.Tensor,
154 | mask: Optional[torch.Tensor] = None,
155 | ) -> torch.Tensor:
156 | qkv = self.Wqkv(x)
157 | qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
158 | q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
159 | q = apply_cached_rotary_emb(encoding, q)
160 | k = apply_cached_rotary_emb(encoding, k)
161 | context = self.inner_attn(q, k, v, mask=mask)
162 | message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
163 | return x + self.ffn(torch.cat([x, message], -1))
164 |
165 |
166 | class CrossBlock(nn.Module):
167 | def __init__(
168 | self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
169 | ) -> None:
170 | super().__init__()
171 | self.heads = num_heads
172 | dim_head = embed_dim // num_heads
173 | self.scale = dim_head**-0.5
174 | inner_dim = dim_head * num_heads
175 | self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
176 | self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
177 | self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
178 | self.ffn = nn.Sequential(
179 | nn.Linear(2 * embed_dim, 2 * embed_dim),
180 | nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
181 | nn.GELU(),
182 | nn.Linear(2 * embed_dim, embed_dim),
183 | )
184 | if flash and FLASH_AVAILABLE:
185 | self.flash = Attention(True)
186 | else:
187 | self.flash = None
188 |
189 | def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
190 | return func(x0), func(x1)
191 |
192 | def forward(
193 | self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
194 | ) -> List[torch.Tensor]:
195 | qk0, qk1 = self.map_(self.to_qk, x0, x1)
196 | v0, v1 = self.map_(self.to_v, x0, x1)
197 | qk0, qk1, v0, v1 = map(
198 | lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
199 | (qk0, qk1, v0, v1),
200 | )
201 | if self.flash is not None and qk0.device.type == "cuda":
202 | m0 = self.flash(qk0, qk1, v1, mask)
203 | m1 = self.flash(
204 | qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
205 | )
206 | else:
207 | qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
208 | sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
209 | if mask is not None:
210 | sim = sim.masked_fill(~mask, -float("inf"))
211 | attn01 = F.softmax(sim, dim=-1)
212 | attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
213 | m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
214 | m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
215 | if mask is not None:
216 | m0, m1 = m0.nan_to_num(), m1.nan_to_num()
217 | m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
218 | m0, m1 = self.map_(self.to_out, m0, m1)
219 | x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
220 | x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
221 | return x0, x1
222 |
223 |
224 | class TransformerLayer(nn.Module):
225 | def __init__(self, *args, **kwargs):
226 | super().__init__()
227 | self.self_attn = SelfBlock(*args, **kwargs)
228 | self.cross_attn = CrossBlock(*args, **kwargs)
229 |
230 | def forward(
231 | self,
232 | desc0,
233 | desc1,
234 | encoding0,
235 | encoding1,
236 | mask0: Optional[torch.Tensor] = None,
237 | mask1: Optional[torch.Tensor] = None,
238 | ):
239 | if mask0 is not None and mask1 is not None:
240 | return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
241 | else:
242 | desc0 = self.self_attn(desc0, encoding0)
243 | desc1 = self.self_attn(desc1, encoding1)
244 | return self.cross_attn(desc0, desc1)
245 |
246 | # This part is compiled and allows padding inputs
247 | def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
248 | mask = mask0 & mask1.transpose(-1, -2)
249 | mask0 = mask0 & mask0.transpose(-1, -2)
250 | mask1 = mask1 & mask1.transpose(-1, -2)
251 | desc0 = self.self_attn(desc0, encoding0, mask0)
252 | desc1 = self.self_attn(desc1, encoding1, mask1)
253 | return self.cross_attn(desc0, desc1, mask)
254 |
255 |
256 | def sigmoid_log_double_softmax(
257 | sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
258 | ) -> torch.Tensor:
259 | """create the log assignment matrix from logits and similarity"""
260 | b, m, n = sim.shape
261 | certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
262 | scores0 = F.log_softmax(sim, 2)
263 | scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
264 | scores = sim.new_full((b, m + 1, n + 1), 0)
265 | scores[:, :m, :n] = scores0 + scores1 + certainties
266 | scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
267 | scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
268 | return scores
269 |
270 |
271 | class MatchAssignment(nn.Module):
272 | def __init__(self, dim: int) -> None:
273 | super().__init__()
274 | self.dim = dim
275 | self.matchability = nn.Linear(dim, 1, bias=True)
276 | self.final_proj = nn.Linear(dim, dim, bias=True)
277 |
278 | def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
279 | """build assignment matrix from descriptors"""
280 | mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
281 | _, _, d = mdesc0.shape
282 | mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
283 | sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
284 | z0 = self.matchability(desc0)
285 | z1 = self.matchability(desc1)
286 | scores = sigmoid_log_double_softmax(sim, z0, z1)
287 | return scores, sim
288 |
289 | def get_matchability(self, desc: torch.Tensor):
290 | return torch.sigmoid(self.matchability(desc)).squeeze(-1)
291 |
292 |
293 | def filter_matches(scores: torch.Tensor, th: float):
294 | """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
295 | max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
296 | m0, m1 = max0.indices, max1.indices
297 | indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
298 | indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
299 | mutual0 = indices0 == m1.gather(1, m0)
300 | mutual1 = indices1 == m0.gather(1, m1)
301 | max0_exp = max0.values.exp()
302 | zero = max0_exp.new_tensor(0)
303 | mscores0 = torch.where(mutual0, max0_exp, zero)
304 | mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
305 | valid0 = mutual0 & (mscores0 > th)
306 | valid1 = mutual1 & valid0.gather(1, m1)
307 | m0 = torch.where(valid0, m0, -1)
308 | m1 = torch.where(valid1, m1, -1)
309 | return m0, m1, mscores0, mscores1
310 |
311 |
312 | class LightGlue(nn.Module):
313 | default_conf = {
314 | "name": "lightglue", # just for interfacing
315 | "input_dim": 256, # input descriptor dimension (autoselected from weights)
316 | "descriptor_dim": 256,
317 | "add_scale_ori": False,
318 | "n_layers": 9,
319 | "num_heads": 4,
320 | "flash": True, # enable FlashAttention if available.
321 | "mp": False, # enable mixed precision
322 | "depth_confidence": 0.95, # early stopping, disable with -1
323 | "width_confidence": 0.99, # point pruning, disable with -1
324 | "filter_threshold": 0.1, # match threshold
325 | "weights": None,
326 | }
327 |
328 | # Point pruning involves an overhead (gather).
329 | # Therefore, we only activate it if there are enough keypoints.
330 | pruning_keypoint_thresholds = {
331 | "cpu": -1,
332 | "mps": -1,
333 | "cuda": 1024,
334 | "flash": 1536,
335 | }
336 |
337 | required_data_keys = ["image0", "image1"]
338 |
339 | version = "v0.1_arxiv"
340 | url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
341 |
342 | features = {
343 | "superpoint": {
344 | "weights": "superpoint_lightglue",
345 | "input_dim": 256,
346 | },
347 | "disk": {
348 | "weights": "disk_lightglue",
349 | "input_dim": 128,
350 | },
351 | "aliked": {
352 | "weights": "aliked_lightglue",
353 | "input_dim": 128,
354 | },
355 | "sift": {
356 | "weights": "sift_lightglue",
357 | "input_dim": 128,
358 | "add_scale_ori": True,
359 | },
360 | }
361 |
362 | def __init__(self, features="superpoint", **conf) -> None:
363 | super().__init__()
364 | self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
365 | if features is not None:
366 | if features not in self.features:
367 | raise ValueError(
368 | f"Unsupported features: {features} not in "
369 | f"{{{','.join(self.features)}}}"
370 | )
371 | for k, v in self.features[features].items():
372 | setattr(conf, k, v)
373 |
374 | if conf.input_dim != conf.descriptor_dim:
375 | self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
376 | else:
377 | self.input_proj = nn.Identity()
378 |
379 | head_dim = conf.descriptor_dim // conf.num_heads
380 | self.posenc = LearnableFourierPositionalEncoding(
381 | 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
382 | )
383 |
384 | h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
385 |
386 | self.transformers = nn.ModuleList(
387 | [TransformerLayer(d, h, conf.flash) for _ in range(n)]
388 | )
389 |
390 | self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
391 | self.token_confidence = nn.ModuleList(
392 | [TokenConfidence(d) for _ in range(n - 1)]
393 | )
394 | self.register_buffer(
395 | "confidence_thresholds",
396 | torch.Tensor(
397 | [self.confidence_threshold(i) for i in range(self.conf.n_layers)]
398 | ),
399 | )
400 |
401 | state_dict = None
402 | if features is not None:
403 | fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
404 | state_dict = torch.hub.load_state_dict_from_url(
405 | self.url.format(self.version, features), file_name=fname
406 | )
407 | self.load_state_dict(state_dict, strict=False)
408 | elif conf.weights is not None:
409 | path = Path(__file__).parent
410 | path = path / "weights/{}.pth".format(self.conf.weights)
411 | state_dict = torch.load(str(path), map_location="cpu")
412 |
413 | if state_dict:
414 | # rename old state dict entries
415 | for i in range(self.conf.n_layers):
416 | pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
417 | state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
418 | pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
419 | state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
420 | self.load_state_dict(state_dict, strict=False)
421 |
422 | # static lengths LightGlue is compiled for (only used with torch.compile)
423 | self.static_lengths = None
424 |
425 | def compile(
426 | self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
427 | ):
428 | if self.conf.width_confidence != -1:
429 | warnings.warn(
430 | "Point pruning is partially disabled for compiled forward.",
431 | stacklevel=2,
432 | )
433 |
434 | for i in range(self.conf.n_layers):
435 | self.transformers[i].masked_forward = torch.compile(
436 | self.transformers[i].masked_forward, mode=mode, fullgraph=True
437 | )
438 |
439 | self.static_lengths = static_lengths
440 |
441 | def forward(self, data: dict) -> dict:
442 | """
443 | Match keypoints and descriptors between two images
444 |
445 | Input (dict):
446 | image0: dict
447 | keypoints: [B x M x 2]
448 | descriptors: [B x M x D]
449 | image: [B x C x H x W] or image_size: [B x 2]
450 | image1: dict
451 | keypoints: [B x N x 2]
452 | descriptors: [B x N x D]
453 | image: [B x C x H x W] or image_size: [B x 2]
454 | Output (dict):
455 | log_assignment: [B x M+1 x N+1]
456 | matches0: [B x M]
457 | matching_scores0: [B x M]
458 | matches1: [B x N]
459 | matching_scores1: [B x N]
460 | matches: List[[Si x 2]], scores: List[[Si]]
461 | """
462 | with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
463 | return self._forward(data)
464 |
465 | def _forward(self, data: dict) -> dict:
466 | for key in self.required_data_keys:
467 | assert key in data, f"Missing key {key} in data"
468 | data0, data1 = data["image0"], data["image1"]
469 | kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
470 | b, m, _ = kpts0.shape
471 | b, n, _ = kpts1.shape
472 | device = kpts0.device
473 | size0, size1 = data0.get("image_size"), data1.get("image_size")
474 | kpts0 = normalize_keypoints(kpts0, size0).clone()
475 | kpts1 = normalize_keypoints(kpts1, size1).clone()
476 |
477 | if self.conf.add_scale_ori:
478 | kpts0 = torch.cat(
479 | [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
480 | )
481 | kpts1 = torch.cat(
482 | [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
483 | )
484 | desc0 = data0["descriptors"].detach().contiguous()
485 | desc1 = data1["descriptors"].detach().contiguous()
486 |
487 | assert desc0.shape[-1] == self.conf.input_dim
488 | assert desc1.shape[-1] == self.conf.input_dim
489 |
490 | if torch.is_autocast_enabled():
491 | desc0 = desc0.half()
492 | desc1 = desc1.half()
493 |
494 | mask0, mask1 = None, None
495 | c = max(m, n)
496 | do_compile = self.static_lengths and c <= max(self.static_lengths)
497 | if do_compile:
498 | kn = min([k for k in self.static_lengths if k >= c])
499 | desc0, mask0 = pad_to_length(desc0, kn)
500 | desc1, mask1 = pad_to_length(desc1, kn)
501 | kpts0, _ = pad_to_length(kpts0, kn)
502 | kpts1, _ = pad_to_length(kpts1, kn)
503 | desc0 = self.input_proj(desc0)
504 | desc1 = self.input_proj(desc1)
505 | # cache positional embeddings
506 | encoding0 = self.posenc(kpts0)
507 | encoding1 = self.posenc(kpts1)
508 |
509 | # GNN + final_proj + assignment
510 | do_early_stop = self.conf.depth_confidence > 0
511 | do_point_pruning = self.conf.width_confidence > 0 and not do_compile
512 | pruning_th = self.pruning_min_kpts(device)
513 | if do_point_pruning:
514 | ind0 = torch.arange(0, m, device=device)[None]
515 | ind1 = torch.arange(0, n, device=device)[None]
516 | # We store the index of the layer at which pruning is detected.
517 | prune0 = torch.ones_like(ind0)
518 | prune1 = torch.ones_like(ind1)
519 | token0, token1 = None, None
520 | for i in range(self.conf.n_layers):
521 | desc0, desc1 = self.transformers[i](
522 | desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
523 | )
524 | if i == self.conf.n_layers - 1:
525 | continue # no early stopping or adaptive width at last layer
526 |
527 | if do_early_stop:
528 | token0, token1 = self.token_confidence[i](desc0, desc1)
529 | if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
530 | break
531 | if do_point_pruning and desc0.shape[-2] > pruning_th:
532 | scores0 = self.log_assignment[i].get_matchability(desc0)
533 | prunemask0 = self.get_pruning_mask(token0, scores0, i)
534 | keep0 = torch.where(prunemask0)[1]
535 | ind0 = ind0.index_select(1, keep0)
536 | desc0 = desc0.index_select(1, keep0)
537 | encoding0 = encoding0.index_select(-2, keep0)
538 | prune0[:, ind0] += 1
539 | if do_point_pruning and desc1.shape[-2] > pruning_th:
540 | scores1 = self.log_assignment[i].get_matchability(desc1)
541 | prunemask1 = self.get_pruning_mask(token1, scores1, i)
542 | keep1 = torch.where(prunemask1)[1]
543 | ind1 = ind1.index_select(1, keep1)
544 | desc1 = desc1.index_select(1, keep1)
545 | encoding1 = encoding1.index_select(-2, keep1)
546 | prune1[:, ind1] += 1
547 |
548 | desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
549 | scores, _ = self.log_assignment[i](desc0, desc1)
550 | m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
551 | matches, mscores = [], []
552 | for k in range(b):
553 | valid = m0[k] > -1
554 | m_indices_0 = torch.where(valid)[0]
555 | m_indices_1 = m0[k][valid]
556 | if do_point_pruning:
557 | m_indices_0 = ind0[k, m_indices_0]
558 | m_indices_1 = ind1[k, m_indices_1]
559 | matches.append(torch.stack([m_indices_0, m_indices_1], -1))
560 | mscores.append(mscores0[k][valid])
561 |
562 | # TODO: Remove when hloc switches to the compact format.
563 | if do_point_pruning:
564 | m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
565 | m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
566 | m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
567 | m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
568 | mscores0_ = torch.zeros((b, m), device=mscores0.device)
569 | mscores1_ = torch.zeros((b, n), device=mscores1.device)
570 | mscores0_[:, ind0] = mscores0
571 | mscores1_[:, ind1] = mscores1
572 | m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
573 | else:
574 | prune0 = torch.ones_like(mscores0) * self.conf.n_layers
575 | prune1 = torch.ones_like(mscores1) * self.conf.n_layers
576 |
577 | pred = {
578 | "matches0": m0,
579 | "matches1": m1,
580 | "matching_scores0": mscores0,
581 | "matching_scores1": mscores1,
582 | "stop": i + 1,
583 | "matches": matches,
584 | "scores": mscores,
585 | "prune0": prune0,
586 | "prune1": prune1,
587 | }
588 |
589 | return pred
590 |
591 | def confidence_threshold(self, layer_index: int) -> float:
592 | """scaled confidence threshold"""
593 | threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
594 | return np.clip(threshold, 0, 1)
595 |
596 | def get_pruning_mask(
597 | self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
598 | ) -> torch.Tensor:
599 | """mask points which should be removed"""
600 | keep = scores > (1 - self.conf.width_confidence)
601 | if confidences is not None: # Low-confidence points are never pruned.
602 | keep |= confidences <= self.confidence_thresholds[layer_index]
603 | return keep
604 |
605 | def check_if_stop(
606 | self,
607 | confidences0: torch.Tensor,
608 | confidences1: torch.Tensor,
609 | layer_index: int,
610 | num_points: int,
611 | ) -> torch.Tensor:
612 | """evaluate stopping condition"""
613 | confidences = torch.cat([confidences0, confidences1], -1)
614 | threshold = self.confidence_thresholds[layer_index]
615 | ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
616 | return ratio_confident > self.conf.depth_confidence
617 |
618 | def pruning_min_kpts(self, device: torch.device):
619 | if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
620 | return self.pruning_keypoint_thresholds["flash"]
621 | else:
622 | return self.pruning_keypoint_thresholds[device.type]
623 |
--------------------------------------------------------------------------------
/lightglue/sift.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from kornia.color import rgb_to_grayscale
7 | from packaging import version
8 |
9 | try:
10 | import pycolmap
11 | except ImportError:
12 | pycolmap = None
13 |
14 | from .utils import Extractor
15 |
16 |
17 | def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
18 | h, w = image_shape
19 | ij = np.round(points - 0.5).astype(int).T[::-1]
20 |
21 | # Remove duplicate points (identical coordinates).
22 | # Pick highest scale or score
23 | s = scales if scores is None else scores
24 | buffer = np.zeros((h, w))
25 | np.maximum.at(buffer, tuple(ij), s)
26 | keep = np.where(buffer[tuple(ij)] == s)[0]
27 |
28 | # Pick lowest angle (arbitrary).
29 | ij = ij[:, keep]
30 | buffer[:] = np.inf
31 | o_abs = np.abs(angles[keep])
32 | np.minimum.at(buffer, tuple(ij), o_abs)
33 | mask = buffer[tuple(ij)] == o_abs
34 | ij = ij[:, mask]
35 | keep = keep[mask]
36 |
37 | if nms_radius > 0:
38 | # Apply NMS on the remaining points
39 | buffer[:] = 0
40 | buffer[tuple(ij)] = s[keep] # scores or scale
41 |
42 | local_max = torch.nn.functional.max_pool2d(
43 | torch.from_numpy(buffer).unsqueeze(0),
44 | kernel_size=nms_radius * 2 + 1,
45 | stride=1,
46 | padding=nms_radius,
47 | ).squeeze(0)
48 | is_local_max = buffer == local_max.numpy()
49 | keep = keep[is_local_max[tuple(ij)]]
50 | return keep
51 |
52 |
53 | def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
54 | x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
55 | x.clip_(min=eps).sqrt_()
56 | return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
57 |
58 |
59 | def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
60 | """
61 | Detect keypoints using OpenCV Detector.
62 | Optionally, perform description.
63 | Args:
64 | features: OpenCV based keypoints detector and descriptor
65 | image: Grayscale image of uint8 data type
66 | Returns:
67 | keypoints: 1D array of detected cv2.KeyPoint
68 | scores: 1D array of responses
69 | descriptors: 1D array of descriptors
70 | """
71 | detections, descriptors = features.detectAndCompute(image, None)
72 | points = np.array([k.pt for k in detections], dtype=np.float32)
73 | scores = np.array([k.response for k in detections], dtype=np.float32)
74 | scales = np.array([k.size for k in detections], dtype=np.float32)
75 | angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
76 | return points, scores, scales, angles, descriptors
77 |
78 |
79 | class SIFT(Extractor):
80 | default_conf = {
81 | "rootsift": True,
82 | "nms_radius": 0, # None to disable filtering entirely.
83 | "max_num_keypoints": 4096,
84 | "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
85 | "detection_threshold": 0.0066667, # from COLMAP
86 | "edge_threshold": 10,
87 | "first_octave": -1, # only used by pycolmap, the default of COLMAP
88 | "num_octaves": 4,
89 | }
90 |
91 | preprocess_conf = {
92 | "resize": 1024,
93 | }
94 |
95 | required_data_keys = ["image"]
96 |
97 | def __init__(self, **conf):
98 | super().__init__(**conf) # Update with default configuration.
99 | backend = self.conf.backend
100 | if backend.startswith("pycolmap"):
101 | if pycolmap is None:
102 | raise ImportError(
103 | "Cannot find module pycolmap: install it with pip"
104 | "or use backend=opencv."
105 | )
106 | options = {
107 | "peak_threshold": self.conf.detection_threshold,
108 | "edge_threshold": self.conf.edge_threshold,
109 | "first_octave": self.conf.first_octave,
110 | "num_octaves": self.conf.num_octaves,
111 | "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
112 | }
113 | device = (
114 | "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
115 | )
116 | if (
117 | backend == "pycolmap_cpu" or not pycolmap.has_cuda
118 | ) and pycolmap.__version__ < "0.5.0":
119 | warnings.warn(
120 | "The pycolmap CPU SIFT is buggy in version < 0.5.0, "
121 | "consider upgrading pycolmap or use the CUDA version.",
122 | stacklevel=1,
123 | )
124 | else:
125 | options["max_num_features"] = self.conf.max_num_keypoints
126 | self.sift = pycolmap.Sift(options=options, device=device)
127 | elif backend == "opencv":
128 | self.sift = cv2.SIFT_create(
129 | contrastThreshold=self.conf.detection_threshold,
130 | nfeatures=self.conf.max_num_keypoints,
131 | edgeThreshold=self.conf.edge_threshold,
132 | nOctaveLayers=self.conf.num_octaves,
133 | )
134 | else:
135 | backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
136 | raise ValueError(
137 | f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
138 | )
139 |
140 | def extract_single_image(self, image: torch.Tensor):
141 | image_np = image.cpu().numpy().squeeze(0)
142 |
143 | if self.conf.backend.startswith("pycolmap"):
144 | if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
145 | detections, descriptors = self.sift.extract(image_np)
146 | scores = None # Scores are not exposed by COLMAP anymore.
147 | else:
148 | detections, scores, descriptors = self.sift.extract(image_np)
149 | keypoints = detections[:, :2] # Keep only (x, y).
150 | scales, angles = detections[:, -2:].T
151 | if scores is not None and (
152 | self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
153 | ):
154 | # Set the scores as a combination of abs. response and scale.
155 | scores = np.abs(scores) * scales
156 | elif self.conf.backend == "opencv":
157 | # TODO: Check if opencv keypoints are already in corner convention
158 | keypoints, scores, scales, angles, descriptors = run_opencv_sift(
159 | self.sift, (image_np * 255.0).astype(np.uint8)
160 | )
161 | pred = {
162 | "keypoints": keypoints,
163 | "scales": scales,
164 | "oris": angles,
165 | "descriptors": descriptors,
166 | }
167 | if scores is not None:
168 | pred["keypoint_scores"] = scores
169 |
170 | # sometimes pycolmap returns points outside the image. We remove them
171 | if self.conf.backend.startswith("pycolmap"):
172 | is_inside = (
173 | pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
174 | ).all(-1)
175 | pred = {k: v[is_inside] for k, v in pred.items()}
176 |
177 | if self.conf.nms_radius is not None:
178 | keep = filter_dog_point(
179 | pred["keypoints"],
180 | pred["scales"],
181 | pred["oris"],
182 | image_np.shape,
183 | self.conf.nms_radius,
184 | scores=pred.get("keypoint_scores"),
185 | )
186 | pred = {k: v[keep] for k, v in pred.items()}
187 |
188 | pred = {k: torch.from_numpy(v) for k, v in pred.items()}
189 | if scores is not None:
190 | # Keep the k keypoints with highest score
191 | num_points = self.conf.max_num_keypoints
192 | if num_points is not None and len(pred["keypoints"]) > num_points:
193 | indices = torch.topk(pred["keypoint_scores"], num_points).indices
194 | pred = {k: v[indices] for k, v in pred.items()}
195 |
196 | return pred
197 |
198 | def forward(self, data: dict) -> dict:
199 | image = data["image"]
200 | if image.shape[1] == 3:
201 | image = rgb_to_grayscale(image)
202 | device = image.device
203 | image = image.cpu()
204 | pred = []
205 | for k in range(len(image)):
206 | img = image[k]
207 | if "image_size" in data.keys():
208 | # avoid extracting points in padded areas
209 | w, h = data["image_size"][k]
210 | img = img[:, :h, :w]
211 | p = self.extract_single_image(img)
212 | pred.append(p)
213 | pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
214 | if self.conf.rootsift:
215 | pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
216 | return pred
217 |
--------------------------------------------------------------------------------
/lightglue/superpoint.py:
--------------------------------------------------------------------------------
1 | # %BANNER_BEGIN%
2 | # ---------------------------------------------------------------------
3 | # %COPYRIGHT_BEGIN%
4 | #
5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
6 | #
7 | # Unpublished Copyright (c) 2020
8 | # Magic Leap, Inc., All Rights Reserved.
9 | #
10 | # NOTICE: All information contained herein is, and remains the property
11 | # of COMPANY. The intellectual and technical concepts contained herein
12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign
13 | # Patents, patents in process, and are protected by trade secret or
14 | # copyright law. Dissemination of this information or reproduction of
15 | # this material is strictly forbidden unless prior written permission is
16 | # obtained from COMPANY. Access to the source code contained herein is
17 | # hereby forbidden to anyone except current COMPANY employees, managers
18 | # or contractors who have executed Confidentiality and Non-disclosure
19 | # agreements explicitly covering such access.
20 | #
21 | # The copyright notice above does not evidence any actual or intended
22 | # publication or disclosure of this source code, which includes
23 | # information that is confidential and/or proprietary, and is a trade
24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
32 | #
33 | # %COPYRIGHT_END%
34 | # ----------------------------------------------------------------------
35 | # %AUTHORS_BEGIN%
36 | #
37 | # Originating Authors: Paul-Edouard Sarlin
38 | #
39 | # %AUTHORS_END%
40 | # --------------------------------------------------------------------*/
41 | # %BANNER_END%
42 |
43 | # Adapted by Remi Pautrat, Philipp Lindenberger
44 |
45 | import torch
46 | from kornia.color import rgb_to_grayscale
47 | from torch import nn
48 |
49 | from .utils import Extractor
50 |
51 |
52 | def simple_nms(scores, nms_radius: int):
53 | """Fast Non-maximum suppression to remove nearby points"""
54 | assert nms_radius >= 0
55 |
56 | def max_pool(x):
57 | return torch.nn.functional.max_pool2d(
58 | x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
59 | )
60 |
61 | zeros = torch.zeros_like(scores)
62 | max_mask = scores == max_pool(scores)
63 | for _ in range(2):
64 | supp_mask = max_pool(max_mask.float()) > 0
65 | supp_scores = torch.where(supp_mask, zeros, scores)
66 | new_max_mask = supp_scores == max_pool(supp_scores)
67 | max_mask = max_mask | (new_max_mask & (~supp_mask))
68 | return torch.where(max_mask, scores, zeros)
69 |
70 |
71 | def top_k_keypoints(keypoints, scores, k):
72 | if k >= len(keypoints):
73 | return keypoints, scores
74 | scores, indices = torch.topk(scores, k, dim=0, sorted=True)
75 | return keypoints[indices], scores
76 |
77 |
78 | def sample_descriptors(keypoints, descriptors, s: int = 8):
79 | """Interpolate descriptors at keypoint locations"""
80 | b, c, h, w = descriptors.shape
81 | keypoints = keypoints - s / 2 + 0.5
82 | keypoints /= torch.tensor(
83 | [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
84 | ).to(
85 | keypoints
86 | )[None]
87 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
88 | args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
89 | descriptors = torch.nn.functional.grid_sample(
90 | descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
91 | )
92 | descriptors = torch.nn.functional.normalize(
93 | descriptors.reshape(b, c, -1), p=2, dim=1
94 | )
95 | return descriptors
96 |
97 |
98 | class SuperPoint(Extractor):
99 | """SuperPoint Convolutional Detector and Descriptor
100 |
101 | SuperPoint: Self-Supervised Interest Point Detection and
102 | Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
103 | Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
104 |
105 | """
106 |
107 | default_conf = {
108 | "descriptor_dim": 256,
109 | "nms_radius": 4,
110 | "max_num_keypoints": None,
111 | "detection_threshold": 0.0005,
112 | "remove_borders": 4,
113 | }
114 |
115 | preprocess_conf = {
116 | "resize": 1024,
117 | }
118 |
119 | required_data_keys = ["image"]
120 |
121 | def __init__(self, **conf):
122 | super().__init__(**conf) # Update with default configuration.
123 | self.relu = nn.ReLU(inplace=True)
124 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
125 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
126 |
127 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
128 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
129 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
130 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
131 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
132 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
133 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
134 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
135 |
136 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
137 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
138 |
139 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
140 | self.convDb = nn.Conv2d(
141 | c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
142 | )
143 |
144 | url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa
145 | self.load_state_dict(torch.hub.load_state_dict_from_url(url))
146 |
147 | if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
148 | raise ValueError("max_num_keypoints must be positive or None")
149 |
150 | def forward(self, data: dict) -> dict:
151 | """Compute keypoints, scores, descriptors for image"""
152 | for key in self.required_data_keys:
153 | assert key in data, f"Missing key {key} in data"
154 | image = data["image"]
155 | if image.shape[1] == 3:
156 | image = rgb_to_grayscale(image)
157 |
158 | # Shared Encoder
159 | x = self.relu(self.conv1a(image))
160 | x = self.relu(self.conv1b(x))
161 | x = self.pool(x)
162 | x = self.relu(self.conv2a(x))
163 | x = self.relu(self.conv2b(x))
164 | x = self.pool(x)
165 | x = self.relu(self.conv3a(x))
166 | x = self.relu(self.conv3b(x))
167 | x = self.pool(x)
168 | x = self.relu(self.conv4a(x))
169 | x = self.relu(self.conv4b(x))
170 |
171 | # Compute the dense keypoint scores
172 | cPa = self.relu(self.convPa(x))
173 | scores = self.convPb(cPa)
174 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
175 | b, _, h, w = scores.shape
176 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
177 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
178 | scores = simple_nms(scores, self.conf.nms_radius)
179 |
180 | # Discard keypoints near the image borders
181 | if self.conf.remove_borders:
182 | pad = self.conf.remove_borders
183 | scores[:, :pad] = -1
184 | scores[:, :, :pad] = -1
185 | scores[:, -pad:] = -1
186 | scores[:, :, -pad:] = -1
187 |
188 | # Extract keypoints
189 | best_kp = torch.where(scores > self.conf.detection_threshold)
190 | scores = scores[best_kp]
191 |
192 | # Separate into batches
193 | keypoints = [
194 | torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
195 | ]
196 | scores = [scores[best_kp[0] == i] for i in range(b)]
197 |
198 | # Keep the k keypoints with highest score
199 | if self.conf.max_num_keypoints is not None:
200 | keypoints, scores = list(
201 | zip(
202 | *[
203 | top_k_keypoints(k, s, self.conf.max_num_keypoints)
204 | for k, s in zip(keypoints, scores)
205 | ]
206 | )
207 | )
208 |
209 | # Convert (h, w) to (x, y)
210 | keypoints = [torch.flip(k, [1]).float() for k in keypoints]
211 |
212 | # Compute the dense descriptors
213 | cDa = self.relu(self.convDa(x))
214 | descriptors = self.convDb(cDa)
215 | descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
216 |
217 | # Extract descriptors
218 | descriptors = [
219 | sample_descriptors(k[None], d[None], 8)[0]
220 | for k, d in zip(keypoints, descriptors)
221 | ]
222 |
223 | return {
224 | "keypoints": torch.stack(keypoints, 0),
225 | "keypoint_scores": torch.stack(scores, 0),
226 | "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
227 | }
228 |
--------------------------------------------------------------------------------
/lightglue/utils.py:
--------------------------------------------------------------------------------
1 | import collections.abc as collections
2 | from pathlib import Path
3 | from types import SimpleNamespace
4 | from typing import Callable, List, Optional, Tuple, Union
5 |
6 | import cv2
7 | import kornia
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 |
13 | class ImagePreprocessor:
14 | default_conf = {
15 | "resize": None, # target edge length, None for no resizing
16 | "side": "long",
17 | "interpolation": "bilinear",
18 | "align_corners": None,
19 | "antialias": True,
20 | }
21 |
22 | def __init__(self, **conf) -> None:
23 | super().__init__()
24 | self.conf = {**self.default_conf, **conf}
25 | self.conf = SimpleNamespace(**self.conf)
26 |
27 | def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
28 | """Resize and preprocess an image, return image and resize scale"""
29 | h, w = img.shape[-2:]
30 | if self.conf.resize is not None:
31 | img = kornia.geometry.transform.resize(
32 | img,
33 | self.conf.resize,
34 | side=self.conf.side,
35 | antialias=self.conf.antialias,
36 | align_corners=self.conf.align_corners,
37 | )
38 | scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
39 | return img, scale
40 |
41 |
42 | def map_tensor(input_, func: Callable):
43 | string_classes = (str, bytes)
44 | if isinstance(input_, string_classes):
45 | return input_
46 | elif isinstance(input_, collections.Mapping):
47 | return {k: map_tensor(sample, func) for k, sample in input_.items()}
48 | elif isinstance(input_, collections.Sequence):
49 | return [map_tensor(sample, func) for sample in input_]
50 | elif isinstance(input_, torch.Tensor):
51 | return func(input_)
52 | else:
53 | return input_
54 |
55 |
56 | def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
57 | """Move batch (dict) to device"""
58 |
59 | def _func(tensor):
60 | return tensor.to(device=device, non_blocking=non_blocking).detach()
61 |
62 | return map_tensor(batch, _func)
63 |
64 |
65 | def rbd(data: dict) -> dict:
66 | """Remove batch dimension from elements in data"""
67 | return {
68 | k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
69 | for k, v in data.items()
70 | }
71 |
72 |
73 | def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
74 | """Read an image from path as RGB or grayscale"""
75 | if not Path(path).exists():
76 | raise FileNotFoundError(f"No image at path {path}.")
77 | mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
78 | image = cv2.imread(str(path), mode)
79 | if image is None:
80 | raise IOError(f"Could not read image at {path}.")
81 | if not grayscale:
82 | image = image[..., ::-1]
83 | return image
84 |
85 |
86 | def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
87 | """Normalize the image tensor and reorder the dimensions."""
88 | if image.ndim == 3:
89 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
90 | elif image.ndim == 2:
91 | image = image[None] # add channel axis
92 | else:
93 | raise ValueError(f"Not an image: {image.shape}")
94 | return torch.tensor(image / 255.0, dtype=torch.float)
95 |
96 |
97 | def resize_image(
98 | image: np.ndarray,
99 | size: Union[List[int], int],
100 | fn: str = "max",
101 | interp: Optional[str] = "area",
102 | ) -> np.ndarray:
103 | """Resize an image to a fixed size, or according to max or min edge."""
104 | h, w = image.shape[:2]
105 |
106 | fn = {"max": max, "min": min}[fn]
107 | if isinstance(size, int):
108 | scale = size / fn(h, w)
109 | h_new, w_new = int(round(h * scale)), int(round(w * scale))
110 | scale = (w_new / w, h_new / h)
111 | elif isinstance(size, (tuple, list)):
112 | h_new, w_new = size
113 | scale = (w_new / w, h_new / h)
114 | else:
115 | raise ValueError(f"Incorrect new size: {size}")
116 | mode = {
117 | "linear": cv2.INTER_LINEAR,
118 | "cubic": cv2.INTER_CUBIC,
119 | "nearest": cv2.INTER_NEAREST,
120 | "area": cv2.INTER_AREA,
121 | }[interp]
122 | return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
123 |
124 |
125 | def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
126 | image = read_image(path)
127 | if resize is not None:
128 | image, _ = resize_image(image, resize, **kwargs)
129 | return numpy_image_to_torch(image)
130 |
131 | def load_pilimage(img: Image, resize: int = None, **kwargs) -> torch.Tensor:
132 | """Read an image from path as RGB or grayscale"""
133 | grayscale=False
134 | mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
135 | img=img.convert("RGB")
136 | image = cv2.cvtColor(np.array(img),cv2.COLOR_RGB2BGR)
137 | #image = cv2.imread(str(path), mode)
138 | if image is None:
139 | raise IOError(f"Could not read image at {path}.")
140 | if not grayscale:
141 | image = image[..., ::-1]
142 | if resize is not None:
143 | image, _ = resize_image(image, resize, **kwargs)
144 | return numpy_image_to_torch(image)
145 |
146 |
147 | class Extractor(torch.nn.Module):
148 | def __init__(self, **conf):
149 | super().__init__()
150 | self.conf = SimpleNamespace(**{**self.default_conf, **conf})
151 |
152 | @torch.no_grad()
153 | def extract(self, img: torch.Tensor, **conf) -> dict:
154 | """Perform extraction with online resizing"""
155 | if img.dim() == 3:
156 | img = img[None] # add batch dim
157 | assert img.dim() == 4 and img.shape[0] == 1
158 | shape = img.shape[-2:][::-1]
159 | img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
160 | feats = self.forward({"image": img})
161 | feats["image_size"] = torch.tensor(shape)[None].to(img).float()
162 | feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
163 | return feats
164 |
165 |
166 | def match_pair(
167 | extractor,
168 | matcher,
169 | image0: torch.Tensor,
170 | image1: torch.Tensor,
171 | device: str = "cpu",
172 | **preprocess,
173 | ):
174 | """Match a pair of images (image0, image1) with an extractor and matcher"""
175 | feats0 = extractor.extract(image0, **preprocess)
176 | feats1 = extractor.extract(image1, **preprocess)
177 | matches01 = matcher({"image0": feats0, "image1": feats1})
178 | data = [feats0, feats1, matches01]
179 | # remove batch dim and move to target device
180 | feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]
181 | return feats0, feats1, matches01
182 |
--------------------------------------------------------------------------------
/lightglue/viz2d.py:
--------------------------------------------------------------------------------
1 | """
2 | 2D visualization primitives based on Matplotlib.
3 | 1) Plot images with `plot_images`.
4 | 2) Call `plot_keypoints` or `plot_matches` any number of times.
5 | 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
6 | """
7 |
8 | import matplotlib
9 | import matplotlib.patheffects as path_effects
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import torch
13 |
14 |
15 | def cm_RdGn(x):
16 | """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
17 | x = np.clip(x, 0, 1)[..., None] * 2
18 | c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
19 | return np.clip(c, 0, 1)
20 |
21 |
22 | def cm_BlRdGn(x_):
23 | """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
24 | x = np.clip(x_, 0, 1)[..., None] * 2
25 | c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
26 |
27 | xn = -np.clip(x_, -1, 0)[..., None] * 2
28 | cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
29 | out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
30 | return out
31 |
32 |
33 | def cm_prune(x_):
34 | """Custom colormap to visualize pruning"""
35 | if isinstance(x_, torch.Tensor):
36 | x_ = x_.cpu().numpy()
37 | max_i = max(x_)
38 | norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
39 | return cm_BlRdGn(norm_x)
40 |
41 |
42 | def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
43 | """Plot a set of images horizontally.
44 | Args:
45 | imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
46 | titles: a list of strings, as titles for each image.
47 | cmaps: colormaps for monochrome images.
48 | adaptive: whether the figure size should fit the image aspect ratios.
49 | """
50 | # conversion to (H, W, 3) for torch.Tensor
51 | imgs = [
52 | img.permute(1, 2, 0).cpu().numpy()
53 | if (isinstance(img, torch.Tensor) and img.dim() == 3)
54 | else img
55 | for img in imgs
56 | ]
57 |
58 | n = len(imgs)
59 | if not isinstance(cmaps, (list, tuple)):
60 | cmaps = [cmaps] * n
61 |
62 | if adaptive:
63 | ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
64 | else:
65 | ratios = [4 / 3] * n
66 | figsize = [sum(ratios) * 4.5, 4.5]
67 | fig, ax = plt.subplots(
68 | 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
69 | )
70 | if n == 1:
71 | ax = [ax]
72 | for i in range(n):
73 | ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
74 | ax[i].get_yaxis().set_ticks([])
75 | ax[i].get_xaxis().set_ticks([])
76 | ax[i].set_axis_off()
77 | for spine in ax[i].spines.values(): # remove frame
78 | spine.set_visible(False)
79 | if titles:
80 | ax[i].set_title(titles[i])
81 | fig.tight_layout(pad=pad)
82 |
83 |
84 | def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
85 | """Plot keypoints for existing images.
86 | Args:
87 | kpts: list of ndarrays of size (N, 2).
88 | colors: string, or list of list of tuples (one for each keypoints).
89 | ps: size of the keypoints as float.
90 | """
91 | if not isinstance(colors, list):
92 | colors = [colors] * len(kpts)
93 | if not isinstance(a, list):
94 | a = [a] * len(kpts)
95 | if axes is None:
96 | axes = plt.gcf().axes
97 | for ax, k, c, alpha in zip(axes, kpts, colors, a):
98 | if isinstance(k, torch.Tensor):
99 | k = k.cpu().numpy()
100 | ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
101 |
102 |
103 | def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
104 | """Plot matches for a pair of existing images.
105 | Args:
106 | kpts0, kpts1: corresponding keypoints of size (N, 2).
107 | color: color of each match, string or RGB tuple. Random if not given.
108 | lw: width of the lines.
109 | ps: size of the end points (no endpoint if ps=0)
110 | indices: indices of the images to draw the matches on.
111 | a: alpha opacity of the match lines.
112 | """
113 | fig = plt.gcf()
114 | if axes is None:
115 | ax = fig.axes
116 | ax0, ax1 = ax[0], ax[1]
117 | else:
118 | ax0, ax1 = axes
119 | if isinstance(kpts0, torch.Tensor):
120 | kpts0 = kpts0.cpu().numpy()
121 | if isinstance(kpts1, torch.Tensor):
122 | kpts1 = kpts1.cpu().numpy()
123 | assert len(kpts0) == len(kpts1)
124 | if color is None:
125 | color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
126 | elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
127 | color = [color] * len(kpts0)
128 |
129 | if lw > 0:
130 | for i in range(len(kpts0)):
131 | line = matplotlib.patches.ConnectionPatch(
132 | xyA=(kpts0[i, 0], kpts0[i, 1]),
133 | xyB=(kpts1[i, 0], kpts1[i, 1]),
134 | coordsA=ax0.transData,
135 | coordsB=ax1.transData,
136 | axesA=ax0,
137 | axesB=ax1,
138 | zorder=1,
139 | color=color[i],
140 | linewidth=lw,
141 | clip_on=True,
142 | alpha=a,
143 | label=None if labels is None else labels[i],
144 | picker=5.0,
145 | )
146 | line.set_annotation_clip(True)
147 | fig.add_artist(line)
148 |
149 | # freeze the axes to prevent the transform to change
150 | ax0.autoscale(enable=False)
151 | ax1.autoscale(enable=False)
152 |
153 | if ps > 0:
154 | ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
155 | ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
156 |
157 |
158 | def add_text(
159 | idx,
160 | text,
161 | pos=(0.01, 0.99),
162 | fs=15,
163 | color="w",
164 | lcolor="k",
165 | lwidth=2,
166 | ha="left",
167 | va="top",
168 | ):
169 | ax = plt.gcf().axes[idx]
170 | t = ax.text(
171 | *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
172 | )
173 | if lcolor is not None:
174 | t.set_path_effects(
175 | [
176 | path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
177 | path_effects.Normal(),
178 | ]
179 | )
180 |
181 |
182 | def save_plot(path, **kw):
183 | """Save the current figure without any white margin."""
184 | plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
185 |
--------------------------------------------------------------------------------
/nodes.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import json
4 | import torch
5 |
6 | DEVICE = ["cuda","cpu"]
7 |
8 | class LightGlueLoader:
9 | @classmethod
10 | def INPUT_TYPES(cls):
11 | return {
12 | "required": {
13 | "device": (DEVICE, {"default":"cuda"}),
14 | "max_num_keypoints":("INT", {"default": 2048}),
15 | "filter_threshold": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}),
16 | }
17 | }
18 |
19 | RETURN_TYPES = ("SuperPoint","LightGlue",)
20 | RETURN_NAMES = ("extractor","matcher",)
21 | FUNCTION = "load_checkpoint"
22 | CATEGORY = "LightGlue"
23 |
24 | def load_checkpoint(self,device,max_num_keypoints,filter_threshold):
25 | from .lightglue import LightGlue, SuperPoint
26 |
27 | # SuperPoint+LightGlue
28 | extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(device) # load the extractor
29 | matcher = LightGlue(features='superpoint',filter_threshold=filter_threshold).eval().to(device) # load the matcher
30 |
31 | return (extractor,matcher,)
32 |
33 | class LightGlueSimple:
34 | @classmethod
35 | def INPUT_TYPES(cls):
36 | return {
37 | "required": {
38 | "extractor": ("SuperPoint",),
39 | "matcher": ("LightGlue",),
40 | "image0": ("IMAGE",),
41 | "image1": ("IMAGE",),
42 | "device": (DEVICE, {"default":"cuda"}),
43 | }
44 | }
45 |
46 | RETURN_TYPES = ("STRING",{},{},{},)
47 | RETURN_NAMES = ("motionbrush","matches","points0","points1",)
48 | FUNCTION = "run_inference"
49 | CATEGORY = "LightGlue"
50 |
51 | def run_inference(self,extractor,matcher,image0,image1,device):
52 | from .lightglue.utils import load_image,load_pilimage, rbd
53 |
54 | image0 = 255.0 * image0[0].cpu().numpy()
55 | image0 = Image.fromarray(np.clip(image0, 0, 255).astype(np.uint8))
56 |
57 | image1 = 255.0 * image1[0].cpu().numpy()
58 | image1 = Image.fromarray(np.clip(image1, 0, 255).astype(np.uint8))
59 |
60 | image0=load_pilimage(image0).to(device)
61 | image1=load_pilimage(image1).to(device)
62 |
63 | feats0 = extractor.extract(image0) # auto-resize the image, disable with resize=None
64 | feats1 = extractor.extract(image1)
65 |
66 | #image0 = load_image("/home/admin/ComfyUI/input/1.png")
67 | #image1 = load_image("/home/admin/ComfyUI/input/2.png")
68 | # match the features
69 | matches01 = matcher({'image0': feats0, 'image1': feats1})
70 | feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension
71 | matches = matches01['matches'] # indices with shape (K,2)
72 | points0 = feats0['keypoints'][matches[..., 0]] # coordinates in image #0, shape (K,2)
73 | points1 = feats1['keypoints'][matches[..., 1]] # coordinates in image #1, shape (K,2)
74 | #print(f'matches{matches}')
75 | #print(f'points0{points0}')
76 | #print(f'points1{points1}')
77 | trajs=torch.stack([points0,points1],1)
78 |
79 | return (json.dumps(trajs.tolist()),matches,points0,points1,)
80 |
81 | class LightGlueSimpleMulti:
82 | @classmethod
83 | def INPUT_TYPES(cls):
84 | return {
85 | "required": {
86 | "extractor": ("SuperPoint",),
87 | "matcher": ("LightGlue",),
88 | "images": ("IMAGE",),
89 | "device": (DEVICE, {"default":"cuda"}),
90 | "scale": ("INT", {"default": 4}),
91 | }
92 | }
93 |
94 | RETURN_TYPES = ("STRING",)
95 | RETURN_NAMES = ("motionbrush",)
96 | FUNCTION = "run_inference"
97 | CATEGORY = "LightGlue"
98 |
99 | def run_inference(self,extractor,matcher,images,device,scale):
100 | from .lightglue.utils import load_image,load_pilimage, rbd
101 |
102 | featses=[]
103 | matcheses=[]
104 | trajs=[]
105 | points0s=[]
106 | points1s=[]
107 | pointsnotuse=[]
108 | pointsuse=[]
109 |
110 | for image in images:
111 | image = 255.0 * image.cpu().numpy()
112 | image = Image.fromarray(np.clip(image, 0, 255).astype(np.uint8))
113 | image=load_pilimage(image).to(device)
114 | feats = extractor.extract(image)
115 | featses.append(feats)
116 | if len(featses)>1:
117 | matches = matcher({'image0': featses[0], 'image1': feats})
118 | feats0, feats1, matches = [rbd(x) for x in [featses[0], feats, matches]] # remove batch dimension
119 | matches = matches['matches']
120 | #print(f'{matches}')
121 | #print(f'{featses[0]}')
122 | points0 = feats0['keypoints'][matches[..., 0]]
123 | points1 = feats1['keypoints'][matches[..., 1]]
124 |
125 |
126 | if len(featses)==2:
127 | pointsuse=(matches[:,0].detach().cpu().numpy()/scale).astype('int').tolist()
128 | else:
129 | pu=False
130 | mlist=(matches[:,0].detach().cpu().numpy()/scale).astype('int').tolist()
131 | for puitem in pointsuse:
132 | if puitem not in mlist:
133 | pointsnotuse.append(puitem)
134 |
135 | points0s.append(points0.tolist())
136 | points1s.append(points1.tolist())
137 | matcheses.append(matches.tolist())
138 |
139 |
140 |
141 | muses=[]
142 | for pitem in pointsuse:
143 | if pitem not in pointsnotuse:
144 | muses.append(pitem)
145 |
146 | for muse in muses:
147 | trajs.append([])
148 | #print(f'{muses}')
149 | for imlist in range(len(matcheses)):
150 | mlist=matcheses[imlist]
151 | for imitem in range(len(mlist)):
152 | mitem=mlist[imitem]
153 | mitem0=(np.array(mitem[0])/scale).astype('int').tolist()
154 | #print(f'{mitem}')
155 | if mitem0 in muses:
156 | if imlist==0:
157 | trajs[muses.index(mitem0)].append(points0s[imlist][imitem])
158 | trajs[muses.index(mitem0)].append(points1s[imlist][imitem])
159 | else:
160 | trajs[muses.index(mitem0)].append(points1s[imlist][imitem])
161 |
162 | ret=[]
163 | for traj in trajs:
164 | if len(traj)>1:
165 | ret.append(traj)
166 | return (json.dumps(ret),)
167 |
168 |
169 | NODE_CLASS_MAPPINGS = {
170 | "LightGlue Loader":LightGlueLoader,
171 | "LightGlue Simple":LightGlueSimple,
172 | "LightGlue Simple Multi":LightGlueSimpleMulti,
173 | }
174 |
175 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "lightglue"
3 | description = "LightGlue: Local Feature Matching at Light Speed"
4 | version = "0.0"
5 | authors = [
6 | {name = "Philipp Lindenberger"},
7 | {name = "Paul-Edouard Sarlin"},
8 | ]
9 | readme = "README.md"
10 | requires-python = ">=3.6"
11 | license = {file = "LICENSE"}
12 | classifiers = [
13 | "Programming Language :: Python :: 3",
14 | "License :: OSI Approved :: Apache Software License",
15 | "Operating System :: OS Independent",
16 | ]
17 | urls = {Repository = "https://github.com/cvg/LightGlue/"}
18 | dynamic = ["dependencies"]
19 |
20 | [project.optional-dependencies]
21 | dev = ["black", "flake8", "isort"]
22 |
23 | [tool.setuptools]
24 | packages = ["lightglue"]
25 |
26 | [tool.setuptools.dynamic]
27 | dependencies = {file = ["requirements.txt"]}
28 |
29 | [tool.isort]
30 | profile = "black"
31 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | opencv-python
3 | matplotlib
4 | kornia>=0.6.11
--------------------------------------------------------------------------------
/tools/draw.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |