├── .gitignore ├── LICENSE ├── README.md ├── gluestick ├── __init__.py ├── drawing.py ├── geometry.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── gluestick.py │ ├── superpoint.py │ ├── two_view_pipeline.py │ └── wireframe.py └── run.py ├── gluestick_matching_demo.ipynb ├── pyproject.toml ├── requirements.txt └── resources ├── demo_seq1.gif ├── img1.jpg ├── img2.jpg └── weights └── superpoint_v1.pth /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/* 131 | *events.out.tfevents.* 132 | /outputs 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Computer Vision and Geometry Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GlueStick 2 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/GlueStick/blob/main/gluestick_matching_demo.ipynb) [![arXiv](https://img.shields.io/badge/arXiv-2304.02008-b31b1b.svg?style=flat)](https://arxiv.org/abs/2304.02008) [![Project Page](https://badgen.net/badge/color/project/green?icon=awesome&label)](https://iago-suarez.com/gluestick) 3 | 4 | Joint deep matcher for points and lines 🖼️💥🖼️ 5 | 6 | **Update: we are pleased to announce that the training code has been released within our new training framework, [GlueFactory](https://github.com/cvg/glue-factory).** 7 | 8 | ![Visualization of point and line matches](resources/demo_seq1.gif) 9 | 10 | This repository contains the official implementation of 11 | [GlueStick: Robust Image Matching by Sticking Points and Lines Together](https://arxiv.org/abs/2304.02008), accepted at ICCV 2023. 12 | 13 | ## Install 🛠️ 14 | 15 | To install the software in Ubuntu 22.04 follow these instructions: 16 | ```bash 17 | sudo apt-get install build-essential cmake libopencv-dev libopencv-contrib-dev 18 | git clone --recursive https://github.com/cvg/GlueStick.git 19 | cd GlueStick 20 | # Create and activate a virtual environment 21 | python -m venv venv 22 | source venv/bin/activate 23 | pip install -r requirements.txt 24 | pip install . 25 | ``` 26 | 27 | ## Running GlueStick 🏃 28 | Download the weights of the model: 29 | ``` 30 | wget https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar -P resources/weights 31 | ``` 32 | 33 | You can execute the inference with it with: 34 | ``` 35 | python -m gluestick.run -img1 resources/img1.jpg -img2 resources/img2.jpg 36 | ``` 37 | 38 | ## Training 🏋️ 39 | The training code is available in a separate repository, [GlueFactory](https://github.com/cvg/glue-factory). Within GlueFactory, you can not only train GlueStick, but also other deep matchers such as [LightGlue](https://github.com/cvg/LightGlue), use multiple feature extractors, line extractors, robust estimators, as well as run evaluations on multiple benchmarks. 40 | 41 | ## Licence 📜 42 | Our code is licenced under [MIT licence](https://github.com/cvg/GlueStick/blob/main/LICENSE). 43 | However, bear in mind that it uses a SuperPoint backbone that has a 44 | [non-commercial licence](https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/LICENSE). 45 | Therefore, the overall system is non-commercial 😞. We are working on an analogous version based on 46 | [DISK](https://github.com/cvlab-epfl/disk) to avoid this problem. 47 | 48 | ## Citation 📝 49 | If you use this code in your project, please consider citing the following paper: 50 | ```bibtex 51 | @InProceedings{pautrat_suarez_2023_gluestick, 52 | title={{GlueStick}: Robust Image Matching by Sticking Points and Lines Together}, 53 | author={Pautrat, R{\'e}mi* and Su{\'a}rez, Iago* and Yu, Yifan and Pollefeys, Marc and Larsson, Viktor}, 54 | booktitle={International Conference on Computer Vision (ICCV)}, 55 | year={2023} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /gluestick/__init__.py: -------------------------------------------------------------------------------- 1 | import collections.abc as collections 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | GLUESTICK_ROOT = Path(__file__).parent.parent 7 | 8 | 9 | def get_class(mod_name, base_path, BaseClass): 10 | """Get the class object which inherits from BaseClass and is defined in 11 | the module named mod_name, child of base_path. 12 | """ 13 | import inspect 14 | mod_path = '{}.{}'.format(base_path, mod_name) 15 | mod = __import__(mod_path, fromlist=['']) 16 | classes = inspect.getmembers(mod, inspect.isclass) 17 | # Filter classes defined in the module 18 | classes = [c for c in classes if c[1].__module__ == mod_path] 19 | # Filter classes inherited from BaseModel 20 | classes = [c for c in classes if issubclass(c[1], BaseClass)] 21 | assert len(classes) == 1, classes 22 | return classes[0][1] 23 | 24 | 25 | def get_model(name): 26 | from .models.base_model import BaseModel 27 | return get_class('models.' + name, __name__, BaseModel) 28 | 29 | 30 | def numpy_image_to_torch(image): 31 | """Normalize the image tensor and reorder the dimensions.""" 32 | if image.ndim == 3: 33 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW 34 | elif image.ndim == 2: 35 | image = image[None] # add channel axis 36 | else: 37 | raise ValueError(f'Not an image: {image.shape}') 38 | return torch.from_numpy(image / 255.).float() 39 | 40 | 41 | def map_tensor(input_, func): 42 | if isinstance(input_, (str, bytes)): 43 | return input_ 44 | elif isinstance(input_, collections.Mapping): 45 | return {k: map_tensor(sample, func) for k, sample in input_.items()} 46 | elif isinstance(input_, collections.Sequence): 47 | return [map_tensor(sample, func) for sample in input_] 48 | else: 49 | return func(input_) 50 | 51 | 52 | def batch_to_np(batch): 53 | return map_tensor(batch, lambda t: t.detach().cpu().numpy()[0]) 54 | -------------------------------------------------------------------------------- /gluestick/drawing.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import seaborn as sns 5 | 6 | 7 | def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, 8 | adaptive=True): 9 | """Plot a set of images horizontally. 10 | Args: 11 | imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). 12 | titles: a list of strings, as titles for each image. 13 | cmaps: colormaps for monochrome images. 14 | adaptive: whether the figure size should fit the image aspect ratios. 15 | """ 16 | n = len(imgs) 17 | if not isinstance(cmaps, (list, tuple)): 18 | cmaps = [cmaps] * n 19 | 20 | if adaptive: 21 | ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H 22 | else: 23 | ratios = [4 / 3] * n 24 | figsize = [sum(ratios) * 4.5, 4.5] 25 | fig, ax = plt.subplots( 26 | 1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios}) 27 | if n == 1: 28 | ax = [ax] 29 | for i in range(n): 30 | ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) 31 | ax[i].get_yaxis().set_ticks([]) 32 | ax[i].get_xaxis().set_ticks([]) 33 | ax[i].set_axis_off() 34 | for spine in ax[i].spines.values(): # remove frame 35 | spine.set_visible(False) 36 | if titles: 37 | ax[i].set_title(titles[i]) 38 | fig.tight_layout(pad=pad) 39 | return ax 40 | 41 | 42 | def plot_keypoints(kpts, colors='lime', ps=4, alpha=1): 43 | """Plot keypoints for existing images. 44 | Args: 45 | kpts: list of ndarrays of size (N, 2). 46 | colors: string, or list of list of tuples (one for each keypoints). 47 | ps: size of the keypoints as float. 48 | """ 49 | if not isinstance(colors, list): 50 | colors = [colors] * len(kpts) 51 | axes = plt.gcf().axes 52 | for a, k, c in zip(axes, kpts, colors): 53 | a.scatter(k[:, 0], k[:, 1], c=c, s=ps, alpha=alpha, linewidths=0) 54 | 55 | 56 | def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): 57 | """Plot matches for a pair of existing images. 58 | Args: 59 | kpts0, kpts1: corresponding keypoints of size (N, 2). 60 | color: color of each match, string or RGB tuple. Random if not given. 61 | lw: width of the lines. 62 | ps: size of the end points (no endpoint if ps=0) 63 | indices: indices of the images to draw the matches on. 64 | a: alpha opacity of the match lines. 65 | """ 66 | fig = plt.gcf() 67 | ax = fig.axes 68 | assert len(ax) > max(indices) 69 | ax0, ax1 = ax[indices[0]], ax[indices[1]] 70 | fig.canvas.draw() 71 | 72 | assert len(kpts0) == len(kpts1) 73 | if color is None: 74 | color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() 75 | elif len(color) > 0 and not isinstance(color[0], (tuple, list)): 76 | color = [color] * len(kpts0) 77 | 78 | if lw > 0: 79 | # transform the points into the figure coordinate system 80 | transFigure = fig.transFigure.inverted() 81 | fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) 82 | fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) 83 | fig.lines += [matplotlib.lines.Line2D( 84 | (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), 85 | zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, 86 | alpha=a) 87 | for i in range(len(kpts0))] 88 | 89 | # freeze the axes to prevent the transform to change 90 | ax0.autoscale(enable=False) 91 | ax1.autoscale(enable=False) 92 | 93 | if ps > 0: 94 | ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) 95 | ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) 96 | 97 | 98 | def plot_lines(lines, line_colors='orange', point_colors='cyan', 99 | ps=4, lw=2, alpha=1., indices=(0, 1)): 100 | """ Plot lines and endpoints for existing images. 101 | Args: 102 | lines: list of ndarrays of size (N, 2, 2). 103 | colors: string, or list of list of tuples (one for each keypoints). 104 | ps: size of the keypoints as float pixels. 105 | lw: line width as float pixels. 106 | alpha: transparency of the points and lines. 107 | indices: indices of the images to draw the matches on. 108 | """ 109 | if not isinstance(line_colors, list): 110 | line_colors = [line_colors] * len(lines) 111 | if not isinstance(point_colors, list): 112 | point_colors = [point_colors] * len(lines) 113 | 114 | fig = plt.gcf() 115 | ax = fig.axes 116 | assert len(ax) > max(indices) 117 | axes = [ax[i] for i in indices] 118 | fig.canvas.draw() 119 | 120 | # Plot the lines and junctions 121 | for a, l, lc, pc in zip(axes, lines, line_colors, point_colors): 122 | for i in range(len(l)): 123 | line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]), 124 | (l[i, 0, 1], l[i, 1, 1]), 125 | zorder=1, c=lc, linewidth=lw, 126 | alpha=alpha) 127 | a.add_line(line) 128 | pts = l.reshape(-1, 2) 129 | a.scatter(pts[:, 0], pts[:, 1], 130 | c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha) 131 | 132 | 133 | def plot_color_line_matches(lines, correct_matches=None, 134 | lw=2, indices=(0, 1)): 135 | """Plot line matches for existing images with multiple colors. 136 | Args: 137 | lines: list of ndarrays of size (N, 2, 2). 138 | correct_matches: bool array of size (N,) indicating correct matches. 139 | lw: line width as float pixels. 140 | indices: indices of the images to draw the matches on. 141 | """ 142 | n_lines = len(lines[0]) 143 | colors = sns.color_palette('husl', n_colors=n_lines) 144 | np.random.shuffle(colors) 145 | alphas = np.ones(n_lines) 146 | # If correct_matches is not None, display wrong matches with a low alpha 147 | if correct_matches is not None: 148 | alphas[~np.array(correct_matches)] = 0.2 149 | 150 | fig = plt.gcf() 151 | ax = fig.axes 152 | assert len(ax) > max(indices) 153 | axes = [ax[i] for i in indices] 154 | fig.canvas.draw() 155 | 156 | # Plot the lines 157 | for a, l in zip(axes, lines): 158 | # Transform the points into the figure coordinate system 159 | transFigure = fig.transFigure.inverted() 160 | endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) 161 | endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) 162 | fig.lines += [matplotlib.lines.Line2D( 163 | (endpoint0[i, 0], endpoint1[i, 0]), 164 | (endpoint0[i, 1], endpoint1[i, 1]), 165 | zorder=1, transform=fig.transFigure, c=colors[i], 166 | alpha=alphas[i], linewidth=lw) for i in range(n_lines)] 167 | -------------------------------------------------------------------------------- /gluestick/geometry.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def to_homogeneous(points): 8 | """Convert N-dimensional points to homogeneous coordinates. 9 | Args: 10 | points: torch.Tensor or numpy.ndarray with size (..., N). 11 | Returns: 12 | A torch.Tensor or numpy.ndarray with size (..., N+1). 13 | """ 14 | if isinstance(points, torch.Tensor): 15 | pad = points.new_ones(points.shape[:-1] + (1,)) 16 | return torch.cat([points, pad], dim=-1) 17 | elif isinstance(points, np.ndarray): 18 | pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype) 19 | return np.concatenate([points, pad], axis=-1) 20 | else: 21 | raise ValueError 22 | 23 | 24 | def from_homogeneous(points, eps=0.): 25 | """Remove the homogeneous dimension of N-dimensional points. 26 | Args: 27 | points: torch.Tensor or numpy.ndarray with size (..., N+1). 28 | Returns: 29 | A torch.Tensor or numpy ndarray with size (..., N). 30 | """ 31 | return points[..., :-1] / (points[..., -1:] + eps) 32 | 33 | 34 | def skew_symmetric(v): 35 | """Create a skew-symmetric matrix from a (batched) vector of size (..., 3). 36 | """ 37 | z = torch.zeros_like(v[..., 0]) 38 | M = torch.stack([ 39 | z, -v[..., 2], v[..., 1], 40 | v[..., 2], z, -v[..., 0], 41 | -v[..., 1], v[..., 0], z, 42 | ], dim=-1).reshape(v.shape[:-1] + (3, 3)) 43 | return M 44 | 45 | 46 | def T_to_E(T): 47 | """Convert batched poses (..., 4, 4) to batched essential matrices.""" 48 | return skew_symmetric(T[..., :3, 3]) @ T[..., :3, :3] 49 | 50 | 51 | def warp_points_torch(points, H, inverse=True): 52 | """ 53 | Warp a list of points with the INVERSE of the given homography. 54 | The inverse is used to be coherent with tf.contrib.image.transform 55 | Arguments: 56 | points: batched list of N points, shape (B, N, 2). 57 | homography: batched or not (shapes (B, 8) and (8,) respectively). 58 | Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warped points. 59 | """ 60 | # H = np.expand_dims(homography, axis=0) if len(homography.shape) == 1 else homography 61 | 62 | # Get the points to the homogeneous format 63 | points = to_homogeneous(points) 64 | 65 | # Apply the homography 66 | out_shape = tuple(list(H.shape[:-1]) + [3, 3]) 67 | H_mat = torch.cat([H, torch.ones_like(H[..., :1])], axis=-1).reshape(out_shape) 68 | if inverse: 69 | H_mat = torch.inverse(H_mat) 70 | warped_points = torch.einsum('...nj,...ji->...ni', points, H_mat.transpose(-2, -1)) 71 | 72 | warped_points = from_homogeneous(warped_points, eps=1e-5) 73 | 74 | return warped_points 75 | 76 | 77 | def seg_equation(segs): 78 | # calculate list of start, end and midpoints points from both lists 79 | start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(segs[..., 1, :]) 80 | # Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1 81 | lines = torch.cross(start_points, end_points, dim=-1) 82 | lines_norm = (torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None]) 83 | assert torch.all(lines_norm > 0), 'Error: trying to compute the equation of a line with a single point' 84 | lines = lines / lines_norm 85 | return lines 86 | 87 | 88 | def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]): 89 | h, w = img_shape 90 | return (pts >= 0).all(dim=-1) & (pts[..., 0] < w) & (pts[..., 1] < h) & (~torch.isinf(pts).any(dim=-1)) 91 | 92 | 93 | def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor: 94 | """ 95 | Shrink an array of segments to fit inside the image. 96 | :param segs: The tensor of segments with shape (N, 2, 2) 97 | :param img_shape: The image shape in format (H, W) 98 | """ 99 | EPS = 1e-4 100 | device = segs.device 101 | w, h = img_shape[1], img_shape[0] 102 | # Project the segments to the reference image 103 | segs = segs.clone() 104 | eqs = seg_equation(segs) 105 | x0, y0 = torch.tensor([1., 0, 0.], device=device), torch.tensor([0., 1, 0], device=device) 106 | x0 = x0.repeat(eqs.shape[:-1] + (1,)) 107 | y0 = y0.repeat(eqs.shape[:-1] + (1,)) 108 | pt_x0s = torch.cross(eqs, x0, dim=-1) 109 | pt_x0s = pt_x0s[..., :-1] / pt_x0s[..., None, -1] 110 | pt_x0s_valid = is_inside_img(pt_x0s, img_shape) 111 | pt_y0s = torch.cross(eqs, y0, dim=-1) 112 | pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1] 113 | pt_y0s_valid = is_inside_img(pt_y0s, img_shape) 114 | 115 | xW, yH = torch.tensor([1., 0, EPS - w], device=device), torch.tensor([0., 1, EPS - h], device=device) 116 | xW = xW.repeat(eqs.shape[:-1] + (1,)) 117 | yH = yH.repeat(eqs.shape[:-1] + (1,)) 118 | pt_xWs = torch.cross(eqs, xW, dim=-1) 119 | pt_xWs = pt_xWs[..., :-1] / pt_xWs[..., None, -1] 120 | pt_xWs_valid = is_inside_img(pt_xWs, img_shape) 121 | pt_yHs = torch.cross(eqs, yH, dim=-1) 122 | pt_yHs = pt_yHs[..., :-1] / pt_yHs[..., None, -1] 123 | pt_yHs_valid = is_inside_img(pt_yHs, img_shape) 124 | 125 | # If the X coordinate of the first endpoint is out 126 | mask = (segs[..., 0, 0] < 0) & pt_x0s_valid 127 | segs[mask, 0, :] = pt_x0s[mask] 128 | mask = (segs[..., 0, 0] > (w - 1)) & pt_xWs_valid 129 | segs[mask, 0, :] = pt_xWs[mask] 130 | # If the X coordinate of the second endpoint is out 131 | mask = (segs[..., 1, 0] < 0) & pt_x0s_valid 132 | segs[mask, 1, :] = pt_x0s[mask] 133 | mask = (segs[:, 1, 0] > (w - 1)) & pt_xWs_valid 134 | segs[mask, 1, :] = pt_xWs[mask] 135 | # If the Y coordinate of the first endpoint is out 136 | mask = (segs[..., 0, 1] < 0) & pt_y0s_valid 137 | segs[mask, 0, :] = pt_y0s[mask] 138 | mask = (segs[..., 0, 1] > (h - 1)) & pt_yHs_valid 139 | segs[mask, 0, :] = pt_yHs[mask] 140 | # If the Y coordinate of the second endpoint is out 141 | mask = (segs[..., 1, 1] < 0) & pt_y0s_valid 142 | segs[mask, 1, :] = pt_y0s[mask] 143 | mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid 144 | segs[mask, 1, :] = pt_yHs[mask] 145 | 146 | assert torch.all(segs >= 0) and torch.all(segs[..., 0] < w) and torch.all(segs[..., 1] < h) 147 | return segs 148 | 149 | 150 | def warp_lines_torch(lines, H, inverse=True, dst_shape: Tuple[int, int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 151 | """ 152 | :param lines: A tensor of shape (B, N, 2, 2) where B is the batch size, N the number of lines. 153 | :param H: The homography used to convert the lines. batched or not (shapes (B, 8) and (8,) respectively). 154 | :param inverse: Whether to apply H or the inverse of H 155 | :param dst_shape:If provided, lines are trimmed to be inside the image 156 | """ 157 | device = lines.device 158 | batch_size, n = lines.shape[:2] 159 | lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(lines.shape) 160 | 161 | if dst_shape is None: 162 | return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device) 163 | 164 | out_img = torch.any((lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1) 165 | valid = ~out_img.all(-1) 166 | any_out_of_img = out_img.any(-1) 167 | lines_to_trim = valid & any_out_of_img 168 | 169 | for b in range(batch_size): 170 | lines_to_trim_mask_b = lines_to_trim[b] 171 | lines_to_trim_b = lines[b][lines_to_trim_mask_b] 172 | corrected_lines = shrink_segs_to_img(lines_to_trim_b, dst_shape) 173 | lines[b][lines_to_trim_mask_b] = corrected_lines 174 | 175 | return lines, valid 176 | -------------------------------------------------------------------------------- /gluestick/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/GlueStick/7d816730ef939caa1c61e2564eceda77304874fa/gluestick/models/__init__.py -------------------------------------------------------------------------------- /gluestick/models/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for trainable models. 3 | """ 4 | 5 | from abc import ABCMeta, abstractmethod 6 | import omegaconf 7 | from omegaconf import OmegaConf 8 | from torch import nn 9 | from copy import copy 10 | 11 | 12 | class MetaModel(ABCMeta): 13 | def __prepare__(name, bases, **kwds): 14 | total_conf = OmegaConf.create() 15 | for base in bases: 16 | for key in ('base_default_conf', 'default_conf'): 17 | update = getattr(base, key, {}) 18 | if isinstance(update, dict): 19 | update = OmegaConf.create(update) 20 | total_conf = OmegaConf.merge(total_conf, update) 21 | return dict(base_default_conf=total_conf) 22 | 23 | 24 | class BaseModel(nn.Module, metaclass=MetaModel): 25 | """ 26 | What the child model is expect to declare: 27 | default_conf: dictionary of the default configuration of the model. 28 | It recursively updates the default_conf of all parent classes, and 29 | it is updated by the user-provided configuration passed to __init__. 30 | Configurations can be nested. 31 | 32 | required_data_keys: list of expected keys in the input data dictionary. 33 | 34 | strict_conf (optional): boolean. If false, BaseModel does not raise 35 | an error when the user provides an unknown configuration entry. 36 | 37 | _init(self, conf): initialization method, where conf is the final 38 | configuration object (also accessible with `self.conf`). Accessing 39 | unknown configuration entries will raise an error. 40 | 41 | _forward(self, data): method that returns a dictionary of batched 42 | prediction tensors based on a dictionary of batched input data tensors. 43 | 44 | loss(self, pred, data): method that returns a dictionary of losses, 45 | computed from model predictions and input data. Each loss is a batch 46 | of scalars, i.e. a torch.Tensor of shape (B,). 47 | The total loss to be optimized has the key `'total'`. 48 | 49 | metrics(self, pred, data): method that returns a dictionary of metrics, 50 | each as a batch of scalars. 51 | """ 52 | default_conf = { 53 | 'name': None, 54 | 'trainable': True, # if false: do not optimize this model parameters 55 | 'freeze_batch_normalization': False, # use test-time statistics 56 | } 57 | required_data_keys = [] 58 | strict_conf = True 59 | 60 | def __init__(self, conf): 61 | """Perform some logic and call the _init method of the child model.""" 62 | super().__init__() 63 | default_conf = OmegaConf.merge( 64 | self.base_default_conf, OmegaConf.create(self.default_conf)) 65 | if self.strict_conf: 66 | OmegaConf.set_struct(default_conf, True) 67 | 68 | # fixme: backward compatibility 69 | if 'pad' in conf and 'pad' not in default_conf: # backward compat. 70 | with omegaconf.read_write(conf): 71 | with omegaconf.open_dict(conf): 72 | conf['interpolation'] = {'pad': conf.pop('pad')} 73 | 74 | if isinstance(conf, dict): 75 | conf = OmegaConf.create(conf) 76 | self.conf = conf = OmegaConf.merge(default_conf, conf) 77 | OmegaConf.set_readonly(conf, True) 78 | OmegaConf.set_struct(conf, True) 79 | self.required_data_keys = copy(self.required_data_keys) 80 | self._init(conf) 81 | 82 | if not conf.trainable: 83 | for p in self.parameters(): 84 | p.requires_grad = False 85 | 86 | def train(self, mode=True): 87 | super().train(mode) 88 | 89 | def freeze_bn(module): 90 | if isinstance(module, nn.modules.batchnorm._BatchNorm): 91 | module.eval() 92 | if self.conf.freeze_batch_normalization: 93 | self.apply(freeze_bn) 94 | 95 | return self 96 | 97 | def forward(self, data): 98 | """Check the data and call the _forward method of the child model.""" 99 | def recursive_key_check(expected, given): 100 | for key in expected: 101 | assert key in given, f'Missing key {key} in data' 102 | if isinstance(expected, dict): 103 | recursive_key_check(expected[key], given[key]) 104 | 105 | recursive_key_check(self.required_data_keys, data) 106 | return self._forward(data) 107 | 108 | @abstractmethod 109 | def _init(self, conf): 110 | """To be implemented by the child class.""" 111 | raise NotImplementedError 112 | 113 | @abstractmethod 114 | def _forward(self, data): 115 | """To be implemented by the child class.""" 116 | raise NotImplementedError 117 | 118 | @abstractmethod 119 | def loss(self, pred, data): 120 | """To be implemented by the child class.""" 121 | raise NotImplementedError 122 | 123 | @abstractmethod 124 | def metrics(self, pred, data): 125 | """To be implemented by the child class.""" 126 | raise NotImplementedError 127 | -------------------------------------------------------------------------------- /gluestick/models/gluestick.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import warnings 3 | from copy import deepcopy 4 | 5 | warnings.filterwarnings("ignore", category=UserWarning) 6 | import torch 7 | import torch.utils.checkpoint 8 | from torch import nn 9 | from .base_model import BaseModel 10 | 11 | ETH_EPS = 1e-8 12 | 13 | 14 | class GlueStick(BaseModel): 15 | default_conf = { 16 | 'input_dim': 256, 17 | 'descriptor_dim': 256, 18 | 'bottleneck_dim': None, 19 | 'weights': None, 20 | 'keypoint_encoder': [32, 64, 128, 256], 21 | 'GNN_layers': ['self', 'cross'] * 9, 22 | 'num_line_iterations': 1, 23 | 'line_attention': False, 24 | 'filter_threshold': 0.2, 25 | 'checkpointed': False, 26 | 'skip_init': False, 27 | 'inter_supervision': None, 28 | 'loss': { 29 | 'nll_weight': 1., 30 | 'nll_balancing': 0.5, 31 | 'reward_weight': 0., 32 | 'bottleneck_l2_weight': 0., 33 | 'dense_nll_weight': 0., 34 | 'inter_supervision': [0.3, 0.6], 35 | }, 36 | } 37 | required_data_keys = [ 38 | 'keypoints0', 'keypoints1', 39 | 'descriptors0', 'descriptors1', 40 | 'keypoint_scores0', 'keypoint_scores1'] 41 | 42 | DEFAULT_LOSS_CONF = {'nll_weight': 1., 'nll_balancing': 0.5, 'reward_weight': 0., 'bottleneck_l2_weight': 0.} 43 | 44 | def _init(self, conf): 45 | if conf.bottleneck_dim is not None: 46 | self.bottleneck_down = nn.Conv1d( 47 | conf.input_dim, conf.bottleneck_dim, kernel_size=1) 48 | self.bottleneck_up = nn.Conv1d( 49 | conf.bottleneck_dim, conf.input_dim, kernel_size=1) 50 | nn.init.constant_(self.bottleneck_down.bias, 0.0) 51 | nn.init.constant_(self.bottleneck_up.bias, 0.0) 52 | 53 | if conf.input_dim != conf.descriptor_dim: 54 | self.input_proj = nn.Conv1d( 55 | conf.input_dim, conf.descriptor_dim, kernel_size=1) 56 | nn.init.constant_(self.input_proj.bias, 0.0) 57 | 58 | self.kenc = KeypointEncoder(conf.descriptor_dim, 59 | conf.keypoint_encoder) 60 | self.lenc = EndPtEncoder(conf.descriptor_dim, conf.keypoint_encoder) 61 | self.gnn = AttentionalGNN(conf.descriptor_dim, conf.GNN_layers, 62 | checkpointed=conf.checkpointed, 63 | inter_supervision=conf.inter_supervision, 64 | num_line_iterations=conf.num_line_iterations, 65 | line_attention=conf.line_attention) 66 | self.final_proj = nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, 67 | kernel_size=1) 68 | nn.init.constant_(self.final_proj.bias, 0.0) 69 | nn.init.orthogonal_(self.final_proj.weight, gain=1) 70 | self.final_line_proj = nn.Conv1d( 71 | conf.descriptor_dim, conf.descriptor_dim, kernel_size=1) 72 | nn.init.constant_(self.final_line_proj.bias, 0.0) 73 | nn.init.orthogonal_(self.final_line_proj.weight, gain=1) 74 | if conf.inter_supervision is not None: 75 | self.inter_line_proj = nn.ModuleList( 76 | [nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1) 77 | for _ in conf.inter_supervision]) 78 | self.layer2idx = {} 79 | for i, l in enumerate(conf.inter_supervision): 80 | nn.init.constant_(self.inter_line_proj[i].bias, 0.0) 81 | nn.init.orthogonal_(self.inter_line_proj[i].weight, gain=1) 82 | self.layer2idx[l] = i 83 | 84 | bin_score = torch.nn.Parameter(torch.tensor(1.)) 85 | self.register_parameter('bin_score', bin_score) 86 | line_bin_score = torch.nn.Parameter(torch.tensor(1.)) 87 | self.register_parameter('line_bin_score', line_bin_score) 88 | 89 | if conf.weights: 90 | assert isinstance(conf.weights, str) 91 | if os.path.exists(conf.weights): 92 | state_dict = torch.load(conf.weights, map_location='cpu') 93 | else: 94 | weights_url = "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar" 95 | state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') 96 | if 'model' in state_dict: 97 | state_dict = {k.replace('matcher.', ''): v for k, v in state_dict['model'].items() if 'matcher.' in k} 98 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} 99 | self.load_state_dict(state_dict) 100 | 101 | def _forward(self, data): 102 | device = data['keypoints0'].device 103 | b_size = len(data['keypoints0']) 104 | image_size0 = (data['image_size0'] if 'image_size0' in data 105 | else data['image0'].shape) 106 | image_size1 = (data['image_size1'] if 'image_size1' in data 107 | else data['image1'].shape) 108 | 109 | pred = {} 110 | desc0, desc1 = data['descriptors0'], data['descriptors1'] 111 | kpts0, kpts1 = data['keypoints0'], data['keypoints1'] 112 | 113 | n_kpts0, n_kpts1 = kpts0.shape[1], kpts1.shape[1] 114 | n_lines0, n_lines1 = data['lines0'].shape[1], data['lines1'].shape[1] 115 | if n_kpts0 == 0 or n_kpts1 == 0: 116 | # No detected keypoints nor lines 117 | pred['log_assignment'] = torch.zeros( 118 | b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device) 119 | pred['matches0'] = torch.full( 120 | (b_size, n_kpts0), -1, device=device, dtype=torch.int64) 121 | pred['matches1'] = torch.full( 122 | (b_size, n_kpts1), -1, device=device, dtype=torch.int64) 123 | pred['match_scores0'] = torch.zeros( 124 | (b_size, n_kpts0), device=device, dtype=torch.float32) 125 | pred['match_scores1'] = torch.zeros( 126 | (b_size, n_kpts1), device=device, dtype=torch.float32) 127 | pred['line_log_assignment'] = torch.zeros(b_size, n_lines0, n_lines1, 128 | dtype=torch.float, device=device) 129 | pred['line_matches0'] = torch.full((b_size, n_lines0), -1, 130 | device=device, dtype=torch.int64) 131 | pred['line_matches1'] = torch.full((b_size, n_lines1), -1, 132 | device=device, dtype=torch.int64) 133 | pred['line_match_scores0'] = torch.zeros( 134 | (b_size, n_lines0), device=device, dtype=torch.float32) 135 | pred['line_match_scores1'] = torch.zeros( 136 | (b_size, n_kpts1), device=device, dtype=torch.float32) 137 | return pred 138 | 139 | lines0 = data['lines0'].flatten(1, 2) 140 | lines1 = data['lines1'].flatten(1, 2) 141 | lines_junc_idx0 = data['lines_junc_idx0'].flatten(1, 2) # [b_size, num_lines * 2] 142 | lines_junc_idx1 = data['lines_junc_idx1'].flatten(1, 2) 143 | 144 | if self.conf.bottleneck_dim is not None: 145 | pred['down_descriptors0'] = desc0 = self.bottleneck_down(desc0) 146 | pred['down_descriptors1'] = desc1 = self.bottleneck_down(desc1) 147 | desc0 = self.bottleneck_up(desc0) 148 | desc1 = self.bottleneck_up(desc1) 149 | desc0 = nn.functional.normalize(desc0, p=2, dim=1) 150 | desc1 = nn.functional.normalize(desc1, p=2, dim=1) 151 | pred['bottleneck_descriptors0'] = desc0 152 | pred['bottleneck_descriptors1'] = desc1 153 | if self.conf.loss.nll_weight == 0: 154 | desc0 = desc0.detach() 155 | desc1 = desc1.detach() 156 | 157 | if self.conf.input_dim != self.conf.descriptor_dim: 158 | desc0 = self.input_proj(desc0) 159 | desc1 = self.input_proj(desc1) 160 | 161 | kpts0 = normalize_keypoints(kpts0, image_size0) 162 | kpts1 = normalize_keypoints(kpts1, image_size1) 163 | 164 | assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1) 165 | assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1) 166 | desc0 = desc0 + self.kenc(kpts0, data['keypoint_scores0']) 167 | desc1 = desc1 + self.kenc(kpts1, data['keypoint_scores1']) 168 | 169 | if n_lines0 != 0 and n_lines1 != 0: 170 | # Pre-compute the line encodings 171 | lines0 = normalize_keypoints(lines0, image_size0).reshape( 172 | b_size, n_lines0, 2, 2) 173 | lines1 = normalize_keypoints(lines1, image_size1).reshape( 174 | b_size, n_lines1, 2, 2) 175 | line_enc0 = self.lenc(lines0, data['line_scores0']) 176 | line_enc1 = self.lenc(lines1, data['line_scores1']) 177 | else: 178 | line_enc0 = torch.zeros( 179 | b_size, self.conf.descriptor_dim, n_lines0 * 2, 180 | dtype=torch.float, device=device) 181 | line_enc1 = torch.zeros( 182 | b_size, self.conf.descriptor_dim, n_lines1 * 2, 183 | dtype=torch.float, device=device) 184 | 185 | desc0, desc1 = self.gnn(desc0, desc1, line_enc0, line_enc1, 186 | lines_junc_idx0, lines_junc_idx1) 187 | 188 | # Match all points (KP and line junctions) 189 | mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) 190 | 191 | kp_scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) 192 | kp_scores = kp_scores / self.conf.descriptor_dim ** .5 193 | kp_scores = log_double_softmax(kp_scores, self.bin_score) 194 | m0, m1, mscores0, mscores1 = self._get_matches(kp_scores) 195 | pred['log_assignment'] = kp_scores 196 | pred['matches0'] = m0 197 | pred['matches1'] = m1 198 | pred['match_scores0'] = mscores0 199 | pred['match_scores1'] = mscores1 200 | 201 | # Match the lines 202 | if n_lines0 > 0 and n_lines1 > 0: 203 | (line_scores, m0_lines, m1_lines, mscores0_lines, 204 | mscores1_lines, raw_line_scores) = self._get_line_matches( 205 | desc0[:, :, :2 * n_lines0], desc1[:, :, :2 * n_lines1], 206 | lines_junc_idx0, lines_junc_idx1, self.final_line_proj) 207 | if self.conf.inter_supervision: 208 | for l in self.conf.inter_supervision: 209 | (line_scores_i, m0_lines_i, m1_lines_i, mscores0_lines_i, 210 | mscores1_lines_i) = self._get_line_matches( 211 | self.gnn.inter_layers[l][0][:, :, :2 * n_lines0], 212 | self.gnn.inter_layers[l][1][:, :, :2 * n_lines1], 213 | lines_junc_idx0, lines_junc_idx1, 214 | self.inter_line_proj[self.layer2idx[l]]) 215 | pred[f'line_{l}_log_assignment'] = line_scores_i 216 | pred[f'line_{l}_matches0'] = m0_lines_i 217 | pred[f'line_{l}_matches1'] = m1_lines_i 218 | pred[f'line_{l}_match_scores0'] = mscores0_lines_i 219 | pred[f'line_{l}_match_scores1'] = mscores1_lines_i 220 | else: 221 | line_scores = torch.zeros(b_size, n_lines0, n_lines1, 222 | dtype=torch.float, device=device) 223 | m0_lines = torch.full((b_size, n_lines0), -1, 224 | device=device, dtype=torch.int64) 225 | m1_lines = torch.full((b_size, n_lines1), -1, 226 | device=device, dtype=torch.int64) 227 | mscores0_lines = torch.zeros( 228 | (b_size, n_lines0), device=device, dtype=torch.float32) 229 | mscores1_lines = torch.zeros( 230 | (b_size, n_lines1), device=device, dtype=torch.float32) 231 | raw_line_scores = torch.zeros(b_size, n_lines0, n_lines1, 232 | dtype=torch.float, device=device) 233 | pred['line_log_assignment'] = line_scores 234 | pred['line_matches0'] = m0_lines 235 | pred['line_matches1'] = m1_lines 236 | pred['line_match_scores0'] = mscores0_lines 237 | pred['line_match_scores1'] = mscores1_lines 238 | pred['raw_line_scores'] = raw_line_scores 239 | 240 | return pred 241 | 242 | def _get_matches(self, scores_mat): 243 | max0 = scores_mat[:, :-1, :-1].max(2) 244 | max1 = scores_mat[:, :-1, :-1].max(1) 245 | m0, m1 = max0.indices, max1.indices 246 | mutual0 = arange_like(m0, 1)[None] == m1.gather(1, m0) 247 | mutual1 = arange_like(m1, 1)[None] == m0.gather(1, m1) 248 | zero = scores_mat.new_tensor(0) 249 | mscores0 = torch.where(mutual0, max0.values.exp(), zero) 250 | mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero) 251 | valid0 = mutual0 & (mscores0 > self.conf.filter_threshold) 252 | valid1 = mutual1 & valid0.gather(1, m1) 253 | m0 = torch.where(valid0, m0, m0.new_tensor(-1)) 254 | m1 = torch.where(valid1, m1, m1.new_tensor(-1)) 255 | return m0, m1, mscores0, mscores1 256 | 257 | def _get_line_matches(self, ldesc0, ldesc1, lines_junc_idx0, 258 | lines_junc_idx1, final_proj): 259 | mldesc0 = final_proj(ldesc0) 260 | mldesc1 = final_proj(ldesc1) 261 | 262 | line_scores = torch.einsum('bdn,bdm->bnm', mldesc0, mldesc1) 263 | line_scores = line_scores / self.conf.descriptor_dim ** .5 264 | 265 | # Get the line representation from the junction descriptors 266 | n2_lines0 = lines_junc_idx0.shape[1] 267 | n2_lines1 = lines_junc_idx1.shape[1] 268 | line_scores = torch.gather( 269 | line_scores, dim=2, 270 | index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1)) 271 | line_scores = torch.gather( 272 | line_scores, dim=1, 273 | index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1)) 274 | line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2, 275 | n2_lines1 // 2, 2)) 276 | 277 | # Match either in one direction or the other 278 | raw_line_scores = 0.5 * torch.maximum( 279 | line_scores[:, :, 0, :, 0] + line_scores[:, :, 1, :, 1], 280 | line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0]) 281 | line_scores = log_double_softmax(raw_line_scores, self.line_bin_score) 282 | m0_lines, m1_lines, mscores0_lines, mscores1_lines = self._get_matches( 283 | line_scores) 284 | return (line_scores, m0_lines, m1_lines, mscores0_lines, 285 | mscores1_lines, raw_line_scores) 286 | 287 | def loss(self, pred, data): 288 | raise NotImplementedError() 289 | 290 | def metrics(self, pred, data): 291 | raise NotImplementedError() 292 | 293 | 294 | def MLP(channels, do_bn=True): 295 | n = len(channels) 296 | layers = [] 297 | for i in range(1, n): 298 | layers.append( 299 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) 300 | if i < (n - 1): 301 | if do_bn: 302 | layers.append(nn.BatchNorm1d(channels[i])) 303 | layers.append(nn.ReLU()) 304 | return nn.Sequential(*layers) 305 | 306 | 307 | def normalize_keypoints(kpts, shape_or_size): 308 | if isinstance(shape_or_size, (tuple, list)): 309 | # it's a shape 310 | h, w = shape_or_size[-2:] 311 | size = kpts.new_tensor([[w, h]]) 312 | else: 313 | # it's a size 314 | assert isinstance(shape_or_size, torch.Tensor) 315 | size = shape_or_size.to(kpts) 316 | c = size / 2 317 | f = size.max(1, keepdim=True).values * 0.7 # somehow we used 0.7 for SG 318 | return (kpts - c[:, None, :]) / f[:, None, :] 319 | 320 | 321 | class KeypointEncoder(nn.Module): 322 | def __init__(self, feature_dim, layers): 323 | super().__init__() 324 | self.encoder = MLP([3] + list(layers) + [feature_dim], do_bn=True) 325 | nn.init.constant_(self.encoder[-1].bias, 0.0) 326 | 327 | def forward(self, kpts, scores): 328 | inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] 329 | return self.encoder(torch.cat(inputs, dim=1)) 330 | 331 | 332 | class EndPtEncoder(nn.Module): 333 | def __init__(self, feature_dim, layers): 334 | super().__init__() 335 | self.encoder = MLP([5] + list(layers) + [feature_dim], do_bn=True) 336 | nn.init.constant_(self.encoder[-1].bias, 0.0) 337 | 338 | def forward(self, endpoints, scores): 339 | # endpoints should be [B, N, 2, 2] 340 | # output is [B, feature_dim, N * 2] 341 | b_size, n_pts, _, _ = endpoints.shape 342 | assert tuple(endpoints.shape[-2:]) == (2, 2) 343 | endpt_offset = (endpoints[:, :, 1] - endpoints[:, :, 0]).unsqueeze(2) 344 | endpt_offset = torch.cat([endpt_offset, -endpt_offset], dim=2) 345 | endpt_offset = endpt_offset.reshape(b_size, 2 * n_pts, 2).transpose(1, 2) 346 | inputs = [endpoints.flatten(1, 2).transpose(1, 2), 347 | endpt_offset, scores.repeat(1, 2).unsqueeze(1)] 348 | return self.encoder(torch.cat(inputs, dim=1)) 349 | 350 | 351 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 352 | def attention(query, key, value): 353 | dim = query.shape[1] 354 | scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5 355 | prob = torch.nn.functional.softmax(scores, dim=-1) 356 | return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob 357 | 358 | 359 | class MultiHeadedAttention(nn.Module): 360 | def __init__(self, h, d_model): 361 | super().__init__() 362 | assert d_model % h == 0 363 | self.dim = d_model // h 364 | self.h = h 365 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) 366 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) 367 | # self.prob = [] 368 | 369 | def forward(self, query, key, value): 370 | b = query.size(0) 371 | query, key, value = [l(x).view(b, self.dim, self.h, -1) 372 | for l, x in zip(self.proj, (query, key, value))] 373 | x, prob = attention(query, key, value) 374 | # self.prob.append(prob.mean(dim=1)) 375 | return self.merge(x.contiguous().view(b, self.dim * self.h, -1)) 376 | 377 | 378 | class AttentionalPropagation(nn.Module): 379 | def __init__(self, num_dim, num_heads, skip_init=False): 380 | super().__init__() 381 | self.attn = MultiHeadedAttention(num_heads, num_dim) 382 | self.mlp = MLP([num_dim * 2, num_dim * 2, num_dim], do_bn=True) 383 | nn.init.constant_(self.mlp[-1].bias, 0.0) 384 | if skip_init: 385 | self.register_parameter('scaling', nn.Parameter(torch.tensor(0.))) 386 | else: 387 | self.scaling = 1. 388 | 389 | def forward(self, x, source): 390 | message = self.attn(x, source, source) 391 | return self.mlp(torch.cat([x, message], dim=1)) * self.scaling 392 | 393 | 394 | class GNNLayer(nn.Module): 395 | def __init__(self, feature_dim, layer_type, skip_init): 396 | super().__init__() 397 | assert layer_type in ['cross', 'self'] 398 | self.type = layer_type 399 | self.update = AttentionalPropagation(feature_dim, 4, skip_init) 400 | 401 | def forward(self, desc0, desc1): 402 | if self.type == 'cross': 403 | src0, src1 = desc1, desc0 404 | elif self.type == 'self': 405 | src0, src1 = desc0, desc1 406 | else: 407 | raise ValueError("Unknown layer type: " + self.type) 408 | # self.update.attn.prob = [] 409 | delta0, delta1 = self.update(desc0, src0), self.update(desc1, src1) 410 | desc0, desc1 = (desc0 + delta0), (desc1 + delta1) 411 | return desc0, desc1 412 | 413 | 414 | class LineLayer(nn.Module): 415 | def __init__(self, feature_dim, line_attention=False): 416 | super().__init__() 417 | self.dim = feature_dim 418 | self.mlp = MLP([self.dim * 3, self.dim * 2, self.dim], do_bn=True) 419 | self.line_attention = line_attention 420 | if line_attention: 421 | self.proj_node = nn.Conv1d(self.dim, self.dim, kernel_size=1) 422 | self.proj_neigh = nn.Conv1d(2 * self.dim, self.dim, kernel_size=1) 423 | 424 | def get_endpoint_update(self, ldesc, line_enc, lines_junc_idx): 425 | # ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2] 426 | # and lines_junc_idx [bs, n_lines * 2] 427 | # Create one message per line endpoint 428 | b_size = lines_junc_idx.shape[0] 429 | line_desc = torch.gather( 430 | ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1)) 431 | message = torch.cat([ 432 | line_desc, 433 | line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(), 434 | line_enc], dim=1) 435 | return self.mlp(message) # [b_size, D, n_lines * 2] 436 | 437 | def get_endpoint_attention(self, ldesc, line_enc, lines_junc_idx): 438 | # ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2] 439 | # and lines_junc_idx [bs, n_lines * 2] 440 | b_size = lines_junc_idx.shape[0] 441 | expanded_lines_junc_idx = lines_junc_idx[:, None].repeat(1, self.dim, 1) 442 | 443 | # Query: desc of the current node 444 | query = self.proj_node(ldesc) # [b_size, D, n_junc] 445 | query = torch.gather(query, 2, expanded_lines_junc_idx) 446 | # query is [b_size, D, n_lines * 2] 447 | 448 | # Key: combination of neighboring desc and line encodings 449 | line_desc = torch.gather(ldesc, 2, expanded_lines_junc_idx) 450 | key = self.proj_neigh(torch.cat([ 451 | line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(), 452 | line_enc], dim=1)) # [b_size, D, n_lines * 2] 453 | 454 | # Compute the attention weights with a custom softmax per junction 455 | prob = (query * key).sum(dim=1) / self.dim ** .5 # [b_size, n_lines * 2] 456 | prob = torch.exp(prob - prob.max()) 457 | denom = torch.zeros_like(ldesc[:, 0]).scatter_reduce_( 458 | dim=1, index=lines_junc_idx, 459 | src=prob, reduce='sum', include_self=False) # [b_size, n_junc] 460 | denom = torch.gather(denom, 1, lines_junc_idx) # [b_size, n_lines * 2] 461 | prob = prob / (denom + ETH_EPS) 462 | return prob # [b_size, n_lines * 2] 463 | 464 | def forward(self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0, 465 | lines_junc_idx1): 466 | # Gather the endpoint updates 467 | lupdate0 = self.get_endpoint_update(ldesc0, line_enc0, lines_junc_idx0) 468 | lupdate1 = self.get_endpoint_update(ldesc1, line_enc1, lines_junc_idx1) 469 | 470 | update0, update1 = torch.zeros_like(ldesc0), torch.zeros_like(ldesc1) 471 | dim = ldesc0.shape[1] 472 | if self.line_attention: 473 | # Compute an attention for each neighbor and do a weighted average 474 | prob0 = self.get_endpoint_attention(ldesc0, line_enc0, 475 | lines_junc_idx0) 476 | lupdate0 = lupdate0 * prob0[:, None] 477 | update0 = update0.scatter_reduce_( 478 | dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1), 479 | src=lupdate0, reduce='sum', include_self=False) 480 | prob1 = self.get_endpoint_attention(ldesc1, line_enc1, 481 | lines_junc_idx1) 482 | lupdate1 = lupdate1 * prob1[:, None] 483 | update1 = update1.scatter_reduce_( 484 | dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1), 485 | src=lupdate1, reduce='sum', include_self=False) 486 | else: 487 | # Average the updates for each junction (requires torch > 1.12) 488 | update0 = update0.scatter_reduce_( 489 | dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1), 490 | src=lupdate0, reduce='mean', include_self=False) 491 | update1 = update1.scatter_reduce_( 492 | dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1), 493 | src=lupdate1, reduce='mean', include_self=False) 494 | 495 | # Update 496 | ldesc0 = ldesc0 + update0 497 | ldesc1 = ldesc1 + update1 498 | 499 | return ldesc0, ldesc1 500 | 501 | 502 | class AttentionalGNN(nn.Module): 503 | def __init__(self, feature_dim, layer_types, checkpointed=False, 504 | skip=False, inter_supervision=None, num_line_iterations=1, 505 | line_attention=False): 506 | super().__init__() 507 | self.checkpointed = checkpointed 508 | self.inter_supervision = inter_supervision 509 | self.num_line_iterations = num_line_iterations 510 | self.inter_layers = {} 511 | self.layers = nn.ModuleList([ 512 | GNNLayer(feature_dim, layer_type, skip) 513 | for layer_type in layer_types]) 514 | self.line_layers = nn.ModuleList( 515 | [LineLayer(feature_dim, line_attention) 516 | for _ in range(len(layer_types) // 2)]) 517 | 518 | def forward(self, desc0, desc1, line_enc0, line_enc1, 519 | lines_junc_idx0, lines_junc_idx1): 520 | for i, layer in enumerate(self.layers): 521 | if self.checkpointed: 522 | desc0, desc1 = torch.utils.checkpoint.checkpoint( 523 | layer, desc0, desc1, preserve_rng_state=False) 524 | else: 525 | desc0, desc1 = layer(desc0, desc1) 526 | if (layer.type == 'self' and lines_junc_idx0.shape[1] > 0 527 | and lines_junc_idx1.shape[1] > 0): 528 | # Add line self attention layers after every self layer 529 | for _ in range(self.num_line_iterations): 530 | if self.checkpointed: 531 | desc0, desc1 = torch.utils.checkpoint.checkpoint( 532 | self.line_layers[i // 2], desc0, desc1, line_enc0, 533 | line_enc1, lines_junc_idx0, lines_junc_idx1, 534 | preserve_rng_state=False) 535 | else: 536 | desc0, desc1 = self.line_layers[i // 2]( 537 | desc0, desc1, line_enc0, line_enc1, 538 | lines_junc_idx0, lines_junc_idx1) 539 | 540 | # Optionally store the line descriptor at intermediate layers 541 | if (self.inter_supervision is not None 542 | and (i // 2) in self.inter_supervision 543 | and layer.type == 'cross'): 544 | self.inter_layers[i // 2] = (desc0.clone(), desc1.clone()) 545 | return desc0, desc1 546 | 547 | 548 | def log_double_softmax(scores, bin_score): 549 | b, m, n = scores.shape 550 | bin_ = bin_score[None, None, None] 551 | scores0 = torch.cat([scores, bin_.expand(b, m, 1)], 2) 552 | scores1 = torch.cat([scores, bin_.expand(b, 1, n)], 1) 553 | scores0 = torch.nn.functional.log_softmax(scores0, 2) 554 | scores1 = torch.nn.functional.log_softmax(scores1, 1) 555 | scores = scores.new_full((b, m + 1, n + 1), 0) 556 | scores[:, :m, :n] = (scores0[:, :, :n] + scores1[:, :m, :]) / 2 557 | scores[:, :-1, -1] = scores0[:, :, -1] 558 | scores[:, -1, :-1] = scores1[:, -1, :] 559 | return scores 560 | 561 | 562 | def arange_like(x, dim): 563 | return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 564 | -------------------------------------------------------------------------------- /gluestick/models/superpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference model of SuperPoint, a feature detector and descriptor. 3 | 4 | Described in: 5 | SuperPoint: Self-Supervised Interest Point Detection and Description, 6 | Daniel DeTone, Tomasz Malisiewicz, Andrew Rabinovich, CVPRW 2018. 7 | 8 | Original code: github.com/MagicLeapResearch/SuperPointPretrainedNetwork 9 | """ 10 | 11 | import torch 12 | from torch import nn 13 | 14 | from .. import GLUESTICK_ROOT 15 | from ..models.base_model import BaseModel 16 | 17 | 18 | def simple_nms(scores, radius): 19 | """Perform non maximum suppression on the heatmap using max-pooling. 20 | This method does not suppress contiguous points that have the same score. 21 | Args: 22 | scores: the score heatmap of size `(B, H, W)`. 23 | size: an interger scalar, the radius of the NMS window. 24 | """ 25 | 26 | def max_pool(x): 27 | return torch.nn.functional.max_pool2d( 28 | x, kernel_size=radius * 2 + 1, stride=1, padding=radius) 29 | 30 | zeros = torch.zeros_like(scores) 31 | max_mask = scores == max_pool(scores) 32 | for _ in range(2): 33 | supp_mask = max_pool(max_mask.float()) > 0 34 | supp_scores = torch.where(supp_mask, zeros, scores) 35 | new_max_mask = supp_scores == max_pool(supp_scores) 36 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 37 | return torch.where(max_mask, scores, zeros) 38 | 39 | 40 | def remove_borders(keypoints, scores, b, h, w): 41 | mask_h = (keypoints[:, 0] >= b) & (keypoints[:, 0] < (h - b)) 42 | mask_w = (keypoints[:, 1] >= b) & (keypoints[:, 1] < (w - b)) 43 | mask = mask_h & mask_w 44 | return keypoints[mask], scores[mask] 45 | 46 | 47 | def top_k_keypoints(keypoints, scores, k): 48 | if k >= len(keypoints): 49 | return keypoints, scores 50 | scores, indices = torch.topk(scores, k, dim=0, sorted=True) 51 | return keypoints[indices], scores 52 | 53 | 54 | def sample_descriptors(keypoints, descriptors, s): 55 | b, c, h, w = descriptors.shape 56 | keypoints = keypoints - s / 2 + 0.5 57 | keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], 58 | ).to(keypoints)[None] 59 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1) 60 | args = {'align_corners': True} if torch.__version__ >= '1.3' else {} 61 | descriptors = torch.nn.functional.grid_sample( 62 | descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) 63 | descriptors = torch.nn.functional.normalize( 64 | descriptors.reshape(b, c, -1), p=2, dim=1) 65 | return descriptors 66 | 67 | 68 | class SuperPoint(BaseModel): 69 | default_conf = { 70 | 'has_detector': True, 71 | 'has_descriptor': True, 72 | 'descriptor_dim': 256, 73 | 74 | # Inference 75 | 'return_all': False, 76 | 'sparse_outputs': True, 77 | 'nms_radius': 4, 78 | 'detection_threshold': 0.005, 79 | 'max_num_keypoints': -1, 80 | 'force_num_keypoints': False, 81 | 'remove_borders': 4, 82 | } 83 | required_data_keys = ['image'] 84 | 85 | def _init(self, conf): 86 | self.relu = nn.ReLU(inplace=True) 87 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 88 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 89 | 90 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) 91 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 92 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 93 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 94 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 95 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 96 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 97 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 98 | 99 | if conf.has_detector: 100 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 101 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) 102 | 103 | if conf.has_descriptor: 104 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 105 | self.convDb = nn.Conv2d( 106 | c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0) 107 | 108 | path = GLUESTICK_ROOT / 'resources' / 'weights' / 'superpoint_v1.pth' 109 | if path.exists(): 110 | weights = torch.load(str(path), map_location='cpu') 111 | else: 112 | weights_url = "https://github.com/cvg/GlueStick/raw/main/resources/weights/superpoint_v1.pth" 113 | weights = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') 114 | self.load_state_dict(weights, strict=False) 115 | 116 | def _forward(self, data): 117 | image = data['image'] 118 | if image.shape[1] == 3: # RGB 119 | scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) 120 | image = (image * scale).sum(1, keepdim=True) 121 | 122 | # Shared Encoder 123 | x = self.relu(self.conv1a(image)) 124 | x = self.relu(self.conv1b(x)) 125 | x = self.pool(x) 126 | x = self.relu(self.conv2a(x)) 127 | x = self.relu(self.conv2b(x)) 128 | x = self.pool(x) 129 | x = self.relu(self.conv3a(x)) 130 | x = self.relu(self.conv3b(x)) 131 | x = self.pool(x) 132 | x = self.relu(self.conv4a(x)) 133 | x = self.relu(self.conv4b(x)) 134 | 135 | pred = {} 136 | if self.conf.has_detector and self.conf.max_num_keypoints != 0: 137 | # Compute the dense keypoint scores 138 | cPa = self.relu(self.convPa(x)) 139 | scores = self.convPb(cPa) 140 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1] 141 | b, c, h, w = scores.shape 142 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 143 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) 144 | pred['keypoint_scores'] = dense_scores = scores 145 | if self.conf.has_descriptor: 146 | # Compute the dense descriptors 147 | cDa = self.relu(self.convDa(x)) 148 | all_desc = self.convDb(cDa) 149 | all_desc = torch.nn.functional.normalize(all_desc, p=2, dim=1) 150 | pred['descriptors'] = all_desc 151 | 152 | if self.conf.max_num_keypoints == 0: # Predict dense descriptors only 153 | b_size = len(image) 154 | device = image.device 155 | return { 156 | 'keypoints': torch.empty(b_size, 0, 2, device=device), 157 | 'keypoint_scores': torch.empty(b_size, 0, device=device), 158 | 'descriptors': torch.empty(b_size, self.conf.descriptor_dim, 0, device=device), 159 | 'all_descriptors': all_desc 160 | } 161 | 162 | if self.conf.sparse_outputs: 163 | assert self.conf.has_detector and self.conf.has_descriptor 164 | 165 | scores = simple_nms(scores, self.conf.nms_radius) 166 | 167 | # Extract keypoints 168 | keypoints = [ 169 | torch.nonzero(s > self.conf.detection_threshold) 170 | for s in scores] 171 | scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] 172 | 173 | # Discard keypoints near the image borders 174 | keypoints, scores = list(zip(*[ 175 | remove_borders(k, s, self.conf.remove_borders, h * 8, w * 8) 176 | for k, s in zip(keypoints, scores)])) 177 | 178 | # Keep the k keypoints with highest score 179 | if self.conf.max_num_keypoints > 0: 180 | keypoints, scores = list(zip(*[ 181 | top_k_keypoints(k, s, self.conf.max_num_keypoints) 182 | for k, s in zip(keypoints, scores)])) 183 | 184 | # Convert (h, w) to (x, y) 185 | keypoints = [torch.flip(k, [1]).float() for k in keypoints] 186 | 187 | if self.conf.force_num_keypoints: 188 | _, _, h, w = data['image'].shape 189 | assert self.conf.max_num_keypoints > 0 190 | scores = list(scores) 191 | for i in range(len(keypoints)): 192 | k, s = keypoints[i], scores[i] 193 | missing = self.conf.max_num_keypoints - len(k) 194 | if missing > 0: 195 | new_k = torch.rand(missing, 2).to(k) 196 | new_k = new_k * k.new_tensor([[w - 1, h - 1]]) 197 | new_s = torch.zeros(missing).to(s) 198 | keypoints[i] = torch.cat([k, new_k], 0) 199 | scores[i] = torch.cat([s, new_s], 0) 200 | 201 | # Extract descriptors 202 | desc = [sample_descriptors(k[None], d[None], 8)[0] 203 | for k, d in zip(keypoints, all_desc)] 204 | 205 | if (len(keypoints) == 1) or self.conf.force_num_keypoints: 206 | keypoints = torch.stack(keypoints, 0) 207 | scores = torch.stack(scores, 0) 208 | desc = torch.stack(desc, 0) 209 | 210 | pred = { 211 | 'keypoints': keypoints, 212 | 'keypoint_scores': scores, 213 | 'descriptors': desc, 214 | } 215 | 216 | if self.conf.return_all: 217 | pred['all_descriptors'] = all_desc 218 | pred['dense_score'] = dense_scores 219 | else: 220 | del all_desc 221 | torch.cuda.empty_cache() 222 | 223 | return pred 224 | 225 | def loss(self, pred, data): 226 | raise NotImplementedError 227 | 228 | def metrics(self, pred, data): 229 | raise NotImplementedError 230 | -------------------------------------------------------------------------------- /gluestick/models/two_view_pipeline.py: -------------------------------------------------------------------------------- 1 | """ 2 | A two-view sparse feature matching pipeline. 3 | 4 | This model contains sub-models for each step: 5 | feature extraction, feature matching, outlier filtering, pose estimation. 6 | Each step is optional, and the features or matches can be provided as input. 7 | Default: SuperPoint with nearest neighbor matching. 8 | 9 | Convention for the matches: m0[i] is the index of the keypoint in image 1 10 | that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched. 11 | """ 12 | 13 | import numpy as np 14 | import torch 15 | 16 | from .. import get_model 17 | from .base_model import BaseModel 18 | 19 | 20 | def keep_quadrant_kp_subset(keypoints, scores, descs, h, w): 21 | """Keep only keypoints in one of the four quadrant of the image.""" 22 | h2, w2 = h // 2, w // 2 23 | w_x = np.random.choice([0, w2]) 24 | w_y = np.random.choice([0, h2]) 25 | valid_mask = ((keypoints[..., 0] >= w_x) 26 | & (keypoints[..., 0] < w_x + w2) 27 | & (keypoints[..., 1] >= w_y) 28 | & (keypoints[..., 1] < w_y + h2)) 29 | keypoints = keypoints[valid_mask][None] 30 | scores = scores[valid_mask][None] 31 | descs = descs.permute(0, 2, 1)[valid_mask].t()[None] 32 | return keypoints, scores, descs 33 | 34 | 35 | def keep_random_kp_subset(keypoints, scores, descs, num_selected): 36 | """Keep a random subset of keypoints.""" 37 | num_kp = keypoints.shape[1] 38 | selected_kp = torch.randperm(num_kp)[:num_selected] 39 | keypoints = keypoints[:, selected_kp] 40 | scores = scores[:, selected_kp] 41 | descs = descs[:, :, selected_kp] 42 | return keypoints, scores, descs 43 | 44 | 45 | def keep_best_kp_subset(keypoints, scores, descs, num_selected): 46 | """Keep the top num_selected best keypoints.""" 47 | sorted_indices = torch.sort(scores, dim=1)[1] 48 | selected_kp = sorted_indices[:, -num_selected:] 49 | keypoints = torch.gather(keypoints, 1, 50 | selected_kp[:, :, None].repeat(1, 1, 2)) 51 | scores = torch.gather(scores, 1, selected_kp) 52 | descs = torch.gather(descs, 2, 53 | selected_kp[:, None].repeat(1, descs.shape[1], 1)) 54 | return keypoints, scores, descs 55 | 56 | 57 | class TwoViewPipeline(BaseModel): 58 | default_conf = { 59 | 'extractor': { 60 | 'name': 'superpoint', 61 | 'trainable': False, 62 | }, 63 | 'use_lines': False, 64 | 'use_points': True, 65 | 'randomize_num_kp': False, 66 | 'detector': {'name': None}, 67 | 'descriptor': {'name': None}, 68 | 'matcher': {'name': 'nearest_neighbor_matcher'}, 69 | 'filter': {'name': None}, 70 | 'solver': {'name': None}, 71 | 'ground_truth': { 72 | 'from_pose_depth': False, 73 | 'from_homography': False, 74 | 'th_positive': 3, 75 | 'th_negative': 5, 76 | 'reward_positive': 1, 77 | 'reward_negative': -0.25, 78 | 'is_likelihood_soft': True, 79 | 'p_random_occluders': 0, 80 | 'n_line_sampled_pts': 50, 81 | 'line_perp_dist_th': 5, 82 | 'overlap_th': 0.2, 83 | 'min_visibility_th': 0.5 84 | }, 85 | } 86 | required_data_keys = ['image0', 'image1'] 87 | strict_conf = False # need to pass new confs to children models 88 | components = [ 89 | 'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver'] 90 | 91 | def _init(self, conf): 92 | if conf.extractor.name: 93 | self.extractor = get_model(conf.extractor.name)(conf.extractor) 94 | else: 95 | if self.conf.detector.name: 96 | self.detector = get_model(conf.detector.name)(conf.detector) 97 | else: 98 | self.required_data_keys += ['keypoints0', 'keypoints1'] 99 | if self.conf.descriptor.name: 100 | self.descriptor = get_model(conf.descriptor.name)( 101 | conf.descriptor) 102 | else: 103 | self.required_data_keys += ['descriptors0', 'descriptors1'] 104 | 105 | if conf.matcher.name: 106 | self.matcher = get_model(conf.matcher.name)(conf.matcher) 107 | else: 108 | self.required_data_keys += ['matches0'] 109 | 110 | if conf.filter.name: 111 | self.filter = get_model(conf.filter.name)(conf.filter) 112 | 113 | if conf.solver.name: 114 | self.solver = get_model(conf.solver.name)(conf.solver) 115 | 116 | def _forward(self, data): 117 | 118 | def process_siamese(data, i): 119 | data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i} 120 | if self.conf.extractor.name: 121 | pred_i = self.extractor(data_i) 122 | else: 123 | pred_i = {} 124 | if self.conf.detector.name: 125 | pred_i = self.detector(data_i) 126 | else: 127 | for k in ['keypoints', 'keypoint_scores', 'descriptors', 128 | 'lines', 'line_scores', 'line_descriptors', 129 | 'valid_lines']: 130 | if k in data_i: 131 | pred_i[k] = data_i[k] 132 | if self.conf.descriptor.name: 133 | pred_i = { 134 | **pred_i, **self.descriptor({**data_i, **pred_i})} 135 | return pred_i 136 | 137 | pred0 = process_siamese(data, '0') 138 | pred1 = process_siamese(data, '1') 139 | 140 | pred = {**{k + '0': v for k, v in pred0.items()}, 141 | **{k + '1': v for k, v in pred1.items()}} 142 | 143 | if self.conf.matcher.name: 144 | pred = {**pred, **self.matcher({**data, **pred})} 145 | 146 | if self.conf.filter.name: 147 | pred = {**pred, **self.filter({**data, **pred})} 148 | 149 | if self.conf.solver.name: 150 | pred = {**pred, **self.solver({**data, **pred})} 151 | 152 | return pred 153 | 154 | def loss(self, pred, data): 155 | losses = {} 156 | total = 0 157 | for k in self.components: 158 | if self.conf[k].name: 159 | try: 160 | losses_ = getattr(self, k).loss(pred, {**pred, **data}) 161 | except NotImplementedError: 162 | continue 163 | losses = {**losses, **losses_} 164 | total = losses_['total'] + total 165 | return {**losses, 'total': total} 166 | 167 | def metrics(self, pred, data): 168 | metrics = {} 169 | for k in self.components: 170 | if self.conf[k].name: 171 | try: 172 | metrics_ = getattr(self, k).metrics(pred, {**pred, **data}) 173 | except NotImplementedError: 174 | continue 175 | metrics = {**metrics, **metrics_} 176 | return metrics 177 | -------------------------------------------------------------------------------- /gluestick/models/wireframe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytlsd import lsd 4 | from sklearn.cluster import DBSCAN 5 | 6 | from .base_model import BaseModel 7 | from .superpoint import SuperPoint, sample_descriptors 8 | from ..geometry import warp_lines_torch 9 | 10 | 11 | def lines_to_wireframe(lines, line_scores, all_descs, conf): 12 | """ Given a set of lines, their score and dense descriptors, 13 | merge close-by endpoints and compute a wireframe defined by 14 | its junctions and connectivity. 15 | Returns: 16 | junctions: list of [num_junc, 2] tensors listing all wireframe junctions 17 | junc_scores: list of [num_junc] tensors with the junction score 18 | junc_descs: list of [dim, num_junc] tensors with the junction descriptors 19 | connectivity: list of [num_junc, num_junc] bool arrays with True when 2 junctions are connected 20 | new_lines: the new set of [b_size, num_lines, 2, 2] lines 21 | lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the junctions of each endpoint 22 | num_true_junctions: a list of the number of valid junctions for each image in the batch, 23 | i.e. before filling with random ones 24 | """ 25 | b_size, _, _, _ = all_descs.shape 26 | device = lines.device 27 | endpoints = lines.reshape(b_size, -1, 2) 28 | 29 | (junctions, junc_scores, junc_descs, connectivity, new_lines, 30 | lines_junc_idx, num_true_junctions) = [], [], [], [], [], [], [] 31 | for bs in range(b_size): 32 | # Cluster the junctions that are close-by 33 | db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit( 34 | endpoints[bs].cpu().numpy()) 35 | clusters = db.labels_ 36 | n_clusters = len(set(clusters)) 37 | num_true_junctions.append(n_clusters) 38 | 39 | # Compute the average junction and score for each cluster 40 | clusters = torch.tensor(clusters, dtype=torch.long, 41 | device=device) 42 | new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, 43 | device=device) 44 | new_junc.scatter_reduce_(0, clusters[:, None].repeat(1, 2), 45 | endpoints[bs], reduce='mean', 46 | include_self=False) 47 | junctions.append(new_junc) 48 | new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device) 49 | new_scores.scatter_reduce_( 50 | 0, clusters, torch.repeat_interleave(line_scores[bs], 2), 51 | reduce='mean', include_self=False) 52 | junc_scores.append(new_scores) 53 | 54 | # Compute the new lines 55 | new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2)) 56 | lines_junc_idx.append(clusters.reshape(-1, 2)) 57 | 58 | # Compute the junction connectivity 59 | junc_connect = torch.eye(n_clusters, dtype=torch.bool, 60 | device=device) 61 | pairs = clusters.reshape(-1, 2) # these pairs are connected by a line 62 | junc_connect[pairs[:, 0], pairs[:, 1]] = True 63 | junc_connect[pairs[:, 1], pairs[:, 0]] = True 64 | connectivity.append(junc_connect) 65 | 66 | # Interpolate the new junction descriptors 67 | junc_descs.append(sample_descriptors( 68 | junctions[-1][None], all_descs[bs:(bs + 1)], 8)[0]) 69 | 70 | new_lines = torch.stack(new_lines, dim=0) 71 | lines_junc_idx = torch.stack(lines_junc_idx, dim=0) 72 | return (junctions, junc_scores, junc_descs, connectivity, 73 | new_lines, lines_junc_idx, num_true_junctions) 74 | 75 | 76 | class SPWireframeDescriptor(BaseModel): 77 | default_conf = { 78 | 'sp_params': { 79 | 'has_detector': True, 80 | 'has_descriptor': True, 81 | 'descriptor_dim': 256, 82 | 'trainable': False, 83 | 84 | # Inference 85 | 'return_all': True, 86 | 'sparse_outputs': True, 87 | 'nms_radius': 4, 88 | 'detection_threshold': 0.005, 89 | 'max_num_keypoints': 1000, 90 | 'force_num_keypoints': True, 91 | 'remove_borders': 4, 92 | }, 93 | 'wireframe_params': { 94 | 'merge_points': True, 95 | 'merge_line_endpoints': True, 96 | 'nms_radius': 3, 97 | 'max_n_junctions': 500, 98 | }, 99 | 'max_n_lines': 250, 100 | 'min_length': 15, 101 | } 102 | required_data_keys = ['image'] 103 | 104 | def _init(self, conf): 105 | self.conf = conf 106 | self.sp = SuperPoint(conf.sp_params) 107 | 108 | def detect_lsd_lines(self, x, max_n_lines=None): 109 | if max_n_lines is None: 110 | max_n_lines = self.conf.max_n_lines 111 | lines, scores, valid_lines = [], [], [] 112 | for b in range(len(x)): 113 | # For each image on batch 114 | img = (x[b].squeeze().cpu().numpy() * 255).astype(np.uint8) 115 | if max_n_lines is None: 116 | b_segs = lsd(img) 117 | else: 118 | for s in [0.3, 0.4, 0.5, 0.7, 0.8, 1.0]: 119 | b_segs = lsd(img, scale=s) 120 | if len(b_segs) >= max_n_lines: 121 | break 122 | 123 | segs_length = np.linalg.norm(b_segs[:, 2:4] - b_segs[:, 0:2], axis=1) 124 | # Remove short lines 125 | b_segs = b_segs[segs_length >= self.conf.min_length] 126 | segs_length = segs_length[segs_length >= self.conf.min_length] 127 | b_scores = b_segs[:, -1] * np.sqrt(segs_length) 128 | # Take the most relevant segments with 129 | indices = np.argsort(-b_scores) 130 | if max_n_lines is not None: 131 | indices = indices[:max_n_lines] 132 | lines.append(torch.from_numpy(b_segs[indices, :4].reshape(-1, 2, 2))) 133 | scores.append(torch.from_numpy(b_scores[indices])) 134 | valid_lines.append(torch.ones_like(scores[-1], dtype=torch.bool)) 135 | 136 | lines = torch.stack(lines).to(x) 137 | scores = torch.stack(scores).to(x) 138 | valid_lines = torch.stack(valid_lines).to(x.device) 139 | return lines, scores, valid_lines 140 | 141 | def _forward(self, data): 142 | b_size, _, h, w = data['image'].shape 143 | device = data['image'].device 144 | 145 | if not self.conf.sp_params.force_num_keypoints: 146 | assert b_size == 1, "Only batch size of 1 accepted for non padded inputs" 147 | 148 | # Line detection 149 | if 'lines' not in data or 'line_scores' not in data: 150 | if 'original_img' in data: 151 | # Detect more lines, because when projecting them to the image most of them will be discarded 152 | lines, line_scores, valid_lines = self.detect_lsd_lines( 153 | data['original_img'], self.conf.max_n_lines * 3) 154 | # Apply the same transformation that is applied in homography_adaptation 155 | lines, valid_lines2 = warp_lines_torch(lines, data['H'], False, data['image'].shape[-2:]) 156 | valid_lines = valid_lines & valid_lines2 157 | lines[~valid_lines] = -1 158 | line_scores[~valid_lines] = 0 159 | # Re-sort the line segments to pick the ones that are inside the image and have bigger score 160 | sorted_scores, sorting_indices = torch.sort(line_scores, dim=-1, descending=True) 161 | line_scores = sorted_scores[:, :self.conf.max_n_lines] 162 | sorting_indices = sorting_indices[:, :self.conf.max_n_lines] 163 | lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1) 164 | valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1) 165 | else: 166 | lines, line_scores, valid_lines = self.detect_lsd_lines(data['image']) 167 | 168 | else: 169 | lines, line_scores, valid_lines = data['lines'], data['line_scores'], data['valid_lines'] 170 | if line_scores.shape[-1] != 0: 171 | line_scores /= (line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None]) 172 | 173 | # SuperPoint prediction 174 | pred = self.sp(data) 175 | 176 | # Remove keypoints that are too close to line endpoints 177 | if self.conf.wireframe_params.merge_points: 178 | kp = pred['keypoints'] 179 | line_endpts = lines.reshape(b_size, -1, 2) 180 | dist_pt_lines = torch.norm( 181 | kp[:, :, None] - line_endpts[:, None], dim=-1) 182 | # For each keypoint, mark it as valid or to remove 183 | pts_to_remove = torch.any( 184 | dist_pt_lines < self.conf.sp_params.nms_radius, dim=2) 185 | # Simply remove them (we assume batch_size = 1 here) 186 | assert len(kp) == 1 187 | pred['keypoints'] = pred['keypoints'][0][~pts_to_remove[0]][None] 188 | pred['keypoint_scores'] = pred['keypoint_scores'][0][~pts_to_remove[0]][None] 189 | pred['descriptors'] = pred['descriptors'][0].T[~pts_to_remove[0]].T[None] 190 | 191 | # Connect the lines together to form a wireframe 192 | orig_lines = lines.clone() 193 | if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0: 194 | # Merge first close-by endpoints to connect lines 195 | (line_points, line_pts_scores, line_descs, line_association, 196 | lines, lines_junc_idx, num_true_junctions) = lines_to_wireframe( 197 | lines, line_scores, pred['all_descriptors'], 198 | conf=self.conf.wireframe_params) 199 | 200 | # Add the keypoints to the junctions and fill the rest with random keypoints 201 | (all_points, all_scores, all_descs, 202 | pl_associativity) = [], [], [], [] 203 | for bs in range(b_size): 204 | all_points.append(torch.cat( 205 | [line_points[bs], pred['keypoints'][bs]], dim=0)) 206 | all_scores.append(torch.cat( 207 | [line_pts_scores[bs], pred['keypoint_scores'][bs]], dim=0)) 208 | all_descs.append(torch.cat( 209 | [line_descs[bs], pred['descriptors'][bs]], dim=1)) 210 | 211 | associativity = torch.eye(len(all_points[-1]), dtype=torch.bool, device=device) 212 | associativity[:num_true_junctions[bs], :num_true_junctions[bs]] = \ 213 | line_association[bs][:num_true_junctions[bs], :num_true_junctions[bs]] 214 | pl_associativity.append(associativity) 215 | 216 | all_points = torch.stack(all_points, dim=0) 217 | all_scores = torch.stack(all_scores, dim=0) 218 | all_descs = torch.stack(all_descs, dim=0) 219 | pl_associativity = torch.stack(pl_associativity, dim=0) 220 | else: 221 | # Lines are independent 222 | all_points = torch.cat([lines.reshape(b_size, -1, 2), 223 | pred['keypoints']], dim=1) 224 | n_pts = all_points.shape[1] 225 | num_lines = lines.shape[1] 226 | num_true_junctions = [num_lines * 2] * b_size 227 | all_scores = torch.cat([ 228 | torch.repeat_interleave(line_scores, 2, dim=1), 229 | pred['keypoint_scores']], dim=1) 230 | pred['line_descriptors'] = self.endpoints_pooling( 231 | lines, pred['all_descriptors'], (h, w)) 232 | all_descs = torch.cat([ 233 | pred['line_descriptors'].reshape(b_size, self.conf.sp_params.descriptor_dim, -1), 234 | pred['descriptors']], dim=2) 235 | pl_associativity = torch.eye( 236 | n_pts, dtype=torch.bool, 237 | device=device)[None].repeat(b_size, 1, 1) 238 | lines_junc_idx = torch.arange( 239 | num_lines * 2, device=device).reshape(1, -1, 2).repeat(b_size, 1, 1) 240 | 241 | del pred['all_descriptors'] # Remove dense descriptors to save memory 242 | torch.cuda.empty_cache() 243 | 244 | return {'keypoints': all_points, 245 | 'keypoint_scores': all_scores, 246 | 'descriptors': all_descs, 247 | 'pl_associativity': pl_associativity, 248 | 'num_junctions': torch.tensor(num_true_junctions), 249 | 'lines': lines, 250 | 'orig_lines': orig_lines, 251 | 'lines_junc_idx': lines_junc_idx, 252 | 'line_scores': line_scores, 253 | 'valid_lines': valid_lines} 254 | 255 | @staticmethod 256 | def endpoints_pooling(segs, all_descriptors, img_shape): 257 | assert segs.ndim == 4 and segs.shape[-2:] == (2, 2) 258 | filter_shape = all_descriptors.shape[-2:] 259 | scale_x = filter_shape[1] / img_shape[1] 260 | scale_y = filter_shape[0] / img_shape[0] 261 | 262 | scaled_segs = torch.round(segs * torch.tensor([scale_x, scale_y]).to(segs)).long() 263 | scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1) 264 | scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1) 265 | line_descriptors = [all_descriptors[None, b, ..., torch.squeeze(b_segs[..., 1]), torch.squeeze(b_segs[..., 0])] 266 | for b, b_segs in enumerate(scaled_segs)] 267 | line_descriptors = torch.cat(line_descriptors) 268 | return line_descriptors # Shape (1, 256, 308, 2) 269 | 270 | def loss(self, pred, data): 271 | raise NotImplementedError 272 | 273 | def metrics(self, pred, data): 274 | return {} 275 | -------------------------------------------------------------------------------- /gluestick/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os.path import join 4 | 5 | import cv2 6 | import torch 7 | from matplotlib import pyplot as plt 8 | 9 | from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT 10 | from .drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches 11 | from .models.two_view_pipeline import TwoViewPipeline 12 | 13 | 14 | def main(): 15 | # Parse input parameters 16 | parser = argparse.ArgumentParser( 17 | prog='GlueStick Demo', 18 | description='Demo app to show the point and line matches obtained by GlueStick') 19 | parser.add_argument('-img1', default=join('resources' + os.path.sep + 'img1.jpg')) 20 | parser.add_argument('-img2', default=join('resources' + os.path.sep + 'img2.jpg')) 21 | parser.add_argument('--max_pts', type=int, default=1000) 22 | parser.add_argument('--max_lines', type=int, default=300) 23 | parser.add_argument('--skip-imshow', default=False, action='store_true') 24 | args = parser.parse_args() 25 | 26 | # Evaluation config 27 | conf = { 28 | 'name': 'two_view_pipeline', 29 | 'use_lines': True, 30 | 'extractor': { 31 | 'name': 'wireframe', 32 | 'sp_params': { 33 | 'force_num_keypoints': False, 34 | 'max_num_keypoints': args.max_pts, 35 | }, 36 | 'wireframe_params': { 37 | 'merge_points': True, 38 | 'merge_line_endpoints': True, 39 | }, 40 | 'max_n_lines': args.max_lines, 41 | }, 42 | 'matcher': { 43 | 'name': 'gluestick', 44 | 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), 45 | 'trainable': False, 46 | }, 47 | 'ground_truth': { 48 | 'from_pose_depth': False, 49 | } 50 | } 51 | 52 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 53 | 54 | pipeline_model = TwoViewPipeline(conf).to(device).eval() 55 | 56 | gray0 = cv2.imread(args.img1, 0) 57 | gray1 = cv2.imread(args.img2, 0) 58 | 59 | torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) 60 | torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] 61 | x = {'image0': torch_gray0, 'image1': torch_gray1} 62 | pred = pipeline_model(x) 63 | 64 | pred = batch_to_np(pred) 65 | kp0, kp1 = pred["keypoints0"], pred["keypoints1"] 66 | m0 = pred["matches0"] 67 | 68 | line_seg0, line_seg1 = pred["lines0"], pred["lines1"] 69 | line_matches = pred["line_matches0"] 70 | 71 | valid_matches = m0 != -1 72 | match_indices = m0[valid_matches] 73 | matched_kps0 = kp0[valid_matches] 74 | matched_kps1 = kp1[match_indices] 75 | 76 | valid_matches = line_matches != -1 77 | match_indices = line_matches[valid_matches] 78 | matched_lines0 = line_seg0[valid_matches] 79 | matched_lines1 = line_seg1[match_indices] 80 | 81 | # Plot the matches 82 | img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) 83 | plot_images([img0, img1], ['Image 1 - detected lines', 'Image 2 - detected lines'], dpi=200, pad=2.0) 84 | plot_lines([line_seg0, line_seg1], ps=4, lw=2) 85 | plt.gcf().canvas.manager.set_window_title('Detected Lines') 86 | plt.savefig('detected_lines.png') 87 | 88 | plot_images([img0, img1], ['Image 1 - detected points', 'Image 2 - detected points'], dpi=200, pad=2.0) 89 | plot_keypoints([kp0, kp1], colors='c') 90 | plt.gcf().canvas.manager.set_window_title('Detected Points') 91 | plt.savefig('detected_points.png') 92 | 93 | plot_images([img0, img1], ['Image 1 - line matches', 'Image 2 - line matches'], dpi=200, pad=2.0) 94 | plot_color_line_matches([matched_lines0, matched_lines1], lw=2) 95 | plt.gcf().canvas.manager.set_window_title('Line Matches') 96 | plt.savefig('line_matches.png') 97 | 98 | plot_images([img0, img1], ['Image 1 - point matches', 'Image 2 - point matches'], dpi=200, pad=2.0) 99 | plot_matches(matched_kps0, matched_kps1, 'green', lw=1, ps=0) 100 | plt.gcf().canvas.manager.set_window_title('Point Matches') 101 | plt.savefig('point_matches.png') 102 | if not args.skip_imshow: 103 | plt.show() 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "gluestick" 3 | description = "GlueStick: Robust Image Matching by Sticking Points and Lines Together" 4 | version = "0.0.0" 5 | authors = [ 6 | { name = "Rémi Pautrat" }, 7 | { name = "Iago Suárez" }, 8 | ] 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | 12 | 13 | license = { file = "LICENSE" } 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "Operating System :: OS Independent", 17 | ] 18 | 19 | urls = { Repository = "https://github.com/cvg/GlueStick/" } 20 | dynamic = ["dependencies"] 21 | 22 | [build-system] 23 | build-backend = "setuptools.build_meta" 24 | requires = ["setuptools"] 25 | 26 | [tool.setuptools.packages.find] 27 | include = ["gluestick", "gluestick.*"] 28 | 29 | [tool.setuptools.dynamic] 30 | dependencies = { file = ["requirements.txt"] } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | build 2 | numpy 3 | matplotlib 4 | scipy 5 | scikit_learn 6 | seaborn 7 | omegaconf==2.2.* 8 | opencv-python==4.7.0.* 9 | torch>=1.12 10 | torchvision>=0.13 11 | setuptools 12 | tqdm 13 | pytlsd@git+https://github.com/iago-suarez/pytlsd.git@4180ab8 14 | -------------------------------------------------------------------------------- /resources/demo_seq1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/GlueStick/7d816730ef939caa1c61e2564eceda77304874fa/resources/demo_seq1.gif -------------------------------------------------------------------------------- /resources/img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/GlueStick/7d816730ef939caa1c61e2564eceda77304874fa/resources/img1.jpg -------------------------------------------------------------------------------- /resources/img2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/GlueStick/7d816730ef939caa1c61e2564eceda77304874fa/resources/img2.jpg -------------------------------------------------------------------------------- /resources/weights/superpoint_v1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/GlueStick/7d816730ef939caa1c61e2564eceda77304874fa/resources/weights/superpoint_v1.pth --------------------------------------------------------------------------------