├── configs ├── readme.md ├── AWDesc_eva.yaml ├── SuperPoint_train.yaml ├── AWDesc_train_CA.yaml └── AWDesc_train_Tiny.yaml ├── nets ├── vit │ ├── __init__.py │ ├── vit_seg_configs.py │ └── vit_seg_modeling_resnet_skip.py ├── utils.py ├── __init__.py └── network.py ├── ckpt ├── awdesc-ca │ └── readme.md ├── awdesc-t16 │ └── readme.md └── awdesc-t32 │ └── readme.md ├── evaluation_hpatch ├── __init__.py ├── hpatches_sequences │ ├── readme.md │ ├── cache │ │ └── mtldesc.npy │ ├── convert_to_png.sh │ ├── download_cache.sh │ ├── download.sh │ └── README.md ├── models │ ├── __init__.py │ └── MTLDesc.py ├── utils │ ├── logger.py │ ├── common.py │ ├── utils.py │ ├── d2net_pyramid.py │ ├── d2net_utils.py │ ├── evaluator.py │ └── evaluation_tools.py ├── export.py └── hpatch_related │ └── hpatch_dataset.py ├── readme.md ├── .gitattributes ├── trainers ├── __init__.py ├── utils.py ├── base_trainer.py ├── mtldesc_trainer.py ├── awdesc_trainer.py └── superpoint_trainer.py ├── .gitignore ├── data_utils ├── __init__.py ├── megadepth_train_dataset.py └── megadepth_train_dataset_dl.py ├── requirements.txt ├── utils ├── logger.py └── evaluation_tools.py ├── LICENSE ├── train.py └── README.md /configs/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/vit/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ckpt/awdesc-ca/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ckpt/awdesc-t16/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ckpt/awdesc-t32/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /evaluation_hpatch/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /evaluation_hpatch/hpatches_sequences/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | Attention Weighted Local Descriptors 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=python 2 | *.py linguist-language=python 3 | -------------------------------------------------------------------------------- /nets/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class qkv_transform(nn.Conv1d): 5 | """Conv1d for qkv_transform""" 6 | 7 | -------------------------------------------------------------------------------- /evaluation_hpatch/hpatches_sequences/cache/mtldesc.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vignywang/AWDesc/HEAD/evaluation_hpatch/hpatches_sequences/cache/mtldesc.npy -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | def get_trainer(name): 2 | f_name, c_name = name.split('.') 3 | mod = __import__('{}.{}'.format(__name__, f_name), fromlist=['']) 4 | return getattr(mod, c_name) 5 | 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | magicpoint_ckpt/ 3 | magicpoint_log/ 4 | megpoint_ckpt/ 5 | megpoint_log/ 6 | superpoint_ckpt/ 7 | superpoint_log/ 8 | ckpt/ 9 | log/ 10 | 11 | # configs/ 12 | 13 | __pycache__/ 14 | data_utils/__pycache__/ 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/8/26 3 | # 4 | 5 | 6 | def get_dataset(name): 7 | f_name, c_name = name.split('.') 8 | mod = __import__('{}.{}'.format(__name__, f_name), fromlist=['']) 9 | return getattr(mod, c_name) 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /evaluation_hpatch/models/__init__.py: -------------------------------------------------------------------------------- 1 | def get_model(name): 2 | mod = __import__('{}.{}'.format(__name__, name), fromlist=['']) 3 | return getattr(mod, _module_to_class(name)) 4 | 5 | 6 | def _module_to_class(name): 7 | return ''.join(n.capitalize() for n in name.split('_')) 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /evaluation_hpatch/hpatches_sequences/convert_to_png.sh: -------------------------------------------------------------------------------- 1 | # DELF Extraction script doesn't support .ppm images. 2 | current_dir=`pwd` 3 | echo $current_dir 4 | for dir in `ls hpatches-sequences-release`; do 5 | echo $dir 6 | cd hpatches-sequences-release/$dir 7 | mogrify -format png *.ppm 8 | cd $current_dir 9 | done 10 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/8/25 3 | # 4 | 5 | 6 | def get_model(name): 7 | # f_name, c_name = name.split('.') 8 | names = name.split('.') 9 | f_name = '.'.join(names[:-1]) 10 | c_name = names[-1] 11 | mod = __import__('{}.{}'.format(__name__, f_name), fromlist=['']) 12 | return getattr(mod, c_name) 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /evaluation_hpatch/hpatches_sequences/download_cache.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | wget https://dsmn.ml/files/d2-net/hpatches-sequences-cache.tar.gz 4 | tar xvzf hpatches-sequences-cache.tar.gz 5 | rm -rf hpatches-sequences-cache.tar.gz 6 | 7 | wget https://dsmn.ml/files/d2-net/hpatches-sequences-cache-top.tar.gz 8 | tar xvzf hpatches-sequences-cache-top.tar.gz 9 | rm -rf hpatches-sequences-cache-top.tar.gz 10 | 11 | -------------------------------------------------------------------------------- /evaluation_hpatch/hpatches_sequences/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Download the dataset 4 | wget http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz 5 | 6 | # Extract the dataset 7 | tar xvzf hpatches-sequences-release.tar.gz 8 | 9 | # Remove the high-resolution sequences 10 | cd hpatches-sequences-release 11 | rm -rf i_contruction i_crownnight i_dc i_pencils i_whitebuilding v_artisans v_astronautis v_talent 12 | cd .. 13 | -------------------------------------------------------------------------------- /configs/AWDesc_eva.yaml: -------------------------------------------------------------------------------- 1 | hpatches: 2 | dataset_dir: hpatches_sequences/hpatches-sequences-release 3 | resize: false 4 | grayscale: false 5 | 6 | model: 7 | name: MTLDesc 8 | backbone: network.MTLDesc 9 | detection_threshold: 0.9 10 | nms_dist: 4 11 | nms_radius: 4 12 | border_remove: 4 13 | weight_path: "../ckpt" 14 | ckpt_name: mtl_mtldesc_0 #mtl_mtl_6 #scalepoint_evo_old #scalepoint_mulhead 15 | weights_id: '29' 16 | 17 | keys: keypoints,descriptors,shape 18 | output_type: normal #benchmark normal 19 | 20 | 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | certifi==2016.2.28 3 | cffi==1.10.0 4 | contextlib2==21.6.0 5 | cycler==0.11.0 6 | decorator==4.4.2 7 | imageio==2.13.5 8 | imgaug==0.4.0 9 | kiwisolver==1.3.1 10 | matplotlib==3.1.3 11 | ml-collections==0.1.0 12 | networkx==2.5.1 13 | numpy == 3.6 14 | olefile==0.44 15 | opencv-python==3.4.0.12 16 | Pillow==8.4.0 17 | protobuf==3.19.3 18 | pycparser==2.18 19 | pyparsing==3.0.6 20 | python-dateutil==2.8.2 21 | PyWavelets==1.1.1 22 | PyYAML==5.3 23 | scikit-image==0.17.2 24 | scipy==1.1.0 25 | Shapely==1.8.0 26 | six==1.10.0 27 | tensorboardX==2.1 28 | tifffile==2020.9.3 29 | torch==1.2.0+cu92 30 | torchvision==0.4.0+cu92 31 | -------------------------------------------------------------------------------- /configs/SuperPoint_train.yaml: -------------------------------------------------------------------------------- 1 | name: superpoint 2 | trainer: superpoint_trainer.SuperPoint 3 | 4 | model: 5 | backbone: network.SuperPointNet 6 | 7 | train: 8 | adjust_lr: true 9 | lr: 0.001 10 | weight_decay: 0.0001 11 | lr_mod: LambdaLR 12 | batch_size: 12 13 | epoch_num: 30 14 | maintain_epoch: 0 15 | decay_epoch: 30 16 | log_freq: 100 17 | num_workers: 8 18 | validate_after: 1000 19 | 20 | dataset: megadepth_train_dataset.MegaDepthTrainDataset 21 | mega_image_dir: /data/Mega_train/image 22 | mega_keypoint_dir: /data/Mega_train/keypoint 23 | mega_despoint_dir: /data/Mega_train/despoint 24 | height: 400 25 | width: 400 26 | 27 | fix_grid_option: 400 28 | fix_sample: false 29 | rotation_option: none 30 | do_augmentation: true 31 | sydesp_type: nomal # random 32 | point_loss_weight: 200 33 | w_weight: 0.1 34 | -------------------------------------------------------------------------------- /configs/AWDesc_train_CA.yaml: -------------------------------------------------------------------------------- 1 | name: mtl 2 | trainer: mtldesc_trainer.MTLDescTrainer 3 | 4 | model: 5 | backbone: network.MTLDesc 6 | 7 | train: 8 | adjust_lr: true 9 | lr: 0.001 10 | weight_decay: 0.0001 11 | lr_mod: LambdaLR 12 | batch_size: 12 13 | epoch_num: 30 14 | maintain_epoch: 0 15 | decay_epoch: 30 16 | log_freq: 100 17 | num_workers: 8 18 | validate_after: 1000 19 | 20 | dataset: megadepth_train_dataset.MegaDepthTrainDataset 21 | mega_image_dir: /data/Mega_train/image 22 | mega_keypoint_dir: /data/Mega_train/keypoint 23 | mega_despoint_dir: /data/Mega_train/despoint 24 | height: 400 25 | width: 400 26 | 27 | T: 15 28 | fix_grid_option: 400 29 | fix_sample: false 30 | rotation_option: none 31 | do_augmentation: true 32 | sydesp_type: nomal # random 33 | point_loss_weight: 200 34 | w_weight: 0.1 35 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created by Yuyang on 19-3-31 3 | # 4 | 5 | import logging 6 | import time 7 | 8 | 9 | def get_logger(log_root): 10 | # create a logger 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | # create a formatter 14 | formatter = logging.Formatter( 15 | fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' 16 | ) 17 | # writing 18 | log_dir = log_root + "/" 19 | c_t = time.strftime('%Y-%m-%d %H%M%S', time.localtime(time.time())) 20 | # if not os.path.exists(log_dir): 21 | # os.mkdir(log_dir) 22 | log_name = log_dir + c_t + '.log' 23 | handler_file = logging.FileHandler(log_name, mode='w') 24 | handler_file.setFormatter(formatter) 25 | 26 | # showing 27 | handler = logging.StreamHandler() 28 | handler.setFormatter(formatter) 29 | 30 | logger.addHandler(handler_file) 31 | logger.addHandler(handler) 32 | return logger 33 | -------------------------------------------------------------------------------- /evaluation_hpatch/utils/logger.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created by Yuyang on 19-3-31 3 | # 4 | 5 | import logging 6 | import time 7 | 8 | 9 | def get_logger(log_root): 10 | # create a logger 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | # create a formatter 14 | formatter = logging.Formatter( 15 | fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' 16 | ) 17 | # writing 18 | log_dir = log_root + "/" 19 | c_t = time.strftime('%Y-%m-%d %H%M%S', time.localtime(time.time())) 20 | # if not os.path.exists(log_dir): 21 | # os.mkdir(log_dir) 22 | log_name = log_dir + c_t + '.log' 23 | handler_file = logging.FileHandler(log_name, mode='w') 24 | handler_file.setFormatter(formatter) 25 | 26 | # showing 27 | handler = logging.StreamHandler() 28 | handler.setFormatter(formatter) 29 | 30 | logger.addHandler(handler_file) 31 | logger.addHandler(handler) 32 | return logger 33 | -------------------------------------------------------------------------------- /configs/AWDesc_train_Tiny.yaml: -------------------------------------------------------------------------------- 1 | name: t16 # or t32 2 | trainer: awdesc_trainer.AWDescTrainer 3 | 4 | model: 5 | backbone: network.Lite16 # or network.Lite32 6 | 7 | train: 8 | adjust_lr: true 9 | lr: 0.001 10 | weight_decay: 0.0001 11 | lr_mod: LambdaLR 12 | batch_size: 12 13 | epoch_num: 30 14 | maintain_epoch: 0 15 | decay_epoch: 30 16 | log_freq: 100 17 | num_workers: 16 18 | validate_after: 1000 19 | 20 | dataset: megadepth_train_dataset_dl.MegaDepthTrainDataset 21 | mega_image_dir: /data/Mega_train/image 22 | mega_keypoint_dir: /data/Mega_train/keypoint 23 | mega_despoint_dir: /data/Mega_train/despoint 24 | mega_dl_dir1: /data/Mega_train/Mega_train/dl_teacher0 25 | mega_dl_dir2: /data/Mega_train/Mega_train/dl_teacher1 26 | height: 400 27 | width: 400 28 | balance: 0.5 29 | T: 15 30 | fix_grid_option: 400 31 | fix_sample: false 32 | rotation_option: none 33 | do_augmentation: true 34 | sydesp_type: nomal # random 35 | point_loss_weight: 200 36 | w_weight: 0.1 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 vignywang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /evaluation_hpatch/utils/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright 2017, Zixin Luo, HKUST. 4 | Commonly used functions 5 | """ 6 | 7 | from datetime import datetime 8 | 9 | 10 | class ClassProperty(property): 11 | """For dynamically obtaining system time""" 12 | 13 | def __get__(self, cls, owner): 14 | return classmethod(self.fget).__get__(None, owner)() 15 | 16 | 17 | class Notify(object): 18 | """Colorful printing prefix. 19 | A quick example: 20 | print(Notify.INFO, YOUR TEXT, Notify.ENDC) 21 | """ 22 | 23 | def __init__(self): 24 | pass 25 | 26 | @ClassProperty 27 | def HEADER(cls): 28 | return str(datetime.now()) + ': \033[95m' 29 | 30 | @ClassProperty 31 | def INFO(cls): 32 | return str(datetime.now()) + ': \033[92mI' 33 | 34 | @ClassProperty 35 | def OKBLUE(cls): 36 | return str(datetime.now()) + ': \033[94m' 37 | 38 | @ClassProperty 39 | def WARNING(cls): 40 | return str(datetime.now()) + ': \033[93mW' 41 | 42 | @ClassProperty 43 | def FAIL(cls): 44 | return str(datetime.now()) + ': \033[91mF' 45 | 46 | @ClassProperty 47 | def BOLD(cls): 48 | return str(datetime.now()) + ': \033[1mB' 49 | 50 | @ClassProperty 51 | def UNDERLINE(cls): 52 | return str(datetime.now()) + ': \033[4mU' 53 | ENDC = '\033[0m' 54 | -------------------------------------------------------------------------------- /evaluation_hpatch/hpatches_sequences/README.md: -------------------------------------------------------------------------------- 1 | # HPatches Sequences / Image Pairs Matching Benchmark 2 | 3 | Please check the [official repository](https://github.com/hpatches/hpatches-dataset) for more information regarding references. 4 | 5 | The dataset can be downloaded by running `bash download.sh` - this script downloads and extracts the HPatches Sequences dataset and removes the sequences containing high resolution images (`> 1600x1200`) as mentioned in the D2-Net paper. You can also download the cache with results for all methods from the D2-Net paper by running `bash download_cache.sh`. 6 | 7 | New methods can be added in cell 4 of the notebook. The local features are supposed to be stored in the [`npz`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez.html) format with three fields: 8 | 9 | - `keypoints` - `N x 2` matrix with `x, y` coordinates of each keypoint in COLMAP format (the `X` axis points to the right, the `Y` axis to the bottom) 10 | 11 | - `scores` - `N` array with detection scores for each keypoint (higher is better) - only required for the "top K" version of the benchmark 12 | 13 | - `descriptors` - `N x D` matrix with the descriptors (L2 normalized if you plan on using the provided mutual nearest neighbors matcher) 14 | 15 | Moreover, the `npz` files are supposed to be saved alongside their corresponding images with the same extension as the `method` (e.g. if `method = d2-net`, the features for the image `hpatches-sequences-release/i_ajuntament/1.ppm` should be in the file `hpatches-sequences-release/i_ajuntament/1.ppm.d2-net`). 16 | 17 | We provide a simple script to extract Hessian Affine keypoints with SIFT descriptors (`extract_hesaff.m`); this script requires MATLAB and [VLFeat](http://www.vlfeat.org/). 18 | 19 | D2-Net features can be extracted by running: 20 | ``` 21 | python extract_features.py --image_list_file image_list_hpatches_sequences.txt 22 | ``` 23 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/08/25 3 | # 4 | import os 5 | import yaml 6 | from pathlib import Path 7 | 8 | import argparse 9 | import torch 10 | import numpy as np 11 | 12 | from utils.logger import get_logger 13 | from trainers import get_trainer 14 | 15 | 16 | def setup_seed(): 17 | # make the result reproducible 18 | torch.manual_seed(3928) 19 | torch.cuda.manual_seed_all(2342) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | np.random.seed(2933) 23 | 24 | 25 | def write_config(logger, prefix, config): 26 | for k, v in config.items(): 27 | if isinstance(v, dict): 28 | logger.info('{}: '.format(k)) 29 | write_config(logger, prefix+' '*4, v) 30 | else: 31 | logger.info('{}{}: {}'.format(prefix, k, v)) 32 | 33 | 34 | def main(): 35 | setup_seed() 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--gpus', type=str, default='0') 38 | parser.add_argument('--configs', type=str, required=True) 39 | parser.add_argument('--indicator', type=str, required=True) 40 | args = parser.parse_args() 41 | 42 | # read configs 43 | with open(args.configs, 'r') as f: 44 | config = yaml.load(f) 45 | 46 | # initialize ckpt_path 47 | ckpt_path = Path('ckpt', config['name']+'_'+args.indicator) 48 | ckpt_path.mkdir(parents=True, exist_ok=True) 49 | config['ckpt_path'] = str(ckpt_path) 50 | 51 | # initialize logger 52 | log_path = Path('log', config['name']+'_'+args.indicator) 53 | log_path.mkdir(parents=True, exist_ok=True) 54 | config['logger'] = get_logger(str(log_path)) 55 | 56 | # write config 57 | write_config(config['logger'], '', config) 58 | 59 | # set gpu devices 60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 61 | args.gpus = [i for i in range(len(args.gpus.split(',')))] 62 | config['logger'].info("Set CUDA_VISIBLE_DEVICES to %s" % args.gpus) 63 | 64 | # initialize trainer and train 65 | with get_trainer(config['trainer'])(**config) as trainer: 66 | trainer.train() 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AWDesc (Local features detection and description) 2 | 3 | Implementation of Attention Weighted Local Descriptors (TPAMI2023). 4 | 5 | Unofficial Pytorch implementation of SuperPoint. 6 | 7 | To do: 8 | - [x] Evaluation code and Trained model for AWDesc 9 | - [x] Training code 10 | - [x] Training code of SuperPoint 11 | - [ ] More detailed readme (Coming soon) 12 | 13 | # Requirement 14 | ``` 15 | pip install -r requirement.txt, 16 | ``` 17 | 18 | # Quick start 19 | HPatches Image Matching Benchmark 20 | 21 | 1.Download the trained model: 22 | 23 | AWDesc_CA: 24 | 25 | https://drive.google.com/file/d/1qrvdd3KVYFl6EwH8s5IS5p_Hs26xIKRD/view?usp=sharing 26 | 27 | AWDesc_Tiny: 28 | 29 | https://drive.google.com/drive/folders/1PGHiGojkE7qCp1T-l9JSn4aJ7gN0_ua6?usp=sharing 30 | 31 | and place it in the "ckpt/mtldesc". 32 | 33 | 34 | 2.Download the HPatches dataset: 35 | 36 | ``` 37 | cd evaluation_hpatch/hpatches_sequences 38 | bash download.sh 39 | ``` 40 | 3.Extract local descriptors: 41 | ``` 42 | cd evaluation_hpatch 43 | CUDA_VISIBLE_DEVICES=0 python export.py --tag [Descriptor_suffix_name] --top-k 10000 --output_root [out_dir] --config ../configs/MTLDesc_eva.yaml 44 | ``` 45 | 4.Evaluation 46 | ``` 47 | cd evaluation_hpatch/hpatches_sequences 48 | jupyter-notebook 49 | 50 | run HPatches-Sequences-Matching-Benchmark.ipynb 51 | ``` 52 | 53 | ## Training 54 | AWDesc-CA 55 | 56 | Download dataset: https://drive.google.com/file/d/1Uz0hVFPxWsE71V77kXZ973iY2GuXC20b/view?usp=sharing 57 | 58 | Set the dataset path in the configuration file configs/AWDesc_train_CA.yaml 59 | 60 | ``` 61 | mega_image_dir: /data/Mega_train/image #images 62 | mega_keypoint_dir: /data/Mega_train/keypoint #keypoints 63 | mega_despoint_dir: /data/Mega_train/despoint #descriptor correspondence points 64 | ``` 65 | ``` 66 | python train.py --gpus 0 --configs configs/AWDesc_train.yaml --indicator awdesc_ca 67 | ``` 68 | 69 | SuperPoint 70 | 71 | Set the dataset path in the configuration file configs/SuperPoint_train.yaml 72 | 73 | ``` 74 | mega_image_dir: /data/Mega_train/image #images 75 | mega_keypoint_dir: /data/Mega_train/keypoint #keypoints 76 | mega_despoint_dir: /data/Mega_train/despoint #descriptor correspondence points 77 | ``` 78 | ``` 79 | python train.py --gpus 0 --configs configs/SuperPoint_train.yaml --indicator superpoint 80 | ``` 81 | 82 | 83 | 84 | AWDesc-Tiny 85 | 86 | Download dataset: 87 | https://pan.baidu.com/s/1-1rpNxYsNl5fVRKB6EWo4A?pwd=elcb 88 | 89 | download code:elcb 90 | 91 | Set the dataset path in the configuration file configs/AWDesc_train_Tiny.yaml 92 | ``` 93 | mega_image_dir: /data/Mega_train/image #images 94 | mega_keypoint_dir: /data/Mega_train/keypoint #keypoints 95 | mega_despoint_dir: /data/Mega_train/despoint #descriptor correspondence points 96 | mega_dl_dir1: /data/Mega_train/dl_teacher0 #Knowledge extracted from teacher 97 | mega_dl_dir2: /data/Mega_train/dl_teacher1 #Knowledge extracted from teacher 98 | ``` 99 | ``` 100 | python train.py --gpus 0 --configs configs/AWDesc_train_Tiny.yaml --indicator awdesc_t16 101 | ``` 102 | -------------------------------------------------------------------------------- /evaluation_hpatch/utils/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/2/23 3 | # 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as f 7 | 8 | 9 | class Matcher(object): 10 | 11 | def __init__(self, dtype='float'): 12 | if dtype == 'float': 13 | self.compute_desp_dist = self._compute_desp_dist 14 | elif dtype == 'binary': 15 | self.compute_desp_dist = self._compute_desp_dist_binary 16 | else: 17 | assert False 18 | 19 | def __call__(self, point_0, desp_0, point_1, desp_1): 20 | dist_0_1 = self.compute_desp_dist(desp_0, desp_1) # [n,m] 21 | dist_1_0 = dist_0_1.transpose((1, 0)) # [m,n] 22 | nearest_idx_0_1 = np.argmin(dist_0_1, axis=1) # [n] 23 | nearest_idx_1_0 = np.argmin(dist_1_0, axis=1) # [m] 24 | matched_src = [] 25 | matched_tgt = [] 26 | for i, idx_0_1 in enumerate(nearest_idx_0_1): 27 | if i == nearest_idx_1_0[idx_0_1]: 28 | matched_src.append(point_0[i]) 29 | matched_tgt.append(point_1[idx_0_1]) 30 | if len(matched_src) <= 4: 31 | print("There exist too little matches") 32 | # assert False 33 | return None 34 | if len(matched_src) != 0: 35 | matched_src = np.stack(matched_src, axis=0) 36 | matched_tgt = np.stack(matched_tgt, axis=0) 37 | return matched_src, matched_tgt 38 | 39 | @staticmethod 40 | def _compute_desp_dist(desp_0, desp_1): 41 | # desp_0:[n,256], desp_1:[m,256] 42 | square_norm_0 = (np.linalg.norm(desp_0, axis=1, keepdims=True)) ** 2 # [n,1] 43 | square_norm_1 = (np.linalg.norm(desp_1, axis=1, keepdims=True).transpose((1, 0))) ** 2 # [1,m] 44 | xty = np.matmul(desp_0, desp_1.transpose((1, 0))) # [n,m] 45 | dist = np.sqrt((square_norm_0 + square_norm_1 - 2 * xty + 1e-4)) 46 | return dist 47 | 48 | @staticmethod 49 | def _compute_desp_dist_binary(desp_0, desp_1): 50 | # desp_0:[n,256], desp_1[m,256] 51 | dist_0_1 = np.logical_xor(desp_0[:, np.newaxis, :], desp_1[np.newaxis, :, :]).sum(axis=2) 52 | return dist_0_1 53 | 54 | 55 | def spatial_nms(prob, kernel_size=9): 56 | """ 57 | 利用max_pooling对预测的特征点的概率图进行非极大值抑制 58 | Args: 59 | prob: shape为[h,w]的概率图 60 | kernel_size: 对每个点进行非极大值抑制时的窗口大小 61 | 62 | Returns: 63 | 经非极大值抑制后的概率图 64 | """ 65 | padding = int(kernel_size//2) 66 | pooled = f.max_pool2d(prob, kernel_size=kernel_size, stride=1, padding=padding) 67 | prob = torch.where(torch.eq(prob, pooled), prob, torch.zeros_like(prob)) 68 | return prob 69 | 70 | 71 | def convert_cv2pt(cv_point): 72 | point_list = [] 73 | for i, cv_pt in enumerate(cv_point): 74 | pt = np.array((cv_pt.pt[1], cv_pt.pt[0])) # y,x的顺序 75 | point_list.append(pt) 76 | point = np.stack(point_list, axis=0) 77 | return point 78 | 79 | 80 | def model_size(model): 81 | ''' Computes the number of parameters of the model 82 | ''' 83 | size = 0 84 | for weights in model.state_dict().values(): 85 | size += np.prod(weights.shape) 86 | return size 87 | 88 | 89 | def torch_set_gpu(gpus): 90 | import os 91 | if type(gpus) is int: 92 | gpus = [gpus] 93 | 94 | cuda = all(gpu>=0 for gpu in gpus) 95 | 96 | if cuda: 97 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus]) 98 | assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % ( 99 | os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES']) 100 | torch.backends.cudnn.benchmark = True # speed-up cudnn 101 | torch.backends.cudnn.fastest = True # even more speed-up? 102 | print( 'Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES'] ) 103 | 104 | else: 105 | print( 'Launching on CPU' ) 106 | 107 | return cuda 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /evaluation_hpatch/utils/d2net_pyramid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from utils.d2net_utils import EmptyTensorError 6 | from utils.d2net_utils import interpolate_dense_features, upscale_positions 7 | 8 | 9 | def process_multiscale(image, model, scales=[.5, 1, 2]): 10 | b, _, h_init, w_init = image.size() 11 | device = image.device 12 | assert(b == 1) 13 | 14 | all_keypoints = torch.zeros([3, 0]) 15 | all_descriptors = torch.zeros([ 16 | model.dense_feature_extraction.num_channels, 0 17 | ]) 18 | all_scores = torch.zeros(0) 19 | 20 | previous_dense_features = None 21 | banned = None 22 | for idx, scale in enumerate(scales): 23 | current_image = F.interpolate( 24 | image, scale_factor=scale, 25 | mode='bilinear', align_corners=True 26 | ) 27 | _, _, h_level, w_level = current_image.size() 28 | 29 | dense_features = model.dense_feature_extraction(current_image) 30 | del current_image 31 | 32 | _, _, h, w = dense_features.size() 33 | 34 | # Sum the feature maps. 35 | if previous_dense_features is not None: 36 | dense_features += F.interpolate( 37 | previous_dense_features, size=[h, w], 38 | mode='bilinear', align_corners=True 39 | ) 40 | del previous_dense_features 41 | 42 | # Recover detections. 43 | detections = model.detection(dense_features) 44 | if banned is not None: 45 | banned = F.interpolate(banned.float(), size=[h, w]).bool() 46 | detections = torch.min(detections, ~banned) 47 | banned = torch.max( 48 | torch.max(detections, dim=1)[0].unsqueeze(1), banned 49 | ) 50 | else: 51 | banned = torch.max(detections, dim=1)[0].unsqueeze(1) 52 | fmap_pos = torch.nonzero(detections[0].cpu()).t() 53 | del detections 54 | 55 | # Recover displacements. 56 | displacements = model.localization(dense_features)[0].cpu() 57 | displacements_i = displacements[ 58 | 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :] 59 | ] 60 | displacements_j = displacements[ 61 | 1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :] 62 | ] 63 | del displacements 64 | 65 | mask = torch.min( 66 | torch.abs(displacements_i) < 0.5, 67 | torch.abs(displacements_j) < 0.5 68 | ) 69 | fmap_pos = fmap_pos[:, mask] 70 | valid_displacements = torch.stack([ 71 | displacements_i[mask], 72 | displacements_j[mask] 73 | ], dim=0) 74 | del mask, displacements_i, displacements_j 75 | 76 | fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements 77 | del valid_displacements 78 | 79 | try: 80 | raw_descriptors, _, ids = interpolate_dense_features( 81 | fmap_keypoints.to(device), 82 | dense_features[0] 83 | ) 84 | except EmptyTensorError: 85 | continue 86 | fmap_pos = fmap_pos[:, ids] 87 | fmap_keypoints = fmap_keypoints[:, ids] 88 | del ids 89 | 90 | keypoints = upscale_positions(fmap_keypoints, scaling_steps=2) 91 | del fmap_keypoints 92 | 93 | descriptors = F.normalize(raw_descriptors, dim=0).cpu() 94 | del raw_descriptors 95 | 96 | keypoints[0, :] *= h_init / h_level 97 | keypoints[1, :] *= w_init / w_level 98 | 99 | fmap_pos = fmap_pos.cpu() 100 | keypoints = keypoints.cpu() 101 | 102 | keypoints = torch.cat([ 103 | keypoints, 104 | torch.ones([1, keypoints.size(1)]) * 1 / scale, 105 | ], dim=0) 106 | 107 | scores = dense_features[ 108 | 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :] 109 | ].cpu() / (idx + 1) 110 | del fmap_pos 111 | 112 | all_keypoints = torch.cat([all_keypoints, keypoints], dim=1) 113 | all_descriptors = torch.cat([all_descriptors, descriptors], dim=1) 114 | all_scores = torch.cat([all_scores, scores], dim=0) 115 | del keypoints, descriptors 116 | 117 | previous_dense_features = dense_features 118 | del dense_features 119 | del previous_dense_features, banned 120 | 121 | keypoints = all_keypoints.t().numpy() 122 | del all_keypoints 123 | scores = all_scores.numpy() 124 | del all_scores 125 | descriptors = all_descriptors.t().numpy() 126 | del all_descriptors 127 | return keypoints, scores, descriptors 128 | -------------------------------------------------------------------------------- /nets/vit/vit_seg_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | def get_b16_config(): 4 | """Returns the ViT-B/16 configuration.""" 5 | config = ml_collections.ConfigDict() 6 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 7 | config.hidden_size = 768 8 | config.transformer = ml_collections.ConfigDict() 9 | config.transformer.mlp_dim = 3072 10 | config.transformer.num_heads = 12 11 | config.transformer.num_layers = 12 12 | config.transformer.attention_dropout_rate = 0.0 13 | config.transformer.dropout_rate = 0.1 14 | 15 | config.classifier = 'seg' 16 | config.representation_size = None 17 | config.resnet_pretrained_path = None 18 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' 19 | config.patch_size = 16 20 | 21 | config.decoder_channels = (256, 128, 64, 16) 22 | config.n_classes = 2 23 | config.activation = 'softmax' 24 | return config 25 | 26 | def get_testing(): 27 | """Returns a minimal configuration for testing.""" 28 | config = ml_collections.ConfigDict() 29 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 30 | config.hidden_size = 1 31 | config.transformer = ml_collections.ConfigDict() 32 | config.transformer.mlp_dim = 1 33 | config.transformer.num_heads = 1 34 | config.transformer.num_layers = 1 35 | config.transformer.attention_dropout_rate = 0.0 36 | config.transformer.dropout_rate = 0.1 37 | config.classifier = 'token' 38 | config.representation_size = None 39 | return config 40 | 41 | def get_r50_b16_config(): 42 | """Returns the Resnet50 + ViT-B/16 configuration.""" 43 | config = get_b16_config() 44 | config.patches.grid = (16, 16) 45 | config.resnet = ml_collections.ConfigDict() 46 | config.resnet.num_layers = (3, 4, 9) 47 | config.resnet.width_factor = 1 48 | 49 | config.classifier = 'seg' 50 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 51 | config.decoder_channels = (256, 128, 64, 16) 52 | config.skip_channels = [512, 256, 64, 16] 53 | config.n_classes = 2 54 | config.n_skip = 3 55 | config.activation = 'softmax' 56 | 57 | return config 58 | 59 | 60 | def get_b32_config(): 61 | """Returns the ViT-B/32 configuration.""" 62 | config = get_b16_config() 63 | config.patches.size = (32, 32) 64 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' 65 | return config 66 | 67 | 68 | def get_l16_config(): 69 | """Returns the ViT-L/16 configuration.""" 70 | config = ml_collections.ConfigDict() 71 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 72 | config.hidden_size = 128 73 | config.transformer = ml_collections.ConfigDict() 74 | config.transformer.mlp_dim = 2048 75 | config.transformer.num_heads = 8 76 | config.transformer.num_layers = 12 77 | config.transformer.attention_dropout_rate = 0.0 78 | config.transformer.dropout_rate = 0.1 79 | config.representation_size = None 80 | 81 | # custom 82 | config.classifier = 'seg' 83 | config.resnet_pretrained_path = None 84 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' 85 | config.decoder_channels = (256, 128, 64, 16) 86 | config.n_classes = 2 87 | config.activation = 'softmax' 88 | return config 89 | 90 | 91 | def get_r50_l16_config(): 92 | """Returns the Resnet50 + ViT-L/16 configuration. customized """ 93 | config = get_l16_config() 94 | config.patches.grid = (16, 16) 95 | config.resnet = ml_collections.ConfigDict() 96 | config.resnet.num_layers = (3, 4, 9) 97 | config.resnet.width_factor = 1 98 | 99 | config.classifier = 'seg' 100 | config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 101 | config.decoder_channels = (256, 128, 64, 16) 102 | config.skip_channels = [512, 256, 64, 16] 103 | config.n_classes = 2 104 | config.activation = 'softmax' 105 | return config 106 | 107 | 108 | def get_l32_config(): 109 | """Returns the ViT-L/32 configuration.""" 110 | config = get_l16_config() 111 | config.patches.size = (32, 32) 112 | return config 113 | 114 | 115 | def get_h14_config(): 116 | """Returns the ViT-L/16 configuration.""" 117 | config = ml_collections.ConfigDict() 118 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 119 | config.hidden_size = 1280 120 | config.transformer = ml_collections.ConfigDict() 121 | config.transformer.mlp_dim = 5120 122 | config.transformer.num_heads = 16 123 | config.transformer.num_layers = 32 124 | config.transformer.attention_dropout_rate = 0.0 125 | config.transformer.dropout_rate = 0.1 126 | config.classifier = 'token' 127 | config.representation_size = None 128 | 129 | return config 130 | -------------------------------------------------------------------------------- /trainers/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/8/31 3 | # 4 | from collections import defaultdict 5 | 6 | import cv2 7 | import torch 8 | import torch.nn.functional as f 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | from PIL import Image 11 | import numpy as np 12 | 13 | 14 | class PolynomialLR(_LRScheduler): 15 | def __init__(self, optimizer, step_size, iter_max, power, last_epoch=-1): 16 | self.step_size = step_size 17 | self.iter_max = iter_max 18 | self.power = power 19 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 20 | 21 | def polynomial_decay(self, lr): 22 | return lr * (1 - float(self.last_epoch) / self.iter_max) ** self.power 23 | 24 | def get_lr(self): 25 | if ( 26 | (self.last_epoch == 0) 27 | or (self.last_epoch % self.step_size != 0) 28 | or (self.last_epoch > self.iter_max) 29 | ): 30 | return [group["lr"] for group in self.optimizer.param_groups] 31 | return [self.polynomial_decay(lr) for lr in self.base_lrs] 32 | 33 | 34 | def resize_labels(labels, size): 35 | """ 36 | Downsample labels for 0.5x and 0.75x logits by nearest interpolation. 37 | Other nearest methods result in misaligned labels. 38 | -> F.interpolate(labels, shape, mode='nearest') 39 | -> cv2.resize(labels, shape, interpolation=cv2.INTER_NEAREST) 40 | """ 41 | new_labels = [] 42 | for label in labels: 43 | label = label.float().numpy() 44 | label = Image.fromarray(label).resize(size, resample=Image.NEAREST) 45 | new_labels.append(np.asarray(label)) 46 | new_labels = torch.LongTensor(new_labels) 47 | return new_labels 48 | 49 | 50 | def resize_labels_torch(label, size, mode='nearest'): 51 | """ 52 | similar like resize_labels, but direct do it for torch tensors, 53 | size: require shape of [h,w] 54 | label: require shape of [bt,h,w] 55 | """ 56 | label = f.interpolate(label.unsqueeze(dim=1), size, mode=mode).squeeze(dim=1) 57 | return label 58 | 59 | 60 | def _fast_hist(label_true, label_pred, n_class): 61 | mask = (label_true >= 0) & (label_true < n_class) 62 | hist = np.bincount( 63 | n_class * label_true[mask].astype(int) + label_pred[mask], 64 | minlength=n_class ** 2, 65 | ).reshape(n_class, n_class) 66 | return hist 67 | 68 | 69 | def scores(label_trues, label_preds, n_class): 70 | hist = np.zeros((n_class, n_class)) 71 | for lt, lp in zip(label_trues, label_preds): 72 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 73 | acc = np.diag(hist).sum() / hist.sum() 74 | acc_cls = np.diag(hist) / hist.sum(axis=1) 75 | acc_cls = np.nanmean(acc_cls) 76 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 77 | valid = hist.sum(axis=1) > 0 # added 78 | mean_iu = np.nanmean(iu[valid]) 79 | freq = hist.sum(axis=1) / hist.sum() 80 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 81 | cls_iu = dict(zip(range(n_class), iu)) 82 | 83 | return { 84 | "Pixel Accuracy": acc, 85 | "Mean Accuracy": acc_cls, 86 | "Frequency Weighted IoU": fwavacc, 87 | "Mean IoU": mean_iu, 88 | # "Class IoU": cls_iu, 89 | } 90 | 91 | 92 | class DepthEvaluator(object): 93 | 94 | def __init__(self): 95 | self.errors = defaultdict(list) 96 | self.min_depth = 1e-3 97 | self.max_depth = 70 98 | 99 | def reset(self): 100 | self.errors = defaultdict(list) 101 | 102 | def val(self): 103 | error = {} 104 | for k, v in self.errors.items(): 105 | error[k] = np.stack(v).mean() 106 | return error 107 | 108 | def eval(self, pred_depth, gt_depth): 109 | gt_depth = np.clip(gt_depth, a_min=None, a_max=self.max_depth) 110 | 111 | # mask = np.logical_and(gt_depth > self.min_depth, gt_depth <= self.max_depth) 112 | mask = gt_depth > self.min_depth 113 | 114 | scalor = np.median(gt_depth[mask]) / np.median(pred_depth[mask]) 115 | pred_depth[mask] *= scalor 116 | 117 | pred_depth[pred_depth < self.min_depth] = self.min_depth 118 | pred_depth[pred_depth > self.max_depth] = self.max_depth 119 | 120 | errors = self.compute_errors(gt_depth[mask], pred_depth[mask]) 121 | for k, v in errors.items(): 122 | self.errors[k].append(v) 123 | 124 | @staticmethod 125 | def compute_errors(gt, pred): 126 | thresh = np.maximum((gt / pred), (pred / gt)) 127 | a1 = (thresh < 1.25).mean() 128 | a2 = (thresh < 1.25 ** 2).mean() 129 | a3 = (thresh < 1.25 ** 3).mean() 130 | 131 | rmse = (gt - pred) ** 2 132 | rmse = np.sqrt(rmse.mean()) 133 | 134 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 135 | rmse_log = np.sqrt(rmse_log.mean()) 136 | 137 | abs_rel = np.abs((gt - pred) / gt) 138 | abs_rel = abs_rel.mean() 139 | 140 | sq_rel = (gt - pred) ** 2 / gt 141 | sq_rel = sq_rel.mean() 142 | 143 | return { 144 | 'abs_rel': abs_rel, 145 | 'sq_rel': sq_rel, 146 | 'rmse': rmse, 147 | 'rmse_log': rmse_log, 148 | 'a1': a1, 149 | 'a2': a2, 150 | 'a3': a3, 151 | } 152 | 153 | 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /evaluation_hpatch/utils/d2net_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import numpy as np 4 | 5 | import torch 6 | 7 | 8 | class EmptyTensorError(Exception): 9 | pass 10 | 11 | 12 | class NoGradientError(Exception): 13 | pass 14 | 15 | 16 | def preprocess_image(image, preprocessing=None): 17 | image = image.astype(np.float32) 18 | image = np.transpose(image, [2, 0, 1]) 19 | if preprocessing is None: 20 | pass 21 | elif preprocessing == 'caffe': 22 | # RGB -> BGR 23 | image = image[:: -1, :, :] 24 | # Zero-center by mean pixel 25 | mean = np.array([103.939, 116.779, 123.68]) 26 | image = image - mean.reshape([3, 1, 1]) 27 | elif preprocessing == 'torch': 28 | image /= 255.0 29 | mean = np.array([0.485, 0.456, 0.406]) 30 | std = np.array([0.229, 0.224, 0.225]) 31 | image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1]) 32 | else: 33 | raise ValueError('Unknown preprocessing parameter.') 34 | return image 35 | 36 | 37 | def imshow_image(image, preprocessing=None): 38 | if preprocessing is None: 39 | pass 40 | elif preprocessing == 'caffe': 41 | mean = np.array([103.939, 116.779, 123.68]) 42 | image = image + mean.reshape([3, 1, 1]) 43 | # RGB -> BGR 44 | image = image[:: -1, :, :] 45 | elif preprocessing == 'torch': 46 | mean = np.array([0.485, 0.456, 0.406]) 47 | std = np.array([0.229, 0.224, 0.225]) 48 | image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1]) 49 | image *= 255.0 50 | else: 51 | raise ValueError('Unknown preprocessing parameter.') 52 | image = np.transpose(image, [1, 2, 0]) 53 | image = np.round(image).astype(np.uint8) 54 | return image 55 | 56 | 57 | def grid_positions(h, w, device, matrix=False): 58 | lines = torch.arange( 59 | 0, h, device=device 60 | ).view(-1, 1).float().repeat(1, w) 61 | columns = torch.arange( 62 | 0, w, device=device 63 | ).view(1, -1).float().repeat(h, 1) 64 | if matrix: 65 | return torch.stack([lines, columns], dim=0) 66 | else: 67 | return torch.cat([lines.view(1, -1), columns.view(1, -1)], dim=0) 68 | 69 | 70 | def upscale_positions(pos, scaling_steps=0): 71 | for _ in range(scaling_steps): 72 | pos = pos * 2 + 0.5 73 | return pos 74 | 75 | 76 | def downscale_positions(pos, scaling_steps=0): 77 | for _ in range(scaling_steps): 78 | pos = (pos - 0.5) / 2 79 | return pos 80 | 81 | 82 | def interpolate_dense_features(pos, dense_features, return_corners=False): 83 | device = pos.device 84 | 85 | ids = torch.arange(0, pos.size(1), device=device) 86 | 87 | _, h, w = dense_features.size() 88 | 89 | i = pos[0, :] 90 | j = pos[1, :] 91 | 92 | # Valid corners 93 | i_top_left = torch.floor(i).long() 94 | j_top_left = torch.floor(j).long() 95 | valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0) 96 | 97 | i_top_right = torch.floor(i).long() 98 | j_top_right = torch.ceil(j).long() 99 | valid_top_right = torch.min(i_top_right >= 0, j_top_right < w) 100 | 101 | i_bottom_left = torch.ceil(i).long() 102 | j_bottom_left = torch.floor(j).long() 103 | valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0) 104 | 105 | i_bottom_right = torch.ceil(i).long() 106 | j_bottom_right = torch.ceil(j).long() 107 | valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w) 108 | 109 | valid_corners = torch.min( 110 | torch.min(valid_top_left, valid_top_right), 111 | torch.min(valid_bottom_left, valid_bottom_right) 112 | ) 113 | 114 | i_top_left = i_top_left[valid_corners] 115 | j_top_left = j_top_left[valid_corners] 116 | 117 | i_top_right = i_top_right[valid_corners] 118 | j_top_right = j_top_right[valid_corners] 119 | 120 | i_bottom_left = i_bottom_left[valid_corners] 121 | j_bottom_left = j_bottom_left[valid_corners] 122 | 123 | i_bottom_right = i_bottom_right[valid_corners] 124 | j_bottom_right = j_bottom_right[valid_corners] 125 | 126 | ids = ids[valid_corners] 127 | if ids.size(0) == 0: 128 | raise EmptyTensorError 129 | 130 | # Interpolation 131 | i = i[ids] 132 | j = j[ids] 133 | dist_i_top_left = i - i_top_left.float() 134 | dist_j_top_left = j - j_top_left.float() 135 | w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) 136 | w_top_right = (1 - dist_i_top_left) * dist_j_top_left 137 | w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) 138 | w_bottom_right = dist_i_top_left * dist_j_top_left 139 | 140 | descriptors = ( 141 | w_top_left * dense_features[:, i_top_left, j_top_left] + 142 | w_top_right * dense_features[:, i_top_right, j_top_right] + 143 | w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] + 144 | w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right] 145 | ) 146 | 147 | pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0) 148 | 149 | if not return_corners: 150 | return [descriptors, pos, ids] 151 | else: 152 | corners = torch.stack([ 153 | torch.stack([i_top_left, j_top_left], dim=0), 154 | torch.stack([i_top_right, j_top_right], dim=0), 155 | torch.stack([i_bottom_left, j_bottom_left], dim=0), 156 | torch.stack([i_bottom_right, j_bottom_right], dim=0) 157 | ], dim=0) 158 | return [descriptors, pos, ids, corners] 159 | 160 | 161 | def savefig(filepath, fig=None, dpi=None): 162 | # TomNorway - https://stackoverflow.com/a/53516034 163 | if not fig: 164 | fig = plt.gcf() 165 | 166 | plt.subplots_adjust(0, 0, 1, 1, 0, 0) 167 | for ax in fig.axes: 168 | ax.axis('off') 169 | ax.margins(0, 0) 170 | ax.xaxis.set_major_locator(plt.NullLocator()) 171 | ax.yaxis.set_major_locator(plt.NullLocator()) 172 | 173 | fig.savefig(filepath, pad_inches=0, bbox_inches='tight', dpi=dpi) 174 | -------------------------------------------------------------------------------- /evaluation_hpatch/export.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/6/27 3 | # 4 | from pathlib import Path 5 | import argparse 6 | 7 | import yaml 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from hpatch_related.hpatch_dataset import OrgHPatchDataset 12 | from models import get_model 13 | import cv2 as cv 14 | import torch 15 | import time 16 | def average_inference_time(time_collect): 17 | average_time = sum(time_collect) / len(time_collect) 18 | info = ('Average inference time: {}ms / {}fps'.format( 19 | round(average_time*1000), round(1/average_time)) 20 | ) 21 | print(info) 22 | return info 23 | 24 | def extract_multiscale(net, img, scale_f=2 ** 0.25, 25 | min_scale=0.125, max_scale=2.0, 26 | min_size=0, max_size=9999,top_k=10000, 27 | verbose=False): 28 | old_bm = torch.backends.cudnn.benchmark 29 | torch.backends.cudnn.benchmark = False 30 | H, W,three= img.shape 31 | shape=img.shape 32 | assert three == 3, "should be a batch with a single RGB image" 33 | assert max_scale <= 2 34 | s = max_scale # current scale factor 35 | 36 | X, Y, S, C, Q, D = [], [], [], [], [], [] 37 | while s + 0.001 >= max(min_scale, min_size / max(H, W)): 38 | if s - 0.001 <= min(max_scale, max_size / max(H, W)): 39 | nh = img.shape[0] 40 | nw = img.shape[1] 41 | if verbose: print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}") 42 | # extract descriptors 43 | 44 | with torch.no_grad(): 45 | res = net.predict(img=img) 46 | x = res['keypoints'][:,0] 47 | y = res['keypoints'][:,1] 48 | d = res['descriptors'] 49 | scores = res['scores'] 50 | 51 | X.append(x * W / nw) 52 | Y.append(y * H / nh) 53 | C.append(scores) 54 | D.append(d) 55 | 56 | s /= scale_f 57 | 58 | # down-scale the image for next iteration 59 | nh, nw = round(H * s), round(W * s) 60 | img = cv.resize(img, dsize=(nw, nh), interpolation=cv.INTER_LINEAR) 61 | torch.backends.cudnn.benchmark = old_bm 62 | Y = np.hstack(Y) 63 | X = np.hstack(X) 64 | scores = np.hstack(C) 65 | XY = np.stack([X, Y]) 66 | XY = np.swapaxes(XY, 0, 1) 67 | D = np.vstack(D) 68 | idxs = scores.argsort()[-top_k or None:] 69 | predictions = { 70 | "keypoints": XY[idxs], 71 | "descriptors": D[idxs], 72 | "scores": scores[idxs], 73 | "shape": shape 74 | } 75 | 76 | return predictions 77 | def extract_singlescale(net, img,top_k=10000 ,image_name=None): 78 | old_bm = torch.backends.cudnn.benchmark 79 | torch.backends.cudnn.benchmark = False 80 | shape = img.shape 81 | X, Y, S, C, Q, D = [], [], [], [], [], [] 82 | with torch.no_grad(): 83 | #res = net.predict(img=img,image_name=image_name) 84 | res = net.predict(img=img) 85 | x = res['keypoints'][:,0] 86 | y = res['keypoints'][:,1] 87 | d = res['descriptors'] 88 | scores = res['scores'] 89 | 90 | X.append(x) 91 | Y.append(y) 92 | C.append(scores) 93 | D.append(d) 94 | torch.backends.cudnn.benchmark = old_bm 95 | Y = np.hstack(Y) 96 | X = np.hstack(X) 97 | scores = np.hstack(C) 98 | XY = np.stack([X, Y]) 99 | XY = np.swapaxes(XY, 0, 1) 100 | D = np.vstack(D) 101 | idxs = scores.argsort()[-top_k or None:] 102 | predictions = { 103 | "keypoints": XY[idxs], 104 | "descriptors": D[idxs], 105 | "scores": scores[idxs], 106 | "shape": shape 107 | } 108 | 109 | return predictions 110 | 111 | if __name__ == '__main__': 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument('--config', type=str,default='../configs/MTLDesc_eva.yaml') 114 | parser.add_argument('--single', type=bool, default=True) 115 | parser.add_argument('--output_root', type=str,default='hpatches_sequences/hpatches-sequences-release') 116 | parser.add_argument("--top-k", type=int, default=10000) 117 | parser.add_argument("--scale-f", type=float, default=2 ** 0.25) 118 | parser.add_argument("--min-size", type=int, default=0) 119 | parser.add_argument("--max-size", type=int, default=9999) 120 | parser.add_argument("--min-scale", type=float, default=0.3) 121 | parser.add_argument("--max-scale", type=float, default=1) 122 | parser.add_argument('--tag', type=str, default='mtldesc',required=True) 123 | args = parser.parse_args() 124 | 125 | with open(args.config, 'r') as f: 126 | config = yaml.load(f) 127 | keys = '*' if config['keys'] == '*' else config['keys'].split(',') 128 | 129 | output_root = Path(args.output_root) 130 | output_root.mkdir(parents=True, exist_ok=True) 131 | 132 | dataset = OrgHPatchDataset(**config['hpatches']) 133 | # time_collect = [] 134 | with get_model(config['model']['name'])(**config['model']) as net: 135 | for i, data in tqdm(enumerate(dataset)): 136 | image_name = data['image_name'] 137 | folder_name = data['folder_name'] 138 | if args.single==True: 139 | start_time = time.time() 140 | predictions = extract_singlescale(net, data['image'],top_k=args.top_k,image_name=folder_name+'_'+image_name) 141 | # time_collect.append(time.time() - start_time) 142 | else: 143 | predictions = extract_multiscale(net, data['image'], scale_f=args.scale_f, 144 | min_scale=args.min_scale, max_scale=args.max_scale, 145 | min_size=args.min_size, max_size=args.max_size,top_k=args.top_k,verbose=True) 146 | 147 | if config['output_type']=='benchmark': 148 | output_dir = Path(output_root,args.tag,folder_name) 149 | output_dir.mkdir(parents=True, exist_ok=True) 150 | outpath = Path(output_dir, image_name) 151 | np.savez(str(outpath), **predictions) 152 | else: 153 | output_dir = Path(output_root, folder_name) 154 | output_dir.mkdir(parents=True, exist_ok=True) 155 | outpath = Path(output_dir, image_name + '.ppm.' + args.tag) 156 | np.savez(open(outpath, 'wb'), **predictions) 157 | # info = average_inference_time(time_collect) 158 | # print(info) 159 | 160 | 161 | -------------------------------------------------------------------------------- /nets/vit/vit_seg_modeling_resnet_skip.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from os.path import join as pjoin 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def np2th(weights, conv=False): 12 | """Possibly convert HWIO to OIHW.""" 13 | if conv: 14 | weights = weights.transpose([3, 2, 0, 1]) 15 | return torch.from_numpy(weights) 16 | 17 | 18 | class StdConv2d(nn.Conv2d): 19 | 20 | def forward(self, x): 21 | w = self.weight 22 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 23 | w = (w - m) / torch.sqrt(v + 1e-5) 24 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 25 | self.dilation, self.groups) 26 | 27 | 28 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 29 | return StdConv2d(cin, cout, kernel_size=3, stride=stride, 30 | padding=1, bias=bias, groups=groups) 31 | 32 | 33 | def conv1x1(cin, cout, stride=1, bias=False): 34 | return StdConv2d(cin, cout, kernel_size=1, stride=stride, 35 | padding=0, bias=bias) 36 | 37 | 38 | class PreActBottleneck(nn.Module): 39 | """Pre-activation (v2) bottleneck block. 40 | """ 41 | 42 | def __init__(self, cin, cout=None, cmid=None, stride=1): 43 | super().__init__() 44 | cout = cout or cin 45 | cmid = cmid or cout//4 46 | 47 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) 48 | self.conv1 = conv1x1(cin, cmid, bias=False) 49 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) 50 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! 51 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) 52 | self.conv3 = conv1x1(cmid, cout, bias=False) 53 | self.relu = nn.ReLU(inplace=True) 54 | 55 | if (stride != 1 or cin != cout): 56 | # Projection also with pre-activation according to paper. 57 | self.downsample = conv1x1(cin, cout, stride, bias=False) 58 | self.gn_proj = nn.GroupNorm(cout, cout) 59 | 60 | def forward(self, x): 61 | 62 | # Residual branch 63 | residual = x 64 | if hasattr(self, 'downsample'): 65 | residual = self.downsample(x) 66 | residual = self.gn_proj(residual) 67 | 68 | # Unit's branch 69 | y = self.relu(self.gn1(self.conv1(x))) 70 | y = self.relu(self.gn2(self.conv2(y))) 71 | y = self.gn3(self.conv3(y)) 72 | 73 | y = self.relu(residual + y) 74 | return y 75 | 76 | def load_from(self, weights, n_block, n_unit): 77 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) 78 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) 79 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) 80 | 81 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) 82 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) 83 | 84 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) 85 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) 86 | 87 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) 88 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) 89 | 90 | self.conv1.weight.copy_(conv1_weight) 91 | self.conv2.weight.copy_(conv2_weight) 92 | self.conv3.weight.copy_(conv3_weight) 93 | 94 | self.gn1.weight.copy_(gn1_weight.view(-1)) 95 | self.gn1.bias.copy_(gn1_bias.view(-1)) 96 | 97 | self.gn2.weight.copy_(gn2_weight.view(-1)) 98 | self.gn2.bias.copy_(gn2_bias.view(-1)) 99 | 100 | self.gn3.weight.copy_(gn3_weight.view(-1)) 101 | self.gn3.bias.copy_(gn3_bias.view(-1)) 102 | 103 | if hasattr(self, 'downsample'): 104 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) 105 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) 106 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) 107 | 108 | self.downsample.weight.copy_(proj_conv_weight) 109 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) 110 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) 111 | 112 | class ResNetV2(nn.Module): 113 | """Implementation of Pre-activation (v2) ResNet mode.""" 114 | 115 | def __init__(self, block_units, width_factor): 116 | super().__init__() 117 | width = int(64 * width_factor) 118 | self.width = width 119 | 120 | self.root = nn.Sequential(OrderedDict([ 121 | ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), 122 | ('gn', nn.GroupNorm(32, width, eps=1e-6)), 123 | ('relu', nn.ReLU(inplace=True)), 124 | # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) 125 | ])) 126 | 127 | self.body = nn.Sequential(OrderedDict([ 128 | ('block1', nn.Sequential(OrderedDict( 129 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + 130 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], 131 | ))), 132 | ('block2', nn.Sequential(OrderedDict( 133 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + 134 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], 135 | ))), 136 | ('block3', nn.Sequential(OrderedDict( 137 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + 138 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], 139 | ))), 140 | ])) 141 | 142 | def forward(self, x): 143 | features = [] 144 | b, c, in_size, _ = x.size() 145 | x = self.root(x) 146 | features.append(x) 147 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) 148 | for i in range(len(self.body)-1): 149 | x = self.body[i](x) 150 | right_size = int(in_size / 4 / (i+1)) 151 | if x.size()[2] != right_size: 152 | pad = right_size - x.size()[2] 153 | assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) 154 | feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) 155 | feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] 156 | else: 157 | feat = x 158 | features.append(feat) 159 | x = self.body[-1](x) 160 | return x, features[::-1] 161 | -------------------------------------------------------------------------------- /evaluation_hpatch/hpatch_related/hpatch_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/2/23 3 | # 4 | import os 5 | import glob 6 | import cv2 as cv 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class HPatchDataset(Dataset): 12 | 13 | def __init__(self, **configs): 14 | default_config = { 15 | 'dataset_dir': '', 16 | 'grayscale': False, 17 | 'resize': False, 18 | 'height': 240, 19 | 'width': 320, 20 | } 21 | default_config.update(configs) 22 | 23 | self.hpatch_height = default_config['height'] 24 | self.hpatch_width = default_config['width'] 25 | self.resize = default_config['resize'] 26 | if default_config['dataset_dir'] == '': 27 | assert False 28 | self.dataset_dir = default_config['dataset_dir'] 29 | self.grayscale = default_config['grayscale'] 30 | 31 | self.data_list = self._format_file_list() 32 | 33 | def __len__(self): 34 | return len(self.data_list) 35 | 36 | def __getitem__(self, idx): 37 | first_image_dir = self.data_list[idx]['first'] 38 | second_image_dir = self.data_list[idx]['second'] 39 | homo_dir = self.data_list[idx]['homo_dir'] 40 | image_type = self.data_list[idx]['type'] 41 | 42 | if self.grayscale: 43 | first_image = cv.imread(first_image_dir, cv.IMREAD_GRAYSCALE) 44 | second_image = cv.imread(second_image_dir, cv.IMREAD_GRAYSCALE) 45 | else: 46 | first_image = cv.imread(first_image_dir)[:, :, ::-1].copy() # convert bgr to rgb 47 | second_image = cv.imread(second_image_dir)[:, :, ::-1].copy() 48 | homo = np.loadtxt(homo_dir, dtype=np.float) 49 | 50 | org_first_shape = [np.shape(first_image)[0], np.shape(first_image)[1]] 51 | org_second_shape = [np.shape(second_image)[0], np.shape(second_image)[1]] 52 | if self.resize: 53 | resize_shape = np.array((self.hpatch_height, self.hpatch_width), dtype=np.float) 54 | first_scale = resize_shape / org_first_shape 55 | second_scale = resize_shape / org_second_shape 56 | homo = self._generate_adjust_homography(first_scale, second_scale, homo) 57 | 58 | first_image = cv.resize(first_image, (self.hpatch_width, self.hpatch_height), interpolation=cv.INTER_LINEAR) 59 | second_image = cv.resize(second_image, (self.hpatch_width, self.hpatch_height), 60 | interpolation=cv.INTER_LINEAR) 61 | 62 | first_shape = resize_shape 63 | second_shape = resize_shape 64 | else: 65 | first_shape = np.array(org_first_shape) 66 | second_shape = np.array(org_second_shape) 67 | # scale is used to recover the location in original image scale 68 | 69 | sample = { 70 | 'first_image': first_image, 'second_image': second_image, 71 | 'image_type': image_type, 'gt_homography': homo, 72 | 'first_shape': first_shape, 'second_shape': second_shape, 73 | } 74 | return sample 75 | 76 | @staticmethod 77 | def _generate_adjust_homography(first_scale, second_scale, homography): 78 | first_inv_scale_mat = np.diag((1. / first_scale[1], 1. / first_scale[0], 1)) 79 | second_scale_mat = np.diag((second_scale[1], second_scale[0], 1)) 80 | adjust_homography = np.matmul(second_scale_mat, np.matmul(homography, first_inv_scale_mat)) 81 | return adjust_homography 82 | 83 | def _format_file_list(self): 84 | data_list = [] 85 | with open(os.path.join(self.dataset_dir, 'illumination_list.txt'), 'r') as ilf: 86 | illumination_lines = ilf.readlines() 87 | for line in illumination_lines: 88 | line = line[:-1] 89 | first_dir, second_dir, homo_dir = line.split(',') 90 | dir_slice = {'first': first_dir, 'second': second_dir, 'homo_dir': homo_dir, 'type': 'illumination'} 91 | data_list.append(dir_slice) 92 | 93 | with open(os.path.join(self.dataset_dir, 'viewpoint_list.txt'), 'r') as vf: 94 | viewpoint_lines = vf.readlines() 95 | for line in viewpoint_lines: 96 | line = line[:-1] 97 | first_dir, second_dir, homo_dir = line.split(',') 98 | dir_slice = {'first': first_dir, 'second': second_dir, 'homo_dir': homo_dir, 'type': 'viewpoint'} 99 | data_list.append(dir_slice) 100 | 101 | return data_list 102 | 103 | 104 | class OrgHPatchDataset(Dataset): 105 | 106 | def __init__(self, **configs): 107 | default_config = { 108 | 'dataset_dir': '', 109 | 'grayscale': False, 110 | 'resize': False, 111 | 'height': 240, 112 | 'width': 320, 113 | } 114 | default_config.update(configs) 115 | 116 | self.hpatch_height = default_config['height'] 117 | self.hpatch_width = default_config['width'] 118 | self.resize = default_config['resize'] 119 | if default_config['dataset_dir'] == '': 120 | assert False 121 | self.dataset_dir = default_config['dataset_dir'] 122 | self.grayscale = default_config['grayscale'] 123 | 124 | self.data_list = self._format_file_list() 125 | 126 | def __len__(self): 127 | return len(self.data_list) 128 | 129 | def __getitem__(self, idx): 130 | image_dir = self.data_list[idx]['image_dir'] 131 | image_name = self.data_list[idx]['image_name'] 132 | folder_name = self.data_list[idx]['folder_name'] 133 | 134 | if self.grayscale: 135 | image = cv.imread(image_dir, cv.IMREAD_GRAYSCALE) 136 | else: 137 | image = cv.imread(image_dir)[:, :, ::-1].copy() # onvert bgr to rgb 138 | 139 | if self.resize: 140 | image = cv.resize(image, (self.hpatch_width, self.hpatch_height), interpolation=cv.INTER_LINEAR) 141 | 142 | sample = { 143 | 'image': image, 144 | 'image_name': image_name, 145 | 'folder_name': folder_name, 146 | } 147 | 148 | return sample 149 | 150 | def _format_file_list(self): 151 | data_list = [] 152 | folder_list = os.listdir(self.dataset_dir) 153 | for folder in folder_list: 154 | images = glob.glob(os.path.join(self.dataset_dir, folder, "*.ppm")) 155 | images = sorted(images) 156 | for image in images: 157 | image_name = image.split('/')[-1].split('.')[0] 158 | data_list.append( 159 | { 160 | 'image_dir': image, 161 | 'image_name': image_name, 162 | 'folder_name': folder, 163 | } 164 | ) 165 | 166 | return data_list 167 | 168 | 169 | if __name__ == "__main__": 170 | # # uncomment to generate the data list 171 | # hpatch_dir = '/data/MegPoint/dataset/hpatch' 172 | # generate_hpatch_data_list(hpatch_dir) 173 | 174 | pass 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/8/28 3 | # 4 | import time 5 | 6 | import torch 7 | import cv2 as cv 8 | from tensorboardX import SummaryWriter 9 | 10 | from utils.utils import Matcher 11 | from utils.evaluation_tools import * 12 | from data_utils import get_dataset 13 | 14 | 15 | class BaseTrainer(object): 16 | 17 | def __init__(self, **config): 18 | self.config = config 19 | self.logger = config['logger'] 20 | 21 | if torch.cuda.is_available(): 22 | self.logger.info('gpu is available, set device to cuda !') 23 | self.device = torch.device('cuda:0') 24 | self.gpu_count = 1 25 | else: 26 | self.logger.info('gpu is not available, set device to cpu !') 27 | self.device = torch.device('cpu') 28 | self.multi_gpus = False 29 | self.drop_last = False 30 | if torch.cuda.device_count() > 1: 31 | self.gpu_count = torch.cuda.device_count() 32 | self.config['train']['batch_size'] *= self.gpu_count 33 | self.multi_gpus = True 34 | self.drop_last = True 35 | self.logger.info("Multi gpus is available, let's use %d GPUS" % torch.cuda.device_count()) 36 | 37 | # 初始化summary writer 38 | self.summary_writer = SummaryWriter(self.config['ckpt_path']) 39 | self._initialize_dataset() 40 | self._initialize_model() 41 | self._initialize_optimizer() 42 | self._initialize_scheduler() 43 | self._initialize_loss() 44 | self._initialize_matcher() 45 | 46 | 47 | self.logger.info("Initialize cat func: _cat_c1c2c3c4") 48 | self.cat = self._cat_c1c2c3c4 49 | 50 | def __enter__(self): 51 | return self 52 | 53 | def __exit__(self, *args): 54 | pass 55 | 56 | def _inference_func(self, *args, **kwargs): 57 | raise NotImplementedError 58 | 59 | def _initialize_dataset(self, *args, **kwargs): 60 | raise NotImplementedError 61 | 62 | def _initialize_model(self, *args, **kwargs): 63 | self.model = None 64 | raise NotImplementedError 65 | 66 | def _initialize_optimizer(self, *args, **kwargs): 67 | self.optimizer = None 68 | raise NotImplementedError 69 | 70 | def _initialize_scheduler(self, *args, **kwargs): 71 | self.scheduler = None 72 | raise NotImplementedError 73 | 74 | def _train_func(self, *args, **kwargs): 75 | raise NotImplementedError 76 | 77 | def _initialize_loss(self, *args, **kwargs): 78 | raise NotImplementedError 79 | 80 | def _train_one_epoch(self, *args, **kwargs): 81 | raise NotImplementedError 82 | 83 | def train(self): 84 | start_time = time.time() 85 | 86 | # start training 87 | for i in range(self.config['train']['epoch_num']): 88 | 89 | # train 90 | self._train_one_epoch(i) 91 | # break # todo 92 | 93 | # validation 94 | if i >= int(self.config['train']['epoch_num'] * self.config['train']['validate_after']): 95 | self._validate_one_epoch(i) 96 | 97 | if self.config['train']['adjust_lr']: 98 | # adjust learning rate 99 | self.scheduler.step(i) 100 | 101 | end_time = time.time() 102 | self.logger.info("The whole training process takes %.3f h" % ((end_time - start_time)/3600)) 103 | 104 | def _initialize_matcher(self): 105 | # 初始化匹配算子 106 | self.logger.info("Initialize matcher of Nearest Neighbor.") 107 | self.general_matcher = Matcher('float') 108 | 109 | def _load_model_params(self, ckpt_file, previous_model): 110 | if ckpt_file is None: 111 | print("Please input correct checkpoint file dir!") 112 | return False 113 | 114 | self.logger.info("Load pretrained model %s " % ckpt_file) 115 | if not self.multi_gpus: 116 | model_dict = previous_model.state_dict() 117 | pretrain_dict = torch.load(ckpt_file, map_location=self.device) 118 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict} 119 | model_dict.update(pretrain_dict) 120 | previous_model.load_state_dict(model_dict) 121 | else: 122 | model_dict = previous_model.module.state_dict() 123 | pretrain_dict = torch.load(ckpt_file, map_location=self.device) 124 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict} 125 | model_dict.update(pretrain_dict) 126 | previous_model.module.load_state_dict(model_dict) 127 | return previous_model 128 | 129 | 130 | 131 | @staticmethod 132 | def _compute_total_metric(illum_metric, view_metric): 133 | illum_acc, illum_sum, illum_num = illum_metric.average() 134 | view_acc, view_sum, view_num = view_metric.average() 135 | return illum_acc, view_acc, (illum_sum+view_sum)/(illum_num+view_num+1e-4) 136 | 137 | @staticmethod 138 | def _compute_match_outlier_distribution(illum_metric, view_metric): 139 | illum_distribution = illum_metric.average_outlier() 140 | view_distribution = view_metric.average_outlier() 141 | return illum_distribution, view_distribution 142 | 143 | @staticmethod 144 | def _cat_c1c2c3c4(c1, c2, c3, c4, dim): 145 | return torch.cat((c1, c2, c3, c4), dim=dim) 146 | 147 | @staticmethod 148 | def _cat_c2c3c4(c1, c2, c3, c4, dim): 149 | return torch.cat((c2, c3, c4), dim=dim) 150 | 151 | @staticmethod 152 | def _cat_c1c2c4(c1, c2, c3, c4, dim): 153 | return torch.cat((c1, c2, c4), dim=dim) 154 | 155 | @staticmethod 156 | def _cat_c1c3c4(c1, c2, c3, c4, dim): 157 | return torch.cat((c1, c3, c4), dim=dim) 158 | 159 | @staticmethod 160 | def _cat_c1c4(c1, c2, c3, c4, dim): 161 | return torch.cat((c1, c4), dim=dim) 162 | 163 | @staticmethod 164 | def _cat_c2c4(c1, c2, c3, c4, dim): 165 | return torch.cat((c2, c4), dim=dim) 166 | 167 | @staticmethod 168 | def _cat_c3c4(c1, c2, c3, c4, dim): 169 | return torch.cat((c3, c4), dim=dim) 170 | 171 | @staticmethod 172 | def _cat_c4(c1, c2, c3, c4, dim): 173 | return c4 174 | 175 | @staticmethod 176 | def _convert_pt2cv(point_list): 177 | cv_point_list = [] 178 | 179 | for i in range(len(point_list)): 180 | cv_point = cv.KeyPoint() 181 | cv_point.pt = tuple(point_list[i][::-1]) 182 | cv_point_list.append(cv_point) 183 | 184 | return cv_point_list 185 | 186 | @staticmethod 187 | def _convert_pt2cv_np(point): 188 | cv_point_list = [] 189 | for i in range(point.shape[0]): 190 | cv_point = cv.KeyPoint() 191 | cv_point.pt = tuple(point[i, ::-1]) 192 | cv_point_list.append(cv_point) 193 | 194 | return cv_point_list 195 | 196 | @staticmethod 197 | def _convert_cv2pt(cv_point): 198 | point_list = [] 199 | for i, cv_pt in enumerate(cv_point): 200 | pt = np.array((cv_pt.pt[1], cv_pt.pt[0])) # y,x的顺序 201 | point_list.append(pt) 202 | point = np.stack(point_list, axis=0) 203 | return point 204 | 205 | @staticmethod 206 | def _compute_masked_loss(unmasked_loss, mask): 207 | total_num = torch.sum(mask, dim=(1, 2)) 208 | loss = torch.sum(mask*unmasked_loss, dim=(1, 2)) / total_num 209 | loss = torch.mean(loss) 210 | return loss 211 | 212 | @staticmethod 213 | def _convert_match2cv(first_point_list, second_point_list, sample_ratio=1.0): 214 | cv_first_point = [] 215 | cv_second_point = [] 216 | cv_matched_list = [] 217 | 218 | assert len(first_point_list) == len(second_point_list) 219 | 220 | inc = 1 221 | if sample_ratio < 1: 222 | inc = int(1.0 / sample_ratio) 223 | 224 | count = 0 225 | if len(first_point_list) > 0: 226 | for j in range(0, len(first_point_list), inc): 227 | cv_point = cv.KeyPoint() 228 | cv_point.pt = tuple(first_point_list[j][::-1]) 229 | cv_first_point.append(cv_point) 230 | 231 | cv_point = cv.KeyPoint() 232 | cv_point.pt = tuple(second_point_list[j][::-1]) 233 | cv_second_point.append(cv_point) 234 | 235 | cv_match = cv.DMatch() 236 | cv_match.queryIdx = count 237 | cv_match.trainIdx = count 238 | cv_matched_list.append(cv_match) 239 | 240 | count += 1 241 | 242 | return cv_first_point, cv_second_point, cv_matched_list 243 | 244 | @staticmethod 245 | def _generate_predict_point(prob, detection_threshold, scale=None, top_k=0): 246 | point_idx = np.where(prob > detection_threshold) 247 | 248 | if len(point_idx[0]) == 0 or len(point_idx[1]) == 0: 249 | point = np.empty((0, 2)) 250 | return point, 0 251 | 252 | prob = prob[point_idx] 253 | sorted_idx = np.argsort(prob)[::-1] 254 | if sorted_idx.shape[0] >= top_k: 255 | sorted_idx = sorted_idx[:top_k] 256 | 257 | point = np.stack(point_idx, axis=1) # [n,2] 258 | top_k_point = [] 259 | for idx in sorted_idx: 260 | top_k_point.append(point[idx]) 261 | 262 | point = np.stack(top_k_point, axis=0) 263 | point_num = point.shape[0] 264 | 265 | if scale is not None: 266 | point = point*scale 267 | return point, point_num 268 | 269 | @staticmethod 270 | def _cvpoint2numpy(point_cv): 271 | """将opencv格式的特征点转换成numpy数组""" 272 | point_list = [] 273 | for pt_cv in point_cv: 274 | point = np.array((pt_cv.pt[1], pt_cv.pt[0])) 275 | point_list.append(point) 276 | point_np = np.stack(point_list, axis=0) 277 | return point_np 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | -------------------------------------------------------------------------------- /trainers/mtldesc_trainer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2019/9/18 3 | # 4 | # 训练算子基类 5 | import os 6 | import time 7 | 8 | import torch 9 | import torch.nn.functional as f 10 | from torch.utils.data import DataLoader 11 | 12 | from nets import get_model 13 | from data_utils import get_dataset 14 | from trainers.base_trainer import BaseTrainer 15 | from utils.utils import spatial_nms 16 | from utils.utils import AttentionWeightedTripletLoss 17 | from utils.utils import PointHeatmapWeightedBCELoss 18 | 19 | class MTLDescTrainer(BaseTrainer): 20 | 21 | def __init__(self, **config): 22 | super(MTLDescTrainer, self).__init__(**config) 23 | 24 | def _initialize_dataset(self): 25 | # 初始化数据集 26 | self.logger.info('Initialize {}'.format(self.config['train']['dataset'])) 27 | self.train_dataset = get_dataset(self.config['train']['dataset'])(**self.config['train']) 28 | 29 | self.train_dataloader = DataLoader( 30 | dataset=self.train_dataset, 31 | batch_size=self.config['train']['batch_size'], 32 | shuffle=True, 33 | num_workers=self.config['train']['num_workers'], 34 | drop_last=True 35 | ) 36 | self.epoch_length = len(self.train_dataset) // self.config['train']['batch_size'] 37 | 38 | def _initialize_model(self): 39 | self.logger.info("Initialize network arch {}".format(self.config['model']['backbone'])) 40 | model = get_model(self.config['model']['backbone'])() 41 | 42 | ''' 43 | from torch.autograd import Variable as V 44 | from thop import profile 45 | input = torch.randn(1, 3, 640, 480).cuda() 46 | input = V(input).to(self.device) 47 | flops, params = profile(model.cuda(), inputs=(input,)) 48 | print("Number of flops: %.2fGFLOPs" % (flops / 1e9)) 49 | print("Number of parameter: %.2fM" % (params / 1e6)) 50 | exit(0) 51 | ''' 52 | if self.multi_gpus: 53 | model = torch.nn.DataParallel(model) 54 | self.model = model.to(self.device) 55 | 56 | def _initialize_loss(self): 57 | # 初始化loss算子 58 | # 初始化heatmap loss 59 | self.logger.info("Initialize the PointHeatmapWeightedBCELoss.") 60 | self.point_loss = PointHeatmapWeightedBCELoss(weight=self.config['train']['point_loss_weight']) 61 | 62 | # 初始化描述子loss 63 | self.logger.info("Initialize the DescriptorGeneralTripletLoss.") 64 | self.descriptor_loss=AttentionWeightedTripletLoss(self.device,T=self.config['train']['T']) 65 | def _initialize_optimizer(self): 66 | # 初始化网络训练优化器 67 | self.logger.info("Initialize Adam optimizer with weight_decay: {:.5f}.".format(self.config['train']['weight_decay'])) 68 | self.optimizer = torch.optim.Adam( 69 | params=self.model.parameters(), 70 | lr=self.config['train']['lr'], 71 | weight_decay=self.config['train']['weight_decay']) 72 | 73 | def _initialize_scheduler(self): 74 | 75 | # 初始化学习率调整算子 76 | if self.config['train']['lr_mod']=='LambdaLR': 77 | self.logger.info("Initialize lr_scheduler of LambdaLR: (%d, %d)" % (self.config['train']['maintain_epoch'], self.config['train']['decay_epoch'])) 78 | def lambda_rule(epoch): 79 | lr_l = 1.0 - max(0, epoch - self.config['train']['maintain_epoch']) / float(self.config['train']['decay_epoch'] + 1) 80 | return lr_l 81 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda_rule) 82 | else: 83 | milestones = [20, 30] 84 | self.logger.info("Initialize lr_scheduler of MultiStepLR: (%d, %d)" % (milestones[0], milestones[1])) 85 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=milestones, gamma=0.1) 86 | 87 | def _train_one_epoch(self, epoch_idx): 88 | self.model.train() 89 | 90 | self.logger.info("-----------------------------------------------------") 91 | self.logger.info("Training epoch %2d begin:" % epoch_idx) 92 | 93 | self._train_func(epoch_idx) 94 | 95 | self.logger.info("Training epoch %2d done." % epoch_idx) 96 | self.logger.info("-----------------------------------------------------") 97 | 98 | def _train_func(self, epoch_idx): 99 | self.model.train() 100 | stime = time.time() 101 | total_loss = 0 102 | for i, data in enumerate(self.train_dataloader): 103 | 104 | # 读取相关数据 105 | image = data["image"].to(self.device) 106 | heatmap_gt = data['heatmap'].to(self.device) 107 | point_mask = data['point_mask'].to(self.device) 108 | desp_point = data["desp_point"].to(self.device) 109 | 110 | warped_image = data["warped_image"].to(self.device) 111 | warped_heatmap_gt = data['warped_heatmap'].to(self.device) 112 | warped_point_mask = data['warped_point_mask'].to(self.device) 113 | warped_desp_point = data["warped_desp_point"].to(self.device) 114 | 115 | valid_mask = data["valid_mask"].to(self.device) 116 | not_search_mask = data["not_search_mask"].to(self.device) 117 | 118 | image_pair = torch.cat((image, warped_image), dim=0) 119 | 120 | # 模型预测 121 | heatmap_pred_pair, feature, weight_map = self.model(image_pair) 122 | 123 | # 计算描述子loss 124 | desp_point_pair = torch.cat((desp_point, warped_desp_point), dim=0) 125 | feature_pair = f.grid_sample(feature, desp_point_pair, mode="bilinear", padding_mode="border") 126 | weight_pair = f.grid_sample(weight_map, desp_point_pair, mode="bilinear", padding_mode="border").squeeze( 127 | dim=1) 128 | feature_pair = feature_pair[:, :, :, 0].transpose(1, 2) 129 | desp_pair = feature_pair / torch.norm(feature_pair, p=2, dim=2, keepdim=True) # L2 Normalization 130 | weight_0,weight_1=torch.chunk(weight_pair, 2, dim=0) 131 | desp_0, desp_1 = torch.chunk(desp_pair, 2, dim=0) 132 | desp_loss = self.descriptor_loss(desp_0, desp_1,weight_0,weight_1,valid_mask, not_search_mask) 133 | # 计算关键点loss 134 | heatmap_gt_pair = torch.cat((heatmap_gt, warped_heatmap_gt), dim=0) 135 | point_mask_pair = torch.cat((point_mask, warped_point_mask), dim=0) 136 | point_loss = self.point_loss(heatmap_pred_pair[:, 0, :, :], heatmap_gt_pair, point_mask_pair) 137 | 138 | loss = desp_loss + point_loss 139 | total_loss += loss 140 | if torch.isnan(loss): 141 | self.logger.error('loss is nan!') 142 | 143 | self.optimizer.zero_grad() 144 | 145 | loss.backward() 146 | 147 | self.optimizer.step() 148 | 149 | if i % self.config['train']['log_freq'] == 0: 150 | 151 | point_loss_val = point_loss.item() 152 | desp_loss_val = desp_loss.item() 153 | loss_val = loss.item() 154 | 155 | self.logger.info( 156 | "[Epoch:%2d][Step:%5d:%5d]: loss = %.4f, point_loss = %.4f, desp_loss = %.4f" 157 | " one step cost %.4fs. " % ( 158 | epoch_idx, i, self.epoch_length, 159 | loss_val, 160 | point_loss_val, 161 | desp_loss_val, 162 | (time.time() - stime) / self.config['train']['log_freq'], 163 | )) 164 | stime = time.time() 165 | self.logger.info("Total_loss:" + str(total_loss.detach().cpu().numpy())) 166 | # save the model 167 | if self.multi_gpus: 168 | torch.save( 169 | self.model.module.state_dict(), os.path.join(self.config['ckpt_path'], 'model_%02d.pt' % epoch_idx)) 170 | else: 171 | torch.save( 172 | self.model.state_dict(), os.path.join(self.config['ckpt_path'], 'model_%02d.pt' % epoch_idx)) 173 | def _inference_func(self, image_pair): 174 | """ 175 | image_pair: [2,1,h,w] 176 | """ 177 | self.model.eval() 178 | _, _, height, width = image_pair.shape 179 | heatmap_pair, feature_pair, weightmap_pair = self.model(image_pair) 180 | c1, c2 = torch.chunk(feature_pair, 2, dim=0) 181 | w1, w2 = torch.chunk(weightmap_pair, 2, dim=0) 182 | heatmap_pair = torch.sigmoid(heatmap_pair) 183 | prob_pair = spatial_nms(heatmap_pair) 184 | 185 | prob_pair = prob_pair.detach().cpu().numpy() 186 | first_prob = prob_pair[0, 0] 187 | second_prob = prob_pair[1, 0] 188 | 189 | # 得到对应的预测点 190 | first_point, first_point_num = self._generate_predict_point( 191 | first_prob, 192 | detection_threshold=self.config['test']['detection_threshold'], 193 | top_k=self.config['test']['top_k']) # [n,2] 194 | 195 | second_point, second_point_num = self._generate_predict_point( 196 | second_prob, 197 | detection_threshold=self.config['test']['detection_threshold'], 198 | top_k=self.config['test']['top_k']) # [n,2] 199 | 200 | if first_point_num <= 4 or second_point_num <= 4: 201 | print("skip this pair because there's little point!") 202 | return None 203 | 204 | # 得到点对应的描述子 205 | select_first_desp = self._generate_combined_descriptor_fast(first_point, c1,w1, height, width) 206 | select_second_desp = self._generate_combined_descriptor_fast(second_point, c2,w2, height, width) 207 | 208 | return first_point, first_point_num, second_point, second_point_num, select_first_desp, select_second_desp 209 | 210 | def _generate_combined_descriptor_fast(self, point, feature,weight, height, width): 211 | """ 212 | 用多层级的组合特征构造描述子 213 | Args: 214 | point: [n,2] 顺序是y,x 215 | c1,c2,c3,c4: 分别对应resnet4个block输出的特征,batchsize都是1 216 | Returns: 217 | desp: [n,dim] 218 | """ 219 | point = torch.from_numpy(point[:, ::-1].copy()).to(torch.float).to(self.device) 220 | # 归一化采样坐标到[-1,1] 221 | point = point * 2. / torch.tensor((width - 1, height - 1), dtype=torch.float, device=self.device) - 1 222 | point = point.unsqueeze(dim=0).unsqueeze(dim=2) # [1,n,1,2] 223 | 224 | feature = f.grid_sample(feature, point, mode="bilinear")[:, :, :, 0].transpose(1, 2)[0] 225 | weight = f.grid_sample(weight, point, mode="bilinear")[:, :, :, 0].transpose(1, 2)[0] 226 | desp_pair = feature / torch.norm(feature, p=2, dim=1, keepdim=True) 227 | desp = desp_pair * weight.expand_as(desp_pair) 228 | desp = desp.detach().cpu().numpy() 229 | 230 | return desp 231 | 232 | -------------------------------------------------------------------------------- /trainers/awdesc_trainer.py: -------------------------------------------------------------------------------- 1 | # 2 | 3 | import os 4 | import time 5 | 6 | import torch 7 | import torch.nn.functional as f 8 | from torch.utils.data import DataLoader 9 | 10 | from nets import get_model 11 | from data_utils import get_dataset 12 | from trainers.base_trainer import BaseTrainer 13 | from utils.utils import spatial_nms 14 | from utils.utils import AttentionWeightedTripletLoss,DescriptorGeneralTripletLoss,L1_loss,IMLoss,Cosine_Loss,Similarity 15 | from utils.utils import PointHeatmapWeightedBCELoss,PointHeatmapMSELoss,PointHeatmapL1Loss,PointHeatmapSigmodMSELoss 16 | 17 | class AWDescTrainer(BaseTrainer): 18 | 19 | def __init__(self, **config): 20 | super(AWDescTrainer, self).__init__(**config) 21 | 22 | def _initialize_dataset(self): 23 | self.logger.info('Initialize {}'.format(self.config['train']['dataset'])) 24 | self.train_dataset = get_dataset(self.config['train']['dataset'])(**self.config['train']) 25 | 26 | self.train_dataloader = DataLoader( 27 | dataset=self.train_dataset, 28 | batch_size=self.config['train']['batch_size'], 29 | shuffle=True, 30 | num_workers=self.config['train']['num_workers'], 31 | drop_last=True, 32 | pin_memory=True, 33 | ) 34 | self.epoch_length = len(self.train_dataset) // self.config['train']['batch_size'] 35 | 36 | def _initialize_model(self): 37 | self.logger.info("Initialize network arch {}".format(self.config['model']['backbone'])) 38 | model = get_model(self.config['model']['backbone'])() 39 | 40 | if self.multi_gpus: 41 | model = torch.nn.DataParallel(model) 42 | self.model = model.to(self.device) 43 | 44 | def _initialize_loss(self): 45 | self.logger.info("Initialize the PointHeatmapWeightedBCELoss.") 46 | self.point_loss = PointHeatmapWeightedBCELoss(weight=self.config['train']['point_loss_weight']) 47 | self.heatmap_loss = PointHeatmapMSELoss() 48 | self.sigmodmse=PointHeatmapSigmodMSELoss() 49 | self.similarity= Similarity() 50 | self.imloss= IMLoss() 51 | self.cosloss= Cosine_Loss(self.device) 52 | self.logger.info("Initialize the DescriptorGeneralTripletLoss.") 53 | self.descriptor_loss=AttentionWeightedTripletLoss(self.device,T=self.config['train']['T']) 54 | self.desclss= DescriptorGeneralTripletLoss(self.device) 55 | def _initialize_optimizer(self): 56 | self.logger.info("Initialize Adam optimizer with weight_decay: {:.5f}.".format(self.config['train']['weight_decay'])) 57 | self.optimizer = torch.optim.Adam( 58 | params=self.model.parameters(), 59 | lr=self.config['train']['lr'], 60 | weight_decay=self.config['train']['weight_decay']) 61 | 62 | def _initialize_scheduler(self): 63 | if self.config['train']['lr_mod']=='LambdaLR': 64 | self.logger.info("Initialize lr_scheduler of LambdaLR: (%d, %d)" % (self.config['train']['maintain_epoch'], self.config['train']['decay_epoch'])) 65 | def lambda_rule(epoch): 66 | lr_l = 1.0 - max(0, epoch - self.config['train']['maintain_epoch']) / float(self.config['train']['decay_epoch'] + 1) 67 | return lr_l 68 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda_rule) 69 | else: 70 | milestones = [20, 30] 71 | self.logger.info("Initialize lr_scheduler of MultiStepLR: (%d, %d)" % (milestones[0], milestones[1])) 72 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=milestones, gamma=0.1) 73 | 74 | def _train_one_epoch(self, epoch_idx): 75 | self.model.train() 76 | 77 | self.logger.info("-----------------------------------------------------") 78 | self.logger.info("Training epoch %2d begin:" % epoch_idx) 79 | 80 | self._train_func(epoch_idx) 81 | 82 | self.logger.info("Training epoch %2d done." % epoch_idx) 83 | self.logger.info("-----------------------------------------------------") 84 | 85 | def _train_func(self, epoch_idx): 86 | self.model.train() 87 | stime = time.time() 88 | total_loss = 0 89 | bord=4 90 | self.mask = torch.ones(1, 1, 400 - 2 * bord, 400 - 2 * bord).cuda() 91 | self.mask = f.pad(self.mask, (bord, bord, bord, bord), "constant", value=0).float().squeeze().repeat(24,1,1) 92 | for i, data in enumerate(self.train_dataloader): 93 | 94 | image = data["image"].to(self.device) 95 | heatmap_gt = data['heatmap'].to(self.device) 96 | point_mask = data['point_mask'].to(self.device) 97 | desp_point = data["desp_point"].to(self.device) 98 | warped_image = data["warped_image"].to(self.device) 99 | warped_heatmap_gt = data['warped_heatmap'].to(self.device) 100 | warped_point_mask = data['warped_point_mask'].to(self.device) 101 | warped_desp_point = data["warped_desp_point"].to(self.device) 102 | valid_mask = data["valid_mask"].to(self.device) 103 | not_search_mask = data["not_search_mask"].to(self.device) 104 | dl_heatmap = data['dl_heatmap'].to(self.device) 105 | warped_dl_heatmap = data['warped_dl_heatmap'].to(self.device) 106 | dl_attmap = data['dl_attmap'].to(self.device) 107 | warped_dl_attmap = data['warped_dl_attmap'].to(self.device) 108 | dl_descriptor = data['dl_descriptor'].to(self.device) #[:, :, :, 0].transpose(1, 2) 109 | warped_dl_descriptor = data['warped_dl_descriptor'].to(self.device) #[:, :, :, 0].transpose(1, 2) 110 | 111 | image_pair = torch.cat((image, warped_image), dim=0) 112 | 113 | heatmap_pred_pair, feature, weight_map = self.model(image_pair) 114 | 115 | desp_point_pair = torch.cat((desp_point, warped_desp_point), dim=0) 116 | feature_pair = f.grid_sample(feature, desp_point_pair, mode="bilinear", padding_mode="border") 117 | weight_pair = f.grid_sample(weight_map, desp_point_pair, mode="bilinear", padding_mode="border").squeeze( 118 | dim=1) 119 | feature_pair = feature_pair[:, :, :, 0].transpose(1, 2) 120 | desp_pair = feature_pair / torch.norm(feature_pair, p=2, dim=2, keepdim=True) # L2 Normalization 121 | weight_0, weight_1 = torch.chunk(weight_pair, 2, dim=0) 122 | desp_0, desp_1 = torch.chunk(desp_pair, 2, dim=0) 123 | 124 | dl_desp_pair = torch.cat((dl_descriptor, warped_dl_descriptor), dim=0) 125 | dl3=self.cosloss(desp_pair,dl_desp_pair) 126 | desp_loss = self.descriptor_loss(desp_0, desp_1, weight_0, weight_1, valid_mask, not_search_mask) 127 | 128 | 129 | # 计算关键点loss 130 | heatmap_gt_pair = torch.cat((heatmap_gt, warped_heatmap_gt), dim=0) 131 | dlheatmap_gt_pair = torch.cat((dl_heatmap, warped_dl_heatmap), dim=0) 132 | dlattmap_gt_pair = torch.cat((dl_attmap, warped_dl_attmap), dim=0) 133 | point_mask_pair = torch.cat((point_mask, warped_point_mask), dim=0) 134 | point_loss = self.point_loss(heatmap_pred_pair[:, 0, :, :], heatmap_gt_pair, point_mask_pair) 135 | dl1 = self.sigmodmse(heatmap_pred_pair[:, 0, :, :], dlheatmap_gt_pair, self.mask) #*2 136 | dl2 = self.sigmodmse(weight_map[:, 0, :, :], dlattmap_gt_pair, self.mask) 137 | loss = desp_loss + point_loss + dl1 + dl2 + dl3 138 | total_loss += loss 139 | if torch.isnan(loss): 140 | self.logger.error('loss is nan!') 141 | 142 | self.optimizer.zero_grad() 143 | 144 | loss.backward() 145 | 146 | self.optimizer.step() 147 | 148 | if i % self.config['train']['log_freq'] == 0: 149 | lossterm3 = dl3.item() 150 | lossterm2 = dl2.item() 151 | lossterm1 = dl1.item() 152 | point_loss_val = point_loss.item() 153 | desp_loss_val = desp_loss.item() 154 | loss_val = loss.item() 155 | 156 | self.logger.info( 157 | "[Epoch:%2d][Step:%5d:%5d]: loss = %.4f, point_loss = %.4f, desp_loss = %.4f, dl_heatmap = %.4f, dl_attmap = %.4f, dl_desc = %.4f" 158 | " one step cost %.4fs. " % ( 159 | epoch_idx, i, self.epoch_length, 160 | loss_val, 161 | point_loss_val, 162 | desp_loss_val, 163 | lossterm1, 164 | lossterm2, 165 | lossterm3, 166 | (time.time() - stime) / self.config['train']['log_freq'], 167 | )) 168 | stime = time.time() 169 | self.logger.info("Total_loss:" + str(total_loss.detach().cpu().numpy())) 170 | # save the model 171 | if self.multi_gpus: 172 | torch.save( 173 | self.model.module.state_dict(), os.path.join(self.config['ckpt_path'], 'model_%02d.pt' % epoch_idx)) 174 | else: 175 | torch.save( 176 | self.model.state_dict(), os.path.join(self.config['ckpt_path'], 'model_%02d.pt' % epoch_idx)) 177 | def _inference_func(self, image_pair): 178 | """ 179 | image_pair: [2,1,h,w] 180 | """ 181 | self.model.eval() 182 | _, _, height, width = image_pair.shape 183 | heatmap_pair, feature_pair, weightmap_pair = self.model(image_pair) 184 | c1, c2 = torch.chunk(feature_pair, 2, dim=0) 185 | w1, w2 = torch.chunk(weightmap_pair, 2, dim=0) 186 | heatmap_pair = torch.sigmoid(heatmap_pair) 187 | prob_pair = spatial_nms(heatmap_pair) 188 | 189 | prob_pair = prob_pair.detach().cpu().numpy() 190 | first_prob = prob_pair[0, 0] 191 | second_prob = prob_pair[1, 0] 192 | 193 | 194 | first_point, first_point_num = self._generate_predict_point( 195 | first_prob, 196 | detection_threshold=self.config['test']['detection_threshold'], 197 | top_k=self.config['test']['top_k']) # [n,2] 198 | 199 | second_point, second_point_num = self._generate_predict_point( 200 | second_prob, 201 | detection_threshold=self.config['test']['detection_threshold'], 202 | top_k=self.config['test']['top_k']) # [n,2] 203 | 204 | if first_point_num <= 4 or second_point_num <= 4: 205 | print("skip this pair because there's little point!") 206 | return None 207 | 208 | 209 | select_first_desp = self._generate_combined_descriptor_fast(first_point, c1,w1, height, width) 210 | select_second_desp = self._generate_combined_descriptor_fast(second_point, c2,w2, height, width) 211 | 212 | return first_point, first_point_num, second_point, second_point_num, select_first_desp, select_second_desp 213 | 214 | def _generate_combined_descriptor_fast(self, point, feature,weight, height, width): 215 | point = torch.from_numpy(point[:, ::-1].copy()).to(torch.float).to(self.device) 216 | point = point * 2. / torch.tensor((width - 1, height - 1), dtype=torch.float, device=self.device) - 1 217 | point = point.unsqueeze(dim=0).unsqueeze(dim=2) # [1,n,1,2] 218 | 219 | feature = f.grid_sample(feature, point, mode="bilinear")[:, :, :, 0].transpose(1, 2)[0] 220 | weight = f.grid_sample(weight, point, mode="bilinear")[:, :, :, 0].transpose(1, 2)[0] 221 | desp_pair = feature / torch.norm(feature, p=2, dim=1, keepdim=True) 222 | desp = desp_pair * weight.expand_as(desp_pair) 223 | desp = desp.detach().cpu().numpy() 224 | 225 | return desp 226 | 227 | 228 | -------------------------------------------------------------------------------- /evaluation_hpatch/models/MTLDesc.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/2/25 3 | # 4 | import os 5 | import sys 6 | sys.path.append("..") 7 | import cv2 as cv 8 | from nets import get_model 9 | from nets.network import * 10 | 11 | class Mtldesc(object): 12 | 13 | def __init__(self, **config): 14 | self.name = 'MTLDesc' 15 | self.config = { 16 | "detection_threshold": 0.9, 17 | "nms_dist": 4, 18 | "dim": 128, 19 | "nms_radius": 4, 20 | "border_remove": 4, 21 | } 22 | self.config.update(config) 23 | 24 | self.detection_threshold = self.config["detection_threshold"] 25 | self.nms_dist = self.config["nms_dist"] 26 | 27 | if torch.cuda.is_available(): 28 | print('gpu is available, set device to cuda !') 29 | self.device = torch.device('cuda:0') 30 | self.gpu_count = 1 31 | else: 32 | print('gpu is not available, set device to cpu !') 33 | self.device = torch.device('cpu') 34 | 35 | # 初始化模型 36 | self.model_name = self.config['backbone'].split('.')[-1] 37 | model = get_model(self.config['backbone'])() 38 | self.model = model.to(self.device) 39 | print("Initialize " +str(self.model_name)) 40 | 41 | if self.config['ckpt_name'] == '': 42 | assert False 43 | self.load(self.config['weight_path'],self.config['ckpt_name'],self.config['weights_id']) 44 | 45 | def _load_model_params(self, ckpt_file, previous_model): 46 | if ckpt_file is None: 47 | print("Please input correct checkpoint file dir!") 48 | return False 49 | 50 | print("Load pretrained model %s " % ckpt_file) 51 | 52 | model_dict = previous_model.state_dict() 53 | pretrain_dict = torch.load(ckpt_file, map_location=self.device) 54 | model_dict.update(pretrain_dict) 55 | previous_model.load_state_dict(model_dict) 56 | return previous_model 57 | 58 | def load(self, weight_path,checkpoint_root,model_idx): 59 | backbone_ckpt = os.path.join(weight_path,checkpoint_root, "model_"+str(model_idx)+".pt") 60 | self.model = self._load_model_params(backbone_ckpt, self.model) 61 | total = sum([param.nelement() for param in self.model.parameters()]) 62 | 63 | def load_split(self, model_ckpt, extractor_ckpt): 64 | self.model = self._load_model_params(model_ckpt, self.model) 65 | 66 | 67 | def _generate_predict_point(self, heatmap, height, width): 68 | xs, ys = np.where(heatmap >= self.config['detection_threshold']) 69 | pts = np.zeros((3, len(xs))) # Populate point data sized 3xN. 70 | if len(xs) > 0: 71 | pts[0, :] = ys 72 | pts[1, :] = xs 73 | pts[2, :] = heatmap[xs, ys] 74 | 75 | if self.config['nms_radius']: 76 | pts, _ = self.nms_fast( 77 | pts, height, width, dist_thresh=self.config['nms_radius']) 78 | inds = np.argsort(pts[2, :]) 79 | pts = pts[:, inds[::-1]] # Sort by confidence. 80 | 81 | # Remove points along border. 82 | bord = self.config['border_remove'] 83 | toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (width-bord)) 84 | toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (height-bord)) 85 | toremove = np.logical_or(toremoveW, toremoveH) 86 | pts = pts[:, ~toremove] 87 | pts = pts.transpose() 88 | 89 | point = pts[:, :2][:, ::-1] 90 | score = pts[:, 2] 91 | 92 | return point, score 93 | 94 | def nms_fast(self, in_corners, H, W, dist_thresh): 95 | """ 96 | Run a faster approximate Non-Max-Suppression on numpy corners shaped: 97 | 3xN [x_i,y_i,conf_i]^T 98 | 99 | Algo summary: Create a grid sized HxW. Assign each corner location a 1, rest 100 | are zeros. Iterate through all the 1's and convert them either to -1 or 0. 101 | Suppress points by setting nearby values to 0. 102 | 103 | Grid Value Legend: 104 | -1 : Kept. 105 | 0 : Empty or suppressed. 106 | 1 : To be processed (converted to either kept or supressed). 107 | 108 | NOTE: The NMS first rounds points to integers, so NMS distance might not 109 | be exactly dist_thresh. It also assumes points are within image boundaries. 110 | 111 | Inputs 112 | in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T. 113 | H - Image height. 114 | W - Image width. 115 | dist_thresh - Distance to suppress, measured as an infinty norm distance. 116 | Returns 117 | nmsed_corners - 3xN numpy matrix with surviving corners. 118 | nmsed_inds - N length numpy vector with surviving corner indices. 119 | """ 120 | grid = np.zeros((H, W)).astype(int) # Track NMS data. 121 | inds = np.zeros((H, W)).astype(int) # Store indices of points. 122 | # Sort by confidence and round to nearest int. 123 | inds1 = np.argsort(-in_corners[2,:]) 124 | corners = in_corners[:,inds1] 125 | rcorners = corners[:2,:].round().astype(int) # Rounded corners. 126 | # Check for edge case of 0 or 1 corners. 127 | if rcorners.shape[1] == 0: 128 | return np.zeros((3,0)).astype(int), np.zeros(0).astype(int) 129 | if rcorners.shape[1] == 1: 130 | out = np.vstack((rcorners, in_corners[2])).reshape(3,1) 131 | return out, np.zeros((1)).astype(int) 132 | # Initialize the grid. 133 | for i, rc in enumerate(rcorners.T): 134 | grid[rcorners[1,i], rcorners[0,i]] = 1 135 | inds[rcorners[1,i], rcorners[0,i]] = i 136 | # Pad the border of the grid, so that we can NMS points near the border. 137 | pad = dist_thresh 138 | grid = np.pad(grid, ((pad,pad), (pad,pad)), mode='constant') 139 | # Iterate through points, highest to lowest conf, suppress neighborhood. 140 | count = 0 141 | for i, rc in enumerate(rcorners.T): 142 | # Account for top and left padding. 143 | pt = (rc[0]+pad, rc[1]+pad) 144 | if grid[pt[1], pt[0]] == 1: # If not yet suppressed. 145 | grid[pt[1]-pad:pt[1]+pad+1, pt[0]-pad:pt[0]+pad+1] = 0 146 | grid[pt[1], pt[0]] = -1 147 | count += 1 148 | # Get all surviving -1's and return sorted array of remaining corners. 149 | keepy, keepx = np.where(grid==-1) 150 | keepy, keepx = keepy - pad, keepx - pad 151 | inds_keep = inds[keepy, keepx] 152 | out = corners[:, inds_keep] 153 | values = out[-1, :] 154 | inds2 = np.argsort(-values) 155 | out = out[:, inds2] 156 | out_inds = inds1[inds_keep[inds2]] 157 | 158 | return out, out_inds 159 | 160 | def predict(self, img, keys="*"): 161 | """ 162 | 获取一幅灰度图像对应的特征点及其描述子 163 | Args: 164 | img: [h,w] 灰度图像,要求h,w能被16整除 165 | Returns: 166 | point: [n,2] 特征点,输出点以y,x为顺序 167 | descriptor: [n,128] 描述子 168 | """ 169 | # switch to eval mode 170 | self.model.eval() 171 | # self.extractor.eval() 172 | 173 | shape = img.shape 174 | assert shape[2] == 3 # must be rgb 175 | 176 | org_h, org_w = shape[0], shape[1] 177 | 178 | # rescale to 16* 179 | if org_h % 16 != 0: 180 | scale_h = int(np.round(org_h / 16.) * 16.) 181 | sh = org_h / scale_h 182 | else: 183 | scale_h = org_h 184 | sh = 1.0 185 | 186 | if org_w % 16 != 0: 187 | scale_w = int(np.round(org_w / 16.) * 16.) 188 | sw = org_w / scale_w 189 | else: 190 | scale_w = org_w 191 | sw = 1.0 192 | 193 | img = cv.resize(img, dsize=(scale_w, scale_h), interpolation=cv.INTER_LINEAR) 194 | 195 | # to torch and scale to [-1,1] 196 | img = torch.from_numpy(img).to(torch.float).unsqueeze(dim=0).permute((0, 3, 1, 2)).to(self.device) 197 | img = (img / 255.) * 2. - 1. 198 | 199 | # detector 200 | heatmap, feature,weightmap = self.model(img) 201 | #heatmap2=f.interpolate(weightmap, heatmap.shape[2:], mode='bilinear') 202 | prob = torch.sigmoid(heatmap) 203 | #prob2 = torch.sigmoid(heatmap2) 204 | #prob=(prob+prob2)/2 205 | # 得到对应的预测点 206 | prob = prob.detach().cpu().numpy() 207 | prob = prob[0, 0] 208 | 209 | point, score = self._generate_predict_point(prob, height=scale_h, width=scale_w) # [n,2] 210 | #weightmap=heatmap 211 | # descriptor 212 | desp = self._generate_combined_descriptor_fast(point, feature,weightmap, scale_h, scale_w) 213 | #print(weightmap) 214 | #exit(0) 215 | # scale point back to the original scale and change to x-y 216 | point = (point * np.array((sh, sw)))[:, ::-1] 217 | 218 | predictions = { 219 | "shape": shape, 220 | "keypoints": point, 221 | "descriptors": desp, 222 | "scores": score, 223 | } 224 | 225 | if keys != '*': 226 | predictions = {k: predictions[k] for k in keys} 227 | 228 | return predictions 229 | 230 | def generate_descriptor(self, input_image, point, image_shape): 231 | """ 232 | 给定点,获取描述子 233 | """ 234 | # switch to eval mode 235 | self.model.eval() 236 | # self.extractor.eval() 237 | 238 | img = input_image 239 | 240 | shape = img.shape 241 | if len(shape) == 3: 242 | assert shape[2] == 1 # only support grayscale image 243 | img = img[:, :, 0] 244 | 245 | org_h, org_w = shape[0], shape[1] 246 | 247 | # rescale to 16* 248 | if org_h % 16 != 0: 249 | scale_h = np.round(org_h / 16.) * 16. 250 | else: 251 | scale_h = org_h 252 | 253 | if org_w % 16 != 0: 254 | scale_w = np.round(org_w / 16.) * 16. 255 | else: 256 | scale_w = org_w 257 | 258 | img = cv.resize(img, dsize=(int(scale_w), int(scale_h)), interpolation=cv.INTER_LINEAR) 259 | 260 | # to torch and scale to [-1,1] 261 | img = torch.from_numpy(img).to(torch.float).unsqueeze(dim=0).unsqueeze(dim=0).to(self.device) 262 | img = (img / 255.) * 2. - 1. 263 | 264 | # detector 265 | _, c1, c2, c3, c4 = self.model(img) 266 | 267 | # descriptor 268 | descriptor = self._generate_combined_descriptor_fast( 269 | point[:, ::-1], c1, c2, c3, c4, image_shape[0], image_shape[1] 270 | ) 271 | 272 | return descriptor 273 | 274 | def _generate_combined_descriptor_fast(self, point, feature,weight_map, height, width): 275 | """ 276 | 用多层级的组合特征构造描述子 277 | Args: 278 | point: [n,2] 顺序是y,x 279 | c1,c2,c3,c4: 分别对应resnet4个block输出的特征,batchsize都是1 280 | Returns: 281 | desp: [n,dim] 282 | """ 283 | point = torch.from_numpy(point[:, ::-1].copy()).to(torch.float).to(self.device) 284 | # 归一化采样坐标到[-1,1] 285 | point = point * 2. / torch.tensor((width-1, height-1), dtype=torch.float, device=self.device) - 1 286 | point = point.unsqueeze(dim=0).unsqueeze(dim=2) # [1,n,1,2] 287 | 288 | feature_pair = f.grid_sample(feature, point, mode="bilinear")[:, :, :, 0].transpose(1, 2)[0] 289 | weight_pair = f.grid_sample(weight_map, point, mode="bilinear", padding_mode="border")[:, :, :, 0].transpose(1, 2)[0]#.squeeze(dim=1) 290 | desp_pair = feature_pair / torch.norm(feature_pair, p=2, dim=1, keepdim=True) 291 | #desp=desp_pair 292 | desp = desp_pair * weight_pair.expand_as(desp_pair) 293 | #desp = desp / torch.norm(desp, p=2, dim=1, keepdim=True) 294 | 295 | desp = desp.detach().cpu().numpy() 296 | 297 | return desp 298 | 299 | def __call__(self, *args, **kwargs): 300 | raise NotImplementedError 301 | 302 | def __enter__(self): 303 | return self 304 | 305 | def __exit__(self, *args): 306 | pass 307 | 308 | 309 | -------------------------------------------------------------------------------- /evaluation_hpatch/utils/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import cv2 5 | import torch 6 | from tqdm import tqdm 7 | 8 | 9 | class Evaluator(object): 10 | 11 | def __init__(self): 12 | self.mutual_check = True 13 | self.err_thld = np.arange(1, 16) # range [1,15] 14 | if torch.cuda.is_available(): 15 | self.device = torch.device('cuda:0') 16 | else: 17 | self.device = torch.device('cpu:0') 18 | self.stats = { 19 | 'i_eval_stats': np.zeros((len(self.err_thld), 8), np.float32), 20 | 'v_eval_stats': np.zeros((len(self.err_thld), 8), np.float32), 21 | 'all_eval_stats': np.zeros((len(self.err_thld), 8), np.float32), 22 | } 23 | 24 | def homo_trans(self, coord, H): 25 | kpt_num = coord.shape[0] 26 | homo_coord = np.concatenate((coord, np.ones((kpt_num, 1))), axis=-1) 27 | proj_coord = np.matmul(H, homo_coord.T).T 28 | proj_coord = proj_coord / proj_coord[:, 2][..., None] 29 | proj_coord = proj_coord[:, 0:2] 30 | return proj_coord 31 | 32 | def mnn_matcher(self, descriptors_a, descriptors_b): 33 | descriptors_a = torch.from_numpy(descriptors_a).to(self.device).to(torch.float) 34 | descriptors_b = torch.from_numpy(descriptors_b).to(self.device).to(torch.float) 35 | sim = descriptors_a @ descriptors_b.t() 36 | nn12 = torch.max(sim, dim=1)[1] 37 | nn21 = torch.max(sim, dim=0)[1] 38 | ids1 = torch.arange(0, sim.shape[0], device=self.device) 39 | mask = (ids1 == nn21[nn12]) 40 | matches = torch.stack([ids1[mask], nn12[mask]]) 41 | return matches.t().detach().cpu().numpy() 42 | 43 | def feature_matcher(self, ref_feat, test_feat): 44 | matches = self.mnn_matcher(ref_feat, test_feat) 45 | matches = [cv2.DMatch(matches[i][0], matches[i][1], 0) for i in range(matches.shape[0])] 46 | return matches 47 | 48 | def get_covisible_mask(self, ref_coord, test_coord, ref_img_shape, test_img_shape, gt_homo, scaling=1.): 49 | ref_coord = ref_coord / scaling 50 | test_coord = test_coord / scaling 51 | 52 | proj_ref_coord = self.homo_trans(ref_coord, gt_homo) 53 | proj_test_coord = self.homo_trans(test_coord, np.linalg.inv(gt_homo)) 54 | 55 | ref_mask = np.logical_and( 56 | np.logical_and(proj_ref_coord[:, 0] < test_img_shape[1] - 1, 57 | proj_ref_coord[:, 1] < test_img_shape[0] - 1), 58 | np.logical_and(proj_ref_coord[:, 0] > 0, proj_ref_coord[:, 1] > 0) 59 | ) 60 | 61 | test_mask = np.logical_and( 62 | np.logical_and(proj_test_coord[:, 0] < ref_img_shape[1] - 1, 63 | proj_test_coord[:, 1] < ref_img_shape[0] - 1), 64 | np.logical_and(proj_test_coord[:, 0] > 0, proj_test_coord[:, 1] > 0) 65 | ) 66 | 67 | return ref_mask, test_mask 68 | 69 | def get_inlier_matches(self, ref_coord, test_coord, putative_matches, gt_homo, scaling=1.): 70 | p_ref_coord = np.float32([ref_coord[m.queryIdx] for m in putative_matches]) / scaling 71 | p_test_coord = np.float32([test_coord[m.trainIdx] for m in putative_matches]) / scaling 72 | 73 | proj_p_ref_coord = self.homo_trans(p_ref_coord, gt_homo) 74 | dist = np.sqrt(np.sum(np.square(proj_p_ref_coord - p_test_coord[:, 0:2]), axis=-1)) 75 | inlier_matches_list = [] 76 | for err_thld in self.err_thld: 77 | inlier_mask = dist <= err_thld 78 | inlier_matches = [putative_matches[z] for z in np.nonzero(inlier_mask)[0]] 79 | inlier_matches_list.append(inlier_matches) 80 | return inlier_matches_list 81 | 82 | def get_gt_matches(self, ref_coord, test_coord, gt_homo, scaling=1.): 83 | ref_coord = ref_coord / scaling 84 | test_coord = test_coord / scaling 85 | proj_ref_coord = self.homo_trans(ref_coord, gt_homo) 86 | 87 | pt0 = np.expand_dims(proj_ref_coord, axis=1) 88 | pt1 = np.expand_dims(test_coord, axis=0) 89 | norm = np.linalg.norm(pt0 - pt1, ord=None, axis=2) 90 | min_dist0 = np.min(norm, axis=1) 91 | min_dist1 = np.min(norm, axis=0) 92 | gt_num_list = [] 93 | for err_thld in self.err_thld: 94 | gt_num0 = np.sum(min_dist0 <= err_thld) 95 | gt_num1 = np.sum(min_dist1 <= err_thld) 96 | gt_num = (gt_num0 + gt_num1) / 2 97 | gt_num_list.append(gt_num) 98 | return gt_num_list 99 | 100 | def compute_homography_accuracy(self, ref_coord, test_coord, ref_img_shape, putative_matches, gt_homo, scaling=1.): 101 | ref_coord = np.float32([ref_coord[m.queryIdx] for m in putative_matches]) / scaling 102 | test_coord = np.float32([test_coord[m.trainIdx] for m in putative_matches]) / scaling 103 | 104 | pred_homo, _ = cv2.findHomography(ref_coord, test_coord, cv2.RANSAC) 105 | if pred_homo is None: 106 | correctness_list = [0 for i in range(len(self.err_thld))] 107 | else: 108 | corners = np.array([[0, 0], 109 | [ref_img_shape[1] / scaling - 1, 0], 110 | [0, ref_img_shape[0] / scaling - 1], 111 | [ref_img_shape[1] / scaling - 1, ref_img_shape[0] / scaling - 1]]) 112 | real_warped_corners = self.homo_trans(corners, gt_homo) 113 | warped_corners = self.homo_trans(corners, pred_homo) 114 | mean_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1)) 115 | correctness_list = [] 116 | for err_thld in self.err_thld: 117 | correctness = float(mean_dist <= err_thld) 118 | correctness_list.append(correctness) 119 | return correctness_list 120 | 121 | def print_stats(self, key): 122 | for i, err_thld in enumerate(self.err_thld): 123 | avg_stats = self.stats[key][i] / max(self.stats[key][i][0], 1) 124 | avg_stats = avg_stats[1:] 125 | print('----------%s----------' % key) 126 | print('threshold: %d' % err_thld) 127 | print('avg_n_feat', int(avg_stats[0])) 128 | print('avg_rep', avg_stats[1]) 129 | print('avg_precision', avg_stats[2]) 130 | print('avg_matching_score', avg_stats[3]) 131 | print('avg_recall', avg_stats[4]) 132 | print('avg_MMA', avg_stats[5]) 133 | print('avg_homography_accuracy', avg_stats[6]) 134 | 135 | def save_results(self, file): 136 | for i, err_thld in enumerate(self.err_thld): 137 | for key in ['i_eval_stats', 'v_eval_stats', 'all_eval_stats']: 138 | avg_stats = self.stats[key][i] / max(self.stats[key][i][0], 1) 139 | avg_stats = avg_stats[1:] 140 | file.write('----------%s----------\n' % key) 141 | file.write('threshold: %d\n' % err_thld) 142 | file.write('avg_n_feat: %d\n' % int(avg_stats[0])) 143 | file.write('avg_rep: %.4f\n' % avg_stats[1]) 144 | file.write('avg_precision: %.4f\n' % avg_stats[2]) 145 | file.write('avg_matching_score: %.4f\n' % avg_stats[3]) 146 | file.write('avg_recall: %.4f\n' % avg_stats[4]) 147 | file.write('avg_MMA: %.4f\n' % avg_stats[5]) 148 | file.write('avg_homography_accuracy: %.4f\n' % avg_stats[6]) 149 | 150 | 151 | def evaluate(read_feats, dataset_path, evaluator): 152 | seq_names = sorted(os.listdir(dataset_path)) 153 | 154 | for seq_idx, seq_name in tqdm(enumerate(seq_names), total=len(seq_names)): 155 | ref_img_shape, ref_kpts, ref_descs = read_feats(seq_name, 1) 156 | 157 | eval_stats = np.zeros((len(evaluator.err_thld), 8), np.float32) 158 | 159 | # print(seq_idx, seq_name) 160 | 161 | for im_idx in range(2, 7): 162 | test_img_shape, test_kpts, test_descs = read_feats(seq_name, im_idx) 163 | gt_homo = np.loadtxt(os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx))) 164 | 165 | # get MMA 166 | num_feat = min(ref_kpts.shape[0], test_kpts.shape[0]) 167 | if num_feat > 0: 168 | mma_putative_matches = evaluator.feature_matcher(ref_descs, test_descs) 169 | else: 170 | mma_putative_matches = [] 171 | mma_inlier_matches_list = evaluator.get_inlier_matches(ref_kpts, test_kpts, mma_putative_matches, gt_homo) 172 | num_mma_putative = len(mma_putative_matches) 173 | num_mma_inlier_list = [len(mma_inlier_matches) for mma_inlier_matches in mma_inlier_matches_list] 174 | 175 | # get covisible keypoints 176 | ref_mask, test_mask = evaluator.get_covisible_mask(ref_kpts, test_kpts, 177 | ref_img_shape, test_img_shape, 178 | gt_homo) 179 | cov_ref_coord, cov_test_coord = ref_kpts[ref_mask], test_kpts[test_mask] 180 | cov_ref_feat, cov_test_feat = ref_descs[ref_mask], test_descs[test_mask] 181 | num_cov_feat = (cov_ref_coord.shape[0] + cov_test_coord.shape[0]) / 2 182 | 183 | # get gt matches 184 | gt_num_list = evaluator.get_gt_matches(cov_ref_coord, cov_test_coord, gt_homo) 185 | # establish putative matches 186 | if num_cov_feat > 0: 187 | putative_matches = evaluator.feature_matcher(cov_ref_feat, cov_test_feat) 188 | else: 189 | putative_matches = [] 190 | num_putative = max(len(putative_matches), 1) 191 | 192 | # get homography accuracy 193 | correctness_list = evaluator.compute_homography_accuracy(cov_ref_coord, cov_test_coord, ref_img_shape, 194 | putative_matches, gt_homo) 195 | # get inlier matches 196 | inlier_matches_list = evaluator.get_inlier_matches(cov_ref_coord, cov_test_coord, putative_matches, gt_homo) 197 | num_inlier_list = [len(inlier_matches) for inlier_matches in inlier_matches_list] 198 | 199 | eval_stats += np.stack([np.array((1, # counter 200 | num_feat, # feature number 201 | gt_num_list[i] / max(num_cov_feat, 1), # repeatability 202 | num_inlier_list[i] / max(num_putative, 1), # precision 203 | num_inlier_list[i] / max(num_cov_feat, 1), # matching score 204 | num_inlier_list[i] / max(gt_num_list[i], 1), # recall 205 | num_mma_inlier_list[i] / max(num_mma_putative, 1), 206 | correctness_list[i])) / 5 # MMA 207 | for i in range(len(evaluator.err_thld)) 208 | ], axis=0) # [len(evaluator.err_thld), 8] 209 | 210 | # print(int(eval_stats[1]), eval_stats[2:]) 211 | evaluator.stats['all_eval_stats'] += eval_stats 212 | if os.path.basename(seq_name)[0] == 'i': 213 | evaluator.stats['i_eval_stats'] += eval_stats 214 | if os.path.basename(seq_name)[0] == 'v': 215 | evaluator.stats['v_eval_stats'] += eval_stats 216 | 217 | # evaluator.print_stats('i_eval_stats') 218 | # evaluator.print_stats('v_eval_stats') 219 | evaluator.print_stats('all_eval_stats') 220 | 221 | err_thld = evaluator.err_thld 222 | i_eval_stats = evaluator.stats['i_eval_stats'].T # [8, 15] 223 | i_eval_count = i_eval_stats[0, 0] 224 | i_err = { 225 | 'Rep.': {thr: i_eval_stats[2][i] for i, thr in enumerate(err_thld)}, 226 | 'Precision': {thr: i_eval_stats[3][i] for i, thr in enumerate(err_thld)}, 227 | 'M.S.': {thr: i_eval_stats[4][i] for i, thr in enumerate(err_thld)}, 228 | 'MMA': {thr: i_eval_stats[6][i] for i, thr in enumerate(err_thld)}, 229 | 'HA': {thr: i_eval_stats[7][i] for i, thr in enumerate(err_thld)}, 230 | } 231 | 232 | v_eval_stats = evaluator.stats['v_eval_stats'].T 233 | v_eval_count = v_eval_stats[0, 0] 234 | v_err = { 235 | 'Rep.': {thr: v_eval_stats[2][i] for i, thr in enumerate(err_thld)}, 236 | 'Precision': {thr: v_eval_stats[3][i] for i, thr in enumerate(err_thld)}, 237 | 'M.S.': {thr: v_eval_stats[4][i] for i, thr in enumerate(err_thld)}, 238 | 'MMA': {thr: v_eval_stats[6][i] for i, thr in enumerate(err_thld)}, 239 | 'HA': {thr: v_eval_stats[7][i] for i, thr in enumerate(err_thld)}, 240 | } 241 | 242 | return { 243 | 'i_err': i_err, 244 | 'i_count': i_eval_count, 245 | 'v_err': v_err, 246 | 'v_count': v_eval_count, 247 | } 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /data_utils/megadepth_train_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/6/29 3 | # 4 | import os 5 | from glob import glob 6 | 7 | import cv2 as cv 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | from data_utils.dataset_tools import HomographyAugmentation 13 | from data_utils.dataset_tools import ImgAugTransform 14 | from data_utils.dataset_tools import space_to_depth 15 | class MegaDepthTrainDataset(Dataset): 16 | """ 17 | Combination of MegaDetph and COCO 18 | """ 19 | def __init__(self, **config): 20 | self.data_list = self._format_file_list( 21 | config['mega_image_dir'], 22 | config['mega_keypoint_dir'], 23 | config['mega_despoint_dir'], 24 | ) 25 | self.sydesp_type=config['sydesp_type'] 26 | self.height = config['height'] 27 | self.width = config['width'] 28 | 29 | self.homography = HomographyAugmentation() 30 | self.photometric = ImgAugTransform() 31 | self.fix_grid = self._generate_fixed_grid() 32 | 33 | def __len__(self): 34 | return len(self.data_list) 35 | 36 | def __getitem__(self, idx): 37 | data_info = self.data_list[idx] 38 | if data_info['type'] == 'synthesis': 39 | return self._get_synthesis_data(data_info) 40 | elif data_info['type'] == 'real': 41 | return self._get_real_data(data_info) 42 | else: 43 | assert False 44 | 45 | def _get_real_data(self, data_info): 46 | image_dir = data_info['image'] 47 | info_dir = data_info['info'] 48 | label_dir = data_info['label'] 49 | 50 | image12 = cv.imread(image_dir)[:, :, ::-1].copy() # 交换BGR为RGB 51 | image1, image2 = np.split(image12, 2, axis=1) 52 | h, w, _ = image1.shape 53 | 54 | if torch.rand([]).item() < 0.5: 55 | image1 = self.photometric(image1) 56 | image2 = self.photometric(image2) 57 | 58 | info = np.load(info_dir) 59 | desp_point1 = info["desp_point1"] 60 | desp_point2 = info["desp_point2"] 61 | valid_mask = info["valid_mask"] 62 | not_search_mask = info["not_search_mask"] 63 | 64 | label = np.load(label_dir) 65 | points1 = label["points_0"] 66 | points2 = label["points_1"] 67 | 68 | # 2.1 得到第一副图点构成的热图 69 | heatmap1 = self._convert_points_to_heatmap(points1) 70 | point_mask1 = torch.ones_like(heatmap1) 71 | 72 | # 2.2 得到第二副图点构成的热图 73 | heatmap2 = self._convert_points_to_heatmap(points2) 74 | point_mask2 = torch.ones_like(heatmap2) 75 | 76 | # debug use 77 | # desp_point1 = ((desp_point1 + 1) * np.array((self.width - 1, self.height - 1)) / 2.)[:, 0, :] 78 | # desp_point2 = ((desp_point2 + 1) * np.array((self.width - 1, self.height - 1)) / 2.)[:, 0, :] 79 | # image_point1 = draw_image_keypoints(image1[:, :, ::-1], desp_point1[valid_mask][:, ::-1], show=False) 80 | # image_point2 = draw_image_keypoints(image2[:, :, ::-1], desp_point2[valid_mask][:, ::-1], show=False) 81 | # image_point1 = draw_image_keypoints(image1[:, :, ::-1], points1, show=False) 82 | # image_point2 = draw_image_keypoints(image2[:, :, ::-1], points2, show=False) 83 | # cat_all = np.concatenate((image_point1, image_point2), axis=0) 84 | # cv.imwrite("/home/yuyang/tmp/debug_%05d.jpg" % idx, cat_all) 85 | # cv.imshow("cat_all", cat_all) 86 | # cv.waitKey() 87 | 88 | image1 = (torch.from_numpy(image1).to(torch.float) * 2. / 255. - 1.).permute((2, 0, 1)).contiguous() 89 | image2 = (torch.from_numpy(image2).to(torch.float) * 2. / 255. - 1.).permute((2, 0, 1)).contiguous() 90 | 91 | desp_point1 = torch.from_numpy(desp_point1) 92 | desp_point2 = torch.from_numpy(desp_point2) 93 | 94 | valid_mask = torch.from_numpy(valid_mask).to(torch.float) 95 | not_search_mask = torch.from_numpy(not_search_mask).to(torch.float) 96 | 97 | return { 98 | "image": image1, 99 | "point_mask": point_mask1, 100 | "heatmap": heatmap1, 101 | "warped_image": image2, 102 | "warped_point_mask": point_mask2, 103 | "warped_heatmap": heatmap2, 104 | "desp_point": desp_point1, 105 | "warped_desp_point": desp_point2, 106 | "valid_mask": valid_mask, 107 | "not_search_mask": not_search_mask, 108 | } 109 | 110 | def _get_synthesis_data(self, data_info): 111 | image12 = cv.imread(data_info['image'])[:, :, ::-1].copy() # 交换BGR为RGB 112 | image1, image2 = np.split(image12, 2, axis=1) 113 | point = np.load(data_info['label']) 114 | info = np.load(data_info['info']) 115 | if torch.rand([]).item() < 0.5: 116 | image = cv.resize(image1, dsize=(self.width, self.height), interpolation=cv.INTER_LINEAR) 117 | point = point["points_0"] 118 | desp_point_load = info["raw_desp_point1"] 119 | else: 120 | image = cv.resize(image2, dsize=(self.width, self.height), interpolation=cv.INTER_LINEAR) 121 | point = point["points_1"] 122 | desp_point_load = info["raw_desp_point2"] 123 | point_mask = np.ones_like(image).astype(np.float32)[:, :, 0].copy() 124 | 125 | # 1、由随机采样的单应变换得到第二副图像及其对应的关键点位置、原始掩膜和该单应变换 126 | if torch.rand([]).item() < 0.5: 127 | warped_image, warped_point_mask, warped_point, homography = \ 128 | image.copy(), point_mask.copy(), point.copy(), np.eye(3) 129 | else: 130 | warped_image, warped_point_mask, warped_point, homography = self.homography(image, point, return_homo=True) 131 | warped_point_mask = warped_point_mask[:, :, 0].copy() 132 | 133 | if torch.rand([]).item() < 0.5: 134 | image = self.photometric(image) 135 | warped_image = self.photometric(warped_image) 136 | 137 | # 2.1 得到第一副图点构成的热图 138 | heatmap = self._convert_points_to_heatmap(point) 139 | 140 | # 2.2 得到第二副图点构成的热图 141 | warped_heatmap = self._convert_points_to_heatmap(warped_point) 142 | 143 | # 3、采样训练描述子要用的点 144 | if self.sydesp_type =='random': 145 | desp_point = self._random_sample_point() 146 | else: 147 | desp_point = desp_point_load 148 | 149 | shape = image.shape 150 | 151 | warped_desp_point, valid_mask, not_search_mask = self._generate_warped_point( 152 | desp_point, homography, shape[0], shape[1]) 153 | 154 | # debug use 155 | # image_point = draw_image_keypoints(image, desp_point, show=False) 156 | # warped_image_point = draw_image_keypoints(warped_image, warped_desp_point, show=False) 157 | # cat_all = np.concatenate((image, warped_image), axis=1) 158 | # cat_all = np.concatenate((image_point, warped_image_point), axis=1) 159 | # cv.imwrite("/home/yuyang/tmp/coco_tmp/%d.jpg" % idx, cat_all) 160 | # cv.imshow("cat_all", cat_all) 161 | # cv.waitKey() 162 | 163 | image = image.astype(np.float32) * 2. / 255. - 1. 164 | warped_image = warped_image.astype(np.float32) * 2. / 255. - 1. 165 | 166 | image = torch.from_numpy(image).permute((2, 0, 1)) 167 | warped_image = torch.from_numpy(warped_image).permute((2, 0, 1)) 168 | 169 | point_mask = torch.from_numpy(point_mask) 170 | warped_point_mask = torch.from_numpy(warped_point_mask) 171 | 172 | desp_point = torch.from_numpy(self._scale_point_for_sample(desp_point)) 173 | warped_desp_point = torch.from_numpy(self._scale_point_for_sample(warped_desp_point)) 174 | 175 | valid_mask = torch.from_numpy(valid_mask) 176 | not_search_mask = torch.from_numpy(not_search_mask) 177 | 178 | return { 179 | "image": image, # [1,h,w] 180 | "point_mask": point_mask, # [h,w] 181 | "heatmap": heatmap, # [h,w] 182 | "warped_image": warped_image, # [1,h,w] 183 | "warped_point_mask": warped_point_mask, # [h,w] 184 | "warped_heatmap": warped_heatmap, # [h,w] 185 | "desp_point": desp_point, # [n,1,2] 186 | "warped_desp_point": warped_desp_point, # [n,1,2] 187 | "valid_mask": valid_mask, # [n] 188 | "not_search_mask": not_search_mask, # [n,n] 189 | } 190 | 191 | @ staticmethod 192 | def _generate_warped_point(point, homography, height, width, threshold=16): 193 | """ 194 | 根据投影变换得到变换后的坐标点,有效关系及不参与负样本搜索的矩阵 195 | Args: 196 | point: [n,2] 与warped_point一一对应 197 | homography: 点对之间的变换关系 198 | 199 | Returns: 200 | not_search_mask: [n,n] type为float32的mask,不搜索的位置为1 201 | """ 202 | # 得到投影点的坐标 203 | point = np.concatenate((point[:, ::-1], np.ones((point.shape[0], 1))), axis=1)[:, :, np.newaxis] # [n,3,1] 204 | project_point = np.matmul(homography, point)[:, :, 0] 205 | project_point = project_point[:, :2] / project_point[:, 2:3] 206 | project_point = project_point[:, ::-1] # 调换为y,x的顺序 207 | 208 | # 投影点在图像范围内的点为有效点,反之则为无效点 209 | boarder_0 = np.array((0, 0), dtype=np.float32) 210 | boarder_1 = np.array((height-1, width-1), dtype=np.float32) 211 | valid_mask = (project_point >= boarder_0) & (project_point <= boarder_1) 212 | valid_mask = np.all(valid_mask, axis=1) 213 | invalid_mask = ~valid_mask 214 | 215 | # 根据无效点及投影点之间的距离关系确定不搜索的负样本矩阵 216 | 217 | dist = np.linalg.norm(project_point[:, np.newaxis, :] - project_point[np.newaxis, :, :], axis=2) 218 | not_search_mask = ((dist <= threshold) | invalid_mask[np.newaxis, :]).astype(np.float32) 219 | return project_point.astype(np.float32), valid_mask.astype(np.float32), not_search_mask 220 | 221 | def _scale_point_for_sample(self, point): 222 | """ 223 | 将点归一化到[-1,1]的区间范围内,并调换顺序为x,y,方便采样 224 | Args: 225 | point: [n,2] y,x的顺序,原始范围为[0,height-1], [0,width-1] 226 | Returns: 227 | point: [n,1,2] x,y的顺序,范围为[-1,1] 228 | """ 229 | org_size = np.array((self.height-1, self.width-1), dtype=np.float32) 230 | point = ((point * 2. / org_size - 1.)[:, ::-1])[:, np.newaxis, :].copy() 231 | return point 232 | 233 | def _random_sample_point(self): 234 | """ 235 | 根据预设的输入图像大小,随机均匀采样坐标点 236 | """ 237 | grid = self.fix_grid.copy() 238 | # 随机选择指定数目个格子 239 | 240 | point_list = [] 241 | for i in range(grid.shape[0]): 242 | y_start, x_start, y_end, x_end = grid[i] 243 | rand_y = np.random.randint(y_start, y_end) 244 | rand_x = np.random.randint(x_start, x_end) 245 | point_list.append(np.array((rand_y, rand_x), dtype=np.float32)) 246 | point = np.stack(point_list, axis=0) 247 | 248 | return point 249 | 250 | def _generate_fixed_grid(self, option=None): 251 | """ 252 | 预先采样固定间隔的225个图像格子 253 | """ 254 | if option == None: 255 | y_num = 20 256 | x_num = 20 257 | else: 258 | y_num = option[0] 259 | x_num = option[1] 260 | 261 | grid_y = np.linspace(0, self.height-1, y_num+1, dtype=np.int) 262 | grid_x = np.linspace(0, self.width-1, x_num+1, dtype=np.int) 263 | 264 | grid_y_start = grid_y[:y_num].copy() 265 | grid_y_end = grid_y[1:y_num+1].copy() 266 | grid_x_start = grid_x[:x_num].copy() 267 | grid_x_end = grid_x[1:x_num+1].copy() 268 | 269 | grid_start = np.stack((np.tile(grid_y_start[:, np.newaxis], (1, x_num)), 270 | np.tile(grid_x_start[np.newaxis, :], (y_num, 1))), axis=2).reshape((-1, 2)) 271 | grid_end = np.stack((np.tile(grid_y_end[:, np.newaxis], (1, x_num)), 272 | np.tile(grid_x_end[np.newaxis, :], (y_num, 1))), axis=2).reshape((-1, 2)) 273 | grid = np.concatenate((grid_start, grid_end), axis=1) 274 | 275 | return grid 276 | 277 | def _convert_points_to_heatmap(self, points): 278 | """ 279 | 将原始点位置经下采样后得到heatmap与incmap,heatmap上对应下采样整型点位置处的值为1,其余为0;incmap与heatmap一一对应, 280 | 在关键点位置处存放整型点到亚像素角点的偏移量,以及训练时用来屏蔽非关键点inc量的incmap_valid 281 | Args: 282 | points: [n,2] 283 | 284 | Returns: 285 | heatmap: [h,w] 关键点位置为1,其余为0 286 | incmap: [2,h,w] 关键点位置存放实际偏移,其余非关键点处的偏移量为0 287 | incmap_valid: [h,w] 关键点位置为1,其余为0,用于训练时屏蔽对非关键点偏移量的训练,只关注关键点的偏移量 288 | 289 | """ 290 | height = self.height 291 | width = self.width 292 | 293 | # localmap = self.localmap.clone() 294 | # padded_heatmap = torch.zeros( 295 | # (height+self.g_paddings*2, width+self.g_paddings*2), dtype=torch.float) 296 | heatmap = torch.zeros((height, width), dtype=torch.float) 297 | 298 | num_pt = points.shape[0] 299 | if num_pt > 0: 300 | for i in range(num_pt): 301 | pt = points[i] 302 | pt_y_float, pt_x_float = pt 303 | 304 | pt_y_int = round(pt_y_float) 305 | pt_x_int = round(pt_x_float) 306 | 307 | pt_y = int(pt_y_int) # 对真值点位置进行下采样,这里有量化误差 308 | pt_x = int(pt_x_int) 309 | 310 | # 排除掉经下采样后在边界外的点 311 | if pt_y < 0 or pt_y > height - 1: 312 | continue 313 | if pt_x < 0 or pt_x > width - 1: 314 | continue 315 | 316 | # 关键点位置在heatmap上置1,并在incmap上记录该点离亚像素点的偏移量 317 | heatmap[pt_y, pt_x] = 1.0 318 | 319 | return heatmap 320 | 321 | def convert_points_to_label(self, points): 322 | 323 | height = self.height 324 | width = self.width 325 | n_height = int(height / 8) 326 | n_width = int(width / 8) 327 | assert n_height * 8 == height and n_width * 8 == width 328 | 329 | num_pt = points.shape[0] 330 | label = torch.zeros((height * width)) 331 | if num_pt > 0: 332 | points_h, points_w = torch.split(points, 1, dim=1) 333 | points_idx = points_w + points_h * width 334 | label = label.scatter_(dim=0, index=points_idx[:, 0], value=1.0).reshape((height, width)) 335 | else: 336 | label = label.reshape((height, width)) 337 | 338 | dense_label = space_to_depth(label) 339 | dense_label = torch.cat((dense_label, 0.5 * torch.ones((1, n_height, n_width))), dim=0) # [65, 30, 40] 340 | sparse_label = torch.argmax(dense_label, dim=0) # [30,40] 341 | 342 | return sparse_label 343 | @staticmethod 344 | def _format_file_list(mega_image_dir, mega_keypoint_dir,mega_despoint_dir): 345 | data_list = [] 346 | 347 | # format megadepth related list 348 | mega_image_list = glob(os.path.join(mega_image_dir, '*.jpg')) 349 | mega_image_list = sorted(mega_image_list) 350 | data_type = 'real' 351 | for img in mega_image_list: 352 | img_name = img.split('/')[-1].split('.')[0] 353 | info = os.path.join(mega_despoint_dir, img_name + '.npz') 354 | label = os.path.join(mega_keypoint_dir, img_name + '.npz') 355 | data_list.append( 356 | { 357 | 'type': data_type, 358 | 'image': img, 359 | 'info': info, 360 | 'label': label, 361 | } 362 | ) 363 | 364 | # format coco related list 365 | data_type = 'synthesis' 366 | for img in mega_image_list: 367 | img_name = img.split('/')[-1].split('.')[0] 368 | label = os.path.join(mega_keypoint_dir, img_name + '.npz') 369 | info = os.path.join(mega_despoint_dir, img_name + '.npz') 370 | data_list.append( 371 | { 372 | 'type': data_type, 373 | 'image': img, 374 | 'info': info, 375 | 'label': label, 376 | } 377 | ) 378 | 379 | return data_list 380 | 381 | 382 | 383 | 384 | 385 | 386 | -------------------------------------------------------------------------------- /nets/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from nets.vit.vit_seg_modeling import * 5 | from nets.vit.vit_seg_modeling import PatchTransfomer 6 | 7 | class SuperPointNet(nn.Module): 8 | 9 | def __init__(self): 10 | super(SuperPointNet, self).__init__() 11 | self.relu = torch.nn.ReLU(inplace=True) 12 | self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) 13 | c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256 14 | # Shared Encoder. 15 | self.conv1a = nn.Conv2d(3, c1, kernel_size=3, stride=1, padding=1) 16 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 17 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 18 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 19 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 20 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 21 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 22 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 23 | # Detector Head. 24 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 25 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) 26 | # Descriptor Head. 27 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 28 | self.convDb = nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0) 29 | 30 | self.softmax = nn.Softmax(dim=1) 31 | self.tanh = nn.Tanh() 32 | 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 36 | 37 | def forward(self, x): 38 | x = self.relu(self.conv1a(x)) 39 | x = self.relu(self.conv1b(x)) 40 | x = self.pool(x) 41 | x = self.relu(self.conv2a(x)) 42 | x = self.relu(self.conv2b(x)) 43 | x = self.pool(x) 44 | x = self.relu(self.conv3a(x)) 45 | x = self.relu(self.conv3b(x)) 46 | x = self.pool(x) 47 | x = self.relu(self.conv4a(x)) 48 | x = self.relu(self.conv4b(x)) 49 | 50 | # detect head 51 | cPa = self.relu(self.convPa(x)) 52 | logit = self.convPb(cPa) 53 | prob = self.softmax(logit)[:, :-1, :, :] 54 | 55 | # descriptor head 56 | cDa = self.relu(self.convDa(x)) 57 | feature = self.convDb(cDa) 58 | 59 | dn = torch.norm(feature, p=2, dim=1, keepdim=True) 60 | desc = feature.div(dn) 61 | 62 | return logit, desc, prob 63 | 64 | class MTLDesc(nn.Module): 65 | def __init__(self): 66 | super(MTLDesc, self).__init__() 67 | self.relu = torch.nn.ReLU(inplace=True) 68 | self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) 69 | # Shared Encoder. 70 | self.conv1a = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 71 | self.conv1b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 72 | 73 | self.conv2a = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 74 | self.conv2b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 75 | 76 | self.conv3a = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 77 | self.conv3b = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 78 | 79 | self.conv4a = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 80 | self.conv4b = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 81 | 82 | #关键点金字塔 83 | self.heatmap1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) 84 | self.heatmap2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) 85 | self.heatmap3 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 86 | self.heatmap4 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 87 | 88 | self.fuse_weight_1 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 89 | self.fuse_weight_2 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 90 | self.fuse_weight_3 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 91 | self.fuse_weight_4 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 92 | #权重初始化 93 | self.fuse_weight_1.data.fill_(0.1) 94 | self.fuse_weight_2.data.fill_(0.2) 95 | self.fuse_weight_3.data.fill_(0.3) 96 | self.fuse_weight_4.data.fill_(0.4) 97 | 98 | self.scalemap = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1) 99 | self.active = f.softplus 100 | self.conv_avg = nn.Conv2d(128, 384, kernel_size=3, stride=1, padding=1) 101 | 102 | self.transfomer=PatchTransfomer(img_size=64,in_channels=128) 103 | self.pool_size=64 104 | self.adapool = nn.AdaptiveAvgPool2d((self.pool_size, self.pool_size)) 105 | self.mask=nn.Conv2d(128,1,kernel_size=3,stride=1,padding=1) 106 | self.conv_des = nn.Conv2d(384, 128, kernel_size=1, stride=1, padding=0) 107 | self.conv_des_1 = nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1, dilation=1) 108 | self.conv_des_2 = nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=6, dilation=6) 109 | self.conv_des_3 = nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=12, dilation=12) 110 | self.conv_des_4 = nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=18, dilation=18) 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 114 | 115 | def forward(self, x): 116 | x = self.relu(self.conv1a(x)) 117 | c1 = self.relu(self.conv1b(x)) # 64 118 | 119 | c2 = self.pool(c1) 120 | c2 = self.relu(self.conv2a(c2)) 121 | c2 = self.relu(self.conv2b(c2)) # 64 122 | 123 | c3 = self.pool(c2) 124 | c3 = self.relu(self.conv3a(c3)) 125 | c3 = self.relu(self.conv3b(c3)) # 128 126 | 127 | c4 = self.pool(c3) 128 | c4 = self.relu(self.conv4a(c4)) 129 | c4 = self.relu(self.conv4b(c4)) # 128 130 | #top=c4 131 | # KeyPoint Map 132 | heatmap1 = self.heatmap1(c1) 133 | heatmap2 = self.heatmap2(c2) 134 | heatmap3 = self.heatmap3(c3) 135 | heatmap4 = self.heatmap3(c4) 136 | des_size = heatmap1.shape[2:] # 1/4 HxW 137 | heatmap2 = f.interpolate(heatmap2, des_size, mode='bilinear') 138 | heatmap3 = f.interpolate(heatmap3, des_size, mode='bilinear') 139 | heatmap4 = f.interpolate(heatmap4, des_size, mode='bilinear') 140 | heatmap = heatmap1 * self.fuse_weight_1 + heatmap2 * self.fuse_weight_2 + heatmap3 * self.fuse_weight_3 + heatmap4 * self.fuse_weight_4 141 | 142 | # Descriptor 143 | des_size = c3.shape[2:] # 1/4 HxW 144 | c1 = f.interpolate(c1, des_size, mode='bilinear') 145 | c2 = f.interpolate(c2, des_size, mode='bilinear') 146 | c3 = c3 147 | c4 = f.interpolate(c4, des_size, mode='bilinear') 148 | feature = torch.cat((c1, c2, c3, c4), dim=1) 149 | 150 | # attention map 151 | meanmap = torch.mean(feature, dim=1, keepdim=True) 152 | attmap = self.scalemap(meanmap) 153 | attmap = self.active(attmap) 154 | 155 | # Global Context 156 | top=self.adapool(c4) 157 | mask=self.relu(self.mask(top)) 158 | avg=self.transfomer(top) 159 | avg=f.interpolate(avg,des_size,mode='bilinear') 160 | mask=f.interpolate(mask,des_size,mode='bilinear') 161 | descriptor = feature 162 | descriptor = self.conv_des(descriptor)+avg*mask 163 | descriptor_1 = self.conv_des_1(descriptor) 164 | descriptor_2 = self.conv_des_2(descriptor) 165 | descriptor_3 = self.conv_des_3(descriptor) 166 | descriptor_4 = self.conv_des_4(descriptor) 167 | descriptor_refine = torch.cat((descriptor_1, descriptor_2, descriptor_3, descriptor_4), dim=1) 168 | descriptor = descriptor + descriptor_refine 169 | return heatmap, descriptor,attmap 170 | class Lite16(nn.Module): 171 | def __init__(self): 172 | super(Lite16, self).__init__() 173 | self.relu = torch.nn.ReLU(inplace=True) 174 | self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) 175 | 176 | # Shared Encoder. 177 | self.conv1a = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) 178 | self.conv1b = nn.Conv2d(16, 16, kernel_size=3, stride=1,padding=1) 179 | 180 | self.conv2a = nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1) 181 | self.conv2b = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 182 | 183 | self.conv3a = nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1) 184 | self.conv3b = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 185 | 186 | self.conv4a = nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1) 187 | self.conv4b = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 188 | 189 | self.heatmap1 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1) 190 | self.heatmap2 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1) 191 | self.heatmap3 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1) 192 | self.heatmap4 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1) 193 | #self.heatmap = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 194 | 195 | self.scalemap = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1) 196 | self.active = f.softplus 197 | self.descriptor = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 198 | 199 | self.fuse_weight_1 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 200 | self.fuse_weight_2 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 201 | self.fuse_weight_3 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 202 | self.fuse_weight_4 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 203 | # 权重初始化 204 | self.fuse_weight_1.data.fill_(0.1) 205 | self.fuse_weight_2.data.fill_(0.2) 206 | self.fuse_weight_3.data.fill_(0.3) 207 | self.fuse_weight_4.data.fill_(0.4) 208 | 209 | for m in self.modules(): 210 | if isinstance(m, nn.Conv2d): 211 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 212 | 213 | def forward(self, x): 214 | x1 = self.relu(self.conv1a(x)) 215 | c1 = self.relu(self.conv1b(x1)) 216 | c1 = torch.cat((x1, c1), dim=1) # 1 217 | 218 | x2 = self.pool(c1) 219 | x2 = self.relu(self.conv2a(x2)) 220 | c2 = self.relu(self.conv2b(x2)) 221 | c2 = torch.cat((x2, c2), dim=1) # 1/2 222 | 223 | x3 = self.pool(c2) 224 | x3 = self.relu(self.conv3a(x3)) 225 | c3 = self.relu(self.conv3b(x3)) 226 | c3 = torch.cat((x3, c3), dim=1) # 1/4 227 | 228 | x4 = self.pool(c3) 229 | x4 = self.relu(self.conv4a(x4)) 230 | c4 = self.relu(self.conv4b(x4)) 231 | c4 = torch.cat((x4, c4), dim=1)# 1/8 232 | 233 | # heatmap = self.heatmap(c1) 234 | heatmap1 = self.heatmap1(c1) 235 | heatmap2 = self.heatmap2(c2) 236 | heatmap3 = self.heatmap3(c3) 237 | heatmap4 = self.heatmap4(c4) 238 | 239 | des_size = c1.shape[2:] # 1/4 HxW 240 | heatmap2 = f.interpolate(heatmap2, des_size, mode='bilinear') 241 | heatmap3 = f.interpolate(heatmap3, des_size, mode='bilinear') 242 | heatmap4 = f.interpolate(heatmap4, des_size, mode='bilinear') 243 | heatmap = heatmap1 * self.fuse_weight_1 + heatmap2 * self.fuse_weight_2 + heatmap3 * self.fuse_weight_3 + heatmap4 * self.fuse_weight_4 244 | #heatmap = torch.cat((c1,heatmap2,heatmap3,heatmap4),dim=1) 245 | #heatmap = self.heatmap(heatmap) 246 | # Descriptor 247 | des_size = c3.shape[2:] # 1/4 HxW 248 | c1 = f.interpolate(c1, des_size, mode='bilinear') 249 | c2 = f.interpolate(c2, des_size, mode='bilinear') 250 | c3 = c3 251 | c4 = f.interpolate(c4, des_size, mode='bilinear') 252 | features = torch.cat((c1, c2, c3,c4), dim=1) 253 | descriptor = self.descriptor(features) 254 | 255 | # attention map 256 | meanmap = torch.mean(features, dim=1, keepdim=True) 257 | attmap = self.scalemap(meanmap) 258 | attmap = self.active(attmap) 259 | 260 | return heatmap, descriptor, attmap 261 | class Lite32(nn.Module): 262 | def __init__(self): 263 | super(Lite32, self).__init__() 264 | self.relu = torch.nn.ReLU(inplace=True) 265 | self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) 266 | 267 | # Shared Encoder. 268 | self.conv1a = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1) 269 | self.conv1b = nn.Conv2d(32, 32, kernel_size=3, stride=1,padding=1) 270 | 271 | self.conv2a = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) 272 | self.conv2b = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 273 | 274 | self.conv3a = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) 275 | self.conv3b = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 276 | 277 | self.conv4a = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) 278 | self.conv4b = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 279 | 280 | self.heatmap1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) 281 | self.heatmap2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) 282 | self.heatmap3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) 283 | self.heatmap4 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) 284 | #self.heatmap = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 285 | 286 | self.scalemap = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1) 287 | self.active = f.softplus 288 | self.descriptor = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) 289 | 290 | self.fuse_weight_1 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 291 | self.fuse_weight_2 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 292 | self.fuse_weight_3 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 293 | self.fuse_weight_4 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 294 | # 权重初始化 295 | self.fuse_weight_1.data.fill_(0.1) 296 | self.fuse_weight_2.data.fill_(0.2) 297 | self.fuse_weight_3.data.fill_(0.3) 298 | self.fuse_weight_4.data.fill_(0.4) 299 | 300 | for m in self.modules(): 301 | if isinstance(m, nn.Conv2d): 302 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 303 | 304 | def forward(self, x): 305 | x1 = self.relu(self.conv1a(x)) 306 | c1 = self.relu(self.conv1b(x1)) 307 | c1 = torch.cat((x1, c1), dim=1) # 1 308 | 309 | x2 = self.pool(c1) 310 | x2 = self.relu(self.conv2a(x2)) 311 | c2 = self.relu(self.conv2b(x2)) 312 | c2 = torch.cat((x2, c2), dim=1) # 1/2 313 | 314 | x3 = self.pool(c2) 315 | x3 = self.relu(self.conv3a(x3)) 316 | c3 = self.relu(self.conv3b(x3)) 317 | c3 = torch.cat((x3, c3), dim=1) # 1/4 318 | 319 | x4 = self.pool(c3) 320 | x4 = self.relu(self.conv4a(x4)) 321 | c4 = self.relu(self.conv4b(x4)) 322 | c4 = torch.cat((x4, c4), dim=1)# 1/8 323 | 324 | # heatmap = self.heatmap(c1) 325 | heatmap1 = self.heatmap1(c1) 326 | heatmap2 = self.heatmap2(c2) 327 | heatmap3 = self.heatmap3(c3) 328 | heatmap4 = self.heatmap4(c4) 329 | 330 | des_size = c1.shape[2:] # 1/4 HxW 331 | heatmap2 = f.interpolate(heatmap2, des_size, mode='bilinear') 332 | heatmap3 = f.interpolate(heatmap3, des_size, mode='bilinear') 333 | heatmap4 = f.interpolate(heatmap4, des_size, mode='bilinear') 334 | heatmap = heatmap1 * self.fuse_weight_1 + heatmap2 * self.fuse_weight_2 + heatmap3 * self.fuse_weight_3 + heatmap4 * self.fuse_weight_4 335 | #heatmap = torch.cat((c1,heatmap2,heatmap3,heatmap4),dim=1) 336 | #heatmap = self.heatmap(heatmap) 337 | # Descriptor 338 | des_size = c3.shape[2:] # 1/4 HxW 339 | c1 = f.interpolate(c1, des_size, mode='bilinear') 340 | c2 = f.interpolate(c2, des_size, mode='bilinear') 341 | c3 = c3 342 | c4 = f.interpolate(c4, des_size, mode='bilinear') 343 | features = torch.cat((c1, c2, c3,c4), dim=1) 344 | descriptor = self.descriptor(features) 345 | 346 | # attention map 347 | meanmap = torch.mean(features, dim=1, keepdim=True) 348 | attmap = self.scalemap(meanmap) 349 | attmap = self.active(attmap) 350 | 351 | return heatmap, descriptor, attmap 352 | -------------------------------------------------------------------------------- /utils/evaluation_tools.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2019/8/13 3 | # 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class MovingAverage(object): 9 | 10 | def __init__(self, max_size=10): 11 | self.max_size = max_size 12 | self.insert_pos = 0 13 | self.queue = [] 14 | 15 | def reset(self): 16 | self.queue = [] 17 | self.insert_pos = 0 18 | 19 | def push(self, x): 20 | current_size = len(self.queue) 21 | if current_size < self.max_size: 22 | self.queue.append(x) 23 | else: 24 | self.queue[int(self.insert_pos % self.max_size)] = x 25 | self.insert_pos += 1 26 | 27 | def average(self): 28 | if len(self.queue) == 0: 29 | return 0 30 | current_queue = np.array(self.queue) 31 | avg = np.mean(current_queue) 32 | return avg 33 | 34 | def current_size(self): 35 | return len(self.queue) 36 | 37 | 38 | class PointStatistics(object): 39 | 40 | def __init__(self): 41 | self.point_num_list = [] 42 | self.sample_num = 0 43 | 44 | def reset(self): 45 | self.point_num_list = [] 46 | self.sample_num = 0 47 | 48 | def average(self): 49 | avg = 0 50 | var = 0 51 | std = 0 52 | 53 | for pt_num in self.point_num_list: 54 | avg += pt_num 55 | if self.sample_num == 0: 56 | return 0, 0 57 | avg /= self.sample_num 58 | 59 | for pt_num in self.point_num_list: 60 | var += (pt_num-avg)**2 61 | var /= self.sample_num 62 | std = np.sqrt(var) 63 | return avg, std 64 | 65 | def update(self, point_num): 66 | self.point_num_list.append(point_num) 67 | self.sample_num += 1 68 | 69 | 70 | class HomoAccuracyCalculator(object): 71 | 72 | def __init__(self, epsilon, height, width): 73 | self.height = height 74 | self.width = width 75 | self.epsilon = epsilon 76 | self.sum_accuracy = 0 77 | self.sum_sample_num = 0 78 | self.corner = self._generate_corner() 79 | 80 | def reset(self): 81 | self.sum_accuracy = 0 82 | self.sum_sample_num = 0 83 | 84 | def average(self): 85 | if self.sum_sample_num == 0: 86 | return 0, 0, 0 87 | return self.sum_accuracy / self.sum_sample_num, self.sum_accuracy, self.sum_sample_num 88 | 89 | def update(self, pred_homography, gt_homography, return_diff=False): 90 | warped_corner_by_pred = np.matmul(pred_homography, self.corner[:, :, np.newaxis])[:, :, 0] 91 | warped_corner_by_gt = np.matmul(gt_homography, self.corner[:, :, np.newaxis])[:, :, 0] 92 | warped_corner_by_pred = warped_corner_by_pred[:, :2] / warped_corner_by_pred[:, 2:3] 93 | warped_corner_by_gt = warped_corner_by_gt[:, :2] / warped_corner_by_gt[:, 2:3] 94 | diff = np.linalg.norm((warped_corner_by_pred-warped_corner_by_gt), axis=1, keepdims=False) 95 | diff = np.mean(diff) 96 | accuracy = (diff <= self.epsilon).astype(np.float) 97 | self.sum_accuracy += accuracy 98 | self.sum_sample_num += 1 99 | if not return_diff: 100 | return accuracy.astype(np.bool) 101 | else: 102 | return accuracy.astype(np.bool), diff 103 | 104 | def _generate_corner(self): 105 | pt_00 = np.array((0, 0, 1), dtype=np.float) 106 | pt_01 = np.array((0, self.height-1, 1), dtype=np.float) 107 | pt_10 = np.array((self.width-1, 0, 1), dtype=np.float) 108 | pt_11 = np.array((self.width-1, self.height-1, 1), dtype=np.float) 109 | corner = np.stack((pt_00, pt_01, pt_10, pt_11), axis=0) 110 | return corner 111 | 112 | 113 | class MeanMatchingAccuracy(object): 114 | 115 | def __init__(self, epsilon): 116 | self.epsilon = epsilon 117 | self.sum_accuracy = 0 118 | self.sum_sample_num = 0 119 | 120 | # 不匹配点间距离统计相关 121 | self.sum_outlier_ratio = [0, 0, 0, 0, 0] 122 | 123 | def reset(self): 124 | self.sum_accuracy = 0 125 | self.sum_sample_num = 0 126 | 127 | # 不匹配点间距离统计相关 128 | self.sum_outlier_ratio = [0, 0, 0, 0, 0] 129 | 130 | def update(self, gt_homography, matched_point): 131 | """ 132 | 计算单个样本对的匹配准确度 133 | Args: 134 | gt_homography: 该样本对的单应变换真值 135 | matched_point: List or Array. 136 | matched_point[0]是source image上的点,顺序为(y,x) 137 | matched_point[1]是target image上的点,顺序为(y,x), 138 | 计算要先颠倒顺序为(x,y) 139 | """ 140 | inv_homography = np.linalg.inv(gt_homography) 141 | src_point, tgt_point = matched_point[0], matched_point[1] 142 | src_point = src_point[:, ::-1] 143 | tgt_point = tgt_point[:, ::-1] 144 | num_matched = np.shape(src_point)[0] 145 | ones = np.ones((num_matched, 1), dtype=np.float) 146 | 147 | homo_src_point = np.concatenate((src_point, ones), axis=1) 148 | homo_tgt_point = np.concatenate((tgt_point, ones), axis=1) 149 | 150 | project_src_point = np.matmul(gt_homography, homo_src_point[:, :, np.newaxis])[:, :, 0] 151 | project_tgt_point = np.matmul(inv_homography, homo_tgt_point[:, :, np.newaxis])[:, :, 0] 152 | 153 | project_src_point = project_src_point[:, :2] / project_src_point[:, 2:3] 154 | project_tgt_point = project_tgt_point[:, :2] / project_tgt_point[:, 2:3] 155 | 156 | dist_src = np.linalg.norm(tgt_point - project_src_point, axis=1) 157 | dist_tgt = np.linalg.norm(src_point - project_tgt_point, axis=1) 158 | 159 | dist_all = np.concatenate((dist_src, dist_tgt)) 160 | self.statistic_dist(dist_all) 161 | 162 | correct_src = (dist_src <= self.epsilon) 163 | correct_tgt = (dist_tgt <= self.epsilon) 164 | correct = (correct_src & correct_tgt).astype(np.float) 165 | correct_ratio = np.mean(correct) 166 | self.sum_accuracy += correct_ratio 167 | self.sum_sample_num += 1 168 | 169 | def statistic_dist(self, dist): 170 | """ 171 | 统计不匹配的点间距离的分布情况,分别统计[0,e/2], (e/2,e], (e,2e], (2e,4e], (4e,+)五个区间中分布的百分比 172 | Args: 173 | dist: [n,] n个不匹配点间的距离 174 | """ 175 | count_0 = (dist <= 0.5*self.epsilon).astype(np.float) 176 | count_1 = ((dist > 0.5*self.epsilon) & (dist <= self.epsilon)).astype(np.float) 177 | count_2 = ((dist > self.epsilon) & (dist <= 2*self.epsilon)).astype(np.float) # (e,2e] 178 | count_3 = ((dist > 2*self.epsilon) & (dist <= 4*self.epsilon)).astype(np.float) # (2e,4e] 179 | count_4 = (dist > 4*self.epsilon).astype(np.float) # (4e,+) 180 | 181 | ratio_0 = np.mean(count_0) 182 | ratio_1 = np.mean(count_1) 183 | ratio_2 = np.mean(count_2) 184 | ratio_3 = np.mean(count_3) 185 | ratio_4 = np.mean(count_4) 186 | 187 | self.sum_outlier_ratio[0] += ratio_0 188 | self.sum_outlier_ratio[1] += ratio_1 189 | self.sum_outlier_ratio[2] += ratio_2 190 | self.sum_outlier_ratio[3] += ratio_3 191 | self.sum_outlier_ratio[4] += ratio_4 192 | 193 | def average(self): 194 | """ 195 | Returns: 平均匹配准确度 196 | """ 197 | if self.sum_sample_num == 0: 198 | return 0, 0, 0 199 | return self.sum_accuracy/self.sum_sample_num, self.sum_accuracy, self.sum_sample_num 200 | 201 | def average_outlier(self): 202 | """ 203 | 返回outlier重投影误差在各个区间的比例 204 | """ 205 | if self.sum_sample_num == 0: 206 | return 0, 0, 0, 0, 0 207 | avg_ratio_0 = self.sum_outlier_ratio[0] / self.sum_sample_num 208 | avg_ratio_1 = self.sum_outlier_ratio[1] / self.sum_sample_num 209 | avg_ratio_2 = self.sum_outlier_ratio[2] / self.sum_sample_num 210 | avg_ratio_3 = self.sum_outlier_ratio[3] / self.sum_sample_num 211 | avg_ratio_4 = self.sum_outlier_ratio[4] / self.sum_sample_num 212 | return avg_ratio_0, avg_ratio_1, avg_ratio_2, avg_ratio_3, avg_ratio_4 213 | 214 | 215 | class RepeatabilityCalculator(object): 216 | 217 | def __init__(self, epsilon, height, width): 218 | self.epsilon = epsilon 219 | self.sum_repeatability = 0 220 | self.sum_sample_num = 0 221 | self.height = height 222 | self.width = width 223 | 224 | def reset(self): 225 | self.sum_repeatability = 0 226 | self.sum_sample_num = 0 227 | 228 | def update(self, point_0, point_1, homography, return_repeat=False): 229 | repeatability, repeat_0, nonrepeat_0, repeat_1, nonrepeat_1 = self.compute_one_sample_repeatability( 230 | point_0, point_1, homography) 231 | self.sum_repeatability += repeatability 232 | self.sum_sample_num += 1 233 | if return_repeat: 234 | return repeat_0, nonrepeat_0, repeat_1, nonrepeat_1 235 | 236 | def average(self): 237 | if self.sum_sample_num == 0: 238 | return 0, 0, 0 239 | average_repeatability = self.sum_repeatability/self.sum_sample_num 240 | return average_repeatability, self.sum_repeatability, self.sum_sample_num 241 | 242 | def compute_one_sample_repeatability(self, point_0, point_1, homography): 243 | inv_homography = np.linalg.inv(homography) 244 | 245 | num_0 = np.shape(point_0)[0] 246 | num_1 = np.shape(point_1)[0] 247 | one_0 = np.ones((num_0, 1), dtype=np.float) 248 | one_1 = np.ones((num_1, 1), dtype=np.float) 249 | 250 | # recover to the original size and flip the order (y,x) to (x,y) 251 | point_0 = point_0[:, ::-1] 252 | point_1 = point_1[:, ::-1] 253 | homo_point_0 = np.concatenate((point_0, one_0), axis=1)[:, :, np.newaxis] # [n, 3, 1] 254 | homo_point_1 = np.concatenate((point_1, one_1), axis=1)[:, :, np.newaxis] 255 | 256 | # compute correctness from 0 to 1 257 | project_point_0 = np.matmul(homography, homo_point_0) 258 | project_point_0 = project_point_0[:, :2, 0] / project_point_0[:, 2:3, 0] 259 | project_point_0, inlier_point_0 = self._exclude_outlier(project_point_0, point_0) 260 | if project_point_0.size > 0: 261 | correctness_0_1, repeat_0 = self.compute_correctness(project_point_0, point_1) 262 | else: 263 | correctness_0_1 = 0 264 | repeat_0 = None 265 | 266 | repeat_list_0 = [] 267 | nonrepeat_list_0 = [] 268 | if repeat_0 is not None: 269 | for i in range(repeat_0.size): 270 | if repeat_0[i]: 271 | repeat_list_0.append(inlier_point_0[i]) 272 | else: 273 | nonrepeat_list_0.append(inlier_point_0[i]) 274 | if len(repeat_list_0) > 0: 275 | repeat_0 = np.stack(repeat_list_0, axis=0)[:, ::-1] # y,x顺序 276 | else: 277 | repeat_0 = np.empty((0, 2)) 278 | if len(nonrepeat_list_0) > 0: 279 | nonrepeat_0 = np.stack(nonrepeat_list_0, axis=0)[:, ::-1] 280 | else: 281 | nonrepeat_0 = np.empty((0, 2)) 282 | else: 283 | repeat_0 = np.empty((0, 2)) 284 | nonrepeat_0 = np.empty((0, 2)) 285 | 286 | # compute correctness from 1 to 0 287 | project_point_1 = np.matmul(inv_homography, homo_point_1) 288 | project_point_1 = project_point_1[:, :2, 0] / project_point_1[:, 2:3, 0] 289 | project_point_1, inlier_point_1 = self._exclude_outlier(project_point_1, point_1) 290 | if project_point_1.size > 0: 291 | correctness_1_0, repeat_1 = self.compute_correctness(project_point_1, point_0) 292 | else: 293 | correctness_1_0 = 0 294 | repeat_1 = None 295 | 296 | repeat_list_1 = [] 297 | nonrepeat_list_1 = [] 298 | if repeat_1 is not None: 299 | for i in range(repeat_1.size): 300 | if repeat_1[i]: 301 | repeat_list_1.append(inlier_point_1[i]) 302 | else: 303 | nonrepeat_list_1.append(inlier_point_1[i]) 304 | if len(repeat_list_1) > 0: 305 | repeat_1 = np.stack(repeat_list_1, axis=0)[:, ::-1] # y,x顺序 306 | else: 307 | repeat_1 = np.empty((0, 2)) 308 | if len(nonrepeat_list_1) > 0: 309 | nonrepeat_1 = np.stack(nonrepeat_list_1, axis=0)[:, ::-1] 310 | else: 311 | nonrepeat_1 = np.empty((0, 2)) 312 | else: 313 | repeat_1 = np.empty((0, 2)) 314 | nonrepeat_1 = np.empty((0, 2)) 315 | 316 | # compute repeatability 317 | total_point = np.shape(project_point_0)[0] + np.shape(project_point_1)[0] 318 | repeatability = (correctness_0_1 + correctness_1_0) / (total_point + 1e-3) 319 | return repeatability, repeat_0, nonrepeat_0, repeat_1, nonrepeat_1 320 | 321 | def _exclude_outlier(self, point, org_point): 322 | inlier = [] 323 | org_inlier = [] 324 | for i in range(point.shape[0]): 325 | x, y = point[i] 326 | if x < 0 or x > self.width - 1: 327 | continue 328 | if y < 0 or y > self.height - 1: 329 | continue 330 | inlier.append(point[i]) 331 | org_inlier.append(org_point[i]) 332 | if len(inlier) > 0: 333 | return np.stack(inlier, axis=0), np.stack(org_inlier, axis=0) 334 | else: 335 | return np.empty((0, 2)), np.empty((0, 2)) 336 | 337 | def compute_correctness(self, point_0, point_1): 338 | # compute the distance of two set of point 339 | # point_0: [n, 2], point_1: [m,2] 340 | point_0 = np.expand_dims(point_0, axis=1) # [n, 1, 2] 341 | point_1 = np.expand_dims(point_1, axis=0) # [1, m, 2] 342 | dist = np.linalg.norm(point_0 - point_1, axis=2) # [n, m] 343 | 344 | min_dist = np.min(dist, axis=1, keepdims=False) # [n] 345 | repeat = np.less_equal(min_dist, self.epsilon) 346 | correctness = np.sum(repeat.astype(np.float)) 347 | 348 | return correctness, repeat 349 | 350 | 351 | class mAPCalculator(object): 352 | 353 | def __init__(self): 354 | self.tp = [] 355 | self.fp = [] 356 | self.prob = [] 357 | self.total_num = 0 358 | 359 | def reset(self): 360 | self.tp = [] 361 | self.fp = [] 362 | self.prob = [] 363 | self.total_num = 0 364 | 365 | def update(self, org_prob, gt_point): 366 | tp, fp, prob, n_gt = self._compute_tp_fp(org_prob, gt_point) 367 | self.tp.append(tp) 368 | self.fp.append(fp) 369 | self.prob.append(prob) 370 | self.total_num += n_gt 371 | 372 | def compute_mAP(self): 373 | if len(self.tp) == 0: 374 | print("There has nothing to compute from! Please Check!") 375 | return 376 | tp = np.concatenate(self.tp) 377 | fp = np.concatenate(self.fp) 378 | prob = np.concatenate(self.prob) 379 | 380 | # 对整体进行排序 381 | sort_idx = np.argsort(prob)[::-1] 382 | tp = tp[sort_idx] 383 | fp = fp[sort_idx] 384 | prob = prob[sort_idx] 385 | 386 | # 进行累加计算 387 | tp_cum = np.cumsum(tp) 388 | fp_cum = np.cumsum(fp) 389 | recall = tp_cum / self.total_num 390 | precision = tp_cum / (tp_cum + fp_cum) 391 | prob = np.concatenate([[1], prob, [0]]) 392 | recall = np.concatenate([[0], recall, [1]]) 393 | precision = np.concatenate([[0], precision, [0]]) 394 | mAP = np.sum(precision[1:] * (recall[1:] - recall[:-1])) 395 | 396 | test_data = np.stack((recall, precision, prob), axis=0) 397 | return mAP, test_data 398 | 399 | def plot_threshold_curve(self, test_data, curve_name, curve_dir): 400 | recall = test_data[0, 1:-1] 401 | precision = test_data[1, 1:-1] 402 | prob = test_data[2, 1:-1] 403 | 404 | tmp_idx = np.where(prob <= 0.15) 405 | recall = recall[tmp_idx] 406 | precision = precision[tmp_idx] 407 | prob = prob[tmp_idx] 408 | title = curve_name 409 | 410 | plt.figure(figsize=(10, 5)) 411 | x_ticks = np.arange(0, 1, 0.01) 412 | y_ticks = np.arange(0, 1, 0.05) 413 | plt.title(title) 414 | plt.xticks(x_ticks) 415 | plt.yticks(y_ticks) 416 | plt.xlabel('probability threshold') 417 | plt.plot(prob, recall, label='recall') 418 | plt.plot(prob, precision, label='precision') 419 | plt.legend(loc='lower right') 420 | plt.grid() 421 | plt.savefig(curve_dir) 422 | 423 | @staticmethod 424 | def _compute_tp_fp(prob, gt_point, remove_zero=1e-4, distance_thresh=2): 425 | # 这里只能计算一个样本的tp以及fp,而不是一个batch 426 | assert len(np.shape(prob)) == 2 427 | 428 | mask = np.where(prob > remove_zero) 429 | # 留下满足满足要求的点 430 | prob = prob[mask] 431 | # 得到对应点的坐标, [n, 2] 432 | pred = np.array(mask).T 433 | 434 | sort_idx = np.argsort(prob)[::-1] 435 | prob = prob[sort_idx] 436 | pred = pred[sort_idx] 437 | 438 | # 得到每个点与真值点间的距离,最终得到[n,m]的距离表达式 439 | diff = np.expand_dims(pred, axis=1) - np.expand_dims(gt_point, axis=0) 440 | dist = np.linalg.norm(diff, axis=-1) 441 | matches = np.less_equal(dist, distance_thresh) 442 | 443 | tp = [] 444 | matched = np.zeros(np.shape(gt_point)[0]) 445 | for m in matches: 446 | correct = np.any(m) 447 | if correct: 448 | gt_idx = np.argmax(m) 449 | # 已匹配则为False 450 | tp.append(not matched[gt_idx]) 451 | # 标记已匹配的点 452 | matched[gt_idx] = 1 453 | else: 454 | tp.append(False) 455 | tp = np.array(tp, bool) 456 | fp = np.logical_not(tp) 457 | n_gt = np.shape(gt_point)[0] 458 | 459 | return tp, fp, prob, n_gt 460 | -------------------------------------------------------------------------------- /evaluation_hpatch/utils/evaluation_tools.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/2/23 3 | # 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class MovingAverage(object): 9 | 10 | def __init__(self, max_size=10): 11 | self.max_size = max_size 12 | self.insert_pos = 0 13 | self.queue = [] 14 | 15 | def reset(self): 16 | self.queue = [] 17 | self.insert_pos = 0 18 | 19 | def push(self, x): 20 | current_size = len(self.queue) 21 | if current_size < self.max_size: 22 | self.queue.append(x) 23 | else: 24 | self.queue[int(self.insert_pos % self.max_size)] = x 25 | self.insert_pos += 1 26 | 27 | def average(self): 28 | if len(self.queue) == 0: 29 | return 0 30 | current_queue = np.array(self.queue) 31 | avg = np.mean(current_queue) 32 | return avg 33 | 34 | def current_size(self): 35 | return len(self.queue) 36 | 37 | 38 | class PointStatistics(object): 39 | 40 | def __init__(self): 41 | self.point_num_list = [] 42 | self.sample_num = 0 43 | 44 | def reset(self): 45 | self.point_num_list = [] 46 | self.sample_num = 0 47 | 48 | def average(self): 49 | avg = 0 50 | var = 0 51 | std = 0 52 | 53 | for pt_num in self.point_num_list: 54 | avg += pt_num 55 | if self.sample_num == 0: 56 | return 0, 0 57 | avg /= self.sample_num 58 | 59 | for pt_num in self.point_num_list: 60 | var += (pt_num-avg)**2 61 | var /= self.sample_num 62 | std = np.sqrt(var) 63 | return avg, std 64 | 65 | def update(self, point_num): 66 | self.point_num_list.append(point_num) 67 | self.sample_num += 1 68 | 69 | 70 | class HomoAccuracyCalculator(object): 71 | 72 | def __init__(self, epsilon): 73 | self.epsilon = epsilon 74 | self.sum_accuracy = 0 75 | self.sum_sample_num = 0 76 | 77 | def reset(self): 78 | self.sum_accuracy = 0 79 | self.sum_sample_num = 0 80 | 81 | def average(self): 82 | if self.sum_sample_num == 0: 83 | return 0, 0, 0 84 | return self.sum_accuracy / self.sum_sample_num, self.sum_accuracy, self.sum_sample_num 85 | 86 | def update(self, pred_homography, gt_homography, shape_0, return_diff=False): 87 | corner = self._generate_corner(height=shape_0[0], width=shape_0[1]) 88 | warped_corner_by_pred = np.matmul(pred_homography, corner[:, :, np.newaxis])[:, :, 0] 89 | warped_corner_by_gt = np.matmul(gt_homography, corner[:, :, np.newaxis])[:, :, 0] 90 | warped_corner_by_pred = warped_corner_by_pred[:, :2] / warped_corner_by_pred[:, 2:3] 91 | warped_corner_by_gt = warped_corner_by_gt[:, :2] / warped_corner_by_gt[:, 2:3] 92 | diff = np.linalg.norm((warped_corner_by_pred-warped_corner_by_gt), axis=1, keepdims=False) 93 | diff = np.mean(diff) 94 | accuracy = (diff <= self.epsilon).astype(np.float) 95 | self.sum_accuracy += accuracy 96 | self.sum_sample_num += 1 97 | if not return_diff: 98 | return accuracy.astype(np.bool) 99 | else: 100 | return accuracy.astype(np.bool), diff 101 | 102 | @staticmethod 103 | def _generate_corner(height, width): 104 | pt_00 = np.array((0, 0, 1), dtype=np.float) 105 | pt_01 = np.array((0, height-1, 1), dtype=np.float) 106 | pt_10 = np.array((width-1, 0, 1), dtype=np.float) 107 | pt_11 = np.array((width-1, height-1, 1), dtype=np.float) 108 | corner = np.stack((pt_00, pt_01, pt_10, pt_11), axis=0) 109 | return corner 110 | 111 | 112 | class MeanMatchingAccuracy(object): 113 | 114 | def __init__(self, epsilon): 115 | self.epsilon = epsilon 116 | self.sum_accuracy = 0 117 | self.sum_sample_num = 0 118 | 119 | # 不匹配点间距离统计相关 120 | self.sum_outlier_ratio = [0, 0, 0, 0, 0] 121 | 122 | def reset(self): 123 | self.sum_accuracy = 0 124 | self.sum_sample_num = 0 125 | 126 | # 不匹配点间距离统计相关 127 | self.sum_outlier_ratio = [0, 0, 0, 0, 0] 128 | 129 | def update(self, gt_homography, matched_point): 130 | """ 131 | 计算单个样本对的匹配准确度 132 | Args: 133 | gt_homography: 该样本对的单应变换真值 134 | matched_point: List or Array. 135 | matched_point[0]是source image上的点,顺序为(y,x) 136 | matched_point[1]是target image上的点,顺序为(y,x), 137 | 计算要先颠倒顺序为(x,y) 138 | """ 139 | # inv_homography = np.linalg.inv(gt_homography) 140 | src_point, tgt_point = matched_point[0], matched_point[1] 141 | src_point = src_point[:, ::-1] 142 | tgt_point = tgt_point[:, ::-1] 143 | num_matched = np.shape(src_point)[0] 144 | ones = np.ones((num_matched, 1), dtype=np.float) 145 | 146 | homo_src_point = np.concatenate((src_point, ones), axis=1) 147 | # homo_tgt_point = np.concatenate((tgt_point, ones), axis=1) 148 | 149 | project_src_point = np.matmul(gt_homography, homo_src_point[:, :, np.newaxis])[:, :, 0] 150 | # project_tgt_point = np.matmul(inv_homography, homo_tgt_point[:, :, np.newaxis])[:, :, 0] 151 | 152 | project_src_point = project_src_point[:, :2] / project_src_point[:, 2:3] 153 | # project_tgt_point = project_tgt_point[:, :2] / project_tgt_point[:, 2:3] 154 | 155 | dist_src = np.linalg.norm(tgt_point - project_src_point, axis=1) 156 | # dist_tgt = np.linalg.norm(src_point - project_tgt_point, axis=1) 157 | 158 | # dist_all = np.concatenate((dist_src, dist_tgt)) 159 | # self.statistic_dist(dist_all) 160 | 161 | correct_src = (dist_src <= self.epsilon) 162 | # correct_tgt = (dist_tgt <= self.epsilon) 163 | # correct = (correct_src & correct_tgt).astype(np.float) 164 | # correct_ratio = np.mean(correct) 165 | correct_ratio = np.mean(correct_src) 166 | self.sum_accuracy += correct_ratio 167 | self.sum_sample_num += 1 168 | 169 | def statistic_dist(self, dist, epsilon=3): 170 | """ 171 | 统计不匹配的点间距离的分布情况,分别统计[0,e/2], (e/2,e], (e,2e], (2e,4e], (4e,+)五个区间中分布的百分比 172 | Args: 173 | dist: [n,] n个不匹配点间的距离 174 | """ 175 | count_0 = (dist <= 0.5*epsilon).astype(np.float) 176 | count_1 = ((dist > 0.5*epsilon) & (dist <= epsilon)).astype(np.float) 177 | count_2 = ((dist > epsilon) & (dist <= 2*epsilon)).astype(np.float) # (e,2e] 178 | count_3 = ((dist > 2*epsilon) & (dist <= 4*epsilon)).astype(np.float) # (2e,4e] 179 | count_4 = (dist > 4*epsilon).astype(np.float) # (4e,+) 180 | 181 | ratio_0 = np.mean(count_0) 182 | ratio_1 = np.mean(count_1) 183 | ratio_2 = np.mean(count_2) 184 | ratio_3 = np.mean(count_3) 185 | ratio_4 = np.mean(count_4) 186 | 187 | self.sum_outlier_ratio[0] += ratio_0 188 | self.sum_outlier_ratio[1] += ratio_1 189 | self.sum_outlier_ratio[2] += ratio_2 190 | self.sum_outlier_ratio[3] += ratio_3 191 | self.sum_outlier_ratio[4] += ratio_4 192 | 193 | def average(self): 194 | """ 195 | Returns: 平均匹配准确度 196 | """ 197 | if self.sum_sample_num == 0: 198 | return 0, 0, 0 199 | return self.sum_accuracy/self.sum_sample_num, self.sum_accuracy, self.sum_sample_num 200 | 201 | def average_outlier(self): 202 | """ 203 | 返回outlier重投影误差在各个区间的比例 204 | """ 205 | if self.sum_sample_num == 0: 206 | return 0, 0, 0, 0, 0 207 | avg_ratio_0 = self.sum_outlier_ratio[0] / self.sum_sample_num 208 | avg_ratio_1 = self.sum_outlier_ratio[1] / self.sum_sample_num 209 | avg_ratio_2 = self.sum_outlier_ratio[2] / self.sum_sample_num 210 | avg_ratio_3 = self.sum_outlier_ratio[3] / self.sum_sample_num 211 | avg_ratio_4 = self.sum_outlier_ratio[4] / self.sum_sample_num 212 | return avg_ratio_0, avg_ratio_1, avg_ratio_2, avg_ratio_3, avg_ratio_4 213 | 214 | 215 | class RepeatabilityCalculator(object): 216 | 217 | def __init__(self, epsilon): 218 | self.epsilon = epsilon 219 | self.sum_repeatability = 0 220 | self.sum_sample_num = 0 221 | 222 | def reset(self): 223 | self.sum_repeatability = 0 224 | self.sum_sample_num = 0 225 | 226 | def update(self, point_0, point_1, homography, shape_0, shape_1, return_repeat=False): 227 | repeatability, repeat_0, nonrepeat_0, repeat_1, nonrepeat_1 = self.compute_one_sample_repeatability( 228 | point_0, point_1, homography, shape_0=shape_0, shape_1=shape_1) 229 | self.sum_repeatability += repeatability 230 | self.sum_sample_num += 1 231 | if return_repeat: 232 | return repeat_0, nonrepeat_0, repeat_1, nonrepeat_1 233 | 234 | def average(self): 235 | if self.sum_sample_num == 0: 236 | return 0, 0, 0 237 | average_repeatability = self.sum_repeatability/self.sum_sample_num 238 | return average_repeatability, self.sum_repeatability, self.sum_sample_num 239 | 240 | def compute_one_sample_repeatability(self, point_0, point_1, homography, shape_0, shape_1): 241 | inv_homography = np.linalg.inv(homography) 242 | 243 | num_0 = np.shape(point_0)[0] 244 | num_1 = np.shape(point_1)[0] 245 | one_0 = np.ones((num_0, 1), dtype=np.float) 246 | one_1 = np.ones((num_1, 1), dtype=np.float) 247 | 248 | # recover to the original size and flip the order (y,x) to (x,y) 249 | point_0 = point_0[:, ::-1] 250 | point_1 = point_1[:, ::-1] 251 | homo_point_0 = np.concatenate((point_0, one_0), axis=1)[:, :, np.newaxis] # [n, 3, 1] 252 | homo_point_1 = np.concatenate((point_1, one_1), axis=1)[:, :, np.newaxis] 253 | 254 | # compute correctness from 0 to 1 255 | project_point_0 = np.matmul(homography, homo_point_0) 256 | project_point_0 = project_point_0[:, :2, 0] / project_point_0[:, 2:3, 0] 257 | project_point_0, inlier_point_0 = self._exclude_outlier(project_point_0, point_0, height=shape_1[0], width=shape_1[1]) 258 | if project_point_0.size > 0: 259 | correctness_0_1, repeat_0 = self.compute_correctness(project_point_0, point_1) 260 | else: 261 | correctness_0_1 = 0 262 | repeat_0 = None 263 | 264 | repeat_list_0 = [] 265 | nonrepeat_list_0 = [] 266 | if repeat_0 is not None: 267 | for i in range(repeat_0.size): 268 | if repeat_0[i]: 269 | repeat_list_0.append(inlier_point_0[i]) 270 | else: 271 | nonrepeat_list_0.append(inlier_point_0[i]) 272 | if len(repeat_list_0) > 0: 273 | repeat_0 = np.stack(repeat_list_0, axis=0)[:, ::-1] # y,x顺序 274 | else: 275 | repeat_0 = np.empty((0, 2)) 276 | if len(nonrepeat_list_0) > 0: 277 | nonrepeat_0 = np.stack(nonrepeat_list_0, axis=0)[:, ::-1] 278 | else: 279 | nonrepeat_0 = np.empty((0, 2)) 280 | else: 281 | repeat_0 = np.empty((0, 2)) 282 | nonrepeat_0 = np.empty((0, 2)) 283 | 284 | # compute correctness from 1 to 0 285 | project_point_1 = np.matmul(inv_homography, homo_point_1) 286 | project_point_1 = project_point_1[:, :2, 0] / project_point_1[:, 2:3, 0] 287 | project_point_1, inlier_point_1 = self._exclude_outlier(project_point_1, point_1, height=shape_0[0], width=shape_0[1]) 288 | if project_point_1.size > 0: 289 | correctness_1_0, repeat_1 = self.compute_correctness(project_point_1, point_0) 290 | else: 291 | correctness_1_0 = 0 292 | repeat_1 = None 293 | 294 | repeat_list_1 = [] 295 | nonrepeat_list_1 = [] 296 | if repeat_1 is not None: 297 | for i in range(repeat_1.size): 298 | if repeat_1[i]: 299 | repeat_list_1.append(inlier_point_1[i]) 300 | else: 301 | nonrepeat_list_1.append(inlier_point_1[i]) 302 | if len(repeat_list_1) > 0: 303 | repeat_1 = np.stack(repeat_list_1, axis=0)[:, ::-1] # y,x顺序 304 | else: 305 | repeat_1 = np.empty((0, 2)) 306 | if len(nonrepeat_list_1) > 0: 307 | nonrepeat_1 = np.stack(nonrepeat_list_1, axis=0)[:, ::-1] 308 | else: 309 | nonrepeat_1 = np.empty((0, 2)) 310 | else: 311 | repeat_1 = np.empty((0, 2)) 312 | nonrepeat_1 = np.empty((0, 2)) 313 | 314 | # compute repeatability 315 | total_point = np.shape(project_point_0)[0] + np.shape(project_point_1)[0] 316 | repeatability = (correctness_0_1 + correctness_1_0) / (total_point + 1e-3) 317 | return repeatability, repeat_0, nonrepeat_0, repeat_1, nonrepeat_1 318 | 319 | @staticmethod 320 | def _exclude_outlier(point, org_point, height, width): 321 | inlier = [] 322 | org_inlier = [] 323 | for i in range(point.shape[0]): 324 | x, y = point[i] 325 | if x < 0 or x > width - 1: 326 | continue 327 | if y < 0 or y > height - 1: 328 | continue 329 | inlier.append(point[i]) 330 | org_inlier.append(org_point[i]) 331 | if len(inlier) > 0: 332 | return np.stack(inlier, axis=0), np.stack(org_inlier, axis=0) 333 | else: 334 | return np.empty((0, 2)), np.empty((0, 2)) 335 | 336 | def compute_correctness(self, point_0, point_1): 337 | # compute the distance of two set of point 338 | # point_0: [n, 2], point_1: [m,2] 339 | point_0 = np.expand_dims(point_0, axis=1) # [n, 1, 2] 340 | point_1 = np.expand_dims(point_1, axis=0) # [1, m, 2] 341 | dist = np.linalg.norm(point_0 - point_1, axis=2) # [n, m] 342 | 343 | min_dist = np.min(dist, axis=1, keepdims=False) # [n] 344 | repeat = np.less_equal(min_dist, self.epsilon) 345 | correctness = np.sum(repeat.astype(np.float)) 346 | 347 | return correctness, repeat 348 | 349 | 350 | class mAPCalculator(object): 351 | 352 | def __init__(self): 353 | self.tp = [] 354 | self.fp = [] 355 | self.prob = [] 356 | self.total_num = 0 357 | 358 | def reset(self): 359 | self.tp = [] 360 | self.fp = [] 361 | self.prob = [] 362 | self.total_num = 0 363 | 364 | def update(self, org_prob, gt_point): 365 | tp, fp, prob, n_gt = self._compute_tp_fp(org_prob, gt_point) 366 | self.tp.append(tp) 367 | self.fp.append(fp) 368 | self.prob.append(prob) 369 | self.total_num += n_gt 370 | 371 | def compute_mAP(self): 372 | if len(self.tp) == 0: 373 | print("There has nothing to compute from! Please Check!") 374 | return 375 | tp = np.concatenate(self.tp) 376 | fp = np.concatenate(self.fp) 377 | prob = np.concatenate(self.prob) 378 | 379 | # 对整体进行排序 380 | sort_idx = np.argsort(prob)[::-1] 381 | tp = tp[sort_idx] 382 | fp = fp[sort_idx] 383 | prob = prob[sort_idx] 384 | 385 | # 进行累加计算 386 | tp_cum = np.cumsum(tp) 387 | fp_cum = np.cumsum(fp) 388 | recall = tp_cum / self.total_num 389 | precision = tp_cum / (tp_cum + fp_cum) 390 | prob = np.concatenate([[1], prob, [0]]) 391 | recall = np.concatenate([[0], recall, [1]]) 392 | precision = np.concatenate([[0], precision, [0]]) 393 | mAP = np.sum(precision[1:] * (recall[1:] - recall[:-1])) 394 | 395 | test_data = np.stack((recall, precision, prob), axis=0) 396 | return mAP, test_data 397 | 398 | def plot_threshold_curve(self, test_data, curve_name, curve_dir): 399 | recall = test_data[0, 1:-1] 400 | precision = test_data[1, 1:-1] 401 | prob = test_data[2, 1:-1] 402 | 403 | tmp_idx = np.where(prob <= 0.15) 404 | recall = recall[tmp_idx] 405 | precision = precision[tmp_idx] 406 | prob = prob[tmp_idx] 407 | title = curve_name 408 | 409 | plt.figure(figsize=(10, 5)) 410 | x_ticks = np.arange(0, 1, 0.01) 411 | y_ticks = np.arange(0, 1, 0.05) 412 | plt.title(title) 413 | plt.xticks(x_ticks) 414 | plt.yticks(y_ticks) 415 | plt.xlabel('probability threshold') 416 | plt.plot(prob, recall, label='recall') 417 | plt.plot(prob, precision, label='precision') 418 | plt.legend(loc='lower right') 419 | plt.grid() 420 | plt.savefig(curve_dir) 421 | 422 | @staticmethod 423 | def _compute_tp_fp(prob, gt_point, remove_zero=1e-4, distance_thresh=2): 424 | # 这里只能计算一个样本的tp以及fp,而不是一个batch 425 | assert len(np.shape(prob)) == 2 426 | 427 | mask = np.where(prob > remove_zero) 428 | # 留下满足满足要求的点 429 | prob = prob[mask] 430 | # 得到对应点的坐标, [n, 2] 431 | pred = np.array(mask).T 432 | 433 | sort_idx = np.argsort(prob)[::-1] 434 | prob = prob[sort_idx] 435 | pred = pred[sort_idx] 436 | 437 | # 得到每个点与真值点间的距离,最终得到[n,m]的距离表达式 438 | diff = np.expand_dims(pred, axis=1) - np.expand_dims(gt_point, axis=0) 439 | dist = np.linalg.norm(diff, axis=-1) 440 | matches = np.less_equal(dist, distance_thresh) 441 | 442 | tp = [] 443 | matched = np.zeros(np.shape(gt_point)[0]) 444 | for m in matches: 445 | correct = np.any(m) 446 | if correct: 447 | gt_idx = np.argmax(m) 448 | # 已匹配则为False 449 | tp.append(not matched[gt_idx]) 450 | # 标记已匹配的点 451 | matched[gt_idx] = 1 452 | else: 453 | tp.append(False) 454 | tp = np.array(tp, bool) 455 | fp = np.logical_not(tp) 456 | n_gt = np.shape(gt_point)[0] 457 | 458 | return tp, fp, prob, n_gt 459 | -------------------------------------------------------------------------------- /data_utils/megadepth_train_dataset_dl.py: -------------------------------------------------------------------------------- 1 | # 2 | # Created on 2020/6/29 3 | # 4 | import os 5 | from glob import glob 6 | from nets.network import MTLDesc 7 | import cv2 as cv 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset 11 | from nets.network import MTLDesc 12 | from data_utils.dataset_tools import HomographyAugmentation 13 | from data_utils.dataset_tools import ImgAugTransform 14 | from data_utils.dataset_tools import space_to_depth 15 | class MegaDepthTrainDataset(Dataset): 16 | """ 17 | Combination of MegaDetph and COCO 18 | """ 19 | def __init__(self, **config): 20 | self.data_list = self._format_file_list( 21 | config['mega_image_dir'], 22 | config['mega_keypoint_dir'], 23 | config['mega_despoint_dir'], 24 | config['mega_dl_dir1'], 25 | config['mega_dl_dir2'], 26 | ) 27 | self.sydesp_type=config['sydesp_type'] 28 | self.height = config['height'] 29 | self.width = config['width'] 30 | self.balance = config['balance'] 31 | self.homography = HomographyAugmentation() 32 | self.photometric = ImgAugTransform() 33 | self.fix_grid = self._generate_fixed_grid() 34 | 35 | def __len__(self): 36 | return len(self.data_list) 37 | 38 | def __getitem__(self, idx): 39 | data_info = self.data_list[idx] 40 | # """ 41 | if data_info['type'] == 'synthesis': 42 | 43 | return self._get_synthesis_data(data_info) 44 | elif data_info['type'] == 'real': 45 | 46 | return self._get_real_data(data_info) 47 | else: 48 | assert False 49 | # """ 50 | #return self._get_synthesis_data(data_info) 51 | def _get_real_data(self, data_info): 52 | image_dir = data_info['image'] 53 | info_dir = data_info['info'] 54 | label_dir = data_info['label'] 55 | dl_dir=data_info['dl1'] 56 | image12 = cv.imread(image_dir)[:, :, ::-1].copy() # 交换BGR为RGB 57 | image1, image2 = np.split(image12, 2, axis=1) 58 | h, w, _ = image1.shape 59 | 60 | if torch.rand([]).item() < 0.5: 61 | image1 = self.photometric(image1) 62 | image2 = self.photometric(image2) 63 | 64 | info = np.load(info_dir) 65 | desp_point1 = info["desp_point1"] 66 | desp_point2 = info["desp_point2"] 67 | valid_mask = info["valid_mask"] 68 | not_search_mask = info["not_search_mask"] 69 | 70 | label = np.load(label_dir) 71 | points1 = label["points_0"] 72 | points2 = label["points_1"] 73 | 74 | dl=np.load(dl_dir) 75 | dl_heatmap1 = dl["heatmap0"].squeeze(axis=0).squeeze(axis=0) 76 | dl_heatmap2 = dl["heatmap1"].squeeze(axis=0).squeeze(axis=0) 77 | dl_attmap1 = dl["attmap0"].squeeze(axis=0).squeeze(axis=0) 78 | dl_attmap2 = dl["attmap1"].squeeze(axis=0).squeeze(axis=0) 79 | dl_descriptor1 = dl["descriptor0"].squeeze(axis=0) 80 | dl_descriptor2 = dl["descriptor1"].squeeze(axis=0) 81 | 82 | 83 | # 2.1 得到第一副图点构成的热图 84 | heatmap1 = self._convert_points_to_heatmap(points1) 85 | point_mask1 = torch.ones_like(heatmap1) 86 | 87 | # 2.2 得到第二副图点构成的热图 88 | heatmap2 = self._convert_points_to_heatmap(points2) 89 | point_mask2 = torch.ones_like(heatmap2) 90 | 91 | image1 = (torch.from_numpy(image1).to(torch.float) * 2. / 255. - 1.).permute((2, 0, 1)).contiguous() 92 | image2 = (torch.from_numpy(image2).to(torch.float) * 2. / 255. - 1.).permute((2, 0, 1)).contiguous() 93 | 94 | desp_point1 = torch.from_numpy(desp_point1) 95 | desp_point2 = torch.from_numpy(desp_point2) 96 | 97 | valid_mask = torch.from_numpy(valid_mask).to(torch.float) 98 | not_search_mask = torch.from_numpy(not_search_mask).to(torch.float) 99 | 100 | return { 101 | "image": image1, 102 | "point_mask": point_mask1, 103 | "heatmap": heatmap1, 104 | "dl_heatmap": dl_heatmap1, 105 | "dl_attmap":dl_attmap1, 106 | "dl_descriptor":dl_descriptor1, 107 | "warped_image": image2, 108 | "warped_point_mask": point_mask2, 109 | "warped_heatmap": heatmap2, 110 | "warped_dl_heatmap": dl_heatmap2, 111 | "warped_dl_attmap": dl_attmap2, 112 | "warped_dl_descriptor": dl_descriptor2, 113 | "desp_point": desp_point1, 114 | "warped_desp_point": desp_point2, 115 | "valid_mask": valid_mask, 116 | "not_search_mask": not_search_mask, 117 | 118 | } 119 | 120 | def _get_synthesis_data(self, data_info): 121 | image12 = cv.imread(data_info['image'])[:, :, ::-1].copy() # 交换BGR为RGB 122 | image1, image2 = np.split(image12, 2, axis=1) 123 | point = np.load(data_info['label']) 124 | info = np.load(data_info['info']) 125 | dl = np.load(data_info['dl2']) 126 | 127 | if torch.rand([]).item() < 0.5: 128 | image = cv.resize(image1, dsize=(self.width, self.height), interpolation=cv.INTER_LINEAR) 129 | point = point["points_0"] 130 | desp_point_load = info["raw_desp_point1"] 131 | dl_heatmap1=dl["heatmap0"].squeeze(axis=0).squeeze(axis=0) 132 | dl_attmap1=dl["attmap0"].squeeze(axis=0).squeeze(axis=0) 133 | dl_descriptor1 = dl["descriptor0"].squeeze(axis=0) 134 | dl_homography = dl["homography1"] 135 | dl_heatmap2 = dl["hheatmap0"].squeeze(axis=0).squeeze(axis=0) 136 | dl_attmap2 = dl["hattmap0"].squeeze(axis=0).squeeze(axis=0) 137 | dl_descriptor2 = dl["hdescriptor0"].squeeze(axis=0) 138 | else: 139 | image = cv.resize(image2, dsize=(self.width, self.height), interpolation=cv.INTER_LINEAR) 140 | point = point["points_1"] 141 | desp_point_load = info["raw_desp_point2"] 142 | dl_heatmap1 = dl["heatmap1"].squeeze(axis=0).squeeze(axis=0) 143 | dl_attmap1 = dl["attmap1"].squeeze(axis=0).squeeze(axis=0) 144 | dl_descriptor1 = dl["descriptor1"].squeeze(axis=0) 145 | dl_homography = dl["homography2"] 146 | dl_heatmap2 = dl["hheatmap1"].squeeze(axis=0).squeeze(axis=0) 147 | dl_attmap2 = dl["hattmap1"].squeeze(axis=0).squeeze(axis=0) 148 | dl_descriptor2 = dl["hdescriptor1"].squeeze(axis=0) 149 | 150 | point_mask = np.ones_like(image).astype(np.float32)[:, :, 0].copy() 151 | 152 | # 1、由随机采样的单应变换得到第二副图像及其对应的关键点位置、原始掩膜和该单应变换 153 | if torch.rand([]).item() < self.balance: 154 | warped_image, warped_point_mask, warped_point, homography = \ 155 | image.copy(), point_mask.copy(), point.copy(), np.eye(3) 156 | dl_heatmap2 = dl_heatmap1.copy() 157 | dl_attmap2 = dl_attmap1.copy() 158 | dl_descriptor2 = dl_descriptor1.copy() 159 | 160 | else: 161 | warped_image, warped_point_mask, warped_point, homography = self.homography(image, point, return_homo=True,homo=dl_homography) 162 | warped_point_mask = warped_point_mask[:, :, 0].copy() 163 | 164 | #dl_attmap2 = cv.resize(dl_attmap2,(400,400)) 165 | #dl_attmap2 = torch.from_numpy(dl_attmap2) 166 | #min = torch.min(dl_attmap2) 167 | #meanmap = torch.where(dl_attmap2 > 0, dl_attmap2, min) 168 | #b = meanmap.squeeze(0).squeeze(0).cpu().numpy() 169 | #norm_img = np.ones(b.shape) 170 | #norm_img = cv.normalize(b, norm_img, 0, 255, cv.NORM_MINMAX) 171 | #norm_img = np.asarray(norm_img, dtype=np.uint8) 172 | #heat_img = cv.applyColorMap(norm_img, cv.COLORMAP_JET) # 注意此处的三通道热力图是cv2专有的GBR排列 173 | #img_add = cv.addWeighted(warped_image, 0.3, heat_img, 0.7, 0) 174 | #cv.imwrite('0005.png', img_add) 175 | #exit(0) 176 | 177 | 178 | 179 | if torch.rand([]).item() < 0.5: 180 | image = self.photometric(image) 181 | warped_image = self.photometric(warped_image) 182 | 183 | # 2.1 得到第一副图点构成的热图 184 | heatmap = self._convert_points_to_heatmap(point) 185 | 186 | # 2.2 得到第二副图点构成的热图 187 | warped_heatmap = self._convert_points_to_heatmap(warped_point) 188 | 189 | # 3、采样训练描述子要用的点 190 | if self.sydesp_type =='random': 191 | desp_point = self._random_sample_point() 192 | else: 193 | desp_point = desp_point_load 194 | 195 | shape = image.shape 196 | 197 | warped_desp_point, valid_mask, not_search_mask = self._generate_warped_point( 198 | desp_point, homography, shape[0], shape[1]) 199 | 200 | image = image.astype(np.float32) * 2. / 255. - 1. 201 | warped_image = warped_image.astype(np.float32) * 2. / 255. - 1. 202 | 203 | image = torch.from_numpy(image).permute((2, 0, 1)) 204 | warped_image = torch.from_numpy(warped_image).permute((2, 0, 1)) 205 | 206 | point_mask = torch.from_numpy(point_mask) 207 | warped_point_mask = torch.from_numpy(warped_point_mask) 208 | 209 | desp_point = torch.from_numpy(self._scale_point_for_sample(desp_point)) 210 | warped_desp_point = torch.from_numpy(self._scale_point_for_sample(warped_desp_point)) 211 | 212 | valid_mask = torch.from_numpy(valid_mask) 213 | not_search_mask = torch.from_numpy(not_search_mask) 214 | 215 | return { 216 | "image": image, # [1,h,w] 217 | "point_mask": point_mask, # [h,w] 218 | "heatmap": heatmap, # [h,w] 219 | "dl_heatmap": dl_heatmap1, 220 | "dl_attmap": dl_attmap1, 221 | "dl_descriptor": dl_descriptor1, 222 | "warped_image": warped_image, # [1,h,w] 223 | "warped_point_mask": warped_point_mask, # [h,w] 224 | "warped_heatmap": warped_heatmap, # [h,w] 225 | "warped_dl_heatmap": dl_heatmap2, 226 | "warped_dl_attmap": dl_attmap2, 227 | "warped_dl_descriptor": dl_descriptor2, 228 | "desp_point": desp_point, # [n,1,2] 229 | "warped_desp_point": warped_desp_point, # [n,1,2] 230 | "valid_mask": valid_mask, # [n] 231 | "not_search_mask": not_search_mask, # [n,n] 232 | } 233 | 234 | @ staticmethod 235 | def _generate_warped_point(point, homography, height, width, threshold=16): 236 | """ 237 | 根据投影变换得到变换后的坐标点,有效关系及不参与负样本搜索的矩阵 238 | Args: 239 | point: [n,2] 与warped_point一一对应 240 | homography: 点对之间的变换关系 241 | 242 | Returns: 243 | not_search_mask: [n,n] type为float32的mask,不搜索的位置为1 244 | """ 245 | # 得到投影点的坐标 246 | point = np.concatenate((point[:, ::-1], np.ones((point.shape[0], 1))), axis=1)[:, :, np.newaxis] # [n,3,1] 247 | project_point = np.matmul(homography, point)[:, :, 0] 248 | project_point = project_point[:, :2] / project_point[:, 2:3] 249 | project_point = project_point[:, ::-1] # 调换为y,x的顺序 250 | 251 | # 投影点在图像范围内的点为有效点,反之则为无效点 252 | boarder_0 = np.array((0, 0), dtype=np.float32) 253 | boarder_1 = np.array((height-1, width-1), dtype=np.float32) 254 | valid_mask = (project_point >= boarder_0) & (project_point <= boarder_1) 255 | valid_mask = np.all(valid_mask, axis=1) 256 | invalid_mask = ~valid_mask 257 | 258 | # 根据无效点及投影点之间的距离关系确定不搜索的负样本矩阵 259 | 260 | dist = np.linalg.norm(project_point[:, np.newaxis, :] - project_point[np.newaxis, :, :], axis=2) 261 | not_search_mask = ((dist <= threshold) | invalid_mask[np.newaxis, :]).astype(np.float32) 262 | return project_point.astype(np.float32), valid_mask.astype(np.float32), not_search_mask 263 | 264 | def _scale_point_for_sample(self, point): 265 | """ 266 | 将点归一化到[-1,1]的区间范围内,并调换顺序为x,y,方便采样 267 | Args: 268 | point: [n,2] y,x的顺序,原始范围为[0,height-1], [0,width-1] 269 | Returns: 270 | point: [n,1,2] x,y的顺序,范围为[-1,1] 271 | """ 272 | org_size = np.array((self.height-1, self.width-1), dtype=np.float32) 273 | point = ((point * 2. / org_size - 1.)[:, ::-1])[:, np.newaxis, :].copy() 274 | return point 275 | 276 | def _random_sample_point(self): 277 | """ 278 | 根据预设的输入图像大小,随机均匀采样坐标点 279 | """ 280 | grid = self.fix_grid.copy() 281 | # 随机选择指定数目个格子 282 | 283 | point_list = [] 284 | for i in range(grid.shape[0]): 285 | y_start, x_start, y_end, x_end = grid[i] 286 | rand_y = np.random.randint(y_start, y_end) 287 | rand_x = np.random.randint(x_start, x_end) 288 | point_list.append(np.array((rand_y, rand_x), dtype=np.float32)) 289 | point = np.stack(point_list, axis=0) 290 | 291 | return point 292 | 293 | def _generate_fixed_grid(self, option=None): 294 | """ 295 | 预先采样固定间隔的225个图像格子 296 | """ 297 | if option == None: 298 | y_num = 20 299 | x_num = 20 300 | else: 301 | y_num = option[0] 302 | x_num = option[1] 303 | 304 | grid_y = np.linspace(0, self.height-1, y_num+1, dtype=np.int) 305 | grid_x = np.linspace(0, self.width-1, x_num+1, dtype=np.int) 306 | 307 | grid_y_start = grid_y[:y_num].copy() 308 | grid_y_end = grid_y[1:y_num+1].copy() 309 | grid_x_start = grid_x[:x_num].copy() 310 | grid_x_end = grid_x[1:x_num+1].copy() 311 | 312 | grid_start = np.stack((np.tile(grid_y_start[:, np.newaxis], (1, x_num)), 313 | np.tile(grid_x_start[np.newaxis, :], (y_num, 1))), axis=2).reshape((-1, 2)) 314 | grid_end = np.stack((np.tile(grid_y_end[:, np.newaxis], (1, x_num)), 315 | np.tile(grid_x_end[np.newaxis, :], (y_num, 1))), axis=2).reshape((-1, 2)) 316 | grid = np.concatenate((grid_start, grid_end), axis=1) 317 | 318 | return grid 319 | 320 | def _convert_points_to_heatmap(self, points): 321 | """ 322 | 将原始点位置经下采样后得到heatmap与incmap,heatmap上对应下采样整型点位置处的值为1,其余为0;incmap与heatmap一一对应, 323 | 在关键点位置处存放整型点到亚像素角点的偏移量,以及训练时用来屏蔽非关键点inc量的incmap_valid 324 | Args: 325 | points: [n,2] 326 | 327 | Returns: 328 | heatmap: [h,w] 关键点位置为1,其余为0 329 | incmap: [2,h,w] 关键点位置存放实际偏移,其余非关键点处的偏移量为0 330 | incmap_valid: [h,w] 关键点位置为1,其余为0,用于训练时屏蔽对非关键点偏移量的训练,只关注关键点的偏移量 331 | 332 | """ 333 | height = self.height 334 | width = self.width 335 | 336 | # localmap = self.localmap.clone() 337 | # padded_heatmap = torch.zeros( 338 | # (height+self.g_paddings*2, width+self.g_paddings*2), dtype=torch.float) 339 | heatmap = torch.zeros((height, width), dtype=torch.float) 340 | 341 | num_pt = points.shape[0] 342 | if num_pt > 0: 343 | for i in range(num_pt): 344 | pt = points[i] 345 | pt_y_float, pt_x_float = pt 346 | 347 | pt_y_int = round(pt_y_float) 348 | pt_x_int = round(pt_x_float) 349 | 350 | pt_y = int(pt_y_int) # 对真值点位置进行下采样,这里有量化误差 351 | pt_x = int(pt_x_int) 352 | 353 | # 排除掉经下采样后在边界外的点 354 | if pt_y < 0 or pt_y > height - 1: 355 | continue 356 | if pt_x < 0 or pt_x > width - 1: 357 | continue 358 | 359 | # 关键点位置在heatmap上置1,并在incmap上记录该点离亚像素点的偏移量 360 | heatmap[pt_y, pt_x] = 1.0 361 | 362 | return heatmap 363 | 364 | def convert_points_to_label(self, points): 365 | 366 | height = self.height 367 | width = self.width 368 | n_height = int(height / 8) 369 | n_width = int(width / 8) 370 | assert n_height * 8 == height and n_width * 8 == width 371 | 372 | num_pt = points.shape[0] 373 | label = torch.zeros((height * width)) 374 | if num_pt > 0: 375 | points_h, points_w = torch.split(points, 1, dim=1) 376 | points_idx = points_w + points_h * width 377 | label = label.scatter_(dim=0, index=points_idx[:, 0], value=1.0).reshape((height, width)) 378 | else: 379 | label = label.reshape((height, width)) 380 | 381 | dense_label = space_to_depth(label) 382 | dense_label = torch.cat((dense_label, 0.5 * torch.ones((1, n_height, n_width))), dim=0) # [65, 30, 40] 383 | sparse_label = torch.argmax(dense_label, dim=0) # [30,40] 384 | 385 | return sparse_label 386 | @staticmethod 387 | def _format_file_list(mega_image_dir, mega_keypoint_dir,mega_despoint_dir,mega_dl_dir1,mega_dl_dir2): 388 | data_list = [] 389 | 390 | # format megadepth related list 391 | mega_image_list = glob(os.path.join(mega_image_dir, '*.jpg')) 392 | mega_image_list = sorted(mega_image_list) 393 | data_type = 'real' 394 | for img in mega_image_list: 395 | img_name = img.split('/')[-1].split('.')[0] 396 | info = os.path.join(mega_despoint_dir, img_name + '.npz') 397 | label = os.path.join(mega_keypoint_dir, img_name + '.npz') 398 | dl1 = os.path.join(mega_dl_dir1, img_name + '.npz') 399 | dl2 = os.path.join(mega_dl_dir2, img_name + '.npz') 400 | data_list.append( 401 | { 402 | 'type': data_type, 403 | 'image': img, 404 | 'info': info, 405 | 'label': label, 406 | 'dl1': dl1, 407 | 'dl2': dl2, 408 | } 409 | ) 410 | 411 | # format coco related list 412 | data_type = 'synthesis' 413 | for img in mega_image_list: 414 | img_name = img.split('/')[-1].split('.')[0] 415 | label = os.path.join(mega_keypoint_dir, img_name + '.npz') 416 | info = os.path.join(mega_despoint_dir, img_name + '.npz') 417 | dl1 = os.path.join(mega_dl_dir1, img_name + '.npz') 418 | dl2 = os.path.join(mega_dl_dir2, img_name + '.npz') 419 | data_list.append( 420 | { 421 | 'type': data_type, 422 | 'image': img, 423 | 'info': info, 424 | 'label': label, 425 | 'dl1': dl1, 426 | 'dl2': dl2, 427 | } 428 | ) 429 | 430 | return data_list 431 | 432 | 433 | 434 | 435 | 436 | 437 | -------------------------------------------------------------------------------- /trainers/superpoint_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | import torch.nn.functional as f 6 | from torch.utils.data import DataLoader 7 | 8 | from nets import get_model 9 | from data_utils import get_dataset 10 | from trainers.base_trainer import BaseTrainer 11 | from utils.utils import spatial_nms 12 | from utils.utils import DescriptorGeneralTripletLoss 13 | from utils.utils import PointHeatmapWeightedBCELoss 14 | 15 | 16 | class MegPointTrainer(BaseTrainer): 17 | 18 | def __init__(self, **config): 19 | super(MegPointTrainer, self).__init__(**config) 20 | self.point_weight = 1 21 | 22 | def _initialize_dataset(self): 23 | # 初始化数据集 24 | self.logger.info('Initialize {}'.format(self.config['train']['dataset'])) 25 | self.train_dataset = get_dataset(self.config['train']['dataset'])(**self.config['train']) 26 | 27 | self.train_dataloader = DataLoader( 28 | dataset=self.train_dataset, 29 | batch_size=self.config['train']['batch_size'], 30 | shuffle=True, 31 | num_workers=self.config['train']['num_workers'], 32 | drop_last=True 33 | ) 34 | self.epoch_length = len(self.train_dataset) // self.config['train']['batch_size'] 35 | 36 | def _initialize_model(self): 37 | self.logger.info("Initialize network arch {}".format(self.config['model']['backbone'])) 38 | model = get_model(self.config['model']['backbone'])() 39 | 40 | self.logger.info("Initialize network arch {}".format(self.config['model']['extractor'])) 41 | extractor = get_model(self.config['model']['extractor'])() 42 | 43 | if self.multi_gpus: 44 | model = torch.nn.DataParallel(model) 45 | extractor = torch.nn.DataParallel(extractor) 46 | self.model = model.to(self.device) 47 | self.extractor = extractor.to(self.device) 48 | 49 | def _initialize_loss(self): 50 | # 初始化loss算子 51 | # 初始化heatmap loss 52 | self.logger.info("Initialize the PointHeatmapWeightedBCELoss.") 53 | self.point_loss = PointHeatmapWeightedBCELoss() 54 | 55 | # 初始化描述子loss 56 | self.logger.info("Initialize the DescriptorGeneralTripletLoss.") 57 | self.descriptor_loss = DescriptorGeneralTripletLoss(self.device) 58 | 59 | def _initialize_optimizer(self): 60 | # 初始化网络训练优化器 61 | self.logger.info("Initialize Adam optimizer with weight_decay: {:.5f}.".format(self.config['train']['weight_decay'])) 62 | self.optimizer = torch.optim.Adam( 63 | params=self.model.parameters(), 64 | lr=self.config['train']['lr'], 65 | weight_decay=self.config['train']['weight_decay']) 66 | self.extractor_optimizer = torch.optim.Adam( 67 | params=self.extractor.parameters(), 68 | lr=self.config['train']['lr'], 69 | weight_decay=self.config['train']['weight_decay']) 70 | 71 | def _initialize_scheduler(self): 72 | 73 | # 初始化学习率调整算子 74 | if self.config['train']['lr_mod'] == 'LambdaLR': 75 | self.logger.info("Initialize lr_scheduler of LambdaLR: (%d, %d)" % ( 76 | self.config['train']['maintain_epoch'], self.config['train']['decay_epoch'])) 77 | 78 | def lambda_rule(epoch): 79 | lr_l = 1.0 - max(0, epoch - self.config['train']['maintain_epoch']) / float( 80 | self.config['train']['decay_epoch'] + 1) 81 | return lr_l 82 | 83 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda_rule) 84 | else: 85 | milestones = [20, 30] 86 | self.logger.info("Initialize lr_scheduler of MultiStepLR: (%d, %d)" % (milestones[0], milestones[1])) 87 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=milestones, gamma=0.1) 88 | 89 | def _train_one_epoch(self, epoch_idx): 90 | self.model.train() 91 | 92 | self.logger.info("-----------------------------------------------------") 93 | self.logger.info("Training epoch %2d begin:" % epoch_idx) 94 | 95 | self._train_func(epoch_idx) 96 | 97 | self.logger.info("Training epoch %2d done." % epoch_idx) 98 | self.logger.info("-----------------------------------------------------") 99 | 100 | def _train_func(self, epoch_idx): 101 | self.model.train() 102 | self.extractor.train() 103 | stime = time.time() 104 | for i, data in enumerate(self.train_dataloader): 105 | 106 | # 读取相关数据 107 | image = data["image"].to(self.device) 108 | heatmap_gt = data['heatmap'].to(self.device) 109 | point_mask = data['point_mask'].to(self.device) 110 | desp_point = data["desp_point"].to(self.device) 111 | 112 | warped_image = data["warped_image"].to(self.device) 113 | warped_heatmap_gt = data['warped_heatmap'].to(self.device) 114 | warped_point_mask = data['warped_point_mask'].to(self.device) 115 | warped_desp_point = data["warped_desp_point"].to(self.device) 116 | 117 | valid_mask = data["valid_mask"].to(self.device) 118 | not_search_mask = data["not_search_mask"].to(self.device) 119 | 120 | image_pair = torch.cat((image, warped_image), dim=0) 121 | 122 | # 模型预测 123 | heatmap_pred_pair, c1_pair, c2_pair, c3_pair, c4_pair = self.model(image_pair) 124 | 125 | # 计算描述子loss 126 | desp_point_pair = torch.cat((desp_point, warped_desp_point), dim=0) 127 | c1_feature_pair = f.grid_sample(c1_pair, desp_point_pair, mode="bilinear", padding_mode="border") 128 | c2_feature_pair = f.grid_sample(c2_pair, desp_point_pair, mode="bilinear", padding_mode="border") 129 | c3_feature_pair = f.grid_sample(c3_pair, desp_point_pair, mode="bilinear", padding_mode="border") 130 | c4_feature_pair = f.grid_sample(c4_pair, desp_point_pair, mode="bilinear", padding_mode="border") 131 | 132 | feature_pair = self.cat(c1_feature_pair, c2_feature_pair, c3_feature_pair, c4_feature_pair, dim=1) 133 | feature_pair = feature_pair[:, :, :, 0].transpose(1, 2) 134 | desp_pair = self.extractor(feature_pair) 135 | desp_0, desp_1 = torch.chunk(desp_pair, 2, dim=0) 136 | 137 | desp_loss = self.descriptor_loss(desp_0, desp_1, valid_mask, not_search_mask) 138 | 139 | # 计算关键点loss 140 | heatmap_gt_pair = torch.cat((heatmap_gt, warped_heatmap_gt), dim=0) 141 | point_mask_pair = torch.cat((point_mask, warped_point_mask), dim=0) 142 | point_loss = self.point_loss(heatmap_pred_pair[:, 0, :, :], heatmap_gt_pair, point_mask_pair) 143 | 144 | loss = desp_loss + point_loss 145 | 146 | if torch.isnan(loss): 147 | self.logger.error('loss is nan!') 148 | 149 | self.optimizer.zero_grad() 150 | self.extractor_optimizer.zero_grad() 151 | 152 | loss.backward() 153 | 154 | self.optimizer.step() 155 | self.extractor_optimizer.step() 156 | 157 | # debug use 158 | # if i == 200: 159 | # break 160 | 161 | if i % self.config['train']['log_freq'] == 0: 162 | 163 | point_loss_val = point_loss.item() 164 | desp_loss_val = desp_loss.item() 165 | loss_val = loss.item() 166 | 167 | self.logger.info( 168 | "[Epoch:%2d][Step:%5d:%5d]: loss = %.4f, point_loss = %.4f, desp_loss = %.4f" 169 | " one step cost %.4fs. " % ( 170 | epoch_idx, i, self.epoch_length, 171 | loss_val, 172 | point_loss_val, 173 | desp_loss_val, 174 | (time.time() - stime) / self.config['train']['log_freq'], 175 | )) 176 | stime = time.time() 177 | 178 | # save the model 179 | if self.multi_gpus: 180 | torch.save( 181 | self.model.module.state_dict(), os.path.join(self.config['ckpt_path'], 'model_%02d.pt' % epoch_idx)) 182 | torch.save( 183 | self.extractor.module.state_dict(), os.path.join(self.config['ckpt_path'], 'extractor_%02d.pt' % epoch_idx)) 184 | else: 185 | torch.save( 186 | self.model.state_dict(), os.path.join(self.config['ckpt_path'], 'model_%02d.pt' % epoch_idx)) 187 | torch.save( 188 | self.extractor.state_dict(), os.path.join(self.config['ckpt_path'], 'extractor_%02d.pt' % epoch_idx)) 189 | 190 | def _inference_func(self, image_pair): 191 | """ 192 | image_pair: [2,1,h,w] 193 | """ 194 | self.model.eval() 195 | self.extractor.eval() 196 | _, _, height, width = image_pair.shape 197 | heatmap_pair, c1_pair, c2_pair, c3_pair, c4_pair = self.model(image_pair) 198 | 199 | c1_0, c1_1 = torch.chunk(c1_pair, 2, dim=0) 200 | c2_0, c2_1 = torch.chunk(c2_pair, 2, dim=0) 201 | c3_0, c3_1 = torch.chunk(c3_pair, 2, dim=0) 202 | c4_0, c4_1 = torch.chunk(c4_pair, 2, dim=0) 203 | 204 | heatmap_pair = torch.sigmoid(heatmap_pair) 205 | prob_pair = spatial_nms(heatmap_pair) 206 | 207 | prob_pair = prob_pair.detach().cpu().numpy() 208 | first_prob = prob_pair[0, 0] 209 | second_prob = prob_pair[1, 0] 210 | 211 | # 得到对应的预测点 212 | first_point, first_point_num = self._generate_predict_point( 213 | first_prob, 214 | detection_threshold=self.config['test']['detection_threshold'], 215 | top_k=self.config['test']['top_k']) # [n,2] 216 | 217 | second_point, second_point_num = self._generate_predict_point( 218 | second_prob, 219 | detection_threshold=self.config['test']['detection_threshold'], 220 | top_k=self.config['test']['top_k']) # [n,2] 221 | 222 | if first_point_num <= 4 or second_point_num <= 4: 223 | print("skip this pair because there's little point!") 224 | return None 225 | 226 | # 得到点对应的描述子 227 | select_first_desp = self._generate_combined_descriptor_fast(first_point, c1_0, c2_0, c3_0, c4_0, height, width) 228 | select_second_desp = self._generate_combined_descriptor_fast(second_point, c1_1, c2_1, c3_1, c4_1, height, width) 229 | 230 | return first_point, first_point_num, second_point, second_point_num, select_first_desp, select_second_desp 231 | 232 | def _generate_combined_descriptor_fast(self, point, c1, c2, c3, c4, height, width): 233 | """ 234 | 用多层级的组合特征构造描述子 235 | Args: 236 | point: [n,2] 顺序是y,x 237 | c1,c2,c3,c4: 分别对应resnet4个block输出的特征,batchsize都是1 238 | Returns: 239 | desp: [n,dim] 240 | """ 241 | point = torch.from_numpy(point[:, ::-1].copy()).to(torch.float).to(self.device) 242 | # 归一化采样坐标到[-1,1] 243 | point = point * 2. / torch.tensor((width - 1, height - 1), dtype=torch.float, device=self.device) - 1 244 | point = point.unsqueeze(dim=0).unsqueeze(dim=2) # [1,n,1,2] 245 | 246 | c1_feature = f.grid_sample(c1, point, mode="bilinear")[:, :, :, 0].transpose(1, 2) 247 | c2_feature = f.grid_sample(c2, point, mode="bilinear")[:, :, :, 0].transpose(1, 2) 248 | c3_feature = f.grid_sample(c3, point, mode="bilinear")[:, :, :, 0].transpose(1, 2) 249 | c4_feature = f.grid_sample(c4, point, mode="bilinear")[:, :, :, 0].transpose(1, 2) 250 | 251 | feature = self.cat(c1_feature, c2_feature, c3_feature, c4_feature, dim=2) 252 | desp = self.extractor(feature)[0] # [n,128] 253 | 254 | desp = desp.detach().cpu().numpy() 255 | 256 | return desp 257 | 258 | def _generate_descriptor_for_superpoint_desp_head(self, point, desp, height, width): 259 | """ 260 | 构建superpoint描述子端的描述子 261 | """ 262 | point = torch.from_numpy(point[:, ::-1].copy()).to(torch.float).to(self.device) 263 | # 归一化采样坐标到[-1,1] 264 | point = point * 2. / torch.tensor((width - 1, height - 1), dtype=torch.float, device=self.device) - 1 265 | point = point.unsqueeze(dim=0).unsqueeze(dim=2) # [1,n,1,2] 266 | 267 | desp = f.grid_sample(desp, point, mode="bilinear")[0, :, :, 0].transpose(0, 1) 268 | desp = desp / torch.norm(desp, dim=1, keepdim=True).clamp(1e-5) 269 | 270 | desp = desp.detach().cpu().numpy() 271 | 272 | return desp 273 | 274 | class SuperPointTrainer(MegPointTrainer): 275 | 276 | def __init__(self, **config): 277 | super(SuperPointTrainer, self).__init__(**config) 278 | 279 | def _initialize_model(self): 280 | self.logger.info("Initialize network arch {}".format(self.config['model']['backbone'])) 281 | model = get_model(self.config['model']['backbone'])() 282 | 283 | if self.multi_gpus: 284 | model = torch.nn.DataParallel(model) 285 | self.model = model.to(self.device) 286 | 287 | def _initialize_loss(self): 288 | # 初始化point loss 289 | self.logger.info("Initialize the CrossEntropyLoss for SuperPoint.") 290 | self.point_loss = torch.nn.CrossEntropyLoss(reduction="none") 291 | 292 | # 初始化描述子loss 293 | self.logger.info("Initialize the DescriptorTripletLoss for SuperPoint.") 294 | self.descriptor_loss = DescriptorGeneralTripletLoss(self.device) 295 | 296 | def _initialize_optimizer(self): 297 | # 初始化网络训练优化器 298 | self.logger.info("Initialize Adam optimizer with weight_decay: {:.5f}.".format(self.config['train']['weight_decay'])) 299 | self.optimizer = torch.optim.Adam( 300 | params=self.model.parameters(), 301 | lr=self.config['train']['lr'], 302 | weight_decay=self.config['train']['weight_decay']) 303 | 304 | def _train_func(self, epoch_idx): 305 | self.model.train() 306 | 307 | stime = time.time() 308 | for i, data in enumerate(self.train_dataloader): 309 | 310 | image = data['image'].to(self.device) 311 | label = data['label'].to(self.device) 312 | mask = data['mask'].to(self.device) 313 | 314 | warped_image = data['warped_image'].to(self.device) 315 | warped_label = data['warped_label'].to(self.device) 316 | warped_mask = data['warped_mask'].to(self.device) 317 | 318 | desp_point = data["desp_point"].to(self.device) 319 | warped_desp_point = data["warped_desp_point"].to(self.device) 320 | valid_mask = data["valid_mask"].to(self.device) 321 | not_search_mask = data["not_search_mask"].to(self.device) 322 | 323 | shape = image.shape 324 | 325 | image_pair = torch.cat((image, warped_image), dim=0) 326 | label_pair = torch.cat((label, warped_label), dim=0) 327 | mask_pair = torch.cat((mask, warped_mask), dim=0) 328 | 329 | logit_pair, desp_pair, _ = self.model(image_pair) 330 | 331 | unmasked_point_loss = self.point_loss(logit_pair, label_pair) 332 | point_loss = self._compute_masked_loss(unmasked_point_loss, mask_pair) 333 | 334 | # compute descriptor loss 335 | desp_point_pair = torch.cat((desp_point, warped_desp_point), dim=0) 336 | desp_pair = f.grid_sample(desp_pair, desp_point_pair, mode="bilinear", padding_mode="border") 337 | desp_pair = desp_pair[:, :, :, 0].transpose(1, 2) 338 | desp_pair = desp_pair / torch.norm(desp_pair, dim=2, keepdim=True) 339 | desp_0, desp_1 = torch.chunk(desp_pair, 2, dim=0) 340 | 341 | desp_loss = self.descriptor_loss(desp_0, desp_1, valid_mask, not_search_mask) 342 | 343 | loss = point_loss + desp_loss 344 | 345 | if torch.isnan(loss): 346 | self.logger.error('loss is nan!') 347 | 348 | self.optimizer.zero_grad() 349 | loss.backward() 350 | 351 | self.optimizer.step() 352 | 353 | if i % self.config['train']['log_freq'] == 0: 354 | 355 | point_loss_val = point_loss.item() 356 | desp_loss_val = desp_loss.item() 357 | loss_val = loss.item() 358 | 359 | self.summary_writer.add_histogram('descriptor', desp_pair) 360 | self.logger.info("[Epoch:%2d][Step:%5d:%5d]: loss = %.4f, point_loss = %.4f, desp_loss = %.4f" 361 | " one step cost %.4fs. " 362 | % (epoch_idx, i, self.epoch_length, loss_val, 363 | point_loss_val, desp_loss_val, 364 | (time.time() - stime) / self.config['train']['log_freq'], 365 | )) 366 | stime = time.time() 367 | 368 | # save the model 369 | if self.multi_gpus: 370 | torch.save(self.model.module.state_dict(), os.path.join(self.config['ckpt_path'], 'model_%02d.pt' % epoch_idx)) 371 | else: 372 | torch.save(self.model.state_dict(), os.path.join(self.config['ckpt_path'], 'model_%02d.pt' % epoch_idx)) 373 | 374 | def _inference_func(self, image_pair): 375 | """ 376 | image_pair: [2,1,h,w] 377 | """ 378 | self.model.eval() 379 | _, _, height, width = image_pair.shape 380 | _, desp_pair, prob_pair = self.model(image_pair) 381 | prob_pair = f.pixel_shuffle(prob_pair, 8) 382 | prob_pair = spatial_nms(prob_pair) 383 | 384 | # 得到对应的预测点 385 | prob_pair = prob_pair.detach().cpu().numpy() 386 | first_prob = prob_pair[0, 0] 387 | second_prob = prob_pair[1, 0] 388 | 389 | first_point, first_point_num = self._generate_predict_point( 390 | first_prob, 391 | detection_threshold=self.config['test']['detection_threshold'], 392 | top_k=self.config['test']['top_k']) # [n,2] 393 | 394 | second_point, second_point_num = self._generate_predict_point( 395 | second_prob, 396 | detection_threshold=self.config['test']['detection_threshold'], 397 | top_k=self.config['test']['top_k']) # [n,2] 398 | 399 | if first_point_num <= 4 or second_point_num <= 4: 400 | print("skip this pair because there's little point!") 401 | return None 402 | 403 | # 得到点对应的描述子 404 | first_desp, second_desp = torch.chunk(desp_pair, 2, dim=0) 405 | 406 | select_first_desp = self._generate_descriptor_for_superpoint_desp_head(first_point, first_desp, height, width) 407 | select_second_desp = self._generate_descriptor_for_superpoint_desp_head(second_point, second_desp, height, width) 408 | 409 | return first_point, first_point_num, second_point, second_point_num, select_first_desp, select_second_desp 410 | 411 | 412 | --------------------------------------------------------------------------------