├── LICENSE ├── README.md ├── extract_kpts_dsc.py └── model ├── HyNet ├── hynet_model.py └── weights │ ├── HardNet++.pth │ ├── HyNet_LIB.pth │ ├── HyNet_ND.pth │ └── HyNet_YOS.pth ├── config_files └── keynet_configs.py ├── extraction_tools.py ├── kornia_tools └── utils.py ├── modules.py ├── network.py └── weights └── keynet_pytorch.pth /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Axel Barroso 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters 2 | This repository contains the PyTorch implementation of Key.Net keypoint detector: 3 | 4 | ```text 5 | "Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters". 6 | Axel Barroso-Laguna, Edgar Riba, Daniel Ponsa, Krystian Mikolajczyk. ICCV 2019. 7 | ``` 8 | [[Paper on arxiv](https://arxiv.org/abs/1904.00889)] 9 | 10 | The training code will be soon published, in the meantime, please check our official [TensorFlow implementation](https://github.com/axelBarroso/Key.Net) for notes about training Key.Net. 11 | 12 | 13 | ## Prerequisite 14 | 15 | Python 3.7 is required for running Key.Net code. Use Conda to install the dependencies: 16 | 17 | ```bash 18 | conda create --name keyNet_torch 19 | conda activate keyNet_torch 20 | conda install pytorch==1.2.0 -c pytorch 21 | conda install -c conda-forge opencv tqdm 22 | conda install -c anaconda pandas 23 | conda install -c pytorch torchvision 24 | pip install kornia==0.1.4 25 | ``` 26 | 27 | ## Feature Extraction 28 | 29 | `extract_kpts_dsc.py` can be used to extract Key.Net + HyNet features for a given list of images. The list of images must contain the relative path to them, and you must provide the root path of images separately. 30 | 31 | The script generates two numpy files, one '.kpt' for keypoints, and a '.dsc' for descriptors. The descriptor used together with Key.Net is [HyNet](https://github.com/yuruntian/HyNet). The output format of the keypoints is as follow: 32 | 33 | - `keypoints` [`N x 4`] array containing the positions of keypoints `x, y`, scales `s` and their scores `sc`. 34 | 35 | 36 | Arguments: 37 | 38 | * list-images: File containing the image paths for extracting features. 39 | * root-images: The output path to save the extracted features. 40 | * results-dir: Indicates the root of the directory containing the images. 41 | * num-points: The number of desired features to extract. Default: 5000. 42 | 43 | ## BibTeX 44 | 45 | If you use this code in your research, please cite our paper: 46 | 47 | ```bibtex 48 | @InProceedings{Barroso-Laguna2019ICCV, 49 | author = {Barroso-Laguna, Axel and Riba, Edgar and Ponsa, Daniel and Mikolajczyk, Krystian}, 50 | title = {{Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters}}, 51 | booktitle = {Proceedings of the 2019 IEEE/CVF International Conference on Computer Vision}, 52 | year = {2019}, 53 | } 54 | ``` 55 | 56 | In addition, if you also use the descriptors extracted by HyNet, please consider citing: 57 | ```bibtex 58 | @inproceedings{hynet2020, 59 | author = {Tian, Yurun and Barroso Laguna, Axel and Ng, Tony and Balntas, Vassileios and Mikolajczyk, Krystian}, 60 | title = {HyNet: Learning Local Descriptor with Hybrid Similarity Measure and Triplet Loss}, 61 | booktitle = {NeurIPS}, 62 | year = {2020} 63 | } 64 | -------------------------------------------------------------------------------- /extract_kpts_dsc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | from model.extraction_tools import initialize_networks, compute_kpts_desc, create_result_dir 7 | from model.config_files.keynet_configs import keynet_config 8 | 9 | def extract_features(): 10 | 11 | # Parse command line arguments. 12 | parser = argparse.ArgumentParser(description='Key.Net PyTorch + HyNet local descriptor.' 13 | 'It returns local features as:' 14 | 'kpts: Num_kpts x 4 - [x, y, scale, score]' 15 | 'desc: Num_kpts x 128') 16 | 17 | parser.add_argument('--list-images', type=str, help='File containing the image paths for extracting features.', 18 | required=True) 19 | 20 | parser.add_argument('--root-images', type=str, default='', 21 | help='Indicates the root of the directory containing the images.' 22 | 'The code will copy the structure and save the extracted features accordingly.') 23 | 24 | parser.add_argument('--method-name', type=str, default='keynet_hynet_default', 25 | help='The output name of the method.') 26 | 27 | parser.add_argument('--results-dir', type=str, default='extracted_features/', 28 | help='The output path to save the extracted keypoint.') 29 | 30 | parser.add_argument('--config-file', type=str, default='KeyNet_default_config', 31 | help='Indicates the configuration file to load Key.Net.') 32 | 33 | parser.add_argument('--num-kpts', type=int, default=5000, 34 | help='Indicates the maximum number of keypoints to be extracted.') 35 | 36 | parser.add_argument('--gpu-visible-devices', type=str, default='0', 37 | help='Indicates the device where model should run') 38 | 39 | 40 | args = parser.parse_known_args()[0] 41 | 42 | # Set CUDA GPU environment 43 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_visible_devices 44 | use_cuda = torch.cuda.is_available() 45 | device = torch.device('cuda:0' if use_cuda else 'cpu') 46 | 47 | # Read Key.Net model and extraction configuration 48 | conf = keynet_config[args.config_file] 49 | keynet_model, desc_model = initialize_networks(conf) 50 | 51 | # read image and extract keypoints and descriptors 52 | f = open(args.list_images, "r") 53 | # for path_to_image in f: 54 | lines = f.readlines() 55 | for idx_im in tqdm(range(len(lines))): 56 | tmp_line = lines[idx_im].split('\n')[0] 57 | im_path = os.path.join(args.root_images, tmp_line) 58 | 59 | xys, desc = compute_kpts_desc(im_path, keynet_model, desc_model, conf, device, num_points=args.num_kpts) 60 | 61 | result_path = os.path.join(args.results_dir, args.method_name, tmp_line) 62 | create_result_dir(result_path) 63 | 64 | np.save(result_path + '.kpt', xys) 65 | np.save(result_path + '.dsc', desc) 66 | 67 | print('{} feature extraction finished.'.format(args.method_name)) 68 | -------------------------------------------------------------------------------- /model/HyNet/hynet_model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | eps_l2_norm = 1e-10 5 | 6 | class FRN(nn.Module): 7 | def __init__(self, num_features, eps=1e-6, is_bias=True, is_scale=True, is_eps_leanable=False): 8 | """ 9 | weight = gamma, bias = beta 10 | 11 | beta, gamma: 12 | Variables of shape [1, 1, 1, C]. if TensorFlow 13 | Variables of shape [1, C, 1, 1]. if PyTorch 14 | eps: A scalar constant or learnable variable. 15 | """ 16 | super(FRN, self).__init__() 17 | 18 | self.num_features = num_features 19 | self.init_eps = eps 20 | self.is_eps_leanable = is_eps_leanable 21 | self.is_bias = is_bias 22 | self.is_scale = is_scale 23 | 24 | 25 | self.weight = nn.parameter.Parameter(torch.Tensor(1, num_features, 1, 1), requires_grad=True) 26 | self.bias = nn.parameter.Parameter(torch.Tensor(1, num_features, 1, 1), requires_grad=True) 27 | if is_eps_leanable: 28 | self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True) 29 | else: 30 | self.register_buffer('eps', torch.Tensor([eps])) 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | nn.init.ones_(self.weight) 35 | nn.init.zeros_(self.bias) 36 | if self.is_eps_leanable: 37 | nn.init.constant_(self.eps, self.init_eps) 38 | 39 | def extra_repr(self): 40 | return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__) 41 | 42 | def forward(self, x): 43 | """ 44 | 0, 1, 2, 3 -> (B, H, W, C) in TensorFlow 45 | 0, 1, 2, 3 -> (B, C, H, W) in PyTorch 46 | TensorFlow code 47 | nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) 48 | x = x * tf.rsqrt(nu2 + tf.abs(eps)) 49 | 50 | # This Code include TLU function max(y, tau) 51 | return tf.maximum(gamma * x + beta, tau) 52 | """ 53 | # Compute the mean norm of activations per channel. 54 | nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True) 55 | 56 | # Perform FRN. 57 | x = x * torch.rsqrt(nu2 + self.eps.abs()) 58 | 59 | # Scale and Bias 60 | if self.is_scale: 61 | x = self.weight * x 62 | if self.is_bias: 63 | x = x + self.bias 64 | return x 65 | 66 | class TLU(nn.Module): 67 | def __init__(self, num_features): 68 | """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau""" 69 | super(TLU, self).__init__() 70 | self.num_features = num_features 71 | self.tau = nn.parameter.Parameter(torch.Tensor(1, num_features, 1, 1), requires_grad=True) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | # nn.init.zeros_(self.tau) 76 | nn.init.constant_(self.tau, -1) 77 | 78 | def extra_repr(self): 79 | return 'num_features={num_features}'.format(**self.__dict__) 80 | 81 | def forward(self, x): 82 | return torch.max(x, self.tau) 83 | 84 | class HyNet(nn.Module): 85 | """ 86 | HyNet model definition. 87 | The FRN and TLU layer are from the papaer 88 | `Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep Neural Networks` 89 | https://github.com/yukkyo/PyTorch-FilterResponseNormalizationLayer 90 | """ 91 | def __init__(self, is_bias=True, is_bias_FRN=True, dim_desc=128, drop_rate=0.3): 92 | super(HyNet, self).__init__() 93 | self.dim_desc = dim_desc 94 | self.drop_rate = drop_rate 95 | 96 | self.layer1 = nn.Sequential( 97 | FRN(1, is_bias=is_bias_FRN), 98 | TLU(1), 99 | nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=is_bias), 100 | FRN(32, is_bias=is_bias_FRN), 101 | TLU(32), 102 | ) 103 | 104 | self.layer2 = nn.Sequential( 105 | nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=is_bias), 106 | FRN(32, is_bias=is_bias_FRN), 107 | TLU(32), 108 | ) 109 | 110 | self.layer3 = nn.Sequential( 111 | nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=is_bias), 112 | FRN(64, is_bias=is_bias_FRN), 113 | TLU(64), 114 | ) 115 | 116 | self.layer4 = nn.Sequential( 117 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=is_bias), 118 | FRN(64, is_bias=is_bias_FRN), 119 | TLU(64), 120 | ) 121 | self.layer5 = nn.Sequential( 122 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=is_bias), 123 | FRN(128, is_bias=is_bias_FRN), 124 | TLU(128), 125 | ) 126 | 127 | self.layer6 = nn.Sequential( 128 | nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=is_bias), 129 | FRN(128, is_bias=is_bias_FRN), 130 | TLU(128), 131 | ) 132 | 133 | self.layer7 = nn.Sequential( 134 | nn.Dropout(self.drop_rate), 135 | nn.Conv2d(128, self.dim_desc, kernel_size=8, bias=False), 136 | nn.BatchNorm2d(self.dim_desc, affine=False) 137 | ) 138 | 139 | self.desc_norm = nn.Sequential( 140 | nn.LocalResponseNorm(2 * self.dim_desc, alpha=2 * self.dim_desc, beta=0.5, k=0) 141 | ) 142 | 143 | return 144 | 145 | def forward(self, x): 146 | for layer in [self.layer1, self.layer2, self.layer3, self.layer4, self.layer5, self.layer6, self.layer7]: 147 | x = layer(x) 148 | 149 | x = self.desc_norm(x + eps_l2_norm) 150 | x = x.view(x.size(0), -1) 151 | return x 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /model/HyNet/weights/HardNet++.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axelBarroso/Key.Net-Pytorch/74fe0a75998445bb47b74b642b764543b25dc90a/model/HyNet/weights/HardNet++.pth -------------------------------------------------------------------------------- /model/HyNet/weights/HyNet_LIB.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axelBarroso/Key.Net-Pytorch/74fe0a75998445bb47b74b642b764543b25dc90a/model/HyNet/weights/HyNet_LIB.pth -------------------------------------------------------------------------------- /model/HyNet/weights/HyNet_ND.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axelBarroso/Key.Net-Pytorch/74fe0a75998445bb47b74b642b764543b25dc90a/model/HyNet/weights/HyNet_ND.pth -------------------------------------------------------------------------------- /model/HyNet/weights/HyNet_YOS.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axelBarroso/Key.Net-Pytorch/74fe0a75998445bb47b74b642b764543b25dc90a/model/HyNet/weights/HyNet_YOS.pth -------------------------------------------------------------------------------- /model/config_files/keynet_configs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | keynet_config = { 4 | 5 | 'KeyNet_default_config': 6 | { 7 | # Key.Net Model 8 | 'num_filters': 8, 9 | 'num_levels': 3, 10 | 'kernel_size': 5, 11 | 12 | # Trained weights 13 | 'weights_detector': 'model/weights/keynet_pytorch.pth', 14 | 'weights_descriptor': 'model/HyNet/weights/HyNet_LIB.pth', 15 | 16 | # Extraction Parameters 17 | 'nms_size': 15, 18 | 'pyramid_levels': 4, 19 | 'up_levels': 1, 20 | 'scale_factor_levels': np.sqrt(2), 21 | 's_mult': 22, 22 | }, 23 | } 24 | -------------------------------------------------------------------------------- /model/extraction_tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from os import mkdir, path 5 | import torch.nn.functional as F 6 | from model.network import KeyNet 7 | from model.modules import NonMaxSuppression 8 | from model.HyNet.hynet_model import HyNet 9 | from model.kornia_tools.utils import custom_pyrdown 10 | from model.kornia_tools.utils import laf_from_center_scale_ori as to_laf 11 | from model.kornia_tools.utils import extract_patches_from_pyramid as extract_patch 12 | 13 | 14 | def create_result_dir(result_path): 15 | ''' 16 | It creates the directory where features will be stored 17 | ''' 18 | directories = result_path.split('/') 19 | tmp = '' 20 | for idx, dir in enumerate(directories): 21 | tmp += (dir + '/') 22 | if idx == len(directories)-1: 23 | continue 24 | if not path.isdir(tmp): 25 | mkdir(tmp) 26 | 27 | 28 | def remove_borders(score_map, borders): 29 | ''' 30 | It removes the borders of the image to avoid detections on the corners 31 | ''' 32 | shape = score_map.shape 33 | mask = torch.ones_like(score_map) 34 | 35 | mask[:, :, 0:borders, :] = 0 36 | mask[:, :, :, 0:borders] = 0 37 | mask[:, :, shape[2] - borders:shape[2], :] = 0 38 | mask[:, :, :, shape[3] - borders:shape[3]] = 0 39 | 40 | return mask*score_map 41 | 42 | 43 | def extract_ms_feats(keynet_model, desc_model, image, factor, s_mult, device, 44 | num_kpts_i=1000, nms=None, down_level=0, up_level=False, im_size=[]): 45 | ''' 46 | Extracts the features for a specific scale level from the pyramid 47 | :param keynet_model: Key.Net model 48 | :param desc_model: HyNet model 49 | :param image: image as a PyTorch tensor 50 | :param factor: rescaling pyramid factor 51 | :param s_mult: Descriptor area multiplier 52 | :param device: GPU or CPU 53 | :param num_kpts_i: number of desired keypoints in the level 54 | :param nms: nums size 55 | :param down_level: Indicates if images needs to go down one pyramid level 56 | :param up_level: Indicates if image is an upper scale level 57 | :param im_size: Original image size 58 | :return: It returns the local features for a specific image level 59 | ''' 60 | 61 | if down_level and not up_level: 62 | image = custom_pyrdown(image, factor=factor) 63 | _, _, nh, nw = image.shape 64 | factor = (im_size[0]/nw, im_size[1]/nh) 65 | elif not up_level: 66 | factor = (1., 1.) 67 | 68 | # src kpts: 69 | with torch.no_grad(): 70 | det_map = keynet_model(image) 71 | det_map = remove_borders(det_map, borders=15) 72 | 73 | kps = nms(det_map) 74 | c = det_map[0, 0, kps[0], kps[1]] 75 | sc, indices = torch.sort(c, descending=True) 76 | indices = indices[torch.where(sc > 0.)] 77 | kps = kps[:, indices[:num_kpts_i]] 78 | kps_np = torch.cat([kps[1].view(-1, 1).float(), kps[0].view(-1, 1).float(), c[indices[:num_kpts_i]].view(-1, 1).float()], 79 | dim=1).detach().cpu().numpy() 80 | num_kpts = len(kps_np) 81 | kp = torch.cat([kps[1].view(-1, 1).float(), kps[0].view(-1, 1).float()],dim=1).unsqueeze(0).cpu() 82 | s = s_mult * torch.ones((1, num_kpts, 1, 1)) 83 | src_laf = to_laf(kp, s, torch.zeros((1, num_kpts, 1))) 84 | 85 | # HyNet takes images on the range [0, 255] 86 | patches = extract_patch(255*image.cpu(), src_laf, PS=32, normalize_lafs_before_extraction=True)[0] 87 | 88 | if len(patches) > 1000: 89 | for i_patches in range(len(patches)//1000+1): 90 | if i_patches == 0: 91 | descs = desc_model(patches[:1000].to(device)) 92 | else: 93 | descs_tmp = desc_model(patches[1000*i_patches:1000*(i_patches+1)].to(device)) 94 | descs = torch.cat([descs, descs_tmp], dim=0) 95 | descs = descs.cpu().detach().numpy() 96 | else: 97 | descs = desc_model(patches.to(device)).cpu().detach().numpy() 98 | 99 | kps_np[:, 0] *= factor[0] 100 | kps_np[:, 1] *= factor[1] 101 | 102 | return kps_np, descs, image.to(device) 103 | 104 | 105 | def compute_kpts_desc(im_path, keynet_model, desc_model, conf, device, num_points): 106 | ''' 107 | The script computes Multi-scale kpts and desc of an image. 108 | 109 | :param im_path: path to image 110 | :param keynet_model: Detector model 111 | :param desc_model: Descriptor model 112 | :param conf: Configuration file to load extraction settings 113 | :param device: GPU or CPU 114 | :param num_points: Number of total local features 115 | :return: Keypoints and descriptors associated with the image 116 | ''' 117 | 118 | # Load extraction configuration 119 | pyramid_levels = conf['pyramid_levels'] 120 | up_levels = conf['up_levels'] 121 | scale_factor_levels = conf['scale_factor_levels'] 122 | s_mult = conf['s_mult'] 123 | nms_size = conf['nms_size'] 124 | nms = NonMaxSuppression(nms_size=nms_size) 125 | 126 | # Compute points per level 127 | point_level = [] 128 | tmp = 0.0 129 | factor_points = (scale_factor_levels ** 2) 130 | levels = pyramid_levels + up_levels + 1 131 | for idx_level in range(levels): 132 | tmp += factor_points ** (-1 * (idx_level - up_levels)) 133 | point_level.append(num_points * factor_points ** (-1 * (idx_level - up_levels))) 134 | 135 | point_level = np.asarray(list(map(lambda x: int(x / tmp), point_level))) 136 | 137 | im_np = np.asarray(cv2.imread(im_path, 0) / 255., np.float32) 138 | 139 | im = torch.from_numpy(im_np).unsqueeze(0).unsqueeze(0) 140 | im = im.to(device) 141 | 142 | if up_levels: 143 | im_up = torch.from_numpy(im_np).unsqueeze(0).unsqueeze(0) 144 | im_up = im_up.to(device) 145 | 146 | src_kp = [] 147 | _, _, h, w = im.shape 148 | # Extract features from the upper levels 149 | for idx_level in range(up_levels): 150 | 151 | num_points_level = point_level[len(point_level) - pyramid_levels - 1 - (idx_level+1)] 152 | 153 | # Resize input image 154 | up_factor = scale_factor_levels ** (1 + idx_level) 155 | nh, nw = int(h * up_factor), int(w * up_factor) 156 | up_factor_kpts = (w/nw, h/nh) 157 | im_up = F.interpolate(im_up, (nh, nw), mode='bilinear', align_corners=False) 158 | 159 | src_kp_i, src_dsc_i, im_up = extract_ms_feats(keynet_model, desc_model, im_up, up_factor_kpts, 160 | s_mult=s_mult, device=device, num_kpts_i=num_points_level, 161 | nms=nms, down_level=idx_level+1, up_level=True, im_size=[w, h]) 162 | 163 | src_kp_i = np.asarray(list(map(lambda x: [x[0], x[1], (1 / scale_factor_levels) ** (1 + idx_level), x[2]], src_kp_i))) 164 | 165 | if src_kp == []: 166 | src_kp = src_kp_i 167 | src_dsc = src_dsc_i 168 | else: 169 | src_kp = np.concatenate([src_kp, src_kp_i], axis=0) 170 | src_dsc = np.concatenate([src_dsc, src_dsc_i], axis=0) 171 | 172 | # Extract features from the downsampling pyramid 173 | for idx_level in range(pyramid_levels + 1): 174 | 175 | num_points_level = point_level[idx_level] 176 | if idx_level > 0 or up_levels: 177 | res_points = int(np.asarray([point_level[a] for a in range(0, idx_level + 1 + up_levels)]).sum() - len(src_kp)) 178 | num_points_level = res_points 179 | 180 | src_kp_i, src_dsc_i, im = extract_ms_feats(keynet_model, desc_model, im, scale_factor_levels, s_mult=s_mult, 181 | device=device, num_kpts_i=num_points_level, nms=nms, 182 | down_level=idx_level, im_size=[w, h]) 183 | 184 | src_kp_i = np.asarray(list(map(lambda x: [x[0], x[1], scale_factor_levels ** idx_level, x[2]], src_kp_i))) 185 | 186 | if src_kp == []: 187 | src_kp = src_kp_i 188 | src_dsc = src_dsc_i 189 | else: 190 | src_kp = np.concatenate([src_kp, src_kp_i], axis=0) 191 | src_dsc = np.concatenate([src_dsc, src_dsc_i], axis=0) 192 | 193 | return src_kp, src_dsc 194 | 195 | 196 | def initialize_networks(conf): 197 | ''' 198 | It loads the detector and descriptor models 199 | :param conf: It contains the configuration and weights path of the models 200 | :return: Key.Net and HyNet models 201 | ''' 202 | use_cuda = torch.cuda.is_available() 203 | device = torch.device("cuda:0" if use_cuda else "cpu") 204 | detector_path = conf['weights_detector'] 205 | descriptor_path = conf['weights_descriptor'] 206 | 207 | # Define keynet_model model 208 | keynet_model = KeyNet(conf) 209 | checkpoint = torch.load(detector_path) 210 | keynet_model.load_state_dict(checkpoint['state_dict']) 211 | keynet_model = keynet_model.to(device) 212 | keynet_model.eval() 213 | 214 | desc_model = HyNet() 215 | checkpoint = torch.load(descriptor_path) 216 | desc_model.load_state_dict(checkpoint) 217 | 218 | desc_model = desc_model.to(device) 219 | desc_model.eval() 220 | 221 | return keynet_model, desc_model 222 | -------------------------------------------------------------------------------- /model/kornia_tools/utils.py: -------------------------------------------------------------------------------- 1 | # TODO: update to official Kornia functions after being able to use latest version 2 | # Copied from oficial website: https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/laf.html 3 | import kornia 4 | import torch 5 | import torch.nn.functional as F 6 | from kornia.filters import filter2D 7 | from kornia.feature.laf import normalize_laf, denormalize_laf, scale_laf, raise_error_if_laf_is_not_valid, \ 8 | get_laf_scale, generate_patch_grid_from_normalized_LAF 9 | 10 | 11 | def laf_from_center_scale_ori(xy: torch.Tensor, scale: torch.Tensor, ori: torch.Tensor) -> torch.Tensor: 12 | """Returns orientation of the LAFs, in radians. Useful to create kornia LAFs from OpenCV keypoints 13 | 14 | Args: 15 | xy: (torch.Tensor): tensor [BxNx2]. 16 | scale: (torch.Tensor): tensor [BxNx1x1]. 17 | ori: (torch.Tensor): tensor [BxNx1]. 18 | 19 | Returns: 20 | torch.Tensor: tensor BxNx2x3 . 21 | """ 22 | names = ['xy', 'scale', 'ori'] 23 | for var_name, var, req_shape in zip(names, 24 | [xy, scale, ori], 25 | [("B", "N", 2), ("B", "N", 1, 1), ("B", "N", 1)]): 26 | if not isinstance(var, torch.Tensor): 27 | raise TypeError("{} type is not a torch.Tensor. Got {}" 28 | .format(var_name, type(var))) 29 | if len(var.shape) != len(req_shape): # type: ignore # because it does not like len(tensor.shape) 30 | raise TypeError( 31 | "{} shape should be must be [{}]. " 32 | "Got {}".format(var_name, str(req_shape), var.size())) 33 | for i, dim in enumerate(req_shape): # type: ignore # because it wants typing for dim 34 | if dim is not int: 35 | continue 36 | if var.size(i) != dim: 37 | raise TypeError( 38 | "{} shape should be must be [{}]. " 39 | "Got {}".format(var_name, str(req_shape), var.size())) 40 | unscaled_laf: torch.Tensor = torch.cat([kornia.angle_to_rotation_matrix(ori.squeeze(-1)), 41 | xy.unsqueeze(-1)], dim=-1) 42 | laf: torch.Tensor = scale_laf(unscaled_laf, scale) 43 | return laf 44 | 45 | 46 | def extract_patches_from_pyramid(img: torch.Tensor, 47 | laf: torch.Tensor, 48 | PS: int = 32, 49 | normalize_lafs_before_extraction: bool = True) -> torch.Tensor: 50 | """Extract patches defined by LAFs from image tensor. 51 | Patches are extracted from appropriate pyramid level 52 | 53 | Args: 54 | laf: (torch.Tensor). 55 | images: (torch.Tensor) images, LAFs are detected in 56 | PS: (int) patch size, default = 32 57 | normalize_lafs_before_extraction (bool): if True, lafs are normalized to image size, default = True 58 | 59 | Returns: 60 | patches: (torch.Tensor) :math:`(B, N, CH, PS,PS)` 61 | """ 62 | raise_error_if_laf_is_not_valid(laf) 63 | if normalize_lafs_before_extraction: 64 | nlaf: torch.Tensor = normalize_laf(laf, img) 65 | else: 66 | nlaf = laf 67 | B, N, _, _ = laf.size() 68 | num, ch, h, w = img.size() 69 | scale = 2.0 * get_laf_scale(denormalize_laf(nlaf, img)) / float(PS) 70 | half: float = 0.5 71 | pyr_idx = (scale.log2() + half).relu().long() 72 | cur_img = img 73 | cur_pyr_level = 0 74 | out = torch.zeros(B, N, ch, PS, PS).to(nlaf.dtype).to(nlaf.device) 75 | while min(cur_img.size(2), cur_img.size(3)) >= PS: 76 | num, ch, h, w = cur_img.size() 77 | # for loop temporarily, to be refactored 78 | for i in range(B): 79 | scale_mask = (pyr_idx[i] == cur_pyr_level).squeeze() 80 | if (scale_mask.float().sum()) == 0: 81 | continue 82 | scale_mask = (scale_mask > 0).view(-1) 83 | grid = generate_patch_grid_from_normalized_LAF( 84 | cur_img[i:i + 1], 85 | nlaf[i:i + 1, scale_mask, :, :], 86 | PS) 87 | patches = F.grid_sample(cur_img[i:i + 1].expand(grid.size(0), ch, h, w), grid, # type: ignore 88 | # padding_mode="border", align_corners=False) 89 | padding_mode="border") 90 | out[i].masked_scatter_(scale_mask.view(-1, 1, 1, 1), patches) 91 | cur_img = kornia.pyrdown(cur_img) 92 | cur_pyr_level += 1 93 | return out 94 | 95 | 96 | # Utility from Kornia: https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/transform/pyramid.html 97 | def _get_pyramid_gaussian_kernel() -> torch.Tensor: 98 | """Utility function that return a pre-computed gaussian kernel.""" 99 | return torch.tensor([[ 100 | [1., 4., 6., 4., 1.], 101 | [4., 16., 24., 16., 4.], 102 | [6., 24., 36., 24., 6.], 103 | [4., 16., 24., 16., 4.], 104 | [1., 4., 6., 4., 1.] 105 | ]]) / 256. 106 | 107 | 108 | def custom_pyrdown(input: torch.Tensor, factor: float = 2., border_type: str = 'reflect', align_corners: bool = False) -> torch.Tensor: 109 | r"""Blurs a tensor and downsamples it. 110 | 111 | Args: 112 | input (tensor): the tensor to be downsampled. 113 | border_type (str): the padding mode to be applied before convolving. 114 | The expected modes are: ``'constant'``, ``'reflect'``, 115 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 116 | align_corners(bool): interpolation flag. Default: False. See 117 | https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail. 118 | 119 | Return: 120 | torch.Tensor: the downsampled tensor. 121 | 122 | Examples: 123 | >>> input = torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4) 124 | >>> pyrdown(input, align_corners=True) 125 | tensor([[[[ 3.7500, 5.2500], 126 | [ 9.7500, 11.2500]]]]) 127 | """ 128 | if not len(input.shape) == 4: 129 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") 130 | kernel: torch.Tensor = _get_pyramid_gaussian_kernel() 131 | b, c, height, width = input.shape 132 | # blur image 133 | x_blur: torch.Tensor = filter2D(input, kernel, border_type) 134 | 135 | # downsample. 136 | out: torch.Tensor = F.interpolate(x_blur, size=(int(height // factor), int(width // factor)), mode='bilinear', 137 | align_corners=align_corners) 138 | return out 139 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import kornia 4 | 5 | 6 | class feature_extractor(nn.Module): 7 | ''' 8 | It loads both, the handcrafted and learnable blocks 9 | ''' 10 | def __init__(self): 11 | super(feature_extractor, self).__init__() 12 | 13 | self.hc_block = handcrafted_block() 14 | self.lb_block = learnable_block() 15 | 16 | def forward(self, x): 17 | x_hc = self.hc_block(x) 18 | x_lb = self.lb_block(x_hc) 19 | return x_lb 20 | 21 | 22 | class handcrafted_block(nn.Module): 23 | ''' 24 | It defines the handcrafted filters within the Key.Net handcrafted block 25 | ''' 26 | def __init__(self): 27 | super(handcrafted_block, self).__init__() 28 | 29 | def forward(self, x): 30 | 31 | sobel = kornia.spatial_gradient(x) 32 | dx, dy = sobel[:, :, 0, :, :], sobel[:, :, 1, :, :] 33 | 34 | sobel_dx = kornia.spatial_gradient(dx) 35 | dxx, dxy = sobel_dx[:, :, 0, :, :], sobel_dx[:, :, 1, :, :] 36 | 37 | sobel_dy = kornia.spatial_gradient(dy) 38 | dyy = sobel_dy[:, :, 1, :, :] 39 | 40 | hc_feats = torch.cat([dx, dy, dx**2., dy**2., dx*dy, dxy, dxy**2., dxx, dyy, dxx*dyy], dim=1) 41 | 42 | return hc_feats 43 | 44 | 45 | class learnable_block(nn.Module): 46 | ''' 47 | It defines the learnable blocks within the Key.Net 48 | ''' 49 | def __init__(self, in_channels=10): 50 | super(learnable_block, self).__init__() 51 | 52 | self.conv0 = conv_blck(in_channels) 53 | self.conv1 = conv_blck() 54 | self.conv2 = conv_blck() 55 | 56 | def forward(self, x): 57 | x = self.conv2(self.conv1(self.conv0(x))) 58 | return x 59 | 60 | 61 | def conv_blck(in_channels=8, out_channels=8, kernel_size=5, 62 | stride=1, padding=2, dilation=1): 63 | ''' 64 | Default learnable convolutional block. 65 | ''' 66 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, 67 | stride, padding, dilation), 68 | nn.BatchNorm2d(out_channels), 69 | nn.ReLU(inplace=True)) 70 | 71 | 72 | class NonMaxSuppression(torch.nn.Module): 73 | ''' 74 | NonMaxSuppression class 75 | ''' 76 | def __init__(self, thr=0.0, nms_size=5): 77 | nn.Module.__init__(self) 78 | padding = nms_size // 2 79 | self.max_filter = torch.nn.MaxPool2d(kernel_size=nms_size, stride=1, padding=padding) 80 | self.thr = thr 81 | 82 | def forward(self, scores): 83 | 84 | # local maxima 85 | maxima = (scores == self.max_filter(scores)) 86 | 87 | # remove low peaks 88 | maxima *= (scores > self.thr) 89 | 90 | return maxima.nonzero().t()[2:4] -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.modules import feature_extractor 5 | from model.kornia_tools.utils import custom_pyrdown 6 | 7 | class KeyNet(nn.Module): 8 | ''' 9 | Key.Net model definition 10 | ''' 11 | def __init__(self, keynet_conf): 12 | super(KeyNet, self).__init__() 13 | 14 | num_filters = keynet_conf['num_filters'] 15 | self.num_levels = keynet_conf['num_levels'] 16 | kernel_size = keynet_conf['kernel_size'] 17 | padding = kernel_size // 2 18 | 19 | self.feature_extractor = feature_extractor() 20 | self.last_conv = nn.Sequential(nn.Conv2d(in_channels=num_filters*self.num_levels, 21 | out_channels=1, kernel_size=kernel_size, padding=padding), 22 | nn.ReLU(inplace=True)) 23 | 24 | def forward(self, x): 25 | """ 26 | x - input image 27 | """ 28 | shape_im = x.shape 29 | for i in range(self.num_levels): 30 | if i == 0: 31 | feats = self.feature_extractor(x) 32 | else: 33 | x = custom_pyrdown(x, factor=1.2) 34 | feats_i = self.feature_extractor(x) 35 | feats_i = F.interpolate(feats_i, size=(shape_im[2], shape_im[3]), mode='bilinear') 36 | feats = torch.cat([feats, feats_i], dim=1) 37 | 38 | scores = self.last_conv(feats) 39 | return scores 40 | -------------------------------------------------------------------------------- /model/weights/keynet_pytorch.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axelBarroso/Key.Net-Pytorch/74fe0a75998445bb47b74b642b764543b25dc90a/model/weights/keynet_pytorch.pth --------------------------------------------------------------------------------