├── .gitattributes ├── .gitignore ├── README.md ├── Results └── ALLSS │ ├── superglue_descriptor_128 │ ├── checkpoints │ │ └── SuperGlue_epoch_80.pth │ └── logdir │ │ └── events.out.tfevents.1617122770.DESKTOP-MEF2DI0 │ ├── superglue_descriptor_64 │ └── checkpoints │ │ └── SuperGlue_epoch_200.pth │ ├── superpoint_descriptor_128 │ ├── checkpoints │ │ └── superPointNet_62000_checkpoint.pth.tar │ └── config.yml │ └── superpoint_descriptor_64 │ ├── checkpoints │ └── superPointNet_100000_checkpoint.pth.tar │ └── config.yml ├── Traditional └── registration.py ├── datasets ├── ALLSS.py ├── GlueSparse.py ├── SSHIDataset.py ├── __init__.py └── data_tools.py ├── requirements.txt ├── superglue └── models │ ├── __init__.py │ ├── matching.py │ ├── matching_test.py │ ├── superglue_test.py │ ├── superglue_train.py │ ├── superpoint.py │ ├── utils.py │ └── weights │ ├── SuperGlue_allss_descriptor_128.pth │ ├── SuperGlue_allss_descriptor_64.pth │ ├── superglue_indoor.pth │ ├── superglue_outdoor.pth │ └── superpoint_v1.pth ├── superpoint ├── Train_model_frontend.py ├── Train_model_heatmap.py ├── configs │ ├── magicpoint_allss_export.yaml │ └── superpoint_allss_train_heatmap.yaml ├── correspondence_tools │ ├── __init__.py │ ├── correspondence_augmentation.py │ ├── correspondence_finder.py │ └── correspondence_plotter.py ├── loss_functions │ ├── __init__.py │ ├── loss_composer.py │ ├── pixelwise_contrastive_loss.py │ └── sparse_loss.py └── models │ ├── __init__.py │ ├── model_utils.py │ ├── model_wrap.py │ ├── superpoint_test.py │ ├── superpoint_train.py │ ├── unet_parts.py │ └── weights │ ├── magicpoint │ └── superPointNet_100000_checkpoint.pth.tar │ ├── superPointNet_allss_descriptor_128.pth.tar │ ├── superPointNet_allss_descriptor_64.pth.tar │ └── superPointNet_coco_descriptor_256.pth.tar ├── superpoint_export_pseudo.py ├── superpoint_flann_test.py ├── superpoint_glue_official_test.py ├── superpoint_glue_test.py ├── superpoint_glue_train.py ├── superpoint_train_descriptor.py ├── traditional.py └── utils ├── __init__.py ├── correspondence_tools ├── __init__.py ├── correspondence_augmentation.py ├── correspondence_finder.py └── correspondence_plotter.py ├── cp_labels.py ├── d2s.py ├── draw.py ├── homographies.py ├── loader.py ├── logging.py ├── loss_functions ├── __init__.py ├── loss_composer.py ├── pixelwise_contrastive_loss.py └── sparse_loss.py ├── losses.py ├── photometric.py ├── photometric_augmentation.py ├── print_tool.py ├── tools.py ├── utils.py └── var_dim.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.vscode 3 | *.bmp 4 | *.pyc 5 | *.json 6 | *.jpg 7 | *.ini 8 | *.npz 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Registration 2 | 3 | ## Introduction 4 | 5 | This is a Python Code for image registration task . It contains the OpenCV implemetation of traditional registration method: SIFT and ORB; and the Pytorch implementation of deep learning method: SuperPoint and SuperGlue. 6 | 7 | SuperPoint and SuperGlue are respectively CVPR2018 and CVPR2020 research project done by Magic Leap . SuperPoint is a CNN framework used for feature extraction and feature description. SuperGlue use deep graph matching method to replace the traditional local feature matching method, it use attention mechanism aggregating the context information . For more details, please see their paper and github repo: 8 | 9 | - SuperPoint Paper: [SuperPoint: Self-Supervised Interest Point Detection and Description](https://arxiv.org/abs/1712.07629). 10 | - SuperPoint Github repo: https://github.com/magicleap/SuperPointPretrainedNetwork. 11 | - SuperGlue Paper:[SuperGlue: Learning Feature Matching with Graph Neural Networks](https://arxiv.org/abs/1911.11763). 12 | - SuperGlue Github repo: https://github.com/magicleap/SuperGluePretrainedNetwork. 13 | 14 | Due to the author only open the test code and pretrained network, you need to realize the training code for your own implementation. There also exist some unofficial project for reference, in this project, the SuperPoint training code is based on the Pytorch implementation: https://github.com/eric-yyjau/pytorch-superpoint. The SuperGlue training code is base on the Pytorch implementation: https://github.com/HeatherJiaZG/SuperGlue-pytorch. 15 | 16 | ## Requirement 17 | 18 | ### Dependencies 19 | 20 | This repo depends on a few standard pythonic modules, plus OpenCV and PyTorch. 21 | 22 | - Python3==3.7 23 | - Pytorch>=1.2 24 | - opencv-python==4.5.1.48 25 | - opencv-contrib-python==4.5.1.48 26 | - cuda (tested in cuda10.1) 27 | 28 | ``` 29 | conda create --name matching python=3.7 30 | conda activate matching 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ### Datasets 35 | 36 | Datasets should be prepared by yourself. The folder structure contain two parts: the datasets for training and evaluation. 37 | 38 | The training dataset should look like this: 39 | 40 | ``` 41 | |-- ALLSS(Your datasets) 42 | | |-- train 43 | | | |-- file.jpg 44 | | | `-- ... 45 | | `-- val 46 | | |-- file.jpg 47 | | | `-- ... 48 | ``` 49 | 50 | Our background is defect inspection in Industrial. We have many pictures which are the same product, but there exist shift、rotation、scaling between of them. And our task is to align all the images to the same coordinate system which benefit to downstream task. So the evaluation dataset(for our using) use one image for template image, others for source image. It should look like this: 51 | 52 | ``` 53 | datasets/ 54 | |-- Amazon 55 | | `-- template 56 | | |-- template.jpg 57 | | |-- source 58 | | | |-- source1.jpg 59 | | | |-- source2.jpg 60 | | | `-- ... 61 | ``` 62 | 63 | ## Run the Code 64 | 65 | ### Traditional method 66 | 67 | There are two main top-level scripts for testing in this repo: 68 | 69 | - `traditional.py` : runs the traditional method: SIFT or ORB for feature extraction and description, KNN method(flann) for feature matching and RANSAC method for outliers rejection. 70 | 71 | Run the SIFT method for testing 72 | 73 | ``` 74 | python traditional.py --Method SIFT --img_dir datasets/Amazon/ --Result_dir Results/Amazon/ --resize_scale 0.5 --match_viz True 75 | ``` 76 | 77 | Run the ORB method for testing 78 | 79 | ``` 80 | python traditional.py --Method ORB --img_dir datasets/Amazon/ --Result_dir Results/Amazon/ --resize_scale 0.5 --match_viz True 81 | ``` 82 | 83 | 84 | ### Superpoint method 85 | 86 | #### Training 87 | 88 | Superpoint use a Self-Supervisied Training method, which contain three steps: 89 | 90 | - **Step 1:** train an initial interest point detector(Magic point) on synthetic data , we don’t realize it in our task and just using the pretrained model weights provided by the implementaion: https://github.com/eric-yyjau/pytorch-superpoint 91 | 92 | - **Step 2:** apply a novel homographic Adaptation procedure to automatically label images from a target, unlabeled domain 93 | 94 | Export the pesudo label on training dataset: 95 | 96 | ``` 97 | python superpoint_export_pseudo.py --config superpoint/configs/magicpoint_allss_export.yaml --export_task train --outputImg 98 | ``` 99 | 100 | Export the pesudo label on validation dataset: 101 | 102 | ``` 103 | python superpoint_export_pseudo.py --config superpoint/configs/magicpoint_allss_export.yaml --export_task val --outputImg 104 | ``` 105 | 106 | - **Step 3:** train a fully-convolutional network that jointly extracts interest points and descriptors from an image 107 | 108 | ``` 109 | python superpoint_train_descriptor.py --config superpoint/configs/superpoint_allss_train_heatmap.yaml 110 | ``` 111 | 112 | #### Testing 113 | 114 | After training on your own dataset, we test on the evaluation datasets. We apply the SuperPoint for feature extraction and feature description. KNN method(flann) for feature matching and RANSAC method for outliers rejection. 115 | 116 | Run the SuperPoint +FLANN method for testing: 117 | 118 | ``` 119 | python superpoint_flann_test.py --img_dir datasets/Amazon/ --descriptor_dim 128 --img_dir --weights_path superpoint/models/weights/superPointNet_allss_descriptor_128.pth.tar 120 | ``` 121 | 122 | **noted:** the **descriptor_dim** and the **weights** should be paired , if descriptor_dim=64, the weights_path should be superpoint/models/weights/superPointNet_allss_descriptor_64.pth.tar 123 | 124 | ### SuperGlue method 125 | 126 | #### Training 127 | 128 | SuperGlue is GNN method for feature matching and outliers rejection. The input of superglue is the output of superpoint, and you can using other deep learning or traditional(SIFT or ORB) feature extraction and description methods to replace superpoint. Run the superglue training code with superpoint pretrained model : 129 | 130 | ``` 131 | python superpoint_glue_train.py --descriptor_dim 128 --keypoint_encoder [32,64,128] --superpoint_weights superpoint/models/weights/superPointNet_allss_descriptor_128.pth.tar --sinkhorn_iterations 30 132 | ``` 133 | 134 | **noted:** the **descriptor_dim** 、**weights** and **keypoint _encoder** should be paired, like this: 135 | 136 | superPointNet_allss_descriptor_64.pth.tar --descriptor_dim 64 --keypoint_encoder [32,64] 137 | 138 | superPointNet_allss_descriptor_128.pth.tar --descriptor_dim 128 --keypoint_encoder [32,64,128] 139 | 140 | superPointNet_allss_descriptor_256.pth.tar --descriptor_dim 256 --keypoint_encoder [32,64,128, 256] 141 | 142 | #### Testing 143 | 144 | Run the SuperPoint+SuperGlue code whose model is trained by ourself for evaluation: 145 | 146 | ``` 147 | python superpoint_glue_test.py --descriptor_dim 128 --keypoint_encoder [32, 64, 128] --sinkhorn_iterations 30 --superpoint_weights superpoint/models/weights/superPointNet_allss_descriptor_128.pth.tar --superglue_weights superglue/models/weights/SuperGlue_allss_descriptor_128.pth 148 | ``` 149 | 150 | **noted:** We didn’t have a good result from the superpoint+superglue method on our training , the reason is that our training datasets is too small, you can training on a big dataset(like coco) and then fine tune on your own dataset. You also can test the official model pretrained by Megicleap for comparing, which has a good result. 151 | 152 | Run the SuperPoint+SuperGlue official code for evaluation: 153 | 154 | ``` 155 | python superpoint_glue_official_test.py --descriptor_dim 256 --superpoint_weights supeeglue/models/weights/superpoint_v1.pth 156 | --superglue_weights superglue/models/weights/superglue_indoor.pth --sinkhorn_iterations 30 157 | ``` 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /Results/ALLSS/superglue_descriptor_128/checkpoints/SuperGlue_epoch_80.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eae18d22a70d037e6d6ae9b67e9523909829cdb4d17b4e8b579593c430775622 3 | size 12234575 4 | -------------------------------------------------------------------------------- /Results/ALLSS/superglue_descriptor_128/logdir/events.out.tfevents.1617122770.DESKTOP-MEF2DI0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/Results/ALLSS/superglue_descriptor_128/logdir/events.out.tfevents.1617122770.DESKTOP-MEF2DI0 -------------------------------------------------------------------------------- /Results/ALLSS/superglue_descriptor_64/checkpoints/SuperGlue_epoch_200.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8d20e8fcf65d6455d43f6570a60f69d33c77c42e26b7be0e2c37e686da52edf3 3 | size 3181207 4 | -------------------------------------------------------------------------------- /Results/ALLSS/superpoint_descriptor_128/checkpoints/superPointNet_62000_checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/Results/ALLSS/superpoint_descriptor_128/checkpoints/superPointNet_62000_checkpoint.pth.tar -------------------------------------------------------------------------------- /Results/ALLSS/superpoint_descriptor_128/config.yml: -------------------------------------------------------------------------------- 1 | data: 2 | augmentation: 3 | homographic: 4 | enable: false 5 | photometric: 6 | enable: true 7 | params: 8 | additive_gaussian_noise: 9 | stddev_range: 10 | - 0 11 | - 10 12 | additive_shade: 13 | kernel_size_range: 14 | - 100 15 | - 150 16 | transparency_range: 17 | - -0.5 18 | - 0.5 19 | additive_speckle_noise: 20 | prob_range: 21 | - 0 22 | - 0.0035 23 | motion_blur: 24 | max_kernel_size: 3 25 | random_brightness: 26 | max_abs_change: 50 27 | random_contrast: 28 | strength_range: 29 | - 0.5 30 | - 1.5 31 | primitives: 32 | - random_brightness 33 | - random_contrast 34 | - additive_speckle_noise 35 | - additive_gaussian_noise 36 | - additive_shade 37 | - motion_blur 38 | cache_in_memory: false 39 | dataset: ALLSS 40 | gaussian_label: 41 | enable: true 42 | params: 43 | GaussianBlur: 44 | sigma: 0.2 45 | labels: Results/ALLSS/magicpoint_homoAdapt_pseudo 46 | preprocessing: 47 | resize: 48 | - 480 49 | - 640 50 | root: null 51 | root_split_txt: null 52 | warped_pair: 53 | enable: true 54 | params: 55 | allow_artifacts: true 56 | max_angle: 1.57 57 | patch_ratio: 0.85 58 | perspective: true 59 | perspective_amplitude_x: 0.2 60 | perspective_amplitude_y: 0.2 61 | rotation: true 62 | scaling: true 63 | scaling_amplitude: 0.2 64 | translation: true 65 | valid_border_margin: 3 66 | front_end_model: Train_model_heatmap 67 | model: 68 | batch_size: 16 69 | dense_loss: 70 | enable: false 71 | params: 72 | descriptor_dist: 4 73 | lambda_d: 800 74 | descriptor_length: 128 75 | detection_threshold: 0.015 76 | detector_loss: 77 | loss_type: softmax 78 | eval_batch_size: 16 79 | lambda_loss: 1 80 | learning_rate: 0.0001 81 | name: SuperPointNet_gauss 82 | nms: 4 83 | other_settings: train 2d, gauss 0.2 84 | sparse_loss: 85 | enable: true 86 | params: 87 | dist: cos 88 | lamda_d: 1 89 | method: 2d 90 | num_masked_non_matches_per_match: 100 91 | num_matching_attempts: 1000 92 | pretrained: null 93 | reset_iter: true 94 | retrain: true 95 | save_interval: 2000 96 | tensorboard_interval: 200 97 | train_iter: 100000 98 | training: 99 | workers_train: 4 100 | workers_val: 2 101 | validation_interval: 2000 102 | validation_size: 5 103 | -------------------------------------------------------------------------------- /Results/ALLSS/superpoint_descriptor_64/checkpoints/superPointNet_100000_checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/Results/ALLSS/superpoint_descriptor_64/checkpoints/superPointNet_100000_checkpoint.pth.tar -------------------------------------------------------------------------------- /Results/ALLSS/superpoint_descriptor_64/config.yml: -------------------------------------------------------------------------------- 1 | data: 2 | augmentation: 3 | homographic: 4 | enable: false 5 | photometric: 6 | enable: true 7 | params: 8 | additive_gaussian_noise: 9 | stddev_range: 10 | - 0 11 | - 10 12 | additive_shade: 13 | kernel_size_range: 14 | - 100 15 | - 150 16 | transparency_range: 17 | - -0.5 18 | - 0.5 19 | additive_speckle_noise: 20 | prob_range: 21 | - 0 22 | - 0.0035 23 | motion_blur: 24 | max_kernel_size: 3 25 | random_brightness: 26 | max_abs_change: 50 27 | random_contrast: 28 | strength_range: 29 | - 0.5 30 | - 1.5 31 | primitives: 32 | - random_brightness 33 | - random_contrast 34 | - additive_speckle_noise 35 | - additive_gaussian_noise 36 | - additive_shade 37 | - motion_blur 38 | cache_in_memory: false 39 | dataset: ALLSS 40 | gaussian_label: 41 | enable: true 42 | params: 43 | GaussianBlur: 44 | sigma: 0.2 45 | labels: Results/ALLSS/magicpoint_homoAdapt_pseudo 46 | preprocessing: 47 | resize: 48 | - 480 49 | - 640 50 | root: null 51 | root_split_txt: null 52 | warped_pair: 53 | enable: true 54 | params: 55 | allow_artifacts: true 56 | max_angle: 1.57 57 | patch_ratio: 0.85 58 | perspective: true 59 | perspective_amplitude_x: 0.2 60 | perspective_amplitude_y: 0.2 61 | rotation: true 62 | scaling: true 63 | scaling_amplitude: 0.2 64 | translation: true 65 | valid_border_margin: 3 66 | front_end_model: Train_model_heatmap 67 | model: 68 | batch_size: 8 69 | dense_loss: 70 | enable: false 71 | params: 72 | descriptor_dist: 4 73 | lambda_d: 800 74 | descriptor_length: 64 75 | detection_threshold: 0.015 76 | detector_loss: 77 | loss_type: softmax 78 | eval_batch_size: 8 79 | lambda_loss: 1 80 | learning_rate: 0.0001 81 | name: SuperPointNet_gauss 82 | nms: 4 83 | other_settings: train 2d, gauss 0.2 84 | sparse_loss: 85 | enable: true 86 | params: 87 | dist: cos 88 | lamda_d: 1 89 | method: 2d 90 | num_masked_non_matches_per_match: 100 91 | num_matching_attempts: 1000 92 | pretrained: null 93 | reset_iter: true 94 | retrain: true 95 | save_interval: 2000 96 | tensorboard_interval: 200 97 | train_iter: 100000 98 | training: 99 | workers_train: 4 100 | workers_val: 2 101 | validation_interval: 2000 102 | validation_size: 5 103 | -------------------------------------------------------------------------------- /Traditional/registration.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | MIN_MATCH_COUNT = 10 5 | 6 | def SIFT_REGIS(src_img,temp_img,RESIZE_SCALE=None,MATCH_VIZ=False): 7 | height,width = temp_img.shape[:2] 8 | 9 | if RESIZE_SCALE is not None: 10 | src_img=cv2.resize(src_img,(int(RESIZE_SCALE*width),int(RESIZE_SCALE*height)),interpolation=cv2.INTER_CUBIC)#source image 11 | temp_img=cv2.resize(temp_img,(int(RESIZE_SCALE*width),int(RESIZE_SCALE*height)),interpolation=cv2.INTER_CUBIC)#template image 12 | 13 | image1 = cv2.cvtColor(src_img, cv2.COLOR_RGB2GRAY) 14 | image2 = cv2.cvtColor(temp_img, cv2.COLOR_RGB2GRAY) 15 | 16 | siftDetector = cv2.xfeatures2d.SIFT_create() 17 | keyPoint1, imageDesc1 = siftDetector.detectAndCompute(image1, None) 18 | keyPoint2, imageDesc2 = siftDetector.detectAndCompute(image2, None) 19 | 20 | FLANN_INDEX_KDTREE = 0 21 | index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5) 22 | search_params = dict(checks=50) 23 | flann = cv2.FlannBasedMatcher(index_params,search_params) 24 | matchePoints = flann.knnMatch(imageDesc1, imageDesc2, k=2) 25 | 26 | good = [] 27 | for m,n in matchePoints: 28 | if m.distance < 0.7*n.distance: 29 | good.append(m) 30 | 31 | if len(good)>MIN_MATCH_COUNT: 32 | src_pts = np.float32([keyPoint1[m.queryIdx].pt for m in good]) 33 | dst_pts = np.float32([keyPoint2[m.trainIdx].pt for m in good]) 34 | # M, mask = cv2.estimateAffine2D(src_pts, dst_pts, method=cv2.RANSAC,ransacReprojThreshold=7) 35 | M, mask = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC,ransacReprojThreshold=7) 36 | if MATCH_VIZ: 37 | matchesMask = mask.ravel().tolist() 38 | draw_params = dict(matchColor = (0,255,0), # draw matches in green color 39 | singlePointColor = None, 40 | matchesMask = matchesMask, # draw only inliers 41 | flags = 0) 42 | match_img = cv2.drawMatches(src_img,keyPoint1,temp_img,keyPoint2,good,None,**draw_params) 43 | else: 44 | match_img=None 45 | else: 46 | print ("SIFT:Not enough matches are found - %d/%d" % (len(good),MIN_MATCH_COUNT)) 47 | return None 48 | 49 | return M, match_img 50 | 51 | def ORB_REGIS(src_img,temp_img,RESIZE_SCALE=None,MATCH_VIZ=False): 52 | height,width = temp_img.shape[:2] 53 | 54 | if RESIZE_SCALE is not None: 55 | src_img=cv2.resize(src_img,(int(RESIZE_SCALE*width),int(RESIZE_SCALE*height)),interpolation=cv2.INTER_CUBIC)#source image 56 | temp_img=cv2.resize(temp_img,(int(RESIZE_SCALE*width),int(RESIZE_SCALE*height)),interpolation=cv2.INTER_CUBIC)#template image 57 | 58 | image1 = cv2.cvtColor(src_img, cv2.COLOR_RGB2GRAY) 59 | image2 = cv2.cvtColor(temp_img, cv2.COLOR_RGB2GRAY) 60 | 61 | ORB=cv2.ORB_create() 62 | keyPoint1, imageDesc1 = ORB.detectAndCompute(image1, None) 63 | keyPoint2, imageDesc2 = ORB.detectAndCompute(image2, None) 64 | 65 | #暴力匹配 66 | bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) 67 | matches = bf.match(imageDesc1,imageDesc2) 68 | matches = sorted(matches, key = lambda x:x.distance) 69 | 70 | if len(matches)>MIN_MATCH_COUNT: 71 | src_pts = np.float32([keyPoint1[m.queryIdx].pt for m in matches]) 72 | dst_pts = np.float32([keyPoint2[m.trainIdx].pt for m in matches]) 73 | # M, mask = cv2.estimateAffine2D(src_pts, dst_pts, method=cv2.RANSAC,ransacReprojThreshold=7) 74 | M, mask = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC,ransacReprojThreshold=7) 75 | if MATCH_VIZ: 76 | matchesMask = mask.ravel().tolist() 77 | draw_params = dict(matchColor = (0,255,0), # draw matches in green color 78 | singlePointColor = None, 79 | matchesMask = matchesMask, # draw only inliers 80 | flags = 0) 81 | match_img = cv2.drawMatches(src_img,keyPoint1,temp_img,keyPoint2,matches,None,**draw_params) 82 | else: 83 | match_img=None 84 | else: 85 | print ("ORB:Not enough matches are found - %d/%d" % (len(good),MIN_MATCH_COUNT)) 86 | return None 87 | 88 | return M,match_img 89 | 90 | -------------------------------------------------------------------------------- /datasets/ALLSS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from pathlib import Path 6 | from torch.utils.data import Dataset 7 | from utils.utils import dict_update 8 | from datasets.data_tools import np_to_tensor,warpLabels 9 | from numpy.linalg import inv 10 | 11 | class ALLSS(Dataset): 12 | default_config = { 13 | 'labels': None, 14 | 'cache_in_memory': False, 15 | 'validation_size': 100, 16 | 'truncate': None, 17 | 'preprocessing': { 18 | 'resize': [240, 320] 19 | }, 20 | 'num_parallel_calls': 10, 21 | 'augmentation': { 22 | 'photometric': { 23 | 'enable': False, 24 | 'primitives': 'all', 25 | 'params': {}, 26 | 'random_order': True, 27 | }, 28 | 'homographic': { 29 | 'enable': False, 30 | 'params': {}, 31 | 'valid_border_margin': 0, 32 | }, 33 | }, 34 | 'warped_pair': { 35 | 'enable': False, 36 | 'params': {}, 37 | 'valid_border_margin': 0, 38 | }, 39 | 'homography_adaptation': { 40 | 'enable': False 41 | } 42 | } 43 | def __init__(self,export=False, transform=None, task='train', **config): 44 | self.config = self.default_config 45 | self.config = dict_update(self.config, config) 46 | self.transforms = transform 47 | self.action = 'train' if task == 'train' else 'val' 48 | 49 | # get files 50 | base_path = Path( 'datasets/ALLSS/' + task) 51 | image_paths = list(base_path.iterdir()) 52 | names = [p.stem for p in image_paths] 53 | image_paths = [str(p) for p in image_paths] 54 | files = {'image_paths': image_paths, 'names': names} 55 | 56 | sequence_set = [] 57 | if self.config['labels']: 58 | print("load labels from: ", self.config['labels']+'/'+task) 59 | for (img, name) in zip(files['image_paths'], files['names']): 60 | p = Path(self.config['labels'], task, '{}.npz'.format(name)) 61 | if p.exists(): 62 | sample = {'image': img, 'name': name, 'points': str(p)} 63 | sequence_set.append(sample) 64 | else: 65 | for (img, name) in zip(files['image_paths'], files['names']): 66 | sample = {'image': img, 'name': name} 67 | sequence_set.append(sample) 68 | 69 | self.samples = sequence_set 70 | self.init_var() 71 | 72 | def init_var(self): 73 | torch.set_default_tensor_type(torch.FloatTensor) 74 | from utils.homographies import sample_homography_np as sample_homography 75 | from utils.utils import compute_valid_mask 76 | from utils.photometric import ImgAugTransform, customizedTransform 77 | from utils.utils import inv_warp_image, inv_warp_image_batch, warp_points 78 | 79 | self.sample_homography = sample_homography 80 | self.inv_warp_image = inv_warp_image 81 | self.inv_warp_image_batch = inv_warp_image_batch 82 | self.compute_valid_mask = compute_valid_mask 83 | self.ImgAugTransform = ImgAugTransform 84 | self.customizedTransform = customizedTransform 85 | self.warp_points = warp_points 86 | 87 | self.enable_photo_train = self.config['augmentation']['photometric']['enable'] 88 | self.enable_homo_train = self.config['augmentation']['homographic']['enable'] 89 | self.enable_homo_val = False 90 | self.enable_photo_val = False 91 | 92 | self.cell_size = 8 93 | if self.config['preprocessing']['resize']: 94 | self.sizer = self.config['preprocessing']['resize'] 95 | 96 | self.gaussian_label = False 97 | if self.config['gaussian_label']['enable']: 98 | self.gaussian_label = True 99 | 100 | def gaussian_blur(self, image): 101 | """ 102 | image: np [H, W] 103 | return: 104 | blurred_image: np [H, W] 105 | """ 106 | aug_par = {'photometric': {}} 107 | aug_par['photometric']['enable'] = True 108 | aug_par['photometric']['params'] = self.config['gaussian_label']['params'] 109 | augmentation = self.ImgAugTransform(**aug_par) 110 | # get label_2D 111 | # labels = points_to_2D(pnts, H, W) 112 | image = image[:,:,np.newaxis] 113 | heatmaps = augmentation(image) 114 | return heatmaps.squeeze() 115 | 116 | def imgPhotometric(self,img): 117 | """ 118 | 119 | :param img: 120 | numpy (H, W) 121 | :return: 122 | """ 123 | augmentation = self.ImgAugTransform(**self.config['augmentation']) 124 | img = img[:,:,np.newaxis] 125 | img = augmentation(img) 126 | cusAug = self.customizedTransform() 127 | img = cusAug(img, **self.config['augmentation']) 128 | return img 129 | def points_to_2D(self,pnts, H, W): 130 | labels = np.zeros((H, W)) 131 | pnts = pnts.astype(int) 132 | labels[pnts[:, 1], pnts[:, 0]] = 1 133 | return labels 134 | 135 | def __getitem__(self, index): 136 | 137 | sample = self.samples[index] 138 | input = {} 139 | input.update(sample) 140 | 141 | img_o = cv2.imread(sample['image']) 142 | img_o = cv2.resize(img_o, (self.sizer[1], self.sizer[0]),interpolation=cv2.INTER_AREA) 143 | img_o = cv2.cvtColor(img_o, cv2.COLOR_RGB2GRAY) 144 | img_o = img_o.astype('float32') / 255.0 145 | H, W = img_o.shape[0], img_o.shape[1] 146 | 147 | img_aug = img_o.copy() 148 | if (self.enable_photo_train == True and self.action == 'train') or (self.enable_photo_val and self.action == 'val'): 149 | img_aug = self.imgPhotometric(img_o) # numpy array (H, W, 1) 150 | img_aug = torch.tensor(img_aug, dtype=torch.float32).view(-1, H, W) 151 | 152 | valid_mask = self.compute_valid_mask(torch.tensor([H, W]), inv_homography=torch.eye(3)) 153 | input.update({'image': img_aug}) 154 | input.update({'valid_mask': valid_mask}) 155 | 156 | if self.config['homography_adaptation']['enable']: 157 | homoAdapt_iter = self.config['homography_adaptation']['num'] 158 | homographies = np.stack([self.sample_homography(np.array([2, 2]), shift=-1, 159 | **self.config['homography_adaptation']['homographies']['params']) 160 | for i in range(homoAdapt_iter)]) 161 | ##### use inverse from the sample homography 162 | homographies = np.stack([inv(homography) for homography in homographies]) 163 | homographies[0,:,:] = np.identity(3) 164 | ###### 165 | homographies = torch.tensor(homographies, dtype=torch.float32) 166 | inv_homographies = torch.stack([torch.inverse(homographies[i, :, :]) for i in range(homoAdapt_iter)]) 167 | # images 168 | warped_img = self.inv_warp_image_batch(img_aug.squeeze().repeat(homoAdapt_iter,1,1,1), inv_homographies, mode='bilinear').unsqueeze(0) 169 | warped_img = warped_img.squeeze() 170 | # masks 171 | valid_mask = self.compute_valid_mask(torch.tensor([H, W]), inv_homography=inv_homographies, 172 | erosion_radius=self.config['augmentation']['homographic'][ 173 | 'valid_border_margin']) 174 | input.update({'image': warped_img, 'valid_mask': valid_mask, 'image_2D':img_aug}) 175 | input.update({'homographies': homographies, 'inv_homographies': inv_homographies}) 176 | 177 | # labels 178 | to_floatTensor = lambda x: torch.tensor(x).type(torch.FloatTensor) 179 | if self.config['labels']: 180 | pnts = np.load(sample['points'])['pts'] 181 | labels = self.points_to_2D(pnts, H, W)#float->int,keypoints:1,others:0 182 | labels_2D = to_floatTensor(labels[np.newaxis,:,:]) 183 | input.update({'labels_2D': labels_2D}) 184 | 185 | ## residual 186 | labels_res = torch.zeros((2, H, W)).type(torch.FloatTensor) 187 | input.update({'labels_res': labels_res}) 188 | 189 | if (self.enable_homo_train == True and self.action == 'train') or (self.enable_homo_val and self.action == 'val'): 190 | homography = self.sample_homography(np.array([2, 2]), shift=-1, 191 | **self.config['augmentation']['homographic']['params']) 192 | ##### use inverse from the sample homography 193 | homography = inv(homography) 194 | inv_homography = inv(homography) 195 | inv_homography = torch.tensor(inv_homography).to(torch.float32) 196 | homography = torch.tensor(homography).to(torch.float32) 197 | warped_img = self.inv_warp_image(img_aug.squeeze(), inv_homography, mode='bilinear').unsqueeze(0) 198 | 199 | ##### check ##### 200 | warped_set = warpLabels(pnts, H, W, homography) 201 | warped_labels = warped_set['labels'] 202 | valid_mask = self.compute_valid_mask(torch.tensor([H, W]), inv_homography=inv_homography, 203 | erosion_radius=self.config['augmentation']['homographic']['valid_border_margin']) 204 | 205 | input.update({'image': warped_img, 'labels_2D': warped_labels, 'valid_mask': valid_mask}) 206 | 207 | if self.config['warped_pair']['enable']: 208 | homography = self.sample_homography(np.array([2, 2]), shift=-1, 209 | **self.config['warped_pair']['params']) 210 | 211 | ##### use inverse from the sample homography 212 | homography = np.linalg.inv(homography) 213 | inv_homography = np.linalg.inv(homography) 214 | 215 | homography = torch.tensor(homography).type(torch.FloatTensor) 216 | inv_homography = torch.tensor(inv_homography).type(torch.FloatTensor) 217 | 218 | # warp original image 219 | warped_img = torch.tensor(img_o, dtype=torch.float32) 220 | warped_img = self.inv_warp_image(warped_img.squeeze(), inv_homography, mode='bilinear').unsqueeze(0) 221 | if (self.enable_photo_train == True and self.action == 'train') or (self.enable_photo_val and self.action == 'val'): 222 | warped_img = self.imgPhotometric(warped_img.numpy().squeeze()) # numpy array (H, W, 1) 223 | warped_img = torch.tensor(warped_img, dtype=torch.float32) 224 | pass 225 | warped_img = warped_img.view(-1, H, W) 226 | 227 | # warped_labels = warpLabels(pnts, H, W, homography) 228 | warped_set = warpLabels(pnts, H, W, homography, bilinear=True) 229 | warped_labels = warped_set['labels'] 230 | warped_res = warped_set['res'] 231 | warped_res = warped_res.transpose(1,2).transpose(0,1) 232 | if self.gaussian_label: 233 | from utils.var_dim import squeezeToNumpy 234 | warped_labels_bi = warped_set['labels_bi'] 235 | warped_labels_gaussian = self.gaussian_blur(squeezeToNumpy(warped_labels_bi)) 236 | warped_labels_gaussian = np_to_tensor(warped_labels_gaussian, H, W) 237 | input['warped_labels_gaussian'] = warped_labels_gaussian 238 | input.update({'warped_labels_bi': warped_labels_bi}) 239 | 240 | input.update({'warped_img': warped_img, 'warped_labels': warped_labels, 'warped_res': warped_res}) 241 | valid_mask = self.compute_valid_mask(torch.tensor([H, W]), inv_homography=inv_homography, 242 | erosion_radius=self.config['warped_pair']['valid_border_margin']) # can set to other value 243 | input.update({'warped_valid_mask': valid_mask}) 244 | input.update({'homographies': homography, 'inv_homographies': inv_homography}) 245 | 246 | if self.gaussian_label: 247 | labels_gaussian = self.gaussian_blur(squeezeToNumpy(labels_2D)) 248 | labels_gaussian = np_to_tensor(labels_gaussian, H, W) 249 | input['labels_2D_gaussian'] = labels_gaussian 250 | 251 | name = sample['name'] 252 | input.update({'name': name, 'scene_name': "./"}) # dummy scene name 253 | return input 254 | 255 | 256 | def __len__(self): 257 | return len(self.samples) 258 | 259 | -------------------------------------------------------------------------------- /datasets/GlueSparse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from scipy.spatial.distance import cdist 6 | from torch.utils.data import Dataset 7 | from superpoint.models.superpoint_test import SuperPoint 8 | from utils.utils import frame2tensor 9 | 10 | class GlueSparse(Dataset): 11 | """Sparse correspondences dataset.""" 12 | 13 | def __init__(self,train_path,sp_config,resize,device): 14 | self.device=device 15 | self.resize=resize 16 | self.files=[] 17 | self.files=[train_path+'/'+ f for f in os.listdir(train_path)] 18 | self.superpoint = SuperPoint(sp_config).to(device) 19 | self.superpoint.eval() 20 | 21 | def __len__(self): 22 | return len(self.files) 23 | 24 | def __getitem__(self, index): 25 | file_name=self.files[index] 26 | image=cv2.imread(file_name,0) 27 | image=cv2.resize(image,(self.resize[0],self.resize[1])) 28 | width,height=image.shape[:2] 29 | corners=np.array([[0,0],[0,height],[width,0],[width,height]],dtype=np.float32) 30 | warp=np.random.randint(-100,100,size=(4,2)).astype(np.float32) 31 | M = cv2.getPerspectiveTransform(corners, corners + warp) 32 | warped = cv2.warpPerspective(src=image, M=M, dsize=(image.shape[1], image.shape[0])) 33 | 34 | # extract keypoints of the image pair using SuperPoint 35 | image_tensor = frame2tensor(image, self.device) 36 | warped_tensor = frame2tensor(warped, self.device) 37 | 38 | pred1 = self.superpoint(image_tensor) 39 | pred2 = self.superpoint(warped_tensor) 40 | 41 | #keypoints,descriptors,scores 42 | kp1_np=pred1["keypoints"][0].cpu().detach().numpy()#(636,2) 43 | kp2_np=pred2["keypoints"][0].cpu().detach().numpy()#(666,2) 44 | 45 | descs1=pred1["descriptors"][0].cpu().detach().numpy().transpose()#(636,256) 46 | descs2=pred2["descriptors"][0].cpu().detach().numpy().transpose()#(666,256) 47 | 48 | scores1_np=pred1["scores"][0].cpu().detach().numpy()#(636,) 49 | scores2_np=pred2["scores"][0].cpu().detach().numpy()#(666,) 50 | 51 | # skip this image pair if no keypoints detected in image 52 | if len(kp1_np) < 1 or len(kp2_np) < 1: 53 | return{ 54 | 'keypoints0': torch.zeros([0, 0, 2], dtype=torch.double), 55 | 'keypoints1': torch.zeros([0, 0, 2], dtype=torch.double), 56 | 'descriptors0': torch.zeros([0, 2], dtype=torch.double), 57 | 'descriptors1': torch.zeros([0, 2], dtype=torch.double), 58 | 'image0': image, 59 | 'image1': warped, 60 | 'file_name': file_name 61 | } 62 | 63 | # obtain the matching matrix of the image pair 64 | kp1_projected = cv2.perspectiveTransform(kp1_np.reshape((1, -1, 2)), M)[0, :, :] 65 | dists = cdist(kp1_projected, kp2_np)#(636,666) 66 | 67 | min1 = np.argmin(dists, axis=0) # 在axis=0方向上找最小的值并返回索引 68 | min2 = np.argmin(dists, axis=1) # 在axis=1方向上找最小的值并返回索引 69 | 70 | min1v = np.min(dists, axis=1) 71 | min1f = min2[min1v < 3] 72 | 73 | xx = np.where(min2[min1] == np.arange(min1.shape[0]))[0]#最佳匹配点,两个方向最近的匹配点 74 | matches = np.intersect1d(min1f, xx)#34 75 | 76 | missing1 = np.setdiff1d(np.arange(kp1_np.shape[0]), min1[matches])#返回两个数组的差集,非配对点 77 | missing2 = np.setdiff1d(np.arange(kp2_np.shape[0]), matches)#返回两个数组的差集,非配对点 78 | 79 | MN = np.concatenate([min1[matches][np.newaxis, :], matches[np.newaxis, :]])#正确匹配的点对应 80 | MN2 = np.concatenate([missing1[np.newaxis, :], (len(kp2_np)) * np.ones((1, len(missing1)), dtype=np.int64)])#没有匹配点的dustbin 列 81 | MN3 = np.concatenate([(len(kp1_np)) * np.ones((1, len(missing2)), dtype=np.int64), missing2[np.newaxis, :]])#没有匹配点的dustbin 行 82 | all_matches = np.concatenate([MN, MN2, MN3], axis=1) 83 | 84 | kp1_np = kp1_np.reshape((1, -1, 2))#(1,636,2) 85 | kp2_np = kp2_np.reshape((1, -1, 2))#(1,666,2) 86 | descs1 = np.transpose(descs1)#(256,636) 87 | descs2 = np.transpose(descs2)#(256,666) 88 | 89 | image = torch.from_numpy(image/255.).double()[None].to(self.device) 90 | warped = torch.from_numpy(warped/255.).double()[None].to(self.device) 91 | 92 | return{ 93 | 'keypoints0': list(kp1_np), 94 | 'keypoints1': list(kp2_np), 95 | 'descriptors0': list(descs1), 96 | 'descriptors1': list(descs2), 97 | 'scores0': list(scores1_np), 98 | 'scores1': list(scores2_np), 99 | 'image0': image, 100 | 'image1': warped, 101 | 'matches':MN, 102 | 'all_matches': list(all_matches), 103 | 'file_name': file_name 104 | } -------------------------------------------------------------------------------- /datasets/SSHIDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | class SSHIDataset(Dataset): 7 | def __init__(self,source_dir,template_path,resize_scale): 8 | self.source_list=os.listdir(source_dir) 9 | self.source_dir = [] 10 | self.source_dir += [source_dir + f for f in self.source_list] 11 | self.template_path=template_path 12 | self.resize_scale=resize_scale 13 | 14 | def __getitem__(self, index): 15 | filename=self.source_list[index] 16 | source_path=self.source_dir[index] 17 | source_original = cv2.imread(source_path, cv2.IMREAD_GRAYSCALE) 18 | template_original = cv2.imread(self.template_path,cv2.IMREAD_GRAYSCALE) 19 | 20 | if self.resize_scale is not None: 21 | source_image = cv2.resize(source_original,(int(self.resize_scale*source_original.shape[1]),int(self.resize_scale*source_original.shape[0]))) 22 | template_image = cv2.resize(template_original,(int(self.resize_scale*template_original.shape[1]),int(self.resize_scale*template_original.shape[0]))) 23 | else: 24 | source_image = source_original 25 | template_image = template_original 26 | source_original=source_original[None]/255 27 | source_image=source_image[None]/255 28 | template_image = template_image[None]/255 29 | return source_original,source_image,template_image,filename 30 | 31 | def __len__(self): 32 | return len(self.source_list) 33 | 34 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | def get_dataset(name): 2 | mod = __import__('datasets.{}'.format(name), fromlist=['']) 3 | return getattr(mod, _module_to_class(name)) 4 | 5 | 6 | def _module_to_class(name): 7 | return ''.join(n.capitalize() for n in name.split('_')) 8 | -------------------------------------------------------------------------------- /datasets/data_tools.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from utils.utils import filter_points 4 | from utils.utils import warp_points 5 | from utils.utils import homography_scaling_torch as homography_scaling 6 | 7 | quan = lambda x: x.round().long() 8 | 9 | def extrapolate_points(pnts): 10 | pnts_int = pnts.long().type(torch.FloatTensor) 11 | pnts_x, pnts_y = pnts_int[:,0], pnts_int[:,1] 12 | 13 | stack_1 = lambda x, y: torch.stack((x, y), dim=1) 14 | pnts_ext = torch.cat((pnts_int, stack_1(pnts_x, pnts_y+1), 15 | stack_1(pnts_x+1, pnts_y), pnts_int+1), dim=0) 16 | 17 | pnts_res = pnts - pnts_int # (x, y) 18 | x_res, y_res = pnts_res[:,0], pnts_res[:,1] # residuals 19 | res_ext = torch.cat(((1-x_res)*(1-y_res), (1-x_res)*y_res, 20 | x_res*(1-y_res), x_res*y_res), dim=0) 21 | return pnts_ext, res_ext 22 | 23 | def scatter_points(warped_pnts, H, W, res_ext = 1): 24 | warped_labels = torch.zeros(H, W) 25 | warped_labels[quan(warped_pnts)[:, 1], quan(warped_pnts)[:, 0]] = res_ext 26 | warped_labels = warped_labels.view(-1, H, W) 27 | return warped_labels 28 | 29 | def get_labels_bi(warped_pnts, H, W): 30 | pnts_ext, res_ext = extrapolate_points(warped_pnts) 31 | pnts_ext, mask = filter_points(pnts_ext, torch.tensor([W, H]), return_mask=True) 32 | res_ext = res_ext[mask] 33 | warped_labels_bi = scatter_points(pnts_ext, H, W, res_ext = res_ext) 34 | return warped_labels_bi 35 | 36 | def warpLabels(pnts, H, W, homography, bilinear = False): 37 | if isinstance(pnts, torch.Tensor): 38 | pnts = pnts.long() 39 | else: 40 | pnts = torch.tensor(pnts).long() 41 | warped_pnts = warp_points(torch.stack((pnts[:, 0], pnts[:, 1]), dim=1), 42 | homography_scaling(homography, H, W)) # check the (x, y) 43 | outs = {} 44 | 45 | if bilinear == True: 46 | warped_labels_bi = get_labels_bi(warped_pnts, H, W) 47 | outs['labels_bi'] = warped_labels_bi 48 | 49 | warped_pnts = filter_points(warped_pnts, torch.tensor([W, H])) 50 | warped_labels = scatter_points(warped_pnts, H, W, res_ext = 1) 51 | warped_labels_res = torch.zeros(H, W, 2) 52 | warped_labels_res[quan(warped_pnts)[:, 1], quan(warped_pnts)[:, 0], :] = warped_pnts - warped_pnts.round() 53 | outs.update({'labels': warped_labels, 'res': warped_labels_res, 'warped_pnts': warped_pnts}) 54 | return outs 55 | 56 | def np_to_tensor(img, H, W): 57 | img = torch.tensor(img).type(torch.FloatTensor).view(-1, H, W) 58 | return img 59 | 60 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | scipy 3 | opencv-python 4 | opencv-contrib-python 5 | matplotlib 6 | imageio 7 | tqdm 8 | tensorboard 9 | tensorboardX 10 | tqdm 11 | pyyaml 12 | imageio 13 | imgaug 14 | jupyter 15 | scikit-learn 16 | torchgeometry 17 | torchsummary 18 | coloredlogs -------------------------------------------------------------------------------- /superglue/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/superglue/models/__init__.py -------------------------------------------------------------------------------- /superglue/models/matching.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | import torch 44 | from superglue.models.superpoint import SuperPoint 45 | from superglue.models.superglue_test import SuperGlue 46 | 47 | class Matching(torch.nn.Module): 48 | """ Image Matching Frontend (SuperPoint + SuperGlue) """ 49 | def __init__(self, config={}): 50 | super().__init__() 51 | self.superpoint = SuperPoint(config.get('superpoint', {})) 52 | self.superglue = SuperGlue(config.get('superglue', {})) 53 | 54 | def forward(self, data): 55 | """ Run SuperPoint (optionally) and SuperGlue 56 | SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input 57 | Args: 58 | data: dictionary with minimal keys: ['image0', 'image1'] 59 | """ 60 | pred = {} 61 | 62 | # Extract SuperPoint (keypoints, scores, descriptors) if not provided 63 | if 'keypoints0' not in data: 64 | pred0 = self.superpoint({'image': data['image0']}) 65 | pred = {**pred, **{k+'0': v for k, v in pred0.items()}} 66 | if 'keypoints1' not in data: 67 | pred1 = self.superpoint({'image': data['image1']}) 68 | pred = {**pred, **{k+'1': v for k, v in pred1.items()}} 69 | 70 | # Batch all features 71 | # We should either have i) one image per batch, or 72 | # ii) the same number of local features for all images in the batch. 73 | data = {**data, **pred} 74 | 75 | for k in data: 76 | if isinstance(data[k], (list, tuple)): 77 | data[k] = torch.stack(data[k]) 78 | 79 | # Perform the matching 80 | pred = {**pred, **self.superglue(data)} 81 | 82 | return pred 83 | -------------------------------------------------------------------------------- /superglue/models/matching_test.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | import torch 44 | from superpoint.models.superpoint_test import SuperPoint 45 | from superglue.models.superglue_test import SuperGlue 46 | 47 | class Matching(torch.nn.Module): 48 | """ Image Matching Frontend (SuperPoint + SuperGlue) """ 49 | def __init__(self, config={}): 50 | super().__init__() 51 | self.superpoint = SuperPoint(config.get('superpoint', {})) 52 | self.superglue = SuperGlue(config.get('superglue', {})) 53 | 54 | def forward(self, data): 55 | """ Run SuperPoint (optionally) and SuperGlue 56 | SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input 57 | Args: 58 | data: dictionary with minimal keys: ['image0', 'image1'] 59 | """ 60 | pred = {} 61 | 62 | # Extract SuperPoint (keypoints, scores, descriptors) if not provided 63 | if 'keypoints0' not in data: 64 | pred0 = self.superpoint(data['image0']) 65 | pred = {**pred, **{k+'0': v for k, v in pred0.items()}} 66 | if 'keypoints1' not in data: 67 | pred1 = self.superpoint(data['image1']) 68 | pred = {**pred, **{k+'1': v for k, v in pred1.items()}} 69 | 70 | # Batch all features 71 | # We should either have i) one image per batch, or 72 | # ii) the same number of local features for all images in the batch. 73 | data = {**data, **pred} 74 | 75 | for k in data: 76 | if isinstance(data[k], (list, tuple)): 77 | data[k] = torch.stack(data[k]) 78 | 79 | # Perform the matching 80 | pred = {**pred, **self.superglue(data)} 81 | 82 | return pred 83 | -------------------------------------------------------------------------------- /superglue/models/superglue_test.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | from copy import deepcopy 44 | from pathlib import Path 45 | import torch 46 | from torch import nn 47 | 48 | 49 | def MLP(channels: list, do_bn=True): 50 | """ Multi-layer perceptron """ 51 | n = len(channels) 52 | layers = [] 53 | for i in range(1, n): 54 | layers.append( 55 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) 56 | if i < (n-1): 57 | if do_bn: 58 | layers.append(nn.BatchNorm1d(channels[i])) 59 | layers.append(nn.ReLU()) 60 | return nn.Sequential(*layers) 61 | 62 | 63 | def normalize_keypoints(kpts, image_shape): 64 | """ Normalize keypoints locations based on image image_shape""" 65 | _, _, height, width = image_shape 66 | one = kpts.new_tensor(1) 67 | size = torch.stack([one*width, one*height])[None] 68 | center = size / 2 69 | scaling = size.max(1, keepdim=True).values * 0.7 70 | return (kpts - center[:, None, :]) / scaling[:, None, :] 71 | 72 | 73 | class KeypointEncoder(nn.Module): 74 | """ Joint encoding of visual appearance and location using MLPs""" 75 | def __init__(self, feature_dim, layers): 76 | super().__init__() 77 | self.encoder = MLP([3] + layers + [feature_dim]) 78 | nn.init.constant_(self.encoder[-1].bias, 0.0) 79 | 80 | def forward(self, kpts, scores): 81 | inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] 82 | return self.encoder(torch.cat(inputs, dim=1)) 83 | 84 | 85 | def attention(query, key, value): 86 | dim = query.shape[1] 87 | scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 88 | prob = torch.nn.functional.softmax(scores, dim=-1) 89 | return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob 90 | 91 | 92 | class MultiHeadedAttention(nn.Module): 93 | """ Multi-head attention to increase model expressivitiy """ 94 | def __init__(self, num_heads: int, d_model: int): 95 | super().__init__() 96 | assert d_model % num_heads == 0 97 | self.dim = d_model // num_heads 98 | self.num_heads = num_heads 99 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) 100 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) 101 | 102 | def forward(self, query, key, value): 103 | batch_dim = query.size(0) 104 | query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) 105 | for l, x in zip(self.proj, (query, key, value))] 106 | x, _ = attention(query, key, value) 107 | return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) 108 | 109 | 110 | class AttentionalPropagation(nn.Module): 111 | def __init__(self, feature_dim: int, num_heads: int): 112 | super().__init__() 113 | self.attn = MultiHeadedAttention(num_heads, feature_dim) 114 | self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) 115 | nn.init.constant_(self.mlp[-1].bias, 0.0) 116 | 117 | def forward(self, x, source): 118 | message = self.attn(x, source, source) 119 | return self.mlp(torch.cat([x, message], dim=1)) 120 | 121 | 122 | class AttentionalGNN(nn.Module): 123 | def __init__(self, feature_dim: int, layer_names: list): 124 | super().__init__() 125 | self.layers = nn.ModuleList([ 126 | AttentionalPropagation(feature_dim, 4) 127 | for _ in range(len(layer_names))]) 128 | self.names = layer_names 129 | 130 | def forward(self, desc0, desc1): 131 | for layer, name in zip(self.layers, self.names): 132 | if name == 'cross': 133 | src0, src1 = desc1, desc0 134 | else: # if name == 'self': 135 | src0, src1 = desc0, desc1 136 | delta0, delta1 = layer(desc0, src0), layer(desc1, src1) 137 | desc0, desc1 = (desc0 + delta0), (desc1 + delta1) 138 | return desc0, desc1 139 | 140 | 141 | def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int): 142 | """ Perform Sinkhorn Normalization in Log-space for stability""" 143 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 144 | for _ in range(iters): 145 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) 146 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) 147 | return Z + u.unsqueeze(2) + v.unsqueeze(1) 148 | 149 | 150 | def log_optimal_transport(scores, alpha, iters: int): 151 | """ Perform Differentiable Optimal Transport in Log-space for stability""" 152 | b, m, n = scores.shape 153 | one = scores.new_tensor(1) 154 | ms, ns = (m*one).to(scores), (n*one).to(scores) 155 | 156 | bins0 = alpha.expand(b, m, 1) 157 | bins1 = alpha.expand(b, 1, n) 158 | alpha = alpha.expand(b, 1, 1) 159 | 160 | couplings = torch.cat([torch.cat([scores, bins0], -1), 161 | torch.cat([bins1, alpha], -1)], 1) 162 | 163 | norm = - (ms + ns).log() 164 | log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) 165 | log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) 166 | log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) 167 | 168 | Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) 169 | Z = Z - norm # multiply probabilities by M+N 170 | return Z 171 | 172 | 173 | def arange_like(x, dim: int): 174 | return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 175 | 176 | 177 | class SuperGlue(nn.Module): 178 | """SuperGlue feature matching middle-end 179 | 180 | Given two sets of keypoints and locations, we determine the 181 | correspondences by: 182 | 1. Keypoint Encoding (normalization + visual feature and location fusion) 183 | 2. Graph Neural Network with multiple self and cross-attention layers 184 | 3. Final projection layer 185 | 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm) 186 | 5. Thresholding matrix based on mutual exclusivity and a match_threshold 187 | 188 | The correspondence ids use -1 to indicate non-matching points. 189 | 190 | Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew 191 | Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural 192 | Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 193 | 194 | """ 195 | default_config = { 196 | 'descriptor_dim': 256, 197 | 'weights': 'indoor', 198 | 'keypoint_encoder': [32, 64, 128, 256], 199 | 'GNN_layers': ['self', 'cross'] * 9, 200 | 'sinkhorn_iterations': 100, 201 | 'match_threshold': 0.2, 202 | } 203 | 204 | def __init__(self, config): 205 | super().__init__() 206 | self.config = {**self.default_config, **config} 207 | 208 | self.kenc = KeypointEncoder( 209 | self.config['descriptor_dim'], self.config['keypoint_encoder']) 210 | 211 | self.gnn = AttentionalGNN( 212 | self.config['descriptor_dim'], self.config['GNN_layers']) 213 | 214 | self.final_proj = nn.Conv1d( 215 | self.config['descriptor_dim'], self.config['descriptor_dim'], 216 | kernel_size=1, bias=True) 217 | 218 | bin_score = torch.nn.Parameter(torch.tensor(1.)) 219 | self.register_parameter('bin_score', bin_score) 220 | 221 | if self.config['weights']: 222 | checkpoints = torch.load(config['weights']) 223 | if 'indoor' in self.config['weights'] or 'outdoor' in self.config['weights']: 224 | state_dict=checkpoints 225 | else: 226 | state_dict = checkpoints['net'] 227 | self.load_state_dict(state_dict) 228 | print('Loaded SuperGlue model weights') 229 | 230 | def forward(self, data): 231 | """Run SuperGlue on a pair of keypoints and descriptors""" 232 | desc0, desc1 = data['descriptors0'], data['descriptors1'] 233 | kpts0, kpts1 = data['keypoints0'], data['keypoints1'] 234 | 235 | if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints 236 | shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] 237 | return { 238 | 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int), 239 | 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int), 240 | 'matching_scores0': kpts0.new_zeros(shape0), 241 | 'matching_scores1': kpts1.new_zeros(shape1), 242 | } 243 | 244 | # Keypoint normalization. 245 | kpts0 = normalize_keypoints(kpts0, data['image0'].shape) 246 | kpts1 = normalize_keypoints(kpts1, data['image1'].shape) 247 | 248 | # Keypoint MLP encoder. 249 | desc0 = desc0 + self.kenc(kpts0, data['scores0'])#(1,256,n_p0) 250 | desc1 = desc1 + self.kenc(kpts1, data['scores1'])#(1,256,n_p1) 251 | 252 | # Multi-layer Transformer network. 253 | desc0, desc1 = self.gnn(desc0, desc1) 254 | 255 | # Final MLP projection. 256 | mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) 257 | 258 | # Compute matching descriptor distance. 259 | scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) 260 | scores = scores / self.config['descriptor_dim']**.5 261 | 262 | # Run the optimal transport. 263 | scores = log_optimal_transport( 264 | scores, self.bin_score, 265 | iters=self.config['sinkhorn_iterations']) 266 | 267 | # Get the matches with score above "match_threshold". 268 | max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) 269 | indices0, indices1 = max0.indices, max1.indices 270 | mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) 271 | mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) 272 | zero = scores.new_tensor(0) 273 | mscores0 = torch.where(mutual0, max0.values.exp(), zero) 274 | mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) 275 | valid0 = mutual0 & (mscores0 > self.config['match_threshold']) 276 | valid1 = mutual1 & valid0.gather(1, indices1) 277 | indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) 278 | indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) 279 | 280 | return { 281 | 'matches0': indices0, # use -1 for invalid match 282 | 'matches1': indices1, # use -1 for invalid match 283 | 'matching_scores0': mscores0, 284 | 'matching_scores1': mscores1, 285 | } 286 | -------------------------------------------------------------------------------- /superglue/models/superpoint.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | from pathlib import Path 44 | import torch 45 | from torch import nn 46 | 47 | def simple_nms(scores, nms_radius: int): 48 | """ Fast Non-maximum suppression to remove nearby points """ 49 | assert(nms_radius >= 0) 50 | 51 | def max_pool(x): 52 | return torch.nn.functional.max_pool2d( 53 | x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) 54 | 55 | zeros = torch.zeros_like(scores) 56 | max_mask = scores == max_pool(scores) 57 | for _ in range(2): 58 | supp_mask = max_pool(max_mask.float()) > 0 59 | supp_scores = torch.where(supp_mask, zeros, scores) 60 | new_max_mask = supp_scores == max_pool(supp_scores) 61 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 62 | return torch.where(max_mask, scores, zeros) 63 | 64 | 65 | def remove_borders(keypoints, scores, border: int, height: int, width: int): 66 | """ Removes keypoints too close to the border """ 67 | mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) 68 | mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) 69 | mask = mask_h & mask_w 70 | return keypoints[mask], scores[mask] 71 | 72 | 73 | def top_k_keypoints(keypoints, scores, k: int): 74 | if k >= len(keypoints): 75 | return keypoints, scores 76 | scores, indices = torch.topk(scores, k, dim=0) 77 | return keypoints[indices], scores 78 | 79 | 80 | def sample_descriptors(keypoints, descriptors, s: int = 8): 81 | """ Interpolate descriptors at keypoint locations """ 82 | b, c, h, w = descriptors.shape 83 | keypoints = keypoints - s / 2 + 0.5 84 | keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], 85 | ).to(keypoints)[None] 86 | keypoints = keypoints*2 - 1 # normalize to (-1, 1) 87 | args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} 88 | descriptors = torch.nn.functional.grid_sample( 89 | descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) 90 | descriptors = torch.nn.functional.normalize( 91 | descriptors.reshape(b, c, -1), p=2, dim=1) 92 | return descriptors 93 | 94 | 95 | class SuperPoint(nn.Module): 96 | """SuperPoint Convolutional Detector and Descriptor 97 | 98 | SuperPoint: Self-Supervised Interest Point Detection and 99 | Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew 100 | Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 101 | 102 | """ 103 | default_config = { 104 | 'descriptor_dim': 256, 105 | 'nms_radius': 4, 106 | 'keypoint_threshold': 0.005, 107 | 'max_keypoints': -1, 108 | 'remove_borders': 4, 109 | } 110 | 111 | def __init__(self, config): 112 | super().__init__() 113 | self.config = {**self.default_config, **config} 114 | 115 | self.relu = nn.ReLU(inplace=True) 116 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 117 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 118 | 119 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) 120 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 121 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 122 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 123 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 124 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 125 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 126 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 127 | 128 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 129 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) 130 | 131 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 132 | self.convDb = nn.Conv2d( 133 | c5, self.config['descriptor_dim'], 134 | kernel_size=1, stride=1, padding=0) 135 | 136 | path = Path(__file__).parent / 'weights/superpoint_v1.pth' 137 | self.load_state_dict(torch.load(str(path))) 138 | 139 | mk = self.config['max_keypoints'] 140 | if mk == 0 or mk < -1: 141 | raise ValueError('\"max_keypoints\" must be positive or \"-1\"') 142 | 143 | print('Loaded SuperPoint model') 144 | 145 | def forward(self, data): 146 | """ Compute keypoints, scores, descriptors for image """ 147 | # Shared Encoder 148 | x = self.relu(self.conv1a(data['image'])) 149 | x = self.relu(self.conv1b(x)) 150 | x = self.pool(x) 151 | x = self.relu(self.conv2a(x)) 152 | x = self.relu(self.conv2b(x)) 153 | x = self.pool(x) 154 | x = self.relu(self.conv3a(x)) 155 | x = self.relu(self.conv3b(x)) 156 | x = self.pool(x) 157 | x = self.relu(self.conv4a(x)) 158 | x = self.relu(self.conv4b(x))#(1,128,64,64) 159 | 160 | # Compute the dense keypoint scores 161 | cPa = self.relu(self.convPa(x))#(1,256,64,64) 162 | scores = self.convPb(cPa)#(1,65,64,64) 163 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1]#(1,64,64,64) 164 | b, _, h, w = scores.shape 165 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 166 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) 167 | scores = simple_nms(scores, self.config['nms_radius']) 168 | 169 | # Extract keypoints 170 | keypoints = [ 171 | torch.nonzero(s > self.config['keypoint_threshold']) 172 | for s in scores] 173 | scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] 174 | 175 | # Discard keypoints near the image borders 176 | keypoints, scores = list(zip(*[ 177 | remove_borders(k, s, self.config['remove_borders'], h*8, w*8) 178 | for k, s in zip(keypoints, scores)])) 179 | 180 | # Keep the k keypoints with highest score 181 | if self.config['max_keypoints'] >= 0: 182 | keypoints, scores = list(zip(*[ 183 | top_k_keypoints(k, s, self.config['max_keypoints']) 184 | for k, s in zip(keypoints, scores)])) 185 | 186 | # Convert (h, w) to (x, y) 187 | keypoints = [torch.flip(k, [1]).float() for k in keypoints] 188 | 189 | # Compute the dense descriptors 190 | cDa = self.relu(self.convDa(x))#(1,256,64,64) 191 | descriptors = self.convDb(cDa)#(1,256,64,64) 192 | descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)#(1,256,64,64) 193 | 194 | # Extract descriptors 195 | descriptors = [sample_descriptors(k[None], d[None], 8)[0] 196 | for k, d in zip(keypoints, descriptors)] 197 | 198 | return { 199 | 'keypoints': keypoints, 200 | 'scores': scores, 201 | 'descriptors': descriptors, 202 | } 203 | -------------------------------------------------------------------------------- /superglue/models/weights/SuperGlue_allss_descriptor_128.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eae18d22a70d037e6d6ae9b67e9523909829cdb4d17b4e8b579593c430775622 3 | size 12234575 4 | -------------------------------------------------------------------------------- /superglue/models/weights/SuperGlue_allss_descriptor_64.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8d20e8fcf65d6455d43f6570a60f69d33c77c42e26b7be0e2c37e686da52edf3 3 | size 3181207 4 | -------------------------------------------------------------------------------- /superglue/models/weights/superglue_indoor.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0e710469be25ebe1e2ccf68edcae8b2945b0617c8e7e68412251d9d47f5052b1 3 | size 48233807 4 | -------------------------------------------------------------------------------- /superglue/models/weights/superglue_outdoor.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2f5f5e9bb3febf07b69df633c4c3ff7a17f8af26a023aae2b9303d22339195bd 3 | size 48233807 4 | -------------------------------------------------------------------------------- /superglue/models/weights/superpoint_v1.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:52b6708629640ca883673b5d5c097c4ddad37d8048b33f09c8ca0d69db12c40e 3 | size 5206086 4 | -------------------------------------------------------------------------------- /superpoint/configs/magicpoint_allss_export.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: 'ALLSS' # 'coco' 'hpatches' 3 | export_folder: 'train' # train, val 4 | preprocessing: 5 | resize: [480, 640] #[240,320] 6 | gaussian_label: 7 | enable: false # false 8 | sigma: 1. 9 | augmentation: 10 | photometric: 11 | enable: false 12 | homography_adaptation: 13 | enable: true 14 | num: 50 # 100 15 | aggregation: 'sum' 16 | filter_counts: 0 17 | homographies: 18 | params: 19 | translation: true 20 | rotation: true 21 | scaling: true 22 | perspective: true 23 | scaling_amplitude: 0.2 24 | perspective_amplitude_x: 0.2 25 | perspective_amplitude_y: 0.2 26 | allow_artifacts: true 27 | patch_ratio: 0.85 28 | 29 | training: 30 | workers_test: 0 31 | 32 | model: 33 | name: 'superpoint_train' 34 | params: { 35 | } 36 | batch_size: 1 37 | eval_batch_size: 1 38 | detection_threshold: 0.015 39 | nms: 4 40 | top_k: 1200 41 | subpixel: 42 | enable: true 43 | 44 | pretrained: 'superpoint/models/weights/magicpoint/superPointNet_100000_checkpoint.pth.tar' 45 | 46 | 47 | -------------------------------------------------------------------------------- /superpoint/configs/superpoint_allss_train_heatmap.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: 'ALLSS' # 'coco' 3 | labels: Results/ALLSS/magicpoint_homoAdapt_pseudo 4 | root: # datasets/COCO 5 | root_split_txt: # /datasets/COCO 6 | 7 | gaussian_label: 8 | enable: true 9 | params: 10 | GaussianBlur: {sigma: 0.2} 11 | 12 | cache_in_memory: false 13 | preprocessing: 14 | resize: [480, 640] 15 | augmentation: 16 | photometric: 17 | enable: true 18 | primitives: [ 19 | 'random_brightness', 'random_contrast', 'additive_speckle_noise', 20 | 'additive_gaussian_noise', 'additive_shade', 'motion_blur'] 21 | params: 22 | random_brightness: {max_abs_change: 50} 23 | random_contrast: {strength_range: [0.5, 1.5]} 24 | additive_gaussian_noise: {stddev_range: [0, 10]} 25 | additive_speckle_noise: {prob_range: [0, 0.0035]} 26 | additive_shade: 27 | transparency_range: [-0.5, 0.5] 28 | kernel_size_range: [100, 150] 29 | motion_blur: {max_kernel_size: 3} 30 | homographic: 31 | enable: false # not implemented 32 | warped_pair: 33 | enable: true 34 | params: 35 | translation: true 36 | rotation: true 37 | scaling: true 38 | perspective: true 39 | scaling_amplitude: 0.2 40 | perspective_amplitude_x: 0.2 41 | perspective_amplitude_y: 0.2 42 | patch_ratio: 0.85 43 | max_angle: 1.57 44 | allow_artifacts: true # true 45 | valid_border_margin: 3 46 | 47 | front_end_model: 'Train_model_heatmap' # 'Train_model_frontend' 48 | 49 | training: 50 | workers_train: 4 # 16 51 | workers_val: 2 # 2 52 | 53 | model: 54 | name: 'superpoint_train' 55 | descriptor_length: 128 56 | detector_loss: 57 | loss_type: 'softmax' 58 | 59 | batch_size: 8 # 32 60 | eval_batch_size: 8 # 32 61 | learning_rate: 0.0001 # 0.0001 62 | detection_threshold: 0.015 # 0.015 63 | lambda_loss: 1 # 1 64 | nms: 4 65 | dense_loss: 66 | enable: false 67 | params: 68 | descriptor_dist: 4 # 4, 7.5 69 | lambda_d: 800 # 800 70 | sparse_loss: 71 | enable: true 72 | params: 73 | num_matching_attempts: 1000 74 | num_masked_non_matches_per_match: 100 75 | lamda_d: 1 76 | dist: 'cos' 77 | method: '2d' 78 | other_settings: 'train 2d, gauss 0.2' 79 | # subpixel: 80 | # enable: false 81 | # params: 82 | # subpixel_channel: 2 83 | # settings: 'predict flow directly' 84 | # loss_func: 'subpixel_loss_no_argmax' # subpixel_loss, subpixel_loss_no_argmax 85 | 86 | retrain: True # set true for new model 87 | reset_iter: True # set true to set the iteration number to 0 88 | train_iter: 100000 # 170000 89 | validation_interval: 2000 # 2000 90 | tensorboard_interval: 200 # 200 91 | save_interval: 2000 # 2000 92 | validation_size: 5 93 | 94 | pretrained: -------------------------------------------------------------------------------- /superpoint/correspondence_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/superpoint/correspondence_tools/__init__.py -------------------------------------------------------------------------------- /superpoint/correspondence_tools/correspondence_plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib.image as mpimg 2 | import matplotlib.pyplot as plt 3 | from matplotlib.patches import Circle 4 | 5 | def plot_correspondences(images, uv_a, uv_b, use_previous_plot=None, circ_color='g', show=True): 6 | if use_previous_plot is None: 7 | fig, axes = plt.subplots(nrows=2, ncols=2) 8 | else: 9 | fig, axes = use_previous_plot[0], use_previous_plot[1] 10 | 11 | fig.set_figheight(10) 12 | fig.set_figwidth(15) 13 | pixel_locs = [uv_a, uv_b, uv_a, uv_b] 14 | axes = axes.flat[0:] 15 | if use_previous_plot is not None: 16 | axes = [axes[1], axes[3]] 17 | images = [images[1], images[3]] 18 | pixel_locs = [pixel_locs[1], pixel_locs[3]] 19 | for ax, img, pixel_loc in zip(axes[0:], images, pixel_locs): 20 | ax.set_aspect('equal') 21 | if isinstance(pixel_loc[0], int) or isinstance(pixel_loc[0], float): 22 | circ = Circle(pixel_loc, radius=10, facecolor=circ_color, edgecolor='white', fill=True ,linewidth = 2.0, linestyle='solid') 23 | ax.add_patch(circ) 24 | else: 25 | for x,y in zip(pixel_loc[0],pixel_loc[1]): 26 | circ = Circle((x,y), radius=10, facecolor=circ_color, edgecolor='white', fill=True ,linewidth = 2.0, linestyle='solid') 27 | ax.add_patch(circ) 28 | ax.imshow(img) 29 | if show: 30 | plt.show() 31 | return None 32 | else: 33 | return fig, axes 34 | 35 | def plot_correspondences_from_dir(log_dir, img_a, img_b, uv_a, uv_b, use_previous_plot=None, circ_color='g', show=True): 36 | img1_filename = log_dir+"/images/"+img_a+"_rgb.png" 37 | img2_filename = log_dir+"/images/"+img_b+"_rgb.png" 38 | img1_depth_filename = log_dir+"/images/"+img_a+"_depth.png" 39 | img2_depth_filename = log_dir+"/images/"+img_b+"_depth.png" 40 | images = [img1_filename, img2_filename, img1_depth_filename, img2_depth_filename] 41 | images = [mpimg.imread(x) for x in images] 42 | return plot_correspondences(images, uv_a, uv_b, use_previous_plot=use_previous_plot, circ_color=circ_color, show=show) 43 | 44 | def plot_correspondences_direct(img_a_rgb, img_a_depth, img_b_rgb, img_b_depth, uv_a, uv_b, use_previous_plot=None, circ_color='g', show=True): 45 | """ 46 | 47 | Plots rgb and depth image pair along with circles at pixel locations 48 | :param img_a_rgb: PIL.Image.Image 49 | :param img_a_depth: PIL.Image.Image 50 | :param img_b_rgb: PIL.Image.Image 51 | :param img_b_depth: PIL.Image.Image 52 | :param uv_a: (u,v) pixel location, or list of pixel locations 53 | :param uv_b: (u,v) pixel location, or list of pixel locations 54 | :param use_previous_plot: 55 | :param circ_color: str 56 | :param show: 57 | :return: 58 | """ 59 | images = [img_a_rgb, img_b_rgb, img_a_depth, img_b_depth] 60 | return plot_correspondences(images, uv_a, uv_b, use_previous_plot=use_previous_plot, circ_color=circ_color, show=show) 61 | 62 | -------------------------------------------------------------------------------- /superpoint/loss_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/superpoint/loss_functions/__init__.py -------------------------------------------------------------------------------- /superpoint/loss_functions/loss_composer.py: -------------------------------------------------------------------------------- 1 | from dense_correspondence.dataset.spartan_dataset_masked import SpartanDataset, SpartanDatasetDataType 2 | from dense_correspondence.loss_functions.pixelwise_contrastive_loss import PixelwiseContrastiveLoss 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | def get_loss(pixelwise_contrastive_loss, match_type, 8 | image_a_pred, image_b_pred, 9 | matches_a, matches_b, 10 | masked_non_matches_a, masked_non_matches_b, 11 | background_non_matches_a, background_non_matches_b, 12 | blind_non_matches_a, blind_non_matches_b): 13 | """ 14 | This function serves the purpose of: 15 | - parsing the different types of SpartanDatasetDataType... 16 | - parsing different types of matches / non matches.. 17 | - into different pixelwise contrastive loss functions 18 | 19 | :return args: loss, match_loss, masked_non_match_loss, \ 20 | background_non_match_loss, blind_non_match_loss 21 | :rtypes: each pytorch Variables 22 | 23 | """ 24 | if (match_type == SpartanDatasetDataType.SINGLE_OBJECT_WITHIN_SCENE).all(): 25 | print "applying SINGLE_OBJECT_WITHIN_SCENE loss" 26 | return get_within_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 27 | matches_a, matches_b, 28 | masked_non_matches_a, masked_non_matches_b, 29 | background_non_matches_a, background_non_matches_b, 30 | blind_non_matches_a, blind_non_matches_b) 31 | 32 | if (match_type == SpartanDatasetDataType.SINGLE_OBJECT_ACROSS_SCENE).all(): 33 | print "applying SINGLE_OBJECT_ACROSS_SCENE loss" 34 | return get_same_object_across_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 35 | blind_non_matches_a, blind_non_matches_b) 36 | 37 | if (match_type == SpartanDatasetDataType.DIFFERENT_OBJECT).all(): 38 | print "applying DIFFERENT_OBJECT loss" 39 | return get_different_object_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 40 | blind_non_matches_a, blind_non_matches_b) 41 | 42 | 43 | if (match_type == SpartanDatasetDataType.MULTI_OBJECT).all(): 44 | print "applying MULTI_OBJECT loss" 45 | return get_within_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 46 | matches_a, matches_b, 47 | masked_non_matches_a, masked_non_matches_b, 48 | background_non_matches_a, background_non_matches_b, 49 | blind_non_matches_a, blind_non_matches_b) 50 | 51 | if (match_type == SpartanDatasetDataType.SYNTHETIC_MULTI_OBJECT).all(): 52 | print "applying SYNTHETIC_MULTI_OBJECT loss" 53 | return get_within_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 54 | matches_a, matches_b, 55 | masked_non_matches_a, masked_non_matches_b, 56 | background_non_matches_a, background_non_matches_b, 57 | blind_non_matches_a, blind_non_matches_b) 58 | 59 | else: 60 | raise ValueError("Should only have above scenes?") 61 | 62 | 63 | def get_within_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 64 | matches_a, matches_b, 65 | masked_non_matches_a, masked_non_matches_b, 66 | background_non_matches_a, background_non_matches_b, 67 | blind_non_matches_a, blind_non_matches_b): 68 | """ 69 | Simple wrapper for pixelwise_contrastive_loss functions. Args and return args documented above in get_loss() 70 | """ 71 | pcl = pixelwise_contrastive_loss 72 | 73 | match_loss, masked_non_match_loss, num_masked_hard_negatives =\ 74 | pixelwise_contrastive_loss.get_loss_matched_and_non_matched_with_l2(image_a_pred, image_b_pred, 75 | matches_a, matches_b, 76 | masked_non_matches_a, masked_non_matches_b, 77 | M_descriptor=pcl._config["M_masked"]) 78 | 79 | if pcl._config["use_l2_pixel_loss_on_background_non_matches"]: 80 | background_non_match_loss, num_background_hard_negatives =\ 81 | pixelwise_contrastive_loss.non_match_loss_with_l2_pixel_norm(image_a_pred, image_b_pred, matches_b, 82 | background_non_matches_a, background_non_matches_b, M_descriptor=pcl._config["M_background"]) 83 | 84 | else: 85 | background_non_match_loss, num_background_hard_negatives =\ 86 | pixelwise_contrastive_loss.non_match_loss_descriptor_only(image_a_pred, image_b_pred, 87 | background_non_matches_a, background_non_matches_b, 88 | M_descriptor=pcl._config["M_background"]) 89 | 90 | 91 | 92 | blind_non_match_loss = zero_loss() 93 | num_blind_hard_negatives = 1 94 | if not (SpartanDataset.is_empty(blind_non_matches_a.data)): 95 | blind_non_match_loss, num_blind_hard_negatives =\ 96 | pixelwise_contrastive_loss.non_match_loss_descriptor_only(image_a_pred, image_b_pred, 97 | blind_non_matches_a, blind_non_matches_b, 98 | M_descriptor=pcl._config["M_masked"]) 99 | 100 | 101 | 102 | total_num_hard_negatives = num_masked_hard_negatives + num_background_hard_negatives 103 | total_num_hard_negatives = max(total_num_hard_negatives, 1) 104 | 105 | if pcl._config["scale_by_hard_negatives"]: 106 | scale_factor = total_num_hard_negatives 107 | 108 | masked_non_match_loss_scaled = masked_non_match_loss*1.0/max(num_masked_hard_negatives, 1) 109 | 110 | background_non_match_loss_scaled = background_non_match_loss*1.0/max(num_background_hard_negatives, 1) 111 | 112 | blind_non_match_loss_scaled = blind_non_match_loss*1.0/max(num_blind_hard_negatives, 1) 113 | else: 114 | # we are not currently using blind non-matches 115 | num_masked_non_matches = max(len(masked_non_matches_a),1) 116 | num_background_non_matches = max(len(background_non_matches_a),1) 117 | num_blind_non_matches = max(len(blind_non_matches_a),1) 118 | scale_factor = num_masked_non_matches + num_background_non_matches 119 | 120 | 121 | masked_non_match_loss_scaled = masked_non_match_loss*1.0/num_masked_non_matches 122 | 123 | background_non_match_loss_scaled = background_non_match_loss*1.0/num_background_non_matches 124 | 125 | blind_non_match_loss_scaled = blind_non_match_loss*1.0/num_blind_non_matches 126 | 127 | 128 | 129 | non_match_loss = 1.0/scale_factor * (masked_non_match_loss + background_non_match_loss) 130 | 131 | loss = pcl._config["match_loss_weight"] * match_loss + \ 132 | pcl._config["non_match_loss_weight"] * non_match_loss 133 | 134 | 135 | 136 | return loss, match_loss, masked_non_match_loss_scaled, background_non_match_loss_scaled, blind_non_match_loss_scaled 137 | 138 | def get_within_scene_loss_triplet(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 139 | matches_a, matches_b, 140 | masked_non_matches_a, masked_non_matches_b, 141 | background_non_matches_a, background_non_matches_b, 142 | blind_non_matches_a, blind_non_matches_b): 143 | """ 144 | Simple wrapper for pixelwise_contrastive_loss functions. Args and return args documented above in get_loss() 145 | """ 146 | 147 | pcl = pixelwise_contrastive_loss 148 | 149 | masked_triplet_loss =\ 150 | pixelwise_contrastive_loss.get_triplet_loss(image_a_pred, image_b_pred, matches_a, 151 | matches_b, masked_non_matches_a, masked_non_matches_b, pcl._config["alpha_triplet"]) 152 | 153 | background_triplet_loss =\ 154 | pixelwise_contrastive_loss.get_triplet_loss(image_a_pred, image_b_pred, matches_a, 155 | matches_b, background_non_matches_a, background_non_matches_b, pcl._config["alpha_triplet"]) 156 | 157 | total_loss = masked_triplet_loss + background_triplet_loss 158 | 159 | return total_loss, zero_loss(), zero_loss(), zero_loss(), zero_loss() 160 | 161 | def get_different_object_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 162 | blind_non_matches_a, blind_non_matches_b): 163 | """ 164 | Simple wrapper for pixelwise_contrastive_loss functions. Args and return args documented above in get_loss() 165 | """ 166 | 167 | scale_by_hard_negatives = pixelwise_contrastive_loss.config["scale_by_hard_negatives_DIFFERENT_OBJECT"] 168 | blind_non_match_loss = zero_loss() 169 | if not (SpartanDataset.is_empty(blind_non_matches_a.data)): 170 | M_descriptor = pixelwise_contrastive_loss.config["M_background"] 171 | 172 | blind_non_match_loss, num_hard_negatives =\ 173 | pixelwise_contrastive_loss.non_match_loss_descriptor_only(image_a_pred, image_b_pred, 174 | blind_non_matches_a, blind_non_matches_b, 175 | M_descriptor=M_descriptor) 176 | 177 | if scale_by_hard_negatives: 178 | scale_factor = max(num_hard_negatives, 1) 179 | else: 180 | scale_factor = max(len(blind_non_matches_a), 1) 181 | 182 | blind_non_match_loss = 1.0/scale_factor * blind_non_match_loss 183 | loss = blind_non_match_loss 184 | return loss, zero_loss(), zero_loss(), zero_loss(), blind_non_match_loss 185 | 186 | def get_same_object_across_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 187 | blind_non_matches_a, blind_non_matches_b): 188 | """ 189 | Simple wrapper for pixelwise_contrastive_loss functions. Args and return args documented above in get_loss() 190 | """ 191 | blind_non_match_loss = zero_loss() 192 | if not (SpartanDataset.is_empty(blind_non_matches_a.data)): 193 | blind_non_match_loss, num_hard_negatives =\ 194 | pixelwise_contrastive_loss.non_match_loss_descriptor_only(image_a_pred, image_b_pred, 195 | blind_non_matches_a, blind_non_matches_b, 196 | M_descriptor=pcl._config["M_masked"], invert=True) 197 | 198 | if pixelwise_contrastive_loss._config["scale_by_hard_negatives"]: 199 | scale_factor = max(num_hard_negatives, 1) 200 | else: 201 | scale_factor = max(len(blind_non_matches_a), 1) 202 | 203 | loss = 1.0/scale_factor * blind_non_match_loss 204 | blind_non_match_loss_scaled = 1.0/scale_factor * blind_non_match_loss 205 | return loss, zero_loss(), zero_loss(), zero_loss(), blind_non_match_loss 206 | 207 | def zero_loss(): 208 | return Variable(torch.FloatTensor([0]).cuda()) 209 | 210 | def is_zero_loss(loss): 211 | return loss.data[0] < 1e-20 212 | 213 | 214 | -------------------------------------------------------------------------------- /superpoint/loss_functions/sparse_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.utils import filter_points 4 | from utils.utils import crop_or_pad_choice 5 | from utils.utils import normPts 6 | from utils.homographies import scale_homography_torch 7 | import superpoint.correspondence_tools.correspondence_finder as correspondence_finder 8 | from superpoint.loss_functions.pixelwise_contrastive_loss import PixelwiseContrastiveLoss 9 | 10 | def get_coor_cells(Hc, Wc, cell_size, device='cpu', uv=False): 11 | coor_cells = torch.stack(torch.meshgrid(torch.arange(Hc), torch.arange(Wc)), dim=2) 12 | coor_cells = coor_cells.type(torch.FloatTensor).to(device) 13 | coor_cells = coor_cells.view(-1, 2) 14 | # change vu to uv 15 | if uv: 16 | coor_cells = torch.stack((coor_cells[:,1], coor_cells[:,0]), dim=1) # (y, x) to (x, y) 17 | 18 | return coor_cells.to(device) 19 | 20 | def warp_coor_cells_with_homographies(coor_cells, homographies, uv=False, device='cpu'): 21 | from utils.utils import warp_points 22 | warped_coor_cells = coor_cells 23 | if uv == False: 24 | warped_coor_cells = torch.stack((warped_coor_cells[:,1], warped_coor_cells[:,0]), dim=1) # (y, x) to (x, y) 25 | 26 | warped_coor_cells = warp_points(warped_coor_cells, homographies, device) 27 | 28 | if uv == False: 29 | warped_coor_cells = torch.stack((warped_coor_cells[:, :, 1], warped_coor_cells[:, :, 0]), dim=2) # (batch, x, y) to (batch, y, x) 30 | return warped_coor_cells 31 | 32 | def create_non_matches(uv_a, uv_b_non_matches, multiplier): 33 | """ 34 | Simple wrapper for repeated code 35 | :param uv_a: 36 | :type uv_a: 37 | :param uv_b_non_matches: 38 | :type uv_b_non_matches: 39 | :param multiplier: 40 | :type multiplier: 41 | :return: 42 | :rtype: 43 | """ 44 | uv_a_long = (torch.t(uv_a[0].repeat(multiplier, 1)).contiguous().view(-1, 1), 45 | torch.t(uv_a[1].repeat(multiplier, 1)).contiguous().view(-1, 1)) 46 | 47 | uv_b_non_matches_long = (uv_b_non_matches[0].view(-1, 1), uv_b_non_matches[1].view(-1, 1)) 48 | 49 | return uv_a_long, uv_b_non_matches_long 50 | 51 | def uv_to_tuple(uv): 52 | return (uv[:, 0], uv[:, 1]) 53 | 54 | def tuple_to_uv(uv_tuple): 55 | return torch.stack([uv_tuple[0], uv_tuple[1]]) 56 | 57 | def tuple_to_1d(uv_tuple, W, uv=True): 58 | if uv: 59 | return uv_tuple[0] + uv_tuple[1]*W 60 | else: 61 | return uv_tuple[0]*W + uv_tuple[1] 62 | 63 | def uv_to_1d(points, W, uv=True): 64 | if uv: 65 | return points[..., 0] + points[..., 1]*W 66 | else: 67 | return points[..., 0]*W + points[..., 1] 68 | 69 | ## calculate matches loss 70 | def get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b, dist='cos', method='1d'): 71 | match_loss, matches_a_descriptors, matches_b_descriptors = \ 72 | PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred, 73 | matches_a, matches_b, dist=dist, method=method) 74 | return match_loss 75 | 76 | def get_non_matches_corr(img_b_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=10, device='cpu'): 77 | ## sample non matches 78 | uv_b_matches = uv_b_matches.squeeze() 79 | uv_b_matches_tuple = uv_to_tuple(uv_b_matches) 80 | uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences(uv_b_matches_tuple, 81 | img_b_shape, num_non_matches_per_match=num_masked_non_matches_per_match, 82 | img_b_mask=None) 83 | 84 | uv_a_tuple, uv_b_non_matches_tuple = \ 85 | create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match) 86 | return uv_a_tuple, uv_b_non_matches_tuple 87 | 88 | def get_non_match_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b, dist='cos'): 89 | ## non matches loss 90 | non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \ 91 | PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred, 92 | non_matches_a.long().squeeze(), 93 | non_matches_b.long().squeeze(), 94 | M=0.2, invert=True, dist=dist) 95 | non_match_loss = non_match_loss.sum()/(num_hard_negatives + 1) 96 | return non_match_loss 97 | 98 | def descriptor_loss_sparse(descriptors, descriptors_warped, homographies, mask_valid=None, 99 | cell_size=8, device='cpu', descriptor_dist=4, lamda_d=250, 100 | num_matching_attempts=1000, num_masked_non_matches_per_match=10, 101 | dist='cos', method='1d', **config): 102 | """ 103 | consider batches of descriptors 104 | :param descriptors: 105 | Output from descriptor head 106 | tensor [descriptors, Hc, Wc] 107 | :param descriptors_warped: 108 | Output from descriptor head of warped image 109 | tensor [descriptors, Hc, Wc] 110 | """ 111 | Hc, Wc = descriptors.shape[1], descriptors.shape[2] 112 | img_shape = (Hc, Wc) 113 | 114 | image_a_pred = descriptors.view(-1, Hc * Wc).transpose(0, 1).unsqueeze(0) # torch [1, H*W, D] 115 | image_b_pred = descriptors_warped.view(-1, Hc * Wc).transpose(0, 1).unsqueeze(0) # torch [1, H*W, D] 116 | 117 | # matches 118 | uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu')#[1200,2] 119 | homographies_H = scale_homography_torch(homographies, img_shape, shift=(-1, -1)) 120 | uv_b_matches = warp_coor_cells_with_homographies(uv_a, homographies_H.to('cpu'), uv=True, device='cpu') 121 | uv_b_matches.round_() 122 | uv_b_matches = uv_b_matches.squeeze(0) 123 | uv_b_matches, mask = filter_points(uv_b_matches, torch.tensor([Wc, Hc]).to(device='cpu'), return_mask=True) 124 | uv_a = uv_a[mask] 125 | 126 | # crop to the same length 127 | shuffle = True 128 | if not shuffle: print("shuffle: ", shuffle) 129 | choice = crop_or_pad_choice(uv_b_matches.shape[0], num_matching_attempts, shuffle=shuffle) 130 | choice = list(torch.tensor(choice)) 131 | uv_a = uv_a[choice] 132 | uv_b_matches = uv_b_matches[choice] 133 | 134 | if method == '2d': 135 | matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float()) # [u, v] 136 | matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float()) 137 | else: 138 | matches_a = uv_to_1d(uv_a, Wc) 139 | matches_b = uv_to_1d(uv_b_matches, Wc) 140 | 141 | if method == '2d': 142 | match_loss = get_match_loss(descriptors, descriptors_warped, matches_a.to(device), 143 | matches_b.to(device), dist=dist, method='2d') 144 | else: 145 | match_loss = get_match_loss(image_a_pred, image_b_pred, 146 | matches_a.long().to(device), matches_b.long().to(device), dist=dist) 147 | 148 | uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr(img_shape, 149 | uv_a, uv_b_matches, 150 | num_masked_non_matches_per_match=num_masked_non_matches_per_match) 151 | 152 | non_matches_a = tuple_to_1d(uv_a_tuple, Wc) 153 | non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc) 154 | non_match_loss = get_non_match_loss(image_a_pred, image_b_pred, non_matches_a.to(device), 155 | non_matches_b.to(device), dist=dist) 156 | 157 | loss = lamda_d * match_loss + non_match_loss 158 | return loss, lamda_d * match_loss, non_match_loss 159 | pass 160 | 161 | def batch_descriptor_loss_sparse(descriptors, descriptors_warped, homographies, **options): 162 | loss = [] 163 | pos_loss = [] 164 | neg_loss = [] 165 | batch_size = descriptors.shape[0] 166 | for i in range(batch_size): 167 | losses = descriptor_loss_sparse(descriptors[i], descriptors_warped[i], 168 | # torch.tensor(homographies[i], dtype=torch.float32), **options) 169 | homographies[i].type(torch.float32), **options) 170 | loss.append(losses[0]) 171 | pos_loss.append(losses[1]) 172 | neg_loss.append(losses[2]) 173 | loss, pos_loss, neg_loss = torch.stack(loss), torch.stack(pos_loss), torch.stack(neg_loss) 174 | return loss.mean(), None, pos_loss.mean(), neg_loss.mean() 175 | 176 | if __name__ == '__main__': 177 | # config 178 | H, W = 240, 320 179 | cell_size = 8 180 | Hc, Wc = H // cell_size, W // cell_size 181 | 182 | D = 3 183 | torch.manual_seed(0) 184 | np.random.seed(0) 185 | 186 | batch_size = 2 187 | device = 'cpu' 188 | method = '2d' 189 | 190 | num_matching_attempts = 1000 191 | num_masked_non_matches_per_match = 200 192 | lamda_d = 1 193 | 194 | homographies = np.identity(3)[np.newaxis, :, :] 195 | homographies = np.tile(homographies, [batch_size, 1, 1]) 196 | 197 | def randomDescriptor(): 198 | descriptors = torch.tensor(np.random.rand(2, D, Hc, Wc)-0.5, dtype=torch.float32) 199 | dn = torch.norm(descriptors, p=2, dim=1) # Compute the norm. 200 | descriptors = descriptors.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize. 201 | return descriptors 202 | 203 | descriptors = randomDescriptor() 204 | print("descriptors: ", descriptors.shape) 205 | descriptors_warped = randomDescriptor() 206 | descriptor_loss = descriptor_loss_sparse(descriptors[0], descriptors_warped[0], 207 | torch.tensor(homographies[0], dtype=torch.float32), 208 | method=method) 209 | 210 | print("descriptor_loss: ", descriptor_loss) 211 | 212 | loss = batch_descriptor_loss_sparse(descriptors, descriptors, 213 | torch.tensor(homographies, dtype=torch.float32), 214 | num_matching_attempts = num_matching_attempts, 215 | num_masked_non_matches_per_match = num_masked_non_matches_per_match, 216 | device=device, 217 | lamda_d = lamda_d, 218 | method=method) 219 | print("same descriptor_loss (pos should be 0): ", loss) 220 | 221 | -------------------------------------------------------------------------------- /superpoint/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/superpoint/models/__init__.py -------------------------------------------------------------------------------- /superpoint/models/model_utils.py: -------------------------------------------------------------------------------- 1 | """ class to process superpoint net 2 | # may be some duplication with model_wrap.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from utils.var_dim import toNumpy 9 | from utils.utils import getPtsFromHeatmap 10 | from utils.utils import crop_or_pad_choice 11 | from utils.losses import norm_patches 12 | from utils.losses import extract_patches 13 | from utils.losses import soft_argmax_2d 14 | 15 | class DepthToSpace(nn.Module): 16 | def __init__(self, block_size): 17 | super(DepthToSpace, self).__init__() 18 | self.block_size = block_size 19 | self.block_size_sq = block_size*block_size 20 | 21 | def forward(self, input): 22 | output = input.permute(0, 2, 3, 1)#(batch,64,15,20)->(batch,15,20,64) 23 | (batch_size, d_height, d_width, d_depth) = output.size() 24 | s_depth = int(d_depth / self.block_size_sq) 25 | s_width = int(d_width * self.block_size) 26 | s_height = int(d_height * self.block_size) 27 | t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)#(batch,15,20,64,1) 28 | spl = t_1.split(self.block_size, 3)#turple:8,(batch,15,20,8,1) 29 | stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]#list:8,(batch,15,160,1) 30 | output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)#(batch,120,160,1) 31 | output = output.permute(0, 3, 1, 2)#(batch,1,120,160) 32 | return output 33 | 34 | class SpaceToDepth(nn.Module): 35 | def __init__(self, block_size): 36 | super(SpaceToDepth, self).__init__() 37 | self.block_size = block_size 38 | self.block_size_sq = block_size*block_size 39 | 40 | def forward(self, input): 41 | output = input.permute(0, 2, 3, 1) 42 | (batch_size, s_height, s_width, s_depth) = output.size() 43 | d_depth = s_depth * self.block_size_sq 44 | d_width = int(s_width / self.block_size) 45 | d_height = int(s_height / self.block_size) 46 | t_1 = output.split(self.block_size, 2)#turple:20,(batch,120,8,1) 47 | stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]#list:20,(batch,15,64) 48 | output = torch.stack(stack, 1)#(batch,20,15,64) 49 | output = output.permute(0, 2, 1, 3)#(batch,15,20,64) 50 | output = output.permute(0, 3, 1, 2)#(batch,64,15,20) 51 | return output 52 | 53 | def flattenDetection(semi, tensor=False): 54 | ''' 55 | Flatten detection output 56 | 57 | :param semi: 58 | output from detector head 59 | tensor [65, Hc, Wc] 60 | :or 61 | tensor (batch_size, 65, Hc, Wc) 62 | 63 | :return: 64 | 3D heatmap 65 | np (1, H, C) 66 | :or 67 | tensor (batch_size, 65, Hc, Wc) 68 | 69 | ''' 70 | batch = False 71 | if len(semi.shape) == 4: 72 | batch = True 73 | batch_size = semi.shape[0] 74 | 75 | if batch: 76 | dense = nn.functional.softmax(semi, dim=1) # [batch, 65, Hc, Wc] 77 | # Remove dustbin. 78 | nodust = dense[:, :-1, :, :] 79 | else: 80 | dense = nn.functional.softmax(semi, dim=0) # [65, Hc, Wc] 81 | nodust = dense[:-1, :, :].unsqueeze(0) 82 | depth2space = DepthToSpace(8) 83 | heatmap = depth2space(nodust) 84 | heatmap = heatmap.squeeze(0) if not batch else heatmap 85 | return heatmap 86 | 87 | class SuperPointNet_process(object): 88 | 89 | def __init__(self, **config): 90 | # N=500, patch_size=5, device='cuda:0' 91 | self.out_num_points = config.get('out_num_points', 500) 92 | self.patch_size = config.get('patch_size', 5) 93 | self.device = config.get('device', 'cuda:0') 94 | self.nms_dist = config.get('nms_dist', 4) 95 | self.conf_thresh = config.get('conf_thresh', 0.015) 96 | self.heatmap = None 97 | self.heatmap_nms_batch = None 98 | pass 99 | 100 | # @staticmethod 101 | def pred_soft_argmax(self, labels_2D, heatmap): 102 | """ 103 | 104 | return: 105 | dict {'loss': mean of difference btw pred and res} 106 | """ 107 | patch_size=self.patch_size 108 | device=self.device 109 | 110 | 111 | outs = {} 112 | # extract patches 113 | label_idx = labels_2D[...].nonzero() 114 | # patch_size = self.config['params']['patch_size'] 115 | patches = extract_patches(label_idx.to(device), heatmap.to(device), 116 | patch_size=patch_size) 117 | # norm patches 118 | # patches = norm_patches(patches) 119 | 120 | # predict offsets 121 | from utils.losses import do_log 122 | patches_log = do_log(patches) 123 | # soft_argmax 124 | dxdy = soft_argmax_2d(patches_log, normalized_coordinates=False) # tensor [B, N, patch, patch] 125 | dxdy = dxdy.squeeze(1) # tensor [N, 2] 126 | dxdy = dxdy-patch_size//2 127 | 128 | # loss 129 | outs['pred'] = dxdy 130 | # ls = lambda x, y: dxdy.cpu() - points_res.cpu() 131 | outs['patches'] = patches 132 | return outs 133 | 134 | # torch 135 | @staticmethod 136 | def sample_desc_from_points(coarse_desc, pts, cell_size=8): 137 | """ 138 | inputs: 139 | coarse_desc: tensor [1, 256, Hc, Wc] 140 | pts: tensor [N, 2] (should be the same device as desc) 141 | return: 142 | desc: tensor [1, N, D] 143 | """ 144 | # --- Process descriptor. 145 | samp_pts = pts.transpose(0,1) 146 | H, W = coarse_desc.shape[2]*cell_size, coarse_desc.shape[3]*cell_size 147 | D = coarse_desc.shape[1] 148 | if pts.shape[1] == 0: 149 | # desc = torch.zeros((D, 0)) 150 | desc = torch.ones((1, 1, D)) 151 | else: 152 | # Interpolate into descriptor map using 2D point locations. 153 | # samp_pts = torch.from_numpy(pts[:2, :].copy()) 154 | samp_pts[0, :] = (samp_pts[0, :] / (float(W) / 2.)) - 1. 155 | samp_pts[1, :] = (samp_pts[1, :] / (float(H) / 2.)) - 1. 156 | samp_pts = samp_pts.transpose(0, 1).contiguous() 157 | samp_pts = samp_pts.view(1, 1, -1, 2) 158 | samp_pts = samp_pts.float() 159 | # samp_pts = samp_pts.to(self.device) 160 | desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts, align_corners=True) # tensor [batch_size(1), D, 1, N] 161 | # desc = desc.data.cpu().numpy().reshape(D, -1) 162 | # desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :] 163 | desc = desc.squeeze().transpose(0,1).unsqueeze(0) 164 | return desc 165 | 166 | # extract residual 167 | @staticmethod 168 | def ext_from_points(labels_res, points): 169 | """ 170 | input: 171 | labels_res: tensor [batch, channel, H, W] 172 | points: tensor [N, 4(pos0(batch), pos1(0), pos2(H), pos3(W) )] 173 | return: 174 | tensor [N, channel] 175 | """ 176 | labels_res = labels_res.transpose(1,2).transpose(2,3).unsqueeze(1) 177 | points_res = labels_res[points[:,0],points[:,1],points[:,2],points[:,3],:] # tensor [N, 2] 178 | return points_res 179 | 180 | # points_res = ext_from_points(labels_res, label_idx) 181 | 182 | @staticmethod 183 | def soft_argmax_2d(patches): 184 | """ 185 | params: 186 | patches: (B, N, H, W) 187 | return: 188 | coor: (B, N, 2) (x, y) 189 | 190 | """ 191 | import torchgeometry as tgm 192 | m = tgm.contrib.SpatialSoftArgmax2d() 193 | coords = m(patches) # 1x4x2 194 | return coords 195 | 196 | 197 | def heatmap_to_nms(self, heatmap, tensor=False, boxnms=False): 198 | """ 199 | return: 200 | heatmap_nms_batch: np [batch, 1, H, W] 201 | """ 202 | to_floatTensor = lambda x: torch.from_numpy(x).type(torch.FloatTensor) 203 | heatmap_np = toNumpy(heatmap) 204 | ## heatmap_nms 205 | if boxnms: 206 | from utils.utils import box_nms 207 | heatmap_nms_batch = [box_nms(h.detach().squeeze(), self.nms_dist, min_prob=self.conf_thresh) \ 208 | for h in heatmap] # [batch, H, W] 209 | heatmap_nms_batch = torch.stack(heatmap_nms_batch, dim=0).unsqueeze(1) 210 | # print('heatmap_nms_batch: ', heatmap_nms_batch.shape) 211 | else: 212 | heatmap_nms_batch = [self.heatmap_nms(h, self.nms_dist, self.conf_thresh) \ 213 | for h in heatmap_np] # [batch, H, W] 214 | heatmap_nms_batch = np.stack(heatmap_nms_batch, axis=0) 215 | heatmap_nms_batch = heatmap_nms_batch[:,np.newaxis,...] 216 | if tensor: 217 | heatmap_nms_batch = to_floatTensor(heatmap_nms_batch) 218 | heatmap_nms_batch = heatmap_nms_batch.to(self.device) 219 | self.heatmap = heatmap 220 | self.heatmap_nms_batch = heatmap_nms_batch 221 | return heatmap_nms_batch 222 | pass 223 | 224 | 225 | @staticmethod 226 | def heatmap_nms(heatmap, nms_dist=4, conf_thresh=0.015): 227 | """ 228 | input: 229 | heatmap: np [(1), H, W] 230 | """ 231 | # nms_dist = self.config['model']['nms'] 232 | # conf_thresh = self.config['model']['detection_threshold'] 233 | heatmap = heatmap.squeeze() 234 | boxnms = False 235 | # print("heatmap: ", heatmap.shape) 236 | pts_nms = getPtsFromHeatmap(heatmap, conf_thresh, nms_dist) 237 | 238 | semi_thd_nms_sample = np.zeros_like(heatmap) 239 | semi_thd_nms_sample[pts_nms[1, :].astype(np.int), pts_nms[0, :].astype(np.int)] = 1 240 | 241 | 242 | return semi_thd_nms_sample 243 | 244 | 245 | def batch_extract_features(self, desc, heatmap_nms_batch, residual): 246 | # extract pts, residuals for pts, descriptors 247 | """ 248 | return: -- type: tensorFloat 249 | pts: tensor [batch, N, 2] (no grad) (x, y) 250 | pts_offset: tensor [batch, N, 2] (grad) (x, y) 251 | pts_desc: tensor [batch, N, 256] (grad) 252 | """ 253 | batch_size = heatmap_nms_batch.shape[0] 254 | 255 | pts_int, pts_offset, pts_desc = [], [], [] 256 | pts_idx = heatmap_nms_batch[...].nonzero() # [N, 4(batch, 0, y, x)] 257 | for i in range(batch_size): 258 | mask_b = (pts_idx[:,0] == i) # first column == batch 259 | pts_int_b = pts_idx[mask_b][:,2:].float() # default floatTensor 260 | pts_int_b = pts_int_b[:, [1, 0]] # tensor [N, 2(x,y)] 261 | res_b = residual[mask_b] 262 | # print("res_b: ", res_b.shape) 263 | # print("pts_int_b: ", pts_int_b.shape) 264 | pts_b = pts_int_b + res_b # .no_grad() 265 | # extract desc 266 | pts_desc_b = self.sample_desc_from_points(desc[i].unsqueeze(0), pts_b).squeeze(0) 267 | # print("pts_desc_b: ", pts_desc_b.shape) 268 | # get random shuffle 269 | choice = crop_or_pad_choice(pts_int_b.shape[0], out_num_points=self.out_num_points, shuffle=True) 270 | choice = torch.tensor(choice).tolist() 271 | pts_int.append(pts_int_b[choice]) 272 | pts_offset.append(res_b[choice]) 273 | pts_desc.append(pts_desc_b[choice]) 274 | 275 | pts_int = torch.stack((pts_int), dim=0) 276 | pts_offset = torch.stack((pts_offset), dim=0) 277 | pts_desc = torch.stack((pts_desc), dim=0) 278 | return {'pts_int': pts_int, 'pts_offset': pts_offset, 'pts_desc': pts_desc} 279 | 280 | -------------------------------------------------------------------------------- /superpoint/models/superpoint_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.init import xavier_uniform_, zeros_ 4 | from superpoint.models.unet_parts import * 5 | import numpy as np 6 | 7 | def simple_nms(scores, nms_radius: int): 8 | """ Fast Non-maximum suppression to remove nearby points """ 9 | assert(nms_radius >= 0) 10 | 11 | def max_pool(x): 12 | return torch.nn.functional.max_pool2d( 13 | x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) 14 | 15 | zeros = torch.zeros_like(scores) 16 | max_mask = scores == max_pool(scores) 17 | for _ in range(2): 18 | supp_mask = max_pool(max_mask.float()) > 0 19 | supp_scores = torch.where(supp_mask, zeros, scores) 20 | new_max_mask = supp_scores == max_pool(supp_scores) 21 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 22 | return torch.where(max_mask, scores, zeros) 23 | 24 | 25 | def remove_borders(keypoints, scores, border: int, height: int, width: int): 26 | """ Removes keypoints too close to the border """ 27 | mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) 28 | mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) 29 | mask = mask_h & mask_w 30 | return keypoints[mask], scores[mask] 31 | 32 | 33 | def top_k_keypoints(keypoints, scores, k: int): 34 | if k >= len(keypoints): 35 | return keypoints, scores 36 | scores, indices = torch.topk(scores, k, dim=0) 37 | return keypoints[indices], scores 38 | 39 | 40 | def sample_descriptors(keypoints, descriptors, s: int = 8): 41 | """ Interpolate descriptors at keypoint locations """ 42 | b, c, h, w = descriptors.shape 43 | keypoints = keypoints - s / 2 + 0.5 44 | keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], 45 | ).to(keypoints)[None] 46 | keypoints = keypoints*2 - 1 # normalize to (-1, 1) 47 | args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} 48 | descriptors = torch.nn.functional.grid_sample( 49 | descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) 50 | descriptors = torch.nn.functional.normalize( 51 | descriptors.reshape(b, c, -1), p=2, dim=1) 52 | return descriptors 53 | 54 | # from models.SubpixelNet import SubpixelNet 55 | class SuperPoint(torch.nn.Module): 56 | """ Pytorch definition of SuperPoint Network. """ 57 | default_config = { 58 | 'descriptor_dim': 256, 59 | 'nms_radius': 4, 60 | 'keypoint_threshold': 0.005, 61 | 'max_keypoints': -1, 62 | 'remove_borders': 4, 63 | } 64 | def __init__(self, config): 65 | super(SuperPoint, self).__init__() 66 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 67 | self.config = {**self.default_config, **config} 68 | det_h = 65 69 | d1=self.config['descriptor_dim'] 70 | self.inc = inconv(1, c1) 71 | self.down1 = down(c1, c2) 72 | self.down2 = down(c2, c3) 73 | self.down3 = down(c3, c4) 74 | self.relu = torch.nn.ReLU(inplace=True) 75 | # Detector Head. 76 | self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 77 | self.bnPa = nn.BatchNorm2d(c5) 78 | self.convPb = torch.nn.Conv2d(c5, det_h, kernel_size=1, stride=1, padding=0) 79 | self.bnPb = nn.BatchNorm2d(det_h) 80 | # Descriptor Head. 81 | self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 82 | self.bnDa = nn.BatchNorm2d(c5) 83 | self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0) 84 | self.bnDb = nn.BatchNorm2d(d1) 85 | self.output = None 86 | 87 | if self.config['weights']: 88 | checkpoints=torch.load(self.config['weights']) 89 | pretrained_dict = checkpoints['model_state_dict'] 90 | #多卡训练单卡加载,带module时 91 | from collections import OrderedDict 92 | new_state_dict = OrderedDict() 93 | for k, v in pretrained_dict.items(): 94 | if "module" in k: 95 | name = k[7:] ## 多gpu 训练带moudule默认参数名字,预训练删除 96 | new_state_dict[name] = v 97 | else: 98 | new_state_dict[k] = v 99 | self.load_state_dict(new_state_dict) 100 | print("Loaded SuperPoint model") 101 | 102 | 103 | def forward(self, x): 104 | """ Forward pass that jointly computes unprocessed point and descriptor 105 | tensors. 106 | Input 107 | x: Image pytorch tensor shaped N x 1 x patch_size x patch_size. 108 | Output 109 | semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8. 110 | desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8. 111 | """ 112 | # Let's stick to this version: first BN, then relu 113 | x1 = self.inc(x)#(batch,64,120,160) 114 | x2 = self.down1(x1)#(batch,64,60,80) 115 | x3 = self.down2(x2)#(batch,128,30,40) 116 | x4 = self.down3(x3)#(batch,128,15,30) 117 | 118 | # Detector Head. 119 | cPa = self.relu(self.bnPa(self.convPa(x4)))#(batch,256,15,30) 120 | semi = self.bnPb(self.convPb(cPa))#(batch,65,15,30) 121 | # Descriptor Head. 122 | cDa = self.relu(self.bnDa(self.convDa(x4)))#(batch,256,15,30) 123 | desc = self.bnDb(self.convDb(cDa))#(batch,256,15,30) 124 | 125 | dn = torch.norm(desc, p=2, dim=1) # Compute the norm:(batch,15,30) 126 | desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize:(batch,256,15,30) 127 | 128 | scores = torch.nn.functional.softmax(semi, 1)[:, :-1]#(1,64,64,64) 129 | b, _, h, w = scores.shape 130 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 131 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) 132 | scores = simple_nms(scores, self.config['nms_radius']) 133 | 134 | # Extract keypoints 135 | keypoints = [ 136 | torch.nonzero(s > self.config['keypoint_threshold']) 137 | for s in scores] 138 | scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] 139 | 140 | # Discard keypoints near the image borders 141 | keypoints, scores = list(zip(*[ 142 | remove_borders(k, s, self.config['remove_borders'], h*8, w*8) 143 | for k, s in zip(keypoints, scores)])) 144 | 145 | # Keep the k keypoints with highest score 146 | if self.config['max_keypoints'] >= 0: 147 | keypoints, scores = list(zip(*[ 148 | top_k_keypoints(k, s, self.config['max_keypoints']) 149 | for k, s in zip(keypoints, scores)])) 150 | # Convert (h, w) to (x, y) 151 | keypoints = [torch.flip(k, [1]).float() for k in keypoints] 152 | 153 | # Extract descriptors 154 | descriptors = [sample_descriptors(k[None], d[None], 8)[0] 155 | for k, d in zip(keypoints, desc)] 156 | 157 | return { 158 | 'keypoints': keypoints, 159 | 'scores': scores, 160 | 'descriptors': descriptors, 161 | } 162 | 163 | 164 | if __name__ == "__main__": 165 | weights_path="superpoint/models/weights/superPointNet_120000.pth.tar" 166 | input=torch.randn(3,1,240,320) 167 | net=SuperPointNet_gauss2() 168 | checkpoint = torch.load(weights_path,map_location=lambda storage, loc: storage) 169 | net.load_state_dict(checkpoint['model_state_dict']) 170 | net.eval() 171 | output=net(input) 172 | 173 | 174 | -------------------------------------------------------------------------------- /superpoint/models/superpoint_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.init import xavier_uniform_, zeros_ 4 | from superpoint.models.unet_parts import * 5 | from superpoint.models.model_utils import flattenDetection 6 | import numpy as np 7 | 8 | class SuperPoint(torch.nn.Module): 9 | """ Pytorch definition of SuperPoint Network. """ 10 | def __init__(self, descriptor_length=256): 11 | super(SuperPoint, self).__init__() 12 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 13 | det_h = 65 14 | d1 = descriptor_length 15 | self.inc = inconv(1, c1) 16 | self.down1 = down(c1, c2) 17 | self.down2 = down(c2, c3) 18 | self.down3 = down(c3, c4) 19 | self.relu = torch.nn.ReLU(inplace=True) 20 | self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 21 | self.bnPa = nn.BatchNorm2d(c5) 22 | self.convPb = torch.nn.Conv2d(c5, det_h, kernel_size=1, stride=1, padding=0) 23 | self.bnPb = nn.BatchNorm2d(det_h) 24 | # Descriptor Head. 25 | self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 26 | self.bnDa = nn.BatchNorm2d(c5) 27 | self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0) 28 | self.bnDb = nn.BatchNorm2d(d1) 29 | self.output = None 30 | 31 | def forward(self, x): 32 | """ Forward pass that jointly computes unprocessed point and descriptor 33 | tensors. 34 | Input 35 | x: Image pytorch tensor shaped N x 1 x patch_size x patch_size. 36 | Output 37 | semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8. 38 | desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8. 39 | """ 40 | # Let's stick to this version: first BN, then relu 41 | x1 = self.inc(x)#(batch,64,120,160) 42 | x2 = self.down1(x1)#(batch,64,60,80) 43 | x3 = self.down2(x2)#(batch,128,30,40) 44 | x4 = self.down3(x3)#(batch,128,15,30) 45 | 46 | # Detector Head. 47 | cPa = self.relu(self.bnPa(self.convPa(x4)))#(batch,256,15,30) 48 | semi = self.bnPb(self.convPb(cPa))#(batch,65,15,30) 49 | # Descriptor Head. 50 | cDa = self.relu(self.bnDa(self.convDa(x4)))#(batch,256,15,30) 51 | desc = self.bnDb(self.convDb(cDa))#(batch,256,15,30) 52 | 53 | dn = torch.norm(desc, p=2, dim=1) # Compute the norm:(batch,15,30) 54 | desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize:(batch,256,15,30) 55 | output = {'semi': semi, 'desc': desc} 56 | self.output = output 57 | 58 | return output -------------------------------------------------------------------------------- /superpoint/models/unet_parts.py: -------------------------------------------------------------------------------- 1 | """U-net parts used for SuperPointNet_gauss2.py 2 | """ 3 | # sub-parts of the U-Net model 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class double_conv(nn.Module): 11 | '''(conv => BN => ReLU) * 2''' 12 | def __init__(self, in_ch, out_ch): 13 | super(double_conv, self).__init__() 14 | self.conv = nn.Sequential( 15 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 16 | nn.BatchNorm2d(out_ch), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 19 | nn.BatchNorm2d(out_ch), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | return x 26 | 27 | 28 | class inconv(nn.Module): 29 | def __init__(self, in_ch, out_ch): 30 | super(inconv, self).__init__() 31 | self.conv = double_conv(in_ch, out_ch) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class down(nn.Module): 39 | def __init__(self, in_ch, out_ch): 40 | super(down, self).__init__() 41 | self.mpconv = nn.Sequential( 42 | nn.MaxPool2d(2), 43 | double_conv(in_ch, out_ch) 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.mpconv(x) 48 | return x 49 | 50 | 51 | class up(nn.Module): 52 | def __init__(self, in_ch, out_ch, bilinear=True): 53 | super(up, self).__init__() 54 | 55 | # would be a nice idea if the upsampling could be learned too, 56 | # but my machine do not have enough memory to handle all those weights 57 | if bilinear: 58 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 59 | else: 60 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 61 | 62 | self.conv = double_conv(in_ch, out_ch) 63 | 64 | def forward(self, x1, x2): 65 | x1 = self.up(x1) 66 | 67 | # input is CHW 68 | diffY = x2.size()[2] - x1.size()[2] 69 | diffX = x2.size()[3] - x1.size()[3] 70 | 71 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 72 | diffY // 2, diffY - diffY//2)) 73 | 74 | # for padding issues, see 75 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 76 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 77 | 78 | x = torch.cat([x2, x1], dim=1) 79 | x = self.conv(x) 80 | return x 81 | 82 | 83 | class outconv(nn.Module): 84 | def __init__(self, in_ch, out_ch): 85 | super(outconv, self).__init__() 86 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 87 | 88 | def forward(self, x): 89 | x = self.conv(x) 90 | return x 91 | -------------------------------------------------------------------------------- /superpoint/models/weights/magicpoint/superPointNet_100000_checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/superpoint/models/weights/magicpoint/superPointNet_100000_checkpoint.pth.tar -------------------------------------------------------------------------------- /superpoint/models/weights/superPointNet_allss_descriptor_128.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/superpoint/models/weights/superPointNet_allss_descriptor_128.pth.tar -------------------------------------------------------------------------------- /superpoint/models/weights/superPointNet_allss_descriptor_64.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/superpoint/models/weights/superPointNet_allss_descriptor_64.pth.tar -------------------------------------------------------------------------------- /superpoint/models/weights/superPointNet_coco_descriptor_256.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/superpoint/models/weights/superPointNet_coco_descriptor_256.pth.tar -------------------------------------------------------------------------------- /superpoint_export_pseudo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import time 5 | import cv2 6 | import yaml 7 | import logging 8 | from pathlib import Path 9 | import numpy as np 10 | from imageio import imread 11 | from tqdm import tqdm 12 | import torch.utils.data as data 13 | from datasets.ALLSS import ALLSS 14 | import torchvision.transforms as transforms 15 | from utils.utils import combine_heatmap,draw_keypoints 16 | from superpoint.models.model_wrap import SuperPointFrontend_torch, PointTracker 17 | 18 | if __name__ == "__main__": 19 | # add parser 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--command", type=str,default='export_detector_homoAdapt') 22 | parser.add_argument("--config", type=str,default='superpoint/configs/magicpoint_allss_export.yaml') 23 | parser.add_argument("--exper_name", type=str,default='magicpoint_synth_homoAdapt_allss_50_[640,480]') 24 | parser.add_argument("--export_task", type=str,default='train',help="export mode: train or val") 25 | parser.add_argument("--save_output", type=str,default='Results/ALLSS',help="export mode: train or val") 26 | parser.add_argument("--eval", action="store_true",default=False,help="turn on eval mode") 27 | parser.add_argument("--outputImg", action="store_true",default=True, help="output image for visualization") 28 | parser.add_argument("--debug", action="store_true", default=False, help="turn on debuging mode") 29 | args = parser.parse_args() 30 | 31 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 32 | torch.set_default_tensor_type(torch.FloatTensor) 33 | logging.basicConfig(format="[%(asctime)s %(levelname)s] %(message)s",datefmt="%m/%d/%Y %H:%M:%S",level=logging.INFO,) 34 | 35 | with open(args.config, "r") as f: 36 | config = yaml.load(f) 37 | print("check config!! ", config) 38 | 39 | # data loading 40 | test_set = ALLSS(export=True,task=args.export_task,**config['data']) 41 | test_loader = data.DataLoader(test_set, batch_size=1, shuffle=False) 42 | 43 | fe = SuperPointFrontend_torch( 44 | config=config, 45 | weights_path=config["pretrained"], 46 | nms_dist=config["model"]["nms"], 47 | conf_thresh=config["model"]["detection_threshold"], 48 | nn_thresh=0.7, 49 | cuda=False, 50 | device=device, 51 | ) 52 | print("==> Successfully loaded pre-trained network.") 53 | fe.net_parallel() 54 | tracker = PointTracker(max_length=5, nn_thresh=fe.nn_thresh) 55 | 56 | count = 0 57 | for i, sample in tqdm(enumerate(test_loader)): 58 | img, mask_2D = sample["image"], sample["valid_mask"] 59 | img = img.transpose(0, 1) 60 | img_2D = sample["image_2D"].numpy().squeeze() 61 | mask_2D = mask_2D.transpose(0, 1) 62 | 63 | inv_homographies, homographies = ( 64 | sample["homographies"], 65 | sample["inv_homographies"], 66 | ) 67 | img, mask_2D, homographies, inv_homographies = ( 68 | img.to(device), 69 | mask_2D.to(device), 70 | homographies.to(device), 71 | inv_homographies.to(device), 72 | ) 73 | # sample = test_set[i] 74 | name = sample["name"][0] 75 | logging.info(f"name: {name}") 76 | 77 | # pass through network 78 | heatmap = fe.run(img, onlyHeatmap=True, train=False)#(100,1,240,320) 79 | outputs = combine_heatmap(heatmap, inv_homographies, mask_2D, device=device) 80 | pts = fe.getPtsFromHeatmap(outputs.detach().cpu().squeeze()) # (x,y, prob) 81 | 82 | # subpixel prediction 83 | if config["model"]["subpixel"]["enable"]: 84 | fe.heatmap = outputs # tensor [batch, 1, H, W] 85 | print("outputs: ", outputs.shape) 86 | print("pts: ", pts.shape) 87 | pts = fe.soft_argmax_points([pts]) 88 | pts = pts[0] 89 | 90 | ## top K points 91 | pts = pts.transpose() 92 | print("total points: ", pts.shape) 93 | print("pts: ", pts[:5]) 94 | 95 | top_k=config["model"]["top_k"] 96 | if top_k: 97 | if pts.shape[0] > top_k: 98 | pts = pts[:top_k, :] 99 | print("topK filter: ", pts.shape) 100 | 101 | ## save keypoints 102 | pred = {} 103 | pred.update({"pts": pts}) 104 | 105 | ## - make directories 106 | filename = str(name) 107 | save_output=os.path.join(args.save_output,args.exper_name,args.export_task) 108 | os.makedirs(save_output,exist_ok=True) 109 | path = Path(save_output, "{}.npz".format(filename)) 110 | np.savez_compressed(path, **pred) 111 | 112 | ## output images for visualization labels 113 | output_images = args.outputImg 114 | if output_images: 115 | img_pts = draw_keypoints(img_2D * 255, pts.transpose(),s=1) 116 | f = save_output+'/'+filename + ".png" 117 | cv2.imwrite(str(f), img_pts) 118 | count += 1 119 | 120 | print("output pseudo ground truth: ", count) 121 | -------------------------------------------------------------------------------- /superpoint_flann_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import argparse 5 | import numpy as np 6 | import torch.nn as nn 7 | from pathlib import Path 8 | from torch.utils.data import DataLoader 9 | from datasets.SSHIDataset import SSHIDataset 10 | from superpoint.models.superpoint_test import SuperPoint 11 | from utils.utils import (make_plot_matches, frame2tensor) 12 | 13 | MIN_MATCH_COUNT = 4 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='SuperPoint_flann test',formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('--img_dir', type=str, default='datasets/Amazon/',help='path to source image directory') 18 | parser.add_argument('--Result_dir', type=str, default='Results/Camera/superpoint_allss_descriptor_128',help='Directory where to write matching Results ') 19 | parser.add_argument('--resize_scale', type=float, default=0.25,help='resize scale;height,weight=scale*height,scale*weight') 20 | parser.add_argument('--match_viz', default=True, help='Whether write the match result or not') 21 | 22 | parser.add_argument('--weights_path', type=str, default='superpoint/models/weights/superPointNet_allss_descriptor_128.pth.tar',help='pretrain model path') 23 | parser.add_argument('--descriptor_dim', type=int, default=128, help='the dimension of descriptor') 24 | parser.add_argument('--max_keypoints', type=int, default=1200, help='Maximum number of keypoints detected by Superpoint'' (\'-1\' keeps all keypoints)') 25 | parser.add_argument('--keypoint_threshold', type=float, default=0.005,help='SuperPoint keypoint detector confidence threshold') 26 | parser.add_argument('--nms_radius', type=int, default=4,help='SuperPoint Non Maximum Suppression (NMS) radius'' (Must be positive)') 27 | opt = parser.parse_args() 28 | print(opt) 29 | 30 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 31 | torch.set_grad_enabled(False) 32 | 33 | config = { 34 | 'superpoint': { 35 | 'weights':opt.weights_path, 36 | 'descriptor_dim': opt.descriptor_dim, 37 | 'nms_radius': opt.nms_radius, 38 | 'keypoint_threshold': opt.keypoint_threshold, 39 | 'max_keypoints': opt.max_keypoints}, 40 | } 41 | 42 | # data path 43 | source_dir=opt.img_dir+'source1/' 44 | template_dir=opt.img_dir+'template1/' 45 | template_img_name=os.listdir(template_dir)[0] 46 | template_img_path=template_dir+template_img_name 47 | 48 | # load eval data 49 | test_dataset=SSHIDataset(source_dir,template_img_path,opt.resize_scale) 50 | test_loader=DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=1) 51 | 52 | superpoint = SuperPoint(config.get('superpoint', {})) 53 | superpoint.to(device).eval() 54 | 55 | for iter,(source_original,source_image,template_image,filename) in enumerate(test_loader): 56 | 57 | image1_tensor = source_image.float().to(device) 58 | image2_tensor = template_image.float().to(device) 59 | 60 | pred1 = superpoint(image1_tensor) 61 | pred2 = superpoint(image2_tensor) 62 | 63 | KeyP1=pred1["keypoints"][0].cpu().detach().numpy() 64 | KeyP2=pred2["keypoints"][0].cpu().detach().numpy() 65 | 66 | Desc1=pred1["descriptors"][0].cpu().detach().numpy().transpose() 67 | Desc2=pred2["descriptors"][0].cpu().detach().numpy().transpose() 68 | 69 | FlANN_INDEX_KDTREE=1 70 | index_params=dict(algorithm=FlANN_INDEX_KDTREE,trees=5) 71 | search_params=dict(checks=50) 72 | flann=cv2.FlannBasedMatcher(index_params,search_params) 73 | knn_matches=flann.knnMatch(Desc1,Desc2,k=2) 74 | 75 | good_matches=[] 76 | for m,n in knn_matches: 77 | if m.distance<0.7*n.distance: 78 | good_matches.append([m]) 79 | 80 | if len(good_matches)>MIN_MATCH_COUNT: 81 | src_pts = np.float32([KeyP1[m[0].queryIdx] for m in good_matches]) 82 | dst_pts = np.float32([KeyP2[m[0].trainIdx] for m in good_matches]) 83 | match_dist = np.float32([m[0].distance for m in good_matches]) 84 | Matrix, mask = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC,ransacReprojThreshold=7) 85 | 86 | #output the transformed images 87 | if opt.resize_scale is not None: 88 | Matrix[:,2]=Matrix[:,2]/opt.resize_scale 89 | source_original=source_original.squeeze().cpu().numpy()*255 90 | Transform=cv2.warpAffine(source_original,Matrix,(source_original.shape[1],source_original.shape[0])) 91 | Transform_dir=os.path.join(opt.Result_dir,'transformed/') 92 | os.makedirs(Transform_dir,exist_ok=True) 93 | cv2.imwrite(Transform_dir+'trans_{}'.format(filename[0]),Transform) 94 | 95 | #output the matching images 96 | RansacMask = (mask==1).ravel().tolist() 97 | src_ransac_pts=src_pts[RansacMask] 98 | dst_ransac_pts=dst_pts[RansacMask] 99 | match_ransac_dist=match_dist[RansacMask] 100 | if match_ransac_dist.max() > 1: 101 | best, worst = 0, Desc1.shape[1] * 2 # estimated range 102 | else: 103 | best, worst = 0, 1 104 | 105 | # 1: for best match, 0: for worst match 106 | match_scores = match_ransac_dist / worst 107 | match_scores[match_scores > 1] = 1 108 | match_scores[match_scores < 0] = 0 109 | match_scores = 1 - match_scores 110 | 111 | image01=source_image.squeeze().cpu().numpy()*255 112 | image02=template_image.squeeze().cpu().numpy()*255 113 | image1=np.repeat(image01[...,np.newaxis],3,2) 114 | image2=np.repeat(image02[...,np.newaxis],3,2) 115 | 116 | img = make_plot_matches(image1, image2, src_ransac_pts, dst_ransac_pts, match_scores, layout='lr') 117 | Match_dir=os.path.join(opt.Result_dir,"Match/") 118 | os.makedirs(Match_dir,exist_ok=True) 119 | cv2.imwrite(Match_dir+"match_{}".format(filename[0]),img) -------------------------------------------------------------------------------- /superpoint_glue_official_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import time 5 | import argparse 6 | from pathlib import Path 7 | import matplotlib.cm as cm 8 | from torch.utils.data import DataLoader 9 | from datasets.SSHIDataset import SSHIDataset 10 | from superglue.models.matching import Matching 11 | from superglue.models.utils import (make_matching_plot_fast, frame2tensor) 12 | 13 | torch.set_grad_enabled(False) 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='SuperPoint + SuperGlue registration test',formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('--exper_name', type=str, default='superpoint_glue_official',help='path to source image directory') 18 | parser.add_argument('--img_dir', type=str, default='datasets/Camera/',help='path to source image directory') 19 | parser.add_argument('--Result_dir', type=str, default='Results/Camera/',help='Directory where to write matching Results ') 20 | parser.add_argument('--resize_scale', type=float, default=0.125,help='resize scale;height,weight=scale*height,scale*weight') 21 | parser.add_argument('--match_viz', default=True, help='Whether write the match result or not') 22 | parser.add_argument('--show_keypoints', default=True,help='Show the detected keypoints') 23 | parser.add_argument('--descriptor_dim',type=int, default=256,help='The dimension of feature descriptor') 24 | 25 | #superglue hyper parameter 26 | parser.add_argument('--superpoint_weights',type=str,default="supeeglue/models/weights/superpoint_v1.pth",help='SuperPoint official weights') 27 | parser.add_argument('--superglue_weights', type=str, default='superglue/models/weights/superglue_indoor.pth',help='SuperGlue official weights') 28 | parser.add_argument('--sinkhorn_iterations', type=int, default=30,help='Number of Sinkhorn iterations performed by SuperGlue') 29 | parser.add_argument('--match_threshold', type=float, default=0.1,help='SuperGlue match threshold') 30 | parser.add_argument('--keypoint_threshold', type=float, default=0.005,help='SuperPoint keypoint detector confidence threshold') 31 | parser.add_argument('--nms_radius', type=int, default=4,help='SuperPoint Non Maximum Suppression (NMS) radius'' (Must be positive)') 32 | parser.add_argument('--max_keypoints', type=int, default=-1, help='Maximum number of keypoints detected by Superpoint'' (\'-1\' keeps all keypoints)') 33 | 34 | opt = parser.parse_args() 35 | print(opt) 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | 38 | config = { 39 | 'superpoint': { 40 | 'weights':opt.superpoint_weights, 41 | 'descriptor_dim':opt.descriptor_dim, 42 | 'nms_radius': opt.nms_radius, 43 | 'keypoint_threshold': opt.keypoint_threshold, 44 | 'max_keypoints': opt.max_keypoints 45 | }, 46 | 'superglue': { 47 | 'weights': opt.superglue_weights, 48 | 'descriptor_dim':opt.descriptor_dim, 49 | 'sinkhorn_iterations': opt.sinkhorn_iterations, 50 | 'match_threshold': opt.match_threshold, 51 | } 52 | } 53 | 54 | # data path 55 | source_dir=opt.img_dir+'source1/' 56 | template_dir=opt.img_dir+'template1/' 57 | template_img_name=os.listdir(template_dir)[0] 58 | template_img_path=template_dir+template_img_name 59 | 60 | # load eval data 61 | test_dataset=SSHIDataset(source_dir,template_img_path,opt.resize_scale) 62 | test_loader=DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=1) 63 | 64 | # model 65 | matching = Matching(config).eval().to(device) 66 | 67 | #eval 68 | for iter,(source_original,source_image,template_image,filename) in enumerate(test_loader): 69 | 70 | source_tensor=source_image.float().to(device) 71 | template_tensor=template_image.float().to(device) 72 | #registration and compute time 73 | start = time.perf_counter() 74 | pred = matching({'image0': source_tensor,'image1':template_tensor}) 75 | kpts0 = pred['keypoints0'][0].cpu().numpy() 76 | kpts1 = pred['keypoints1'][0].cpu().numpy() 77 | matches = pred['matches0'][0].cpu().numpy() 78 | confidence = pred['matching_scores0'][0].cpu().numpy() 79 | valid = matches > -1 80 | mkpts0 = kpts0[valid] 81 | mkpts1 = kpts1[matches[valid]] 82 | if len(mkpts0)>3: 83 | # Matrix,mask=cv2.estimateAffine2D(mkpts0,mkpts1,method=cv2.RANSAC,ransacReprojThreshold=7) 84 | Matrix,mask=cv2.estimateAffinePartial2D(mkpts0,mkpts1,method=cv2.RANSAC,ransacReprojThreshold=7) 85 | if opt.resize_scale is not None: 86 | Matrix[:,2]=Matrix[:,2]/opt.resize_scale 87 | flag=(mask>0).ravel().tolist() 88 | mkpts0,mkpts1=mkpts0[flag],mkpts1[flag] 89 | end = time.perf_counter() 90 | elapsed = end-start 91 | print("Time used:",elapsed) 92 | 93 | source_image=source_image.squeeze().cpu().numpy()*255 94 | template_image=template_image.squeeze().cpu().numpy()*255 95 | 96 | source_original=source_original.squeeze().cpu().numpy()*255 97 | Transform=cv2.warpAffine(source_original,Matrix,(source_original.shape[1],source_original.shape[0])) 98 | 99 | #output the Results 100 | if not os.path.exists(opt.Result_dir): 101 | os.makedirs(opt.Result_dir) 102 | 103 | #Output the Transform Results 104 | exper_name=opt.exper_name 105 | Transform_dir=os.path.join(opt.Result_dir,opt.exper_name,'Transform/') 106 | if not os.path.exists(Transform_dir): 107 | os.makedirs(Transform_dir) 108 | cv2.imwrite(Transform_dir+'trans_{}'.format(filename[0]),Transform) 109 | 110 | #Output the Match Results 111 | color = cm.jet(confidence[valid]) 112 | text = [ 113 | 'SuperGlue', 114 | 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)), 115 | 'Matches: {}'.format(len(mkpts0)) 116 | ] 117 | 118 | k_thresh = matching.superpoint.config['keypoint_threshold'] 119 | m_thresh = matching.superglue.config['match_threshold'] 120 | small_text = [ 121 | 'Keypoint Threshold: {:.4f}'.format(k_thresh), 122 | 'Match Threshold: {:.2f}'.format(m_thresh), 123 | ' ' 124 | ] 125 | 126 | out = make_matching_plot_fast( 127 | source_image, template_image, kpts0, kpts1, mkpts0, mkpts1, color, text, 128 | path=None, show_keypoints=opt.show_keypoints, small_text=small_text) 129 | 130 | Match_dir=os.path.join(opt.Result_dir,opt.exper_name,'Match/') 131 | if not os.path.exists(Match_dir): 132 | os.makedirs(Match_dir) 133 | 134 | out_file = str(Path(Match_dir, filename[0])) 135 | print('\nWriting image to {}'.format(out_file)) 136 | cv2.imwrite(out_file, out) -------------------------------------------------------------------------------- /superpoint_glue_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import time 5 | import argparse 6 | from pathlib import Path 7 | import matplotlib.cm as cm 8 | from torch.utils.data import DataLoader 9 | from datasets.SSHIDataset import SSHIDataset 10 | from superglue.models.matching_test import Matching 11 | from superglue.models.utils import (make_matching_plot_fast, frame2tensor) 12 | 13 | torch.set_grad_enabled(False) 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='SuperPoint + SuperGlue registration test',formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('--exper_name', type=str, default='superpoint_glue_descriptor',help='path to source image directory') 18 | parser.add_argument('--img_dir', type=str, default='datasets/Amazon/',help='path to source image directory') 19 | parser.add_argument('--Result_dir', type=str, default='Results/Amazon/',help='Directory where to write matching Results ') 20 | parser.add_argument('--resize_scale', type=float, default=0.125,help='resize scale;height,weight=scale*height,scale*weight') 21 | parser.add_argument('--match_viz', default=True, help='Whether write the match result or not') 22 | parser.add_argument('--show_keypoints', default=True,help='Show the detected keypoints') 23 | parser.add_argument('--descriptor_dim',type=int, default=128,help='The dimension of feature descriptor') 24 | 25 | #superpoint hyper parameter 26 | parser.add_argument('--superpoint_weights',type=str,default="superpoint/models/weights/superPointNet_allss_descriptor_128.pth.tar") 27 | parser.add_argument('--keypoint_threshold', type=float, default=0.005,help='SuperPoint keypoint detector confidence threshold') 28 | parser.add_argument('--nms_radius', type=int, default=4,help='SuperPoint Non Maximum Suppression (NMS) radius'' (Must be positive)') 29 | parser.add_argument('--max_keypoints', type=int, default=-1, help='Maximum number of keypoints detected by Superpoint'' (\'-1\' keeps all keypoints)') 30 | 31 | #superglue hyper parameter 32 | parser.add_argument('--superglue_weights', type=str, default='superglue/models/weights/SuperGlue_allss_descriptor_128.pth',help='SuperGlue weights') 33 | parser.add_argument('--keypoint_encoder', default=[32, 64, 128],help='The dimension of keypoint encoder') 34 | parser.add_argument('--sinkhorn_iterations', type=int, default=30,help='Number of Sinkhorn iterations performed by SuperGlue') 35 | parser.add_argument('--match_threshold', type=float, default=0.1,help='SuperGlue match threshold') 36 | 37 | opt = parser.parse_args() 38 | print(opt) 39 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 40 | 41 | config = { 42 | 'superpoint': { 43 | 'weights':opt.superpoint_weights, 44 | 'descriptor_dim':opt.descriptor_dim, 45 | 'nms_radius': opt.nms_radius, 46 | 'keypoint_threshold': opt.keypoint_threshold, 47 | 'max_keypoints': opt.max_keypoints 48 | }, 49 | 'superglue': { 50 | 'weights': opt.superglue_weights, 51 | 'descriptor_dim':opt.descriptor_dim, 52 | 'keypoint_encoder':opt.keypoint_encoder, 53 | 'sinkhorn_iterations': opt.sinkhorn_iterations, 54 | 'match_threshold': opt.match_threshold, 55 | } 56 | } 57 | 58 | # data path 59 | source_dir=opt.img_dir+'source1/' 60 | template_dir=opt.img_dir+'template1/' 61 | template_img_name=os.listdir(template_dir)[0] 62 | template_img_path=template_dir+template_img_name 63 | 64 | # load eval data 65 | test_dataset=SSHIDataset(source_dir,template_img_path,opt.resize_scale) 66 | test_loader=DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=1) 67 | 68 | # model 69 | matching = Matching(config).eval().to(device) 70 | 71 | #eval 72 | for iter,(source_original,source_image,template_image,filename) in enumerate(test_loader): 73 | 74 | source_tensor=source_image.float().to(device) 75 | template_tensor=template_image.float().to(device) 76 | #registration and compute time 77 | start = time.perf_counter() 78 | pred = matching({'image0': source_tensor,'image1':template_tensor}) 79 | kpts0 = pred['keypoints0'][0].cpu().numpy() 80 | kpts1 = pred['keypoints1'][0].cpu().numpy() 81 | matches = pred['matches0'][0].cpu().numpy() 82 | confidence = pred['matching_scores0'][0].cpu().numpy() 83 | valid = matches > -1 84 | mkpts0 = kpts0[valid] 85 | mkpts1 = kpts1[matches[valid]] 86 | if len(mkpts0)>3: 87 | # Matrix,mask=cv2.estimateAffine2D(mkpts0,mkpts1,method=cv2.RANSAC,ransacReprojThreshold=7) 88 | Matrix,mask=cv2.estimateAffinePartial2D(mkpts0,mkpts1,method=cv2.RANSAC,ransacReprojThreshold=7) 89 | if opt.resize_scale is not None: 90 | Matrix[:,2]=Matrix[:,2]/opt.resize_scale 91 | flag=(mask>0).ravel().tolist() 92 | mkpts0,mkpts1=mkpts0[flag],mkpts1[flag] 93 | end = time.perf_counter() 94 | elapsed = end-start 95 | print("Time used:",elapsed) 96 | 97 | source_image=source_image.squeeze().cpu().numpy()*255 98 | template_image=template_image.squeeze().cpu().numpy()*255 99 | 100 | source_original=source_original.squeeze().cpu().numpy()*255 101 | Transform=cv2.warpAffine(source_original,Matrix,(source_original.shape[1],source_original.shape[0])) 102 | 103 | #output the Results 104 | if not os.path.exists(opt.Result_dir): 105 | os.makedirs(opt.Result_dir) 106 | 107 | #Output the Transform Results 108 | exper_name=opt.exper_name+"_"+str(opt.exper_name) 109 | Transform_dir=os.path.join(opt.Result_dir,opt.exper_name,'Transform/') 110 | if not os.path.exists(Transform_dir): 111 | os.makedirs(Transform_dir) 112 | cv2.imwrite(Transform_dir+'trans_{}'.format(filename[0]),Transform) 113 | 114 | #Output the Match Results 115 | color = cm.jet(confidence[valid]) 116 | text = [ 117 | 'SuperGlue', 118 | 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)), 119 | 'Matches: {}'.format(len(mkpts0)) 120 | ] 121 | 122 | k_thresh = matching.superpoint.config['keypoint_threshold'] 123 | m_thresh = matching.superglue.config['match_threshold'] 124 | small_text = [ 125 | 'Keypoint Threshold: {:.4f}'.format(k_thresh), 126 | 'Match Threshold: {:.2f}'.format(m_thresh), 127 | ' ' 128 | ] 129 | 130 | out = make_matching_plot_fast( 131 | source_image, template_image, kpts0, kpts1, mkpts0, mkpts1, color, text, 132 | path=None, show_keypoints=opt.show_keypoints, small_text=small_text) 133 | 134 | Match_dir=os.path.join(opt.Result_dir,opt.exper_name,'Match/') 135 | if not os.path.exists(Match_dir): 136 | os.makedirs(Match_dir) 137 | 138 | out_file = str(Path(Match_dir, filename[0])) 139 | print('\nWriting image to {}'.format(out_file)) 140 | cv2.imwrite(out_file, out) 141 | -------------------------------------------------------------------------------- /superpoint_glue_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import argparse 5 | import numpy as np 6 | 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | import matplotlib.cm as cm 11 | import torch.multiprocessing 12 | from torch.autograd import Variable 13 | from tensorboardX import SummaryWriter 14 | from datasets.GlueSparse import GlueSparse 15 | from superglue.models.superglue_train import SuperGlue 16 | from utils.utils import make_matching_plot 17 | 18 | #model introduction and path 19 | parser = argparse.ArgumentParser(description='Image pair matching and pose evaluation with SuperGlue',formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument('--image_path', type=str, default='datasets/ALLSS/', help='Path to the directory of training imgs.') 21 | parser.add_argument('--Result_dir', type=str, default='Results/ALLSS/superglue_descriptor_128',help='Path to the directory') 22 | 23 | #base parameter 24 | parser.add_argument('--epoch', type=int, default=200,help='Number of epoches') 25 | parser.add_argument('--batch_size', type=int, default=1,help='batch_size') 26 | parser.add_argument('--learning_rate', type=int, default=0.0001,help='Learning rate') 27 | parser.add_argument('--shuffle', default=True ,help='Shuffle ordering of pairs before processing') 28 | parser.add_argument('--show_keypoints', default=True ,help='Plot the keypoints in addition to the matches') 29 | parser.add_argument('--viz_extension', type=str, default='png',choices=['png','pdf'],help='visualization file extension.Use pdf for highest-quality') 30 | 31 | #model hyper parameter 32 | parser.add_argument('--superpoint_weights',type=str,default="superpoint/models/weights/superPointNet_allss_descriptor_128.pth.tar") 33 | parser.add_argument('--descriptor_dim',type=int, default=128,help='The dimension of feature descriptor') 34 | parser.add_argument('--keypoint_encoder', default=[32,64,128],help='The dimension of keypoint encoder') 35 | parser.add_argument('--max_keypoints', type=int, default=1200,help='Maximum number of keypoints detected by Superpoint'' (\'-1\' keeps all keypoints)') 36 | parser.add_argument('--keypoint_threshold', type=float, default=0.005,help='SuperPoint keypoint detector confidence threshold') 37 | parser.add_argument('--nms_radius', type=int, default=4,help='SuperPoint Non Maximum Suppression (NMS) radius'' (Must be positive)') 38 | 39 | parser.add_argument('--sinkhorn_iterations', type=int, default=30,help='Number of Sinkhorn iterations performed by SuperGlue') 40 | parser.add_argument('--match_threshold', type=float, default=0.2,help='SuperGlue match threshold') 41 | parser.add_argument('--resize',default=[640,480],help='The size of image') 42 | 43 | #pretrain or checkpoint model 44 | parser.add_argument('--pretrain_weights', type=str, default='',help='SuperGlue official weights') 45 | parser.add_argument('--checkpoints_dir', type=str, default='checkpoints/', help='models saved here') 46 | parser.add_argument('--checkpoints_name', type=str, default='', help='checkpoint model name') 47 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints') 48 | parser.add_argument('--resume', default=False, help='if specified, the model start from checkpoints') 49 | 50 | 51 | if __name__ == '__main__': 52 | opt = parser.parse_args() 53 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 54 | print(opt) 55 | 56 | config = { 57 | 'superpoint':{ 58 | 'weights':opt.superpoint_weights, 59 | 'descriptor_dim': opt.descriptor_dim, 60 | 'nms_radius': opt.nms_radius, 61 | 'keypoint_threshold': opt.keypoint_threshold, 62 | 'max_keypoints': opt.max_keypoints, 63 | }, 64 | 'superglue': { 65 | 'descriptor_dim': opt.descriptor_dim, 66 | 'keypoint_encoder': opt.keypoint_encoder, 67 | 'sinkhorn_iterations': opt.sinkhorn_iterations, 68 | 'match_threshold': opt.match_threshold, 69 | } 70 | } 71 | 72 | # load training data 73 | train_path= os.path.join(opt.image_path,'train') 74 | train_set = GlueSparse(train_path,config.get('superpoint', {}),opt.resize,device) 75 | train_loader = torch.utils.data.DataLoader(dataset=train_set, shuffle=True, batch_size=opt.batch_size, drop_last=True) 76 | 77 | start_epoch=0 78 | writer =SummaryWriter(opt.Result_dir+'/logdir') 79 | superglue = SuperGlue(config.get('superglue', {})).to(device) 80 | optimizer = torch.optim.Adam(superglue.parameters(), lr=opt.learning_rate) 81 | 82 | if opt.pretrain_weights in ['indoor','outdoor']: 83 | path_pretrain = 'superglue/models/weights/superglue_{}.pth'.format(opt.pretrain_weights) 84 | pretrain = torch.load(path_pretrain) 85 | superglue.load_state_dict(pretrain) 86 | print('Loaded SuperGlue official model (\"{}\" weights)'.format(opt.pretrain_weights)) 87 | 88 | if opt.resume: 89 | path_checkpoints= os.path.join(opt.Result_dir,checkpoints_dir,opt.checkpoints_name) 90 | checkpoint= torch.load(path_checkpoints) 91 | superglue.load_state_dict(checkpoint['net']) 92 | start_epoch=checkpoint['epoch'] 93 | print('Loaded checkpoint model (\"{}\" weights)'.format(opt.checkpoints_name)) 94 | 95 | 96 | # store viz results 97 | eval_output_dir = Path(opt.Result_dir,'match') 98 | eval_output_dir.mkdir(exist_ok=True, parents=True) 99 | print('Will write visualization images to','directory \"{}\"'.format(eval_output_dir)) 100 | 101 | # start training 102 | for epoch in range(start_epoch+1, opt.epoch+1): 103 | epoch_loss = 0 104 | mean_loss = [] 105 | 106 | for i, pred in enumerate(train_loader): 107 | for k in pred: 108 | if k != 'file_name' and k!='image0' and k!='image1': 109 | if type(pred[k]) == torch.Tensor: 110 | pred[k] = Variable(pred[k].cuda()).type(torch.cuda.FloatTensor) 111 | else: 112 | pred[k] = Variable(torch.stack(pred[k]).cuda()) 113 | 114 | superglue.train() 115 | data = superglue(pred) 116 | for k, v in pred.items(): 117 | pred[k] = v[0] 118 | pred = {**pred, **data} 119 | 120 | if pred['skip_train'] == True: # image has no keypoint 121 | continue 122 | 123 | # process loss 124 | Loss = pred['loss'] 125 | epoch_loss += Loss.item() 126 | mean_loss.append(Loss) 127 | optimizer.zero_grad() 128 | Loss.backward() 129 | optimizer.step() 130 | 131 | # for every 50 images, print progress and visualize the matches 132 | if (i+1) % 5 == 0: 133 | mean_loss_item = torch.mean(torch.stack(mean_loss)).item() 134 | writer.add_scalar('Mean_Loss',mean_loss_item,len(train_loader)*(epoch-1)+i+1) 135 | print ('Epoch [{}/{}], Step [{}/{}], Mean Loss: {:.4f}' 136 | .format(epoch, opt.epoch, i+1, len(train_loader), mean_loss_item)) 137 | mean_loss = [] 138 | 139 | ### eval ### 140 | # Visualize the matches. 141 | superglue.eval() 142 | image0, image1 = pred['image0'].cpu().numpy()[0]*255., pred['image1'].cpu().numpy()[0]*255. 143 | kpts0, kpts1 = pred['keypoints0'].cpu().numpy()[0], pred['keypoints1'].cpu().numpy()[0] 144 | matches, conf = pred['matches0'].cpu().detach().numpy(), pred['matching_scores0'].cpu().detach().numpy() 145 | valid = matches > -1 146 | mkpts0 = kpts0[valid] 147 | mkpts1 = kpts1[matches[valid]] 148 | mconf = conf[valid] 149 | viz_path = eval_output_dir / '{}_matches.{}'.format(str(i), opt.viz_extension) 150 | color = cm.jet(mconf) 151 | stem = pred['file_name'] 152 | text = [] 153 | 154 | make_matching_plot( 155 | image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, 156 | text, viz_path, stem, stem, opt.show_keypoints, 157 | True, False, 'Matches') 158 | 159 | # save checkpoint when an epoch finishes 160 | epoch_loss /= len(train_loader) 161 | writer.add_scalar("epoch_loss",epoch_loss,epoch) 162 | model_out_path = opt.Result_dir+"/checkpoints" 163 | os.makedirs(model_out_path,exist_ok=True) 164 | model_checkpoint = model_out_path+"/SuperGlue_epoch_{}.pth".format(epoch) 165 | 166 | checkpoint = {'epoch': epoch,'net': superglue.state_dict()} 167 | torch.save(checkpoint, model_checkpoint) 168 | print("Epoch [{}/{}] done. Epoch Loss {}. Checkpoint saved to {}".format(epoch, opt.epoch, epoch_loss, model_out_path)) 169 | 170 | writer.close() -------------------------------------------------------------------------------- /superpoint_train_descriptor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | import logging 5 | import torch 6 | import torch.optim 7 | import torch.utils.data as data 8 | from tensorboardX import SummaryWriter 9 | from datasets.ALLSS import ALLSS 10 | from superpoint.Train_model_heatmap import Train_model_heatmap 11 | 12 | if __name__ == '__main__': 13 | parser=argparse.ArgumentParser() 14 | parser.add_argument("--config", type=str,default='superpoint/configs/superpoint_allss_train_heatmap.yaml') 15 | parser.add_argument("--exper_name", type=str,default='superpoint_allss_descriptor_128') 16 | parser.add_argument("--output_dir",type=str,default='Results/ALLSS/') 17 | args=parser.parse_args() 18 | 19 | device=torch.device("cuda" if torch.cuda.is_available() else 'cpu') 20 | logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s',datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 21 | logging.info('Start training on {}'.format(device)) 22 | 23 | #output path for saving training record 24 | output_dir=os.path.join(args.output_dir,args.exper_name) 25 | checkpoints_dir=os.path.join(output_dir,'checkpoints') 26 | os.makedirs(output_dir,exist_ok=True) 27 | os.makedirs(checkpoints_dir,exist_ok=True) 28 | writer=SummaryWriter(output_dir) 29 | 30 | #load configs and save configs 31 | with open(args.config, 'r') as f: 32 | config = yaml.load(f) 33 | with open(os.path.join(output_dir, 'config.yml'), 'w') as f: 34 | yaml.dump(config, f, default_flow_style=False) 35 | 36 | train_set = ALLSS(task='train',**config['data']) 37 | val_set = ALLSS(task='val',**config['data']) 38 | train_loader = data.DataLoader(train_set, batch_size=config['model']['batch_size'], shuffle=False) 39 | val_loader = data.DataLoader(val_set, batch_size=config['model']['eval_batch_size'], shuffle=False) 40 | 41 | train_agent = Train_model_heatmap(config, save_path=checkpoints_dir, device=device) 42 | train_agent.writer = writer 43 | train_agent.train_loader = train_loader 44 | train_agent.val_loader = val_loader 45 | train_agent.loadModel() 46 | train_agent.dataParallel() 47 | 48 | try: 49 | train_agent.train() 50 | except KeyboardInterrupt: 51 | print ("press ctrl + c, save model!") 52 | train_agent.saveModel() 53 | pass 54 | 55 | -------------------------------------------------------------------------------- /traditional.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import argparse 5 | import numpy as np 6 | from Traditional.registration import SIFT_REGIS,ORB_REGIS 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser(description='Traditional Registration',formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | parser.add_argument('--Method', type=str, default='SIFT',help='The method of feature based registration') 11 | parser.add_argument('--img_dir', type=str, default='datasets/Amazon/',help='path to source image directory') 12 | parser.add_argument('--Result_dir', type=str, default='Results/Amazon/',help='Directory where to write matching Results ') 13 | parser.add_argument('--resize_scale', type=float, default=0.5,help='resize scale;height,weight=scale*height,scale*weight') 14 | parser.add_argument('--match_viz', default=True, help='Whether write the match result or not') 15 | opt = parser.parse_args() 16 | 17 | #path for source and template img 18 | source_dir=opt.img_dir+'source1/' 19 | template_dir=opt.img_dir+'template1/' 20 | source_dir_List=os.listdir(source_dir) 21 | template_img_name=os.listdir(template_dir)[0] 22 | template_img=cv2.imread(template_dir+template_img_name) 23 | 24 | for i in source_dir_List: 25 | source_img_name=os.path.join(source_dir+i) 26 | source_img=cv2.imread(source_img_name) 27 | 28 | #registration and compute time 29 | start = time.perf_counter() 30 | if opt.Method=='SIFT': 31 | Matrix,match_img=SIFT_REGIS(source_img,template_img,opt.resize_scale,opt.match_viz) 32 | if opt.Method=='ORB': 33 | Matrix,match_img=ORB_REGIS(source_img,template_img,opt.resize_scale,opt.match_viz) 34 | 35 | if opt.resize_scale is not None: 36 | Matrix[:,2]=Matrix[:,2]/opt.resize_scale 37 | end = time.perf_counter() 38 | elapsed = end-start 39 | print("Time used:",elapsed) 40 | 41 | #output the matching results 42 | if not os.path.exists(opt.Result_dir): 43 | os.makedirs(opt.Result_dir) 44 | 45 | Transform_dir=os.path.join(opt.Result_dir,opt.Method+'/Transform1/') 46 | if not os.path.exists(Transform_dir): 47 | os.makedirs(Transform_dir) 48 | 49 | Match_dir=os.path.join(opt.Result_dir,opt.Method+'/Match1/') 50 | if not os.path.exists(Match_dir): 51 | os.makedirs(Match_dir) 52 | 53 | Transform=cv2.warpAffine(source_img,Matrix,(template_img.shape[1],template_img.shape[0])) 54 | cv2.imwrite(Transform_dir+'trans_{}'.format(i),Transform) 55 | 56 | if opt.match_viz: 57 | cv2.imwrite(Match_dir+'match_{}'.format(i),match_img) 58 | 59 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/utils/__init__.py -------------------------------------------------------------------------------- /utils/correspondence_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/utils/correspondence_tools/__init__.py -------------------------------------------------------------------------------- /utils/correspondence_tools/correspondence_plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib.image as mpimg 2 | import matplotlib.pyplot as plt 3 | from matplotlib.patches import Circle 4 | 5 | def plot_correspondences(images, uv_a, uv_b, use_previous_plot=None, circ_color='g', show=True): 6 | if use_previous_plot is None: 7 | fig, axes = plt.subplots(nrows=2, ncols=2) 8 | else: 9 | fig, axes = use_previous_plot[0], use_previous_plot[1] 10 | 11 | fig.set_figheight(10) 12 | fig.set_figwidth(15) 13 | pixel_locs = [uv_a, uv_b, uv_a, uv_b] 14 | axes = axes.flat[0:] 15 | if use_previous_plot is not None: 16 | axes = [axes[1], axes[3]] 17 | images = [images[1], images[3]] 18 | pixel_locs = [pixel_locs[1], pixel_locs[3]] 19 | for ax, img, pixel_loc in zip(axes[0:], images, pixel_locs): 20 | ax.set_aspect('equal') 21 | if isinstance(pixel_loc[0], int) or isinstance(pixel_loc[0], float): 22 | circ = Circle(pixel_loc, radius=10, facecolor=circ_color, edgecolor='white', fill=True ,linewidth = 2.0, linestyle='solid') 23 | ax.add_patch(circ) 24 | else: 25 | for x,y in zip(pixel_loc[0],pixel_loc[1]): 26 | circ = Circle((x,y), radius=10, facecolor=circ_color, edgecolor='white', fill=True ,linewidth = 2.0, linestyle='solid') 27 | ax.add_patch(circ) 28 | ax.imshow(img) 29 | if show: 30 | plt.show() 31 | return None 32 | else: 33 | return fig, axes 34 | 35 | def plot_correspondences_from_dir(log_dir, img_a, img_b, uv_a, uv_b, use_previous_plot=None, circ_color='g', show=True): 36 | img1_filename = log_dir+"/images/"+img_a+"_rgb.png" 37 | img2_filename = log_dir+"/images/"+img_b+"_rgb.png" 38 | img1_depth_filename = log_dir+"/images/"+img_a+"_depth.png" 39 | img2_depth_filename = log_dir+"/images/"+img_b+"_depth.png" 40 | images = [img1_filename, img2_filename, img1_depth_filename, img2_depth_filename] 41 | images = [mpimg.imread(x) for x in images] 42 | return plot_correspondences(images, uv_a, uv_b, use_previous_plot=use_previous_plot, circ_color=circ_color, show=show) 43 | 44 | def plot_correspondences_direct(img_a_rgb, img_a_depth, img_b_rgb, img_b_depth, uv_a, uv_b, use_previous_plot=None, circ_color='g', show=True): 45 | """ 46 | 47 | Plots rgb and depth image pair along with circles at pixel locations 48 | :param img_a_rgb: PIL.Image.Image 49 | :param img_a_depth: PIL.Image.Image 50 | :param img_b_rgb: PIL.Image.Image 51 | :param img_b_depth: PIL.Image.Image 52 | :param uv_a: (u,v) pixel location, or list of pixel locations 53 | :param uv_b: (u,v) pixel location, or list of pixel locations 54 | :param use_previous_plot: 55 | :param circ_color: str 56 | :param show: 57 | :return: 58 | """ 59 | images = [img_a_rgb, img_b_rgb, img_a_depth, img_b_depth] 60 | return plot_correspondences(images, uv_a, uv_b, use_previous_plot=use_previous_plot, circ_color=circ_color, show=show) 61 | 62 | -------------------------------------------------------------------------------- /utils/cp_labels.py: -------------------------------------------------------------------------------- 1 | """copy labels out of images (step 2) 2 | """ 3 | 4 | import subprocess 5 | from glob import glob 6 | import os 7 | 8 | source_folder = 'magicpoint_synth20_homoAdapt100_kitti_h384' 9 | target_folder = f"{source_folder}_labels" 10 | base_path = '/data/kitti' 11 | middle_path = 'predictions/' 12 | final_folder = 'train' 13 | folders = glob(f'{base_path}/{source_folder}/{middle_path}/{final_folder}/*') 14 | 15 | # print(f"folders: {folders}") 16 | for f in folders: 17 | if os.path.isdir(f) == False: 18 | continue 19 | f_target = str(f).replace(source_folder, target_folder) 20 | command = f'rsync -rh {f}/*.npz {f_target}' 21 | print(f"command: {command}") 22 | subprocess.run(f"{command}", shell=True, check=True) 23 | 24 | print(f"total folders: {len(folders)}") 25 | -------------------------------------------------------------------------------- /utils/d2s.py: -------------------------------------------------------------------------------- 1 | """Module used to change 2D labels to 3D labels and vise versa. 2 | Mimic function from tensorflow. 3 | 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | class DepthToSpace(nn.Module): 9 | def __init__(self, block_size): 10 | super(DepthToSpace, self).__init__() 11 | self.block_size = block_size 12 | self.block_size_sq = block_size*block_size 13 | 14 | def forward(self, input): 15 | output = input.permute(0, 2, 3, 1)#(batch,64,15,20)->(batch,15,20,64) 16 | (batch_size, d_height, d_width, d_depth) = output.size() 17 | s_depth = int(d_depth / self.block_size_sq) 18 | s_width = int(d_width * self.block_size) 19 | s_height = int(d_height * self.block_size) 20 | t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)#(batch,15,20,64,1) 21 | spl = t_1.split(self.block_size, 3)#turple:8,(batch,15,20,8,1) 22 | stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]#list:8,(batch,15,160,1) 23 | output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)#(batch,120,160,1) 24 | output = output.permute(0, 3, 1, 2)#(batch,1,120,160) 25 | return output 26 | 27 | class SpaceToDepth(nn.Module): 28 | def __init__(self, block_size): 29 | super(SpaceToDepth, self).__init__() 30 | self.block_size = block_size 31 | self.block_size_sq = block_size*block_size 32 | 33 | def forward(self, input): 34 | output = input.permute(0, 2, 3, 1) 35 | (batch_size, s_height, s_width, s_depth) = output.size() 36 | d_depth = s_depth * self.block_size_sq 37 | d_width = int(s_width / self.block_size) 38 | d_height = int(s_height / self.block_size) 39 | t_1 = output.split(self.block_size, 2)#turple:20,(batch,120,8,1) 40 | stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]#list:20,(batch,15,64) 41 | output = torch.stack(stack, 1)#(batch,20,15,64) 42 | output = output.permute(0, 2, 1, 3)#(batch,15,20,64) 43 | output = output.permute(0, 3, 1, 2)#(batch,64,15,20) 44 | return output 45 | -------------------------------------------------------------------------------- /utils/draw.py: -------------------------------------------------------------------------------- 1 | """ 2 | util functions for visualization 3 | """ 4 | 5 | import argparse 6 | import time 7 | import csv 8 | import yaml 9 | import os 10 | import logging 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | from tensorboardX import SummaryWriter 17 | import cv2 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | def plot_imgs(imgs, titles=None, cmap='brg', ylabel='', normalize=False, ax=None, dpi=100): 22 | n = len(imgs) 23 | if not isinstance(cmap, list): 24 | cmap = [cmap]*n 25 | if ax is None: 26 | fig, ax = plt.subplots(1, n, figsize=(6*n, 6), dpi=dpi) 27 | if n == 1: 28 | ax = [ax] 29 | else: 30 | if not isinstance(ax, list): 31 | ax = [ax] 32 | assert len(ax) == len(imgs) 33 | for i in range(n): 34 | if imgs[i].shape[-1] == 3: 35 | imgs[i] = imgs[i][..., ::-1] # BGR to RGB 36 | ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmap[i]), 37 | vmin=None if normalize else 0, 38 | vmax=None if normalize else 1) 39 | if titles: 40 | ax[i].set_title(titles[i]) 41 | ax[i].get_yaxis().set_ticks([]) 42 | ax[i].get_xaxis().set_ticks([]) 43 | for spine in ax[i].spines.values(): # remove frame 44 | spine.set_visible(False) 45 | ax[0].set_ylabel(ylabel) 46 | plt.tight_layout() 47 | 48 | 49 | # from utils.draw import img_overlap 50 | def img_overlap(img_r, img_g, img_gray): # img_b repeat 51 | img = np.concatenate((img_gray, img_gray, img_gray), axis=0) 52 | img[0, :, :] += img_r[0, :, :] 53 | img[1, :, :] += img_g[0, :, :] 54 | img[img > 1] = 1 55 | img[img < 0] = 0 56 | return img 57 | 58 | def draw_keypoints(img, corners, color=(0, 255, 0), radius=3, s=3): 59 | ''' 60 | 61 | :param img: 62 | image: 63 | numpy [H, W] 64 | :param corners: 65 | Points 66 | numpy [N, 2] 67 | :param color: 68 | :param radius: 69 | :param s: 70 | :return: 71 | overlaying image 72 | numpy [H, W] 73 | ''' 74 | img = np.repeat(cv2.resize(img, None, fx=s, fy=s)[..., np.newaxis], 3, -1) 75 | for c in np.stack(corners).T: 76 | # cv2.circle(img, tuple(s * np.flip(c, 0)), radius, color, thickness=-1) 77 | cv2.circle(img, tuple((s * c[:2]).astype(int)), radius, color, thickness=-1) 78 | return img 79 | 80 | # def draw_keypoints(img, corners, color=(0, 255, 0), radius=3, s=3): 81 | # ''' 82 | 83 | # :param img: 84 | # np (H, W) 85 | # :param corners: 86 | # np (3, N) 87 | # :param color: 88 | # :param radius: 89 | # :param s: 90 | # :return: 91 | # ''' 92 | # img = np.repeat(cv2.resize(img, None, fx=s, fy=s)[..., np.newaxis], 3, -1) 93 | # for c in np.stack(corners).T: 94 | # # cv2.circle(img, tuple(s * np.flip(c, 0)), radius, color, thickness=-1) 95 | # cv2.circle(img, tuple((s*c[:2]).astype(int)), radius, color, thickness=-1) 96 | # return img 97 | 98 | def draw_matches(rgb1, rgb2, match_pairs, lw = 0.5, color='g', if_fig=True, 99 | filename='matches.png', show=False): 100 | ''' 101 | 102 | :param rgb1: 103 | image1 104 | numpy (H, W) 105 | :param rgb2: 106 | image2 107 | numpy (H, W) 108 | :param match_pairs: 109 | numpy (keypoiny1 x, keypoint1 y, keypoint2 x, keypoint 2 y) 110 | :return: 111 | None 112 | ''' 113 | from matplotlib import pyplot as plt 114 | 115 | h1, w1 = rgb1.shape[:2] 116 | h2, w2 = rgb2.shape[:2] 117 | canvas = np.zeros((max(h1, h2), w1 + w2, 3), dtype=rgb1.dtype) 118 | canvas[:h1, :w1] = rgb1[:,:,np.newaxis] 119 | canvas[:h2, w1:] = rgb2[:,:,np.newaxis] 120 | # fig = plt.figure(frameon=False) 121 | if if_fig: 122 | fig = plt.figure(figsize=(15,5)) 123 | plt.axis("off") 124 | plt.imshow(canvas, zorder=1) 125 | 126 | xs = match_pairs[:, [0, 2]] 127 | xs[:, 1] += w1 128 | ys = match_pairs[:, [1, 3]] 129 | 130 | alpha = 1 131 | sf = 5 132 | # lw = 0.5 133 | # markersize = 1 134 | markersize = 2 135 | 136 | plt.plot( 137 | xs.T, ys.T, 138 | alpha=alpha, 139 | linestyle="-", 140 | linewidth=lw, 141 | aa=False, 142 | marker='o', 143 | markersize=markersize, 144 | fillstyle='none', 145 | color=color, 146 | zorder=2, 147 | # color=[0.0, 0.8, 0.0], 148 | ); 149 | plt.tight_layout() 150 | if filename is not None: 151 | plt.savefig(filename, dpi=300, bbox_inches='tight') 152 | print('#Matches = {}'.format(len(match_pairs))) 153 | if show: 154 | plt.show() 155 | 156 | 157 | 158 | # from utils.draw import draw_matches_cv 159 | def draw_matches_cv(data): 160 | keypoints1 = [cv2.KeyPoint(p[1], p[0], 1) for p in data['keypoints1']] 161 | keypoints2 = [cv2.KeyPoint(p[1], p[0], 1) for p in data['keypoints2']] 162 | inliers = data['inliers'].astype(bool) 163 | matches = np.array(data['matches'])[inliers].tolist() 164 | def to3dim(img): 165 | if img.ndim == 2: 166 | img = img[:, :, np.newaxis] 167 | return img 168 | img1 = to3dim(data['image1']) 169 | img2 = to3dim(data['image2']) 170 | img1 = np.concatenate([img1, img1, img1], axis=2) 171 | img2 = np.concatenate([img2, img2, img2], axis=2) 172 | return cv2.drawMatches(img1, keypoints1, img2, keypoints2, matches, 173 | None, matchColor=(0,255,0), singlePointColor=(0, 0, 255)) 174 | 175 | 176 | def drawBox(points, img, offset=np.array([0,0]), color=(0,255,0)): 177 | # print("origin", points) 178 | offset = offset[::-1] 179 | points = points + offset 180 | points = points.astype(int) 181 | for i in range(len(points)): 182 | img = img + cv2.line(np.zeros_like(img),tuple(points[-1+i]), tuple(points[i]), color,5) 183 | return img 184 | 185 | -------------------------------------------------------------------------------- /utils/homographies.py: -------------------------------------------------------------------------------- 1 | """Sample homography matrices 2 | # mimic the function from tensorflow 3 | # very tricky. Need to be careful for using the parameters. 4 | 5 | """ 6 | import torch 7 | from math import pi 8 | import cv2 9 | import numpy as np 10 | from utils.utils import dict_update 11 | 12 | def sample_homography_np( 13 | shape, shift=0, perspective=True, scaling=True, rotation=True, translation=True, 14 | n_scales=5, n_angles=25, scaling_amplitude=0.1, perspective_amplitude_x=0.1, 15 | perspective_amplitude_y=0.1, patch_ratio=0.5, max_angle=pi/2, 16 | allow_artifacts=False, translation_overflow=0.): 17 | """Sample a random valid homography. 18 | 19 | Computes the homography transformation between a random patch in the original image 20 | and a warped projection with the same image size. 21 | As in `tf.contrib.image.transform`, it maps the output point (warped patch) to a 22 | transformed input point (original patch). 23 | The original patch, which is initialized with a simple half-size centered crop, is 24 | iteratively projected, scaled, rotated and translated. 25 | 26 | Arguments: 27 | shape: A rank-2 `Tensor` specifying the height and width of the original image. 28 | perspective: A boolean that enables the perspective and affine transformations. 29 | scaling: A boolean that enables the random scaling of the patch. 30 | rotation: A boolean that enables the random rotation of the patch. 31 | translation: A boolean that enables the random translation of the patch. 32 | n_scales: The number of tentative scales that are sampled when scaling. 33 | n_angles: The number of tentatives angles that are sampled when rotating. 34 | scaling_amplitude: Controls the amount of scale. 35 | perspective_amplitude_x: Controls the perspective effect in x direction. 36 | perspective_amplitude_y: Controls the perspective effect in y direction. 37 | patch_ratio: Controls the size of the patches used to create the homography. 38 | max_angle: Maximum angle used in rotations. 39 | allow_artifacts: A boolean that enables artifacts when applying the homography. 40 | translation_overflow: Amount of border artifacts caused by translation. 41 | 42 | Returns: 43 | A `Tensor` of shape `[1, 8]` corresponding to the flattened homography transform. 44 | """ 45 | 46 | pts1 = np.stack([[0., 0.], [0., 1.], [1., 1.], [1., 0.]], axis=0) 47 | margin = (1 - patch_ratio) / 2 48 | pts2 = margin + np.array([[0, 0], [0, patch_ratio], 49 | [patch_ratio, patch_ratio], [patch_ratio, 0]]) 50 | 51 | from numpy.random import normal 52 | from numpy.random import uniform 53 | from scipy.stats import truncnorm 54 | 55 | std_trunc = 2 56 | 57 | if perspective: 58 | if not allow_artifacts: 59 | perspective_amplitude_x = min(perspective_amplitude_x, margin) 60 | perspective_amplitude_y = min(perspective_amplitude_y, margin) 61 | perspective_displacement = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_y/2).rvs(1) 62 | h_displacement_left = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_x/2).rvs(1) 63 | h_displacement_right = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_x/2).rvs(1) 64 | pts2 += np.array([[h_displacement_left, perspective_displacement], 65 | [h_displacement_left, -perspective_displacement], 66 | [h_displacement_right, perspective_displacement], 67 | [h_displacement_right, -perspective_displacement]]).squeeze() 68 | 69 | if scaling: 70 | scales = truncnorm(-1*std_trunc, std_trunc, loc=1, scale=scaling_amplitude/2).rvs(n_scales) 71 | scales = np.concatenate((np.array([1]), scales), axis=0) 72 | center = np.mean(pts2, axis=0, keepdims=True) 73 | scaled = (pts2 - center)[np.newaxis, :, :] * scales[:, np.newaxis, np.newaxis] + center 74 | if allow_artifacts: 75 | valid = np.arange(n_scales) 76 | else: 77 | valid = (scaled >= 0.) * (scaled < 1.) 78 | valid = valid.prod(axis=1).prod(axis=1) 79 | valid = np.where(valid)[0] 80 | idx = valid[np.random.randint(valid.shape[0], size=1)].squeeze().astype(int) 81 | pts2 = scaled[idx,:,:] 82 | 83 | # Random translation 84 | if translation: 85 | t_min, t_max = np.min(pts2, axis=0), np.min(1 - pts2, axis=0) 86 | if allow_artifacts: 87 | t_min += translation_overflow 88 | t_max += translation_overflow 89 | pts2 += np.array([uniform(-t_min[0], t_max[0],1), uniform(-t_min[1], t_max[1], 1)]).T 90 | 91 | # Random rotation 92 | if rotation: 93 | angles = np.linspace(-max_angle, max_angle, num=n_angles) 94 | angles = np.concatenate((angles, np.array([0.])), axis=0) # in case no rotation is valid 95 | center = np.mean(pts2, axis=0, keepdims=True) 96 | rot_mat = np.reshape(np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), 97 | np.cos(angles)], axis=1), [-1, 2, 2]) 98 | rotated = np.matmul( (pts2 - center)[np.newaxis,:,:], rot_mat) + center 99 | if allow_artifacts: 100 | valid = np.arange(n_angles) # all scales are valid except scale=1 101 | else: 102 | valid = (rotated >= 0.) * (rotated < 1.) 103 | valid = valid.prod(axis=1).prod(axis=1) 104 | valid = np.where(valid)[0] 105 | idx = valid[np.random.randint(valid.shape[0], size=1)].squeeze().astype(int) 106 | pts2 = rotated[idx,:,:] 107 | 108 | # Rescale to actual size 109 | shape = shape[::-1] # different convention [y, x] 110 | pts1 *= shape[np.newaxis,:] 111 | pts2 *= shape[np.newaxis,:] 112 | 113 | def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] 114 | def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] 115 | 116 | homography = cv2.getPerspectiveTransform(np.float32(pts1+shift), np.float32(pts2+shift)) 117 | return homography 118 | 119 | 120 | 121 | def scale_homography_torch(H, shape, shift=(-1,-1), dtype=torch.float32): 122 | height, width = shape[0], shape[1] 123 | trans = torch.tensor([[2./width, 0., shift[0]], [0., 2./height, shift[1]], [0., 0., 1.]], dtype=dtype) 124 | H_tf = torch.inverse(trans) @ H @ trans 125 | return H_tf 126 | 127 | def scale_homography(H, shape, shift=(-1,-1)): 128 | height, width = shape[0], shape[1] 129 | trans = np.array([[2./width, 0., shift[0]], [0., 2./height, shift[1]], [0., 0., 1.]]) 130 | H_tf = np.linalg.inv(trans) @ H @ trans 131 | return H_tf -------------------------------------------------------------------------------- /utils/loader.py: -------------------------------------------------------------------------------- 1 | """many loaders 2 | # loader for model, dataset, testing dataset 3 | """ 4 | 5 | import os 6 | import logging 7 | from pathlib import Path 8 | import numpy as np 9 | import torch 10 | import torch.optim 11 | import torch.utils.data 12 | from utils.utils import tensor2array, save_checkpoint, load_checkpoint, save_path_formatter 13 | # from settings import EXPER_PATH 14 | 15 | # from utils.loader import get_save_path 16 | def get_save_path(output_dir): 17 | """ 18 | This func 19 | :param output_dir: 20 | :return: 21 | """ 22 | save_path = Path(output_dir) 23 | save_path = save_path / 'checkpoints' 24 | logging.info('=> will save everything to {}'.format(save_path)) 25 | os.makedirs(save_path, exist_ok=True) 26 | return save_path 27 | 28 | def worker_init_fn(worker_id): 29 | """The function is designed for pytorch multi-process dataloader. 30 | Note that we use the pytorch random generator to generate a base_seed. 31 | Please try to be consistent. 32 | 33 | References: 34 | https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed 35 | 36 | """ 37 | base_seed = torch.IntTensor(1).random_().item() 38 | # print(worker_id, base_seed) 39 | np.random.seed(base_seed + worker_id) 40 | 41 | 42 | def dataLoader(config, dataset='syn', warp_input=False, train=True, val=True): 43 | import torchvision.transforms as transforms 44 | training_params = config.get('training', {}) 45 | workers_train = training_params.get('workers_train', 1) # 16 46 | workers_val = training_params.get('workers_val', 1) # 16 47 | 48 | logging.info(f"workers_train: {workers_train}, workers_val: {workers_val}") 49 | data_transforms = { 50 | 'train': transforms.Compose([ 51 | transforms.ToTensor(), 52 | ]), 53 | 'val': transforms.Compose([ 54 | transforms.ToTensor(), 55 | ]), 56 | } 57 | # if dataset == 'syn': 58 | # from datasets.SyntheticDataset_gaussian import SyntheticDataset as Dataset 59 | # else: 60 | Dataset = get_module('datasets', dataset) 61 | print(f"dataset: {dataset}") 62 | 63 | train_set = Dataset( 64 | transform=data_transforms['train'], 65 | task = 'train', 66 | **config['data'], 67 | ) 68 | train_loader = torch.utils.data.DataLoader( 69 | train_set, batch_size=config['model']['batch_size'], shuffle=True, 70 | pin_memory=True, 71 | num_workers=workers_train, 72 | worker_init_fn=worker_init_fn 73 | ) 74 | val_set = Dataset( 75 | transform=data_transforms['train'], 76 | task = 'val', 77 | **config['data'], 78 | ) 79 | val_loader = torch.utils.data.DataLoader( 80 | val_set, batch_size=config['model']['eval_batch_size'], shuffle=True, 81 | pin_memory=True, 82 | num_workers=workers_val, 83 | worker_init_fn=worker_init_fn 84 | ) 85 | # val_set, val_loader = None, None 86 | return {'train_loader': train_loader, 'val_loader': val_loader, 87 | 'train_set': train_set, 'val_set': val_set} 88 | 89 | def dataLoader_test(config, dataset='syn', warp_input=False, export_task='train'): 90 | import torchvision.transforms as transforms 91 | training_params = config.get('training', {}) 92 | workers_test = training_params.get('workers_test', 1) # 16 93 | logging.info(f"workers_test: {workers_test}") 94 | 95 | data_transforms = { 96 | 'test': transforms.Compose([ 97 | transforms.ToTensor(), 98 | ]) 99 | } 100 | test_loader = None 101 | if dataset == 'syn': 102 | from datasets.SyntheticDataset import SyntheticDataset 103 | test_set = SyntheticDataset( 104 | transform=data_transforms['test'], 105 | train=False, 106 | warp_input=warp_input, 107 | getPts=True, 108 | seed=1, 109 | **config['data'], 110 | ) 111 | elif dataset == 'hpatches': 112 | from datasets.patches_dataset import PatchesDataset 113 | if config['data']['preprocessing']['resize']: 114 | size = config['data']['preprocessing']['resize'] 115 | test_set = PatchesDataset( 116 | transform=data_transforms['test'], 117 | **config['data'], 118 | ) 119 | test_loader = torch.utils.data.DataLoader( 120 | test_set, batch_size=1, shuffle=False, 121 | pin_memory=True, 122 | num_workers=workers_test, 123 | worker_init_fn=worker_init_fn 124 | ) 125 | # elif dataset == 'Coco' or 'Kitti' or 'Tum': 126 | else: 127 | # from datasets.Kitti import Kitti 128 | logging.info(f"load dataset from : {dataset}") 129 | Dataset = get_module('datasets', dataset) 130 | test_set = Dataset( 131 | export=True, 132 | task=export_task, 133 | **config['data'], 134 | ) 135 | test_loader = torch.utils.data.DataLoader( 136 | test_set, batch_size=1, shuffle=False, 137 | pin_memory=True, 138 | num_workers=workers_test, 139 | worker_init_fn=worker_init_fn 140 | 141 | ) 142 | return {'test_set': test_set, 'test_loader': test_loader} 143 | 144 | def get_module(path, name): 145 | import importlib 146 | if path == '': 147 | mod = importlib.import_module(name) 148 | else: 149 | mod = importlib.import_module('{}.{}'.format(path, name)) 150 | return getattr(mod, name) 151 | 152 | def get_model(name): 153 | mod = __import__('models.{}'.format(name), fromlist=['']) 154 | return getattr(mod, name) 155 | 156 | def modelLoader(model='SuperPointNet', **options): 157 | # create model 158 | logging.info("=> creating model: %s", model) 159 | net = get_model(model) 160 | net = net(**options) 161 | return net 162 | 163 | 164 | # mode: 'full' means the formats include the optimizer and epoch 165 | # full_path: if not full path, we need to go through another helper function 166 | def pretrainedLoader(net, optimizer, epoch, path, mode='full', full_path=False): 167 | # load checkpoint 168 | if full_path == True: 169 | checkpoint = torch.load(path) 170 | else: 171 | checkpoint = load_checkpoint(path) 172 | # apply checkpoint 173 | if mode == 'full': 174 | net.load_state_dict(checkpoint['model_state_dict']) 175 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 176 | # epoch = checkpoint['epoch'] 177 | epoch = checkpoint['n_iter'] 178 | # epoch = 0 179 | else: 180 | net.load_state_dict(checkpoint) 181 | # net.load_state_dict(torch.load(path,map_location=lambda storage, loc: storage)) 182 | return net, optimizer, epoch 183 | 184 | if __name__ == '__main__': 185 | net = modelLoader(model='SuperPointNet') 186 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | """colorful logging 2 | # import the whole file 3 | """ 4 | 5 | import coloredlogs, logging 6 | logging.basicConfig() 7 | logger = logging.getLogger() 8 | coloredlogs.install(level='INFO', logger=logger) 9 | 10 | from termcolor import colored, cprint 11 | # from sty import fg, bg, ef, rs 12 | 13 | def toRed(text): 14 | return colored(text, 'red', attrs=['reverse']) 15 | 16 | def toCyan(text): 17 | return colored(text, 'cyan', attrs=['reverse']) 18 | -------------------------------------------------------------------------------- /utils/loss_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PH8411/image-matching/26bdfe05fef4f4dbccde79e5d5efc3a77eb59ea8/utils/loss_functions/__init__.py -------------------------------------------------------------------------------- /utils/loss_functions/loss_composer.py: -------------------------------------------------------------------------------- 1 | from dense_correspondence.dataset.spartan_dataset_masked import SpartanDataset, SpartanDatasetDataType 2 | from dense_correspondence.loss_functions.pixelwise_contrastive_loss import PixelwiseContrastiveLoss 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | def get_loss(pixelwise_contrastive_loss, match_type, 8 | image_a_pred, image_b_pred, 9 | matches_a, matches_b, 10 | masked_non_matches_a, masked_non_matches_b, 11 | background_non_matches_a, background_non_matches_b, 12 | blind_non_matches_a, blind_non_matches_b): 13 | """ 14 | This function serves the purpose of: 15 | - parsing the different types of SpartanDatasetDataType... 16 | - parsing different types of matches / non matches.. 17 | - into different pixelwise contrastive loss functions 18 | 19 | :return args: loss, match_loss, masked_non_match_loss, \ 20 | background_non_match_loss, blind_non_match_loss 21 | :rtypes: each pytorch Variables 22 | 23 | """ 24 | if (match_type == SpartanDatasetDataType.SINGLE_OBJECT_WITHIN_SCENE).all(): 25 | print "applying SINGLE_OBJECT_WITHIN_SCENE loss" 26 | return get_within_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 27 | matches_a, matches_b, 28 | masked_non_matches_a, masked_non_matches_b, 29 | background_non_matches_a, background_non_matches_b, 30 | blind_non_matches_a, blind_non_matches_b) 31 | 32 | if (match_type == SpartanDatasetDataType.SINGLE_OBJECT_ACROSS_SCENE).all(): 33 | print "applying SINGLE_OBJECT_ACROSS_SCENE loss" 34 | return get_same_object_across_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 35 | blind_non_matches_a, blind_non_matches_b) 36 | 37 | if (match_type == SpartanDatasetDataType.DIFFERENT_OBJECT).all(): 38 | print "applying DIFFERENT_OBJECT loss" 39 | return get_different_object_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 40 | blind_non_matches_a, blind_non_matches_b) 41 | 42 | 43 | if (match_type == SpartanDatasetDataType.MULTI_OBJECT).all(): 44 | print "applying MULTI_OBJECT loss" 45 | return get_within_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 46 | matches_a, matches_b, 47 | masked_non_matches_a, masked_non_matches_b, 48 | background_non_matches_a, background_non_matches_b, 49 | blind_non_matches_a, blind_non_matches_b) 50 | 51 | if (match_type == SpartanDatasetDataType.SYNTHETIC_MULTI_OBJECT).all(): 52 | print "applying SYNTHETIC_MULTI_OBJECT loss" 53 | return get_within_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 54 | matches_a, matches_b, 55 | masked_non_matches_a, masked_non_matches_b, 56 | background_non_matches_a, background_non_matches_b, 57 | blind_non_matches_a, blind_non_matches_b) 58 | 59 | else: 60 | raise ValueError("Should only have above scenes?") 61 | 62 | 63 | def get_within_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 64 | matches_a, matches_b, 65 | masked_non_matches_a, masked_non_matches_b, 66 | background_non_matches_a, background_non_matches_b, 67 | blind_non_matches_a, blind_non_matches_b): 68 | """ 69 | Simple wrapper for pixelwise_contrastive_loss functions. Args and return args documented above in get_loss() 70 | """ 71 | pcl = pixelwise_contrastive_loss 72 | 73 | match_loss, masked_non_match_loss, num_masked_hard_negatives =\ 74 | pixelwise_contrastive_loss.get_loss_matched_and_non_matched_with_l2(image_a_pred, image_b_pred, 75 | matches_a, matches_b, 76 | masked_non_matches_a, masked_non_matches_b, 77 | M_descriptor=pcl._config["M_masked"]) 78 | 79 | if pcl._config["use_l2_pixel_loss_on_background_non_matches"]: 80 | background_non_match_loss, num_background_hard_negatives =\ 81 | pixelwise_contrastive_loss.non_match_loss_with_l2_pixel_norm(image_a_pred, image_b_pred, matches_b, 82 | background_non_matches_a, background_non_matches_b, M_descriptor=pcl._config["M_background"]) 83 | 84 | else: 85 | background_non_match_loss, num_background_hard_negatives =\ 86 | pixelwise_contrastive_loss.non_match_loss_descriptor_only(image_a_pred, image_b_pred, 87 | background_non_matches_a, background_non_matches_b, 88 | M_descriptor=pcl._config["M_background"]) 89 | 90 | 91 | 92 | blind_non_match_loss = zero_loss() 93 | num_blind_hard_negatives = 1 94 | if not (SpartanDataset.is_empty(blind_non_matches_a.data)): 95 | blind_non_match_loss, num_blind_hard_negatives =\ 96 | pixelwise_contrastive_loss.non_match_loss_descriptor_only(image_a_pred, image_b_pred, 97 | blind_non_matches_a, blind_non_matches_b, 98 | M_descriptor=pcl._config["M_masked"]) 99 | 100 | 101 | 102 | total_num_hard_negatives = num_masked_hard_negatives + num_background_hard_negatives 103 | total_num_hard_negatives = max(total_num_hard_negatives, 1) 104 | 105 | if pcl._config["scale_by_hard_negatives"]: 106 | scale_factor = total_num_hard_negatives 107 | 108 | masked_non_match_loss_scaled = masked_non_match_loss*1.0/max(num_masked_hard_negatives, 1) 109 | 110 | background_non_match_loss_scaled = background_non_match_loss*1.0/max(num_background_hard_negatives, 1) 111 | 112 | blind_non_match_loss_scaled = blind_non_match_loss*1.0/max(num_blind_hard_negatives, 1) 113 | else: 114 | # we are not currently using blind non-matches 115 | num_masked_non_matches = max(len(masked_non_matches_a),1) 116 | num_background_non_matches = max(len(background_non_matches_a),1) 117 | num_blind_non_matches = max(len(blind_non_matches_a),1) 118 | scale_factor = num_masked_non_matches + num_background_non_matches 119 | 120 | 121 | masked_non_match_loss_scaled = masked_non_match_loss*1.0/num_masked_non_matches 122 | 123 | background_non_match_loss_scaled = background_non_match_loss*1.0/num_background_non_matches 124 | 125 | blind_non_match_loss_scaled = blind_non_match_loss*1.0/num_blind_non_matches 126 | 127 | 128 | 129 | non_match_loss = 1.0/scale_factor * (masked_non_match_loss + background_non_match_loss) 130 | 131 | loss = pcl._config["match_loss_weight"] * match_loss + \ 132 | pcl._config["non_match_loss_weight"] * non_match_loss 133 | 134 | 135 | 136 | return loss, match_loss, masked_non_match_loss_scaled, background_non_match_loss_scaled, blind_non_match_loss_scaled 137 | 138 | def get_within_scene_loss_triplet(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 139 | matches_a, matches_b, 140 | masked_non_matches_a, masked_non_matches_b, 141 | background_non_matches_a, background_non_matches_b, 142 | blind_non_matches_a, blind_non_matches_b): 143 | """ 144 | Simple wrapper for pixelwise_contrastive_loss functions. Args and return args documented above in get_loss() 145 | """ 146 | 147 | pcl = pixelwise_contrastive_loss 148 | 149 | masked_triplet_loss =\ 150 | pixelwise_contrastive_loss.get_triplet_loss(image_a_pred, image_b_pred, matches_a, 151 | matches_b, masked_non_matches_a, masked_non_matches_b, pcl._config["alpha_triplet"]) 152 | 153 | background_triplet_loss =\ 154 | pixelwise_contrastive_loss.get_triplet_loss(image_a_pred, image_b_pred, matches_a, 155 | matches_b, background_non_matches_a, background_non_matches_b, pcl._config["alpha_triplet"]) 156 | 157 | total_loss = masked_triplet_loss + background_triplet_loss 158 | 159 | return total_loss, zero_loss(), zero_loss(), zero_loss(), zero_loss() 160 | 161 | def get_different_object_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 162 | blind_non_matches_a, blind_non_matches_b): 163 | """ 164 | Simple wrapper for pixelwise_contrastive_loss functions. Args and return args documented above in get_loss() 165 | """ 166 | 167 | scale_by_hard_negatives = pixelwise_contrastive_loss.config["scale_by_hard_negatives_DIFFERENT_OBJECT"] 168 | blind_non_match_loss = zero_loss() 169 | if not (SpartanDataset.is_empty(blind_non_matches_a.data)): 170 | M_descriptor = pixelwise_contrastive_loss.config["M_background"] 171 | 172 | blind_non_match_loss, num_hard_negatives =\ 173 | pixelwise_contrastive_loss.non_match_loss_descriptor_only(image_a_pred, image_b_pred, 174 | blind_non_matches_a, blind_non_matches_b, 175 | M_descriptor=M_descriptor) 176 | 177 | if scale_by_hard_negatives: 178 | scale_factor = max(num_hard_negatives, 1) 179 | else: 180 | scale_factor = max(len(blind_non_matches_a), 1) 181 | 182 | blind_non_match_loss = 1.0/scale_factor * blind_non_match_loss 183 | loss = blind_non_match_loss 184 | return loss, zero_loss(), zero_loss(), zero_loss(), blind_non_match_loss 185 | 186 | def get_same_object_across_scene_loss(pixelwise_contrastive_loss, image_a_pred, image_b_pred, 187 | blind_non_matches_a, blind_non_matches_b): 188 | """ 189 | Simple wrapper for pixelwise_contrastive_loss functions. Args and return args documented above in get_loss() 190 | """ 191 | blind_non_match_loss = zero_loss() 192 | if not (SpartanDataset.is_empty(blind_non_matches_a.data)): 193 | blind_non_match_loss, num_hard_negatives =\ 194 | pixelwise_contrastive_loss.non_match_loss_descriptor_only(image_a_pred, image_b_pred, 195 | blind_non_matches_a, blind_non_matches_b, 196 | M_descriptor=pcl._config["M_masked"], invert=True) 197 | 198 | if pixelwise_contrastive_loss._config["scale_by_hard_negatives"]: 199 | scale_factor = max(num_hard_negatives, 1) 200 | else: 201 | scale_factor = max(len(blind_non_matches_a), 1) 202 | 203 | loss = 1.0/scale_factor * blind_non_match_loss 204 | blind_non_match_loss_scaled = 1.0/scale_factor * blind_non_match_loss 205 | return loss, zero_loss(), zero_loss(), zero_loss(), blind_non_match_loss 206 | 207 | def zero_loss(): 208 | return Variable(torch.FloatTensor([0]).cuda()) 209 | 210 | def is_zero_loss(loss): 211 | return loss.data[0] < 1e-20 212 | 213 | 214 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | """losses 2 | # losses for heatmap residule 3 | # use it if you're computing residual loss. 4 | # current disable residual loss 5 | 6 | """ 7 | # losse 8 | import torch 9 | 10 | def print_var(points): 11 | print("points: ", points.shape) 12 | print("points: ", points) 13 | pass 14 | 15 | # from utils.losses import pts_to_bbox 16 | def pts_to_bbox(points, patch_size): 17 | """ 18 | input: 19 | points: (y, x) 20 | output: 21 | bbox: (x1, y1, x2, y2) 22 | """ 23 | 24 | shift_l = (patch_size+1) / 2 25 | shift_r = patch_size - shift_l 26 | pts_l = points-shift_l 27 | pts_r = points+shift_r+1 28 | bbox = torch.stack((pts_l[:,1], pts_l[:,0], pts_r[:,1], pts_r[:,0]), dim=1) 29 | return bbox 30 | pass 31 | 32 | # roi pooling 33 | # from utils.losses import _roi_pool 34 | # def _roi_pool(pred_heatmap, rois, patch_size=8): 35 | # from utils.roi_pool import RoIPool # noqa: E402 36 | # m = RoIPool(patch_size, 1.0) 37 | # patches = m(pred_heatmap, rois.float()) 38 | # return patches 39 | 40 | # torchvision roi pooling 41 | def _roi_pool(pred_heatmap, rois, patch_size=8): 42 | from torchvision.ops import roi_pool 43 | patches = roi_pool(pred_heatmap, rois.float(), (patch_size, patch_size), spatial_scale=1.0) 44 | return patches 45 | pass 46 | 47 | # from utils.losses import norm_patches 48 | def norm_patches(patches): 49 | patch_size = patches.shape[-1] 50 | patches = patches.view(-1, 1, patch_size*patch_size) 51 | d = torch.sum(patches, dim=-1).unsqueeze(-1) + 1e-6 52 | patches = patches/d 53 | patches = patches.view(-1, 1, patch_size, patch_size) 54 | # print("patches: ", patches.shape) 55 | return patches 56 | 57 | # from utils.losses import extract_patch_from_points 58 | def extract_patch_from_points(heatmap, points, patch_size=5): 59 | """ 60 | this function works in numpy 61 | """ 62 | import numpy as np 63 | from utils.utils import toNumpy 64 | # numpy 65 | if type(heatmap) is torch.Tensor: 66 | heatmap = toNumpy(heatmap) 67 | heatmap = heatmap.squeeze() # [H, W] 68 | # padding 69 | pad_size = int(patch_size/2) 70 | heatmap = np.pad(heatmap, pad_size, 'constant') 71 | # crop it 72 | patches = [] 73 | ext = lambda img, pnt, wid: img[pnt[1]:pnt[1]+wid, pnt[0]:pnt[0]+wid] 74 | print("heatmap: ", heatmap.shape) 75 | for i in range(points.shape[0]): 76 | # print("point: ", points[i,:]) 77 | patch = ext(heatmap, points[i,:].astype(int), patch_size) 78 | # print("patch: ", patch.shape) 79 | patches.append(patch) 80 | 81 | # if i > 10: break 82 | # extract points 83 | return patches 84 | 85 | # from utils.losses import extract_patches 86 | def extract_patches(label_idx, image, patch_size=7): 87 | """ 88 | return: 89 | patches: tensor [N, 1, patch, patch] 90 | """ 91 | rois = pts_to_bbox(label_idx[:,2:], patch_size).long() 92 | # filter out?? 93 | rois = torch.cat((label_idx[:,:1], rois), dim=1) 94 | # print_var(rois) 95 | # print_var(image) 96 | patches = _roi_pool(image, rois, patch_size=patch_size) 97 | return patches 98 | 99 | # from utils.losses import points_to_4d 100 | def points_to_4d(points): 101 | """ 102 | input: 103 | points: tensor [N, 2] check(y, x) 104 | """ 105 | num_of_points = points.shape[0] 106 | cols = torch.zeros(num_of_points, 1).float() 107 | points = torch.cat((cols, cols, points.float()), dim=1) 108 | return points 109 | 110 | # from utils.losses import soft_argmax_2d 111 | def soft_argmax_2d(patches, normalized_coordinates=True): 112 | """ 113 | params: 114 | patches: (B, N, H, W) 115 | return: 116 | coor: (B, N, 2) (x, y) 117 | 118 | """ 119 | import torchgeometry as tgm 120 | m = tgm.contrib.SpatialSoftArgmax2d(normalized_coordinates=normalized_coordinates) 121 | coords = m(patches) # 1x4x2 122 | return coords 123 | 124 | ## log on patches 125 | # from utils.losses import do_log 126 | def do_log(patches): 127 | patches[patches<0] = 1e-6 128 | patches_log = torch.log(patches) 129 | return patches_log 130 | 131 | # from utils.losses import subpixel_loss 132 | def subpixel_loss(labels_2D, labels_res, pred_heatmap, patch_size=7): 133 | """ 134 | input: 135 | (tensor should be in GPU) 136 | labels_2D: tensor [batch, 1, H, W] 137 | labels_res: tensor [batch, 2, H, W] 138 | pred_heatmap: tensor [batch, 1, H, W] 139 | 140 | return: 141 | loss: sum of all losses 142 | """ 143 | 144 | 145 | # soft argmax 146 | def _soft_argmax(patches): 147 | from models.SubpixelNet import SubpixelNet as subpixNet 148 | dxdy = subpixNet.soft_argmax_2d(patches) # tensor [B, N, patch, patch] 149 | dxdy = dxdy.squeeze(1) # tensor [N, 2] 150 | return dxdy 151 | 152 | points = labels_2D[...].nonzero() 153 | num_points = points.shape[0] 154 | if num_points == 0: 155 | return 0 156 | 157 | labels_res = labels_res.transpose(1,2).transpose(2,3).unsqueeze(1) 158 | rois = pts_to_bbox(points[:,2:], patch_size) 159 | # filter out?? 160 | rois = torch.cat((points[:,:1], rois), dim=1) 161 | points_res = labels_res[points[:,0],points[:,1],points[:,2],points[:,3],:] # tensor [N, 2] 162 | # print_var(rois) 163 | # print_var(labels_res) 164 | # print_var(points) 165 | # print("points max: ", points.max(dim=0)) 166 | # print_var(labels_2D) 167 | # print_var(points_res) 168 | 169 | patches = _roi_pool(pred_heatmap, rois, patch_size=patch_size) 170 | # get argsoft max 171 | dxdy = _soft_argmax(patches) 172 | 173 | loss = (points_res - dxdy) 174 | loss = torch.norm(loss, p=2, dim=-1) 175 | loss = loss.sum()/num_points 176 | # print("loss: ", loss) 177 | return loss 178 | 179 | def subpixel_loss_no_argmax(labels_2D, labels_res, pred_heatmap, **options): 180 | # extract points 181 | points = labels_2D[...].nonzero() 182 | num_points = points.shape[0] 183 | if num_points == 0: 184 | return 0 185 | 186 | def residual_from_points(labels_res, points): 187 | # extract residuals 188 | labels_res = labels_res.transpose(1,2).transpose(2,3).unsqueeze(1) 189 | points_res = labels_res[points[:,0],points[:,1],points[:,2],points[:,3],:] # tensor [N, 2] 190 | return points_res 191 | 192 | points_res = residual_from_points(labels_res, points) 193 | # print_var(points_res) 194 | # extract predicted residuals 195 | pred_res = residual_from_points(pred_heatmap, points) 196 | # print_var(pred_res) 197 | 198 | # loss 199 | loss = (points_res - pred_res) 200 | loss = torch.norm(loss, p=2, dim=-1).mean() 201 | # loss = loss.sum()/num_points 202 | return loss 203 | pass 204 | 205 | if __name__ == '__main__': 206 | 207 | ## example: 208 | # device='cuda:0' 209 | # patches = subpixel_loss(warped_labels.to(device), labels_warped_res.to(device), warped_labels.to(device), 8, patch_size=11) 210 | 211 | pass -------------------------------------------------------------------------------- /utils/photometric.py: -------------------------------------------------------------------------------- 1 | """ photometric augmentation 2 | # used in dataloader 3 | """ 4 | 5 | from imgaug import augmenters as iaa 6 | import numpy as np 7 | import cv2 8 | 9 | 10 | class ImgAugTransform: 11 | def __init__(self, **config): 12 | from numpy.random import uniform 13 | from numpy.random import randint 14 | 15 | ## old photometric 16 | self.aug = iaa.Sequential([ 17 | iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 3.0))), 18 | iaa.Sometimes(0.25, 19 | iaa.OneOf([iaa.Dropout(p=(0, 0.1)), 20 | iaa.CoarseDropout(0.1, size_percent=0.5)])), 21 | iaa.Sometimes(0.25, 22 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05), per_channel=0.5), 23 | ) 24 | ]) 25 | 26 | if config['photometric']['enable']: 27 | params = config['photometric']['params'] 28 | aug_all = [] 29 | if params.get('random_brightness', False): 30 | change = params['random_brightness']['max_abs_change'] 31 | aug = iaa.Add((-change, change)) 32 | # aug_all.append(aug) 33 | aug_all.append(aug) 34 | # if params['random_contrast']: 35 | if params.get('random_contrast', False): 36 | change = params['random_contrast']['strength_range'] 37 | aug = iaa.LinearContrast((change[0], change[1])) 38 | aug_all.append(aug) 39 | # if params['additive_gaussian_noise']: 40 | if params.get('additive_gaussian_noise', False): 41 | change = params['additive_gaussian_noise']['stddev_range'] 42 | aug = iaa.AdditiveGaussianNoise(scale=(change[0], change[1])) 43 | aug_all.append(aug) 44 | # if params['additive_speckle_noise']: 45 | if params.get('additive_speckle_noise', False): 46 | change = params['additive_speckle_noise']['prob_range'] 47 | # aug = iaa.Dropout(p=(change[0], change[1])) 48 | aug = iaa.ImpulseNoise(p=(change[0], change[1])) 49 | aug_all.append(aug) 50 | # if params['motion_blur']: 51 | if params.get('motion_blur', False): 52 | change = params['motion_blur']['max_kernel_size'] 53 | if change > 3: 54 | change = randint(3, change) 55 | elif change == 3: 56 | aug = iaa.Sometimes(0.5, iaa.MotionBlur(change)) 57 | aug_all.append(aug) 58 | 59 | if params.get('GaussianBlur', False): 60 | change = params['GaussianBlur']['sigma'] 61 | aug = iaa.GaussianBlur(sigma=(change)) 62 | aug_all.append(aug) 63 | 64 | self.aug = iaa.Sequential(aug_all) 65 | 66 | 67 | else: 68 | self.aug = iaa.Sequential([ 69 | iaa.Noop(), 70 | ]) 71 | 72 | def __call__(self, img): 73 | img = np.array(img) 74 | img = (img * 255).astype(np.uint8) 75 | img = self.aug.augment_image(img) 76 | img = img.astype(np.float32) / 255 77 | return img 78 | 79 | 80 | 81 | class customizedTransform: 82 | def __init__(self): 83 | pass 84 | 85 | def additive_shade(self, image, nb_ellipses=20, transparency_range=[-0.5, 0.8], 86 | kernel_size_range=[250, 350]): 87 | def _py_additive_shade(img): 88 | min_dim = min(img.shape[:2]) / 4 89 | mask = np.zeros(img.shape[:2], np.uint8) 90 | for i in range(nb_ellipses): 91 | ax = int(max(np.random.rand() * min_dim, min_dim / 5)) 92 | ay = int(max(np.random.rand() * min_dim, min_dim / 5)) 93 | max_rad = max(ax, ay) 94 | x = np.random.randint(max_rad, img.shape[1] - max_rad) # center 95 | y = np.random.randint(max_rad, img.shape[0] - max_rad) 96 | angle = np.random.rand() * 90 97 | cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1) 98 | 99 | transparency = np.random.uniform(*transparency_range) 100 | kernel_size = np.random.randint(*kernel_size_range) 101 | if (kernel_size % 2) == 0: # kernel_size has to be odd 102 | kernel_size += 1 103 | mask = cv2.GaussianBlur(mask.astype(np.float32), (kernel_size, kernel_size), 0) 104 | # shaded = img * (1 - transparency * mask[..., np.newaxis] / 255.) 105 | shaded = img * (1 - transparency * mask[..., np.newaxis] / 255.) 106 | return np.clip(shaded, 0, 255) 107 | 108 | shaded = _py_additive_shade(image) 109 | return shaded 110 | 111 | def __call__(self, img, **config): 112 | if config['photometric']['params']['additive_shade']: 113 | params = config['photometric']['params'] 114 | img = self.additive_shade(img * 255, **params['additive_shade']) 115 | return img / 255 116 | 117 | 118 | -------------------------------------------------------------------------------- /utils/photometric_augmentation.py: -------------------------------------------------------------------------------- 1 | """ deprecated: photometric augmentation from tensorflow implementation 2 | # not used in our pipeline 3 | # need to verify if synthetic generation uses it. 4 | """ 5 | import cv2 as cv 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | 10 | augmentations = [ 11 | 'additive_gaussian_noise', 12 | 'additive_speckle_noise', 13 | 'random_brightness', 14 | 'random_contrast', 15 | 'additive_shade', 16 | 'motion_blur' 17 | ] 18 | 19 | 20 | def additive_gaussian_noise(image, stddev_range=[5, 95]): 21 | stddev = tf.random_uniform((), *stddev_range) 22 | noise = tf.random_normal(tf.shape(image), stddev=stddev) 23 | noisy_image = tf.clip_by_value(image + noise, 0, 255) 24 | return noisy_image 25 | 26 | 27 | def additive_speckle_noise(image, prob_range=[0.0, 0.005]): 28 | prob = tf.random_uniform((), *prob_range) 29 | sample = tf.random_uniform(tf.shape(image)) 30 | noisy_image = tf.where(sample <= prob, tf.zeros_like(image), image) 31 | noisy_image = tf.where(sample >= (1. - prob), 255.*tf.ones_like(image), noisy_image) 32 | return noisy_image 33 | 34 | 35 | def random_brightness(image, max_abs_change=50): 36 | return tf.clip_by_value(tf.image.random_brightness(image, max_abs_change), 0, 255) 37 | 38 | 39 | def random_contrast(image, strength_range=[0.5, 1.5]): 40 | return tf.clip_by_value(tf.image.random_contrast(image, *strength_range), 0, 255) 41 | 42 | 43 | def additive_shade(image, nb_ellipses=20, transparency_range=[-0.5, 0.8], 44 | kernel_size_range=[250, 350]): 45 | 46 | def _py_additive_shade(img): 47 | min_dim = min(img.shape[:2]) / 4 48 | mask = np.zeros(img.shape[:2], np.uint8) 49 | for i in range(nb_ellipses): 50 | ax = int(max(np.random.rand() * min_dim, min_dim / 5)) 51 | ay = int(max(np.random.rand() * min_dim, min_dim / 5)) 52 | max_rad = max(ax, ay) 53 | x = np.random.randint(max_rad, img.shape[1] - max_rad) # center 54 | y = np.random.randint(max_rad, img.shape[0] - max_rad) 55 | angle = np.random.rand() * 90 56 | cv.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1) 57 | 58 | transparency = np.random.uniform(*transparency_range) 59 | kernel_size = np.random.randint(*kernel_size_range) 60 | if (kernel_size % 2) == 0: # kernel_size has to be odd 61 | kernel_size += 1 62 | mask = cv.GaussianBlur(mask.astype(np.float32), (kernel_size, kernel_size), 0) 63 | shaded = img * (1 - transparency * mask[..., np.newaxis]/255.) 64 | return np.clip(shaded, 0, 255) 65 | 66 | shaded = tf.py_func(_py_additive_shade, [image], tf.float32) 67 | res = tf.reshape(shaded, tf.shape(image)) 68 | return res 69 | 70 | 71 | def motion_blur(image, max_kernel_size=10): 72 | 73 | def _py_motion_blur(img): 74 | # Either vertial, hozirontal or diagonal blur 75 | mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up']) 76 | ksize = np.random.randint(0, (max_kernel_size+1)/2)*2 + 1 # make sure is odd 77 | center = int((ksize-1)/2) 78 | kernel = np.zeros((ksize, ksize)) 79 | if mode == 'h': 80 | kernel[center, :] = 1. 81 | elif mode == 'v': 82 | kernel[:, center] = 1. 83 | elif mode == 'diag_down': 84 | kernel = np.eye(ksize) 85 | elif mode == 'diag_up': 86 | kernel = np.flip(np.eye(ksize), 0) 87 | var = ksize * ksize / 16. 88 | grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1) 89 | gaussian = np.exp(-(np.square(grid-center)+np.square(grid.T-center))/(2.*var)) 90 | kernel *= gaussian 91 | kernel /= np.sum(kernel) 92 | img = cv.filter2D(img, -1, kernel) 93 | return img 94 | 95 | blurred = tf.py_func(_py_motion_blur, [image], tf.float32) 96 | return tf.reshape(blurred, tf.shape(image)) 97 | -------------------------------------------------------------------------------- /utils/print_tool.py: -------------------------------------------------------------------------------- 1 | """tools to print object shape or type 2 | 3 | """ 4 | 5 | 6 | # from utils.print_tool import print_config 7 | def print_config(config, file=None): 8 | print('='*10, ' important config: ', '='*10, file=file) 9 | for item in list(config): 10 | print(item, ": ", config[item], file=file) 11 | 12 | print('='*32) 13 | 14 | # from utils.print_tool import print_dict_attr 15 | def print_dict_attr(dictionary, attr=None, file=None): 16 | for item in list(dictionary): 17 | d = dictionary[item] 18 | if attr == None: 19 | print(item, ": ", d, file=file) 20 | else: 21 | if hasattr(d, attr): 22 | print(item, ": ", getattr(d, attr), file=file) 23 | else: 24 | print(item, ": ", len(d), file=file) 25 | 26 | import logging 27 | # from utils.print_tool import datasize 28 | def datasize(train_loader, config, tag='train'): 29 | logging.info('== %s split size %d in %d batches'%\ 30 | (tag, len(train_loader)*config['model']['batch_size'], len(train_loader))) 31 | pass -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | """tools to combine dictionary 2 | 3 | """ 4 | import collections 5 | 6 | 7 | def dict_update(d, u): 8 | """Improved update for nested dictionaries. 9 | 10 | Arguments: 11 | d: The dictionary to be updated. 12 | u: The update dictionary. 13 | 14 | Returns: 15 | The updated dictionary. 16 | """ 17 | for k, v in u.items(): 18 | if isinstance(v, collections.Mapping): 19 | d[k] = dict_update(d.get(k, {}), v) 20 | else: 21 | d[k] = v 22 | return d 23 | -------------------------------------------------------------------------------- /utils/var_dim.py: -------------------------------------------------------------------------------- 1 | """change the dimension of tensor/ numpy array 2 | """ 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | # from utils.var_dim import to3dim 9 | def to3dim(img): 10 | if img.ndim == 2: 11 | img = img[:, :, np.newaxis] 12 | return img 13 | 14 | 15 | # torch 16 | # from utils.var_dim import tensorto4d 17 | def tensorto4d(inp): 18 | if len(inp.shape) == 2: 19 | inp = inp.view(1, 1, inp.shape[0], inp.shape[1]) 20 | elif len(inp.shape) == 3: 21 | inp = inp.view(1, inp.shape[0], inp.shape[1], inp.shape[2]) 22 | return inp 23 | 24 | # torch 25 | # from utils.var_dim import squeezeToNumpy 26 | def squeezeToNumpy(tensor_arr): 27 | return tensor_arr.detach().cpu().numpy().squeeze() 28 | 29 | # from utils.var_dim import toNumpy 30 | def toNumpy(tensor): 31 | return tensor.detach().cpu().numpy() --------------------------------------------------------------------------------