├── .gitignore
├── README.md
├── __pycache__
└── liftfeat_wrapper.cpython-38.pyc
├── assert
├── achitecture.png
├── demo_liftfeat.gif
├── demo_sp.gif
├── keypoints_liftfeat.gif
├── query.jpg
├── ref.jpg
└── trajectory_liftfeat.gif
├── data
└── megadepth_1500.json
├── dataset
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── coco_augmentor.cpython-38.pyc
│ ├── coco_wrapper.cpython-38.pyc
│ ├── dataset_utils.cpython-38.pyc
│ ├── megadepth.cpython-38.pyc
│ └── megadepth_wrapper.cpython-38.pyc
├── coco_augmentor.py
├── coco_wrapper.py
├── dataset_utils.py
├── megadepth.py
└── megadepth_wrapper.py
├── demo.py
├── evaluation
├── HPatch_evaluation.py
├── MegaDepth1500_evaluation.py
├── __pycache__
│ └── eval_utils.cpython-38.pyc
└── eval_utils.py
├── loss
├── __pycache__
│ └── loss.cpython-38.pyc
└── loss.py
├── models
├── __pycache__
│ ├── interpolator.cpython-310.pyc
│ ├── interpolator.cpython-38.pyc
│ ├── liftfeat_wrapper.cpython-310.pyc
│ ├── liftfeat_wrapper.cpython-38.pyc
│ ├── model.cpython-310.pyc
│ └── model.cpython-38.pyc
├── interpolator.py
├── liftfeat_wrapper.py
└── model.py
├── requirements.txt
├── tools
├── demo_match_video.py
└── demo_vo.py
├── train.py
├── train.sh
├── utils
├── VisualOdometry.py
├── __init__.py
├── __pycache__
│ ├── VisualOdometry.cpython-38.pyc
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-38.pyc
│ ├── alike_wrapper.cpython-38.pyc
│ ├── config.cpython-310.pyc
│ ├── config.cpython-38.pyc
│ ├── depth_anything_wrapper.cpython-38.pyc
│ ├── featurebooster.cpython-310.pyc
│ ├── featurebooster.cpython-38.pyc
│ └── post_process.cpython-38.pyc
├── alike_wrapper.py
├── config.py
├── depth_anything_wrapper.py
├── featurebooster.py
└── post_process.py
└── weights
└── LiftFeat.pth
/.gitignore:
--------------------------------------------------------------------------------
1 | visualize
2 | trained_weights
3 | data/HPatch
4 | data/megadepth_test_1500
5 | output
6 | issues
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## LiftFeat: 3D Geometry-Aware Local Feature Matching
2 |
3 |
4 |

5 |

6 |
7 |
8 | Real-time SuperPoint demonstration (left) compared to LiftFeat (right) on a textureless scene.
9 |
10 |
11 |
12 | - 🎉 **New!** Training code is now available 🚀
13 | - 🎉 **New!** The test code and pretrained model have been released. 🚀
14 |
15 | ## Table of Contents
16 | - [Introduction](#introduction)
17 | - [Installation](#installation)
18 | - [Usage](#usage)
19 | - [Inference](#inference)
20 | - [Training](#training)
21 | - [Evaluation](#evaluation)
22 | - [Citation](#citation)
23 | - [License](#license)
24 |
25 | ## Introduction
26 | This repository contains the official implementation of the paper:
27 | **[LiftFeat: 3D Geometry-Aware Local Feature Matching](https://www.arxiv.org/abs/2505.03422)**, to be presented at *ICRA 2025*.
28 |
29 | **Overview of LiftFeat's achitecture**
30 |
31 |

32 |
33 |
34 | LiftFeat is a lightweight and robust local feature matching network designed to handle challenging scenarios such as drastic lighting changes, low-texture regions, and repetitive patterns. By incorporating 3D geometric cues through surface normals predicted from monocular depth, LiftFeat enhances the discriminative power of 2D descriptors. Our proposed 3D geometry-aware feature lifting module effectively fuses these cues, leading to significant improvements in tasks like relative pose estimation, homography estimation, and visual localization.
35 |
36 | ## Installation
37 | If you use conda as virtual environment,you can create a new env with:
38 | ```bash
39 | git clone https://github.com/lyp-deeplearning/LiftFeat.git
40 | cd LiftFeat
41 | conda create -n LiftFeat python=3.8
42 | conda activate LiftFeat
43 | pip install -r requirements.txt
44 | ```
45 |
46 | ## Usage
47 | ### Inference with image pair
48 | To run LiftFeat on an image,you can simply run with:
49 | ```bash
50 | python demo.py --img1= --img2=
51 | ```
52 |
53 | ### Run with video
54 | We provide a simple real-time demo that matches a template image to each frame of a video stream using our LiftFeat method.
55 |
56 | You can run the demo with the following command:
57 | ```bash
58 | python tools/demo_match_video.py --img your_template.png --video your.mp4
59 | ```
60 |
61 | We also provide a [sample template image and video with lighting variation](https://drive.google.com/drive/folders/1b-t-f2Bt47KU674bPI09bGtJ9BHx05Yu?usp=drive_link) for demonstration purposes.
62 |
63 | ### Visual Odometry Demo
64 | We have added a new application to evaluate LiftFeat on visual odometry (VO) tasks.
65 |
66 | We use sequences from the KITTI dataset to demonstrate frame-to-frame motion estimation. Running the script below will generate the estimated camera trajectory and the error curve:
67 |
68 | ```bash
69 | python tools/demo_vo.py --path1 /path/to/gray/images --path2 /path/to/color/images --id 03
70 | ```
71 |
72 | We also provide a sample [KITTI sequence](https://drive.google.com/drive/folders/1b-t-f2Bt47KU674bPI09bGtJ9BHx05Yu?usp=drive_link) for quick testing.
73 |
74 |
75 |

76 |

77 |
78 |
79 |
80 | ## Training
81 | To train LiftFeat as described in the paper, you will need MegaDepth & COCO_20k subset of COCO2017 dataset as described in the paper *[XFeat: Accelerated Features for Lightweight Image Matching](https://arxiv.org/abs/2404.19174)*
82 | You can obtain the full COCO2017 train data at https://cocodataset.org/.
83 | However, we [make available](https://drive.google.com/file/d/1ijYsPq7dtLQSl-oEsUOGH1fAy21YLc7H/view?usp=drive_link) a subset of COCO for convenience. We simply selected a subset of 20k images according to image resolution. Please check COCO [terms of use](https://cocodataset.org/#termsofuse) before using the data.
84 |
85 | To reproduce the training setup from the paper, please follow the steps:
86 | 1. Download [COCO_20k](https://drive.google.com/file/d/1ijYsPq7dtLQSl-oEsUOGH1fAy21YLc7H/view?usp=drive_link) containing a subset of COCO2017;
87 | 2. Download MegaDepth dataset. You can follow [LoFTR instructions](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md#download-datasets), we use the same standard as LoFTR. Then put the megadepth indices inside the MegaDepth root folder following the standard below:
88 | ```bash
89 | {megadepth_root_path}/train_data/megadepth_indices #indices
90 | {megadepth_root_path}/MegaDepth_v1 #images & depth maps & poses
91 | ```
92 | 3. Finally you can call training
93 | ```bash
94 | python train.py --megadepth_root_path /MegaDepth --synthetic_root_path /coco_20k --ckpt_save_path /path/to/ckpts
95 | ```
96 |
97 | ### Evaluation
98 | All evaluation code are in *evaluation*, you can download **HPatch** dataset following [D2-Net](https://github.com/mihaidusmanu/d2-net/tree/master) and download **MegaDepth** test dataset following [LoFTR](https://github.com/zju3dv/LoFTR/tree/master).
99 |
100 | **Download and process HPatch**
101 | ```bash
102 | cd /data
103 |
104 | # Download the dataset
105 | wget https://huggingface.co/datasets/vbalnt/hpatches/resolve/main/hpatches-sequences-release.zip
106 |
107 | # Extract the dataset
108 | unzip hpatches-sequences-release.zip
109 |
110 | # Remove the high-resolution sequences
111 | cd hpatches-sequences-release
112 | rm -rf i_contruction i_crownnight i_dc i_pencils i_whitebuilding v_artisans v_astronautis v_talent
113 |
114 | cd /data
115 |
116 | ln -s /data/hpatches-sequences-release ./HPatch
117 | ```
118 |
119 | **Download and process MegaDepth1500**
120 | We provide download link to [megadepth_test_1500](https://drive.google.com/drive/folders/1nTkK1485FuwqA0DbZrK2Cl0WnXadUZdc)
121 | ```bash
122 | tar xvf
123 |
124 | cd /data
125 |
126 | ln -s ./megadepth_test_1500
127 | ```
128 |
129 |
130 | **Homography Estimation**
131 | ```bash
132 | python evaluation/HPatch_evaluation.py
133 | ```
134 |
135 | **Relative Pose Estimation**
136 |
137 | For *Megadepth1500* dataset:
138 | ```bash
139 | python evaluation/MegaDepth1500_evaluation.py
140 | ```
141 |
142 |
143 | ## Citation
144 | If you find this code useful for your research, please cite the paper:
145 | ```bibtex
146 | @misc{liu2025liftfeat3dgeometryawarelocal,
147 | title={LiftFeat: 3D Geometry-Aware Local Feature Matching},
148 | author={Yepeng Liu and Wenpeng Lai and Zhou Zhao and Yuxuan Xiong and Jinchi Zhu and Jun Cheng and Yongchao Xu},
149 | year={2025},
150 | eprint={2505.03422},
151 | archivePrefix={arXiv},
152 | primaryClass={cs.CV},
153 | url={https://arxiv.org/abs/2505.03422},
154 | }
155 | ```
156 |
157 | ## License
158 | [](LICENSE)
159 |
160 |
161 | ## Acknowledgements
162 | We would like to thank the authors of the following open-source repositories for their valuable contributions, which have inspired or supported this work:
163 |
164 | - [verlab/accelerated_features](https://github.com/verlab/accelerated_features)
165 | - [zju3dv/LoFTR](https://github.com/zju3dv/LoFTR)
166 | - [rpautrat/SuperPoint](https://github.com/rpautrat/SuperPoint)
167 | - [Depth-Anything-V2](https://github.com/DepthAnything/Depth-Anything-V2)
168 | - [Python-VO](https://github.com/Shiaoming/Python-VO)
169 |
170 | We deeply appreciate the efforts of the research community in releasing high-quality codebases.
171 |
--------------------------------------------------------------------------------
/__pycache__/liftfeat_wrapper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/__pycache__/liftfeat_wrapper.cpython-38.pyc
--------------------------------------------------------------------------------
/assert/achitecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/achitecture.png
--------------------------------------------------------------------------------
/assert/demo_liftfeat.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/demo_liftfeat.gif
--------------------------------------------------------------------------------
/assert/demo_sp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/demo_sp.gif
--------------------------------------------------------------------------------
/assert/keypoints_liftfeat.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/keypoints_liftfeat.gif
--------------------------------------------------------------------------------
/assert/query.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/query.jpg
--------------------------------------------------------------------------------
/assert/ref.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/ref.jpg
--------------------------------------------------------------------------------
/assert/trajectory_liftfeat.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/trajectory_liftfeat.gif
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/coco_augmentor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/coco_augmentor.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/coco_wrapper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/coco_wrapper.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/dataset_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/megadepth.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/megadepth.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/megadepth_wrapper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/megadepth_wrapper.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/coco_augmentor.py:
--------------------------------------------------------------------------------
1 | """
2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3 | COCO_20k image augmentor
4 | """
5 |
6 | import torch
7 | from torch import nn
8 | from torch.utils.data import Dataset
9 | import torch.utils.data as data
10 | from torchvision import transforms
11 | import torch.nn.functional as F
12 |
13 | import cv2
14 | import kornia
15 | import kornia.augmentation as K
16 | from kornia.geometry.transform import get_tps_transform as findTPS
17 | from kornia.geometry.transform import warp_points_tps, warp_image_tps
18 |
19 | import glob
20 | import random
21 | import tqdm
22 |
23 | import numpy as np
24 | import pdb
25 | import time
26 |
27 | random.seed(0)
28 | torch.manual_seed(0)
29 |
30 | def generateRandomTPS(shape,grid=(8,6),GLOBAL_MULTIPLIER=0.3,prob=0.5):
31 |
32 | h, w = shape
33 | sh, sw = h/grid[0], w/grid[1]
34 | src = torch.dstack(torch.meshgrid(torch.arange(0, h + sh , sh), torch.arange(0, w + sw , sw), indexing='ij'))
35 |
36 | offsets = torch.rand(grid[0]+1, grid[1]+1, 2) - 0.5
37 | offsets *= torch.tensor([ sh/2, sw/2 ]).view(1, 1, 2) * min(0.97, 2.0 * GLOBAL_MULTIPLIER)
38 | dst = src + offsets if np.random.uniform() < prob else src
39 |
40 | src, dst = src.view(1, -1, 2), dst.view(1, -1, 2)
41 | src = (src / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1.
42 | dst = (dst / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1.
43 | weights, A = findTPS(dst, src)
44 |
45 | return src, weights, A
46 |
47 |
48 | def generateRandomHomography(shape,GLOBAL_MULTIPLIER=0.3):
49 | #Generate random in-plane rotation [-theta,+theta]
50 | theta = np.radians(np.random.uniform(-30, 30))
51 |
52 | #Generate random scale in both x and y
53 | scale_x, scale_y = np.random.uniform(0.35, 1.2, 2)
54 |
55 | #Generate random translation shift
56 | tx , ty = -shape[1]/2.0 , -shape[0]/2.0
57 | txn, tyn = np.random.normal(0, 120.0*GLOBAL_MULTIPLIER, 2)
58 |
59 | c, s = np.cos(theta), np.sin(theta)
60 |
61 | #Affine coeffs
62 | sx , sy = np.random.normal(0,0.6*GLOBAL_MULTIPLIER,2)
63 |
64 | #Projective coeffs
65 | p1 , p2 = np.random.normal(0,0.006*GLOBAL_MULTIPLIER,2)
66 |
67 |
68 | # Build Homography from parmeterizations
69 | H_t = np.array(((1,0, tx), (0, 1, ty), (0,0,1))) #t
70 | H_r = np.array(((c,-s, 0), (s, c, 0), (0,0,1))) #rotation,
71 | H_a = np.array(((1,sy, 0), (sx, 1, 0), (0,0,1))) # affine
72 | H_p = np.array(((1, 0, 0), (0 , 1, 0), (p1,p2,1))) # projective
73 | H_s = np.array(((scale_x,0, 0), (0, scale_y, 0), (0,0,1))) #scale
74 | H_b = np.array(((1.0,0,-tx +txn), (0, 1, -ty + tyn), (0,0,1))) #t_back,
75 |
76 | #H = H_e * H_s * H_a * H_p
77 | H = np.dot(np.dot(np.dot(np.dot(np.dot(H_b,H_s),H_p),H_a),H_r),H_t)
78 |
79 | return H
80 |
81 |
82 | class COCOAugmentor(nn.Module):
83 |
84 | def __init__(self,device,load_dataset=True,
85 | img_dir="/home/yepeng_liu/code_python/dataset/coco_20k",
86 | warp_resolution=(1200, 900),
87 | out_resolution=(400, 300),
88 | sides_crop=0.2,
89 | max_num_imgs=50,
90 | num_test_imgs=10,
91 | batch_size=1,
92 | photometric=True,
93 | geometric=True,
94 | reload_step=1_000
95 | ):
96 | super(COCOAugmentor,self).__init__()
97 | self.half=16
98 | self.device=device
99 |
100 | self.dims=warp_resolution
101 | self.batch_size=batch_size
102 | self.out_resolution=out_resolution
103 | self.sides_crop=sides_crop
104 | self.max_num_imgs=max_num_imgs
105 | self.num_test_imgs=num_test_imgs
106 | self.dims_t=torch.tensor([int(self.dims[0]*(1. - self.sides_crop)) - int(self.dims[0]*self.sides_crop) -1,
107 | int(self.dims[1]*(1. - self.sides_crop)) - int(self.dims[1]*self.sides_crop) -1]).float().to(device).view(1,1,2)
108 | self.dims_s=torch.tensor([self.dims_t[0,0,0] / out_resolution[0],
109 | self.dims_t[0,0,1] / out_resolution[1]]).float().to(device).view(1,1,2)
110 |
111 | self.all_imgs=glob.glob(img_dir+'/*.jpg')+glob.glob(img_dir+'/*.png')
112 |
113 | self.photometric=photometric
114 | self.geometric=geometric
115 | self.cnt=1
116 | self.reload_step=reload_step
117 |
118 | list_augmentation=[
119 | kornia.augmentation.ColorJitter(0.15,0.15,0.15,0.15,p=1.),
120 | kornia.augmentation.RandomEqualize(p=0.4),
121 | kornia.augmentation.RandomGaussianBlur(p=0.3,sigma=(2.0,2.0),kernel_size=(7,7))
122 | ]
123 |
124 | if photometric is False:
125 | list_augmentation = []
126 |
127 | self.aug_list=kornia.augmentation.ImageSequential(*list_augmentation)
128 |
129 | if len(self.all_imgs)<10:
130 | raise RuntimeError('Couldnt find enough images to train. Please check the path: ',img_dir)
131 |
132 | if load_dataset:
133 | print('[COCO]: ',len(self.all_imgs),' images for training..')
134 | if len(self.all_imgs) - num_test_imgs < max_num_imgs:
135 | raise RuntimeError('Error: test set overlaps with training set! Decrease number of test imgs')
136 |
137 | self.load_imgs()
138 |
139 | self.TPS = True
140 |
141 |
142 | def load_imgs(self):
143 | random.shuffle(self.all_imgs)
144 | train = []
145 | for p in tqdm.tqdm(self.all_imgs[:self.max_num_imgs],desc='loading train'):
146 | im=cv2.imread(p)
147 | halfH,halfW=im.shape[0]//2,im.shape[1]//2
148 | if halfH>halfW:
149 | im=np.rot90(im)
150 | halfH,halfW=halfW,halfH
151 |
152 | if im.shape[0]!=self.dims[1] or im.shape[1]!=self.dims[0]:
153 | im = cv2.resize(im, self.dims)
154 |
155 | train.append(np.copy(im))
156 |
157 | self.train=train
158 | self.test=[
159 | cv2.resize(cv2.imread(p),self.dims)
160 | for p in tqdm.tqdm(self.all_imgs[-self.num_test_imgs:],desc='loading test')
161 | ]
162 |
163 | def norm_pts_grid(self, x):
164 | if len(x.size()) == 2:
165 | return (x.view(1,-1,2) * self.dims_s / self.dims_t) * 2. - 1
166 | return (x * self.dims_s / self.dims_t) * 2. - 1
167 |
168 | def denorm_pts_grid(self, x):
169 | if len(x.size()) == 2:
170 | return ((x.view(1,-1,2) + 1) / 2.) / self.dims_s * self.dims_t
171 | return ((x+1) / 2.) / self.dims_s * self.dims_t
172 |
173 | def rnd_kps(self, shape, n = 256):
174 | h, w = shape
175 | kps = torch.rand(size = (3,n)).to(self.device)
176 | kps[0,:]*=w
177 | kps[1,:]*=h
178 | kps[2,:] = 1.0
179 |
180 | return kps
181 |
182 | def warp_points(self, H, pts):
183 | scale = self.dims_s.view(-1,2)
184 | offset = torch.tensor([int(self.dims[0]*self.sides_crop), int(self.dims[1]*self.sides_crop)], device = pts.device).float()
185 | pts = pts*scale + offset
186 | pts = torch.vstack( [pts.t(), torch.ones(1, pts.shape[0], device = pts.device)])
187 | warped = torch.matmul(H, pts)
188 | warped = warped / warped[2,...]
189 | warped = warped.t()[:, :2]
190 | return (warped - offset) / scale
191 |
192 | @torch.inference_mode()
193 | def forward(self, x, difficulty = 0.3, TPS = False, prob_deformation = 0.5, test = False):
194 | """
195 | Perform augmentation to a batch of images.
196 |
197 | input:
198 | x -> torch.Tensor(B, C, H, W): rgb images
199 | difficulty -> float: level of difficulty, 0.1 is medium, 0.3 is already pretty hard
200 | tps -> bool: Wether to apply non-rigid deformations in images
201 | prob_deformation -> float: probability to apply a deformation
202 |
203 | return:
204 | 'output' -> torch.Tensor(B, C, H, W): rgb images
205 | Tuple:
206 | 'H' -> torch.Tensor(3,3): homography matrix
207 | 'mask' -> torch.Tensor(B, H, W): mask of valid pixels after warp
208 | (deformation only)
209 | src, weights, A are parameters from a TPS warp (all torch.Tensors)
210 |
211 | """
212 |
213 | if self.cnt % self.reload_step == 0:
214 | self.load_imgs()
215 |
216 | if self.geometric is False:
217 | difficulty = 0.
218 |
219 | with torch.no_grad():
220 | x = (x/255.).to(self.device)
221 | b, c, h, w = x.shape
222 | shape = (h, w)
223 |
224 | ######## Geometric Transformations
225 |
226 | H = torch.tensor(np.array([generateRandomHomography(shape,difficulty) for b in range(self.batch_size)]),dtype=torch.float32).to(self.device)
227 |
228 | output = kornia.geometry.transform.warp_perspective(x,H,dsize=shape,padding_mode='zeros')
229 |
230 | #crop % of image boundaries each side to reduce invalid pixels after warps
231 | low_h = int(h * self.sides_crop); low_w = int(w*self.sides_crop)
232 | high_h = int(h*(1. - self.sides_crop)); high_w= int(w * (1. - self.sides_crop))
233 | output = output[..., low_h:high_h, low_w:high_w]
234 | x = x[..., low_h:high_h, low_w:high_w]
235 |
236 | #apply TPS if desired:
237 | if TPS:
238 | src, weights, A = None, None, None
239 | for b in range(self.batch_size):
240 | b_src, b_weights, b_A = generateRandomTPS(shape, (8,6), difficulty, prob = prob_deformation)
241 | b_src, b_weights, b_A = b_src.to(self.device), b_weights.to(self.device), b_A.to(self.device)
242 |
243 | if src is None:
244 | src, weights, A = b_src, b_weights, b_A
245 | else:
246 | src = torch.cat((b_src, src))
247 | weights = torch.cat((b_weights, weights))
248 | A = torch.cat((b_A, A))
249 |
250 | output = warp_image_tps(output, src, weights, A)
251 |
252 | output = F.interpolate(output, self.out_resolution[::-1], mode = 'nearest')
253 | x = F.interpolate(x, self.out_resolution[::-1], mode = 'nearest')
254 |
255 | mask = ~torch.all(output == 0, dim=1, keepdim=True)
256 | mask = mask.expand(-1,3,-1,-1)
257 |
258 | # Make-up invalid regions with texture from the batch
259 | rv = 1 if not TPS else 2
260 | output_shifted = torch.roll(x, rv, 0)
261 | output[~mask] = output_shifted[~mask]
262 | mask = mask[:, 0, :, :]
263 |
264 | ######## Photometric Transformations
265 | output = self.aug_list(output)
266 |
267 | b, c, h, w = output.shape
268 | #Correlated Gaussian Noise
269 | if np.random.uniform() > 0.5 and self.photometric:
270 | noise = F.interpolate(torch.randn_like(output)*(10/255), (h//2, w//2))
271 | noise = F.interpolate(noise, (h, w), mode = 'bicubic')
272 | output = torch.clip( output + noise, 0., 1.)
273 |
274 | #Random shadows
275 | if np.random.uniform() > 0.6 and self.photometric:
276 | noise = torch.rand((b, 1, h//64, w//64), device = self.device) * 1.3
277 | noise = torch.clip(noise, 0.25, 1.0)
278 | noise = F.interpolate(noise, (h, w), mode = 'bicubic')
279 | noise = noise.expand(-1, 3, -1, -1)
280 | output *= noise
281 | output = torch.clip( output, 0., 1.)
282 |
283 | self.cnt+=1
284 |
285 | if TPS:
286 | return output, (H, src, weights, A, mask)
287 | else:
288 | return output, (H, mask)
289 |
290 | def get_correspondences(self, kps_target, T):
291 | H, H2, src, W, A = T
292 | undeformed = self.denorm_pts_grid(
293 | warp_points_tps(self.norm_pts_grid(kps_target),
294 | src, W, A) ).view(-1,2)
295 |
296 | warped_to_src = self.warp_points(H@torch.inverse(H2), undeformed)
297 |
298 | return warped_to_src
--------------------------------------------------------------------------------
/dataset/coco_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pdb
4 |
5 | debug_cnt = -1
6 |
7 | def make_batch(augmentor, difficulty = 0.3, train = True):
8 | Hs = []
9 | img_list = augmentor.train if train else augmentor.test
10 | dev = augmentor.device
11 | batch_images = []
12 |
13 | with torch.no_grad(): # we dont require grads in the augmentation
14 | for b in range(augmentor.batch_size):
15 | rdidx = np.random.randint(len(img_list))
16 | img = torch.tensor(img_list[rdidx], dtype=torch.float32).permute(2,0,1).to(augmentor.device).unsqueeze(0)
17 | batch_images.append(img)
18 |
19 | batch_images = torch.cat(batch_images)
20 |
21 | p1, H1 = augmentor(batch_images, difficulty)
22 | p2, H2 = augmentor(batch_images, difficulty, TPS = True, prob_deformation = 0.7)
23 | # p2, H2 = augmentor(batch_images, difficulty, TPS = False, prob_deformation = 0.7)
24 |
25 | return p1, p2, H1, H2
26 |
27 |
28 | def plot_corrs(p1, p2, src_pts, tgt_pts):
29 | import matplotlib.pyplot as plt
30 | p1 = p1.cpu()
31 | p2 = p2.cpu()
32 | src_pts = src_pts.cpu() ; tgt_pts = tgt_pts.cpu()
33 | rnd_idx = np.random.randint(len(src_pts), size=200)
34 | src_pts = src_pts[rnd_idx, ...]
35 | tgt_pts = tgt_pts[rnd_idx, ...]
36 |
37 | #Plot ground-truth correspondences
38 | fig, ax = plt.subplots(1,2,figsize=(18, 12))
39 | colors = np.random.uniform(size=(len(tgt_pts),3))
40 | #Src image
41 | img = p1
42 | for i, p in enumerate(src_pts):
43 | ax[0].scatter(p[0],p[1],color=colors[i])
44 | ax[0].imshow(img.permute(1,2,0).numpy()[...,::-1])
45 |
46 | #Target img
47 | img2 = p2
48 | for i, p in enumerate(tgt_pts):
49 | ax[1].scatter(p[0],p[1],color=colors[i])
50 | ax[1].imshow(img2.permute(1,2,0).numpy()[...,::-1])
51 | plt.show()
52 |
53 |
54 | def get_corresponding_pts(p1, p2, H, H2, augmentor, h, w, crop = None):
55 | '''
56 | Get dense corresponding points
57 | '''
58 | global debug_cnt
59 | negatives, positives = [], []
60 |
61 | with torch.no_grad():
62 | #real input res of samples
63 | rh, rw = p1.shape[-2:]
64 | ratio = torch.tensor([rw/w, rh/h], device = p1.device)
65 |
66 | (H, mask1) = H
67 | (H2, src, W, A, mask2) = H2
68 |
69 | #Generate meshgrid of target pts
70 | x, y = torch.meshgrid(torch.arange(w, device=p1.device), torch.arange(h, device=p1.device), indexing ='xy')
71 | mesh = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], dim=-1)
72 | target_pts = mesh.view(-1, 2) * ratio
73 |
74 | #Pack all transformations into T
75 | for batch_idx in range(len(p1)):
76 | with torch.no_grad():
77 | T = (H[batch_idx], H2[batch_idx],
78 | src[batch_idx].unsqueeze(0), W[batch_idx].unsqueeze(0), A[batch_idx].unsqueeze(0))
79 | #We now warp the target points to src image
80 | src_pts = (augmentor.get_correspondences(target_pts, T) ) #target to src
81 | tgt_pts = (target_pts)
82 |
83 | #Check out of bounds points
84 | mask_valid = (src_pts[:, 0] >=0) & (src_pts[:, 1] >=0) & \
85 | (src_pts[:, 0] < rw) & (src_pts[:, 1] < rh)
86 |
87 | negatives.append( tgt_pts[~mask_valid] )
88 | tgt_pts = tgt_pts[mask_valid]
89 | src_pts = src_pts[mask_valid]
90 |
91 |
92 | #Remove invalid pixels
93 | mask_valid = mask1[batch_idx, src_pts[:,1].long(), src_pts[:,0].long()] & \
94 | mask2[batch_idx, tgt_pts[:,1].long(), tgt_pts[:,0].long()]
95 | tgt_pts = tgt_pts[mask_valid]
96 | src_pts = src_pts[mask_valid]
97 |
98 | # limit nb of matches if desired
99 | if crop is not None:
100 | rnd_idx = torch.randperm(len(src_pts), device=src_pts.device)[:crop]
101 | src_pts = src_pts[rnd_idx]
102 | tgt_pts = tgt_pts[rnd_idx]
103 |
104 | if debug_cnt >=0 and debug_cnt < 4:
105 | plot_corrs(p1[batch_idx], p2[batch_idx], src_pts , tgt_pts )
106 | debug_cnt +=1
107 |
108 | src_pts = (src_pts / ratio)
109 | tgt_pts = (tgt_pts / ratio)
110 |
111 | #Check out of bounds points
112 | padto = 10 if crop is not None else 2
113 | mask_valid1 = (src_pts[:, 0] >= (0 + padto)) & (src_pts[:, 1] >= (0 + padto)) & \
114 | (src_pts[:, 0] < (w - padto)) & (src_pts[:, 1] < (h - padto))
115 | mask_valid2 = (tgt_pts[:, 0] >= (0 + padto)) & (tgt_pts[:, 1] >= (0 + padto)) & \
116 | (tgt_pts[:, 0] < (w - padto)) & (tgt_pts[:, 1] < (h - padto))
117 | mask_valid = mask_valid1 & mask_valid2
118 | tgt_pts = tgt_pts[mask_valid]
119 | src_pts = src_pts[mask_valid]
120 |
121 | #Remove repeated correspondences
122 | lut_mat = torch.ones((h, w, 4), device = src_pts.device, dtype = src_pts.dtype) * -1
123 | # src_pts_np = src_pts.cpu().numpy()
124 | # tgt_pts_np = tgt_pts.cpu().numpy()
125 | try:
126 | lut_mat[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
127 | mask_valid = torch.all(lut_mat >= 0, dim=-1)
128 | points = lut_mat[mask_valid]
129 | positives.append(points)
130 | except:
131 | pdb.set_trace()
132 | print('..')
133 |
134 | return negatives, positives
135 |
136 |
137 | def crop_patches(tensor, coords, size = 7):
138 | '''
139 | Crop [size x size] patches around 2D coordinates from a tensor.
140 | '''
141 | B, C, H, W = tensor.shape
142 |
143 | x, y = coords[:, 0], coords[:, 1]
144 | y = y.view(-1, 1, 1)
145 | x = x.view(-1, 1, 1)
146 | halfsize = size // 2
147 | # Create meshgrid for indexing
148 | x_offset, y_offset = torch.meshgrid(torch.arange(-halfsize, halfsize+1), torch.arange(-halfsize, halfsize+1), indexing='xy')
149 | y_offset = y_offset.to(tensor.device)
150 | x_offset = x_offset.to(tensor.device)
151 |
152 | # Compute indices around each coordinate
153 | y_indices = (y + y_offset.view(1, size, size)).squeeze(0) + halfsize
154 | x_indices = (x + x_offset.view(1, size, size)).squeeze(0) + halfsize
155 |
156 | # Handle out-of-boundary indices with padding
157 | tensor_padded = torch.nn.functional.pad(tensor, (halfsize, halfsize, halfsize, halfsize), mode='constant')
158 |
159 | # Index tensor to get patches
160 | patches = tensor_padded[:, :, y_indices, x_indices] # [B, C, N, H, W]
161 | return patches
162 |
163 | def subpix_softmax2d(heatmaps, temp = 0.25):
164 | N, H, W = heatmaps.shape
165 | heatmaps = torch.softmax(temp * heatmaps.view(-1, H*W), -1).view(-1, H, W)
166 | x, y = torch.meshgrid(torch.arange(W, device = heatmaps.device ), torch.arange(H, device = heatmaps.device ), indexing = 'xy')
167 | x = x - (W//2)
168 | y = y - (H//2)
169 | #pdb.set_trace()
170 | coords_x = (x[None, ...] * heatmaps)
171 | coords_y = (y[None, ...] * heatmaps)
172 | coords = torch.cat([coords_x[..., None], coords_y[..., None]], -1).view(N, H*W, 2)
173 | coords = coords.sum(1)
174 |
175 | return coords
176 |
--------------------------------------------------------------------------------
/dataset/dataset_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3 |
4 | MegaDepth data handling was adapted from
5 | LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
6 | """
7 |
8 | import io
9 | import cv2
10 | import numpy as np
11 | import h5py
12 | import torch
13 | from numpy.linalg import inv
14 |
15 |
16 | try:
17 | # for internel use only
18 | from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT
19 | except Exception:
20 | MEGADEPTH_CLIENT = SCANNET_CLIENT = None
21 |
22 | # --- DATA IO ---
23 |
24 | def load_array_from_s3(
25 | path, client, cv_type,
26 | use_h5py=False,
27 | ):
28 | byte_str = client.Get(path)
29 | try:
30 | if not use_h5py:
31 | raw_array = np.fromstring(byte_str, np.uint8)
32 | data = cv2.imdecode(raw_array, cv_type)
33 | else:
34 | f = io.BytesIO(byte_str)
35 | data = np.array(h5py.File(f, 'r')['/depth'])
36 | except Exception as ex:
37 | print(f"==> Data loading failure: {path}")
38 | raise ex
39 |
40 | assert data is not None
41 | return data
42 |
43 |
44 | def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
45 | cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
46 | else cv2.IMREAD_COLOR
47 | if str(path).startswith('s3://'):
48 | image = load_array_from_s3(str(path), client, cv_type)
49 | else:
50 | image = cv2.imread(str(path), 1)
51 |
52 | if augment_fn is not None:
53 | image = cv2.imread(str(path), cv2.IMREAD_COLOR)
54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55 | image = augment_fn(image)
56 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
57 | return image # (h, w)
58 |
59 |
60 | def get_resized_wh(w, h, resize=None):
61 | if resize is not None: # resize the longer edge
62 | scale = resize / max(h, w)
63 | w_new, h_new = int(round(w*scale)), int(round(h*scale))
64 | else:
65 | w_new, h_new = w, h
66 | return w_new, h_new
67 |
68 |
69 | def get_divisible_wh(w, h, df=None):
70 | if df is not None:
71 | w_new, h_new = map(lambda x: int(x // df * df), [w, h])
72 | else:
73 | w_new, h_new = w, h
74 | return w_new, h_new
75 |
76 |
77 | def pad_bottom_right(inp, pad_size, ret_mask=False):
78 | assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
79 | mask = None
80 | if inp.ndim == 2:
81 | padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
82 | padded[:inp.shape[0], :inp.shape[1]] = inp
83 | if ret_mask:
84 | mask = np.zeros((pad_size, pad_size), dtype=bool)
85 | mask[:inp.shape[0], :inp.shape[1]] = True
86 | elif inp.ndim == 3:
87 | padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
88 | padded[:, :inp.shape[1], :inp.shape[2]] = inp
89 | if ret_mask:
90 | mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
91 | mask[:, :inp.shape[1], :inp.shape[2]] = True
92 | else:
93 | raise NotImplementedError()
94 | return padded, mask
95 |
96 |
97 | # --- MEGADEPTH ---
98 |
99 | def fix_path_from_d2net(path):
100 | if not path:
101 | return None
102 |
103 | path = path.replace('Undistorted_SfM/', '')
104 | path = path.replace('images', 'dense0/imgs')
105 | path = path.replace('phoenix/S6/zl548/MegaDepth_v1/', '')
106 |
107 | return path
108 |
109 | def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
110 | """
111 | Args:
112 | resize (int, optional): the longer edge of resized images. None for no resize.
113 | padding (bool): If set to 'True', zero-pad resized images to squared size.
114 | augment_fn (callable, optional): augments images with pre-defined visual effects
115 | Returns:
116 | image (torch.tensor): (1, h, w)
117 | mask (torch.tensor): (h, w)
118 | scale (torch.tensor): [w/w_new, h/h_new]
119 | """
120 | # read image
121 | image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
122 |
123 | # resize image
124 | w, h = image.shape[1], image.shape[0]
125 |
126 | if resize is not None:
127 | if len(resize) == 2:
128 | w_new, h_new = resize
129 | else:
130 | resize = resize[0]
131 | w_new, h_new = get_resized_wh(w, h, resize)
132 | w_new, h_new = get_divisible_wh(w_new, h_new, df)
133 |
134 |
135 | image = cv2.resize(image, (w_new, h_new))
136 | scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
137 |
138 | if padding: # padding
139 | pad_to = max(h_new, w_new)
140 | image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
141 | else:
142 | mask = None
143 | else:
144 | scale=torch.tensor([1.0,1.0],dtype=torch.float)
145 |
146 | if padding:
147 | pad_to=max(w,h)
148 | image,mask=pad_bottom_right(image,pad_to,ret_mask=True)
149 | else:
150 | mask=None
151 |
152 | #image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
153 | image_t = torch.from_numpy(image).float().permute(2,0,1) / 255 # (h, w) -> (1, h, w) and normalized
154 | mask = torch.from_numpy(mask) if mask is not None else None
155 |
156 | return image, image_t, mask, scale
157 |
158 |
159 | def read_megadepth_depth(path, pad_to=None):
160 |
161 | if str(path).startswith('s3://'):
162 | depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
163 | else:
164 | depth = np.array(h5py.File(path, 'r')['depth'])
165 | if pad_to is not None:
166 | depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
167 | depth = torch.from_numpy(depth).float() # (h, w)
168 | return depth
169 |
170 |
171 | def imread_bgr(path, augment_fn=None, client=SCANNET_CLIENT):
172 | cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR
173 | if str(path).startswith('s3://'):
174 | image = load_array_from_s3(str(path), client, cv_type)
175 | else:
176 | image = cv2.imread(str(path), 1)
177 |
178 | if augment_fn is not None:
179 | image = cv2.imread(str(path), cv2.IMREAD_COLOR)
180 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
181 | image = augment_fn(image)
182 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
183 | return image # (h, w)
184 |
--------------------------------------------------------------------------------
/dataset/megadepth.py:
--------------------------------------------------------------------------------
1 | """
2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3 |
4 | MegaDepth data handling was adapted from
5 | LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
6 | """
7 |
8 | import os.path as osp
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | from torch.utils.data import Dataset
13 | import glob
14 | import numpy.random as rnd
15 |
16 | import os
17 | import sys
18 | sys.path.append(os.path.join(os.path.dirname(__file__),'..'))
19 | from dataset.dataset_utils import read_megadepth_gray, read_megadepth_depth, fix_path_from_d2net
20 |
21 | import pdb, tqdm, os
22 |
23 |
24 | class MegaDepthDataset(Dataset):
25 | def __init__(self,
26 | root_dir,
27 | npz_path,
28 | mode='train',
29 | min_overlap_score = 0.3, #0.3,
30 | max_overlap_score = 1.0, #1,
31 | load_depth = True,
32 | img_resize = (800,608), #or None
33 | df=32,
34 | img_padding=False,
35 | depth_padding=True,
36 | augment_fn=None,
37 | **kwargs):
38 | """
39 | Manage one scene(npz_path) of MegaDepth dataset.
40 |
41 | Args:
42 | root_dir (str): megadepth root directory that has `phoenix`.
43 | npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
44 | mode (str): options are ['train', 'val', 'test']
45 | min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
46 | img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
47 | This is useful during training with batches and testing with memory intensive algorithms.
48 | df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
49 | img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
50 | depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
51 | augment_fn (callable, optional): augments images with pre-defined visual effects.
52 | """
53 | super().__init__()
54 | self.root_dir = root_dir
55 | self.mode = mode
56 | self.scene_id = npz_path.split('.')[0]
57 | self.load_depth = load_depth
58 | # prepare scene_info and pair_info
59 | if mode == 'test' and min_overlap_score != 0:
60 | min_overlap_score = 0
61 | self.scene_info = np.load(npz_path, allow_pickle=True)
62 | self.pair_infos = self.scene_info['pair_infos'].copy()
63 | del self.scene_info['pair_infos']
64 | self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score and pair_info[1] < max_overlap_score]
65 |
66 | # parameters for image resizing, padding and depthmap padding
67 | if mode == 'train':
68 | assert img_resize is not None #and img_padding and depth_padding
69 |
70 | self.img_resize = img_resize
71 | self.df = df
72 | self.img_padding = img_padding
73 | self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
74 |
75 | # for training LoFTR
76 | self.augment_fn = augment_fn if mode == 'train' else None
77 | self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
78 | #pdb.set_trace()
79 | for idx in range(len(self.scene_info['image_paths'])):
80 | self.scene_info['image_paths'][idx] = fix_path_from_d2net(self.scene_info['image_paths'][idx])
81 |
82 | for idx in range(len(self.scene_info['depth_paths'])):
83 | self.scene_info['depth_paths'][idx] = fix_path_from_d2net(self.scene_info['depth_paths'][idx])
84 |
85 |
86 | def __len__(self):
87 | return len(self.pair_infos)
88 |
89 | def __getitem__(self, idx):
90 | (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx % len(self)]
91 |
92 | # read grayscale image and mask. (1, h, w) and (h, w)
93 | img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
94 | img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
95 |
96 | # TODO: Support augmentation & handle seeds for each worker correctly.
97 | image0, image0_t, mask0, scale0 = read_megadepth_gray(img_name0, self.img_resize, self.df, self.img_padding, None)
98 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
99 | image1, image1_t, mask1, scale1 = read_megadepth_gray(img_name1, self.img_resize, self.df, self.img_padding, None)
100 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
101 |
102 | if self.load_depth:
103 | # read depth. shape: (h, w)
104 | if self.mode in ['train', 'val']:
105 | depth0 = read_megadepth_depth(
106 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
107 | depth1 = read_megadepth_depth(
108 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
109 | else:
110 | depth0 = depth1 = torch.tensor([])
111 |
112 | # read intrinsics of original size
113 | K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
114 | K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
115 |
116 | # read and compute relative poses
117 | T0 = self.scene_info['poses'][idx0]
118 | T1 = self.scene_info['poses'][idx1]
119 | T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
120 | T_1to0 = T_0to1.inverse()
121 |
122 | data = {
123 | 'image0': image0_t, # (1, h, w)
124 | 'image0_np': image0,
125 | 'depth0': depth0, # (h, w)
126 | 'image1': image1_t,
127 | 'image1_np': image1,
128 | 'depth1': depth1,
129 | 'T_0to1': T_0to1, # (4, 4)
130 | 'T_1to0': T_1to0,
131 | 'K0': K_0, # (3, 3)
132 | 'K1': K_1,
133 | 'scale0': scale0, # [scale_w, scale_h]
134 | 'scale1': scale1,
135 | 'dataset_name': 'MegaDepth',
136 | 'scene_id': self.scene_id,
137 | 'pair_id': idx,
138 | 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
139 | }
140 |
141 | # for LoFTR training
142 | if mask0 is not None: # img_padding is True
143 | if self.coarse_scale:
144 | [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
145 | scale_factor=self.coarse_scale,
146 | mode='nearest',
147 | recompute_scale_factor=False)[0].bool()
148 | data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
149 |
150 | else:
151 |
152 | # read intrinsics of original size
153 | K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
154 | K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
155 |
156 | # read and compute relative poses
157 | T0 = self.scene_info['poses'][idx0]
158 | T1 = self.scene_info['poses'][idx1]
159 | T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
160 | T_1to0 = T_0to1.inverse()
161 |
162 | data = {
163 | 'image0': image0, # (1, h, w)
164 | 'image1': image1,
165 | 'T_0to1': T_0to1, # (4, 4)
166 | 'T_1to0': T_1to0,
167 | 'K0': K_0, # (3, 3)
168 | 'K1': K_1,
169 | 'scale0': scale0, # [scale_w, scale_h]
170 | 'scale1': scale1,
171 | 'dataset_name': 'MegaDepth',
172 | 'scene_id': self.scene_id,
173 | 'pair_id': idx,
174 | 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
175 | }
176 |
177 | return data
--------------------------------------------------------------------------------
/dataset/megadepth_wrapper.py:
--------------------------------------------------------------------------------
1 | """
2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3 |
4 | MegaDepth data handling was adapted from
5 | LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
6 | """
7 |
8 | import torch
9 | from kornia.utils import create_meshgrid
10 | import matplotlib.pyplot as plt
11 | import pdb
12 | import cv2
13 |
14 | @torch.no_grad()
15 | def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
16 | """ Warp kpts0 from I0 to I1 with depth, K and Rt
17 | Also check covisibility and depth consistency.
18 | Depth is consistent if relative error < 0.2 (hard-coded).
19 |
20 | Args:
21 | kpts0 (torch.Tensor): [N, L, 2] - ,
22 | depth0 (torch.Tensor): [N, H, W],
23 | depth1 (torch.Tensor): [N, H, W],
24 | T_0to1 (torch.Tensor): [N, 3, 4],
25 | K0 (torch.Tensor): [N, 3, 3],
26 | K1 (torch.Tensor): [N, 3, 3],
27 | Returns:
28 | calculable_mask (torch.Tensor): [N, L]
29 | warped_keypoints0 (torch.Tensor): [N, L, 2]
30 | """
31 | kpts0_long = kpts0.round().long().clip(0, 2000-1)
32 |
33 | depth0[:, 0, :] = 0 ; depth1[:, 0, :] = 0
34 | depth0[:, :, 0] = 0 ; depth1[:, :, 0] = 0
35 |
36 | # Sample depth, get calculable_mask on depth != 0
37 | kpts0_depth = torch.stack(
38 | [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
39 | ) # (N, L)
40 | nonzero_mask = kpts0_depth > 0
41 |
42 | # Draw cross marks on the image for each keypoint
43 | # for b in range(len(kpts0)):
44 | # fig, ax = plt.subplots(1,2)
45 | # depth_np = depth0.numpy()[b]
46 | # depth_np_plot = depth_np.copy()
47 | # for x, y in kpts0_long[b, nonzero_mask[b], :].numpy():
48 | # cv2.drawMarker(depth_np_plot, (x, y), (255), cv2.MARKER_CROSS, markerSize=10, thickness=2)
49 | # ax[0].imshow(depth_np)
50 | # ax[1].imshow(depth_np_plot)
51 |
52 | # Unproject
53 | kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
54 | kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
55 |
56 | # Rigid Transform
57 | w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
58 | w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
59 |
60 | # Project
61 | w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
62 | w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-5) # (N, L, 2), +1e-4 to avoid zero depth
63 |
64 | # Covisible Check
65 | # h, w = depth1.shape[1:3]
66 | # covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
67 | # (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
68 | # w_kpts0_long = w_kpts0.long()
69 | # w_kpts0_long[~covisible_mask, :] = 0
70 |
71 | # w_kpts0_depth = torch.stack(
72 | # [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
73 | # ) # (N, L)
74 | # consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
75 |
76 |
77 | valid_mask = nonzero_mask #* consistent_mask* covisible_mask
78 |
79 | return valid_mask, w_kpts0
80 |
81 |
82 | @torch.no_grad()
83 | def spvs_coarse(data, scale = 8):
84 | """
85 | Supervise corresp with dense depth & camera poses
86 | """
87 |
88 | # 1. misc
89 | device = data['image0'].device
90 | N, _, H0, W0 = data['image0'].shape
91 | _, _, H1, W1 = data['image1'].shape
92 | #scale = 8
93 | scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
94 | scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale
95 | h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
96 |
97 | # 2. warp grids
98 | # create kpts in meshgrid and resize them to image resolution
99 | grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) # [N, hw, 2]
100 | grid_pt1_i = scale1 * grid_pt1_c
101 |
102 | # warp kpts bi-directionally and check reproj error
103 | nonzero_m1, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
104 | nonzero_m2, w_pt1_og = warp_kpts( w_pt1_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
105 |
106 |
107 | dist = torch.linalg.norm( grid_pt1_i - w_pt1_og, dim=-1)
108 | mask_mutual = (dist < 1.5) & nonzero_m1 & nonzero_m2
109 |
110 | #_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
111 | batched_corrs = [ torch.cat([w_pt1_i[i, mask_mutual[i]] / data['scale0'][i],
112 | grid_pt1_i[i, mask_mutual[i]] / data['scale1'][i]],dim=-1) for i in range(len(mask_mutual))]
113 |
114 |
115 | #Remove repeated correspondences - this is important for network convergence
116 | corrs = []
117 | for pts in batched_corrs:
118 | lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1
119 | lut_mat21 = torch.clone(lut_mat12)
120 | src_pts = pts[:, :2] / scale
121 | tgt_pts = pts[:, 2:] / scale
122 | try:
123 | lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
124 | mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1)
125 | points = lut_mat12[mask_valid12]
126 |
127 | #Target-src check
128 | src_pts, tgt_pts = points[:, :2], points[:, 2:]
129 | lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
130 | mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1)
131 | points = lut_mat21[mask_valid21]
132 |
133 | corrs.append(points)
134 | except:
135 | pdb.set_trace()
136 | print('..')
137 |
138 | #Plot for debug purposes
139 | # for i in range(len(corrs)):
140 | # plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8)
141 |
142 | return corrs
143 |
144 | @torch.no_grad()
145 | def get_correspondences(pts2, data, idx):
146 | device = data['image0'].device
147 | N, _, H0, W0 = data['image0'].shape
148 | _, _, H1, W1 = data['image1'].shape
149 |
150 | pts2 = pts2[None, ...]
151 |
152 | scale0 = data['scale0'][idx, None][None, ...] if 'scale0' in data else 1
153 | scale1 = data['scale1'][idx, None][None, ...] if 'scale1' in data else 1
154 |
155 | pts2 = scale1 * pts2 * 8
156 |
157 | # warp kpts bi-directionally and check reproj error
158 | nonzero_m1, pts1 = warp_kpts(pts2, data['depth1'][idx][None, ...], data['depth0'][idx][None, ...], data['T_1to0'][idx][None, ...],
159 | data['K1'][idx][None, ...], data['K0'][idx][None, ...])
160 |
161 | corrs = torch.cat([pts1[0, :] / data['scale0'][idx],
162 | pts2[0, :] / data['scale1'][idx]],dim=-1)
163 |
164 | #plot_corrs(data['image0'][idx], data['image1'][idx], corrs[:, :2], corrs[:, 2:])
165 |
166 | return corrs
167 |
168 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import numpy as np
5 | import math
6 | import cv2
7 |
8 | from models.liftfeat_wrapper import LiftFeat,MODEL_PATH
9 |
10 | import argparse
11 |
12 | parser=argparse.ArgumentParser(description='HPatch dataset evaluation script')
13 | parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name')
14 | parser.add_argument('--img1',type=str,default='./assert/ref.jpg',help='reference image path')
15 | parser.add_argument('--img2',type=str,default='./assert/query.jpg',help='query image path')
16 | parser.add_argument('--size',type=str,default=None,help='Resize images to w,h, None means disable resize')
17 | parser.add_argument('--use_opencv_match',action='store_true',help='Enable OpenCV match function')
18 | parser.add_argument('--gpu',type=str,default='0',help='GPU ID')
19 | args=parser.parse_args()
20 |
21 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
22 |
23 |
24 | def warp_corners_and_draw_matches(ref_points, dst_points, img1, img2):
25 | # Calculate the Homography matrix
26 | H, mask = cv2.findHomography(ref_points, dst_points, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999)
27 | mask = mask.flatten()
28 |
29 | # Get corners of the first image (image1)
30 | h, w = img1.shape[:2]
31 | corners_img1 = np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype=np.float32).reshape(-1, 1, 2)
32 |
33 | # Warp corners to the second image (image2) space
34 | warped_corners = cv2.perspectiveTransform(corners_img1, H)
35 |
36 | # Draw the warped corners in image2
37 | img2_with_corners = img2.copy()
38 |
39 | # Prepare keypoints and matches for drawMatches function
40 | keypoints1 = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in ref_points]
41 | keypoints2 = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in dst_points]
42 | matches = [cv2.DMatch(i,i,0) for i in range(len(mask)) if mask[i]]
43 |
44 | # Draw inlier matches
45 | img_matches = cv2.drawMatches(img1, keypoints1, img2_with_corners, keypoints2, matches, None,
46 | matchColor=(0, 255, 0), flags=2)
47 |
48 | return img_matches
49 |
50 |
51 | def opencv_knn_match(descs1,descs2,kpts1,kpts2):
52 | bf = cv2.BFMatcher()
53 |
54 | matches = bf.knnMatch(descs1,descs2,k=2)
55 |
56 | good_matches = []
57 | for m, n in matches:
58 | if m.distance < 0.9 * n.distance:
59 | good_matches.append(m)
60 |
61 | mkpts1 = [];mkpts2 = []
62 |
63 | for m in good_matches:
64 | mkpt1=kpts1[m.queryIdx];mkpt2=kpts2[m.trainIdx]
65 | mkpts1.append(mkpt1);mkpts2.append(mkpt2)
66 |
67 | mkpts1 = np.array(mkpts1)
68 | mkpts2 = np.array(mkpts2)
69 |
70 | return mkpts1,mkpts2
71 |
72 |
73 | if __name__=="__main__":
74 | if args.size:
75 | print(f'resize images to {args.size}')
76 | w=int(args.size.split(',')[0])
77 | h=int(args.size.split(',')[1])
78 | dst_size=(w,h)
79 | else:
80 | print(f'disable resize')
81 |
82 | if args.use_opencv_match:
83 | print(f'Use OpenCV knnMatch')
84 | else:
85 | print(f'Use original match function')
86 |
87 | liftfeat=LiftFeat(weight=MODEL_PATH,detect_threshold=0.05)
88 |
89 | img1=cv2.imread(args.img1)
90 | img2=cv2.imread(args.img2)
91 |
92 | if args.size:
93 | img1=cv2.resize(img1,dst_size)
94 | img2=cv2.resize(img2,dst_size)
95 |
96 | if args.use_opencv_match:
97 | data1 = liftfeat.extract(img1)
98 | data2 = liftfeat.extract(img2)
99 | kpts1,descs1=data1['keypoints'].cpu().numpy(),data1['descriptors'].cpu().numpy()
100 | kpts2,descs2=data2['keypoints'].cpu().numpy(),data2['descriptors'].cpu().numpy()
101 |
102 | mkpts1,mkpts2 = opencv_knn_match(descs1,descs2,kpts1,kpts2)
103 | else:
104 | mkpts1,mkpts2=liftfeat.match_liftfeat(img1,img2)
105 |
106 |
107 | canvas=warp_corners_and_draw_matches(mkpts1,mkpts2,img1,img2)
108 |
109 | import matplotlib.pyplot as plt
110 | plt.figure(figsize=[12,12])
111 | plt.imshow(canvas[...,::-1])
112 |
113 | plt.savefig(os.path.join(os.path.dirname(__file__),'match.jpg'), dpi=300, bbox_inches='tight')
114 |
115 | plt.show()
116 |
--------------------------------------------------------------------------------
/evaluation/HPatch_evaluation.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | from tqdm import tqdm
4 | import torch
5 | import numpy as np
6 | import sys
7 | import poselib
8 |
9 | sys.path.append(os.path.join(os.path.dirname(__file__),'..'))
10 |
11 | import argparse
12 | import datetime
13 |
14 | parser=argparse.ArgumentParser(description='HPatch dataset evaluation script')
15 | parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name')
16 | parser.add_argument('--gpu',type=str,default='0',help='GPU ID')
17 | args=parser.parse_args()
18 |
19 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
20 |
21 | use_cuda = torch.cuda.is_available()
22 | device = torch.device("cuda" if use_cuda else "cpu")
23 |
24 | top_k = None
25 | n_i = 52
26 | n_v = 56
27 |
28 | DATASET_ROOT = os.path.join(os.path.dirname(__file__),'../data/HPatch')
29 |
30 | from evaluation.eval_utils import *
31 | from models.liftfeat_wrapper import LiftFeat
32 |
33 |
34 | poselib_config = {"ransac_th": 3.0, "options": {}}
35 |
36 | class PoseLibHomographyEstimator:
37 | def __init__(self, conf):
38 | self.conf = conf
39 |
40 | def estimate(self, mkpts0,mkpts1):
41 | M, info = poselib.estimate_homography(
42 | mkpts0,
43 | mkpts1,
44 | {
45 | "max_reproj_error": self.conf["ransac_th"],
46 | **self.conf["options"],
47 | },
48 | )
49 | success = M is not None
50 | if not success:
51 | M = np.eye(3,dtype=np.float32)
52 | inl = np.zeros(mkpts0.shape[0],dtype=np.bool_)
53 | else:
54 | inl = info["inliers"]
55 |
56 | estimation = {
57 | "success": success,
58 | "M_0to1": M,
59 | "inliers": inl,
60 | }
61 |
62 | return estimation
63 |
64 |
65 | estimator=PoseLibHomographyEstimator(poselib_config)
66 |
67 |
68 | def poselib_homography_estimate(mkpts0,mkpts1):
69 | data=estimator.estimate(mkpts0,mkpts1)
70 | return data
71 |
72 |
73 | def generate_standard_image(img,target_size=(1920,1080)):
74 | sh,sw=img.shape[0],img.shape[1]
75 | rh,rw=float(target_size[1])/float(sh),float(target_size[0])/float(sw)
76 | ratio=min(rh,rw)
77 | nh,nw=int(ratio*sh),int(ratio*sw)
78 | ph,pw=target_size[1]-nh,target_size[0]-nw
79 | nimg=cv2.resize(img,(nw,nh))
80 | nimg=cv2.copyMakeBorder(nimg,0,ph,0,pw,cv2.BORDER_CONSTANT,value=(0,0,0))
81 |
82 | return nimg,ratio,ph,pw
83 |
84 |
85 | def benchmark_features(match_fn):
86 | lim = [1, 9]
87 | rng = np.arange(lim[0], lim[1] + 1)
88 |
89 | seq_names = sorted(os.listdir(DATASET_ROOT))
90 |
91 | n_feats = []
92 | n_matches = []
93 | seq_type = []
94 | i_err = {thr: 0 for thr in rng}
95 | v_err = {thr: 0 for thr in rng}
96 |
97 | i_err_homo = {thr: 0 for thr in rng}
98 | v_err_homo = {thr: 0 for thr in rng}
99 |
100 | for seq_idx, seq_name in tqdm(enumerate(seq_names), total=len(seq_names)):
101 | # load reference image
102 | ref_img = cv2.imread(os.path.join(DATASET_ROOT, seq_name, "1.ppm"))
103 | ref_img_shape=ref_img.shape
104 |
105 | # load query images
106 | for im_idx in range(2, 7):
107 | # read ground-truth homography
108 | homography = np.loadtxt(os.path.join(DATASET_ROOT, seq_name, "H_1_" + str(im_idx)))
109 | query_img = cv2.imread(os.path.join(DATASET_ROOT, seq_name, f"{im_idx}.ppm"))
110 |
111 | mkpts_a,mkpts_b=match_fn(ref_img,query_img)
112 |
113 | pos_a = mkpts_a
114 | pos_a_h = np.concatenate([pos_a, np.ones([pos_a.shape[0], 1])], axis=1)
115 | pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h)))
116 | pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
117 |
118 | pos_b = mkpts_b
119 |
120 | dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1))
121 |
122 | n_matches.append(pos_a.shape[0])
123 | seq_type.append(seq_name[0])
124 |
125 | if dist.shape[0] == 0:
126 | dist = np.array([float("inf")])
127 |
128 | for thr in rng:
129 | if seq_name[0] == "i":
130 | i_err[thr] += np.mean(dist <= thr)
131 | else:
132 | v_err[thr] += np.mean(dist <= thr)
133 |
134 | # estimate homography
135 | gt_homo = homography
136 | pred_homo, _ = cv2.findHomography(mkpts_a,mkpts_b,cv2.USAC_MAGSAC)
137 | if pred_homo is None:
138 | homo_dist = np.array([float("inf")])
139 | else:
140 | corners = np.array(
141 | [
142 | [0, 0],
143 | [ref_img_shape[1] - 1, 0],
144 | [0, ref_img_shape[0] - 1],
145 | [ref_img_shape[1] - 1, ref_img_shape[0] - 1],
146 | ]
147 | )
148 | real_warped_corners = homo_trans(corners, gt_homo)
149 | warped_corners = homo_trans(corners, pred_homo)
150 | homo_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1))
151 |
152 | for thr in rng:
153 | if seq_name[0] == "i":
154 | i_err_homo[thr] += np.mean(homo_dist <= thr)
155 | else:
156 | v_err_homo[thr] += np.mean(homo_dist <= thr)
157 |
158 | seq_type = np.array(seq_type)
159 | n_feats = np.array(n_feats)
160 | n_matches = np.array(n_matches)
161 |
162 | return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches]
163 |
164 |
165 | if __name__ == "__main__":
166 | errors = {}
167 |
168 | weights=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
169 | liftfeat=LiftFeat(weight=weights)
170 |
171 | errors = benchmark_features(liftfeat.match_liftfeat)
172 |
173 | i_err, v_err, i_err_hom, v_err_hom, _ = errors
174 |
175 | cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
176 |
177 | print(f'\n==={cur_time}==={args.name}===')
178 | print(f"MHA@3 MHA@5 MHA@7")
179 | for thr in [3, 5, 7]:
180 | ill_err_hom = i_err_hom[thr] / (n_i * 5)
181 | view_err_hom = v_err_hom[thr] / (n_v * 5)
182 | print(f"{ill_err_hom * 100:.2f}%-{view_err_hom * 100:.2f}%")
183 |
--------------------------------------------------------------------------------
/evaluation/MegaDepth1500_evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | from pathlib import Path
5 | import numpy as np
6 | import torch
7 | import torch.utils.data as data
8 | import tqdm
9 | from copy import deepcopy
10 | from torchvision.transforms import ToTensor
11 | import torch.nn.functional as F
12 | import json
13 |
14 | import scipy.io as scio
15 | import poselib
16 |
17 | import argparse
18 | import datetime
19 |
20 | parser=argparse.ArgumentParser(description='MegaDepth dataset evaluation script')
21 | parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name')
22 | parser.add_argument('--gpu',type=str,default='0',help='GPU ID')
23 | args=parser.parse_args()
24 |
25 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
26 |
27 | sys.path.append(os.path.join(os.path.dirname(__file__),'../'))
28 | from models.liftfeat_wrapper import LiftFeat
29 | from evaluation.eval_utils import *
30 |
31 | from torch.utils.data import Dataset,DataLoader
32 |
33 | use_cuda = torch.cuda.is_available()
34 | device = "cuda" if use_cuda else "cpu"
35 |
36 | DATASET_ROOT = os.path.join(os.path.dirname(__file__),'../data/megadepth_test_1500')
37 | DATASET_JSON = os.path.join(os.path.dirname(__file__),'../data/megadepth_1500.json')
38 |
39 | class MegaDepth1500(Dataset):
40 | """
41 | Streamlined MegaDepth-1500 dataloader. The camera poses & metadata are stored in a formatted json for facilitating
42 | the download of the dataset and to keep the setup as simple as possible.
43 | """
44 | def __init__(self, json_file, root_dir):
45 | # Load the info & calibration from the JSON
46 | with open(json_file, 'r') as f:
47 | self.data = json.load(f)
48 |
49 | self.root_dir = root_dir
50 |
51 | if not os.path.exists(self.root_dir):
52 | raise RuntimeError(
53 | f"Dataset {self.root_dir} does not exist! \n \
54 | > If you didn't download the dataset, use the downloader tool: python3 -m modules.dataset.download -h")
55 |
56 | def __len__(self):
57 | return len(self.data)
58 |
59 | def __getitem__(self, idx):
60 | data = deepcopy(self.data[idx])
61 |
62 | h1, w1 = data['size0_hw']
63 | h2, w2 = data['size1_hw']
64 |
65 | # Here we resize the images to max_dim = 1200, as described in the paper, and adjust the image such that it is divisible by 32
66 | # following the protocol of the LoFTR's Dataloader (intrinsics are corrected accordingly).
67 | # For adapting this with different resolution, you would need to re-scale intrinsics below.
68 | image0 = cv2.resize(cv2.imread(f"{self.root_dir}/{data['pair_names'][0]}"),(w1, h1))
69 |
70 | image1 = cv2.resize(cv2.imread(f"{self.root_dir}/{data['pair_names'][1]}"),(w2, h2))
71 |
72 | data['image0'] = torch.tensor(image0.astype(np.float32)/255).permute(2,0,1)
73 | data['image1'] = torch.tensor(image1.astype(np.float32)/255).permute(2,0,1)
74 |
75 | for k,v in data.items():
76 | if k not in ('dataset_name', 'scene_id', 'pair_id', 'pair_names', 'size0_hw', 'size1_hw', 'image0', 'image1'):
77 | data[k] = torch.tensor(np.array(v, dtype=np.float32))
78 |
79 | return data
80 |
81 | if __name__ == "__main__":
82 | weights=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
83 | liftfeat=LiftFeat(weight=weights)
84 |
85 | dataset = MegaDepth1500(json_file = DATASET_JSON, root_dir = DATASET_ROOT)
86 |
87 | loader = DataLoader(dataset, batch_size=1, shuffle=False)
88 |
89 | metrics = {}
90 | R_errs = []
91 | t_errs = []
92 | inliers = []
93 |
94 | results=[]
95 |
96 | cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
97 |
98 | for d in tqdm.tqdm(loader, desc="processing"):
99 | error_infos = compute_pose_error(liftfeat.match_liftfeat,d)
100 | results.append(error_infos)
101 |
102 | print(f'\n==={cur_time}==={args.name}===')
103 | d_err_auc,errors=compute_maa(results)
104 | for s_k,s_v in d_err_auc.items():
105 | print(f'{s_k}: {s_v*100}')
106 |
--------------------------------------------------------------------------------
/evaluation/__pycache__/eval_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/evaluation/__pycache__/eval_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/evaluation/eval_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import poselib
4 |
5 |
6 | def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
7 | # angle error between 2 vectors
8 | t_gt = T_0to1[:3, 3]
9 | n = np.linalg.norm(t) * np.linalg.norm(t_gt)
10 | t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
11 | t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
12 | if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
13 | t_err = 0
14 |
15 | # angle error between 2 rotation matrices
16 | R_gt = T_0to1[:3, :3]
17 | cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
18 | cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
19 | R_err = np.rad2deg(np.abs(np.arccos(cos)))
20 |
21 | return t_err, R_err
22 |
23 | def intrinsics_to_camera(K):
24 | px, py = K[0, 2], K[1, 2]
25 | fx, fy = K[0, 0], K[1, 1]
26 | return {
27 | "model": "PINHOLE",
28 | "width": int(2 * px),
29 | "height": int(2 * py),
30 | "params": [fx, fy, px, py],
31 | }
32 |
33 |
34 | def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
35 | M, info = poselib.estimate_relative_pose(
36 | kpts0, kpts1,
37 | intrinsics_to_camera(K0),
38 | intrinsics_to_camera(K1),
39 | {"max_epipolar_error": thresh,
40 | "success_prob": conf,
41 | "min_iterations": 20,
42 | "max_iterations": 1_000},
43 | )
44 |
45 | R, t, inl = M.R, M.t, info["inliers"]
46 | inl = np.array(inl)
47 | ret = (R, t, inl)
48 |
49 | return ret
50 |
51 | def tensor2bgr(t):
52 | return (t.cpu()[0].permute(1,2,0).numpy()*255).astype(np.uint8)
53 |
54 | def compute_pose_error(match_fn,data):
55 | result = {}
56 |
57 | with torch.no_grad():
58 | mkpts0,mkpts1=match_fn(tensor2bgr(data["image0"]),tensor2bgr(data["image1"]))
59 |
60 | mkpts0=mkpts0 * data["scale0"].numpy()
61 | mkpts1=mkpts1 * data["scale1"].numpy()
62 |
63 | K0, K1 = data["K0"][0].numpy(), data["K1"][0].numpy()
64 | T_0to1 = data["T_0to1"][0].numpy()
65 | T_1to0 = data["T_1to0"][0].numpy()
66 |
67 | result={}
68 | conf = 0.99999
69 |
70 | ret = estimate_pose(mkpts0,mkpts1,K0,K1,4.0,conf)
71 | if ret is not None:
72 | R, t, inliers = ret
73 | t_err, R_err = relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0)
74 | result['R_err'] = R_err
75 | result['t_err'] = t_err
76 |
77 | return result
78 |
79 |
80 | def error_auc(errors, thresholds=[5, 10, 20]):
81 | """
82 | Args:
83 | errors (list): [N,]
84 | thresholds (list)
85 | """
86 | errors = [0] + sorted(list(errors))
87 | recall = list(np.linspace(0, 1, len(errors)))
88 |
89 | aucs = []
90 |
91 | for thr in thresholds:
92 | last_index = np.searchsorted(errors, thr)
93 | y = recall[:last_index] + [recall[last_index-1]]
94 | x = errors[:last_index] + [thr]
95 | aucs.append(np.trapz(y, x) / thr)
96 |
97 | return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
98 |
99 | def compute_maa(pairs, thresholds=[5, 10, 20]):
100 | # print("auc / mAcc on %d pairs" % (len(pairs)))
101 | errors = []
102 |
103 | for p in pairs:
104 | et = p['t_err']
105 | er = p['R_err']
106 | errors.append(max(et, er))
107 |
108 | d_err_auc = error_auc(errors)
109 |
110 | # for k,v in d_err_auc.items():
111 | # print(k, ': ', '%.1f'%(v*100))
112 |
113 | errors = np.array(errors)
114 |
115 | for t in thresholds:
116 | acc = (errors <= t).sum() / len(errors)
117 | # print("mAcc@%d: %.1f "%(t, acc*100))
118 |
119 | return d_err_auc,errors
120 |
121 | def homo_trans(coord, H):
122 | kpt_num = coord.shape[0]
123 | homo_coord = np.concatenate((coord, np.ones((kpt_num, 1))), axis=-1)
124 | proj_coord = np.matmul(H, homo_coord.T).T
125 | proj_coord = proj_coord / proj_coord[:, 2][..., None]
126 | proj_coord = proj_coord[:, 0:2]
127 | return proj_coord
128 |
--------------------------------------------------------------------------------
/loss/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/loss/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/loss/loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import time
7 |
8 |
9 | def dual_softmax_loss(X, Y, temp = 0.2):
10 | if X.size() != Y.size() or X.dim() != 2 or Y.dim() != 2:
11 | raise RuntimeError('Error: X and Y shapes must match and be 2D matrices')
12 |
13 | dist_mat = (X @ Y.t()) * temp
14 | conf_matrix12 = F.log_softmax(dist_mat, dim=1)
15 | conf_matrix21 = F.log_softmax(dist_mat.t(), dim=1)
16 |
17 | with torch.no_grad():
18 | conf12 = torch.exp( conf_matrix12 ).max(dim=-1)[0]
19 | conf21 = torch.exp( conf_matrix21 ).max(dim=-1)[0]
20 | conf = conf12 * conf21
21 |
22 | target = torch.arange(len(X), device = X.device)
23 |
24 | loss = F.nll_loss(conf_matrix12, target) + \
25 | F.nll_loss(conf_matrix21, target)
26 |
27 | return loss, conf
28 |
29 |
30 | class LiftFeatLoss(nn.Module):
31 | def __init__(self,device,lam_descs=1,lam_fb_descs=1,lam_kpts=1,lam_heatmap=1,lam_normals=1,lam_coordinates=1,lam_fb_coordinates=1,depth_spvs=False):
32 | super().__init__()
33 |
34 | # loss parameters
35 | self.lam_descs=lam_descs
36 | self.lam_fb_descs=lam_fb_descs
37 | self.lam_kpts=lam_kpts
38 | self.lam_heatmap=lam_heatmap
39 | self.lam_normals=lam_normals
40 | self.lam_coordinates=lam_coordinates
41 | self.lam_fb_coordinates=lam_fb_coordinates
42 | self.depth_spvs=depth_spvs
43 | self.running_descs_loss=0
44 | self.running_kpts_loss=0
45 | self.running_heatmaps_loss=0
46 | self.loss_descs=0
47 | self.loss_fb_descs=0
48 | self.loss_kpts=0
49 | self.loss_heatmaps=0
50 | self.loss_normals=0
51 | self.loss_coordinates=0
52 | self.loss_fb_coordinates=0
53 | self.acc_coarse=0
54 | self.acc_fb_coarse=0
55 | self.acc_kpt=0
56 | self.acc_coordinates=0
57 | self.acc_fb_coordinates=0
58 |
59 | # device
60 | self.dev=device
61 |
62 |
63 | def check_accuracy(self,m1,m2,pts1=None,pts2=None,plot=False):
64 | with torch.no_grad():
65 | #dist_mat = torch.cdist(X,Y)
66 | dist_mat = m1 @ m2.t()
67 | nn = torch.argmax(dist_mat, dim=1)
68 | #nn = torch.argmin(dist_mat, dim=1)
69 | correct = nn == torch.arange(len(m1), device = m1.device)
70 |
71 | if pts1 is not None and plot:
72 | import matplotlib.pyplot as plt
73 | canvas = torch.zeros((60, 80),device=m1.device)
74 | pts1 = pts1[~correct]
75 | canvas[pts1[:,1].long(), pts1[:,0].long()] = 1
76 | canvas = canvas.cpu().numpy()
77 | plt.imshow(canvas), plt.show()
78 |
79 | acc = correct.sum().item() / len(m1)
80 | return acc
81 |
82 | def compute_descriptors_loss(self,descs1,descs2,pts):
83 | loss=[]
84 | acc=0
85 | B,_,H,W=descs1.shape
86 | conf_list=[]
87 |
88 | for b in range(B):
89 | pts1,pts2=pts[b][:,:2],pts[b][:,2:]
90 | m1=descs1[b,:,pts1[:,1].long(),pts1[:,0].long()].permute(1,0)
91 | m2=descs2[b,:,pts2[:,1].long(),pts2[:,0].long()].permute(1,0)
92 |
93 | loss_per,conf_per=dual_softmax_loss(m1,m2)
94 | loss.append(loss_per.unsqueeze(0))
95 | conf_list.append(conf_per)
96 |
97 | acc_coarse_per=self.check_accuracy(m1,m2)
98 | acc += acc_coarse_per
99 |
100 | loss=torch.cat(loss,dim=-1).mean()
101 | acc /= B
102 | return loss,acc,conf_list
103 |
104 |
105 | def alike_distill_loss(self,kpts,alike_kpts):
106 | C, H, W = kpts.shape
107 | kpts = kpts.permute(1,2,0)
108 | # get ALike keypoints
109 | with torch.no_grad():
110 | labels = torch.ones((H, W), dtype = torch.long, device = kpts.device) * 64 # -> Default is non-keypoint (bin 64)
111 | offsets = (((alike_kpts/8) - (alike_kpts/8).long())*8).long()
112 | offsets = offsets[:, 0] + 8*offsets[:, 1] # Linear IDX
113 | labels[(alike_kpts[:,1]/8).long(), (alike_kpts[:,0]/8).long()] = offsets
114 |
115 | kpts = kpts.view(-1,C)
116 | labels = labels.view(-1)
117 |
118 | mask = labels < 64
119 | idxs_pos = mask.nonzero().flatten()
120 | idxs_neg = (~mask).nonzero().flatten()
121 | perm = torch.randperm(idxs_neg.size(0))[:len(idxs_pos)//32]
122 | idxs_neg = idxs_neg[perm]
123 | idxs = torch.cat([idxs_pos, idxs_neg])
124 |
125 | kpts = kpts[idxs]
126 | labels = labels[idxs]
127 |
128 | with torch.no_grad():
129 | predicted = kpts.max(dim=-1)[1]
130 | acc = (labels == predicted)
131 | acc = acc.sum() / len(acc)
132 |
133 | kpts = F.log_softmax(kpts,dim=-1)
134 | loss = F.nll_loss(kpts, labels, reduction = 'mean')
135 |
136 | return loss, acc
137 |
138 |
139 | def compute_keypoints_loss(self,kpts1,kpts2,alike_kpts1,alike_kpts2):
140 | loss=[]
141 | acc=0
142 | B,_,H,W=kpts1.shape
143 |
144 | for b in range(B):
145 | loss_per1,acc_per1=self.alike_distill_loss(kpts1[b],alike_kpts1[b])
146 | loss_per2,acc_per2=self.alike_distill_loss(kpts2[b],alike_kpts2[b])
147 | loss_per=(loss_per1+loss_per2)
148 | acc_per=(acc_per1+acc_per2)/2
149 | loss.append(loss_per.unsqueeze(0))
150 | acc += acc_per
151 |
152 | loss=torch.cat(loss,dim=-1).mean()
153 | acc /= B
154 | return loss,acc
155 |
156 |
157 | def compute_heatmaps_loss(self,heatmaps1,heatmaps2,pts,conf_list):
158 | loss=[]
159 | B,_,H,W=heatmaps1.shape
160 |
161 | for b in range(B):
162 | pts1,pts2=pts[b][:,:2],pts[b][:,2:]
163 | h1=heatmaps1[b,0,pts1[:,1].long(),pts1[:,0].long()]
164 | h2=heatmaps2[b,0,pts2[:,1].long(),pts2[:,0].long()]
165 |
166 | conf=conf_list[b]
167 | loss_per1=F.l1_loss(h1,conf)
168 | loss_per2=F.l1_loss(h2,conf)
169 | loss_per=(loss_per1+loss_per2)
170 | loss.append(loss_per.unsqueeze(0))
171 |
172 | loss=torch.cat(loss,dim=-1).mean()
173 | return loss
174 |
175 |
176 | def normal_loss(self,normal,target_normal):
177 | # import pdb;pdb.set_trace()
178 | normal = normal.permute(1, 2, 0)
179 | target_normal = target_normal.permute(1,2,0)
180 | # loss = F.l1_loss(d_feat, depth_anything_normal_feat)
181 | dot = torch.cosine_similarity(normal, target_normal, dim=2)
182 | valid_mask = target_normal[:, :, 0].float() \
183 | * (dot.detach() < 0.999).float() \
184 | * (dot.detach() > -0.999).float()
185 | valid_mask = valid_mask > 0.0
186 | al = torch.acos(dot[valid_mask])
187 | loss = torch.mean(al)
188 | return loss
189 |
190 |
191 | def compute_normals_loss(self,normals1,normals2,DA_normals1,DA_normals2,megadepth_batch_size,coco_batch_size):
192 | loss=[]
193 |
194 | # import pdb;pdb.set_trace()
195 |
196 | # only MegaDepth image need depth-normal
197 | normals1=normals1[coco_batch_size:,...]
198 | normals2=normals2[coco_batch_size:,...]
199 | for b in range(len(DA_normals1)):
200 | normal1,normal2=normals1[b],normals2[b]
201 | loss_per1=self.normal_loss(normal1,DA_normals1[b].permute(2,0,1))
202 | loss_per2=self.normal_loss(normal2,DA_normals2[b].permute(2,0,1))
203 | loss_per=(loss_per1+loss_per2)
204 | loss.append(loss_per.unsqueeze(0))
205 |
206 | loss=torch.cat(loss,dim=-1).mean()
207 | return loss
208 |
209 |
210 | def coordinate_loss(self,coordinate,conf,pts1):
211 | with torch.no_grad():
212 | coordinate_detached = pts1 * 8
213 | offset_detached = (coordinate_detached/8) - (coordinate_detached/8).long()
214 | offset_detached = (offset_detached * 8).long()
215 | label = offset_detached[:, 0] + 8*offset_detached[:, 1]
216 |
217 | #pdb.set_trace()
218 | coordinate_log = F.log_softmax(coordinate, dim=-1)
219 |
220 | predicted = coordinate.max(dim=-1)[1]
221 | acc = (label == predicted)
222 | acc = acc[conf > 0.1]
223 | acc = acc.sum() / len(acc)
224 |
225 | loss = F.nll_loss(coordinate_log, label, reduction = 'none')
226 |
227 | #Weight loss by confidence, giving more emphasis on reliable matches
228 | conf = conf / conf.sum()
229 | loss = (loss * conf).sum()
230 |
231 | return loss*2., acc
232 |
233 | def compute_coordinates_loss(self,coordinates,pts,conf_list):
234 | loss=[]
235 | acc=0
236 | B,_,H,W=coordinates.shape
237 |
238 | for b in range(B):
239 | pts1,pts2=pts[b][:,:2],pts[b][:,2:]
240 | coordinate=coordinates[b,:,pts1[:,1].long(),pts1[:,0].long()].permute(1,0)
241 | conf=conf_list[b]
242 |
243 | loss_per,acc_per=self.coordinate_loss(coordinate,conf,pts1)
244 | loss.append(loss_per.unsqueeze(0))
245 | acc += acc_per
246 |
247 | loss=torch.cat(loss,dim=-1).mean()
248 | acc /= B
249 |
250 | return loss,acc
251 |
252 |
253 | def forward(self,
254 | descs1,fb_descs1,kpts1,normals1,
255 | descs2,fb_descs2,kpts2,normals2,
256 | pts,coordinates,fb_coordinates,
257 | alike_kpts1,alike_kpts2,
258 | DA_normals1,DA_normals2,
259 | megadepth_batch_size,coco_batch_size
260 | ):
261 | # import pdb;pdb.set_trace()
262 | self.loss_descs,self.acc_coarse,conf_list=self.compute_descriptors_loss(descs1,descs2,pts)
263 | self.loss_fb_descs,self.acc_fb_coarse,fb_conf_list=self.compute_descriptors_loss(fb_descs1,fb_descs2,pts)
264 |
265 | # start=time.perf_counter()
266 | self.loss_kpts,self.acc_kpt=self.compute_keypoints_loss(kpts1,kpts2,alike_kpts1,alike_kpts2)
267 | # end=time.perf_counter()
268 | # print(f"kpts loss cost {end-start} seconds")
269 |
270 | # start=time.perf_counter()
271 | self.loss_normals=self.compute_normals_loss(normals1,normals2,DA_normals1,DA_normals2,megadepth_batch_size,coco_batch_size)
272 | # end=time.perf_counter()
273 | # print(f"normal loss cost {end-start} seconds")
274 |
275 | self.loss_coordinates,self.acc_coordinates=self.compute_coordinates_loss(coordinates,pts,conf_list)
276 | self.loss_fb_coordinates,self.acc_fb_coordinates=self.compute_coordinates_loss(fb_coordinates,pts,fb_conf_list)
277 |
278 | return {
279 | 'loss_descs':self.lam_descs*self.loss_descs,
280 | 'acc_coarse':self.acc_coarse,
281 | 'loss_coordinates':self.lam_coordinates*self.loss_coordinates,
282 | 'acc_coordinates':self.acc_coordinates,
283 | 'loss_fb_descs':self.lam_fb_descs*self.loss_fb_descs,
284 | 'acc_fb_coarse':self.acc_fb_coarse,
285 | 'loss_fb_coordinates':self.lam_fb_coordinates*self.loss_fb_coordinates,
286 | 'acc_fb_coordinates':self.acc_fb_coordinates,
287 | 'loss_kpts':self.lam_kpts*self.loss_kpts,
288 | 'acc_kpt':self.acc_kpt,
289 | 'loss_normals':self.lam_normals*self.loss_normals,
290 | }
291 |
292 |
--------------------------------------------------------------------------------
/models/__pycache__/interpolator.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/interpolator.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/interpolator.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/interpolator.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/liftfeat_wrapper.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/liftfeat_wrapper.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/liftfeat_wrapper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/liftfeat_wrapper.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/models/interpolator.py:
--------------------------------------------------------------------------------
1 | """
2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3 |
4 | This script is used to interpolate rough descriptors from LiftFeat
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | class InterpolateSparse2d(nn.Module):
12 | """ Efficiently interpolate tensor at given sparse 2D positions. """
13 | def __init__(self, mode = 'bicubic', align_corners = False):
14 | super().__init__()
15 | self.mode = mode
16 | self.align_corners = align_corners
17 |
18 | def normgrid(self, x, H, W):
19 | """ Normalize coords to [-1,1]. """
20 | return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.
21 |
22 | def forward(self, x, pos, H, W):
23 | """
24 | Input
25 | x: [B, C, H, W] feature tensor
26 | pos: [B, N, 2] tensor of positions
27 | H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]
28 |
29 | Returns
30 | [B, N, C] sampled channels at 2d positions
31 | """
32 | grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
33 | x = F.grid_sample(x, grid, mode = self.mode , align_corners = False)
34 | return x.permute(0,2,3,1).squeeze(-2)
--------------------------------------------------------------------------------
/models/liftfeat_wrapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import numpy as np
5 | import math
6 | import cv2
7 |
8 | from models.model import LiftFeatSPModel
9 | from models.interpolator import InterpolateSparse2d
10 | from utils.config import featureboost_config
11 |
12 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
13 |
14 | MODEL_PATH = os.path.join(os.path.dirname(__file__), "../weights/LiftFeat.pth")
15 |
16 |
17 | class NonMaxSuppression(torch.nn.Module):
18 | def __init__(self, rep_thr=0.1, top_k=4096):
19 | super(NonMaxSuppression, self).__init__()
20 | self.max_filter = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
21 | self.rep_thr = rep_thr
22 | self.top_k = top_k
23 |
24 | def NMS(self, x, threshold=0.05, kernel_size=5):
25 | B, _, H, W = x.shape
26 | pad = kernel_size // 2
27 | local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
28 | pos = (x == local_max) & (x > threshold)
29 | pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]
30 |
31 | pad_val = max([len(x) for x in pos_batched])
32 | pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)
33 |
34 | # Pad kpts and build (B, N, 2) tensor
35 | for b in range(len(pos_batched)):
36 | pos[b, : len(pos_batched[b]), :] = pos_batched[b]
37 |
38 | return pos
39 |
40 | def forward(self, score):
41 | pos = self.NMS(score, self.rep_thr)
42 |
43 | return pos
44 |
45 |
46 | def load_model(model, weight_path):
47 | pretrained_weights = torch.load(weight_path, map_location="cpu")
48 |
49 | model_keys = set(model.state_dict().keys())
50 | pretrained_keys = set(pretrained_weights.keys())
51 |
52 | missing_keys = model_keys - pretrained_keys
53 | unexpected_keys = pretrained_keys - model_keys
54 |
55 | # if missing_keys:
56 | # print("Missing keys in pretrained weights:", missing_keys)
57 | # else:
58 | # print("No missing keys in pretrained weights.")
59 |
60 | # if unexpected_keys:
61 | # print("Unexpected keys in pretrained weights:", unexpected_keys)
62 | # else:
63 | # print("No unexpected keys in pretrained weights.")
64 |
65 | if not missing_keys and not unexpected_keys:
66 | model.load_state_dict(pretrained_weights)
67 | print("load weight successfully.")
68 | else:
69 | model.load_state_dict(pretrained_weights, strict=False)
70 | # print("There were issues with the keys.")
71 | return model
72 |
73 |
74 | import torch.nn as nn
75 |
76 |
77 | class LiftFeat(nn.Module):
78 | def __init__(self, weight=MODEL_PATH, top_k=4096, detect_threshold=0.1):
79 | super().__init__()
80 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81 | self.net = LiftFeatSPModel(featureboost_config).to(self.device).eval()
82 | self.top_k = top_k
83 | self.sampler = InterpolateSparse2d("bicubic")
84 | self.net = load_model(self.net, weight)
85 | self.detector = NonMaxSuppression(rep_thr=detect_threshold)
86 | self.net = self.net.to(self.device)
87 | self.detector = self.detector.to(self.device)
88 | self.sampler = self.sampler.to(self.device)
89 |
90 | def image_preprocess(self, image: np.ndarray):
91 | H, W, C = image.shape[0], image.shape[1], image.shape[2]
92 |
93 | _H = math.ceil(H / 32) * 32
94 | _W = math.ceil(W / 32) * 32
95 |
96 | pad_h = _H - H
97 | pad_w = _W - W
98 |
99 | image = cv2.copyMakeBorder(image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, None, (0, 0, 0))
100 |
101 | pad_info = [0, pad_h, 0, pad_w]
102 |
103 | if len(image.shape) == 3:
104 | image = image[None, ...]
105 |
106 | image = torch.tensor(image).permute(0, 3, 1, 2) / 255
107 | image = image.to(device)
108 |
109 | return image, pad_info
110 |
111 | @torch.inference_mode()
112 | def extract(self, image: np.ndarray):
113 | image, pad_info = self.image_preprocess(image)
114 | B, _, _H1, _W1 = image.shape
115 |
116 | M1, K1, D1 = self.net.forward1(image)
117 | refine_M = self.net.forward2(M1, K1, D1)
118 |
119 | refine_M = refine_M.reshape(M1.shape[0], M1.shape[2], M1.shape[3], -1).permute(0, 3, 1, 2)
120 | refine_M = torch.nn.functional.normalize(refine_M, 2, dim=1)
121 |
122 | descs_map = refine_M
123 |
124 | scores = torch.softmax(K1, dim=1)[:, :64]
125 | heatmap = scores.permute(0, 2, 3, 1).reshape(scores.shape[0], scores.shape[2], scores.shape[3], 8, 8)
126 | heatmap = heatmap.permute(0, 1, 3, 2, 4).reshape(scores.shape[0], 1, scores.shape[2] * 8, scores.shape[3] * 8)
127 |
128 | pos = self.detector(heatmap)
129 | kpts = pos.squeeze(0)
130 | mask_w = kpts[..., 0] < (_W1 - pad_info[-1])
131 | kpts = kpts[mask_w]
132 | mask_h = kpts[..., 1] < (_H1 - pad_info[1])
133 | kpts = kpts[mask_h]
134 |
135 | scores = self.sampler(heatmap, kpts.unsqueeze(0), _H1, _W1)
136 | scores = scores.squeeze(0).reshape(-1)
137 | descs = self.sampler(descs_map, kpts.unsqueeze(0), _H1, _W1)
138 | descs = torch.nn.functional.normalize(descs, p=2, dim=1)
139 | descs = descs.squeeze(0)
140 |
141 | return {"descriptors": descs, "keypoints": kpts, "scores": scores}
142 |
143 | def match_liftfeat(self, img1, img2, min_cossim=-1):
144 | # import pdb;pdb.set_trace()
145 | data1 = self.extract(img1)
146 | data2 = self.extract(img2)
147 |
148 | kpts1, feats1 = data1["keypoints"], data1["descriptors"]
149 | kpts2, feats2 = data2["keypoints"], data2["descriptors"]
150 |
151 | cossim = feats1 @ feats2.t()
152 | cossim_t = feats2 @ feats1.t()
153 |
154 | _, match12 = cossim.max(dim=1)
155 | _, match21 = cossim_t.max(dim=1)
156 |
157 | idx0 = torch.arange(len(match12), device=match12.device)
158 | mutual = match21[match12] == idx0
159 |
160 | if min_cossim > 0:
161 | cossim, _ = cossim.max(dim=1)
162 | good = cossim > min_cossim
163 | idx0 = idx0[mutual & good]
164 | idx1 = match12[mutual & good]
165 | else:
166 | idx0 = idx0[mutual]
167 | idx1 = match12[mutual]
168 |
169 | mkpts1, mkpts2 = kpts1[idx0], kpts2[idx1]
170 | mkpts1, mkpts2 = mkpts1.cpu().numpy(), mkpts2.cpu().numpy()
171 |
172 | return mkpts1, mkpts2
173 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | "LiftFeat: 3D Geometry-Aware Local Feature Matching"
4 | """
5 |
6 | import numpy as np
7 | import os
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 |
12 | import tqdm
13 | import math
14 | import cv2
15 |
16 | import sys
17 | sys.path.append('/home/yepeng_liu/code_python/laiwenpeng/LiftFeat')
18 | from utils.featurebooster import FeatureBooster
19 | from utils.config import featureboost_config
20 |
21 | # from models.model_dfb import LiftFeatModel
22 | # from models.interpolator import InterpolateSparse2d
23 | # from third_party.config import featureboost_config
24 |
25 | """
26 | foundational functions
27 | """
28 | def simple_nms(scores, radius):
29 | """Perform non maximum suppression on the heatmap using max-pooling.
30 | This method does not suppress contiguous points that have the same score.
31 | Args:
32 | scores: the score heatmap of size `(B, H, W)`.
33 | radius: an integer scalar, the radius of the NMS window.
34 | """
35 |
36 | def max_pool(x):
37 | return torch.nn.functional.max_pool2d(
38 | x, kernel_size=radius * 2 + 1, stride=1, padding=radius
39 | )
40 |
41 | zeros = torch.zeros_like(scores)
42 | max_mask = scores == max_pool(scores)
43 | for _ in range(2):
44 | supp_mask = max_pool(max_mask.float()) > 0
45 | supp_scores = torch.where(supp_mask, zeros, scores)
46 | new_max_mask = supp_scores == max_pool(supp_scores)
47 | max_mask = max_mask | (new_max_mask & (~supp_mask))
48 | return torch.where(max_mask, scores, zeros)
49 |
50 |
51 | def top_k_keypoints(keypoints, scores, k):
52 | if k >= len(keypoints):
53 | return keypoints, scores
54 | scores, indices = torch.topk(scores, k, dim=0, sorted=True)
55 | return keypoints[indices], scores
56 |
57 |
58 | def sample_k_keypoints(keypoints, scores, k):
59 | if k >= len(keypoints):
60 | return keypoints, scores
61 | indices = torch.multinomial(scores, k, replacement=False)
62 | return keypoints[indices], scores[indices]
63 |
64 |
65 | def soft_argmax_refinement(keypoints, scores, radius: int):
66 | width = 2 * radius + 1
67 | sum_ = torch.nn.functional.avg_pool2d(
68 | scores[:, None], width, 1, radius, divisor_override=1
69 | )
70 | ar = torch.arange(-radius, radius + 1).to(scores)
71 | kernel_x = ar[None].expand(width, -1)[None, None]
72 | dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius)
73 | dy = torch.nn.functional.conv2d(
74 | scores[:, None], kernel_x.transpose(2, 3), padding=radius
75 | )
76 | dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None]
77 | refined_keypoints = []
78 | for i, kpts in enumerate(keypoints):
79 | delta = dydx[i][tuple(kpts.t())]
80 | refined_keypoints.append(kpts.float() + delta)
81 | return refined_keypoints
82 |
83 |
84 | # Legacy (broken) sampling of the descriptors
85 | def sample_descriptors(keypoints, descriptors, s):
86 | b, c, h, w = descriptors.shape
87 | keypoints = keypoints - s / 2 + 0.5
88 | keypoints /= torch.tensor(
89 | [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
90 | ).to(
91 | keypoints
92 | )[None]
93 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
94 | args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
95 | descriptors = torch.nn.functional.grid_sample(
96 | descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
97 | )
98 | descriptors = torch.nn.functional.normalize(
99 | descriptors.reshape(b, c, -1), p=2, dim=1
100 | )
101 | return descriptors
102 |
103 |
104 | # The original keypoint sampling is incorrect. We patch it here but
105 | # keep the original one above for legacy.
106 | def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8):
107 | """Interpolate descriptors at keypoint locations"""
108 | b, c, h, w = descriptors.shape
109 | keypoints = keypoints / (keypoints.new_tensor([w, h]) * s)
110 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
111 | descriptors = torch.nn.functional.grid_sample(
112 | descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
113 | )
114 | descriptors = torch.nn.functional.normalize(
115 | descriptors.reshape(b, c, -1), p=2, dim=1
116 | )
117 | return descriptors
118 |
119 |
120 | class UpsampleLayer(nn.Module):
121 | def __init__(self, in_channels):
122 | super().__init__()
123 | # 定义特征提取层,减少通道数同时增加特征提取能力
124 | self.conv = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, stride=1, padding=1)
125 | # 使用BN层
126 | self.bn = nn.BatchNorm2d(in_channels//2)
127 | # 使用LeakyReLU激活函数
128 | self.leaky_relu = nn.LeakyReLU(0.1)
129 |
130 | def forward(self, x):
131 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
132 | x = self.leaky_relu(self.bn(self.conv(x)))
133 |
134 | return x
135 |
136 |
137 | class KeypointHead(nn.Module):
138 | def __init__(self,in_channels,out_channels):
139 | super().__init__()
140 | self.layer1=BaseLayer(in_channels,32)
141 | self.layer2=BaseLayer(32,32)
142 | self.layer3=BaseLayer(32,64)
143 | self.layer4=BaseLayer(64,64)
144 | self.layer5=BaseLayer(64,128)
145 |
146 | self.conv=nn.Conv2d(128,out_channels,kernel_size=3,stride=1,padding=1)
147 | self.bn=nn.BatchNorm2d(65)
148 |
149 | def forward(self,x):
150 | x=self.layer1(x)
151 | x=self.layer2(x)
152 | x=self.layer3(x)
153 | x=self.layer4(x)
154 | x=self.layer5(x)
155 | x=self.bn(self.conv(x))
156 | return x
157 |
158 |
159 | class DescriptorHead(nn.Module):
160 | def __init__(self,in_channels,out_channels):
161 | super().__init__()
162 | self.layer=nn.Sequential(
163 | BaseLayer(in_channels,32),
164 | BaseLayer(32,32,activation=False),
165 | BaseLayer(32,64,activation=False),
166 | BaseLayer(64,out_channels,activation=False)
167 | )
168 |
169 | def forward(self,x):
170 | x=self.layer(x)
171 | # x=nn.functional.softmax(x,dim=1)
172 | return x
173 |
174 |
175 | class HeatmapHead(nn.Module):
176 | def __init__(self,in_channels,mid_channels,out_channels):
177 | super().__init__()
178 | self.convHa = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
179 | self.bnHa = nn.BatchNorm2d(mid_channels)
180 | self.convHb = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
181 | self.bnHb = nn.BatchNorm2d(out_channels)
182 | self.leaky_relu = nn.LeakyReLU(0.1)
183 |
184 | def forward(self,x):
185 | x = self.leaky_relu(self.bnHa(self.convHa(x)))
186 | x = self.leaky_relu(self.bnHb(self.convHb(x)))
187 |
188 | x = torch.sigmoid(x)
189 | return x
190 |
191 |
192 | class DepthHead(nn.Module):
193 | def __init__(self, in_channels):
194 | super().__init__()
195 | self.upsampleDa = UpsampleLayer(in_channels)
196 | self.upsampleDb = UpsampleLayer(in_channels//2)
197 | self.upsampleDc = UpsampleLayer(in_channels//4)
198 |
199 | self.convDepa = nn.Conv2d(in_channels//2+in_channels, in_channels//2, kernel_size=3, stride=1, padding=1)
200 | self.bnDepa = nn.BatchNorm2d(in_channels//2)
201 | self.convDepb = nn.Conv2d(in_channels//4+in_channels//2, in_channels//4, kernel_size=3, stride=1, padding=1)
202 | self.bnDepb = nn.BatchNorm2d(in_channels//4)
203 | self.convDepc = nn.Conv2d(in_channels//8+in_channels//4, 3, kernel_size=3, stride=1, padding=1)
204 | self.bnDepc = nn.BatchNorm2d(3)
205 |
206 | self.leaky_relu = nn.LeakyReLU(0.1)
207 |
208 | def forward(self, x):
209 | x0 = F.interpolate(x, scale_factor=2,mode='bilinear',align_corners=False)
210 | x1 = self.upsampleDa(x)
211 | x1 = torch.cat([x0,x1],dim=1)
212 | x1 = self.leaky_relu(self.bnDepa(self.convDepa(x1)))
213 |
214 | x1_0 = F.interpolate(x1,scale_factor=2,mode='bilinear',align_corners=False)
215 | x2 = self.upsampleDb(x1)
216 | x2 = torch.cat([x1_0,x2],dim=1)
217 | x2 = self.leaky_relu(self.bnDepb(self.convDepb(x2)))
218 |
219 | x2_0 = F.interpolate(x2,scale_factor=2,mode='bilinear',align_corners=False)
220 | x3 = self.upsampleDc(x2)
221 | x3 = torch.cat([x2_0,x3],dim=1)
222 | x = self.leaky_relu(self.bnDepc(self.convDepc(x3)))
223 |
224 | x = F.normalize(x,p=2,dim=1)
225 | return x
226 |
227 |
228 | class BaseLayer(nn.Module):
229 | def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False,activation=True):
230 | super().__init__()
231 | if activation:
232 | self.layer=nn.Sequential(
233 | nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias),
234 | nn.BatchNorm2d(out_channels,affine=False),
235 | nn.ReLU(inplace=True)
236 | )
237 | else:
238 | self.layer=nn.Sequential(
239 | nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias),
240 | nn.BatchNorm2d(out_channels,affine=False)
241 | )
242 |
243 | def forward(self,x):
244 | return self.layer(x)
245 |
246 |
247 | class LiftFeatSPModel(nn.Module):
248 | default_conf = {
249 | "has_detector": True,
250 | "has_descriptor": True,
251 | "descriptor_dim": 64,
252 | # Inference
253 | "sparse_outputs": True,
254 | "dense_outputs": False,
255 | "nms_radius": 4,
256 | "refinement_radius": 0,
257 | "detection_threshold": 0.005,
258 | "max_num_keypoints": -1,
259 | "max_num_keypoints_val": None,
260 | "force_num_keypoints": False,
261 | "randomize_keypoints_training": False,
262 | "remove_borders": 4,
263 | "legacy_sampling": True, # True to use the old broken sampling
264 | }
265 |
266 | def __init__(self, featureboost_config, use_kenc=False, use_normal=True, use_cross=True):
267 | super().__init__()
268 | self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
269 | self.descriptor_dim = 64
270 |
271 | self.norm = nn.InstanceNorm2d(1)
272 |
273 | self.relu = nn.ReLU(inplace=True)
274 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
275 | c1,c2,c3,c4,c5 = 24,24,64,64,128
276 |
277 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
278 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
279 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
280 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
281 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
282 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
283 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
284 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
285 | self.conv5a = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
286 | self.conv5b = nn.Conv2d(c5, c5, kernel_size=3, stride=1, padding=1)
287 |
288 | self.upsample4 = UpsampleLayer(c4)
289 | self.upsample5 = UpsampleLayer(c5)
290 | self.conv_fusion45 = nn.Conv2d(c5//2+c4,c4,kernel_size=3,stride=1,padding=1)
291 | self.conv_fusion34 = nn.Conv2d(c4//2+c3,c3,kernel_size=3,stride=1,padding=1)
292 |
293 | # detector
294 | self.keypoint_head = KeypointHead(in_channels=c3,out_channels=65)
295 | # descriptor
296 | self.descriptor_head = DescriptorHead(in_channels=c3,out_channels=self.descriptor_dim)
297 | # # heatmap
298 | # self.heatmap_head = HeatmapHead(in_channels=c3,mid_channels=c3,out_channels=1)
299 | # depth
300 | self.depth_head = DepthHead(c3)
301 |
302 | self.fine_matcher = nn.Sequential(
303 | nn.Linear(128, 512),
304 | nn.BatchNorm1d(512, affine=False),
305 | nn.ReLU(inplace = True),
306 | nn.Linear(512, 512),
307 | nn.BatchNorm1d(512, affine=False),
308 | nn.ReLU(inplace = True),
309 | nn.Linear(512, 512),
310 | nn.BatchNorm1d(512, affine=False),
311 | nn.ReLU(inplace = True),
312 | nn.Linear(512, 512),
313 | nn.BatchNorm1d(512, affine=False),
314 | nn.ReLU(inplace = True),
315 | nn.Linear(512, 64),
316 | )
317 |
318 | # feature_booster
319 | self.feature_boost = FeatureBooster(featureboost_config, use_kenc=use_kenc, use_cross=use_cross, use_normal=use_normal)
320 |
321 | def feature_extract(self, x):
322 | x1 = self.relu(self.conv1a(x))
323 | x1 = self.relu(self.conv1b(x1))
324 | x1 = self.pool(x1)
325 | x2 = self.relu(self.conv2a(x1))
326 | x2 = self.relu(self.conv2b(x2))
327 | x2 = self.pool(x2)
328 | x3 = self.relu(self.conv3a(x2))
329 | x3 = self.relu(self.conv3b(x3))
330 | x3 = self.pool(x3)
331 | x4 = self.relu(self.conv4a(x3))
332 | x4 = self.relu(self.conv4b(x4))
333 | x4 = self.pool(x4)
334 | x5 = self.relu(self.conv5a(x4))
335 | x5 = self.relu(self.conv5b(x5))
336 | x5 = self.pool(x5)
337 | return x3,x4,x5
338 |
339 | def fuse_multi_features(self,x3,x4,x5):
340 | # upsample x5 feature
341 | x5 = self.upsample5(x5)
342 | x4 = torch.cat([x4,x5],dim=1)
343 | x4 = self.conv_fusion45(x4)
344 |
345 | # upsample x4 feature
346 | x4 = self.upsample4(x4)
347 | x3 = torch.cat([x3,x4],dim=1)
348 | x = self.conv_fusion34(x3)
349 | return x
350 |
351 | def _unfold2d(self, x, ws = 2):
352 | """
353 | Unfolds tensor in 2D with desired ws (window size) and concat the channels
354 | """
355 | B, C, H, W = x.shape
356 | x = x.unfold(2, ws , ws).unfold(3, ws,ws).reshape(B, C, H//ws, W//ws, ws**2)
357 | return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws)
358 |
359 |
360 | def forward1(self, x):
361 | """
362 | input:
363 | x -> torch.Tensor(B, C, H, W) grayscale or rgb images
364 | return:
365 | feats -> torch.Tensor(B, 64, H/8, W/8) dense local features
366 | keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map
367 | heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map
368 |
369 | """
370 | with torch.no_grad():
371 | x = x.mean(dim=1, keepdim = True)
372 | x = self.norm(x)
373 |
374 | x3,x4,x5 = self.feature_extract(x)
375 |
376 | # features fusion
377 | x = self.fuse_multi_features(x3,x4,x5)
378 |
379 | # keypoint
380 | keypoint_map = self.keypoint_head(x)
381 | # descriptor
382 | des_map = self.descriptor_head(x)
383 | # # heatmap
384 | # heatmap = self.heatmap_head(x)
385 |
386 | # import pdb;pdb.set_trace()
387 | # depth
388 | d_feats = self.depth_head(x)
389 |
390 | return des_map, keypoint_map, d_feats
391 | # return des_map, keypoint_map, heatmap, d_feats
392 |
393 | def forward2(self, descs, kpts, normals):
394 | # import pdb;pdb.set_trace()
395 | normals_feat=self._unfold2d(normals, ws=8)
396 | normals_v=normals_feat.squeeze(0).permute(1,2,0).reshape(-1,normals_feat.shape[1])
397 | descs_v=descs.squeeze(0).permute(1,2,0).reshape(-1,descs.shape[1])
398 | kpts_v=kpts.squeeze(0).permute(1,2,0).reshape(-1,kpts.shape[1])
399 | descs_refine = self.feature_boost(descs_v, kpts_v, normals_v)
400 | return descs_refine
401 |
402 | def forward(self,x):
403 | M1,K1,D1=self.forward1(x)
404 | descs_refine=self.forward2(M1,K1,D1)
405 | return descs_refine,M1,K1,D1
406 |
407 |
408 | if __name__ == "__main__":
409 | img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg')
410 | img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
411 | img=cv2.resize(img,(800,608))
412 | import pdb;pdb.set_trace()
413 | img=torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()/255.0
414 | img=img.cuda() if torch.cuda.is_available() else img
415 | liftfeat_sp=LiftFeatSPModel(featureboost_config).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
416 | des_map, keypoint_map, d_feats=liftfeat_sp.forward1(img)
417 | des_fine=liftfeat_sp.forward2(des_map,keypoint_map,d_feats)
418 | print(des_map.shape)
419 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1
2 | torchvision==0.14.1
3 | einops==0.8.0
4 | kornia==0.7.3
5 | timm==1.0.15
6 | albumentations==1.4.12
7 | imgaug==0.4.0
8 | opencv-python==4.10.0.84
9 | matplotlib==3.7.5
10 | numpy==1.24.4
11 | scikit-image==0.21.0
12 | scipy==1.10.1
13 | pillow==10.3.0
14 | tensorboard==2.14.0
15 | tqdm==4.66.4
16 | omegaconf==2.3.0
17 | thop==0.1.1.post2209072238
18 | poselib
19 |
--------------------------------------------------------------------------------
/tools/demo_match_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import yaml
6 | import matplotlib.cm as cm
7 | import argparse
8 |
9 | import sys
10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
11 | from models.liftfeat_wrapper import LiftFeat,MODEL_PATH
12 | from utils.post_process import match_features
13 | os.environ['CUDA_VISIBLE_DEVICES']='0'
14 |
15 | use_cuda = torch.cuda.is_available()
16 | device = torch.device("cuda" if use_cuda else "cpu")
17 |
18 | class VideoHandler:
19 | def __init__(self,video_path,size=[640,360]):
20 | self.video_path=video_path
21 | self.size=size
22 | self.cap=cv2.VideoCapture(video_path)
23 |
24 | def get_frame(self):
25 | ret,frame=self.cap.read()
26 | if ret==True:
27 | frame=cv2.resize(frame,(int(self.size[0]),int(self.size[1])))
28 | return ret,frame
29 |
30 | def draw_video_match(img0,img1,kpts0,kpts1,mkpts0,mkpts1,match_scores,mask,max_match_num=512,margin=15):
31 | H0, W0, c = img0.shape
32 | H1, W1, c = img1.shape
33 | H, W = max(H0, H1), W0 + W1 + margin
34 |
35 | # 构建画布,把两个图像先拼接到一起
36 | out = 255*np.ones((H, W, 3), np.uint8)
37 | out[:H0, :W0, :] = img0
38 | out[:H1, W0+margin:, :] = img1
39 | #out = np.stack([out]*3, -1)
40 |
41 | kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int)
42 |
43 | mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
44 | mkpts0_correct,mkpts1_correct=mkpts0[mask],mkpts1[mask]
45 | mkpts0_wrong,mkpts1_wrong=mkpts0[~mask],mkpts1[~mask]
46 | match_s=match_scores[mask]
47 |
48 | print(f"correct: {mkpts0_correct.shape[0]} wrong: {mkpts0_wrong.shape[0]}")
49 |
50 | if mkpts0_correct.shape[0] > max_match_num:
51 | # perm=np.random.randint(low=0,high=mkpts0_correct.shape[0],size=max_match_num)
52 | # mkpts0_show,mkpts1_show=mkpts0_correct[perm],mkpts1_correct[perm]
53 | mkpts0_show,mkpts1_show=mkpts0_correct,mkpts1_correct
54 | else:
55 | mkpts0_show,mkpts1_show=mkpts0_correct,mkpts1_correct
56 |
57 | # 普通的点
58 | vis_normal_point = True
59 | if (vis_normal_point):
60 | for x, y in mkpts0_show:
61 | cv2.circle(out, (x, y), 2, (47,132,250), -1, lineType=cv2.LINE_AA)
62 | for x, y in mkpts1_show:
63 | cv2.circle(out, (x + margin + W0, y), 2, (47,132,250), -1,lineType=cv2.LINE_AA)
64 |
65 | vis_match_line = True
66 | if (vis_match_line):
67 | for pt0, pt1,score in zip(mkpts0_show, mkpts1_show,match_s):
68 | color_cm = cm.jet(1.0 - score, alpha=0)
69 | color = (int(color_cm[0] * 255), int(color_cm[1] * 255), int(color_cm[2] * 255))
70 | cv2.line(out, pt0, (W0 + margin + pt1[0], pt1[1]), color, 1)
71 |
72 | return out
73 |
74 | def run_video_demo(std_img_path,video_path):
75 |
76 |
77 | liftfeat=LiftFeat(weight=MODEL_PATH,detect_threshold=0.15)
78 |
79 | std_img=cv2.imread(std_img_path)
80 | std_img=cv2.resize(std_img,(640,360))
81 |
82 | handler=VideoHandler(video_path)
83 |
84 | # 定义编解码器并创建VideoWriter对象
85 | if not os.path.exists('./output'):
86 | os.makedirs('./output')
87 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者使用 'XVID'
88 | out = cv2.VideoWriter('./output/video_demo.mp4', fourcc, 20.0, (1300, 360))
89 | K=[[1084.8,0,640.24],[0,1085,354.87],[0,0,1]]
90 | K=np.array(K)
91 | data_std=liftfeat.extract(std_img)
92 |
93 | while True:
94 | ret,frame=handler.get_frame()
95 | if ret==False:
96 | break
97 |
98 | if frame is not None:
99 | data=liftfeat.extract(frame)
100 | idx0, idx1, match_scores=match_features(data_std["descriptors"],data["descriptors"],-1)
101 | mkpts0=data_std["keypoints"][idx0]
102 | mkpts1=data["keypoints"][idx1]
103 | mkpts0_np=mkpts0.cpu().numpy()
104 | mkpts1_np=mkpts1.cpu().numpy()
105 | match_scores_np=match_scores.detach().cpu().numpy()
106 | kpts0 = (mkpts0_np - K[[0, 1], [2, 2]][None]) / K[[0, 1], [0, 1]][None]
107 | kpts1 = (mkpts1_np - K[[0, 1], [2, 2]][None]) / K[[0, 1], [0, 1]][None]
108 |
109 | # normalize ransac threshold
110 | ransac_thr = 0.5 / np.mean([K[0, 0], K[1, 1], K[0, 0], K[1, 1]])
111 |
112 | if mkpts0_np.shape[0] < 5:
113 | print(f"mkpts size less then 5")
114 | else:
115 | # compute pose with cv2
116 |
117 | E, mask = cv2.findEssentialMat(kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=0.999, method=cv2.RANSAC)
118 | if E is None:
119 | print("\nE is None while trying to recover pose.\n")
120 | continue
121 | match_mask=mask.squeeze(axis=1)>0
122 | show_kpts0,show_kpts1=mkpts0_np[match_mask],mkpts1_np[match_mask]
123 | show_match_scores=match_scores_np[match_mask]
124 | show_mask=np.ones(show_kpts0.shape[0])>0
125 | match_img=draw_video_match(std_img,frame,show_kpts0,show_kpts1,show_kpts0,show_kpts1,show_match_scores,show_mask,margin=20)
126 | kpts0_num,kpts1_num=data_std["keypoints"].shape[0],data["keypoints"].shape[0]
127 | cv2.putText(match_img,f"LiftFeat",(10,20),cv2.FONT_HERSHEY_TRIPLEX,0.5,(0,0,241))
128 | cv2.putText(match_img,f"Keypoints: {kpts0_num}:{kpts1_num}",(10,40),cv2.FONT_HERSHEY_TRIPLEX,0.5,(0,0,255))
129 | cv2.putText(match_img,f"Matches: {show_kpts0.shape[0]}",(10,60),cv2.FONT_HERSHEY_TRIPLEX,0.5,(0,0,255))
130 | out.write(match_img)
131 |
132 |
133 | out.release()
134 |
135 |
136 |
137 |
138 | if __name__=="__main__":
139 | parser = argparse.ArgumentParser(description="Run LiftFeat video matching demo.")
140 | parser.add_argument('--img', type=str, required=True, help='Path to the template image')
141 | parser.add_argument('--video', type=str, required=True, help='Path to the input video')
142 |
143 | args = parser.parse_args()
144 |
145 | run_video_demo(args.img, args.video)
146 |
--------------------------------------------------------------------------------
/tools/demo_vo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import argparse
4 | import yaml
5 | import logging
6 | import os
7 | import sys
8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9 | from utils.VisualOdometry import VisualOdometry, AbosluteScaleComputer, create_dataloader, \
10 | plot_keypoints, create_detector, create_matcher
11 | from models.liftfeat_wrapper import LiftFeat,MODEL_PATH
12 |
13 |
14 | vo_config = {
15 | 'dataset': {
16 | 'name': 'KITTILoader',
17 | 'root_path': '/home/yepeng_liu/code_python/dataset/visual_odometry/kitty/gray',
18 | 'sequence': '10',
19 | 'start': 0
20 | },
21 | 'detector': {
22 | 'name': 'LiftFeatDetector',
23 | 'descriptor_dim': 64,
24 | 'nms_radius': 5,
25 | 'keypoint_threshold': 0.005,
26 | 'max_keypoints': 4096,
27 | 'remove_borders': 4,
28 | 'cuda': 1
29 | },
30 | 'matcher': {
31 | 'name': 'FrameByFrameMatcher',
32 | 'type': 'FLANN',
33 | 'FLANN': {
34 | 'kdTrees': 5,
35 | 'searchChecks': 50
36 | },
37 | 'distance_ratio': 0.75
38 | }
39 | }
40 |
41 | # 可视化当前frame的关键点
42 | def keypoints_plot(img, vo, img_id, path2):
43 | img_ = cv2.imread(path2+str(img_id-1).zfill(6)+".png")
44 |
45 | if not vo.match_kps:
46 | img_ = plot_keypoints(img_, vo.kptdescs["cur"]["keypoints"])
47 | else:
48 | for index in range(vo.match_kps["ref"].shape[0]):
49 | ref_point = tuple(map(int, vo.match_kps['ref'][index,:])) # 将关键点转换为整数元组
50 | cur_point = tuple(map(int, vo.match_kps['cur'][index,:]))
51 | cv2.line(img_, ref_point, cur_point, (0, 255, 0), 2) # Draw green line
52 | cv2.circle(img_, cur_point, 3, (0, 0, 255), -1) # Draw red circle at current keypoint
53 |
54 | return img_
55 |
56 | # 负责绘制相机的轨迹并计算估计轨迹与真实轨迹的误差。
57 | class TrajPlotter(object):
58 | def __init__(self):
59 | self.errors = []
60 | self.traj = np.zeros((800, 800, 3), dtype=np.uint8)
61 | pass
62 |
63 | def update(self, est_xyz, gt_xyz):
64 | x, z = est_xyz[0], est_xyz[2]
65 | gt_x, gt_z = gt_xyz[0], gt_xyz[2]
66 | est = np.array([x, z]).reshape(2)
67 | gt = np.array([gt_x, gt_z]).reshape(2)
68 | error = np.linalg.norm(est - gt)
69 | self.errors.append(error)
70 | avg_error = np.mean(np.array(self.errors))
71 | # === drawer ==================================
72 | # each point
73 | draw_x, draw_y = int(x) + 80, int(z) + 230
74 | true_x, true_y = int(gt_x) + 80, int(gt_z) + 230
75 |
76 | # draw trajectory
77 | cv2.circle(self.traj, (draw_x, draw_y), 1, (0, 0, 255), 1)
78 | cv2.circle(self.traj, (true_x, true_y), 1, (0, 255, 0), 2)
79 | cv2.rectangle(self.traj, (10, 5), (450, 120), (0, 0, 0), -1)
80 |
81 | # draw text
82 | text = "[AvgError] %2.4fm" % (avg_error)
83 | print(text)
84 | cv2.putText(self.traj, text, (20, 40),
85 | cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
86 | note = "Green: GT, Red: Predict"
87 | cv2.putText(self.traj, note, (20, 80),
88 | cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
89 |
90 | return self.traj
91 |
92 | def run_video(args):
93 | # create dataloader
94 | vo_config["dataset"]['root_path'] = args.path1
95 | vo_config["dataset"]['sequence'] = args.id
96 | loader = create_dataloader(vo_config["dataset"])
97 | # create detector
98 | liftfeat=LiftFeat(weight=MODEL_PATH, detect_threshold=0.25)
99 | # create matcher
100 | matcher = create_matcher(vo_config["matcher"])
101 |
102 | absscale = AbosluteScaleComputer()
103 | traj_plotter = TrajPlotter()
104 |
105 |
106 | if not os.path.exists('./output'):
107 | os.makedirs('./output')
108 | fname = "kitti_liftfeat_flannmatch"
109 | log_fopen = open("output/" + fname + ".txt", mode='a')
110 |
111 | vo = VisualOdometry(liftfeat, matcher, loader.cam)
112 |
113 | # Initialize video writer for keypoints and trajectory videos
114 | keypoints_video_path = "output/" + fname + "_keypoints_liftfeat.avi"
115 | trajectory_video_path = "output/" + fname + "_trajectory_liftfeat.avi"
116 |
117 | # Set up video writer: choose codec and set FPS and frame size
118 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
119 | fps = 10 # Adjust the FPS according to your input data
120 | frame_size = (1200, 400) # Get frame size from first image
121 |
122 | # Video writers for keypoints and trajectory
123 | keypoints_writer = cv2.VideoWriter(keypoints_video_path, fourcc, fps, frame_size)
124 | trajectory_writer = cv2.VideoWriter(trajectory_video_path, fourcc, fps, (800, 800))
125 |
126 | for i, img in enumerate(loader):
127 | img_id = loader.img_id
128 | gt_pose = loader.get_cur_pose()
129 |
130 | R, t = vo.update(img, absscale.update(gt_pose))
131 |
132 | # === log writer ==============================
133 | print(i, t[0, 0], t[1, 0], t[2, 0], gt_pose[0, 3], gt_pose[1, 3], gt_pose[2, 3], file=log_fopen)
134 |
135 | # === drawer ==================================
136 | img1 = keypoints_plot(img, vo, img_id, args.path2)
137 | img1 = cv2.resize(img1, (1200, 400))
138 | img2 = traj_plotter.update(t, gt_pose[:, 3])
139 |
140 | # Write frames to videos
141 | keypoints_writer.write(img1)
142 | trajectory_writer.write(img2)
143 |
144 | # Release the video writers
145 | keypoints_writer.release()
146 | trajectory_writer.release()
147 | print(f"Videos saved as {keypoints_video_path} and {trajectory_video_path}")
148 |
149 |
150 |
151 | if __name__ == "__main__":
152 | parser = argparse.ArgumentParser(description='python_vo')
153 | parser.add_argument('--path1', type=str, default='/home/yepeng_liu/code_python/dataset/visual_odometry/kitty/gray',
154 | help='config file')
155 | parser.add_argument('--path2', type=str, default="/home/yepeng_liu/code_python/dataset/visual_odometry/kitty/color/sequences/03/image_2/",
156 | help='config file')
157 | parser.add_argument('--id', type=str, default="03",
158 | help='config file')
159 |
160 |
161 | args = parser.parse_args()
162 |
163 | run_video(args)
164 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3 | training script
4 | """
5 |
6 | import argparse
7 | import os
8 | import time
9 | import sys
10 | sys.path.append(os.path.dirname(__file__))
11 |
12 | def parse_arguments():
13 | parser = argparse.ArgumentParser(description="LiftFeat training script.")
14 | parser.add_argument('--name',type=str,default='LiftFeat',help='set process name')
15 |
16 | # MegaDepth dataset setting
17 | parser.add_argument('--use_megadepth',action='store_true')
18 | parser.add_argument('--megadepth_root_path', type=str,
19 | default='/home/yepeng_liu/code_python/dataset/MegaDepth/phoenix/S6/zl548',
20 | help='Path to the MegaDepth dataset root directory.')
21 | parser.add_argument('--megadepth_batch_size', type=int, default=6)
22 |
23 | # COCO20k dataset setting
24 | parser.add_argument('--use_coco',action='store_true')
25 | parser.add_argument('--coco_root_path', type=str, default='/home/yepeng_liu/code_python/dataset/coco_20k',
26 | help='Path to the COCO20k dataset root directory.')
27 | parser.add_argument('--coco_batch_size',type=int,default=4)
28 |
29 | parser.add_argument('--ckpt_save_path', type=str, default='/home/yepeng_liu/code_python/LiftFeat/trained_weights/test',
30 | help='Path to save the checkpoints.')
31 | parser.add_argument('--n_steps', type=int, default=160_000,
32 | help='Number of training steps. Default is 160000.')
33 | parser.add_argument('--lr', type=float, default=3e-4,
34 | help='Learning rate. Default is 0.0003.')
35 | parser.add_argument('--gamma_steplr', type=float, default=0.5,
36 | help='Gamma value for StepLR scheduler. Default is 0.5.')
37 | parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))),
38 | default=(800, 608), help='Training resolution as width,height. Default is (800, 608).')
39 | parser.add_argument('--device_num', type=str, default='0',
40 | help='Device number to use for training. Default is "0".')
41 | parser.add_argument('--dry_run', action='store_true',
42 | help='If set, perform a dry run training with a mini-batch for sanity check.')
43 | parser.add_argument('--save_ckpt_every', type=int, default=500,
44 | help='Save checkpoints every N steps. Default is 500.')
45 | parser.add_argument('--use_coord_loss',action='store_true',help='Enable coordinate loss')
46 |
47 | args = parser.parse_args()
48 |
49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num
50 |
51 | return args
52 |
53 | args = parse_arguments()
54 |
55 | import torch
56 | from torch import nn
57 | from torch import optim
58 | import torch.nn.functional as F
59 | from torch.utils.tensorboard import SummaryWriter
60 | from torch.utils.data import Dataset, DataLoader
61 |
62 | import numpy as np
63 | import tqdm
64 | import glob
65 |
66 | from models.model import LiftFeatSPModel
67 | from loss.loss import LiftFeatLoss
68 | from utils.config import featureboost_config
69 | from models.interpolator import InterpolateSparse2d
70 | from utils.depth_anything_wrapper import DepthAnythingExtractor
71 | from utils.alike_wrapper import ALikeExtractor
72 |
73 | from dataset import megadepth_wrapper
74 | from dataset import coco_wrapper
75 | from dataset.megadepth import MegaDepthDataset
76 | from dataset.coco_augmentor import COCOAugmentor
77 |
78 | import setproctitle
79 |
80 |
81 | class Trainer():
82 | def __init__(self, megadepth_root_path,use_megadepth,megadepth_batch_size,
83 | coco_root_path,use_coco,coco_batch_size,
84 | ckpt_save_path,
85 | model_name = 'LiftFeat',
86 | n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5,
87 | training_res = (800, 608), device_num="0", dry_run = False,
88 | save_ckpt_every = 500, use_coord_loss = False):
89 | print(f'MegeDepth: {use_megadepth}-{megadepth_batch_size}')
90 | print(f'COCO20k: {use_coco}-{coco_batch_size}')
91 | print(f'Coordinate loss: {use_coord_loss}')
92 | self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
93 |
94 | # training model
95 | self.net = LiftFeatSPModel(featureboost_config, use_kenc=False, use_normal=True, use_cross=True).to(self.dev)
96 | self.loss_fn=LiftFeatLoss(self.dev,lam_descs=1,lam_kpts=2,lam_heatmap=1)
97 |
98 | # depth-anything model
99 | self.depth_net=DepthAnythingExtractor('vits',self.dev,256)
100 |
101 | # alike model
102 | self.alike_net=ALikeExtractor('alike-t',self.dev)
103 |
104 | #Setup optimizer
105 | self.steps = n_steps
106 | self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr)
107 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10_000, gamma=gamma_steplr)
108 |
109 | ##################### COCO INIT ##########################
110 | self.use_coco=use_coco
111 | self.coco_batch_size=coco_batch_size
112 | if self.use_coco:
113 | self.augmentor=COCOAugmentor(
114 | img_dir=coco_root_path,
115 | device=self.dev,load_dataset=True,
116 | batch_size=self.coco_batch_size,
117 | out_resolution=training_res,
118 | warp_resolution=training_res,
119 | sides_crop=0.1,
120 | max_num_imgs=3000,
121 | num_test_imgs=5,
122 | photometric=True,
123 | geometric=True,
124 | reload_step=4000
125 | )
126 | ##################### COCO END #######################
127 |
128 |
129 | ##################### MEGADEPTH INIT ##########################
130 | self.use_megadepth=use_megadepth
131 | self.megadepth_batch_size=megadepth_batch_size
132 | if self.use_megadepth:
133 | TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices"
134 | TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1"
135 |
136 | TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
137 |
138 | npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:]
139 | megadepth_dataset = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE,
140 | npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] )
141 |
142 | self.megadepth_dataloader = DataLoader(megadepth_dataset, batch_size=megadepth_batch_size, shuffle=True)
143 | self.megadepth_data_iter = iter(self.megadepth_dataloader)
144 | ##################### MEGADEPTH INIT END #######################
145 |
146 | os.makedirs(ckpt_save_path, exist_ok=True)
147 | os.makedirs(ckpt_save_path + '/logdir', exist_ok=True)
148 |
149 | self.dry_run = dry_run
150 | self.save_ckpt_every = save_ckpt_every
151 | self.ckpt_save_path = ckpt_save_path
152 | self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S"))
153 | self.model_name = model_name
154 | self.use_coord_loss = use_coord_loss
155 |
156 |
157 | def generate_train_data(self):
158 | imgs1_t,imgs2_t=[],[]
159 | imgs1_np,imgs2_np=[],[]
160 | # norms0,norms1=[],[]
161 | positives_coarse=[]
162 |
163 | if self.use_coco:
164 | coco_imgs1, coco_imgs2, H1, H2 = coco_wrapper.make_batch(self.augmentor, 0.1)
165 | h_coarse, w_coarse = coco_imgs1[0].shape[-2] // 8, coco_imgs1[0].shape[-1] // 8
166 | _ , positives_coco_coarse = coco_wrapper.get_corresponding_pts(coco_imgs1, coco_imgs2, H1, H2, self.augmentor, h_coarse, w_coarse)
167 | coco_imgs1=coco_imgs1.mean(1,keepdim=True);coco_imgs2=coco_imgs2.mean(1,keepdim=True)
168 | imgs1_t.append(coco_imgs1);imgs2_t.append(coco_imgs2)
169 | positives_coarse += positives_coco_coarse
170 |
171 | if self.use_megadepth:
172 | try:
173 | megadepth_data=next(self.megadepth_data_iter)
174 | except StopIteration:
175 | print('End of MD DATASET')
176 | self.megadepth_data_iter=iter(self.megadepth_dataloader)
177 | megadepth_data=next(self.megadepth_data_iter)
178 | if megadepth_data is not None:
179 | for k in megadepth_data.keys():
180 | if isinstance(megadepth_data[k],torch.Tensor):
181 | megadepth_data[k]=megadepth_data[k].to(self.dev)
182 | megadepth_imgs1_t,megadepth_imgs2_t=megadepth_data['image0'],megadepth_data['image1']
183 | megadepth_imgs1_t=megadepth_imgs1_t.mean(1,keepdim=True);megadepth_imgs2_t=megadepth_imgs2_t.mean(1,keepdim=True)
184 | imgs1_t.append(megadepth_imgs1_t);imgs2_t.append(megadepth_imgs2_t)
185 | megadepth_imgs1_np,megadepth_imgs2_np=megadepth_data['image0_np'],megadepth_data['image1_np']
186 | for np_idx in range(megadepth_imgs1_np.shape[0]):
187 | img1_np,img2_np=megadepth_imgs1_np[np_idx].squeeze(0).cpu().numpy(),megadepth_imgs2_np[np_idx].squeeze(0).cpu().numpy()
188 | imgs1_np.append(img1_np);imgs2_np.append(img2_np)
189 | positives_megadepth_coarse=megadepth_wrapper.spvs_coarse(megadepth_data,8)
190 | positives_coarse += positives_megadepth_coarse
191 |
192 | with torch.no_grad():
193 | imgs1_t=torch.cat(imgs1_t,dim=0)
194 | imgs2_t=torch.cat(imgs2_t,dim=0)
195 |
196 | return imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse
197 |
198 |
199 | def train(self):
200 | self.net.train()
201 |
202 | with tqdm.tqdm(total=self.steps) as pbar:
203 | for i in range(self.steps):
204 | # import pdb;pdb.set_trace()
205 | imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse=self.generate_train_data()
206 |
207 | #Check if batch is corrupted with too few correspondences
208 | is_corrupted = False
209 | for p in positives_coarse:
210 | if len(p) < 30:
211 | is_corrupted = True
212 |
213 | if is_corrupted:
214 | continue
215 |
216 | # import pdb;pdb.set_trace()
217 | #Forward pass
218 | # start=time.perf_counter()
219 | feats1,kpts1,normals1 = self.net.forward1(imgs1_t)
220 | feats2,kpts2,normals2 = self.net.forward1(imgs2_t)
221 |
222 | coordinates,fb_coordinates=[],[]
223 | alike_kpts1,alike_kpts2=[],[]
224 | DA_normals1,DA_normals2=[],[]
225 |
226 | # import pdb;pdb.set_trace()
227 |
228 | fb_feats1,fb_feats2=[],[]
229 | for b in range(feats1.shape[0]):
230 | feat1=feats1[b].permute(1,2,0).reshape(-1,feats1.shape[1])
231 | feat2=feats2[b].permute(1,2,0).reshape(-1,feats2.shape[1])
232 |
233 | coordinate=self.net.fine_matcher(torch.cat([feat1,feat2],dim=-1))
234 | coordinates.append(coordinate)
235 |
236 | fb_feat1=self.net.forward2(feats1[b].unsqueeze(0),kpts1[b].unsqueeze(0),normals1[b].unsqueeze(0))
237 | fb_feat2=self.net.forward2(feats2[b].unsqueeze(0),kpts2[b].unsqueeze(0),normals2[b].unsqueeze(0))
238 |
239 | fb_coordinate=self.net.fine_matcher(torch.cat([fb_feat1,fb_feat2],dim=-1))
240 | fb_coordinates.append(fb_coordinate)
241 |
242 | fb_feats1.append(fb_feat1.unsqueeze(0));fb_feats2.append(fb_feat2.unsqueeze(0))
243 |
244 | img1,img2=imgs1_t[b],imgs2_t[b]
245 | img1=img1.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255
246 | img2=img2.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255
247 | alike_kpt1=torch.tensor(self.alike_net.extract_alike_kpts(img1),device=self.dev)
248 | alike_kpt2=torch.tensor(self.alike_net.extract_alike_kpts(img2),device=self.dev)
249 | alike_kpts1.append(alike_kpt1);alike_kpts2.append(alike_kpt2)
250 |
251 | # import pdb;pdb.set_trace()
252 | for b in range(len(imgs1_np)):
253 | megadepth_depth1,megadepth_norm1=self.depth_net.extract(imgs1_np[b])
254 | megadepth_depth2,megadepth_norm2=self.depth_net.extract(imgs2_np[b])
255 | DA_normals1.append(megadepth_norm1);DA_normals2.append(megadepth_norm2)
256 |
257 | # import pdb;pdb.set_trace()
258 | fb_feats1=torch.cat(fb_feats1,dim=0)
259 | fb_feats2=torch.cat(fb_feats2,dim=0)
260 | fb_feats1=fb_feats1.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
261 | fb_feats2=fb_feats2.reshape(feats2.shape[0],feats2.shape[2],feats2.shape[3],-1).permute(0,3,1,2)
262 |
263 | coordinates=torch.cat(coordinates,dim=0)
264 | coordinates=coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
265 |
266 | fb_coordinates=torch.cat(fb_coordinates,dim=0)
267 | fb_coordinates=fb_coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
268 |
269 | # end=time.perf_counter()
270 | # print(f"forward1 cost {end-start} seconds")
271 |
272 | loss_items = []
273 |
274 | # import pdb;pdb.set_trace()
275 | loss_info=self.loss_fn(
276 | feats1,fb_feats1,kpts1,normals1,
277 | feats2,fb_feats2,kpts2,normals2,
278 | positives_coarse,
279 | coordinates,fb_coordinates,
280 | alike_kpts1,alike_kpts2,
281 | DA_normals1,DA_normals2,
282 | self.megadepth_batch_size,self.coco_batch_size)
283 |
284 | loss_descs,acc_coarse=loss_info['loss_descs'],loss_info['acc_coarse']
285 | loss_coordinates,acc_coordinates=loss_info['loss_coordinates'],loss_info['acc_coordinates']
286 | loss_fb_descs,acc_fb_coarse=loss_info['loss_fb_descs'],loss_info['acc_fb_coarse']
287 | loss_fb_coordinates,acc_fb_coordinates=loss_info['loss_fb_coordinates'],loss_info['acc_fb_coordinates']
288 | loss_kpts,acc_kpt=loss_info['loss_kpts'],loss_info['acc_kpt']
289 | loss_normals=loss_info['loss_normals']
290 |
291 | loss_items.append(loss_fb_descs.unsqueeze(0))
292 | loss_items.append(loss_kpts.unsqueeze(0))
293 | loss_items.append(loss_normals.unsqueeze(0))
294 |
295 | if self.use_coord_loss:
296 | loss_items.append(loss_fb_coordinates.unsqueeze(0))
297 |
298 | # nb_coarse = len(m1)
299 | # nb_coarse = len(fb_m1)
300 | loss = torch.cat(loss_items, -1).mean()
301 |
302 | # Compute Backward Pass
303 | loss.backward()
304 | torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.)
305 | self.opt.step()
306 | self.opt.zero_grad()
307 | self.scheduler.step()
308 |
309 | # import pdb;pdb.set_trace()
310 | if (i+1) % self.save_ckpt_every == 0:
311 | print('saving iter ', i+1)
312 | torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth')
313 |
314 | pbar.set_description(
315 | 'Loss: {:.4f} \
316 | loss_descs: {:.3f} acc_coarse: {:.3f} \
317 | loss_coordinates: {:.3f} acc_coordinates: {:.3f} \
318 | loss_fb_descs: {:.3f} acc_fb_coarse: {:.3f} \
319 | loss_fb_coordinates: {:.3f} acc_fb_coordinates: {:.3f} \
320 | loss_kpts: {:.3f} acc_kpts: {:.3f} \
321 | loss_normals: {:.3f}'.format( \
322 | loss.item(), \
323 | loss_descs.item(), acc_coarse, \
324 | loss_coordinates.item(), acc_coordinates, \
325 | loss_fb_descs.item(), acc_fb_coarse, \
326 | loss_fb_coordinates.item(), acc_fb_coordinates, \
327 | loss_kpts.item(), acc_kpt, \
328 | loss_normals.item()) )
329 | pbar.update(1)
330 |
331 | # Log metrics
332 | self.writer.add_scalar('Loss/total', loss.item(), i)
333 | self.writer.add_scalar('Accuracy/acc_coarse', acc_coarse, i)
334 | self.writer.add_scalar('Accuracy/acc_coordinates', acc_coordinates, i)
335 | self.writer.add_scalar('Accuracy/acc_fb_coarse', acc_fb_coarse, i)
336 | self.writer.add_scalar('Accuracy/acc_fb_coordinates', acc_fb_coordinates, i)
337 | self.writer.add_scalar('Loss/descs', loss_descs.item(), i)
338 | self.writer.add_scalar('Loss/coordinates', loss_coordinates.item(), i)
339 | self.writer.add_scalar('Loss/fb_descs', loss_fb_descs.item(), i)
340 | self.writer.add_scalar('Loss/fb_coordinates', loss_fb_coordinates.item(), i)
341 | self.writer.add_scalar('Loss/kpts', loss_kpts.item(), i)
342 | self.writer.add_scalar('Loss/normals', loss_normals.item(), i)
343 |
344 |
345 |
346 | if __name__ == '__main__':
347 |
348 | setproctitle.setproctitle(args.name)
349 |
350 | trainer = Trainer(
351 | megadepth_root_path=args.megadepth_root_path,
352 | use_megadepth=args.use_megadepth,
353 | megadepth_batch_size=args.megadepth_batch_size,
354 | coco_root_path=args.coco_root_path,
355 | use_coco=args.use_coco,
356 | coco_batch_size=args.coco_batch_size,
357 | ckpt_save_path=args.ckpt_save_path,
358 | n_steps=args.n_steps,
359 | lr=args.lr,
360 | gamma_steplr=args.gamma_steplr,
361 | training_res=args.training_res,
362 | device_num=args.device_num,
363 | dry_run=args.dry_run,
364 | save_ckpt_every=args.save_ckpt_every,
365 | use_coord_loss=args.use_coord_loss
366 | )
367 |
368 | #The most fun part
369 | trainer.train()
370 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | # default training
2 | nohup python /home/yepeng_liu/code_python/LiftFeat/train.py \
3 | --name LiftFeat_test \
4 | --ckpt_save_path /home/yepeng_liu/code_python/LiftFeat/trained_weights/test \
5 | --device_num 1 \
6 | --use_megadepth \
7 | --megadepth_batch_size 8 \
8 | --use_coco \
9 | --coco_batch_size 4 \
10 | --save_ckpt_every 1000 \
11 | > /home/yepeng_liu/code_python/LiftFeat/trained_weights/test/training.log 2>&1 &
--------------------------------------------------------------------------------
/utils/VisualOdometry.py:
--------------------------------------------------------------------------------
1 | # based on: https://github.com/uoip/monoVO-python
2 |
3 | import numpy as np
4 | import cv2
5 | import logging
6 | import glob
7 |
8 | def create_dataloader(conf):
9 | try:
10 | code_line = f"{conf['name']}(conf)"
11 | loader = eval(code_line)
12 | except NameError:
13 | raise NotImplementedError(f"{conf['name']} is not implemented yet.")
14 |
15 | return loader
16 |
17 | """
18 | 针孔相机模型类:用于定义针孔相机的内参
19 | fx,fy:焦距
20 | cx,cy:光心位置
21 | k1,k2,p1,p2,p3:畸变参数
22 | """
23 | class PinholeCamera(object):
24 | def __init__(self, width, height, fx, fy, cx, cy,
25 | k1=0.0, k2=0.0, p1=0.0, p2=0.0, k3=0.0):
26 | self.width = width
27 | self.height = height
28 | self.fx = fx
29 | self.fy = fy
30 | self.cx = cx
31 | self.cy = cy
32 | self.distortion = (abs(k1) > 0.0000001)
33 | self.d = [k1, k2, p1, p2, k3]
34 |
35 | class KITTILoader(object):
36 | default_config = {
37 | "root_path": "../test_imgs",
38 | "sequence": "00",
39 | "start": 0
40 | }
41 |
42 | def __init__(self, config={}):
43 | self.config = self.default_config
44 | self.config = {**self.config, **config}
45 | logging.info("KITTI Dataset config: ")
46 | logging.info(self.config)
47 |
48 | if self.config["sequence"] in ["00", "01", "02"]:
49 | self.cam = PinholeCamera(1241.0, 376.0, 718.8560, 718.8560, 607.1928, 185.2157)
50 | elif self.config["sequence"] in ["03"]:
51 | self.cam = PinholeCamera(1242.0, 375.0, 721.5377, 721.5377, 609.5593, 172.854)
52 | elif self.config["sequence"] in ["04", "05", "06", "07", "08", "09", "10"]:
53 | self.cam = PinholeCamera(1226.0, 370.0, 707.0912, 707.0912, 601.8873, 183.1104)
54 | else:
55 | raise ValueError(f"Unknown sequence number: {self.config['sequence']}")
56 |
57 | # read ground truth pose
58 | self.pose_path = self.config["root_path"] + "/poses/" + self.config["sequence"] + ".txt"
59 | self.gt_poses = []
60 | with open(self.pose_path) as f:
61 | lines = f.readlines()
62 | for line in lines:
63 | ss = line.strip().split()
64 | pose = np.zeros((1, len(ss)))
65 | for i in range(len(ss)):
66 | pose[0, i] = float(ss[i])
67 |
68 | pose.resize([3, 4])
69 | self.gt_poses.append(pose)
70 |
71 | # image id
72 | self.img_id = self.config["start"]
73 | self.img_N = len(glob.glob(pathname=self.config["root_path"] + "/sequences/" \
74 | + self.config["sequence"] + "/image_0/*.png"))
75 |
76 | def get_cur_pose(self):
77 | return self.gt_poses[self.img_id - 1]
78 |
79 | def __getitem__(self, item):
80 | file_name = self.config["root_path"] + "/sequences/" + self.config["sequence"] \
81 | + "/image_0/" + str(item).zfill(6) + ".png"
82 | img = cv2.imread(file_name)
83 | return img
84 |
85 | def __iter__(self):
86 | return self
87 |
88 | def __next__(self):
89 | if self.img_id < self.img_N:
90 | file_name = self.config["root_path"] + "/sequences/" + self.config["sequence"] \
91 | + "/image_0/" + str(self.img_id).zfill(6) + ".png"
92 | img = cv2.imread(file_name)
93 |
94 | self.img_id += 1
95 |
96 | return img
97 | raise StopIteration()
98 |
99 | def __len__(self):
100 | return self.img_N - self.config["start"]
101 |
102 |
103 | def create_detector(conf):
104 | try:
105 | code_line = f"{conf['name']}(conf)"
106 | detector = eval(code_line)
107 | except NameError:
108 | raise NotImplementedError(f"{conf['name']} is not implemented yet.")
109 |
110 | return detector
111 |
112 |
113 | def create_matcher(conf):
114 | try:
115 | code_line = f"{conf['name']}(conf)"
116 | matcher = eval(code_line)
117 | except NameError:
118 | raise NotImplementedError(f"{conf['name']} is not implemented yet.")
119 |
120 | return matcher
121 |
122 | class FrameByFrameMatcher(object):
123 | default_config = {
124 | "type": "FLANN",
125 | "KNN": {
126 | "HAMMING": True, # For ORB Binary descriptor, only can use hamming matching
127 | "first_N": 300, # For hamming matching, use first N min matches
128 | },
129 | "FLANN": {
130 | "kdTrees": 5,
131 | "searchChecks": 50
132 | },
133 | "distance_ratio": 0.75
134 | }
135 |
136 | def __init__(self, config={}):
137 | self.config = self.default_config
138 | self.config = {**self.config, **config}
139 | logging.info("Frame by frame matcher config: ")
140 | logging.info(self.config)
141 |
142 | if self.config["type"] == "KNN":
143 | logging.info("creating brutal force matcher...")
144 | if self.config["KNN"]["HAMMING"]:
145 | logging.info("brutal force with hamming norm.")
146 | self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
147 | else:
148 | self.matcher = cv2.BFMatcher()
149 | elif self.config["type"] == "FLANN":
150 | logging.info("creating FLANN matcher...")
151 | # FLANN parameters
152 | FLANN_INDEX_KDTREE = 1
153 | index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=self.config["FLANN"]["kdTrees"])
154 | search_params = dict(checks=self.config["FLANN"]["searchChecks"]) # or pass empty dictionary
155 | self.matcher = cv2.FlannBasedMatcher(index_params, search_params)
156 | else:
157 | raise ValueError(f"Unknown matcher type: {self.matcher_type}")
158 |
159 | def match(self, kptdescs):
160 | self.good = []
161 | # get shape of the descriptor
162 | self.descriptor_shape = kptdescs["ref"]["descriptors"].shape[1]
163 |
164 | if self.config["type"] == "KNN" and self.config["KNN"]["HAMMING"]:
165 | logging.debug("KNN keypoints matching...")
166 | matches = self.matcher.match(kptdescs["ref"]["descriptors"], kptdescs["cur"]["descriptors"])
167 | # Sort them in the order of their distance.
168 | matches = sorted(matches, key=lambda x: x.distance)
169 | # self.good = matches[:self.config["KNN"]["first_N"]]
170 | for i in range(self.config["KNN"]["first_N"]):
171 | self.good.append([matches[i]])
172 | else:
173 | logging.debug("FLANN keypoints matching...")
174 | matches = self.matcher.knnMatch(kptdescs["ref"]["descriptors"], kptdescs["cur"]["descriptors"], k=2)
175 | # Apply ratio test
176 | for m, n in matches:
177 | if m.distance < self.config["distance_ratio"] * n.distance:
178 | self.good.append([m])
179 | # Sort them in the order of their distance.
180 | self.good = sorted(self.good, key=lambda x: x[0].distance)
181 | return self.good
182 |
183 | def get_good_keypoints(self, kptdescs):
184 | logging.debug("getting matched keypoints...")
185 | kp_ref = np.zeros([len(self.good), 2])
186 | kp_cur = np.zeros([len(self.good), 2])
187 | match_dist = np.zeros([len(self.good)])
188 | for i, m in enumerate(self.good):
189 | kp_ref[i, :] = kptdescs["ref"]["keypoints"][m[0].queryIdx]
190 | kp_cur[i, :] = kptdescs["cur"]["keypoints"][m[0].trainIdx]
191 | match_dist[i] = m[0].distance
192 |
193 | ret_dict = {
194 | "ref_keypoints": kp_ref,
195 | "cur_keypoints": kp_cur,
196 | "match_score": self.normalised_matching_scores(match_dist)
197 | }
198 | return ret_dict
199 |
200 | def __call__(self, kptdescs):
201 | self.match(kptdescs)
202 | return self.get_good_keypoints(kptdescs)
203 |
204 | def normalised_matching_scores(self, match_dist):
205 |
206 | if self.config["type"] == "KNN" and self.config["KNN"]["HAMMING"]:
207 | # ORB Hamming distance
208 | best, worst = 0, self.descriptor_shape * 8 # min and max hamming distance
209 | worst = worst / 4 # scale
210 | else:
211 | # for non-normalized descriptor
212 | if match_dist.max() > 1:
213 | best, worst = 0, self.descriptor_shape * 2 # estimated range
214 | else:
215 | best, worst = 0, 1
216 |
217 | # normalise the score!
218 | match_scores = match_dist / worst
219 | # range constraint
220 | match_scores[match_scores > 1] = 1
221 | match_scores[match_scores < 0] = 0
222 | # 1: for best match, 0: for worst match
223 | match_scores = 1 - match_scores
224 |
225 | return match_scores
226 |
227 | def draw_matched(self, img0, img1):
228 | pass
229 |
230 | # --- VISUALIZATION ---
231 | # based on: https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/models/utils.py
232 | def plot_keypoints(image, kpts):
233 | kpts = np.round(kpts).astype(int)
234 | for x, y in kpts:
235 | cv2.drawMarker(image, (x, y), (0, 255, 0), cv2.MARKER_CROSS, 6)
236 |
237 | return image
238 |
239 | class VisualOdometry(object):
240 | """
241 | A simple frame by frame visual odometry
242 | """
243 |
244 | def __init__(self, detector, matcher, cam):
245 | """
246 | :param detector: a feature detector can detect keypoints their descriptors
247 | :param matcher: a keypoints matcher matching keypoints between two frames
248 | :param cam: camera parameters
249 | """
250 | # feature detector and keypoints matcher
251 | self.detector = detector
252 | self.matcher = matcher
253 |
254 | # camera parameters
255 | self.focal = cam.fx
256 | self.pp = (cam.cx, cam.cy)
257 |
258 | # frame index counter
259 | self.index = 0
260 |
261 | # keypoints and descriptors
262 | self.kptdescs = {}
263 |
264 | # match points
265 | self.match_kps = {}
266 |
267 | # pose of current frame
268 | self.cur_R = None
269 | self.cur_t = None
270 |
271 | def update(self, image, absolute_scale=1):
272 | """
273 | update a new image to visual odometry, and compute the pose
274 | :param image: input image
275 | :param absolute_scale: the absolute scale between current frame and last frame
276 | :return: R and t of current frame
277 | """
278 | predict_data = self.detector.extract(image)
279 | kptdesc = {
280 | "keypoints": predict_data["keypoints"].cpu().detach().numpy(),
281 | "descriptors": predict_data["descriptors"].cpu().detach().numpy()
282 | }
283 |
284 | # first frame
285 | if self.index == 0:
286 | # save keypoints and descriptors
287 | self.kptdescs["cur"] = kptdesc
288 |
289 | # start point
290 | self.cur_R = np.identity(3)
291 | self.cur_t = np.zeros((3, 1))
292 | else:
293 | # update keypoints and descriptors
294 | self.kptdescs["cur"] = kptdesc
295 |
296 | # match keypoints
297 | matches = self.matcher(self.kptdescs)
298 | self.match_kps = {"cur":matches['cur_keypoints'], "ref":matches['ref_keypoints']}
299 |
300 | # compute relative R,t between ref and cur frame
301 | E, mask = cv2.findEssentialMat(matches['cur_keypoints'], matches['ref_keypoints'],
302 | focal=self.focal, pp=self.pp,
303 | method=cv2.RANSAC, prob=0.999, threshold=1.0)
304 | _, R, t, mask = cv2.recoverPose(E, matches['cur_keypoints'], matches['ref_keypoints'],
305 | focal=self.focal, pp=self.pp)
306 |
307 | # get absolute pose based on absolute_scale
308 | if (absolute_scale > 0.1):
309 | self.cur_t = self.cur_t + absolute_scale * self.cur_R.dot(t)
310 | self.cur_R = R.dot(self.cur_R)
311 |
312 | self.kptdescs["ref"] = self.kptdescs["cur"]
313 |
314 | self.index += 1
315 | return self.cur_R, self.cur_t
316 |
317 | # 计算当前帧和上一帧的绝对位移,用于调整相机的平移向量
318 | class AbosluteScaleComputer(object):
319 | def __init__(self):
320 | self.prev_pose = None
321 | self.cur_pose = None
322 | self.count = 0
323 |
324 | def update(self, pose):
325 | self.cur_pose = pose
326 |
327 | scale = 1.0
328 | if self.count != 0:
329 | scale = np.sqrt(
330 | (self.cur_pose[0, 3] - self.prev_pose[0, 3]) * (self.cur_pose[0, 3] - self.prev_pose[0, 3])
331 | + (self.cur_pose[1, 3] - self.prev_pose[1, 3]) * (self.cur_pose[1, 3] - self.prev_pose[1, 3])
332 | + (self.cur_pose[2, 3] - self.prev_pose[2, 3]) * (self.cur_pose[2, 3] - self.prev_pose[2, 3]))
333 |
334 | self.count += 1
335 | self.prev_pose = self.cur_pose
336 | return scale
337 |
338 |
339 |
340 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/VisualOdometry.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/VisualOdometry.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/alike_wrapper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/alike_wrapper.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/depth_anything_wrapper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/depth_anything_wrapper.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/featurebooster.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/featurebooster.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/featurebooster.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/featurebooster.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/post_process.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/post_process.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/alike_wrapper.py:
--------------------------------------------------------------------------------
1 | """
2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3 | """
4 |
5 |
6 | import sys
7 | import os
8 |
9 | ALIKE_PATH = '/home/yepeng_liu/code_python/multimodal_remote/ALIKE'
10 | sys.path.append(ALIKE_PATH)
11 |
12 | import torch
13 | import torch.nn as nn
14 | from alike import ALike
15 | import cv2
16 | import numpy as np
17 |
18 | import pdb
19 |
20 | dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21 |
22 | configs = {
23 | 'alike-t': {'c1': 8, 'c2': 16, 'c3': 32, 'c4': 64, 'dim': 64, 'single_head': True, 'radius': 2,
24 | 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-t.pth')},
25 | 'alike-s': {'c1': 8, 'c2': 16, 'c3': 48, 'c4': 96, 'dim': 96, 'single_head': True, 'radius': 2,
26 | 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-s.pth')},
27 | 'alike-n': {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True, 'radius': 2,
28 | 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-n.pth')},
29 | 'alike-l': {'c1': 32, 'c2': 64, 'c3': 128, 'c4': 128, 'dim': 128, 'single_head': False, 'radius': 2,
30 | 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-l.pth')},
31 | }
32 |
33 |
34 | class ALikeExtractor(nn.Module):
35 | def __init__(self,model_type,device) -> None:
36 | super().__init__()
37 | self.net=ALike(**configs[model_type],device=device,top_k=4096,scores_th=0.1,n_limit=8000)
38 |
39 |
40 | @torch.inference_mode()
41 | def extract_alike_kpts(self,img):
42 | pred0=self.net(img,sub_pixel=True)
43 | return pred0['keypoints']
44 |
45 |
46 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 |
5 | featureboost_config = {
6 | "keypoint_dim": 65,
7 | "keypoint_encoder": [128, 64, 64],
8 | "normal_dim": 192,
9 | "normal_encoder": [128, 64, 64],
10 | "descriptor_encoder": [64, 64],
11 | "descriptor_dim": 64,
12 | "Attentional_layers": 3,
13 | "last_activation": None,
14 | "l2_normalization": None,
15 | "output_dim": 64,
16 | }
--------------------------------------------------------------------------------
/utils/depth_anything_wrapper.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import glob
4 | import matplotlib
5 | import numpy as np
6 | import os
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torchvision.transforms import Compose
11 | import sys
12 |
13 | sys.path.append("/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2")
14 | from depth_anything_v2.dpt_opt import DepthAnythingV2
15 | from depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet
16 |
17 | import time
18 |
19 | VITS_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vits.pth"
20 | VITB_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vitb.pth"
21 | VITL_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth"
22 |
23 | model_configs = {
24 | "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
25 | "vitb": {
26 | "encoder": "vitb",
27 | "features": 128,
28 | "out_channels": [96, 192, 384, 768],
29 | },
30 | "vitl": {
31 | "encoder": "vitl",
32 | "features": 256,
33 | "out_channels": [256, 512, 1024, 1024],
34 | },
35 | "vitg": {
36 | "encoder": "vitg",
37 | "features": 384,
38 | "out_channels": [1536, 1536, 1536, 1536],
39 | },
40 | }
41 |
42 | class DepthAnythingExtractor(nn.Module):
43 | def __init__(self, encoder_type, device, input_size, process_size=(608,800)):
44 | super().__init__()
45 | self.net = DepthAnythingV2(**model_configs[encoder_type])
46 | self.device = device
47 | if encoder_type == "vits":
48 | print(f"loading {VITS_MODEL_PATH}")
49 | self.net.load_state_dict(torch.load(VITS_MODEL_PATH, map_location="cpu"))
50 | elif encoder_type == "vitb":
51 | print(f"loading {VITB_MODEL_PATH}")
52 | self.net.load_state_dict(torch.load(VITB_MODEL_PATH, map_location="cpu"))
53 | elif encoder_type == "vitl":
54 | print(f"loading {VITL_MODEL_PATH}")
55 | self.net.load_state_dict(torch.load(VITL_MODEL_PATH, map_location="cpu"))
56 | else:
57 | raise RuntimeError("unsupport encoder type")
58 | self.net.to(self.device).eval()
59 | self.tranform = Compose([
60 | Resize(
61 | width=input_size,
62 | height=input_size,
63 | resize_target=False,
64 | keep_aspect_ratio=True,
65 | ensure_multiple_of=14,
66 | resize_method='lower_bound',
67 | image_interpolation_method=cv2.INTER_CUBIC,
68 | ),
69 | NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
70 | PrepareForNet(),
71 | ])
72 | self.process_size=process_size
73 | self.input_size=input_size
74 |
75 | @torch.inference_mode()
76 | def infer_image(self,img):
77 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
78 |
79 | img = self.tranform({'image': img})['image']
80 |
81 | img = torch.from_numpy(img).unsqueeze(0)
82 |
83 | img = img.to(self.device)
84 |
85 | with torch.no_grad():
86 | depth = self.net.forward(img)
87 |
88 | depth = F.interpolate(depth[:, None], self.process_size, mode="bilinear", align_corners=True)[0, 0]
89 |
90 | return depth.cpu().numpy()
91 |
92 | @torch.inference_mode()
93 | def compute_normal_map_torch(self, depth_map, scale=1.0):
94 | """
95 | 通过深度图计算法向量 (PyTorch 实现)
96 |
97 | 参数:
98 | depth_map (torch.Tensor): 深度图,形状为 (H, W)
99 | scale (float): 深度值的比例因子,用于调整深度图中的梯度计算
100 |
101 | 返回:
102 | torch.Tensor: 法向量图,形状为 (H, W, 3)
103 | """
104 | if depth_map.ndim != 2:
105 | raise ValueError("输入 depth_map 必须是二维张量。")
106 |
107 | # 计算深度图的梯度
108 | dzdx = torch.diff(depth_map, dim=1, append=depth_map[:, -1:]) * scale
109 | dzdy = torch.diff(depth_map, dim=0, append=depth_map[-1:, :]) * scale
110 |
111 | # 初始化法向量图
112 | H, W = depth_map.shape
113 | normal_map = torch.zeros((H, W, 3), dtype=depth_map.dtype, device=depth_map.device)
114 | normal_map[:, :, 0] = -dzdx # x 分量
115 | normal_map[:, :, 1] = -dzdy # y 分量
116 | normal_map[:, :, 2] = 1.0 # z 分量
117 |
118 | # 归一化法向量
119 | norm = torch.linalg.norm(normal_map, dim=2, keepdim=True)
120 | norm = torch.where(norm == 0, torch.tensor(1.0, device=depth_map.device), norm) # 避免除以零
121 | normal_map /= norm
122 |
123 | return normal_map
124 |
125 | @torch.inference_mode()
126 | def extract(self, img):
127 | depth = self.infer_image(img)
128 | depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
129 | depth_t=torch.from_numpy(depth).float().to(self.device)
130 | normal_map = self.compute_normal_map_torch(depth_t,1.0)
131 | return depth_t,normal_map
132 |
133 |
134 | if __name__=="__main__":
135 | img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg')
136 | img=cv2.imread(img_path)
137 | img=cv2.resize(img,(800,608))
138 | import pdb;pdb.set_trace()
139 | DAExtractor=DepthAnythingExtractor('vitb',torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'),256)
140 | depth_t,norm=DAExtractor.extract(img)
141 | norm=norm.cpu().numpy()
142 | norm=(norm+1)/2*255
143 | norm=norm.astype(np.uint8)
144 | cv2.imwrite(os.path.join(os.path.dirname(__file__),"norm.png"),norm)
145 | start=time.perf_counter()
146 | for i in range(20):
147 | depth_t,norm=DAExtractor.extract(img)
148 | end=time.perf_counter()
149 | print(f"cost {end-start} seconds")
150 |
--------------------------------------------------------------------------------
/utils/featurebooster.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def MLP(channels: List[int], do_bn: bool = False) -> nn.Module:
9 | """ Multi-layer perceptron """
10 | n = len(channels)
11 | layers = []
12 | for i in range(1, n):
13 | layers.append(nn.Linear(channels[i - 1], channels[i]))
14 | if i < (n-1):
15 | if do_bn:
16 | layers.append(nn.BatchNorm1d(channels[i]))
17 | layers.append(nn.ReLU())
18 | return nn.Sequential(*layers)
19 |
20 | def MLP_no_ReLU(channels: List[int], do_bn: bool = False) -> nn.Module:
21 | """ Multi-layer perceptron """
22 | n = len(channels)
23 | layers = []
24 | for i in range(1, n):
25 | layers.append(nn.Linear(channels[i - 1], channels[i]))
26 | if i < (n-1):
27 | if do_bn:
28 | layers.append(nn.BatchNorm1d(channels[i]))
29 | return nn.Sequential(*layers)
30 |
31 |
32 | class KeypointEncoder(nn.Module):
33 | """ Encoding of geometric properties using MLP """
34 | def __init__(self, keypoint_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
35 | super().__init__()
36 | self.encoder = MLP([keypoint_dim] + layers + [feature_dim])
37 | self.use_dropout = dropout
38 | self.dropout = nn.Dropout(p=p)
39 |
40 | def forward(self, kpts):
41 | if self.use_dropout:
42 | return self.dropout(self.encoder(kpts))
43 | return self.encoder(kpts)
44 |
45 | class NormalEncoder(nn.Module):
46 | """ Encoding of geometric properties using MLP """
47 | def __init__(self, normal_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
48 | super().__init__()
49 | self.encoder = MLP_no_ReLU([normal_dim] + layers + [feature_dim])
50 | self.use_dropout = dropout
51 | self.dropout = nn.Dropout(p=p)
52 |
53 | def forward(self, kpts):
54 | if self.use_dropout:
55 | return self.dropout(self.encoder(kpts))
56 | return self.encoder(kpts)
57 |
58 |
59 | class DescriptorEncoder(nn.Module):
60 | """ Encoding of visual descriptor using MLP """
61 | def __init__(self, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
62 | super().__init__()
63 | self.encoder = MLP([feature_dim] + layers + [feature_dim])
64 | self.use_dropout = dropout
65 | self.dropout = nn.Dropout(p=p)
66 |
67 | def forward(self, descs):
68 | residual = descs
69 | if self.use_dropout:
70 | return residual + self.dropout(self.encoder(descs))
71 | return residual + self.encoder(descs)
72 |
73 |
74 | class AFTAttention(nn.Module):
75 | """ Attention-free attention """
76 | def __init__(self, d_model: int, dropout: bool = False, p: float = 0.1) -> None:
77 | super().__init__()
78 | self.dim = d_model
79 | self.query = nn.Linear(d_model, d_model)
80 | self.key = nn.Linear(d_model, d_model)
81 | self.value = nn.Linear(d_model, d_model)
82 | self.proj = nn.Linear(d_model, d_model)
83 | # self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
84 | self.use_dropout = dropout
85 | self.dropout = nn.Dropout(p=p)
86 |
87 | def forward(self, x: torch.Tensor) -> torch.Tensor:
88 | residual = x
89 | q = self.query(x)
90 | k = self.key(x)
91 | v = self.value(x)
92 | # q = torch.sigmoid(q)
93 | k = k.T
94 | k = torch.softmax(k, dim=-1)
95 | k = k.T
96 | kv = (k * v).sum(dim=-2, keepdim=True)
97 | x = q * kv
98 | x = self.proj(x)
99 | if self.use_dropout:
100 | x = self.dropout(x)
101 | x += residual
102 | # x = self.layer_norm(x)
103 | return x
104 |
105 |
106 | class PositionwiseFeedForward(nn.Module):
107 | def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1) -> None:
108 | super().__init__()
109 | self.mlp = MLP([feature_dim, feature_dim*2, feature_dim])
110 | # self.layer_norm = nn.LayerNorm(feature_dim, eps=1e-6)
111 | self.use_dropout = dropout
112 | self.dropout = nn.Dropout(p=p)
113 |
114 | def forward(self, x: torch.Tensor) -> torch.Tensor:
115 | residual = x
116 | x = self.mlp(x)
117 | if self.use_dropout:
118 | x = self.dropout(x)
119 | x += residual
120 | # x = self.layer_norm(x)
121 | return x
122 |
123 |
124 | class AttentionalLayer(nn.Module):
125 | def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1):
126 | super().__init__()
127 | self.attn = AFTAttention(feature_dim, dropout=dropout, p=p)
128 | self.ffn = PositionwiseFeedForward(feature_dim, dropout=dropout, p=p)
129 |
130 | def forward(self, x: torch.Tensor) -> torch.Tensor:
131 | # import pdb;pdb.set_trace()
132 | x = self.attn(x)
133 | x = self.ffn(x)
134 | return x
135 |
136 |
137 | class AttentionalNN(nn.Module):
138 | def __init__(self, feature_dim: int, layer_num: int, dropout: bool = False, p: float = 0.1) -> None:
139 | super().__init__()
140 | self.layers = nn.ModuleList([
141 | AttentionalLayer(feature_dim, dropout=dropout, p=p)
142 | for _ in range(layer_num)])
143 |
144 | def forward(self, desc: torch.Tensor) -> torch.Tensor:
145 | for layer in self.layers:
146 | desc = layer(desc)
147 | return desc
148 |
149 |
150 | class FeatureBooster(nn.Module):
151 | default_config = {
152 | 'descriptor_dim': 128,
153 | 'keypoint_encoder': [32, 64, 128],
154 | 'Attentional_layers': 3,
155 | 'last_activation': 'relu',
156 | 'l2_normalization': True,
157 | 'output_dim': 128
158 | }
159 |
160 | def __init__(self, config, dropout=False, p=0.1, use_kenc=True, use_normal=True, use_cross=True):
161 | super().__init__()
162 | self.config = {**self.default_config, **config}
163 | self.use_kenc = use_kenc
164 | self.use_cross = use_cross
165 | self.use_normal = use_normal
166 |
167 | if use_kenc:
168 | self.kenc = KeypointEncoder(self.config['keypoint_dim'], self.config['descriptor_dim'], self.config['keypoint_encoder'], dropout=dropout)
169 |
170 | if use_normal:
171 | self.nenc = NormalEncoder(self.config['normal_dim'], self.config['descriptor_dim'], self.config['normal_encoder'], dropout=dropout)
172 |
173 | if self.config.get('descriptor_encoder', False):
174 | self.denc = DescriptorEncoder(self.config['descriptor_dim'], self.config['descriptor_encoder'], dropout=dropout)
175 | else:
176 | self.denc = None
177 |
178 | if self.use_cross:
179 | self.attn_proj = AttentionalNN(feature_dim=self.config['descriptor_dim'], layer_num=self.config['Attentional_layers'], dropout=dropout)
180 |
181 | # self.final_proj = nn.Linear(self.config['descriptor_dim'], self.config['output_dim'])
182 |
183 | self.use_dropout = dropout
184 | self.dropout = nn.Dropout(p=p)
185 |
186 | # self.layer_norm = nn.LayerNorm(self.config['descriptor_dim'], eps=1e-6)
187 |
188 | if self.config.get('last_activation', False):
189 | if self.config['last_activation'].lower() == 'relu':
190 | self.last_activation = nn.ReLU()
191 | elif self.config['last_activation'].lower() == 'sigmoid':
192 | self.last_activation = nn.Sigmoid()
193 | elif self.config['last_activation'].lower() == 'tanh':
194 | self.last_activation = nn.Tanh()
195 | else:
196 | raise Exception('Not supported activation "%s".' % self.config['last_activation'])
197 | else:
198 | self.last_activation = None
199 |
200 | def forward(self, desc, kpts, normals):
201 | # import pdb;pdb.set_trace()
202 | ## Self boosting
203 | # Descriptor MLP encoder
204 | if self.denc is not None:
205 | desc = self.denc(desc)
206 | # Geometric MLP encoder
207 | if self.use_kenc:
208 | desc = desc + self.kenc(kpts)
209 | if self.use_dropout:
210 | desc = self.dropout(desc)
211 |
212 | # 法向量特征 encoder
213 | if self.use_normal:
214 | desc = desc + self.nenc(normals)
215 | if self.use_dropout:
216 | desc = self.dropout(desc)
217 |
218 | ## Cross boosting
219 | # Multi-layer Transformer network.
220 | if self.use_cross:
221 | # desc = self.attn_proj(self.layer_norm(desc))
222 | desc = self.attn_proj(desc)
223 |
224 | ## Post processing
225 | # Final MLP projection
226 | # desc = self.final_proj(desc)
227 | if self.last_activation is not None:
228 | desc = self.last_activation(desc)
229 | # L2 normalization
230 | if self.config['l2_normalization']:
231 | desc = F.normalize(desc, dim=-1)
232 |
233 | return desc
234 |
235 | if __name__ == "__main__":
236 | from config import t1_featureboost_config
237 | fb_net = FeatureBooster(t1_featureboost_config)
238 |
239 | descs=torch.randn([1900,64])
240 | kpts=torch.randn([1900,65])
241 | normals=torch.randn([1900,3])
242 |
243 | import pdb;pdb.set_trace()
244 |
245 | descs_refine=fb_net(descs,kpts,normals)
246 |
247 | print(descs_refine.shape)
248 |
--------------------------------------------------------------------------------
/utils/post_process.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def match_features(feats1, feats2, min_cossim=0.82):
4 | cossim = feats1 @ feats2.t()
5 | cossim_t = feats2 @ feats1.t()
6 | _, match12 = cossim.max(dim=1)
7 | _, match21 = cossim_t.max(dim=1)
8 | idx0 = torch.arange(len(match12), device=match12.device)
9 | mutual = match21[match12] == idx0
10 | # import pdb; pdb.set_trace()
11 | if min_cossim > 0:
12 | best_sim, _ = cossim.max(dim=1)
13 | good = best_sim > min_cossim
14 | idx0 = idx0[mutual & good]
15 | idx1 = match12[mutual & good]
16 | else:
17 | idx0 = idx0[mutual]
18 | idx1 = match12[mutual]
19 |
20 | match_scores = cossim[idx0, idx1]
21 | return idx0, idx1, match_scores
--------------------------------------------------------------------------------
/weights/LiftFeat.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/weights/LiftFeat.pth
--------------------------------------------------------------------------------