├── .gitignore ├── README.md ├── __pycache__ └── liftfeat_wrapper.cpython-38.pyc ├── assert ├── achitecture.png ├── demo_liftfeat.gif ├── demo_sp.gif ├── keypoints_liftfeat.gif ├── query.jpg ├── ref.jpg └── trajectory_liftfeat.gif ├── data └── megadepth_1500.json ├── dataset ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── coco_augmentor.cpython-38.pyc │ ├── coco_wrapper.cpython-38.pyc │ ├── dataset_utils.cpython-38.pyc │ ├── megadepth.cpython-38.pyc │ └── megadepth_wrapper.cpython-38.pyc ├── coco_augmentor.py ├── coco_wrapper.py ├── dataset_utils.py ├── megadepth.py └── megadepth_wrapper.py ├── demo.py ├── evaluation ├── HPatch_evaluation.py ├── MegaDepth1500_evaluation.py ├── __pycache__ │ └── eval_utils.cpython-38.pyc └── eval_utils.py ├── loss ├── __pycache__ │ └── loss.cpython-38.pyc └── loss.py ├── models ├── __pycache__ │ ├── interpolator.cpython-310.pyc │ ├── interpolator.cpython-38.pyc │ ├── liftfeat_wrapper.cpython-310.pyc │ ├── liftfeat_wrapper.cpython-38.pyc │ ├── model.cpython-310.pyc │ └── model.cpython-38.pyc ├── interpolator.py ├── liftfeat_wrapper.py └── model.py ├── requirements.txt ├── tools ├── demo_match_video.py └── demo_vo.py ├── train.py ├── train.sh ├── utils ├── VisualOdometry.py ├── __init__.py ├── __pycache__ │ ├── VisualOdometry.cpython-38.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── alike_wrapper.cpython-38.pyc │ ├── config.cpython-310.pyc │ ├── config.cpython-38.pyc │ ├── depth_anything_wrapper.cpython-38.pyc │ ├── featurebooster.cpython-310.pyc │ ├── featurebooster.cpython-38.pyc │ └── post_process.cpython-38.pyc ├── alike_wrapper.py ├── config.py ├── depth_anything_wrapper.py ├── featurebooster.py └── post_process.py └── weights └── LiftFeat.pth /.gitignore: -------------------------------------------------------------------------------- 1 | visualize 2 | trained_weights 3 | data/HPatch 4 | data/megadepth_test_1500 5 | output 6 | issues -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## LiftFeat: 3D Geometry-Aware Local Feature Matching 2 |
3 |
4 | 5 | 6 |
7 | 8 | Real-time SuperPoint demonstration (left) compared to LiftFeat (right) on a textureless scene. 9 | 10 |
11 | 12 | - 🎉 **New!** Training code is now available 🚀 13 | - 🎉 **New!** The test code and pretrained model have been released. 🚀 14 | 15 | ## Table of Contents 16 | - [Introduction](#introduction) 17 | - [Installation](#installation) 18 | - [Usage](#usage) 19 | - [Inference](#inference) 20 | - [Training](#training) 21 | - [Evaluation](#evaluation) 22 | - [Citation](#citation) 23 | - [License](#license) 24 | 25 | ## Introduction 26 | This repository contains the official implementation of the paper: 27 | **[LiftFeat: 3D Geometry-Aware Local Feature Matching](https://www.arxiv.org/abs/2505.03422)**, to be presented at *ICRA 2025*. 28 | 29 | **Overview of LiftFeat's achitecture** 30 |
31 | 32 |
33 | 34 | LiftFeat is a lightweight and robust local feature matching network designed to handle challenging scenarios such as drastic lighting changes, low-texture regions, and repetitive patterns. By incorporating 3D geometric cues through surface normals predicted from monocular depth, LiftFeat enhances the discriminative power of 2D descriptors. Our proposed 3D geometry-aware feature lifting module effectively fuses these cues, leading to significant improvements in tasks like relative pose estimation, homography estimation, and visual localization. 35 | 36 | ## Installation 37 | If you use conda as virtual environment,you can create a new env with: 38 | ```bash 39 | git clone https://github.com/lyp-deeplearning/LiftFeat.git 40 | cd LiftFeat 41 | conda create -n LiftFeat python=3.8 42 | conda activate LiftFeat 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ## Usage 47 | ### Inference with image pair 48 | To run LiftFeat on an image,you can simply run with: 49 | ```bash 50 | python demo.py --img1= --img2= 51 | ``` 52 | 53 | ### Run with video 54 | We provide a simple real-time demo that matches a template image to each frame of a video stream using our LiftFeat method. 55 | 56 | You can run the demo with the following command: 57 | ```bash 58 | python tools/demo_match_video.py --img your_template.png --video your.mp4 59 | ``` 60 | 61 | We also provide a [sample template image and video with lighting variation](https://drive.google.com/drive/folders/1b-t-f2Bt47KU674bPI09bGtJ9BHx05Yu?usp=drive_link) for demonstration purposes. 62 | 63 | ### Visual Odometry Demo 64 | We have added a new application to evaluate LiftFeat on visual odometry (VO) tasks. 65 | 66 | We use sequences from the KITTI dataset to demonstrate frame-to-frame motion estimation. Running the script below will generate the estimated camera trajectory and the error curve: 67 | 68 | ```bash 69 | python tools/demo_vo.py --path1 /path/to/gray/images --path2 /path/to/color/images --id 03 70 | ``` 71 | 72 | We also provide a sample [KITTI sequence](https://drive.google.com/drive/folders/1b-t-f2Bt47KU674bPI09bGtJ9BHx05Yu?usp=drive_link) for quick testing. 73 | 74 |
75 | 76 | 77 |
78 | 79 | 80 | ## Training 81 | To train LiftFeat as described in the paper, you will need MegaDepth & COCO_20k subset of COCO2017 dataset as described in the paper *[XFeat: Accelerated Features for Lightweight Image Matching](https://arxiv.org/abs/2404.19174)* 82 | You can obtain the full COCO2017 train data at https://cocodataset.org/. 83 | However, we [make available](https://drive.google.com/file/d/1ijYsPq7dtLQSl-oEsUOGH1fAy21YLc7H/view?usp=drive_link) a subset of COCO for convenience. We simply selected a subset of 20k images according to image resolution. Please check COCO [terms of use](https://cocodataset.org/#termsofuse) before using the data. 84 | 85 | To reproduce the training setup from the paper, please follow the steps: 86 | 1. Download [COCO_20k](https://drive.google.com/file/d/1ijYsPq7dtLQSl-oEsUOGH1fAy21YLc7H/view?usp=drive_link) containing a subset of COCO2017; 87 | 2. Download MegaDepth dataset. You can follow [LoFTR instructions](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md#download-datasets), we use the same standard as LoFTR. Then put the megadepth indices inside the MegaDepth root folder following the standard below: 88 | ```bash 89 | {megadepth_root_path}/train_data/megadepth_indices #indices 90 | {megadepth_root_path}/MegaDepth_v1 #images & depth maps & poses 91 | ``` 92 | 3. Finally you can call training 93 | ```bash 94 | python train.py --megadepth_root_path /MegaDepth --synthetic_root_path /coco_20k --ckpt_save_path /path/to/ckpts 95 | ``` 96 | 97 | ### Evaluation 98 | All evaluation code are in *evaluation*, you can download **HPatch** dataset following [D2-Net](https://github.com/mihaidusmanu/d2-net/tree/master) and download **MegaDepth** test dataset following [LoFTR](https://github.com/zju3dv/LoFTR/tree/master). 99 | 100 | **Download and process HPatch** 101 | ```bash 102 | cd /data 103 | 104 | # Download the dataset 105 | wget https://huggingface.co/datasets/vbalnt/hpatches/resolve/main/hpatches-sequences-release.zip 106 | 107 | # Extract the dataset 108 | unzip hpatches-sequences-release.zip 109 | 110 | # Remove the high-resolution sequences 111 | cd hpatches-sequences-release 112 | rm -rf i_contruction i_crownnight i_dc i_pencils i_whitebuilding v_artisans v_astronautis v_talent 113 | 114 | cd /data 115 | 116 | ln -s /data/hpatches-sequences-release ./HPatch 117 | ``` 118 | 119 | **Download and process MegaDepth1500** 120 | We provide download link to [megadepth_test_1500](https://drive.google.com/drive/folders/1nTkK1485FuwqA0DbZrK2Cl0WnXadUZdc) 121 | ```bash 122 | tar xvf 123 | 124 | cd /data 125 | 126 | ln -s ./megadepth_test_1500 127 | ``` 128 | 129 | 130 | **Homography Estimation** 131 | ```bash 132 | python evaluation/HPatch_evaluation.py 133 | ``` 134 | 135 | **Relative Pose Estimation** 136 | 137 | For *Megadepth1500* dataset: 138 | ```bash 139 | python evaluation/MegaDepth1500_evaluation.py 140 | ``` 141 | 142 | 143 | ## Citation 144 | If you find this code useful for your research, please cite the paper: 145 | ```bibtex 146 | @misc{liu2025liftfeat3dgeometryawarelocal, 147 | title={LiftFeat: 3D Geometry-Aware Local Feature Matching}, 148 | author={Yepeng Liu and Wenpeng Lai and Zhou Zhao and Yuxuan Xiong and Jinchi Zhu and Jun Cheng and Yongchao Xu}, 149 | year={2025}, 150 | eprint={2505.03422}, 151 | archivePrefix={arXiv}, 152 | primaryClass={cs.CV}, 153 | url={https://arxiv.org/abs/2505.03422}, 154 | } 155 | ``` 156 | 157 | ## License 158 | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) 159 | 160 | 161 | ## Acknowledgements 162 | We would like to thank the authors of the following open-source repositories for their valuable contributions, which have inspired or supported this work: 163 | 164 | - [verlab/accelerated_features](https://github.com/verlab/accelerated_features) 165 | - [zju3dv/LoFTR](https://github.com/zju3dv/LoFTR) 166 | - [rpautrat/SuperPoint](https://github.com/rpautrat/SuperPoint) 167 | - [Depth-Anything-V2](https://github.com/DepthAnything/Depth-Anything-V2) 168 | - [Python-VO](https://github.com/Shiaoming/Python-VO) 169 | 170 | We deeply appreciate the efforts of the research community in releasing high-quality codebases. 171 | -------------------------------------------------------------------------------- /__pycache__/liftfeat_wrapper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/__pycache__/liftfeat_wrapper.cpython-38.pyc -------------------------------------------------------------------------------- /assert/achitecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/achitecture.png -------------------------------------------------------------------------------- /assert/demo_liftfeat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/demo_liftfeat.gif -------------------------------------------------------------------------------- /assert/demo_sp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/demo_sp.gif -------------------------------------------------------------------------------- /assert/keypoints_liftfeat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/keypoints_liftfeat.gif -------------------------------------------------------------------------------- /assert/query.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/query.jpg -------------------------------------------------------------------------------- /assert/ref.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/ref.jpg -------------------------------------------------------------------------------- /assert/trajectory_liftfeat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/assert/trajectory_liftfeat.gif -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/coco_augmentor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/coco_augmentor.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/coco_wrapper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/coco_wrapper.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/dataset_utils.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/megadepth.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/megadepth.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/megadepth_wrapper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/dataset/__pycache__/megadepth_wrapper.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/coco_augmentor.py: -------------------------------------------------------------------------------- 1 | """ 2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching" 3 | COCO_20k image augmentor 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import Dataset 9 | import torch.utils.data as data 10 | from torchvision import transforms 11 | import torch.nn.functional as F 12 | 13 | import cv2 14 | import kornia 15 | import kornia.augmentation as K 16 | from kornia.geometry.transform import get_tps_transform as findTPS 17 | from kornia.geometry.transform import warp_points_tps, warp_image_tps 18 | 19 | import glob 20 | import random 21 | import tqdm 22 | 23 | import numpy as np 24 | import pdb 25 | import time 26 | 27 | random.seed(0) 28 | torch.manual_seed(0) 29 | 30 | def generateRandomTPS(shape,grid=(8,6),GLOBAL_MULTIPLIER=0.3,prob=0.5): 31 | 32 | h, w = shape 33 | sh, sw = h/grid[0], w/grid[1] 34 | src = torch.dstack(torch.meshgrid(torch.arange(0, h + sh , sh), torch.arange(0, w + sw , sw), indexing='ij')) 35 | 36 | offsets = torch.rand(grid[0]+1, grid[1]+1, 2) - 0.5 37 | offsets *= torch.tensor([ sh/2, sw/2 ]).view(1, 1, 2) * min(0.97, 2.0 * GLOBAL_MULTIPLIER) 38 | dst = src + offsets if np.random.uniform() < prob else src 39 | 40 | src, dst = src.view(1, -1, 2), dst.view(1, -1, 2) 41 | src = (src / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1. 42 | dst = (dst / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1. 43 | weights, A = findTPS(dst, src) 44 | 45 | return src, weights, A 46 | 47 | 48 | def generateRandomHomography(shape,GLOBAL_MULTIPLIER=0.3): 49 | #Generate random in-plane rotation [-theta,+theta] 50 | theta = np.radians(np.random.uniform(-30, 30)) 51 | 52 | #Generate random scale in both x and y 53 | scale_x, scale_y = np.random.uniform(0.35, 1.2, 2) 54 | 55 | #Generate random translation shift 56 | tx , ty = -shape[1]/2.0 , -shape[0]/2.0 57 | txn, tyn = np.random.normal(0, 120.0*GLOBAL_MULTIPLIER, 2) 58 | 59 | c, s = np.cos(theta), np.sin(theta) 60 | 61 | #Affine coeffs 62 | sx , sy = np.random.normal(0,0.6*GLOBAL_MULTIPLIER,2) 63 | 64 | #Projective coeffs 65 | p1 , p2 = np.random.normal(0,0.006*GLOBAL_MULTIPLIER,2) 66 | 67 | 68 | # Build Homography from parmeterizations 69 | H_t = np.array(((1,0, tx), (0, 1, ty), (0,0,1))) #t 70 | H_r = np.array(((c,-s, 0), (s, c, 0), (0,0,1))) #rotation, 71 | H_a = np.array(((1,sy, 0), (sx, 1, 0), (0,0,1))) # affine 72 | H_p = np.array(((1, 0, 0), (0 , 1, 0), (p1,p2,1))) # projective 73 | H_s = np.array(((scale_x,0, 0), (0, scale_y, 0), (0,0,1))) #scale 74 | H_b = np.array(((1.0,0,-tx +txn), (0, 1, -ty + tyn), (0,0,1))) #t_back, 75 | 76 | #H = H_e * H_s * H_a * H_p 77 | H = np.dot(np.dot(np.dot(np.dot(np.dot(H_b,H_s),H_p),H_a),H_r),H_t) 78 | 79 | return H 80 | 81 | 82 | class COCOAugmentor(nn.Module): 83 | 84 | def __init__(self,device,load_dataset=True, 85 | img_dir="/home/yepeng_liu/code_python/dataset/coco_20k", 86 | warp_resolution=(1200, 900), 87 | out_resolution=(400, 300), 88 | sides_crop=0.2, 89 | max_num_imgs=50, 90 | num_test_imgs=10, 91 | batch_size=1, 92 | photometric=True, 93 | geometric=True, 94 | reload_step=1_000 95 | ): 96 | super(COCOAugmentor,self).__init__() 97 | self.half=16 98 | self.device=device 99 | 100 | self.dims=warp_resolution 101 | self.batch_size=batch_size 102 | self.out_resolution=out_resolution 103 | self.sides_crop=sides_crop 104 | self.max_num_imgs=max_num_imgs 105 | self.num_test_imgs=num_test_imgs 106 | self.dims_t=torch.tensor([int(self.dims[0]*(1. - self.sides_crop)) - int(self.dims[0]*self.sides_crop) -1, 107 | int(self.dims[1]*(1. - self.sides_crop)) - int(self.dims[1]*self.sides_crop) -1]).float().to(device).view(1,1,2) 108 | self.dims_s=torch.tensor([self.dims_t[0,0,0] / out_resolution[0], 109 | self.dims_t[0,0,1] / out_resolution[1]]).float().to(device).view(1,1,2) 110 | 111 | self.all_imgs=glob.glob(img_dir+'/*.jpg')+glob.glob(img_dir+'/*.png') 112 | 113 | self.photometric=photometric 114 | self.geometric=geometric 115 | self.cnt=1 116 | self.reload_step=reload_step 117 | 118 | list_augmentation=[ 119 | kornia.augmentation.ColorJitter(0.15,0.15,0.15,0.15,p=1.), 120 | kornia.augmentation.RandomEqualize(p=0.4), 121 | kornia.augmentation.RandomGaussianBlur(p=0.3,sigma=(2.0,2.0),kernel_size=(7,7)) 122 | ] 123 | 124 | if photometric is False: 125 | list_augmentation = [] 126 | 127 | self.aug_list=kornia.augmentation.ImageSequential(*list_augmentation) 128 | 129 | if len(self.all_imgs)<10: 130 | raise RuntimeError('Couldnt find enough images to train. Please check the path: ',img_dir) 131 | 132 | if load_dataset: 133 | print('[COCO]: ',len(self.all_imgs),' images for training..') 134 | if len(self.all_imgs) - num_test_imgs < max_num_imgs: 135 | raise RuntimeError('Error: test set overlaps with training set! Decrease number of test imgs') 136 | 137 | self.load_imgs() 138 | 139 | self.TPS = True 140 | 141 | 142 | def load_imgs(self): 143 | random.shuffle(self.all_imgs) 144 | train = [] 145 | for p in tqdm.tqdm(self.all_imgs[:self.max_num_imgs],desc='loading train'): 146 | im=cv2.imread(p) 147 | halfH,halfW=im.shape[0]//2,im.shape[1]//2 148 | if halfH>halfW: 149 | im=np.rot90(im) 150 | halfH,halfW=halfW,halfH 151 | 152 | if im.shape[0]!=self.dims[1] or im.shape[1]!=self.dims[0]: 153 | im = cv2.resize(im, self.dims) 154 | 155 | train.append(np.copy(im)) 156 | 157 | self.train=train 158 | self.test=[ 159 | cv2.resize(cv2.imread(p),self.dims) 160 | for p in tqdm.tqdm(self.all_imgs[-self.num_test_imgs:],desc='loading test') 161 | ] 162 | 163 | def norm_pts_grid(self, x): 164 | if len(x.size()) == 2: 165 | return (x.view(1,-1,2) * self.dims_s / self.dims_t) * 2. - 1 166 | return (x * self.dims_s / self.dims_t) * 2. - 1 167 | 168 | def denorm_pts_grid(self, x): 169 | if len(x.size()) == 2: 170 | return ((x.view(1,-1,2) + 1) / 2.) / self.dims_s * self.dims_t 171 | return ((x+1) / 2.) / self.dims_s * self.dims_t 172 | 173 | def rnd_kps(self, shape, n = 256): 174 | h, w = shape 175 | kps = torch.rand(size = (3,n)).to(self.device) 176 | kps[0,:]*=w 177 | kps[1,:]*=h 178 | kps[2,:] = 1.0 179 | 180 | return kps 181 | 182 | def warp_points(self, H, pts): 183 | scale = self.dims_s.view(-1,2) 184 | offset = torch.tensor([int(self.dims[0]*self.sides_crop), int(self.dims[1]*self.sides_crop)], device = pts.device).float() 185 | pts = pts*scale + offset 186 | pts = torch.vstack( [pts.t(), torch.ones(1, pts.shape[0], device = pts.device)]) 187 | warped = torch.matmul(H, pts) 188 | warped = warped / warped[2,...] 189 | warped = warped.t()[:, :2] 190 | return (warped - offset) / scale 191 | 192 | @torch.inference_mode() 193 | def forward(self, x, difficulty = 0.3, TPS = False, prob_deformation = 0.5, test = False): 194 | """ 195 | Perform augmentation to a batch of images. 196 | 197 | input: 198 | x -> torch.Tensor(B, C, H, W): rgb images 199 | difficulty -> float: level of difficulty, 0.1 is medium, 0.3 is already pretty hard 200 | tps -> bool: Wether to apply non-rigid deformations in images 201 | prob_deformation -> float: probability to apply a deformation 202 | 203 | return: 204 | 'output' -> torch.Tensor(B, C, H, W): rgb images 205 | Tuple: 206 | 'H' -> torch.Tensor(3,3): homography matrix 207 | 'mask' -> torch.Tensor(B, H, W): mask of valid pixels after warp 208 | (deformation only) 209 | src, weights, A are parameters from a TPS warp (all torch.Tensors) 210 | 211 | """ 212 | 213 | if self.cnt % self.reload_step == 0: 214 | self.load_imgs() 215 | 216 | if self.geometric is False: 217 | difficulty = 0. 218 | 219 | with torch.no_grad(): 220 | x = (x/255.).to(self.device) 221 | b, c, h, w = x.shape 222 | shape = (h, w) 223 | 224 | ######## Geometric Transformations 225 | 226 | H = torch.tensor(np.array([generateRandomHomography(shape,difficulty) for b in range(self.batch_size)]),dtype=torch.float32).to(self.device) 227 | 228 | output = kornia.geometry.transform.warp_perspective(x,H,dsize=shape,padding_mode='zeros') 229 | 230 | #crop % of image boundaries each side to reduce invalid pixels after warps 231 | low_h = int(h * self.sides_crop); low_w = int(w*self.sides_crop) 232 | high_h = int(h*(1. - self.sides_crop)); high_w= int(w * (1. - self.sides_crop)) 233 | output = output[..., low_h:high_h, low_w:high_w] 234 | x = x[..., low_h:high_h, low_w:high_w] 235 | 236 | #apply TPS if desired: 237 | if TPS: 238 | src, weights, A = None, None, None 239 | for b in range(self.batch_size): 240 | b_src, b_weights, b_A = generateRandomTPS(shape, (8,6), difficulty, prob = prob_deformation) 241 | b_src, b_weights, b_A = b_src.to(self.device), b_weights.to(self.device), b_A.to(self.device) 242 | 243 | if src is None: 244 | src, weights, A = b_src, b_weights, b_A 245 | else: 246 | src = torch.cat((b_src, src)) 247 | weights = torch.cat((b_weights, weights)) 248 | A = torch.cat((b_A, A)) 249 | 250 | output = warp_image_tps(output, src, weights, A) 251 | 252 | output = F.interpolate(output, self.out_resolution[::-1], mode = 'nearest') 253 | x = F.interpolate(x, self.out_resolution[::-1], mode = 'nearest') 254 | 255 | mask = ~torch.all(output == 0, dim=1, keepdim=True) 256 | mask = mask.expand(-1,3,-1,-1) 257 | 258 | # Make-up invalid regions with texture from the batch 259 | rv = 1 if not TPS else 2 260 | output_shifted = torch.roll(x, rv, 0) 261 | output[~mask] = output_shifted[~mask] 262 | mask = mask[:, 0, :, :] 263 | 264 | ######## Photometric Transformations 265 | output = self.aug_list(output) 266 | 267 | b, c, h, w = output.shape 268 | #Correlated Gaussian Noise 269 | if np.random.uniform() > 0.5 and self.photometric: 270 | noise = F.interpolate(torch.randn_like(output)*(10/255), (h//2, w//2)) 271 | noise = F.interpolate(noise, (h, w), mode = 'bicubic') 272 | output = torch.clip( output + noise, 0., 1.) 273 | 274 | #Random shadows 275 | if np.random.uniform() > 0.6 and self.photometric: 276 | noise = torch.rand((b, 1, h//64, w//64), device = self.device) * 1.3 277 | noise = torch.clip(noise, 0.25, 1.0) 278 | noise = F.interpolate(noise, (h, w), mode = 'bicubic') 279 | noise = noise.expand(-1, 3, -1, -1) 280 | output *= noise 281 | output = torch.clip( output, 0., 1.) 282 | 283 | self.cnt+=1 284 | 285 | if TPS: 286 | return output, (H, src, weights, A, mask) 287 | else: 288 | return output, (H, mask) 289 | 290 | def get_correspondences(self, kps_target, T): 291 | H, H2, src, W, A = T 292 | undeformed = self.denorm_pts_grid( 293 | warp_points_tps(self.norm_pts_grid(kps_target), 294 | src, W, A) ).view(-1,2) 295 | 296 | warped_to_src = self.warp_points(H@torch.inverse(H2), undeformed) 297 | 298 | return warped_to_src -------------------------------------------------------------------------------- /dataset/coco_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pdb 4 | 5 | debug_cnt = -1 6 | 7 | def make_batch(augmentor, difficulty = 0.3, train = True): 8 | Hs = [] 9 | img_list = augmentor.train if train else augmentor.test 10 | dev = augmentor.device 11 | batch_images = [] 12 | 13 | with torch.no_grad(): # we dont require grads in the augmentation 14 | for b in range(augmentor.batch_size): 15 | rdidx = np.random.randint(len(img_list)) 16 | img = torch.tensor(img_list[rdidx], dtype=torch.float32).permute(2,0,1).to(augmentor.device).unsqueeze(0) 17 | batch_images.append(img) 18 | 19 | batch_images = torch.cat(batch_images) 20 | 21 | p1, H1 = augmentor(batch_images, difficulty) 22 | p2, H2 = augmentor(batch_images, difficulty, TPS = True, prob_deformation = 0.7) 23 | # p2, H2 = augmentor(batch_images, difficulty, TPS = False, prob_deformation = 0.7) 24 | 25 | return p1, p2, H1, H2 26 | 27 | 28 | def plot_corrs(p1, p2, src_pts, tgt_pts): 29 | import matplotlib.pyplot as plt 30 | p1 = p1.cpu() 31 | p2 = p2.cpu() 32 | src_pts = src_pts.cpu() ; tgt_pts = tgt_pts.cpu() 33 | rnd_idx = np.random.randint(len(src_pts), size=200) 34 | src_pts = src_pts[rnd_idx, ...] 35 | tgt_pts = tgt_pts[rnd_idx, ...] 36 | 37 | #Plot ground-truth correspondences 38 | fig, ax = plt.subplots(1,2,figsize=(18, 12)) 39 | colors = np.random.uniform(size=(len(tgt_pts),3)) 40 | #Src image 41 | img = p1 42 | for i, p in enumerate(src_pts): 43 | ax[0].scatter(p[0],p[1],color=colors[i]) 44 | ax[0].imshow(img.permute(1,2,0).numpy()[...,::-1]) 45 | 46 | #Target img 47 | img2 = p2 48 | for i, p in enumerate(tgt_pts): 49 | ax[1].scatter(p[0],p[1],color=colors[i]) 50 | ax[1].imshow(img2.permute(1,2,0).numpy()[...,::-1]) 51 | plt.show() 52 | 53 | 54 | def get_corresponding_pts(p1, p2, H, H2, augmentor, h, w, crop = None): 55 | ''' 56 | Get dense corresponding points 57 | ''' 58 | global debug_cnt 59 | negatives, positives = [], [] 60 | 61 | with torch.no_grad(): 62 | #real input res of samples 63 | rh, rw = p1.shape[-2:] 64 | ratio = torch.tensor([rw/w, rh/h], device = p1.device) 65 | 66 | (H, mask1) = H 67 | (H2, src, W, A, mask2) = H2 68 | 69 | #Generate meshgrid of target pts 70 | x, y = torch.meshgrid(torch.arange(w, device=p1.device), torch.arange(h, device=p1.device), indexing ='xy') 71 | mesh = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], dim=-1) 72 | target_pts = mesh.view(-1, 2) * ratio 73 | 74 | #Pack all transformations into T 75 | for batch_idx in range(len(p1)): 76 | with torch.no_grad(): 77 | T = (H[batch_idx], H2[batch_idx], 78 | src[batch_idx].unsqueeze(0), W[batch_idx].unsqueeze(0), A[batch_idx].unsqueeze(0)) 79 | #We now warp the target points to src image 80 | src_pts = (augmentor.get_correspondences(target_pts, T) ) #target to src 81 | tgt_pts = (target_pts) 82 | 83 | #Check out of bounds points 84 | mask_valid = (src_pts[:, 0] >=0) & (src_pts[:, 1] >=0) & \ 85 | (src_pts[:, 0] < rw) & (src_pts[:, 1] < rh) 86 | 87 | negatives.append( tgt_pts[~mask_valid] ) 88 | tgt_pts = tgt_pts[mask_valid] 89 | src_pts = src_pts[mask_valid] 90 | 91 | 92 | #Remove invalid pixels 93 | mask_valid = mask1[batch_idx, src_pts[:,1].long(), src_pts[:,0].long()] & \ 94 | mask2[batch_idx, tgt_pts[:,1].long(), tgt_pts[:,0].long()] 95 | tgt_pts = tgt_pts[mask_valid] 96 | src_pts = src_pts[mask_valid] 97 | 98 | # limit nb of matches if desired 99 | if crop is not None: 100 | rnd_idx = torch.randperm(len(src_pts), device=src_pts.device)[:crop] 101 | src_pts = src_pts[rnd_idx] 102 | tgt_pts = tgt_pts[rnd_idx] 103 | 104 | if debug_cnt >=0 and debug_cnt < 4: 105 | plot_corrs(p1[batch_idx], p2[batch_idx], src_pts , tgt_pts ) 106 | debug_cnt +=1 107 | 108 | src_pts = (src_pts / ratio) 109 | tgt_pts = (tgt_pts / ratio) 110 | 111 | #Check out of bounds points 112 | padto = 10 if crop is not None else 2 113 | mask_valid1 = (src_pts[:, 0] >= (0 + padto)) & (src_pts[:, 1] >= (0 + padto)) & \ 114 | (src_pts[:, 0] < (w - padto)) & (src_pts[:, 1] < (h - padto)) 115 | mask_valid2 = (tgt_pts[:, 0] >= (0 + padto)) & (tgt_pts[:, 1] >= (0 + padto)) & \ 116 | (tgt_pts[:, 0] < (w - padto)) & (tgt_pts[:, 1] < (h - padto)) 117 | mask_valid = mask_valid1 & mask_valid2 118 | tgt_pts = tgt_pts[mask_valid] 119 | src_pts = src_pts[mask_valid] 120 | 121 | #Remove repeated correspondences 122 | lut_mat = torch.ones((h, w, 4), device = src_pts.device, dtype = src_pts.dtype) * -1 123 | # src_pts_np = src_pts.cpu().numpy() 124 | # tgt_pts_np = tgt_pts.cpu().numpy() 125 | try: 126 | lut_mat[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) 127 | mask_valid = torch.all(lut_mat >= 0, dim=-1) 128 | points = lut_mat[mask_valid] 129 | positives.append(points) 130 | except: 131 | pdb.set_trace() 132 | print('..') 133 | 134 | return negatives, positives 135 | 136 | 137 | def crop_patches(tensor, coords, size = 7): 138 | ''' 139 | Crop [size x size] patches around 2D coordinates from a tensor. 140 | ''' 141 | B, C, H, W = tensor.shape 142 | 143 | x, y = coords[:, 0], coords[:, 1] 144 | y = y.view(-1, 1, 1) 145 | x = x.view(-1, 1, 1) 146 | halfsize = size // 2 147 | # Create meshgrid for indexing 148 | x_offset, y_offset = torch.meshgrid(torch.arange(-halfsize, halfsize+1), torch.arange(-halfsize, halfsize+1), indexing='xy') 149 | y_offset = y_offset.to(tensor.device) 150 | x_offset = x_offset.to(tensor.device) 151 | 152 | # Compute indices around each coordinate 153 | y_indices = (y + y_offset.view(1, size, size)).squeeze(0) + halfsize 154 | x_indices = (x + x_offset.view(1, size, size)).squeeze(0) + halfsize 155 | 156 | # Handle out-of-boundary indices with padding 157 | tensor_padded = torch.nn.functional.pad(tensor, (halfsize, halfsize, halfsize, halfsize), mode='constant') 158 | 159 | # Index tensor to get patches 160 | patches = tensor_padded[:, :, y_indices, x_indices] # [B, C, N, H, W] 161 | return patches 162 | 163 | def subpix_softmax2d(heatmaps, temp = 0.25): 164 | N, H, W = heatmaps.shape 165 | heatmaps = torch.softmax(temp * heatmaps.view(-1, H*W), -1).view(-1, H, W) 166 | x, y = torch.meshgrid(torch.arange(W, device = heatmaps.device ), torch.arange(H, device = heatmaps.device ), indexing = 'xy') 167 | x = x - (W//2) 168 | y = y - (H//2) 169 | #pdb.set_trace() 170 | coords_x = (x[None, ...] * heatmaps) 171 | coords_y = (y[None, ...] * heatmaps) 172 | coords = torch.cat([coords_x[..., None], coords_y[..., None]], -1).view(N, H*W, 2) 173 | coords = coords.sum(1) 174 | 175 | return coords 176 | -------------------------------------------------------------------------------- /dataset/dataset_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching" 3 | 4 | MegaDepth data handling was adapted from 5 | LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py 6 | """ 7 | 8 | import io 9 | import cv2 10 | import numpy as np 11 | import h5py 12 | import torch 13 | from numpy.linalg import inv 14 | 15 | 16 | try: 17 | # for internel use only 18 | from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT 19 | except Exception: 20 | MEGADEPTH_CLIENT = SCANNET_CLIENT = None 21 | 22 | # --- DATA IO --- 23 | 24 | def load_array_from_s3( 25 | path, client, cv_type, 26 | use_h5py=False, 27 | ): 28 | byte_str = client.Get(path) 29 | try: 30 | if not use_h5py: 31 | raw_array = np.fromstring(byte_str, np.uint8) 32 | data = cv2.imdecode(raw_array, cv_type) 33 | else: 34 | f = io.BytesIO(byte_str) 35 | data = np.array(h5py.File(f, 'r')['/depth']) 36 | except Exception as ex: 37 | print(f"==> Data loading failure: {path}") 38 | raise ex 39 | 40 | assert data is not None 41 | return data 42 | 43 | 44 | def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): 45 | cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ 46 | else cv2.IMREAD_COLOR 47 | if str(path).startswith('s3://'): 48 | image = load_array_from_s3(str(path), client, cv_type) 49 | else: 50 | image = cv2.imread(str(path), 1) 51 | 52 | if augment_fn is not None: 53 | image = cv2.imread(str(path), cv2.IMREAD_COLOR) 54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 55 | image = augment_fn(image) 56 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 57 | return image # (h, w) 58 | 59 | 60 | def get_resized_wh(w, h, resize=None): 61 | if resize is not None: # resize the longer edge 62 | scale = resize / max(h, w) 63 | w_new, h_new = int(round(w*scale)), int(round(h*scale)) 64 | else: 65 | w_new, h_new = w, h 66 | return w_new, h_new 67 | 68 | 69 | def get_divisible_wh(w, h, df=None): 70 | if df is not None: 71 | w_new, h_new = map(lambda x: int(x // df * df), [w, h]) 72 | else: 73 | w_new, h_new = w, h 74 | return w_new, h_new 75 | 76 | 77 | def pad_bottom_right(inp, pad_size, ret_mask=False): 78 | assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" 79 | mask = None 80 | if inp.ndim == 2: 81 | padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) 82 | padded[:inp.shape[0], :inp.shape[1]] = inp 83 | if ret_mask: 84 | mask = np.zeros((pad_size, pad_size), dtype=bool) 85 | mask[:inp.shape[0], :inp.shape[1]] = True 86 | elif inp.ndim == 3: 87 | padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) 88 | padded[:, :inp.shape[1], :inp.shape[2]] = inp 89 | if ret_mask: 90 | mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) 91 | mask[:, :inp.shape[1], :inp.shape[2]] = True 92 | else: 93 | raise NotImplementedError() 94 | return padded, mask 95 | 96 | 97 | # --- MEGADEPTH --- 98 | 99 | def fix_path_from_d2net(path): 100 | if not path: 101 | return None 102 | 103 | path = path.replace('Undistorted_SfM/', '') 104 | path = path.replace('images', 'dense0/imgs') 105 | path = path.replace('phoenix/S6/zl548/MegaDepth_v1/', '') 106 | 107 | return path 108 | 109 | def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): 110 | """ 111 | Args: 112 | resize (int, optional): the longer edge of resized images. None for no resize. 113 | padding (bool): If set to 'True', zero-pad resized images to squared size. 114 | augment_fn (callable, optional): augments images with pre-defined visual effects 115 | Returns: 116 | image (torch.tensor): (1, h, w) 117 | mask (torch.tensor): (h, w) 118 | scale (torch.tensor): [w/w_new, h/h_new] 119 | """ 120 | # read image 121 | image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) 122 | 123 | # resize image 124 | w, h = image.shape[1], image.shape[0] 125 | 126 | if resize is not None: 127 | if len(resize) == 2: 128 | w_new, h_new = resize 129 | else: 130 | resize = resize[0] 131 | w_new, h_new = get_resized_wh(w, h, resize) 132 | w_new, h_new = get_divisible_wh(w_new, h_new, df) 133 | 134 | 135 | image = cv2.resize(image, (w_new, h_new)) 136 | scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) 137 | 138 | if padding: # padding 139 | pad_to = max(h_new, w_new) 140 | image, mask = pad_bottom_right(image, pad_to, ret_mask=True) 141 | else: 142 | mask = None 143 | else: 144 | scale=torch.tensor([1.0,1.0],dtype=torch.float) 145 | 146 | if padding: 147 | pad_to=max(w,h) 148 | image,mask=pad_bottom_right(image,pad_to,ret_mask=True) 149 | else: 150 | mask=None 151 | 152 | #image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized 153 | image_t = torch.from_numpy(image).float().permute(2,0,1) / 255 # (h, w) -> (1, h, w) and normalized 154 | mask = torch.from_numpy(mask) if mask is not None else None 155 | 156 | return image, image_t, mask, scale 157 | 158 | 159 | def read_megadepth_depth(path, pad_to=None): 160 | 161 | if str(path).startswith('s3://'): 162 | depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) 163 | else: 164 | depth = np.array(h5py.File(path, 'r')['depth']) 165 | if pad_to is not None: 166 | depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) 167 | depth = torch.from_numpy(depth).float() # (h, w) 168 | return depth 169 | 170 | 171 | def imread_bgr(path, augment_fn=None, client=SCANNET_CLIENT): 172 | cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR 173 | if str(path).startswith('s3://'): 174 | image = load_array_from_s3(str(path), client, cv_type) 175 | else: 176 | image = cv2.imread(str(path), 1) 177 | 178 | if augment_fn is not None: 179 | image = cv2.imread(str(path), cv2.IMREAD_COLOR) 180 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 181 | image = augment_fn(image) 182 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 183 | return image # (h, w) 184 | -------------------------------------------------------------------------------- /dataset/megadepth.py: -------------------------------------------------------------------------------- 1 | """ 2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching" 3 | 4 | MegaDepth data handling was adapted from 5 | LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py 6 | """ 7 | 8 | import os.path as osp 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.data import Dataset 13 | import glob 14 | import numpy.random as rnd 15 | 16 | import os 17 | import sys 18 | sys.path.append(os.path.join(os.path.dirname(__file__),'..')) 19 | from dataset.dataset_utils import read_megadepth_gray, read_megadepth_depth, fix_path_from_d2net 20 | 21 | import pdb, tqdm, os 22 | 23 | 24 | class MegaDepthDataset(Dataset): 25 | def __init__(self, 26 | root_dir, 27 | npz_path, 28 | mode='train', 29 | min_overlap_score = 0.3, #0.3, 30 | max_overlap_score = 1.0, #1, 31 | load_depth = True, 32 | img_resize = (800,608), #or None 33 | df=32, 34 | img_padding=False, 35 | depth_padding=True, 36 | augment_fn=None, 37 | **kwargs): 38 | """ 39 | Manage one scene(npz_path) of MegaDepth dataset. 40 | 41 | Args: 42 | root_dir (str): megadepth root directory that has `phoenix`. 43 | npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. 44 | mode (str): options are ['train', 'val', 'test'] 45 | min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. 46 | img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. 47 | This is useful during training with batches and testing with memory intensive algorithms. 48 | df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. 49 | img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. 50 | depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. 51 | augment_fn (callable, optional): augments images with pre-defined visual effects. 52 | """ 53 | super().__init__() 54 | self.root_dir = root_dir 55 | self.mode = mode 56 | self.scene_id = npz_path.split('.')[0] 57 | self.load_depth = load_depth 58 | # prepare scene_info and pair_info 59 | if mode == 'test' and min_overlap_score != 0: 60 | min_overlap_score = 0 61 | self.scene_info = np.load(npz_path, allow_pickle=True) 62 | self.pair_infos = self.scene_info['pair_infos'].copy() 63 | del self.scene_info['pair_infos'] 64 | self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score and pair_info[1] < max_overlap_score] 65 | 66 | # parameters for image resizing, padding and depthmap padding 67 | if mode == 'train': 68 | assert img_resize is not None #and img_padding and depth_padding 69 | 70 | self.img_resize = img_resize 71 | self.df = df 72 | self.img_padding = img_padding 73 | self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. 74 | 75 | # for training LoFTR 76 | self.augment_fn = augment_fn if mode == 'train' else None 77 | self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) 78 | #pdb.set_trace() 79 | for idx in range(len(self.scene_info['image_paths'])): 80 | self.scene_info['image_paths'][idx] = fix_path_from_d2net(self.scene_info['image_paths'][idx]) 81 | 82 | for idx in range(len(self.scene_info['depth_paths'])): 83 | self.scene_info['depth_paths'][idx] = fix_path_from_d2net(self.scene_info['depth_paths'][idx]) 84 | 85 | 86 | def __len__(self): 87 | return len(self.pair_infos) 88 | 89 | def __getitem__(self, idx): 90 | (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx % len(self)] 91 | 92 | # read grayscale image and mask. (1, h, w) and (h, w) 93 | img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) 94 | img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) 95 | 96 | # TODO: Support augmentation & handle seeds for each worker correctly. 97 | image0, image0_t, mask0, scale0 = read_megadepth_gray(img_name0, self.img_resize, self.df, self.img_padding, None) 98 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 99 | image1, image1_t, mask1, scale1 = read_megadepth_gray(img_name1, self.img_resize, self.df, self.img_padding, None) 100 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 101 | 102 | if self.load_depth: 103 | # read depth. shape: (h, w) 104 | if self.mode in ['train', 'val']: 105 | depth0 = read_megadepth_depth( 106 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) 107 | depth1 = read_megadepth_depth( 108 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) 109 | else: 110 | depth0 = depth1 = torch.tensor([]) 111 | 112 | # read intrinsics of original size 113 | K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) 114 | K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) 115 | 116 | # read and compute relative poses 117 | T0 = self.scene_info['poses'][idx0] 118 | T1 = self.scene_info['poses'][idx1] 119 | T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) 120 | T_1to0 = T_0to1.inverse() 121 | 122 | data = { 123 | 'image0': image0_t, # (1, h, w) 124 | 'image0_np': image0, 125 | 'depth0': depth0, # (h, w) 126 | 'image1': image1_t, 127 | 'image1_np': image1, 128 | 'depth1': depth1, 129 | 'T_0to1': T_0to1, # (4, 4) 130 | 'T_1to0': T_1to0, 131 | 'K0': K_0, # (3, 3) 132 | 'K1': K_1, 133 | 'scale0': scale0, # [scale_w, scale_h] 134 | 'scale1': scale1, 135 | 'dataset_name': 'MegaDepth', 136 | 'scene_id': self.scene_id, 137 | 'pair_id': idx, 138 | 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), 139 | } 140 | 141 | # for LoFTR training 142 | if mask0 is not None: # img_padding is True 143 | if self.coarse_scale: 144 | [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), 145 | scale_factor=self.coarse_scale, 146 | mode='nearest', 147 | recompute_scale_factor=False)[0].bool() 148 | data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) 149 | 150 | else: 151 | 152 | # read intrinsics of original size 153 | K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) 154 | K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) 155 | 156 | # read and compute relative poses 157 | T0 = self.scene_info['poses'][idx0] 158 | T1 = self.scene_info['poses'][idx1] 159 | T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) 160 | T_1to0 = T_0to1.inverse() 161 | 162 | data = { 163 | 'image0': image0, # (1, h, w) 164 | 'image1': image1, 165 | 'T_0to1': T_0to1, # (4, 4) 166 | 'T_1to0': T_1to0, 167 | 'K0': K_0, # (3, 3) 168 | 'K1': K_1, 169 | 'scale0': scale0, # [scale_w, scale_h] 170 | 'scale1': scale1, 171 | 'dataset_name': 'MegaDepth', 172 | 'scene_id': self.scene_id, 173 | 'pair_id': idx, 174 | 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), 175 | } 176 | 177 | return data -------------------------------------------------------------------------------- /dataset/megadepth_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching" 3 | 4 | MegaDepth data handling was adapted from 5 | LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py 6 | """ 7 | 8 | import torch 9 | from kornia.utils import create_meshgrid 10 | import matplotlib.pyplot as plt 11 | import pdb 12 | import cv2 13 | 14 | @torch.no_grad() 15 | def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): 16 | """ Warp kpts0 from I0 to I1 with depth, K and Rt 17 | Also check covisibility and depth consistency. 18 | Depth is consistent if relative error < 0.2 (hard-coded). 19 | 20 | Args: 21 | kpts0 (torch.Tensor): [N, L, 2] - , 22 | depth0 (torch.Tensor): [N, H, W], 23 | depth1 (torch.Tensor): [N, H, W], 24 | T_0to1 (torch.Tensor): [N, 3, 4], 25 | K0 (torch.Tensor): [N, 3, 3], 26 | K1 (torch.Tensor): [N, 3, 3], 27 | Returns: 28 | calculable_mask (torch.Tensor): [N, L] 29 | warped_keypoints0 (torch.Tensor): [N, L, 2] 30 | """ 31 | kpts0_long = kpts0.round().long().clip(0, 2000-1) 32 | 33 | depth0[:, 0, :] = 0 ; depth1[:, 0, :] = 0 34 | depth0[:, :, 0] = 0 ; depth1[:, :, 0] = 0 35 | 36 | # Sample depth, get calculable_mask on depth != 0 37 | kpts0_depth = torch.stack( 38 | [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 39 | ) # (N, L) 40 | nonzero_mask = kpts0_depth > 0 41 | 42 | # Draw cross marks on the image for each keypoint 43 | # for b in range(len(kpts0)): 44 | # fig, ax = plt.subplots(1,2) 45 | # depth_np = depth0.numpy()[b] 46 | # depth_np_plot = depth_np.copy() 47 | # for x, y in kpts0_long[b, nonzero_mask[b], :].numpy(): 48 | # cv2.drawMarker(depth_np_plot, (x, y), (255), cv2.MARKER_CROSS, markerSize=10, thickness=2) 49 | # ax[0].imshow(depth_np) 50 | # ax[1].imshow(depth_np_plot) 51 | 52 | # Unproject 53 | kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) 54 | kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) 55 | 56 | # Rigid Transform 57 | w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) 58 | w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] 59 | 60 | # Project 61 | w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) 62 | w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-5) # (N, L, 2), +1e-4 to avoid zero depth 63 | 64 | # Covisible Check 65 | # h, w = depth1.shape[1:3] 66 | # covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ 67 | # (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) 68 | # w_kpts0_long = w_kpts0.long() 69 | # w_kpts0_long[~covisible_mask, :] = 0 70 | 71 | # w_kpts0_depth = torch.stack( 72 | # [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 73 | # ) # (N, L) 74 | # consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 75 | 76 | 77 | valid_mask = nonzero_mask #* consistent_mask* covisible_mask 78 | 79 | return valid_mask, w_kpts0 80 | 81 | 82 | @torch.no_grad() 83 | def spvs_coarse(data, scale = 8): 84 | """ 85 | Supervise corresp with dense depth & camera poses 86 | """ 87 | 88 | # 1. misc 89 | device = data['image0'].device 90 | N, _, H0, W0 = data['image0'].shape 91 | _, _, H1, W1 = data['image1'].shape 92 | #scale = 8 93 | scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale 94 | scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale 95 | h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) 96 | 97 | # 2. warp grids 98 | # create kpts in meshgrid and resize them to image resolution 99 | grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) # [N, hw, 2] 100 | grid_pt1_i = scale1 * grid_pt1_c 101 | 102 | # warp kpts bi-directionally and check reproj error 103 | nonzero_m1, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) 104 | nonzero_m2, w_pt1_og = warp_kpts( w_pt1_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) 105 | 106 | 107 | dist = torch.linalg.norm( grid_pt1_i - w_pt1_og, dim=-1) 108 | mask_mutual = (dist < 1.5) & nonzero_m1 & nonzero_m2 109 | 110 | #_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) 111 | batched_corrs = [ torch.cat([w_pt1_i[i, mask_mutual[i]] / data['scale0'][i], 112 | grid_pt1_i[i, mask_mutual[i]] / data['scale1'][i]],dim=-1) for i in range(len(mask_mutual))] 113 | 114 | 115 | #Remove repeated correspondences - this is important for network convergence 116 | corrs = [] 117 | for pts in batched_corrs: 118 | lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1 119 | lut_mat21 = torch.clone(lut_mat12) 120 | src_pts = pts[:, :2] / scale 121 | tgt_pts = pts[:, 2:] / scale 122 | try: 123 | lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) 124 | mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1) 125 | points = lut_mat12[mask_valid12] 126 | 127 | #Target-src check 128 | src_pts, tgt_pts = points[:, :2], points[:, 2:] 129 | lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) 130 | mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1) 131 | points = lut_mat21[mask_valid21] 132 | 133 | corrs.append(points) 134 | except: 135 | pdb.set_trace() 136 | print('..') 137 | 138 | #Plot for debug purposes 139 | # for i in range(len(corrs)): 140 | # plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8) 141 | 142 | return corrs 143 | 144 | @torch.no_grad() 145 | def get_correspondences(pts2, data, idx): 146 | device = data['image0'].device 147 | N, _, H0, W0 = data['image0'].shape 148 | _, _, H1, W1 = data['image1'].shape 149 | 150 | pts2 = pts2[None, ...] 151 | 152 | scale0 = data['scale0'][idx, None][None, ...] if 'scale0' in data else 1 153 | scale1 = data['scale1'][idx, None][None, ...] if 'scale1' in data else 1 154 | 155 | pts2 = scale1 * pts2 * 8 156 | 157 | # warp kpts bi-directionally and check reproj error 158 | nonzero_m1, pts1 = warp_kpts(pts2, data['depth1'][idx][None, ...], data['depth0'][idx][None, ...], data['T_1to0'][idx][None, ...], 159 | data['K1'][idx][None, ...], data['K0'][idx][None, ...]) 160 | 161 | corrs = torch.cat([pts1[0, :] / data['scale0'][idx], 162 | pts2[0, :] / data['scale1'][idx]],dim=-1) 163 | 164 | #plot_corrs(data['image0'][idx], data['image1'][idx], corrs[:, :2], corrs[:, 2:]) 165 | 166 | return corrs 167 | 168 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | import math 6 | import cv2 7 | 8 | from models.liftfeat_wrapper import LiftFeat,MODEL_PATH 9 | 10 | import argparse 11 | 12 | parser=argparse.ArgumentParser(description='HPatch dataset evaluation script') 13 | parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name') 14 | parser.add_argument('--img1',type=str,default='./assert/ref.jpg',help='reference image path') 15 | parser.add_argument('--img2',type=str,default='./assert/query.jpg',help='query image path') 16 | parser.add_argument('--size',type=str,default=None,help='Resize images to w,h, None means disable resize') 17 | parser.add_argument('--use_opencv_match',action='store_true',help='Enable OpenCV match function') 18 | parser.add_argument('--gpu',type=str,default='0',help='GPU ID') 19 | args=parser.parse_args() 20 | 21 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 22 | 23 | 24 | def warp_corners_and_draw_matches(ref_points, dst_points, img1, img2): 25 | # Calculate the Homography matrix 26 | H, mask = cv2.findHomography(ref_points, dst_points, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999) 27 | mask = mask.flatten() 28 | 29 | # Get corners of the first image (image1) 30 | h, w = img1.shape[:2] 31 | corners_img1 = np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype=np.float32).reshape(-1, 1, 2) 32 | 33 | # Warp corners to the second image (image2) space 34 | warped_corners = cv2.perspectiveTransform(corners_img1, H) 35 | 36 | # Draw the warped corners in image2 37 | img2_with_corners = img2.copy() 38 | 39 | # Prepare keypoints and matches for drawMatches function 40 | keypoints1 = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in ref_points] 41 | keypoints2 = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in dst_points] 42 | matches = [cv2.DMatch(i,i,0) for i in range(len(mask)) if mask[i]] 43 | 44 | # Draw inlier matches 45 | img_matches = cv2.drawMatches(img1, keypoints1, img2_with_corners, keypoints2, matches, None, 46 | matchColor=(0, 255, 0), flags=2) 47 | 48 | return img_matches 49 | 50 | 51 | def opencv_knn_match(descs1,descs2,kpts1,kpts2): 52 | bf = cv2.BFMatcher() 53 | 54 | matches = bf.knnMatch(descs1,descs2,k=2) 55 | 56 | good_matches = [] 57 | for m, n in matches: 58 | if m.distance < 0.9 * n.distance: 59 | good_matches.append(m) 60 | 61 | mkpts1 = [];mkpts2 = [] 62 | 63 | for m in good_matches: 64 | mkpt1=kpts1[m.queryIdx];mkpt2=kpts2[m.trainIdx] 65 | mkpts1.append(mkpt1);mkpts2.append(mkpt2) 66 | 67 | mkpts1 = np.array(mkpts1) 68 | mkpts2 = np.array(mkpts2) 69 | 70 | return mkpts1,mkpts2 71 | 72 | 73 | if __name__=="__main__": 74 | if args.size: 75 | print(f'resize images to {args.size}') 76 | w=int(args.size.split(',')[0]) 77 | h=int(args.size.split(',')[1]) 78 | dst_size=(w,h) 79 | else: 80 | print(f'disable resize') 81 | 82 | if args.use_opencv_match: 83 | print(f'Use OpenCV knnMatch') 84 | else: 85 | print(f'Use original match function') 86 | 87 | liftfeat=LiftFeat(weight=MODEL_PATH,detect_threshold=0.05) 88 | 89 | img1=cv2.imread(args.img1) 90 | img2=cv2.imread(args.img2) 91 | 92 | if args.size: 93 | img1=cv2.resize(img1,dst_size) 94 | img2=cv2.resize(img2,dst_size) 95 | 96 | if args.use_opencv_match: 97 | data1 = liftfeat.extract(img1) 98 | data2 = liftfeat.extract(img2) 99 | kpts1,descs1=data1['keypoints'].cpu().numpy(),data1['descriptors'].cpu().numpy() 100 | kpts2,descs2=data2['keypoints'].cpu().numpy(),data2['descriptors'].cpu().numpy() 101 | 102 | mkpts1,mkpts2 = opencv_knn_match(descs1,descs2,kpts1,kpts2) 103 | else: 104 | mkpts1,mkpts2=liftfeat.match_liftfeat(img1,img2) 105 | 106 | 107 | canvas=warp_corners_and_draw_matches(mkpts1,mkpts2,img1,img2) 108 | 109 | import matplotlib.pyplot as plt 110 | plt.figure(figsize=[12,12]) 111 | plt.imshow(canvas[...,::-1]) 112 | 113 | plt.savefig(os.path.join(os.path.dirname(__file__),'match.jpg'), dpi=300, bbox_inches='tight') 114 | 115 | plt.show() 116 | -------------------------------------------------------------------------------- /evaluation/HPatch_evaluation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from tqdm import tqdm 4 | import torch 5 | import numpy as np 6 | import sys 7 | import poselib 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__),'..')) 10 | 11 | import argparse 12 | import datetime 13 | 14 | parser=argparse.ArgumentParser(description='HPatch dataset evaluation script') 15 | parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name') 16 | parser.add_argument('--gpu',type=str,default='0',help='GPU ID') 17 | args=parser.parse_args() 18 | 19 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 20 | 21 | use_cuda = torch.cuda.is_available() 22 | device = torch.device("cuda" if use_cuda else "cpu") 23 | 24 | top_k = None 25 | n_i = 52 26 | n_v = 56 27 | 28 | DATASET_ROOT = os.path.join(os.path.dirname(__file__),'../data/HPatch') 29 | 30 | from evaluation.eval_utils import * 31 | from models.liftfeat_wrapper import LiftFeat 32 | 33 | 34 | poselib_config = {"ransac_th": 3.0, "options": {}} 35 | 36 | class PoseLibHomographyEstimator: 37 | def __init__(self, conf): 38 | self.conf = conf 39 | 40 | def estimate(self, mkpts0,mkpts1): 41 | M, info = poselib.estimate_homography( 42 | mkpts0, 43 | mkpts1, 44 | { 45 | "max_reproj_error": self.conf["ransac_th"], 46 | **self.conf["options"], 47 | }, 48 | ) 49 | success = M is not None 50 | if not success: 51 | M = np.eye(3,dtype=np.float32) 52 | inl = np.zeros(mkpts0.shape[0],dtype=np.bool_) 53 | else: 54 | inl = info["inliers"] 55 | 56 | estimation = { 57 | "success": success, 58 | "M_0to1": M, 59 | "inliers": inl, 60 | } 61 | 62 | return estimation 63 | 64 | 65 | estimator=PoseLibHomographyEstimator(poselib_config) 66 | 67 | 68 | def poselib_homography_estimate(mkpts0,mkpts1): 69 | data=estimator.estimate(mkpts0,mkpts1) 70 | return data 71 | 72 | 73 | def generate_standard_image(img,target_size=(1920,1080)): 74 | sh,sw=img.shape[0],img.shape[1] 75 | rh,rw=float(target_size[1])/float(sh),float(target_size[0])/float(sw) 76 | ratio=min(rh,rw) 77 | nh,nw=int(ratio*sh),int(ratio*sw) 78 | ph,pw=target_size[1]-nh,target_size[0]-nw 79 | nimg=cv2.resize(img,(nw,nh)) 80 | nimg=cv2.copyMakeBorder(nimg,0,ph,0,pw,cv2.BORDER_CONSTANT,value=(0,0,0)) 81 | 82 | return nimg,ratio,ph,pw 83 | 84 | 85 | def benchmark_features(match_fn): 86 | lim = [1, 9] 87 | rng = np.arange(lim[0], lim[1] + 1) 88 | 89 | seq_names = sorted(os.listdir(DATASET_ROOT)) 90 | 91 | n_feats = [] 92 | n_matches = [] 93 | seq_type = [] 94 | i_err = {thr: 0 for thr in rng} 95 | v_err = {thr: 0 for thr in rng} 96 | 97 | i_err_homo = {thr: 0 for thr in rng} 98 | v_err_homo = {thr: 0 for thr in rng} 99 | 100 | for seq_idx, seq_name in tqdm(enumerate(seq_names), total=len(seq_names)): 101 | # load reference image 102 | ref_img = cv2.imread(os.path.join(DATASET_ROOT, seq_name, "1.ppm")) 103 | ref_img_shape=ref_img.shape 104 | 105 | # load query images 106 | for im_idx in range(2, 7): 107 | # read ground-truth homography 108 | homography = np.loadtxt(os.path.join(DATASET_ROOT, seq_name, "H_1_" + str(im_idx))) 109 | query_img = cv2.imread(os.path.join(DATASET_ROOT, seq_name, f"{im_idx}.ppm")) 110 | 111 | mkpts_a,mkpts_b=match_fn(ref_img,query_img) 112 | 113 | pos_a = mkpts_a 114 | pos_a_h = np.concatenate([pos_a, np.ones([pos_a.shape[0], 1])], axis=1) 115 | pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h))) 116 | pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:] 117 | 118 | pos_b = mkpts_b 119 | 120 | dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1)) 121 | 122 | n_matches.append(pos_a.shape[0]) 123 | seq_type.append(seq_name[0]) 124 | 125 | if dist.shape[0] == 0: 126 | dist = np.array([float("inf")]) 127 | 128 | for thr in rng: 129 | if seq_name[0] == "i": 130 | i_err[thr] += np.mean(dist <= thr) 131 | else: 132 | v_err[thr] += np.mean(dist <= thr) 133 | 134 | # estimate homography 135 | gt_homo = homography 136 | pred_homo, _ = cv2.findHomography(mkpts_a,mkpts_b,cv2.USAC_MAGSAC) 137 | if pred_homo is None: 138 | homo_dist = np.array([float("inf")]) 139 | else: 140 | corners = np.array( 141 | [ 142 | [0, 0], 143 | [ref_img_shape[1] - 1, 0], 144 | [0, ref_img_shape[0] - 1], 145 | [ref_img_shape[1] - 1, ref_img_shape[0] - 1], 146 | ] 147 | ) 148 | real_warped_corners = homo_trans(corners, gt_homo) 149 | warped_corners = homo_trans(corners, pred_homo) 150 | homo_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1)) 151 | 152 | for thr in rng: 153 | if seq_name[0] == "i": 154 | i_err_homo[thr] += np.mean(homo_dist <= thr) 155 | else: 156 | v_err_homo[thr] += np.mean(homo_dist <= thr) 157 | 158 | seq_type = np.array(seq_type) 159 | n_feats = np.array(n_feats) 160 | n_matches = np.array(n_matches) 161 | 162 | return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches] 163 | 164 | 165 | if __name__ == "__main__": 166 | errors = {} 167 | 168 | weights=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth') 169 | liftfeat=LiftFeat(weight=weights) 170 | 171 | errors = benchmark_features(liftfeat.match_liftfeat) 172 | 173 | i_err, v_err, i_err_hom, v_err_hom, _ = errors 174 | 175 | cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 176 | 177 | print(f'\n==={cur_time}==={args.name}===') 178 | print(f"MHA@3 MHA@5 MHA@7") 179 | for thr in [3, 5, 7]: 180 | ill_err_hom = i_err_hom[thr] / (n_i * 5) 181 | view_err_hom = v_err_hom[thr] / (n_v * 5) 182 | print(f"{ill_err_hom * 100:.2f}%-{view_err_hom * 100:.2f}%") 183 | -------------------------------------------------------------------------------- /evaluation/MegaDepth1500_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | from pathlib import Path 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | import tqdm 9 | from copy import deepcopy 10 | from torchvision.transforms import ToTensor 11 | import torch.nn.functional as F 12 | import json 13 | 14 | import scipy.io as scio 15 | import poselib 16 | 17 | import argparse 18 | import datetime 19 | 20 | parser=argparse.ArgumentParser(description='MegaDepth dataset evaluation script') 21 | parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name') 22 | parser.add_argument('--gpu',type=str,default='0',help='GPU ID') 23 | args=parser.parse_args() 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 26 | 27 | sys.path.append(os.path.join(os.path.dirname(__file__),'../')) 28 | from models.liftfeat_wrapper import LiftFeat 29 | from evaluation.eval_utils import * 30 | 31 | from torch.utils.data import Dataset,DataLoader 32 | 33 | use_cuda = torch.cuda.is_available() 34 | device = "cuda" if use_cuda else "cpu" 35 | 36 | DATASET_ROOT = os.path.join(os.path.dirname(__file__),'../data/megadepth_test_1500') 37 | DATASET_JSON = os.path.join(os.path.dirname(__file__),'../data/megadepth_1500.json') 38 | 39 | class MegaDepth1500(Dataset): 40 | """ 41 | Streamlined MegaDepth-1500 dataloader. The camera poses & metadata are stored in a formatted json for facilitating 42 | the download of the dataset and to keep the setup as simple as possible. 43 | """ 44 | def __init__(self, json_file, root_dir): 45 | # Load the info & calibration from the JSON 46 | with open(json_file, 'r') as f: 47 | self.data = json.load(f) 48 | 49 | self.root_dir = root_dir 50 | 51 | if not os.path.exists(self.root_dir): 52 | raise RuntimeError( 53 | f"Dataset {self.root_dir} does not exist! \n \ 54 | > If you didn't download the dataset, use the downloader tool: python3 -m modules.dataset.download -h") 55 | 56 | def __len__(self): 57 | return len(self.data) 58 | 59 | def __getitem__(self, idx): 60 | data = deepcopy(self.data[idx]) 61 | 62 | h1, w1 = data['size0_hw'] 63 | h2, w2 = data['size1_hw'] 64 | 65 | # Here we resize the images to max_dim = 1200, as described in the paper, and adjust the image such that it is divisible by 32 66 | # following the protocol of the LoFTR's Dataloader (intrinsics are corrected accordingly). 67 | # For adapting this with different resolution, you would need to re-scale intrinsics below. 68 | image0 = cv2.resize(cv2.imread(f"{self.root_dir}/{data['pair_names'][0]}"),(w1, h1)) 69 | 70 | image1 = cv2.resize(cv2.imread(f"{self.root_dir}/{data['pair_names'][1]}"),(w2, h2)) 71 | 72 | data['image0'] = torch.tensor(image0.astype(np.float32)/255).permute(2,0,1) 73 | data['image1'] = torch.tensor(image1.astype(np.float32)/255).permute(2,0,1) 74 | 75 | for k,v in data.items(): 76 | if k not in ('dataset_name', 'scene_id', 'pair_id', 'pair_names', 'size0_hw', 'size1_hw', 'image0', 'image1'): 77 | data[k] = torch.tensor(np.array(v, dtype=np.float32)) 78 | 79 | return data 80 | 81 | if __name__ == "__main__": 82 | weights=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth') 83 | liftfeat=LiftFeat(weight=weights) 84 | 85 | dataset = MegaDepth1500(json_file = DATASET_JSON, root_dir = DATASET_ROOT) 86 | 87 | loader = DataLoader(dataset, batch_size=1, shuffle=False) 88 | 89 | metrics = {} 90 | R_errs = [] 91 | t_errs = [] 92 | inliers = [] 93 | 94 | results=[] 95 | 96 | cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 97 | 98 | for d in tqdm.tqdm(loader, desc="processing"): 99 | error_infos = compute_pose_error(liftfeat.match_liftfeat,d) 100 | results.append(error_infos) 101 | 102 | print(f'\n==={cur_time}==={args.name}===') 103 | d_err_auc,errors=compute_maa(results) 104 | for s_k,s_v in d_err_auc.items(): 105 | print(f'{s_k}: {s_v*100}') 106 | -------------------------------------------------------------------------------- /evaluation/__pycache__/eval_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/evaluation/__pycache__/eval_utils.cpython-38.pyc -------------------------------------------------------------------------------- /evaluation/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import poselib 4 | 5 | 6 | def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): 7 | # angle error between 2 vectors 8 | t_gt = T_0to1[:3, 3] 9 | n = np.linalg.norm(t) * np.linalg.norm(t_gt) 10 | t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) 11 | t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity 12 | if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging 13 | t_err = 0 14 | 15 | # angle error between 2 rotation matrices 16 | R_gt = T_0to1[:3, :3] 17 | cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 18 | cos = np.clip(cos, -1.0, 1.0) # handle numercial errors 19 | R_err = np.rad2deg(np.abs(np.arccos(cos))) 20 | 21 | return t_err, R_err 22 | 23 | def intrinsics_to_camera(K): 24 | px, py = K[0, 2], K[1, 2] 25 | fx, fy = K[0, 0], K[1, 1] 26 | return { 27 | "model": "PINHOLE", 28 | "width": int(2 * px), 29 | "height": int(2 * py), 30 | "params": [fx, fy, px, py], 31 | } 32 | 33 | 34 | def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): 35 | M, info = poselib.estimate_relative_pose( 36 | kpts0, kpts1, 37 | intrinsics_to_camera(K0), 38 | intrinsics_to_camera(K1), 39 | {"max_epipolar_error": thresh, 40 | "success_prob": conf, 41 | "min_iterations": 20, 42 | "max_iterations": 1_000}, 43 | ) 44 | 45 | R, t, inl = M.R, M.t, info["inliers"] 46 | inl = np.array(inl) 47 | ret = (R, t, inl) 48 | 49 | return ret 50 | 51 | def tensor2bgr(t): 52 | return (t.cpu()[0].permute(1,2,0).numpy()*255).astype(np.uint8) 53 | 54 | def compute_pose_error(match_fn,data): 55 | result = {} 56 | 57 | with torch.no_grad(): 58 | mkpts0,mkpts1=match_fn(tensor2bgr(data["image0"]),tensor2bgr(data["image1"])) 59 | 60 | mkpts0=mkpts0 * data["scale0"].numpy() 61 | mkpts1=mkpts1 * data["scale1"].numpy() 62 | 63 | K0, K1 = data["K0"][0].numpy(), data["K1"][0].numpy() 64 | T_0to1 = data["T_0to1"][0].numpy() 65 | T_1to0 = data["T_1to0"][0].numpy() 66 | 67 | result={} 68 | conf = 0.99999 69 | 70 | ret = estimate_pose(mkpts0,mkpts1,K0,K1,4.0,conf) 71 | if ret is not None: 72 | R, t, inliers = ret 73 | t_err, R_err = relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0) 74 | result['R_err'] = R_err 75 | result['t_err'] = t_err 76 | 77 | return result 78 | 79 | 80 | def error_auc(errors, thresholds=[5, 10, 20]): 81 | """ 82 | Args: 83 | errors (list): [N,] 84 | thresholds (list) 85 | """ 86 | errors = [0] + sorted(list(errors)) 87 | recall = list(np.linspace(0, 1, len(errors))) 88 | 89 | aucs = [] 90 | 91 | for thr in thresholds: 92 | last_index = np.searchsorted(errors, thr) 93 | y = recall[:last_index] + [recall[last_index-1]] 94 | x = errors[:last_index] + [thr] 95 | aucs.append(np.trapz(y, x) / thr) 96 | 97 | return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} 98 | 99 | def compute_maa(pairs, thresholds=[5, 10, 20]): 100 | # print("auc / mAcc on %d pairs" % (len(pairs))) 101 | errors = [] 102 | 103 | for p in pairs: 104 | et = p['t_err'] 105 | er = p['R_err'] 106 | errors.append(max(et, er)) 107 | 108 | d_err_auc = error_auc(errors) 109 | 110 | # for k,v in d_err_auc.items(): 111 | # print(k, ': ', '%.1f'%(v*100)) 112 | 113 | errors = np.array(errors) 114 | 115 | for t in thresholds: 116 | acc = (errors <= t).sum() / len(errors) 117 | # print("mAcc@%d: %.1f "%(t, acc*100)) 118 | 119 | return d_err_auc,errors 120 | 121 | def homo_trans(coord, H): 122 | kpt_num = coord.shape[0] 123 | homo_coord = np.concatenate((coord, np.ones((kpt_num, 1))), axis=-1) 124 | proj_coord = np.matmul(H, homo_coord.T).T 125 | proj_coord = proj_coord / proj_coord[:, 2][..., None] 126 | proj_coord = proj_coord[:, 0:2] 127 | return proj_coord 128 | -------------------------------------------------------------------------------- /loss/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/loss/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import time 7 | 8 | 9 | def dual_softmax_loss(X, Y, temp = 0.2): 10 | if X.size() != Y.size() or X.dim() != 2 or Y.dim() != 2: 11 | raise RuntimeError('Error: X and Y shapes must match and be 2D matrices') 12 | 13 | dist_mat = (X @ Y.t()) * temp 14 | conf_matrix12 = F.log_softmax(dist_mat, dim=1) 15 | conf_matrix21 = F.log_softmax(dist_mat.t(), dim=1) 16 | 17 | with torch.no_grad(): 18 | conf12 = torch.exp( conf_matrix12 ).max(dim=-1)[0] 19 | conf21 = torch.exp( conf_matrix21 ).max(dim=-1)[0] 20 | conf = conf12 * conf21 21 | 22 | target = torch.arange(len(X), device = X.device) 23 | 24 | loss = F.nll_loss(conf_matrix12, target) + \ 25 | F.nll_loss(conf_matrix21, target) 26 | 27 | return loss, conf 28 | 29 | 30 | class LiftFeatLoss(nn.Module): 31 | def __init__(self,device,lam_descs=1,lam_fb_descs=1,lam_kpts=1,lam_heatmap=1,lam_normals=1,lam_coordinates=1,lam_fb_coordinates=1,depth_spvs=False): 32 | super().__init__() 33 | 34 | # loss parameters 35 | self.lam_descs=lam_descs 36 | self.lam_fb_descs=lam_fb_descs 37 | self.lam_kpts=lam_kpts 38 | self.lam_heatmap=lam_heatmap 39 | self.lam_normals=lam_normals 40 | self.lam_coordinates=lam_coordinates 41 | self.lam_fb_coordinates=lam_fb_coordinates 42 | self.depth_spvs=depth_spvs 43 | self.running_descs_loss=0 44 | self.running_kpts_loss=0 45 | self.running_heatmaps_loss=0 46 | self.loss_descs=0 47 | self.loss_fb_descs=0 48 | self.loss_kpts=0 49 | self.loss_heatmaps=0 50 | self.loss_normals=0 51 | self.loss_coordinates=0 52 | self.loss_fb_coordinates=0 53 | self.acc_coarse=0 54 | self.acc_fb_coarse=0 55 | self.acc_kpt=0 56 | self.acc_coordinates=0 57 | self.acc_fb_coordinates=0 58 | 59 | # device 60 | self.dev=device 61 | 62 | 63 | def check_accuracy(self,m1,m2,pts1=None,pts2=None,plot=False): 64 | with torch.no_grad(): 65 | #dist_mat = torch.cdist(X,Y) 66 | dist_mat = m1 @ m2.t() 67 | nn = torch.argmax(dist_mat, dim=1) 68 | #nn = torch.argmin(dist_mat, dim=1) 69 | correct = nn == torch.arange(len(m1), device = m1.device) 70 | 71 | if pts1 is not None and plot: 72 | import matplotlib.pyplot as plt 73 | canvas = torch.zeros((60, 80),device=m1.device) 74 | pts1 = pts1[~correct] 75 | canvas[pts1[:,1].long(), pts1[:,0].long()] = 1 76 | canvas = canvas.cpu().numpy() 77 | plt.imshow(canvas), plt.show() 78 | 79 | acc = correct.sum().item() / len(m1) 80 | return acc 81 | 82 | def compute_descriptors_loss(self,descs1,descs2,pts): 83 | loss=[] 84 | acc=0 85 | B,_,H,W=descs1.shape 86 | conf_list=[] 87 | 88 | for b in range(B): 89 | pts1,pts2=pts[b][:,:2],pts[b][:,2:] 90 | m1=descs1[b,:,pts1[:,1].long(),pts1[:,0].long()].permute(1,0) 91 | m2=descs2[b,:,pts2[:,1].long(),pts2[:,0].long()].permute(1,0) 92 | 93 | loss_per,conf_per=dual_softmax_loss(m1,m2) 94 | loss.append(loss_per.unsqueeze(0)) 95 | conf_list.append(conf_per) 96 | 97 | acc_coarse_per=self.check_accuracy(m1,m2) 98 | acc += acc_coarse_per 99 | 100 | loss=torch.cat(loss,dim=-1).mean() 101 | acc /= B 102 | return loss,acc,conf_list 103 | 104 | 105 | def alike_distill_loss(self,kpts,alike_kpts): 106 | C, H, W = kpts.shape 107 | kpts = kpts.permute(1,2,0) 108 | # get ALike keypoints 109 | with torch.no_grad(): 110 | labels = torch.ones((H, W), dtype = torch.long, device = kpts.device) * 64 # -> Default is non-keypoint (bin 64) 111 | offsets = (((alike_kpts/8) - (alike_kpts/8).long())*8).long() 112 | offsets = offsets[:, 0] + 8*offsets[:, 1] # Linear IDX 113 | labels[(alike_kpts[:,1]/8).long(), (alike_kpts[:,0]/8).long()] = offsets 114 | 115 | kpts = kpts.view(-1,C) 116 | labels = labels.view(-1) 117 | 118 | mask = labels < 64 119 | idxs_pos = mask.nonzero().flatten() 120 | idxs_neg = (~mask).nonzero().flatten() 121 | perm = torch.randperm(idxs_neg.size(0))[:len(idxs_pos)//32] 122 | idxs_neg = idxs_neg[perm] 123 | idxs = torch.cat([idxs_pos, idxs_neg]) 124 | 125 | kpts = kpts[idxs] 126 | labels = labels[idxs] 127 | 128 | with torch.no_grad(): 129 | predicted = kpts.max(dim=-1)[1] 130 | acc = (labels == predicted) 131 | acc = acc.sum() / len(acc) 132 | 133 | kpts = F.log_softmax(kpts,dim=-1) 134 | loss = F.nll_loss(kpts, labels, reduction = 'mean') 135 | 136 | return loss, acc 137 | 138 | 139 | def compute_keypoints_loss(self,kpts1,kpts2,alike_kpts1,alike_kpts2): 140 | loss=[] 141 | acc=0 142 | B,_,H,W=kpts1.shape 143 | 144 | for b in range(B): 145 | loss_per1,acc_per1=self.alike_distill_loss(kpts1[b],alike_kpts1[b]) 146 | loss_per2,acc_per2=self.alike_distill_loss(kpts2[b],alike_kpts2[b]) 147 | loss_per=(loss_per1+loss_per2) 148 | acc_per=(acc_per1+acc_per2)/2 149 | loss.append(loss_per.unsqueeze(0)) 150 | acc += acc_per 151 | 152 | loss=torch.cat(loss,dim=-1).mean() 153 | acc /= B 154 | return loss,acc 155 | 156 | 157 | def compute_heatmaps_loss(self,heatmaps1,heatmaps2,pts,conf_list): 158 | loss=[] 159 | B,_,H,W=heatmaps1.shape 160 | 161 | for b in range(B): 162 | pts1,pts2=pts[b][:,:2],pts[b][:,2:] 163 | h1=heatmaps1[b,0,pts1[:,1].long(),pts1[:,0].long()] 164 | h2=heatmaps2[b,0,pts2[:,1].long(),pts2[:,0].long()] 165 | 166 | conf=conf_list[b] 167 | loss_per1=F.l1_loss(h1,conf) 168 | loss_per2=F.l1_loss(h2,conf) 169 | loss_per=(loss_per1+loss_per2) 170 | loss.append(loss_per.unsqueeze(0)) 171 | 172 | loss=torch.cat(loss,dim=-1).mean() 173 | return loss 174 | 175 | 176 | def normal_loss(self,normal,target_normal): 177 | # import pdb;pdb.set_trace() 178 | normal = normal.permute(1, 2, 0) 179 | target_normal = target_normal.permute(1,2,0) 180 | # loss = F.l1_loss(d_feat, depth_anything_normal_feat) 181 | dot = torch.cosine_similarity(normal, target_normal, dim=2) 182 | valid_mask = target_normal[:, :, 0].float() \ 183 | * (dot.detach() < 0.999).float() \ 184 | * (dot.detach() > -0.999).float() 185 | valid_mask = valid_mask > 0.0 186 | al = torch.acos(dot[valid_mask]) 187 | loss = torch.mean(al) 188 | return loss 189 | 190 | 191 | def compute_normals_loss(self,normals1,normals2,DA_normals1,DA_normals2,megadepth_batch_size,coco_batch_size): 192 | loss=[] 193 | 194 | # import pdb;pdb.set_trace() 195 | 196 | # only MegaDepth image need depth-normal 197 | normals1=normals1[coco_batch_size:,...] 198 | normals2=normals2[coco_batch_size:,...] 199 | for b in range(len(DA_normals1)): 200 | normal1,normal2=normals1[b],normals2[b] 201 | loss_per1=self.normal_loss(normal1,DA_normals1[b].permute(2,0,1)) 202 | loss_per2=self.normal_loss(normal2,DA_normals2[b].permute(2,0,1)) 203 | loss_per=(loss_per1+loss_per2) 204 | loss.append(loss_per.unsqueeze(0)) 205 | 206 | loss=torch.cat(loss,dim=-1).mean() 207 | return loss 208 | 209 | 210 | def coordinate_loss(self,coordinate,conf,pts1): 211 | with torch.no_grad(): 212 | coordinate_detached = pts1 * 8 213 | offset_detached = (coordinate_detached/8) - (coordinate_detached/8).long() 214 | offset_detached = (offset_detached * 8).long() 215 | label = offset_detached[:, 0] + 8*offset_detached[:, 1] 216 | 217 | #pdb.set_trace() 218 | coordinate_log = F.log_softmax(coordinate, dim=-1) 219 | 220 | predicted = coordinate.max(dim=-1)[1] 221 | acc = (label == predicted) 222 | acc = acc[conf > 0.1] 223 | acc = acc.sum() / len(acc) 224 | 225 | loss = F.nll_loss(coordinate_log, label, reduction = 'none') 226 | 227 | #Weight loss by confidence, giving more emphasis on reliable matches 228 | conf = conf / conf.sum() 229 | loss = (loss * conf).sum() 230 | 231 | return loss*2., acc 232 | 233 | def compute_coordinates_loss(self,coordinates,pts,conf_list): 234 | loss=[] 235 | acc=0 236 | B,_,H,W=coordinates.shape 237 | 238 | for b in range(B): 239 | pts1,pts2=pts[b][:,:2],pts[b][:,2:] 240 | coordinate=coordinates[b,:,pts1[:,1].long(),pts1[:,0].long()].permute(1,0) 241 | conf=conf_list[b] 242 | 243 | loss_per,acc_per=self.coordinate_loss(coordinate,conf,pts1) 244 | loss.append(loss_per.unsqueeze(0)) 245 | acc += acc_per 246 | 247 | loss=torch.cat(loss,dim=-1).mean() 248 | acc /= B 249 | 250 | return loss,acc 251 | 252 | 253 | def forward(self, 254 | descs1,fb_descs1,kpts1,normals1, 255 | descs2,fb_descs2,kpts2,normals2, 256 | pts,coordinates,fb_coordinates, 257 | alike_kpts1,alike_kpts2, 258 | DA_normals1,DA_normals2, 259 | megadepth_batch_size,coco_batch_size 260 | ): 261 | # import pdb;pdb.set_trace() 262 | self.loss_descs,self.acc_coarse,conf_list=self.compute_descriptors_loss(descs1,descs2,pts) 263 | self.loss_fb_descs,self.acc_fb_coarse,fb_conf_list=self.compute_descriptors_loss(fb_descs1,fb_descs2,pts) 264 | 265 | # start=time.perf_counter() 266 | self.loss_kpts,self.acc_kpt=self.compute_keypoints_loss(kpts1,kpts2,alike_kpts1,alike_kpts2) 267 | # end=time.perf_counter() 268 | # print(f"kpts loss cost {end-start} seconds") 269 | 270 | # start=time.perf_counter() 271 | self.loss_normals=self.compute_normals_loss(normals1,normals2,DA_normals1,DA_normals2,megadepth_batch_size,coco_batch_size) 272 | # end=time.perf_counter() 273 | # print(f"normal loss cost {end-start} seconds") 274 | 275 | self.loss_coordinates,self.acc_coordinates=self.compute_coordinates_loss(coordinates,pts,conf_list) 276 | self.loss_fb_coordinates,self.acc_fb_coordinates=self.compute_coordinates_loss(fb_coordinates,pts,fb_conf_list) 277 | 278 | return { 279 | 'loss_descs':self.lam_descs*self.loss_descs, 280 | 'acc_coarse':self.acc_coarse, 281 | 'loss_coordinates':self.lam_coordinates*self.loss_coordinates, 282 | 'acc_coordinates':self.acc_coordinates, 283 | 'loss_fb_descs':self.lam_fb_descs*self.loss_fb_descs, 284 | 'acc_fb_coarse':self.acc_fb_coarse, 285 | 'loss_fb_coordinates':self.lam_fb_coordinates*self.loss_fb_coordinates, 286 | 'acc_fb_coordinates':self.acc_fb_coordinates, 287 | 'loss_kpts':self.lam_kpts*self.loss_kpts, 288 | 'acc_kpt':self.acc_kpt, 289 | 'loss_normals':self.lam_normals*self.loss_normals, 290 | } 291 | 292 | -------------------------------------------------------------------------------- /models/__pycache__/interpolator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/interpolator.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/interpolator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/interpolator.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/liftfeat_wrapper.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/liftfeat_wrapper.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/liftfeat_wrapper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/liftfeat_wrapper.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/models/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /models/interpolator.py: -------------------------------------------------------------------------------- 1 | """ 2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching" 3 | 4 | This script is used to interpolate rough descriptors from LiftFeat 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class InterpolateSparse2d(nn.Module): 12 | """ Efficiently interpolate tensor at given sparse 2D positions. """ 13 | def __init__(self, mode = 'bicubic', align_corners = False): 14 | super().__init__() 15 | self.mode = mode 16 | self.align_corners = align_corners 17 | 18 | def normgrid(self, x, H, W): 19 | """ Normalize coords to [-1,1]. """ 20 | return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1. 21 | 22 | def forward(self, x, pos, H, W): 23 | """ 24 | Input 25 | x: [B, C, H, W] feature tensor 26 | pos: [B, N, 2] tensor of positions 27 | H, W: int, original resolution of input 2d positions -- used in normalization [-1,1] 28 | 29 | Returns 30 | [B, N, C] sampled channels at 2d positions 31 | """ 32 | grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype) 33 | x = F.grid_sample(x, grid, mode = self.mode , align_corners = False) 34 | return x.permute(0,2,3,1).squeeze(-2) -------------------------------------------------------------------------------- /models/liftfeat_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | import math 6 | import cv2 7 | 8 | from models.model import LiftFeatSPModel 9 | from models.interpolator import InterpolateSparse2d 10 | from utils.config import featureboost_config 11 | 12 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 13 | 14 | MODEL_PATH = os.path.join(os.path.dirname(__file__), "../weights/LiftFeat.pth") 15 | 16 | 17 | class NonMaxSuppression(torch.nn.Module): 18 | def __init__(self, rep_thr=0.1, top_k=4096): 19 | super(NonMaxSuppression, self).__init__() 20 | self.max_filter = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2) 21 | self.rep_thr = rep_thr 22 | self.top_k = top_k 23 | 24 | def NMS(self, x, threshold=0.05, kernel_size=5): 25 | B, _, H, W = x.shape 26 | pad = kernel_size // 2 27 | local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x) 28 | pos = (x == local_max) & (x > threshold) 29 | pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos] 30 | 31 | pad_val = max([len(x) for x in pos_batched]) 32 | pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device) 33 | 34 | # Pad kpts and build (B, N, 2) tensor 35 | for b in range(len(pos_batched)): 36 | pos[b, : len(pos_batched[b]), :] = pos_batched[b] 37 | 38 | return pos 39 | 40 | def forward(self, score): 41 | pos = self.NMS(score, self.rep_thr) 42 | 43 | return pos 44 | 45 | 46 | def load_model(model, weight_path): 47 | pretrained_weights = torch.load(weight_path, map_location="cpu") 48 | 49 | model_keys = set(model.state_dict().keys()) 50 | pretrained_keys = set(pretrained_weights.keys()) 51 | 52 | missing_keys = model_keys - pretrained_keys 53 | unexpected_keys = pretrained_keys - model_keys 54 | 55 | # if missing_keys: 56 | # print("Missing keys in pretrained weights:", missing_keys) 57 | # else: 58 | # print("No missing keys in pretrained weights.") 59 | 60 | # if unexpected_keys: 61 | # print("Unexpected keys in pretrained weights:", unexpected_keys) 62 | # else: 63 | # print("No unexpected keys in pretrained weights.") 64 | 65 | if not missing_keys and not unexpected_keys: 66 | model.load_state_dict(pretrained_weights) 67 | print("load weight successfully.") 68 | else: 69 | model.load_state_dict(pretrained_weights, strict=False) 70 | # print("There were issues with the keys.") 71 | return model 72 | 73 | 74 | import torch.nn as nn 75 | 76 | 77 | class LiftFeat(nn.Module): 78 | def __init__(self, weight=MODEL_PATH, top_k=4096, detect_threshold=0.1): 79 | super().__init__() 80 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 81 | self.net = LiftFeatSPModel(featureboost_config).to(self.device).eval() 82 | self.top_k = top_k 83 | self.sampler = InterpolateSparse2d("bicubic") 84 | self.net = load_model(self.net, weight) 85 | self.detector = NonMaxSuppression(rep_thr=detect_threshold) 86 | self.net = self.net.to(self.device) 87 | self.detector = self.detector.to(self.device) 88 | self.sampler = self.sampler.to(self.device) 89 | 90 | def image_preprocess(self, image: np.ndarray): 91 | H, W, C = image.shape[0], image.shape[1], image.shape[2] 92 | 93 | _H = math.ceil(H / 32) * 32 94 | _W = math.ceil(W / 32) * 32 95 | 96 | pad_h = _H - H 97 | pad_w = _W - W 98 | 99 | image = cv2.copyMakeBorder(image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, None, (0, 0, 0)) 100 | 101 | pad_info = [0, pad_h, 0, pad_w] 102 | 103 | if len(image.shape) == 3: 104 | image = image[None, ...] 105 | 106 | image = torch.tensor(image).permute(0, 3, 1, 2) / 255 107 | image = image.to(device) 108 | 109 | return image, pad_info 110 | 111 | @torch.inference_mode() 112 | def extract(self, image: np.ndarray): 113 | image, pad_info = self.image_preprocess(image) 114 | B, _, _H1, _W1 = image.shape 115 | 116 | M1, K1, D1 = self.net.forward1(image) 117 | refine_M = self.net.forward2(M1, K1, D1) 118 | 119 | refine_M = refine_M.reshape(M1.shape[0], M1.shape[2], M1.shape[3], -1).permute(0, 3, 1, 2) 120 | refine_M = torch.nn.functional.normalize(refine_M, 2, dim=1) 121 | 122 | descs_map = refine_M 123 | 124 | scores = torch.softmax(K1, dim=1)[:, :64] 125 | heatmap = scores.permute(0, 2, 3, 1).reshape(scores.shape[0], scores.shape[2], scores.shape[3], 8, 8) 126 | heatmap = heatmap.permute(0, 1, 3, 2, 4).reshape(scores.shape[0], 1, scores.shape[2] * 8, scores.shape[3] * 8) 127 | 128 | pos = self.detector(heatmap) 129 | kpts = pos.squeeze(0) 130 | mask_w = kpts[..., 0] < (_W1 - pad_info[-1]) 131 | kpts = kpts[mask_w] 132 | mask_h = kpts[..., 1] < (_H1 - pad_info[1]) 133 | kpts = kpts[mask_h] 134 | 135 | scores = self.sampler(heatmap, kpts.unsqueeze(0), _H1, _W1) 136 | scores = scores.squeeze(0).reshape(-1) 137 | descs = self.sampler(descs_map, kpts.unsqueeze(0), _H1, _W1) 138 | descs = torch.nn.functional.normalize(descs, p=2, dim=1) 139 | descs = descs.squeeze(0) 140 | 141 | return {"descriptors": descs, "keypoints": kpts, "scores": scores} 142 | 143 | def match_liftfeat(self, img1, img2, min_cossim=-1): 144 | # import pdb;pdb.set_trace() 145 | data1 = self.extract(img1) 146 | data2 = self.extract(img2) 147 | 148 | kpts1, feats1 = data1["keypoints"], data1["descriptors"] 149 | kpts2, feats2 = data2["keypoints"], data2["descriptors"] 150 | 151 | cossim = feats1 @ feats2.t() 152 | cossim_t = feats2 @ feats1.t() 153 | 154 | _, match12 = cossim.max(dim=1) 155 | _, match21 = cossim_t.max(dim=1) 156 | 157 | idx0 = torch.arange(len(match12), device=match12.device) 158 | mutual = match21[match12] == idx0 159 | 160 | if min_cossim > 0: 161 | cossim, _ = cossim.max(dim=1) 162 | good = cossim > min_cossim 163 | idx0 = idx0[mutual & good] 164 | idx1 = match12[mutual & good] 165 | else: 166 | idx0 = idx0[mutual] 167 | idx1 = match12[mutual] 168 | 169 | mkpts1, mkpts2 = kpts1[idx0], kpts2[idx1] 170 | mkpts1, mkpts2 = mkpts1.cpu().numpy(), mkpts2.cpu().numpy() 171 | 172 | return mkpts1, mkpts2 173 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | "LiftFeat: 3D Geometry-Aware Local Feature Matching" 4 | """ 5 | 6 | import numpy as np 7 | import os 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | import tqdm 13 | import math 14 | import cv2 15 | 16 | import sys 17 | sys.path.append('/home/yepeng_liu/code_python/laiwenpeng/LiftFeat') 18 | from utils.featurebooster import FeatureBooster 19 | from utils.config import featureboost_config 20 | 21 | # from models.model_dfb import LiftFeatModel 22 | # from models.interpolator import InterpolateSparse2d 23 | # from third_party.config import featureboost_config 24 | 25 | """ 26 | foundational functions 27 | """ 28 | def simple_nms(scores, radius): 29 | """Perform non maximum suppression on the heatmap using max-pooling. 30 | This method does not suppress contiguous points that have the same score. 31 | Args: 32 | scores: the score heatmap of size `(B, H, W)`. 33 | radius: an integer scalar, the radius of the NMS window. 34 | """ 35 | 36 | def max_pool(x): 37 | return torch.nn.functional.max_pool2d( 38 | x, kernel_size=radius * 2 + 1, stride=1, padding=radius 39 | ) 40 | 41 | zeros = torch.zeros_like(scores) 42 | max_mask = scores == max_pool(scores) 43 | for _ in range(2): 44 | supp_mask = max_pool(max_mask.float()) > 0 45 | supp_scores = torch.where(supp_mask, zeros, scores) 46 | new_max_mask = supp_scores == max_pool(supp_scores) 47 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 48 | return torch.where(max_mask, scores, zeros) 49 | 50 | 51 | def top_k_keypoints(keypoints, scores, k): 52 | if k >= len(keypoints): 53 | return keypoints, scores 54 | scores, indices = torch.topk(scores, k, dim=0, sorted=True) 55 | return keypoints[indices], scores 56 | 57 | 58 | def sample_k_keypoints(keypoints, scores, k): 59 | if k >= len(keypoints): 60 | return keypoints, scores 61 | indices = torch.multinomial(scores, k, replacement=False) 62 | return keypoints[indices], scores[indices] 63 | 64 | 65 | def soft_argmax_refinement(keypoints, scores, radius: int): 66 | width = 2 * radius + 1 67 | sum_ = torch.nn.functional.avg_pool2d( 68 | scores[:, None], width, 1, radius, divisor_override=1 69 | ) 70 | ar = torch.arange(-radius, radius + 1).to(scores) 71 | kernel_x = ar[None].expand(width, -1)[None, None] 72 | dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius) 73 | dy = torch.nn.functional.conv2d( 74 | scores[:, None], kernel_x.transpose(2, 3), padding=radius 75 | ) 76 | dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None] 77 | refined_keypoints = [] 78 | for i, kpts in enumerate(keypoints): 79 | delta = dydx[i][tuple(kpts.t())] 80 | refined_keypoints.append(kpts.float() + delta) 81 | return refined_keypoints 82 | 83 | 84 | # Legacy (broken) sampling of the descriptors 85 | def sample_descriptors(keypoints, descriptors, s): 86 | b, c, h, w = descriptors.shape 87 | keypoints = keypoints - s / 2 + 0.5 88 | keypoints /= torch.tensor( 89 | [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], 90 | ).to( 91 | keypoints 92 | )[None] 93 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1) 94 | args = {"align_corners": True} if torch.__version__ >= "1.3" else {} 95 | descriptors = torch.nn.functional.grid_sample( 96 | descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args 97 | ) 98 | descriptors = torch.nn.functional.normalize( 99 | descriptors.reshape(b, c, -1), p=2, dim=1 100 | ) 101 | return descriptors 102 | 103 | 104 | # The original keypoint sampling is incorrect. We patch it here but 105 | # keep the original one above for legacy. 106 | def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): 107 | """Interpolate descriptors at keypoint locations""" 108 | b, c, h, w = descriptors.shape 109 | keypoints = keypoints / (keypoints.new_tensor([w, h]) * s) 110 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1) 111 | descriptors = torch.nn.functional.grid_sample( 112 | descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False 113 | ) 114 | descriptors = torch.nn.functional.normalize( 115 | descriptors.reshape(b, c, -1), p=2, dim=1 116 | ) 117 | return descriptors 118 | 119 | 120 | class UpsampleLayer(nn.Module): 121 | def __init__(self, in_channels): 122 | super().__init__() 123 | # 定义特征提取层,减少通道数同时增加特征提取能力 124 | self.conv = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, stride=1, padding=1) 125 | # 使用BN层 126 | self.bn = nn.BatchNorm2d(in_channels//2) 127 | # 使用LeakyReLU激活函数 128 | self.leaky_relu = nn.LeakyReLU(0.1) 129 | 130 | def forward(self, x): 131 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) 132 | x = self.leaky_relu(self.bn(self.conv(x))) 133 | 134 | return x 135 | 136 | 137 | class KeypointHead(nn.Module): 138 | def __init__(self,in_channels,out_channels): 139 | super().__init__() 140 | self.layer1=BaseLayer(in_channels,32) 141 | self.layer2=BaseLayer(32,32) 142 | self.layer3=BaseLayer(32,64) 143 | self.layer4=BaseLayer(64,64) 144 | self.layer5=BaseLayer(64,128) 145 | 146 | self.conv=nn.Conv2d(128,out_channels,kernel_size=3,stride=1,padding=1) 147 | self.bn=nn.BatchNorm2d(65) 148 | 149 | def forward(self,x): 150 | x=self.layer1(x) 151 | x=self.layer2(x) 152 | x=self.layer3(x) 153 | x=self.layer4(x) 154 | x=self.layer5(x) 155 | x=self.bn(self.conv(x)) 156 | return x 157 | 158 | 159 | class DescriptorHead(nn.Module): 160 | def __init__(self,in_channels,out_channels): 161 | super().__init__() 162 | self.layer=nn.Sequential( 163 | BaseLayer(in_channels,32), 164 | BaseLayer(32,32,activation=False), 165 | BaseLayer(32,64,activation=False), 166 | BaseLayer(64,out_channels,activation=False) 167 | ) 168 | 169 | def forward(self,x): 170 | x=self.layer(x) 171 | # x=nn.functional.softmax(x,dim=1) 172 | return x 173 | 174 | 175 | class HeatmapHead(nn.Module): 176 | def __init__(self,in_channels,mid_channels,out_channels): 177 | super().__init__() 178 | self.convHa = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) 179 | self.bnHa = nn.BatchNorm2d(mid_channels) 180 | self.convHb = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) 181 | self.bnHb = nn.BatchNorm2d(out_channels) 182 | self.leaky_relu = nn.LeakyReLU(0.1) 183 | 184 | def forward(self,x): 185 | x = self.leaky_relu(self.bnHa(self.convHa(x))) 186 | x = self.leaky_relu(self.bnHb(self.convHb(x))) 187 | 188 | x = torch.sigmoid(x) 189 | return x 190 | 191 | 192 | class DepthHead(nn.Module): 193 | def __init__(self, in_channels): 194 | super().__init__() 195 | self.upsampleDa = UpsampleLayer(in_channels) 196 | self.upsampleDb = UpsampleLayer(in_channels//2) 197 | self.upsampleDc = UpsampleLayer(in_channels//4) 198 | 199 | self.convDepa = nn.Conv2d(in_channels//2+in_channels, in_channels//2, kernel_size=3, stride=1, padding=1) 200 | self.bnDepa = nn.BatchNorm2d(in_channels//2) 201 | self.convDepb = nn.Conv2d(in_channels//4+in_channels//2, in_channels//4, kernel_size=3, stride=1, padding=1) 202 | self.bnDepb = nn.BatchNorm2d(in_channels//4) 203 | self.convDepc = nn.Conv2d(in_channels//8+in_channels//4, 3, kernel_size=3, stride=1, padding=1) 204 | self.bnDepc = nn.BatchNorm2d(3) 205 | 206 | self.leaky_relu = nn.LeakyReLU(0.1) 207 | 208 | def forward(self, x): 209 | x0 = F.interpolate(x, scale_factor=2,mode='bilinear',align_corners=False) 210 | x1 = self.upsampleDa(x) 211 | x1 = torch.cat([x0,x1],dim=1) 212 | x1 = self.leaky_relu(self.bnDepa(self.convDepa(x1))) 213 | 214 | x1_0 = F.interpolate(x1,scale_factor=2,mode='bilinear',align_corners=False) 215 | x2 = self.upsampleDb(x1) 216 | x2 = torch.cat([x1_0,x2],dim=1) 217 | x2 = self.leaky_relu(self.bnDepb(self.convDepb(x2))) 218 | 219 | x2_0 = F.interpolate(x2,scale_factor=2,mode='bilinear',align_corners=False) 220 | x3 = self.upsampleDc(x2) 221 | x3 = torch.cat([x2_0,x3],dim=1) 222 | x = self.leaky_relu(self.bnDepc(self.convDepc(x3))) 223 | 224 | x = F.normalize(x,p=2,dim=1) 225 | return x 226 | 227 | 228 | class BaseLayer(nn.Module): 229 | def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False,activation=True): 230 | super().__init__() 231 | if activation: 232 | self.layer=nn.Sequential( 233 | nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias), 234 | nn.BatchNorm2d(out_channels,affine=False), 235 | nn.ReLU(inplace=True) 236 | ) 237 | else: 238 | self.layer=nn.Sequential( 239 | nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias), 240 | nn.BatchNorm2d(out_channels,affine=False) 241 | ) 242 | 243 | def forward(self,x): 244 | return self.layer(x) 245 | 246 | 247 | class LiftFeatSPModel(nn.Module): 248 | default_conf = { 249 | "has_detector": True, 250 | "has_descriptor": True, 251 | "descriptor_dim": 64, 252 | # Inference 253 | "sparse_outputs": True, 254 | "dense_outputs": False, 255 | "nms_radius": 4, 256 | "refinement_radius": 0, 257 | "detection_threshold": 0.005, 258 | "max_num_keypoints": -1, 259 | "max_num_keypoints_val": None, 260 | "force_num_keypoints": False, 261 | "randomize_keypoints_training": False, 262 | "remove_borders": 4, 263 | "legacy_sampling": True, # True to use the old broken sampling 264 | } 265 | 266 | def __init__(self, featureboost_config, use_kenc=False, use_normal=True, use_cross=True): 267 | super().__init__() 268 | self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') 269 | self.descriptor_dim = 64 270 | 271 | self.norm = nn.InstanceNorm2d(1) 272 | 273 | self.relu = nn.ReLU(inplace=True) 274 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 275 | c1,c2,c3,c4,c5 = 24,24,64,64,128 276 | 277 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) 278 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 279 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 280 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 281 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 282 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 283 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 284 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 285 | self.conv5a = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 286 | self.conv5b = nn.Conv2d(c5, c5, kernel_size=3, stride=1, padding=1) 287 | 288 | self.upsample4 = UpsampleLayer(c4) 289 | self.upsample5 = UpsampleLayer(c5) 290 | self.conv_fusion45 = nn.Conv2d(c5//2+c4,c4,kernel_size=3,stride=1,padding=1) 291 | self.conv_fusion34 = nn.Conv2d(c4//2+c3,c3,kernel_size=3,stride=1,padding=1) 292 | 293 | # detector 294 | self.keypoint_head = KeypointHead(in_channels=c3,out_channels=65) 295 | # descriptor 296 | self.descriptor_head = DescriptorHead(in_channels=c3,out_channels=self.descriptor_dim) 297 | # # heatmap 298 | # self.heatmap_head = HeatmapHead(in_channels=c3,mid_channels=c3,out_channels=1) 299 | # depth 300 | self.depth_head = DepthHead(c3) 301 | 302 | self.fine_matcher = nn.Sequential( 303 | nn.Linear(128, 512), 304 | nn.BatchNorm1d(512, affine=False), 305 | nn.ReLU(inplace = True), 306 | nn.Linear(512, 512), 307 | nn.BatchNorm1d(512, affine=False), 308 | nn.ReLU(inplace = True), 309 | nn.Linear(512, 512), 310 | nn.BatchNorm1d(512, affine=False), 311 | nn.ReLU(inplace = True), 312 | nn.Linear(512, 512), 313 | nn.BatchNorm1d(512, affine=False), 314 | nn.ReLU(inplace = True), 315 | nn.Linear(512, 64), 316 | ) 317 | 318 | # feature_booster 319 | self.feature_boost = FeatureBooster(featureboost_config, use_kenc=use_kenc, use_cross=use_cross, use_normal=use_normal) 320 | 321 | def feature_extract(self, x): 322 | x1 = self.relu(self.conv1a(x)) 323 | x1 = self.relu(self.conv1b(x1)) 324 | x1 = self.pool(x1) 325 | x2 = self.relu(self.conv2a(x1)) 326 | x2 = self.relu(self.conv2b(x2)) 327 | x2 = self.pool(x2) 328 | x3 = self.relu(self.conv3a(x2)) 329 | x3 = self.relu(self.conv3b(x3)) 330 | x3 = self.pool(x3) 331 | x4 = self.relu(self.conv4a(x3)) 332 | x4 = self.relu(self.conv4b(x4)) 333 | x4 = self.pool(x4) 334 | x5 = self.relu(self.conv5a(x4)) 335 | x5 = self.relu(self.conv5b(x5)) 336 | x5 = self.pool(x5) 337 | return x3,x4,x5 338 | 339 | def fuse_multi_features(self,x3,x4,x5): 340 | # upsample x5 feature 341 | x5 = self.upsample5(x5) 342 | x4 = torch.cat([x4,x5],dim=1) 343 | x4 = self.conv_fusion45(x4) 344 | 345 | # upsample x4 feature 346 | x4 = self.upsample4(x4) 347 | x3 = torch.cat([x3,x4],dim=1) 348 | x = self.conv_fusion34(x3) 349 | return x 350 | 351 | def _unfold2d(self, x, ws = 2): 352 | """ 353 | Unfolds tensor in 2D with desired ws (window size) and concat the channels 354 | """ 355 | B, C, H, W = x.shape 356 | x = x.unfold(2, ws , ws).unfold(3, ws,ws).reshape(B, C, H//ws, W//ws, ws**2) 357 | return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws) 358 | 359 | 360 | def forward1(self, x): 361 | """ 362 | input: 363 | x -> torch.Tensor(B, C, H, W) grayscale or rgb images 364 | return: 365 | feats -> torch.Tensor(B, 64, H/8, W/8) dense local features 366 | keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map 367 | heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map 368 | 369 | """ 370 | with torch.no_grad(): 371 | x = x.mean(dim=1, keepdim = True) 372 | x = self.norm(x) 373 | 374 | x3,x4,x5 = self.feature_extract(x) 375 | 376 | # features fusion 377 | x = self.fuse_multi_features(x3,x4,x5) 378 | 379 | # keypoint 380 | keypoint_map = self.keypoint_head(x) 381 | # descriptor 382 | des_map = self.descriptor_head(x) 383 | # # heatmap 384 | # heatmap = self.heatmap_head(x) 385 | 386 | # import pdb;pdb.set_trace() 387 | # depth 388 | d_feats = self.depth_head(x) 389 | 390 | return des_map, keypoint_map, d_feats 391 | # return des_map, keypoint_map, heatmap, d_feats 392 | 393 | def forward2(self, descs, kpts, normals): 394 | # import pdb;pdb.set_trace() 395 | normals_feat=self._unfold2d(normals, ws=8) 396 | normals_v=normals_feat.squeeze(0).permute(1,2,0).reshape(-1,normals_feat.shape[1]) 397 | descs_v=descs.squeeze(0).permute(1,2,0).reshape(-1,descs.shape[1]) 398 | kpts_v=kpts.squeeze(0).permute(1,2,0).reshape(-1,kpts.shape[1]) 399 | descs_refine = self.feature_boost(descs_v, kpts_v, normals_v) 400 | return descs_refine 401 | 402 | def forward(self,x): 403 | M1,K1,D1=self.forward1(x) 404 | descs_refine=self.forward2(M1,K1,D1) 405 | return descs_refine,M1,K1,D1 406 | 407 | 408 | if __name__ == "__main__": 409 | img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg') 410 | img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE) 411 | img=cv2.resize(img,(800,608)) 412 | import pdb;pdb.set_trace() 413 | img=torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()/255.0 414 | img=img.cuda() if torch.cuda.is_available() else img 415 | liftfeat_sp=LiftFeatSPModel(featureboost_config).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 416 | des_map, keypoint_map, d_feats=liftfeat_sp.forward1(img) 417 | des_fine=liftfeat_sp.forward2(des_map,keypoint_map,d_feats) 418 | print(des_map.shape) 419 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | einops==0.8.0 4 | kornia==0.7.3 5 | timm==1.0.15 6 | albumentations==1.4.12 7 | imgaug==0.4.0 8 | opencv-python==4.10.0.84 9 | matplotlib==3.7.5 10 | numpy==1.24.4 11 | scikit-image==0.21.0 12 | scipy==1.10.1 13 | pillow==10.3.0 14 | tensorboard==2.14.0 15 | tqdm==4.66.4 16 | omegaconf==2.3.0 17 | thop==0.1.1.post2209072238 18 | poselib 19 | -------------------------------------------------------------------------------- /tools/demo_match_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import yaml 6 | import matplotlib.cm as cm 7 | import argparse 8 | 9 | import sys 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 11 | from models.liftfeat_wrapper import LiftFeat,MODEL_PATH 12 | from utils.post_process import match_features 13 | os.environ['CUDA_VISIBLE_DEVICES']='0' 14 | 15 | use_cuda = torch.cuda.is_available() 16 | device = torch.device("cuda" if use_cuda else "cpu") 17 | 18 | class VideoHandler: 19 | def __init__(self,video_path,size=[640,360]): 20 | self.video_path=video_path 21 | self.size=size 22 | self.cap=cv2.VideoCapture(video_path) 23 | 24 | def get_frame(self): 25 | ret,frame=self.cap.read() 26 | if ret==True: 27 | frame=cv2.resize(frame,(int(self.size[0]),int(self.size[1]))) 28 | return ret,frame 29 | 30 | def draw_video_match(img0,img1,kpts0,kpts1,mkpts0,mkpts1,match_scores,mask,max_match_num=512,margin=15): 31 | H0, W0, c = img0.shape 32 | H1, W1, c = img1.shape 33 | H, W = max(H0, H1), W0 + W1 + margin 34 | 35 | # 构建画布,把两个图像先拼接到一起 36 | out = 255*np.ones((H, W, 3), np.uint8) 37 | out[:H0, :W0, :] = img0 38 | out[:H1, W0+margin:, :] = img1 39 | #out = np.stack([out]*3, -1) 40 | 41 | kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) 42 | 43 | mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) 44 | mkpts0_correct,mkpts1_correct=mkpts0[mask],mkpts1[mask] 45 | mkpts0_wrong,mkpts1_wrong=mkpts0[~mask],mkpts1[~mask] 46 | match_s=match_scores[mask] 47 | 48 | print(f"correct: {mkpts0_correct.shape[0]} wrong: {mkpts0_wrong.shape[0]}") 49 | 50 | if mkpts0_correct.shape[0] > max_match_num: 51 | # perm=np.random.randint(low=0,high=mkpts0_correct.shape[0],size=max_match_num) 52 | # mkpts0_show,mkpts1_show=mkpts0_correct[perm],mkpts1_correct[perm] 53 | mkpts0_show,mkpts1_show=mkpts0_correct,mkpts1_correct 54 | else: 55 | mkpts0_show,mkpts1_show=mkpts0_correct,mkpts1_correct 56 | 57 | # 普通的点 58 | vis_normal_point = True 59 | if (vis_normal_point): 60 | for x, y in mkpts0_show: 61 | cv2.circle(out, (x, y), 2, (47,132,250), -1, lineType=cv2.LINE_AA) 62 | for x, y in mkpts1_show: 63 | cv2.circle(out, (x + margin + W0, y), 2, (47,132,250), -1,lineType=cv2.LINE_AA) 64 | 65 | vis_match_line = True 66 | if (vis_match_line): 67 | for pt0, pt1,score in zip(mkpts0_show, mkpts1_show,match_s): 68 | color_cm = cm.jet(1.0 - score, alpha=0) 69 | color = (int(color_cm[0] * 255), int(color_cm[1] * 255), int(color_cm[2] * 255)) 70 | cv2.line(out, pt0, (W0 + margin + pt1[0], pt1[1]), color, 1) 71 | 72 | return out 73 | 74 | def run_video_demo(std_img_path,video_path): 75 | 76 | 77 | liftfeat=LiftFeat(weight=MODEL_PATH,detect_threshold=0.15) 78 | 79 | std_img=cv2.imread(std_img_path) 80 | std_img=cv2.resize(std_img,(640,360)) 81 | 82 | handler=VideoHandler(video_path) 83 | 84 | # 定义编解码器并创建VideoWriter对象 85 | if not os.path.exists('./output'): 86 | os.makedirs('./output') 87 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者使用 'XVID' 88 | out = cv2.VideoWriter('./output/video_demo.mp4', fourcc, 20.0, (1300, 360)) 89 | K=[[1084.8,0,640.24],[0,1085,354.87],[0,0,1]] 90 | K=np.array(K) 91 | data_std=liftfeat.extract(std_img) 92 | 93 | while True: 94 | ret,frame=handler.get_frame() 95 | if ret==False: 96 | break 97 | 98 | if frame is not None: 99 | data=liftfeat.extract(frame) 100 | idx0, idx1, match_scores=match_features(data_std["descriptors"],data["descriptors"],-1) 101 | mkpts0=data_std["keypoints"][idx0] 102 | mkpts1=data["keypoints"][idx1] 103 | mkpts0_np=mkpts0.cpu().numpy() 104 | mkpts1_np=mkpts1.cpu().numpy() 105 | match_scores_np=match_scores.detach().cpu().numpy() 106 | kpts0 = (mkpts0_np - K[[0, 1], [2, 2]][None]) / K[[0, 1], [0, 1]][None] 107 | kpts1 = (mkpts1_np - K[[0, 1], [2, 2]][None]) / K[[0, 1], [0, 1]][None] 108 | 109 | # normalize ransac threshold 110 | ransac_thr = 0.5 / np.mean([K[0, 0], K[1, 1], K[0, 0], K[1, 1]]) 111 | 112 | if mkpts0_np.shape[0] < 5: 113 | print(f"mkpts size less then 5") 114 | else: 115 | # compute pose with cv2 116 | 117 | E, mask = cv2.findEssentialMat(kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=0.999, method=cv2.RANSAC) 118 | if E is None: 119 | print("\nE is None while trying to recover pose.\n") 120 | continue 121 | match_mask=mask.squeeze(axis=1)>0 122 | show_kpts0,show_kpts1=mkpts0_np[match_mask],mkpts1_np[match_mask] 123 | show_match_scores=match_scores_np[match_mask] 124 | show_mask=np.ones(show_kpts0.shape[0])>0 125 | match_img=draw_video_match(std_img,frame,show_kpts0,show_kpts1,show_kpts0,show_kpts1,show_match_scores,show_mask,margin=20) 126 | kpts0_num,kpts1_num=data_std["keypoints"].shape[0],data["keypoints"].shape[0] 127 | cv2.putText(match_img,f"LiftFeat",(10,20),cv2.FONT_HERSHEY_TRIPLEX,0.5,(0,0,241)) 128 | cv2.putText(match_img,f"Keypoints: {kpts0_num}:{kpts1_num}",(10,40),cv2.FONT_HERSHEY_TRIPLEX,0.5,(0,0,255)) 129 | cv2.putText(match_img,f"Matches: {show_kpts0.shape[0]}",(10,60),cv2.FONT_HERSHEY_TRIPLEX,0.5,(0,0,255)) 130 | out.write(match_img) 131 | 132 | 133 | out.release() 134 | 135 | 136 | 137 | 138 | if __name__=="__main__": 139 | parser = argparse.ArgumentParser(description="Run LiftFeat video matching demo.") 140 | parser.add_argument('--img', type=str, required=True, help='Path to the template image') 141 | parser.add_argument('--video', type=str, required=True, help='Path to the input video') 142 | 143 | args = parser.parse_args() 144 | 145 | run_video_demo(args.img, args.video) 146 | -------------------------------------------------------------------------------- /tools/demo_vo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import argparse 4 | import yaml 5 | import logging 6 | import os 7 | import sys 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 9 | from utils.VisualOdometry import VisualOdometry, AbosluteScaleComputer, create_dataloader, \ 10 | plot_keypoints, create_detector, create_matcher 11 | from models.liftfeat_wrapper import LiftFeat,MODEL_PATH 12 | 13 | 14 | vo_config = { 15 | 'dataset': { 16 | 'name': 'KITTILoader', 17 | 'root_path': '/home/yepeng_liu/code_python/dataset/visual_odometry/kitty/gray', 18 | 'sequence': '10', 19 | 'start': 0 20 | }, 21 | 'detector': { 22 | 'name': 'LiftFeatDetector', 23 | 'descriptor_dim': 64, 24 | 'nms_radius': 5, 25 | 'keypoint_threshold': 0.005, 26 | 'max_keypoints': 4096, 27 | 'remove_borders': 4, 28 | 'cuda': 1 29 | }, 30 | 'matcher': { 31 | 'name': 'FrameByFrameMatcher', 32 | 'type': 'FLANN', 33 | 'FLANN': { 34 | 'kdTrees': 5, 35 | 'searchChecks': 50 36 | }, 37 | 'distance_ratio': 0.75 38 | } 39 | } 40 | 41 | # 可视化当前frame的关键点 42 | def keypoints_plot(img, vo, img_id, path2): 43 | img_ = cv2.imread(path2+str(img_id-1).zfill(6)+".png") 44 | 45 | if not vo.match_kps: 46 | img_ = plot_keypoints(img_, vo.kptdescs["cur"]["keypoints"]) 47 | else: 48 | for index in range(vo.match_kps["ref"].shape[0]): 49 | ref_point = tuple(map(int, vo.match_kps['ref'][index,:])) # 将关键点转换为整数元组 50 | cur_point = tuple(map(int, vo.match_kps['cur'][index,:])) 51 | cv2.line(img_, ref_point, cur_point, (0, 255, 0), 2) # Draw green line 52 | cv2.circle(img_, cur_point, 3, (0, 0, 255), -1) # Draw red circle at current keypoint 53 | 54 | return img_ 55 | 56 | # 负责绘制相机的轨迹并计算估计轨迹与真实轨迹的误差。 57 | class TrajPlotter(object): 58 | def __init__(self): 59 | self.errors = [] 60 | self.traj = np.zeros((800, 800, 3), dtype=np.uint8) 61 | pass 62 | 63 | def update(self, est_xyz, gt_xyz): 64 | x, z = est_xyz[0], est_xyz[2] 65 | gt_x, gt_z = gt_xyz[0], gt_xyz[2] 66 | est = np.array([x, z]).reshape(2) 67 | gt = np.array([gt_x, gt_z]).reshape(2) 68 | error = np.linalg.norm(est - gt) 69 | self.errors.append(error) 70 | avg_error = np.mean(np.array(self.errors)) 71 | # === drawer ================================== 72 | # each point 73 | draw_x, draw_y = int(x) + 80, int(z) + 230 74 | true_x, true_y = int(gt_x) + 80, int(gt_z) + 230 75 | 76 | # draw trajectory 77 | cv2.circle(self.traj, (draw_x, draw_y), 1, (0, 0, 255), 1) 78 | cv2.circle(self.traj, (true_x, true_y), 1, (0, 255, 0), 2) 79 | cv2.rectangle(self.traj, (10, 5), (450, 120), (0, 0, 0), -1) 80 | 81 | # draw text 82 | text = "[AvgError] %2.4fm" % (avg_error) 83 | print(text) 84 | cv2.putText(self.traj, text, (20, 40), 85 | cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) 86 | note = "Green: GT, Red: Predict" 87 | cv2.putText(self.traj, note, (20, 80), 88 | cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) 89 | 90 | return self.traj 91 | 92 | def run_video(args): 93 | # create dataloader 94 | vo_config["dataset"]['root_path'] = args.path1 95 | vo_config["dataset"]['sequence'] = args.id 96 | loader = create_dataloader(vo_config["dataset"]) 97 | # create detector 98 | liftfeat=LiftFeat(weight=MODEL_PATH, detect_threshold=0.25) 99 | # create matcher 100 | matcher = create_matcher(vo_config["matcher"]) 101 | 102 | absscale = AbosluteScaleComputer() 103 | traj_plotter = TrajPlotter() 104 | 105 | 106 | if not os.path.exists('./output'): 107 | os.makedirs('./output') 108 | fname = "kitti_liftfeat_flannmatch" 109 | log_fopen = open("output/" + fname + ".txt", mode='a') 110 | 111 | vo = VisualOdometry(liftfeat, matcher, loader.cam) 112 | 113 | # Initialize video writer for keypoints and trajectory videos 114 | keypoints_video_path = "output/" + fname + "_keypoints_liftfeat.avi" 115 | trajectory_video_path = "output/" + fname + "_trajectory_liftfeat.avi" 116 | 117 | # Set up video writer: choose codec and set FPS and frame size 118 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 119 | fps = 10 # Adjust the FPS according to your input data 120 | frame_size = (1200, 400) # Get frame size from first image 121 | 122 | # Video writers for keypoints and trajectory 123 | keypoints_writer = cv2.VideoWriter(keypoints_video_path, fourcc, fps, frame_size) 124 | trajectory_writer = cv2.VideoWriter(trajectory_video_path, fourcc, fps, (800, 800)) 125 | 126 | for i, img in enumerate(loader): 127 | img_id = loader.img_id 128 | gt_pose = loader.get_cur_pose() 129 | 130 | R, t = vo.update(img, absscale.update(gt_pose)) 131 | 132 | # === log writer ============================== 133 | print(i, t[0, 0], t[1, 0], t[2, 0], gt_pose[0, 3], gt_pose[1, 3], gt_pose[2, 3], file=log_fopen) 134 | 135 | # === drawer ================================== 136 | img1 = keypoints_plot(img, vo, img_id, args.path2) 137 | img1 = cv2.resize(img1, (1200, 400)) 138 | img2 = traj_plotter.update(t, gt_pose[:, 3]) 139 | 140 | # Write frames to videos 141 | keypoints_writer.write(img1) 142 | trajectory_writer.write(img2) 143 | 144 | # Release the video writers 145 | keypoints_writer.release() 146 | trajectory_writer.release() 147 | print(f"Videos saved as {keypoints_video_path} and {trajectory_video_path}") 148 | 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser(description='python_vo') 153 | parser.add_argument('--path1', type=str, default='/home/yepeng_liu/code_python/dataset/visual_odometry/kitty/gray', 154 | help='config file') 155 | parser.add_argument('--path2', type=str, default="/home/yepeng_liu/code_python/dataset/visual_odometry/kitty/color/sequences/03/image_2/", 156 | help='config file') 157 | parser.add_argument('--id', type=str, default="03", 158 | help='config file') 159 | 160 | 161 | args = parser.parse_args() 162 | 163 | run_video(args) 164 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching" 3 | training script 4 | """ 5 | 6 | import argparse 7 | import os 8 | import time 9 | import sys 10 | sys.path.append(os.path.dirname(__file__)) 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser(description="LiftFeat training script.") 14 | parser.add_argument('--name',type=str,default='LiftFeat',help='set process name') 15 | 16 | # MegaDepth dataset setting 17 | parser.add_argument('--use_megadepth',action='store_true') 18 | parser.add_argument('--megadepth_root_path', type=str, 19 | default='/home/yepeng_liu/code_python/dataset/MegaDepth/phoenix/S6/zl548', 20 | help='Path to the MegaDepth dataset root directory.') 21 | parser.add_argument('--megadepth_batch_size', type=int, default=6) 22 | 23 | # COCO20k dataset setting 24 | parser.add_argument('--use_coco',action='store_true') 25 | parser.add_argument('--coco_root_path', type=str, default='/home/yepeng_liu/code_python/dataset/coco_20k', 26 | help='Path to the COCO20k dataset root directory.') 27 | parser.add_argument('--coco_batch_size',type=int,default=4) 28 | 29 | parser.add_argument('--ckpt_save_path', type=str, default='/home/yepeng_liu/code_python/LiftFeat/trained_weights/test', 30 | help='Path to save the checkpoints.') 31 | parser.add_argument('--n_steps', type=int, default=160_000, 32 | help='Number of training steps. Default is 160000.') 33 | parser.add_argument('--lr', type=float, default=3e-4, 34 | help='Learning rate. Default is 0.0003.') 35 | parser.add_argument('--gamma_steplr', type=float, default=0.5, 36 | help='Gamma value for StepLR scheduler. Default is 0.5.') 37 | parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))), 38 | default=(800, 608), help='Training resolution as width,height. Default is (800, 608).') 39 | parser.add_argument('--device_num', type=str, default='0', 40 | help='Device number to use for training. Default is "0".') 41 | parser.add_argument('--dry_run', action='store_true', 42 | help='If set, perform a dry run training with a mini-batch for sanity check.') 43 | parser.add_argument('--save_ckpt_every', type=int, default=500, 44 | help='Save checkpoints every N steps. Default is 500.') 45 | parser.add_argument('--use_coord_loss',action='store_true',help='Enable coordinate loss') 46 | 47 | args = parser.parse_args() 48 | 49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num 50 | 51 | return args 52 | 53 | args = parse_arguments() 54 | 55 | import torch 56 | from torch import nn 57 | from torch import optim 58 | import torch.nn.functional as F 59 | from torch.utils.tensorboard import SummaryWriter 60 | from torch.utils.data import Dataset, DataLoader 61 | 62 | import numpy as np 63 | import tqdm 64 | import glob 65 | 66 | from models.model import LiftFeatSPModel 67 | from loss.loss import LiftFeatLoss 68 | from utils.config import featureboost_config 69 | from models.interpolator import InterpolateSparse2d 70 | from utils.depth_anything_wrapper import DepthAnythingExtractor 71 | from utils.alike_wrapper import ALikeExtractor 72 | 73 | from dataset import megadepth_wrapper 74 | from dataset import coco_wrapper 75 | from dataset.megadepth import MegaDepthDataset 76 | from dataset.coco_augmentor import COCOAugmentor 77 | 78 | import setproctitle 79 | 80 | 81 | class Trainer(): 82 | def __init__(self, megadepth_root_path,use_megadepth,megadepth_batch_size, 83 | coco_root_path,use_coco,coco_batch_size, 84 | ckpt_save_path, 85 | model_name = 'LiftFeat', 86 | n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5, 87 | training_res = (800, 608), device_num="0", dry_run = False, 88 | save_ckpt_every = 500, use_coord_loss = False): 89 | print(f'MegeDepth: {use_megadepth}-{megadepth_batch_size}') 90 | print(f'COCO20k: {use_coco}-{coco_batch_size}') 91 | print(f'Coordinate loss: {use_coord_loss}') 92 | self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu') 93 | 94 | # training model 95 | self.net = LiftFeatSPModel(featureboost_config, use_kenc=False, use_normal=True, use_cross=True).to(self.dev) 96 | self.loss_fn=LiftFeatLoss(self.dev,lam_descs=1,lam_kpts=2,lam_heatmap=1) 97 | 98 | # depth-anything model 99 | self.depth_net=DepthAnythingExtractor('vits',self.dev,256) 100 | 101 | # alike model 102 | self.alike_net=ALikeExtractor('alike-t',self.dev) 103 | 104 | #Setup optimizer 105 | self.steps = n_steps 106 | self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr) 107 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10_000, gamma=gamma_steplr) 108 | 109 | ##################### COCO INIT ########################## 110 | self.use_coco=use_coco 111 | self.coco_batch_size=coco_batch_size 112 | if self.use_coco: 113 | self.augmentor=COCOAugmentor( 114 | img_dir=coco_root_path, 115 | device=self.dev,load_dataset=True, 116 | batch_size=self.coco_batch_size, 117 | out_resolution=training_res, 118 | warp_resolution=training_res, 119 | sides_crop=0.1, 120 | max_num_imgs=3000, 121 | num_test_imgs=5, 122 | photometric=True, 123 | geometric=True, 124 | reload_step=4000 125 | ) 126 | ##################### COCO END ####################### 127 | 128 | 129 | ##################### MEGADEPTH INIT ########################## 130 | self.use_megadepth=use_megadepth 131 | self.megadepth_batch_size=megadepth_batch_size 132 | if self.use_megadepth: 133 | TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices" 134 | TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1" 135 | 136 | TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" 137 | 138 | npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:] 139 | megadepth_dataset = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE, 140 | npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] ) 141 | 142 | self.megadepth_dataloader = DataLoader(megadepth_dataset, batch_size=megadepth_batch_size, shuffle=True) 143 | self.megadepth_data_iter = iter(self.megadepth_dataloader) 144 | ##################### MEGADEPTH INIT END ####################### 145 | 146 | os.makedirs(ckpt_save_path, exist_ok=True) 147 | os.makedirs(ckpt_save_path + '/logdir', exist_ok=True) 148 | 149 | self.dry_run = dry_run 150 | self.save_ckpt_every = save_ckpt_every 151 | self.ckpt_save_path = ckpt_save_path 152 | self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S")) 153 | self.model_name = model_name 154 | self.use_coord_loss = use_coord_loss 155 | 156 | 157 | def generate_train_data(self): 158 | imgs1_t,imgs2_t=[],[] 159 | imgs1_np,imgs2_np=[],[] 160 | # norms0,norms1=[],[] 161 | positives_coarse=[] 162 | 163 | if self.use_coco: 164 | coco_imgs1, coco_imgs2, H1, H2 = coco_wrapper.make_batch(self.augmentor, 0.1) 165 | h_coarse, w_coarse = coco_imgs1[0].shape[-2] // 8, coco_imgs1[0].shape[-1] // 8 166 | _ , positives_coco_coarse = coco_wrapper.get_corresponding_pts(coco_imgs1, coco_imgs2, H1, H2, self.augmentor, h_coarse, w_coarse) 167 | coco_imgs1=coco_imgs1.mean(1,keepdim=True);coco_imgs2=coco_imgs2.mean(1,keepdim=True) 168 | imgs1_t.append(coco_imgs1);imgs2_t.append(coco_imgs2) 169 | positives_coarse += positives_coco_coarse 170 | 171 | if self.use_megadepth: 172 | try: 173 | megadepth_data=next(self.megadepth_data_iter) 174 | except StopIteration: 175 | print('End of MD DATASET') 176 | self.megadepth_data_iter=iter(self.megadepth_dataloader) 177 | megadepth_data=next(self.megadepth_data_iter) 178 | if megadepth_data is not None: 179 | for k in megadepth_data.keys(): 180 | if isinstance(megadepth_data[k],torch.Tensor): 181 | megadepth_data[k]=megadepth_data[k].to(self.dev) 182 | megadepth_imgs1_t,megadepth_imgs2_t=megadepth_data['image0'],megadepth_data['image1'] 183 | megadepth_imgs1_t=megadepth_imgs1_t.mean(1,keepdim=True);megadepth_imgs2_t=megadepth_imgs2_t.mean(1,keepdim=True) 184 | imgs1_t.append(megadepth_imgs1_t);imgs2_t.append(megadepth_imgs2_t) 185 | megadepth_imgs1_np,megadepth_imgs2_np=megadepth_data['image0_np'],megadepth_data['image1_np'] 186 | for np_idx in range(megadepth_imgs1_np.shape[0]): 187 | img1_np,img2_np=megadepth_imgs1_np[np_idx].squeeze(0).cpu().numpy(),megadepth_imgs2_np[np_idx].squeeze(0).cpu().numpy() 188 | imgs1_np.append(img1_np);imgs2_np.append(img2_np) 189 | positives_megadepth_coarse=megadepth_wrapper.spvs_coarse(megadepth_data,8) 190 | positives_coarse += positives_megadepth_coarse 191 | 192 | with torch.no_grad(): 193 | imgs1_t=torch.cat(imgs1_t,dim=0) 194 | imgs2_t=torch.cat(imgs2_t,dim=0) 195 | 196 | return imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse 197 | 198 | 199 | def train(self): 200 | self.net.train() 201 | 202 | with tqdm.tqdm(total=self.steps) as pbar: 203 | for i in range(self.steps): 204 | # import pdb;pdb.set_trace() 205 | imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse=self.generate_train_data() 206 | 207 | #Check if batch is corrupted with too few correspondences 208 | is_corrupted = False 209 | for p in positives_coarse: 210 | if len(p) < 30: 211 | is_corrupted = True 212 | 213 | if is_corrupted: 214 | continue 215 | 216 | # import pdb;pdb.set_trace() 217 | #Forward pass 218 | # start=time.perf_counter() 219 | feats1,kpts1,normals1 = self.net.forward1(imgs1_t) 220 | feats2,kpts2,normals2 = self.net.forward1(imgs2_t) 221 | 222 | coordinates,fb_coordinates=[],[] 223 | alike_kpts1,alike_kpts2=[],[] 224 | DA_normals1,DA_normals2=[],[] 225 | 226 | # import pdb;pdb.set_trace() 227 | 228 | fb_feats1,fb_feats2=[],[] 229 | for b in range(feats1.shape[0]): 230 | feat1=feats1[b].permute(1,2,0).reshape(-1,feats1.shape[1]) 231 | feat2=feats2[b].permute(1,2,0).reshape(-1,feats2.shape[1]) 232 | 233 | coordinate=self.net.fine_matcher(torch.cat([feat1,feat2],dim=-1)) 234 | coordinates.append(coordinate) 235 | 236 | fb_feat1=self.net.forward2(feats1[b].unsqueeze(0),kpts1[b].unsqueeze(0),normals1[b].unsqueeze(0)) 237 | fb_feat2=self.net.forward2(feats2[b].unsqueeze(0),kpts2[b].unsqueeze(0),normals2[b].unsqueeze(0)) 238 | 239 | fb_coordinate=self.net.fine_matcher(torch.cat([fb_feat1,fb_feat2],dim=-1)) 240 | fb_coordinates.append(fb_coordinate) 241 | 242 | fb_feats1.append(fb_feat1.unsqueeze(0));fb_feats2.append(fb_feat2.unsqueeze(0)) 243 | 244 | img1,img2=imgs1_t[b],imgs2_t[b] 245 | img1=img1.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255 246 | img2=img2.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255 247 | alike_kpt1=torch.tensor(self.alike_net.extract_alike_kpts(img1),device=self.dev) 248 | alike_kpt2=torch.tensor(self.alike_net.extract_alike_kpts(img2),device=self.dev) 249 | alike_kpts1.append(alike_kpt1);alike_kpts2.append(alike_kpt2) 250 | 251 | # import pdb;pdb.set_trace() 252 | for b in range(len(imgs1_np)): 253 | megadepth_depth1,megadepth_norm1=self.depth_net.extract(imgs1_np[b]) 254 | megadepth_depth2,megadepth_norm2=self.depth_net.extract(imgs2_np[b]) 255 | DA_normals1.append(megadepth_norm1);DA_normals2.append(megadepth_norm2) 256 | 257 | # import pdb;pdb.set_trace() 258 | fb_feats1=torch.cat(fb_feats1,dim=0) 259 | fb_feats2=torch.cat(fb_feats2,dim=0) 260 | fb_feats1=fb_feats1.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2) 261 | fb_feats2=fb_feats2.reshape(feats2.shape[0],feats2.shape[2],feats2.shape[3],-1).permute(0,3,1,2) 262 | 263 | coordinates=torch.cat(coordinates,dim=0) 264 | coordinates=coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2) 265 | 266 | fb_coordinates=torch.cat(fb_coordinates,dim=0) 267 | fb_coordinates=fb_coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2) 268 | 269 | # end=time.perf_counter() 270 | # print(f"forward1 cost {end-start} seconds") 271 | 272 | loss_items = [] 273 | 274 | # import pdb;pdb.set_trace() 275 | loss_info=self.loss_fn( 276 | feats1,fb_feats1,kpts1,normals1, 277 | feats2,fb_feats2,kpts2,normals2, 278 | positives_coarse, 279 | coordinates,fb_coordinates, 280 | alike_kpts1,alike_kpts2, 281 | DA_normals1,DA_normals2, 282 | self.megadepth_batch_size,self.coco_batch_size) 283 | 284 | loss_descs,acc_coarse=loss_info['loss_descs'],loss_info['acc_coarse'] 285 | loss_coordinates,acc_coordinates=loss_info['loss_coordinates'],loss_info['acc_coordinates'] 286 | loss_fb_descs,acc_fb_coarse=loss_info['loss_fb_descs'],loss_info['acc_fb_coarse'] 287 | loss_fb_coordinates,acc_fb_coordinates=loss_info['loss_fb_coordinates'],loss_info['acc_fb_coordinates'] 288 | loss_kpts,acc_kpt=loss_info['loss_kpts'],loss_info['acc_kpt'] 289 | loss_normals=loss_info['loss_normals'] 290 | 291 | loss_items.append(loss_fb_descs.unsqueeze(0)) 292 | loss_items.append(loss_kpts.unsqueeze(0)) 293 | loss_items.append(loss_normals.unsqueeze(0)) 294 | 295 | if self.use_coord_loss: 296 | loss_items.append(loss_fb_coordinates.unsqueeze(0)) 297 | 298 | # nb_coarse = len(m1) 299 | # nb_coarse = len(fb_m1) 300 | loss = torch.cat(loss_items, -1).mean() 301 | 302 | # Compute Backward Pass 303 | loss.backward() 304 | torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.) 305 | self.opt.step() 306 | self.opt.zero_grad() 307 | self.scheduler.step() 308 | 309 | # import pdb;pdb.set_trace() 310 | if (i+1) % self.save_ckpt_every == 0: 311 | print('saving iter ', i+1) 312 | torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth') 313 | 314 | pbar.set_description( 315 | 'Loss: {:.4f} \ 316 | loss_descs: {:.3f} acc_coarse: {:.3f} \ 317 | loss_coordinates: {:.3f} acc_coordinates: {:.3f} \ 318 | loss_fb_descs: {:.3f} acc_fb_coarse: {:.3f} \ 319 | loss_fb_coordinates: {:.3f} acc_fb_coordinates: {:.3f} \ 320 | loss_kpts: {:.3f} acc_kpts: {:.3f} \ 321 | loss_normals: {:.3f}'.format( \ 322 | loss.item(), \ 323 | loss_descs.item(), acc_coarse, \ 324 | loss_coordinates.item(), acc_coordinates, \ 325 | loss_fb_descs.item(), acc_fb_coarse, \ 326 | loss_fb_coordinates.item(), acc_fb_coordinates, \ 327 | loss_kpts.item(), acc_kpt, \ 328 | loss_normals.item()) ) 329 | pbar.update(1) 330 | 331 | # Log metrics 332 | self.writer.add_scalar('Loss/total', loss.item(), i) 333 | self.writer.add_scalar('Accuracy/acc_coarse', acc_coarse, i) 334 | self.writer.add_scalar('Accuracy/acc_coordinates', acc_coordinates, i) 335 | self.writer.add_scalar('Accuracy/acc_fb_coarse', acc_fb_coarse, i) 336 | self.writer.add_scalar('Accuracy/acc_fb_coordinates', acc_fb_coordinates, i) 337 | self.writer.add_scalar('Loss/descs', loss_descs.item(), i) 338 | self.writer.add_scalar('Loss/coordinates', loss_coordinates.item(), i) 339 | self.writer.add_scalar('Loss/fb_descs', loss_fb_descs.item(), i) 340 | self.writer.add_scalar('Loss/fb_coordinates', loss_fb_coordinates.item(), i) 341 | self.writer.add_scalar('Loss/kpts', loss_kpts.item(), i) 342 | self.writer.add_scalar('Loss/normals', loss_normals.item(), i) 343 | 344 | 345 | 346 | if __name__ == '__main__': 347 | 348 | setproctitle.setproctitle(args.name) 349 | 350 | trainer = Trainer( 351 | megadepth_root_path=args.megadepth_root_path, 352 | use_megadepth=args.use_megadepth, 353 | megadepth_batch_size=args.megadepth_batch_size, 354 | coco_root_path=args.coco_root_path, 355 | use_coco=args.use_coco, 356 | coco_batch_size=args.coco_batch_size, 357 | ckpt_save_path=args.ckpt_save_path, 358 | n_steps=args.n_steps, 359 | lr=args.lr, 360 | gamma_steplr=args.gamma_steplr, 361 | training_res=args.training_res, 362 | device_num=args.device_num, 363 | dry_run=args.dry_run, 364 | save_ckpt_every=args.save_ckpt_every, 365 | use_coord_loss=args.use_coord_loss 366 | ) 367 | 368 | #The most fun part 369 | trainer.train() 370 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # default training 2 | nohup python /home/yepeng_liu/code_python/LiftFeat/train.py \ 3 | --name LiftFeat_test \ 4 | --ckpt_save_path /home/yepeng_liu/code_python/LiftFeat/trained_weights/test \ 5 | --device_num 1 \ 6 | --use_megadepth \ 7 | --megadepth_batch_size 8 \ 8 | --use_coco \ 9 | --coco_batch_size 4 \ 10 | --save_ckpt_every 1000 \ 11 | > /home/yepeng_liu/code_python/LiftFeat/trained_weights/test/training.log 2>&1 & -------------------------------------------------------------------------------- /utils/VisualOdometry.py: -------------------------------------------------------------------------------- 1 | # based on: https://github.com/uoip/monoVO-python 2 | 3 | import numpy as np 4 | import cv2 5 | import logging 6 | import glob 7 | 8 | def create_dataloader(conf): 9 | try: 10 | code_line = f"{conf['name']}(conf)" 11 | loader = eval(code_line) 12 | except NameError: 13 | raise NotImplementedError(f"{conf['name']} is not implemented yet.") 14 | 15 | return loader 16 | 17 | """ 18 | 针孔相机模型类:用于定义针孔相机的内参 19 | fx,fy:焦距 20 | cx,cy:光心位置 21 | k1,k2,p1,p2,p3:畸变参数 22 | """ 23 | class PinholeCamera(object): 24 | def __init__(self, width, height, fx, fy, cx, cy, 25 | k1=0.0, k2=0.0, p1=0.0, p2=0.0, k3=0.0): 26 | self.width = width 27 | self.height = height 28 | self.fx = fx 29 | self.fy = fy 30 | self.cx = cx 31 | self.cy = cy 32 | self.distortion = (abs(k1) > 0.0000001) 33 | self.d = [k1, k2, p1, p2, k3] 34 | 35 | class KITTILoader(object): 36 | default_config = { 37 | "root_path": "../test_imgs", 38 | "sequence": "00", 39 | "start": 0 40 | } 41 | 42 | def __init__(self, config={}): 43 | self.config = self.default_config 44 | self.config = {**self.config, **config} 45 | logging.info("KITTI Dataset config: ") 46 | logging.info(self.config) 47 | 48 | if self.config["sequence"] in ["00", "01", "02"]: 49 | self.cam = PinholeCamera(1241.0, 376.0, 718.8560, 718.8560, 607.1928, 185.2157) 50 | elif self.config["sequence"] in ["03"]: 51 | self.cam = PinholeCamera(1242.0, 375.0, 721.5377, 721.5377, 609.5593, 172.854) 52 | elif self.config["sequence"] in ["04", "05", "06", "07", "08", "09", "10"]: 53 | self.cam = PinholeCamera(1226.0, 370.0, 707.0912, 707.0912, 601.8873, 183.1104) 54 | else: 55 | raise ValueError(f"Unknown sequence number: {self.config['sequence']}") 56 | 57 | # read ground truth pose 58 | self.pose_path = self.config["root_path"] + "/poses/" + self.config["sequence"] + ".txt" 59 | self.gt_poses = [] 60 | with open(self.pose_path) as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | ss = line.strip().split() 64 | pose = np.zeros((1, len(ss))) 65 | for i in range(len(ss)): 66 | pose[0, i] = float(ss[i]) 67 | 68 | pose.resize([3, 4]) 69 | self.gt_poses.append(pose) 70 | 71 | # image id 72 | self.img_id = self.config["start"] 73 | self.img_N = len(glob.glob(pathname=self.config["root_path"] + "/sequences/" \ 74 | + self.config["sequence"] + "/image_0/*.png")) 75 | 76 | def get_cur_pose(self): 77 | return self.gt_poses[self.img_id - 1] 78 | 79 | def __getitem__(self, item): 80 | file_name = self.config["root_path"] + "/sequences/" + self.config["sequence"] \ 81 | + "/image_0/" + str(item).zfill(6) + ".png" 82 | img = cv2.imread(file_name) 83 | return img 84 | 85 | def __iter__(self): 86 | return self 87 | 88 | def __next__(self): 89 | if self.img_id < self.img_N: 90 | file_name = self.config["root_path"] + "/sequences/" + self.config["sequence"] \ 91 | + "/image_0/" + str(self.img_id).zfill(6) + ".png" 92 | img = cv2.imread(file_name) 93 | 94 | self.img_id += 1 95 | 96 | return img 97 | raise StopIteration() 98 | 99 | def __len__(self): 100 | return self.img_N - self.config["start"] 101 | 102 | 103 | def create_detector(conf): 104 | try: 105 | code_line = f"{conf['name']}(conf)" 106 | detector = eval(code_line) 107 | except NameError: 108 | raise NotImplementedError(f"{conf['name']} is not implemented yet.") 109 | 110 | return detector 111 | 112 | 113 | def create_matcher(conf): 114 | try: 115 | code_line = f"{conf['name']}(conf)" 116 | matcher = eval(code_line) 117 | except NameError: 118 | raise NotImplementedError(f"{conf['name']} is not implemented yet.") 119 | 120 | return matcher 121 | 122 | class FrameByFrameMatcher(object): 123 | default_config = { 124 | "type": "FLANN", 125 | "KNN": { 126 | "HAMMING": True, # For ORB Binary descriptor, only can use hamming matching 127 | "first_N": 300, # For hamming matching, use first N min matches 128 | }, 129 | "FLANN": { 130 | "kdTrees": 5, 131 | "searchChecks": 50 132 | }, 133 | "distance_ratio": 0.75 134 | } 135 | 136 | def __init__(self, config={}): 137 | self.config = self.default_config 138 | self.config = {**self.config, **config} 139 | logging.info("Frame by frame matcher config: ") 140 | logging.info(self.config) 141 | 142 | if self.config["type"] == "KNN": 143 | logging.info("creating brutal force matcher...") 144 | if self.config["KNN"]["HAMMING"]: 145 | logging.info("brutal force with hamming norm.") 146 | self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) 147 | else: 148 | self.matcher = cv2.BFMatcher() 149 | elif self.config["type"] == "FLANN": 150 | logging.info("creating FLANN matcher...") 151 | # FLANN parameters 152 | FLANN_INDEX_KDTREE = 1 153 | index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=self.config["FLANN"]["kdTrees"]) 154 | search_params = dict(checks=self.config["FLANN"]["searchChecks"]) # or pass empty dictionary 155 | self.matcher = cv2.FlannBasedMatcher(index_params, search_params) 156 | else: 157 | raise ValueError(f"Unknown matcher type: {self.matcher_type}") 158 | 159 | def match(self, kptdescs): 160 | self.good = [] 161 | # get shape of the descriptor 162 | self.descriptor_shape = kptdescs["ref"]["descriptors"].shape[1] 163 | 164 | if self.config["type"] == "KNN" and self.config["KNN"]["HAMMING"]: 165 | logging.debug("KNN keypoints matching...") 166 | matches = self.matcher.match(kptdescs["ref"]["descriptors"], kptdescs["cur"]["descriptors"]) 167 | # Sort them in the order of their distance. 168 | matches = sorted(matches, key=lambda x: x.distance) 169 | # self.good = matches[:self.config["KNN"]["first_N"]] 170 | for i in range(self.config["KNN"]["first_N"]): 171 | self.good.append([matches[i]]) 172 | else: 173 | logging.debug("FLANN keypoints matching...") 174 | matches = self.matcher.knnMatch(kptdescs["ref"]["descriptors"], kptdescs["cur"]["descriptors"], k=2) 175 | # Apply ratio test 176 | for m, n in matches: 177 | if m.distance < self.config["distance_ratio"] * n.distance: 178 | self.good.append([m]) 179 | # Sort them in the order of their distance. 180 | self.good = sorted(self.good, key=lambda x: x[0].distance) 181 | return self.good 182 | 183 | def get_good_keypoints(self, kptdescs): 184 | logging.debug("getting matched keypoints...") 185 | kp_ref = np.zeros([len(self.good), 2]) 186 | kp_cur = np.zeros([len(self.good), 2]) 187 | match_dist = np.zeros([len(self.good)]) 188 | for i, m in enumerate(self.good): 189 | kp_ref[i, :] = kptdescs["ref"]["keypoints"][m[0].queryIdx] 190 | kp_cur[i, :] = kptdescs["cur"]["keypoints"][m[0].trainIdx] 191 | match_dist[i] = m[0].distance 192 | 193 | ret_dict = { 194 | "ref_keypoints": kp_ref, 195 | "cur_keypoints": kp_cur, 196 | "match_score": self.normalised_matching_scores(match_dist) 197 | } 198 | return ret_dict 199 | 200 | def __call__(self, kptdescs): 201 | self.match(kptdescs) 202 | return self.get_good_keypoints(kptdescs) 203 | 204 | def normalised_matching_scores(self, match_dist): 205 | 206 | if self.config["type"] == "KNN" and self.config["KNN"]["HAMMING"]: 207 | # ORB Hamming distance 208 | best, worst = 0, self.descriptor_shape * 8 # min and max hamming distance 209 | worst = worst / 4 # scale 210 | else: 211 | # for non-normalized descriptor 212 | if match_dist.max() > 1: 213 | best, worst = 0, self.descriptor_shape * 2 # estimated range 214 | else: 215 | best, worst = 0, 1 216 | 217 | # normalise the score! 218 | match_scores = match_dist / worst 219 | # range constraint 220 | match_scores[match_scores > 1] = 1 221 | match_scores[match_scores < 0] = 0 222 | # 1: for best match, 0: for worst match 223 | match_scores = 1 - match_scores 224 | 225 | return match_scores 226 | 227 | def draw_matched(self, img0, img1): 228 | pass 229 | 230 | # --- VISUALIZATION --- 231 | # based on: https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/models/utils.py 232 | def plot_keypoints(image, kpts): 233 | kpts = np.round(kpts).astype(int) 234 | for x, y in kpts: 235 | cv2.drawMarker(image, (x, y), (0, 255, 0), cv2.MARKER_CROSS, 6) 236 | 237 | return image 238 | 239 | class VisualOdometry(object): 240 | """ 241 | A simple frame by frame visual odometry 242 | """ 243 | 244 | def __init__(self, detector, matcher, cam): 245 | """ 246 | :param detector: a feature detector can detect keypoints their descriptors 247 | :param matcher: a keypoints matcher matching keypoints between two frames 248 | :param cam: camera parameters 249 | """ 250 | # feature detector and keypoints matcher 251 | self.detector = detector 252 | self.matcher = matcher 253 | 254 | # camera parameters 255 | self.focal = cam.fx 256 | self.pp = (cam.cx, cam.cy) 257 | 258 | # frame index counter 259 | self.index = 0 260 | 261 | # keypoints and descriptors 262 | self.kptdescs = {} 263 | 264 | # match points 265 | self.match_kps = {} 266 | 267 | # pose of current frame 268 | self.cur_R = None 269 | self.cur_t = None 270 | 271 | def update(self, image, absolute_scale=1): 272 | """ 273 | update a new image to visual odometry, and compute the pose 274 | :param image: input image 275 | :param absolute_scale: the absolute scale between current frame and last frame 276 | :return: R and t of current frame 277 | """ 278 | predict_data = self.detector.extract(image) 279 | kptdesc = { 280 | "keypoints": predict_data["keypoints"].cpu().detach().numpy(), 281 | "descriptors": predict_data["descriptors"].cpu().detach().numpy() 282 | } 283 | 284 | # first frame 285 | if self.index == 0: 286 | # save keypoints and descriptors 287 | self.kptdescs["cur"] = kptdesc 288 | 289 | # start point 290 | self.cur_R = np.identity(3) 291 | self.cur_t = np.zeros((3, 1)) 292 | else: 293 | # update keypoints and descriptors 294 | self.kptdescs["cur"] = kptdesc 295 | 296 | # match keypoints 297 | matches = self.matcher(self.kptdescs) 298 | self.match_kps = {"cur":matches['cur_keypoints'], "ref":matches['ref_keypoints']} 299 | 300 | # compute relative R,t between ref and cur frame 301 | E, mask = cv2.findEssentialMat(matches['cur_keypoints'], matches['ref_keypoints'], 302 | focal=self.focal, pp=self.pp, 303 | method=cv2.RANSAC, prob=0.999, threshold=1.0) 304 | _, R, t, mask = cv2.recoverPose(E, matches['cur_keypoints'], matches['ref_keypoints'], 305 | focal=self.focal, pp=self.pp) 306 | 307 | # get absolute pose based on absolute_scale 308 | if (absolute_scale > 0.1): 309 | self.cur_t = self.cur_t + absolute_scale * self.cur_R.dot(t) 310 | self.cur_R = R.dot(self.cur_R) 311 | 312 | self.kptdescs["ref"] = self.kptdescs["cur"] 313 | 314 | self.index += 1 315 | return self.cur_R, self.cur_t 316 | 317 | # 计算当前帧和上一帧的绝对位移,用于调整相机的平移向量 318 | class AbosluteScaleComputer(object): 319 | def __init__(self): 320 | self.prev_pose = None 321 | self.cur_pose = None 322 | self.count = 0 323 | 324 | def update(self, pose): 325 | self.cur_pose = pose 326 | 327 | scale = 1.0 328 | if self.count != 0: 329 | scale = np.sqrt( 330 | (self.cur_pose[0, 3] - self.prev_pose[0, 3]) * (self.cur_pose[0, 3] - self.prev_pose[0, 3]) 331 | + (self.cur_pose[1, 3] - self.prev_pose[1, 3]) * (self.cur_pose[1, 3] - self.prev_pose[1, 3]) 332 | + (self.cur_pose[2, 3] - self.prev_pose[2, 3]) * (self.cur_pose[2, 3] - self.prev_pose[2, 3])) 333 | 334 | self.count += 1 335 | self.prev_pose = self.cur_pose 336 | return scale 337 | 338 | 339 | 340 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/VisualOdometry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/VisualOdometry.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/alike_wrapper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/alike_wrapper.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/depth_anything_wrapper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/depth_anything_wrapper.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/featurebooster.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/featurebooster.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/featurebooster.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/featurebooster.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/post_process.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/utils/__pycache__/post_process.cpython-38.pyc -------------------------------------------------------------------------------- /utils/alike_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | "LiftFeat: 3D Geometry-Aware Local Feature Matching" 3 | """ 4 | 5 | 6 | import sys 7 | import os 8 | 9 | ALIKE_PATH = '/home/yepeng_liu/code_python/multimodal_remote/ALIKE' 10 | sys.path.append(ALIKE_PATH) 11 | 12 | import torch 13 | import torch.nn as nn 14 | from alike import ALike 15 | import cv2 16 | import numpy as np 17 | 18 | import pdb 19 | 20 | dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | configs = { 23 | 'alike-t': {'c1': 8, 'c2': 16, 'c3': 32, 'c4': 64, 'dim': 64, 'single_head': True, 'radius': 2, 24 | 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-t.pth')}, 25 | 'alike-s': {'c1': 8, 'c2': 16, 'c3': 48, 'c4': 96, 'dim': 96, 'single_head': True, 'radius': 2, 26 | 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-s.pth')}, 27 | 'alike-n': {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True, 'radius': 2, 28 | 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-n.pth')}, 29 | 'alike-l': {'c1': 32, 'c2': 64, 'c3': 128, 'c4': 128, 'dim': 128, 'single_head': False, 'radius': 2, 30 | 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-l.pth')}, 31 | } 32 | 33 | 34 | class ALikeExtractor(nn.Module): 35 | def __init__(self,model_type,device) -> None: 36 | super().__init__() 37 | self.net=ALike(**configs[model_type],device=device,top_k=4096,scores_th=0.1,n_limit=8000) 38 | 39 | 40 | @torch.inference_mode() 41 | def extract_alike_kpts(self,img): 42 | pred0=self.net(img,sub_pixel=True) 43 | return pred0['keypoints'] 44 | 45 | 46 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | featureboost_config = { 6 | "keypoint_dim": 65, 7 | "keypoint_encoder": [128, 64, 64], 8 | "normal_dim": 192, 9 | "normal_encoder": [128, 64, 64], 10 | "descriptor_encoder": [64, 64], 11 | "descriptor_dim": 64, 12 | "Attentional_layers": 3, 13 | "last_activation": None, 14 | "l2_normalization": None, 15 | "output_dim": 64, 16 | } -------------------------------------------------------------------------------- /utils/depth_anything_wrapper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import matplotlib 5 | import numpy as np 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.transforms import Compose 11 | import sys 12 | 13 | sys.path.append("/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2") 14 | from depth_anything_v2.dpt_opt import DepthAnythingV2 15 | from depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet 16 | 17 | import time 18 | 19 | VITS_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vits.pth" 20 | VITB_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vitb.pth" 21 | VITL_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth" 22 | 23 | model_configs = { 24 | "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]}, 25 | "vitb": { 26 | "encoder": "vitb", 27 | "features": 128, 28 | "out_channels": [96, 192, 384, 768], 29 | }, 30 | "vitl": { 31 | "encoder": "vitl", 32 | "features": 256, 33 | "out_channels": [256, 512, 1024, 1024], 34 | }, 35 | "vitg": { 36 | "encoder": "vitg", 37 | "features": 384, 38 | "out_channels": [1536, 1536, 1536, 1536], 39 | }, 40 | } 41 | 42 | class DepthAnythingExtractor(nn.Module): 43 | def __init__(self, encoder_type, device, input_size, process_size=(608,800)): 44 | super().__init__() 45 | self.net = DepthAnythingV2(**model_configs[encoder_type]) 46 | self.device = device 47 | if encoder_type == "vits": 48 | print(f"loading {VITS_MODEL_PATH}") 49 | self.net.load_state_dict(torch.load(VITS_MODEL_PATH, map_location="cpu")) 50 | elif encoder_type == "vitb": 51 | print(f"loading {VITB_MODEL_PATH}") 52 | self.net.load_state_dict(torch.load(VITB_MODEL_PATH, map_location="cpu")) 53 | elif encoder_type == "vitl": 54 | print(f"loading {VITL_MODEL_PATH}") 55 | self.net.load_state_dict(torch.load(VITL_MODEL_PATH, map_location="cpu")) 56 | else: 57 | raise RuntimeError("unsupport encoder type") 58 | self.net.to(self.device).eval() 59 | self.tranform = Compose([ 60 | Resize( 61 | width=input_size, 62 | height=input_size, 63 | resize_target=False, 64 | keep_aspect_ratio=True, 65 | ensure_multiple_of=14, 66 | resize_method='lower_bound', 67 | image_interpolation_method=cv2.INTER_CUBIC, 68 | ), 69 | NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 70 | PrepareForNet(), 71 | ]) 72 | self.process_size=process_size 73 | self.input_size=input_size 74 | 75 | @torch.inference_mode() 76 | def infer_image(self,img): 77 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 78 | 79 | img = self.tranform({'image': img})['image'] 80 | 81 | img = torch.from_numpy(img).unsqueeze(0) 82 | 83 | img = img.to(self.device) 84 | 85 | with torch.no_grad(): 86 | depth = self.net.forward(img) 87 | 88 | depth = F.interpolate(depth[:, None], self.process_size, mode="bilinear", align_corners=True)[0, 0] 89 | 90 | return depth.cpu().numpy() 91 | 92 | @torch.inference_mode() 93 | def compute_normal_map_torch(self, depth_map, scale=1.0): 94 | """ 95 | 通过深度图计算法向量 (PyTorch 实现) 96 | 97 | 参数: 98 | depth_map (torch.Tensor): 深度图,形状为 (H, W) 99 | scale (float): 深度值的比例因子,用于调整深度图中的梯度计算 100 | 101 | 返回: 102 | torch.Tensor: 法向量图,形状为 (H, W, 3) 103 | """ 104 | if depth_map.ndim != 2: 105 | raise ValueError("输入 depth_map 必须是二维张量。") 106 | 107 | # 计算深度图的梯度 108 | dzdx = torch.diff(depth_map, dim=1, append=depth_map[:, -1:]) * scale 109 | dzdy = torch.diff(depth_map, dim=0, append=depth_map[-1:, :]) * scale 110 | 111 | # 初始化法向量图 112 | H, W = depth_map.shape 113 | normal_map = torch.zeros((H, W, 3), dtype=depth_map.dtype, device=depth_map.device) 114 | normal_map[:, :, 0] = -dzdx # x 分量 115 | normal_map[:, :, 1] = -dzdy # y 分量 116 | normal_map[:, :, 2] = 1.0 # z 分量 117 | 118 | # 归一化法向量 119 | norm = torch.linalg.norm(normal_map, dim=2, keepdim=True) 120 | norm = torch.where(norm == 0, torch.tensor(1.0, device=depth_map.device), norm) # 避免除以零 121 | normal_map /= norm 122 | 123 | return normal_map 124 | 125 | @torch.inference_mode() 126 | def extract(self, img): 127 | depth = self.infer_image(img) 128 | depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 129 | depth_t=torch.from_numpy(depth).float().to(self.device) 130 | normal_map = self.compute_normal_map_torch(depth_t,1.0) 131 | return depth_t,normal_map 132 | 133 | 134 | if __name__=="__main__": 135 | img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg') 136 | img=cv2.imread(img_path) 137 | img=cv2.resize(img,(800,608)) 138 | import pdb;pdb.set_trace() 139 | DAExtractor=DepthAnythingExtractor('vitb',torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'),256) 140 | depth_t,norm=DAExtractor.extract(img) 141 | norm=norm.cpu().numpy() 142 | norm=(norm+1)/2*255 143 | norm=norm.astype(np.uint8) 144 | cv2.imwrite(os.path.join(os.path.dirname(__file__),"norm.png"),norm) 145 | start=time.perf_counter() 146 | for i in range(20): 147 | depth_t,norm=DAExtractor.extract(img) 148 | end=time.perf_counter() 149 | print(f"cost {end-start} seconds") 150 | -------------------------------------------------------------------------------- /utils/featurebooster.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def MLP(channels: List[int], do_bn: bool = False) -> nn.Module: 9 | """ Multi-layer perceptron """ 10 | n = len(channels) 11 | layers = [] 12 | for i in range(1, n): 13 | layers.append(nn.Linear(channels[i - 1], channels[i])) 14 | if i < (n-1): 15 | if do_bn: 16 | layers.append(nn.BatchNorm1d(channels[i])) 17 | layers.append(nn.ReLU()) 18 | return nn.Sequential(*layers) 19 | 20 | def MLP_no_ReLU(channels: List[int], do_bn: bool = False) -> nn.Module: 21 | """ Multi-layer perceptron """ 22 | n = len(channels) 23 | layers = [] 24 | for i in range(1, n): 25 | layers.append(nn.Linear(channels[i - 1], channels[i])) 26 | if i < (n-1): 27 | if do_bn: 28 | layers.append(nn.BatchNorm1d(channels[i])) 29 | return nn.Sequential(*layers) 30 | 31 | 32 | class KeypointEncoder(nn.Module): 33 | """ Encoding of geometric properties using MLP """ 34 | def __init__(self, keypoint_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None: 35 | super().__init__() 36 | self.encoder = MLP([keypoint_dim] + layers + [feature_dim]) 37 | self.use_dropout = dropout 38 | self.dropout = nn.Dropout(p=p) 39 | 40 | def forward(self, kpts): 41 | if self.use_dropout: 42 | return self.dropout(self.encoder(kpts)) 43 | return self.encoder(kpts) 44 | 45 | class NormalEncoder(nn.Module): 46 | """ Encoding of geometric properties using MLP """ 47 | def __init__(self, normal_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None: 48 | super().__init__() 49 | self.encoder = MLP_no_ReLU([normal_dim] + layers + [feature_dim]) 50 | self.use_dropout = dropout 51 | self.dropout = nn.Dropout(p=p) 52 | 53 | def forward(self, kpts): 54 | if self.use_dropout: 55 | return self.dropout(self.encoder(kpts)) 56 | return self.encoder(kpts) 57 | 58 | 59 | class DescriptorEncoder(nn.Module): 60 | """ Encoding of visual descriptor using MLP """ 61 | def __init__(self, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None: 62 | super().__init__() 63 | self.encoder = MLP([feature_dim] + layers + [feature_dim]) 64 | self.use_dropout = dropout 65 | self.dropout = nn.Dropout(p=p) 66 | 67 | def forward(self, descs): 68 | residual = descs 69 | if self.use_dropout: 70 | return residual + self.dropout(self.encoder(descs)) 71 | return residual + self.encoder(descs) 72 | 73 | 74 | class AFTAttention(nn.Module): 75 | """ Attention-free attention """ 76 | def __init__(self, d_model: int, dropout: bool = False, p: float = 0.1) -> None: 77 | super().__init__() 78 | self.dim = d_model 79 | self.query = nn.Linear(d_model, d_model) 80 | self.key = nn.Linear(d_model, d_model) 81 | self.value = nn.Linear(d_model, d_model) 82 | self.proj = nn.Linear(d_model, d_model) 83 | # self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 84 | self.use_dropout = dropout 85 | self.dropout = nn.Dropout(p=p) 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | residual = x 89 | q = self.query(x) 90 | k = self.key(x) 91 | v = self.value(x) 92 | # q = torch.sigmoid(q) 93 | k = k.T 94 | k = torch.softmax(k, dim=-1) 95 | k = k.T 96 | kv = (k * v).sum(dim=-2, keepdim=True) 97 | x = q * kv 98 | x = self.proj(x) 99 | if self.use_dropout: 100 | x = self.dropout(x) 101 | x += residual 102 | # x = self.layer_norm(x) 103 | return x 104 | 105 | 106 | class PositionwiseFeedForward(nn.Module): 107 | def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1) -> None: 108 | super().__init__() 109 | self.mlp = MLP([feature_dim, feature_dim*2, feature_dim]) 110 | # self.layer_norm = nn.LayerNorm(feature_dim, eps=1e-6) 111 | self.use_dropout = dropout 112 | self.dropout = nn.Dropout(p=p) 113 | 114 | def forward(self, x: torch.Tensor) -> torch.Tensor: 115 | residual = x 116 | x = self.mlp(x) 117 | if self.use_dropout: 118 | x = self.dropout(x) 119 | x += residual 120 | # x = self.layer_norm(x) 121 | return x 122 | 123 | 124 | class AttentionalLayer(nn.Module): 125 | def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1): 126 | super().__init__() 127 | self.attn = AFTAttention(feature_dim, dropout=dropout, p=p) 128 | self.ffn = PositionwiseFeedForward(feature_dim, dropout=dropout, p=p) 129 | 130 | def forward(self, x: torch.Tensor) -> torch.Tensor: 131 | # import pdb;pdb.set_trace() 132 | x = self.attn(x) 133 | x = self.ffn(x) 134 | return x 135 | 136 | 137 | class AttentionalNN(nn.Module): 138 | def __init__(self, feature_dim: int, layer_num: int, dropout: bool = False, p: float = 0.1) -> None: 139 | super().__init__() 140 | self.layers = nn.ModuleList([ 141 | AttentionalLayer(feature_dim, dropout=dropout, p=p) 142 | for _ in range(layer_num)]) 143 | 144 | def forward(self, desc: torch.Tensor) -> torch.Tensor: 145 | for layer in self.layers: 146 | desc = layer(desc) 147 | return desc 148 | 149 | 150 | class FeatureBooster(nn.Module): 151 | default_config = { 152 | 'descriptor_dim': 128, 153 | 'keypoint_encoder': [32, 64, 128], 154 | 'Attentional_layers': 3, 155 | 'last_activation': 'relu', 156 | 'l2_normalization': True, 157 | 'output_dim': 128 158 | } 159 | 160 | def __init__(self, config, dropout=False, p=0.1, use_kenc=True, use_normal=True, use_cross=True): 161 | super().__init__() 162 | self.config = {**self.default_config, **config} 163 | self.use_kenc = use_kenc 164 | self.use_cross = use_cross 165 | self.use_normal = use_normal 166 | 167 | if use_kenc: 168 | self.kenc = KeypointEncoder(self.config['keypoint_dim'], self.config['descriptor_dim'], self.config['keypoint_encoder'], dropout=dropout) 169 | 170 | if use_normal: 171 | self.nenc = NormalEncoder(self.config['normal_dim'], self.config['descriptor_dim'], self.config['normal_encoder'], dropout=dropout) 172 | 173 | if self.config.get('descriptor_encoder', False): 174 | self.denc = DescriptorEncoder(self.config['descriptor_dim'], self.config['descriptor_encoder'], dropout=dropout) 175 | else: 176 | self.denc = None 177 | 178 | if self.use_cross: 179 | self.attn_proj = AttentionalNN(feature_dim=self.config['descriptor_dim'], layer_num=self.config['Attentional_layers'], dropout=dropout) 180 | 181 | # self.final_proj = nn.Linear(self.config['descriptor_dim'], self.config['output_dim']) 182 | 183 | self.use_dropout = dropout 184 | self.dropout = nn.Dropout(p=p) 185 | 186 | # self.layer_norm = nn.LayerNorm(self.config['descriptor_dim'], eps=1e-6) 187 | 188 | if self.config.get('last_activation', False): 189 | if self.config['last_activation'].lower() == 'relu': 190 | self.last_activation = nn.ReLU() 191 | elif self.config['last_activation'].lower() == 'sigmoid': 192 | self.last_activation = nn.Sigmoid() 193 | elif self.config['last_activation'].lower() == 'tanh': 194 | self.last_activation = nn.Tanh() 195 | else: 196 | raise Exception('Not supported activation "%s".' % self.config['last_activation']) 197 | else: 198 | self.last_activation = None 199 | 200 | def forward(self, desc, kpts, normals): 201 | # import pdb;pdb.set_trace() 202 | ## Self boosting 203 | # Descriptor MLP encoder 204 | if self.denc is not None: 205 | desc = self.denc(desc) 206 | # Geometric MLP encoder 207 | if self.use_kenc: 208 | desc = desc + self.kenc(kpts) 209 | if self.use_dropout: 210 | desc = self.dropout(desc) 211 | 212 | # 法向量特征 encoder 213 | if self.use_normal: 214 | desc = desc + self.nenc(normals) 215 | if self.use_dropout: 216 | desc = self.dropout(desc) 217 | 218 | ## Cross boosting 219 | # Multi-layer Transformer network. 220 | if self.use_cross: 221 | # desc = self.attn_proj(self.layer_norm(desc)) 222 | desc = self.attn_proj(desc) 223 | 224 | ## Post processing 225 | # Final MLP projection 226 | # desc = self.final_proj(desc) 227 | if self.last_activation is not None: 228 | desc = self.last_activation(desc) 229 | # L2 normalization 230 | if self.config['l2_normalization']: 231 | desc = F.normalize(desc, dim=-1) 232 | 233 | return desc 234 | 235 | if __name__ == "__main__": 236 | from config import t1_featureboost_config 237 | fb_net = FeatureBooster(t1_featureboost_config) 238 | 239 | descs=torch.randn([1900,64]) 240 | kpts=torch.randn([1900,65]) 241 | normals=torch.randn([1900,3]) 242 | 243 | import pdb;pdb.set_trace() 244 | 245 | descs_refine=fb_net(descs,kpts,normals) 246 | 247 | print(descs_refine.shape) 248 | -------------------------------------------------------------------------------- /utils/post_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def match_features(feats1, feats2, min_cossim=0.82): 4 | cossim = feats1 @ feats2.t() 5 | cossim_t = feats2 @ feats1.t() 6 | _, match12 = cossim.max(dim=1) 7 | _, match21 = cossim_t.max(dim=1) 8 | idx0 = torch.arange(len(match12), device=match12.device) 9 | mutual = match21[match12] == idx0 10 | # import pdb; pdb.set_trace() 11 | if min_cossim > 0: 12 | best_sim, _ = cossim.max(dim=1) 13 | good = best_sim > min_cossim 14 | idx0 = idx0[mutual & good] 15 | idx1 = match12[mutual & good] 16 | else: 17 | idx0 = idx0[mutual] 18 | idx1 = match12[mutual] 19 | 20 | match_scores = cossim[idx0, idx1] 21 | return idx0, idx1, match_scores -------------------------------------------------------------------------------- /weights/LiftFeat.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyp-deeplearning/LiftFeat/ffe1f46576ad54ca178b724b00421da0eddc383a/weights/LiftFeat.pth --------------------------------------------------------------------------------