├── managers ├── __init__.py ├── extractor.py └── trainer.py ├── imgs ├── framework.png └── hpatches_res.png ├── evaluations ├── hpatches │ ├── cache │ │ ├── caps.npy │ │ ├── r2d2.npy │ │ ├── d2-net.npy │ │ ├── hesaff.npy │ │ ├── lf-net.npy │ │ ├── aslfeat.npy │ │ ├── delf-new.npy │ │ ├── hesaffnet.npy │ │ ├── contextdesc.npy │ │ ├── superpoint.npy │ │ ├── PoSFeat_CVPR.npy │ │ ├── disk-epipolar.npy │ │ ├── disk-d-2k-official.npy │ │ └── disk-d-8k-official.npy │ └── evaluation.py ├── aachen │ ├── camera.py │ ├── utils.py │ ├── matchers.py │ ├── reconstruct_pipeline_v1_1.py │ └── reconstruct_pipeline.py └── ETH_local_feature │ ├── custom_matcher.py │ └── reconstruction_pipeline.py ├── networks ├── __init__.py ├── DeteNet.py ├── PoSFeat_model.py └── DescNet.py ├── losses ├── __init__.py ├── epipolarloss.py ├── preprocess.py └── kploss.py ├── datasets ├── __init__.py ├── hpatches.py ├── ETH_local_feature.py ├── aachen.py └── data_utils.py ├── train.py ├── extract.py ├── configs ├── extract_hpatches.yaml ├── extract_aachen.yaml ├── extract_ETH.yaml ├── train_kp.yaml └── train_desc.yaml ├── .gitignore ├── LICENSE └── README.md /managers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/imgs/framework.png -------------------------------------------------------------------------------- /imgs/hpatches_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/imgs/hpatches_res.png -------------------------------------------------------------------------------- /evaluations/hpatches/cache/caps.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/caps.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/r2d2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/r2d2.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/d2-net.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/d2-net.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/hesaff.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/hesaff.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/lf-net.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/lf-net.npy -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .DescNet import ResUNet 2 | from .DeteNet import KeypointDet 3 | from .PoSFeat_model import PoSFeat -------------------------------------------------------------------------------- /evaluations/hpatches/cache/aslfeat.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/aslfeat.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/delf-new.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/delf-new.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/hesaffnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/hesaffnet.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/contextdesc.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/contextdesc.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/superpoint.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/superpoint.npy -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocess import * 2 | from . import preprocess_utils 3 | from .epipolarloss import * 4 | from .kploss import * -------------------------------------------------------------------------------- /evaluations/hpatches/cache/PoSFeat_CVPR.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/PoSFeat_CVPR.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/disk-epipolar.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/disk-epipolar.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/disk-d-2k-official.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/disk-d-2k-official.npy -------------------------------------------------------------------------------- /evaluations/hpatches/cache/disk-d-8k-official.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-SAIL/PoSFeat/HEAD/evaluations/hpatches/cache/disk-d-8k-official.npy -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .megadepth import MegaDepth_SIFT, MegaDepth_superpoint, MegaDepth_Depth 2 | from .hpatches import HPatch_SIFT 3 | from .aachen import Aachen_Day_Night 4 | from .ETH_local_feature import ETH_LFB -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from managers.trainer import * 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--local_rank', type=int, default=-1) 6 | parser.add_argument('--config', type=str, default='./configs/debug.yaml') 7 | args = parser.parse_args() 8 | trainer = Trainer(args) 9 | trainer.train() -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from managers.extractor import * 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--local_rank', type=int, default=-1) 6 | parser.add_argument('--config', type=str, default='./configs/extract.yaml') 7 | args = parser.parse_args() 8 | extractor = Extractor(args) 9 | extractor.extract() -------------------------------------------------------------------------------- /evaluations/aachen/camera.py: -------------------------------------------------------------------------------- 1 | # Simple COLMAP camera class. 2 | class Camera: 3 | def __init__(self): 4 | self.camera_model = None 5 | self.intrinsics = None 6 | self.qvec = None 7 | self.t = None 8 | 9 | def set_intrinsics(self, camera_model, intrinsics): 10 | self.camera_model = camera_model 11 | self.intrinsics = intrinsics 12 | 13 | def set_pose(self, qvec, t): 14 | self.qvec = qvec 15 | self.t = t 16 | -------------------------------------------------------------------------------- /evaluations/aachen/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def quaternion_to_rotation_matrix(qvec): 4 | qvec = qvec / np.linalg.norm(qvec) 5 | w, x, y, z = qvec 6 | R = np.array([[1 - 2 * y * y - 2 * z * z, 2 * x * y - 2 * z * w, 2 * x * z + 2 * y * w], 7 | [2 * x * y + 2 * z * w, 1 - 2 * x * x - 2 * z * z, 2 * y * z - 2 * x * w], 8 | [2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y]]) 9 | return R 10 | 11 | 12 | def camera_center_to_translation(c, qvec): 13 | R = quaternion_to_rotation_matrix(qvec) 14 | return (-1) * np.matmul(R, c) 15 | 16 | -------------------------------------------------------------------------------- /configs/extract_hpatches.yaml: -------------------------------------------------------------------------------- 1 | output_root: 'hpatches/PoSFeat_mytrain' 2 | postfix: 'PoSFeat_mytrain' 3 | load_path: './ckpts/keypoint/005' 4 | 5 | loss_distance: 'cos' 6 | output_desc: True 7 | output_img: False 8 | 9 | model: 'PoSFeat' 10 | model_config: 11 | backbone: None 12 | backbone_config: None 13 | localheader: None 14 | localheader_config: None 15 | 16 | data: 'HPatch_SIFT' # the images in megadepth_caps have been resized to 640x480, therefore, we don't need to resize the images 17 | data_config_extract: 18 | data_path: '/home/kunbpc/data/kunb/hpatches/hpatches-sequences-release' 19 | prune_kp: True 20 | num_pts: 5000 21 | batch_size: 1 22 | workers: 4 23 | 24 | local_thr: 0.99 # only used in save_img, this is a percent thr instead of a abs thr 25 | 26 | use_sift: False 27 | detector: 'generate_kpts_single' 28 | detector_config: 29 | num_pts: 8192 30 | stable: True 31 | use_nms: True # softnms, True, False 32 | nms_radius: 1 33 | thr: 0.9 # False or a float 34 | thr_mod: abs # max mean abs -------------------------------------------------------------------------------- /configs/extract_aachen.yaml: -------------------------------------------------------------------------------- 1 | output_root: 'aachen/PoSFeat_mytrain' 2 | postfix: 'PoSFeat_mytrain' 3 | load_path: './ckpts/keypoint/005' 4 | 5 | loss_distance: 'cos' 6 | output_desc: True 7 | output_img: False 8 | 9 | model: 'PoSFeat' 10 | model_config: 11 | backbone: None 12 | backbone_config: None 13 | localheader: None 14 | localheader_config: None 15 | 16 | data: 'Aachen_Day_Night' 17 | data_config_extract: 18 | data_path: '/home/kunbpc/data/kunb/aachen/aachen/images/images_upright' 19 | prune_kp: False 20 | batch_size: 1 21 | workers: 4 22 | 23 | local_thr: 0.99 # only used in save image 24 | 25 | use_sift: False 26 | detector: 'generate_kpts_single' 27 | detector_config: 28 | num_pts: 20480 29 | stable: True 30 | use_nms: True # softnms, True, False 31 | nms_radius: 3 32 | thr: 0.5 # False or a float 33 | thr_mod: abs # max mean abs 34 | detector_config_query: 35 | num_pts: 20480 36 | stable: True 37 | use_nms: True # softnms, True, False 38 | nms_radius: 3 39 | thr: 0.5 # False or a float 40 | thr_mod: abs # max mean abs -------------------------------------------------------------------------------- /configs/extract_ETH.yaml: -------------------------------------------------------------------------------- 1 | # extractor config 2 | output_root: 'ETH/PoSFeat_mytrain' 3 | postfix: 'PoSFeat_mytrain' 4 | load_path: './ckpts/keypoint/005' 5 | 6 | loss_distance: 'cos' 7 | output_desc: True 8 | output_img: False 9 | 10 | model: 'PoSFeat' 11 | model_config: 12 | backbone: None 13 | backbone_config: None 14 | localheader: None 15 | localheader_config: None 16 | 17 | 18 | data: 'ETH_LFB' 19 | data_config_extract: 20 | data_path: '/home/kunbpc/data/kunb/ETH-local-feature' 21 | subfolder: 'South-Building' # Alamo ArtsQuad_dataset Fountain 22 | # Gendarmenmarkt Herzjesu Madrid_Metropolis 23 | # Oxford5k Roman_Forum South-Building Tower_of_London 24 | batch_size: 1 25 | workers: 4 26 | local_thr: 0.99 # only used in save image 27 | 28 | use_sift: False 29 | detector: 'generate_kpts_single' 30 | detector_config: 31 | num_pts: 20480 32 | stable: True 33 | use_nms: True # softnms, True, False 34 | nms_radius: 3 35 | thr: 0.9 # False or a float 36 | thr_mod: abs # max mean abs 37 | 38 | # reconstruction config 39 | colmap_path: /home/kunbpc/Installed/colmap/build/src/exe 40 | matcher: 'mutual_nn_ratio_matcher' # mutual_nn_matcher mutual_nn_ratio_matcher 41 | matcher_config: 42 | ratio: 0.75 -------------------------------------------------------------------------------- /configs/train_kp.yaml: -------------------------------------------------------------------------------- 1 | checkpoint_name: 'keypoint' 2 | epoch: 5 3 | epoch_step: 1000 4 | lr_decay_step: 4 5 | lr_decay_factor: 0.1 6 | log_freq: 200 7 | grad_clip: False 8 | clip_norm: 10. 9 | test_kp: True 10 | 11 | optimal_modules: ['localheader'] 12 | optimal_lrs: [1.e-3] 13 | optimizer: SGD 14 | 15 | load_path: ckpts/descriptor/010 16 | 17 | model: 'PoSFeat' 18 | model_config: 19 | backbone: None 20 | backbone_config: None 21 | localheader: None 22 | localheader_config: None 23 | 24 | data: 'MegaDepth_SIFT' 25 | data_config_train: 26 | data_path: '/home/kunbpc/data/kunb/megadepth_caps/train' 27 | prune_kp: False 28 | num_pts: 2000 29 | batch_size: 6 30 | workers: 6 31 | random_percent: 0.5 32 | rot_thr: 80 33 | no_cuda: ['name1', 'name2'] 34 | 35 | val_config: 36 | data_config_val: 37 | data_path: '/home/kunbpc/data/kunb/megadepth_caps/train' 38 | prune_kp: False 39 | shuffle: True 40 | num_pts: 1024 41 | batch_size: 8 42 | workers: 1 43 | random_percent: 0 44 | rot_thr: 360 45 | detector: 'generate_kpts_single' # sift or something else 46 | detector_config: 47 | num_pts: 1024 48 | stable: True 49 | use_nms: True 50 | nms_radius: 1 51 | thr: False 52 | loss_distance: 'cos' 53 | vis_topk: 50 54 | vis_err_thr: 5 55 | 56 | losses: ['DiskLoss'] 57 | losses_weight: [1] 58 | tb_component: ['reinforce', 'kp_penalty'] 59 | 60 | DiskLoss_config: 61 | grid_size: 8 62 | loss_distance: 'cos' 63 | temperature_base: 60 64 | temperature_max: 60 65 | epipolar_reward: constant_reward # constant_reward, dynamic_reward 66 | reward_config: 67 | reward_thr: 2 68 | rescale_thr: False 69 | cor_detach: True 70 | good_reward: 1 71 | bad_reward: -0.25 72 | kp_penalty: -0.001 73 | match_grad: False -------------------------------------------------------------------------------- /datasets/hpatches.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as transforms 5 | import skimage.io as io 6 | from path import Path 7 | import cv2 8 | import torch.nn.functional as F 9 | 10 | class HPatch_SIFT(Dataset): 11 | def __init__(self, configs): 12 | super(HPatch_SIFT, self).__init__() 13 | self.configs = configs 14 | self.transform = transforms.Compose([transforms.ToTensor(), 15 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 16 | std=(0.229, 0.224, 0.225)), 17 | ]) 18 | # self.imfs = [] 19 | self.sift = cv2.SIFT_create() 20 | imdir = Path(self.configs['data_path']) 21 | self.imfs = imdir.glob('*/*.ppm') 22 | self.imfs.sort() 23 | 24 | 25 | def __getitem__(self, item): 26 | imf = self.imfs[item] 27 | im = io.imread(imf) 28 | name = imf.split('/')[-2:] 29 | name = '/'.join(name) 30 | im_tensor = self.transform(im) 31 | c, h, w = im_tensor.shape 32 | pad=(0,0,0,0) 33 | 34 | # now use crop to get suitable size 35 | crop_r = w%16 36 | crop_b = h%16 37 | im_tensor = im_tensor[:,:h-crop_b,:w-crop_r] 38 | im = im[:h-crop_b,:w-crop_r,:] 39 | gray = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) 40 | kpts = self.sift.detect(gray) 41 | kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts]) 42 | coord = torch.from_numpy(kpts).float() 43 | out = {'im1': im_tensor, 'im1_ori':im, 'coord1': coord, 'name1': name, 'pad1':pad} 44 | return out 45 | 46 | def __len__(self): 47 | return len(self.imfs) -------------------------------------------------------------------------------- /datasets/ETH_local_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as transforms 5 | import skimage.io as io 6 | from path import Path 7 | import cv2 8 | import torch.nn.functional as F 9 | 10 | class ETH_LFB(Dataset): 11 | def __init__(self, configs): 12 | """ 13 | dataset for eth local feature benchmark 14 | """ 15 | super(ETH_LFB, self).__init__() 16 | self.configs = configs 17 | self.transform = transforms.Compose([transforms.ToTensor(), 18 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 19 | std=(0.229, 0.224, 0.225)), 20 | ]) 21 | # self.imfs = [] 22 | self.sift = cv2.SIFT_create() 23 | imdir = Path(self.configs['data_path']) 24 | folder_dir = imdir/self.configs['subfolder'] 25 | images_dir = folder_dir/'images' 26 | imgs = images_dir.glob('*') 27 | self.imfs = imgs 28 | self.imfs.sort() 29 | 30 | def __getitem__(self, item): 31 | imf = self.imfs[item] 32 | im = io.imread(imf) 33 | name = imf.name 34 | name = '{}/{}'.format(self.configs['subfolder'], name) 35 | if len(im.shape) != 3: #gray images 36 | im = cv2.cvtColor(im, cv2.COLOR_GRAY2RGB) 37 | im = im.copy() 38 | im_tensor = self.transform(im) # 39 | c, h, w = im_tensor.shape 40 | # pad_b = 16 - h%16 41 | # pad_r = 16 - w%16 42 | # pad = (0,pad_r,0,pad_b) 43 | # im_tensor = F.pad(im_tensor.unsqueeze(0), pad, mode='replicate').squeeze(0) 44 | pad=(0,0,0,0) 45 | 46 | # now use crop to get suitable size 47 | crop_r = w%16 48 | crop_b = h%16 49 | im_tensor = im_tensor[:,:h-crop_b,:w-crop_r] 50 | im = im[:h-crop_b,:w-crop_r,:] 51 | # using sift keypoints 52 | gray = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) 53 | kpts = self.sift.detect(gray) 54 | kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts]) 55 | coord = torch.from_numpy(kpts).float() 56 | out = {'im1': im_tensor, 'im1_ori':im, 'coord1': coord, 'name1': name, 'pad1':pad} 57 | return out 58 | 59 | def __len__(self): 60 | return len(self.imfs) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /datasets/aachen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as transforms 5 | import skimage.io as io 6 | from path import Path 7 | import cv2 8 | import torch.nn.functional as F 9 | 10 | class Aachen_Day_Night(Dataset): 11 | def __init__(self, configs): 12 | super(Aachen_Day_Night, self).__init__() 13 | self.configs = configs 14 | self.transform = transforms.Compose([transforms.ToTensor(), 15 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 16 | std=(0.229, 0.224, 0.225)), 17 | ]) 18 | # self.imfs = [] 19 | self.sift = cv2.SIFT_create() 20 | imdir = Path(self.configs['data_path']) 21 | dbimgs = imdir.glob('db/*.jpg') 22 | queryimgs = imdir.glob('query/*/*/*.jpg') 23 | sequences1 = imdir.glob('sequences/gopro3_undistorted/*.png') 24 | sequences2 = imdir.glob('sequences/nexus4_sequences/*/*.png') 25 | self.imfs = dbimgs 26 | self.imfs.extend(queryimgs) 27 | self.imfs.extend(sequences1) 28 | self.imfs.extend(sequences2) 29 | self.imfs.sort() 30 | 31 | 32 | def __getitem__(self, item): 33 | imf = self.imfs[item] 34 | im = io.imread(imf) 35 | imf_split = imf.split('/') 36 | if 'db' in imf_split: 37 | name = imf_split[-2:] 38 | name = '/'.join(name) 39 | elif 'query' in imf_split: 40 | name = imf_split[-4:] 41 | name = '/'.join(name) 42 | elif 'gopro3_undistorted' in imf_split: 43 | name = imf_split[-3:] 44 | name = '/'.join(name) 45 | elif 'nexus4_sequences' in imf_split: 46 | name = imf_split[-4:] 47 | name = '/'.join(name) 48 | im_tensor = self.transform(im) 49 | c, h, w = im_tensor.shape 50 | pad=(0,0,0,0) 51 | 52 | # now use crop to get suitable size 53 | crop_r = w%16 54 | crop_b = h%16 55 | im_tensor = im_tensor[:,:h-crop_b,:w-crop_r] 56 | im = im[:h-crop_b,:w-crop_r,:] 57 | gray = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) 58 | kpts = self.sift.detect(gray) 59 | kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts]) 60 | coord = torch.from_numpy(kpts).float() 61 | out = {'im1': im_tensor, 'im1_ori':im, 'coord1': coord, 'name1': name, 'pad1':pad} 62 | return out 63 | 64 | def __len__(self): 65 | return len(self.imfs) -------------------------------------------------------------------------------- /evaluations/aachen/matchers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # Mutual nearest neighbors matcher for L2 normalized descriptors. 5 | def mutual_nn_matcher(descriptors1, descriptors2): 6 | device = descriptors1.device 7 | sim = descriptors1 @ descriptors2.t() 8 | nn12 = torch.max(sim, dim=1)[1] 9 | nn21 = torch.max(sim, dim=0)[1] 10 | ids1 = torch.arange(0, sim.shape[0], device=device) 11 | mask = ids1 == nn21[nn12] 12 | matches = torch.stack([ids1[mask], nn12[mask]]).t() 13 | return matches.data.cpu().numpy() 14 | 15 | 16 | # Symmetric Lowe's ratio test matcher for L2 normalized descriptors. 17 | def ratio_matcher(descriptors1, descriptors2, ratio=0.95): 18 | device = descriptors1.device 19 | sim = descriptors1 @ descriptors2.t() 20 | 21 | # Retrieve top 2 nearest neighbors 1->2. 22 | nns_sim, nns = torch.topk(sim, 2, dim=1) 23 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 24 | # Compute Lowe's ratio. 25 | ratios12 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 26 | # Save first NN. 27 | nn12 = nns[:, 0] 28 | 29 | # Retrieve top 2 nearest neighbors 1->2. 30 | nns_sim, nns = torch.topk(sim.t(), 2, dim=1) 31 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 32 | # Compute Lowe's ratio. 33 | ratios21 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 34 | # Save first NN. 35 | nn21 = nns[:, 0] 36 | 37 | # Symmetric ratio test. 38 | ids1 = torch.arange(0, sim.shape[0], device=device) 39 | mask = torch.min(ratios12 <= ratio, ratios21[nn12] <= ratio) 40 | 41 | # Final matches. 42 | matches = torch.stack([ids1[mask], nn12[mask]], dim=-1) 43 | 44 | return matches.data.cpu().numpy() 45 | 46 | 47 | # Mutual NN + symmetric Lowe's ratio test matcher for L2 normalized descriptors. 48 | def mutual_nn_ratio_matcher(descriptors1, descriptors2, ratio=0.95): 49 | device = descriptors1.device 50 | sim = descriptors1 @ descriptors2.t() 51 | 52 | # Retrieve top 2 nearest neighbors 1->2. 53 | nns_sim, nns = torch.topk(sim, 2, dim=1) 54 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 55 | # Compute Lowe's ratio. 56 | ratios12 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 57 | # Save first NN and match similarity. 58 | nn12 = nns[:, 0] 59 | 60 | # Retrieve top 2 nearest neighbors 1->2. 61 | nns_sim, nns = torch.topk(sim.t(), 2, dim=1) 62 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 63 | # Compute Lowe's ratio. 64 | ratios21 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 65 | # Save first NN. 66 | nn21 = nns[:, 0] 67 | 68 | # Mutual NN + symmetric ratio test. 69 | ids1 = torch.arange(0, sim.shape[0], device=device) 70 | mask = torch.min(ids1 == nn21[nn12], torch.min(ratios12 <= ratio, ratios21[nn12] <= ratio)) 71 | 72 | # Final matches. 73 | matches = torch.stack([ids1[mask], nn12[mask]], dim=-1) 74 | 75 | return matches.data.cpu().numpy() 76 | 77 | -------------------------------------------------------------------------------- /evaluations/ETH_local_feature/custom_matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # Mutual nearest neighbors matcher for L2 normalized descriptors. 5 | def mutual_nn_matcher(descriptors1, descriptors2, **args): 6 | device = descriptors1.device 7 | sim = descriptors1 @ descriptors2.t() 8 | nn12 = torch.max(sim, dim=1)[1] 9 | nn21 = torch.max(sim, dim=0)[1] 10 | ids1 = torch.arange(0, sim.shape[0], device=device) 11 | mask = ids1 == nn21[nn12] 12 | matches = torch.stack([ids1[mask], nn12[mask]]).t() 13 | return matches.data.cpu().numpy() 14 | 15 | # Symmetric Lowe's ratio test matcher for L2 normalized descriptors. 16 | def ratio_matcher(descriptors1, descriptors2, ratio=0.95): 17 | device = descriptors1.device 18 | sim = descriptors1 @ descriptors2.t() 19 | 20 | # Retrieve top 2 nearest neighbors 1->2. 21 | nns_sim, nns = torch.topk(sim, 2, dim=1) 22 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 23 | # Compute Lowe's ratio. 24 | ratios12 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 25 | # Save first NN. 26 | nn12 = nns[:, 0] 27 | 28 | # Retrieve top 2 nearest neighbors 1->2. 29 | nns_sim, nns = torch.topk(sim.t(), 2, dim=1) 30 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 31 | # Compute Lowe's ratio. 32 | ratios21 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 33 | # Save first NN. 34 | nn21 = nns[:, 0] 35 | 36 | # Symmetric ratio test. 37 | ids1 = torch.arange(0, sim.shape[0], device=device) 38 | mask = torch.min(ratios12 <= ratio, ratios21[nn12] <= ratio) 39 | 40 | # Final matches. 41 | matches = torch.stack([ids1[mask], nn12[mask]], dim=-1) 42 | 43 | return matches.data.cpu().numpy() 44 | 45 | 46 | # Mutual NN + symmetric Lowe's ratio test matcher for L2 normalized descriptors. 47 | def mutual_nn_ratio_matcher(descriptors1, descriptors2, ratio=0.95): 48 | device = descriptors1.device 49 | sim = descriptors1 @ descriptors2.t() 50 | 51 | # Retrieve top 2 nearest neighbors 1->2. 52 | nns_sim, nns = torch.topk(sim, 2, dim=1) 53 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 54 | # Compute Lowe's ratio. 55 | ratios12 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 56 | # Save first NN and match similarity. 57 | nn12 = nns[:, 0] 58 | 59 | # Retrieve top 2 nearest neighbors 1->2. 60 | nns_sim, nns = torch.topk(sim.t(), 2, dim=1) 61 | nns_dist = torch.sqrt(2 - 2 * nns_sim) 62 | # Compute Lowe's ratio. 63 | ratios21 = nns_dist[:, 0] / (nns_dist[:, 1] + 1e-8) 64 | # Save first NN. 65 | nn21 = nns[:, 0] 66 | 67 | # Mutual NN + symmetric ratio test. 68 | ids1 = torch.arange(0, sim.shape[0], device=device) 69 | mask = torch.min(ids1 == nn21[nn12], torch.min(ratios12 <= ratio, ratios21[nn12] <= ratio)) 70 | 71 | # Final matches. 72 | matches = torch.stack([ids1[mask], nn12[mask]], dim=-1) 73 | 74 | return matches.data.cpu().numpy() -------------------------------------------------------------------------------- /configs/train_desc.yaml: -------------------------------------------------------------------------------- 1 | checkpoint_name: 'descriptor' 2 | epoch: 10 3 | epoch_step: 10000 4 | lr_decay_step: 9 5 | lr_decay_factor: 0.1 6 | log_freq: 1000 7 | grad_clip: False 8 | clip_norm: 10. 9 | test_kp: False # if False, test with sift on hpatches 10 | 11 | optimal_modules: ['backbone'] 12 | optimal_lrs: [1.e-4] 13 | optimizer: Adam # SGD Adam 14 | 15 | # the model settings 16 | model: 'PoSFeat' 17 | model_config: 18 | backbone: 'ResUNet' 19 | backbone_config: 20 | encoder: 'resnet50' 21 | pretrained: True 22 | coarse_out_ch: 128 23 | fine_out_ch: 128 24 | localheader: 'KeypointDet' 25 | localheader_config: 26 | in_channels: 192 # 128 for localmap 64 for localmap_small 27 | prior: 'identity' # ASL_Peak D2 identity 28 | act: 'Softplus' 29 | align_local_grad: False 30 | local_input_elements: ['local_map', 'local_map_small'] 31 | local_with_img: True 32 | 33 | # the data settings 34 | data: 'MegaDepth_SIFT' # the images in megadepth_caps have been resized to 640x480, therefore, we don't need to resize the images 35 | data_config_train: 36 | data_path: '/home/kunbpc/data/kunb/megadepth_caps_(copy)/train' 37 | prune_kp: False 38 | num_pts: 2000 39 | batch_size: 8 40 | workers: 6 41 | random_percent: 0.5 42 | rot_thr: 80 43 | no_cuda: ['name1', 'name2'] 44 | 45 | val_config: 46 | data_config_val: 47 | data_path: '/home/kunbpc/data/kunb/megadepth_caps_(copy)/train' 48 | prune_kp: False 49 | shuffle: True 50 | num_pts: 1024 51 | batch_size: 8 52 | workers: 1 53 | random_percent: 0 54 | rot_thr: 360 55 | detector: 'sift' # sift or something else, e.g., generate_kpts_single2 56 | loss_distance: 'cos' 57 | vis_topk: 50 58 | vis_err_thr: 5 59 | 60 | # the losses settings 61 | preprocess_train: 'Preprocess_Line2Window' 62 | preprocess_train_config: 63 | kps_generator: 'generate_kpts_regular_grid_random' 64 | kps_generator_config: 65 | grid_size: 16 # the grid size on fine feature map 66 | map_init: 'identity' # SSIM D2 ASL_Peak identity 67 | keep_spatial: True 68 | random_select: 'random' # random for refinement with kpmap guidance, regular_random for initialization 69 | window_size: 0.1 70 | loss_distance: 'cos' 71 | use_nn_grid: False # for compute std_grid, and when not use line search, locate center 72 | use_line_search: True 73 | line_search_config: 74 | line_step: 100 75 | use_nn: True 76 | loc_rand: True 77 | temperature_base: 60 78 | temperature_max: 60 79 | # linelen_thr: 0 80 | # regular_grid_points: False #if false, use sift keypoints 81 | 82 | losses: ['EpipolarLoss_full'] 83 | losses_weight: [1] 84 | tb_component: ['loss_w1', 'loss_w2', 'percent_w'] 85 | 86 | EpipolarLoss_full_config: 87 | grid_cost_thr: 0.5 88 | win_cost_thr: 0.1 89 | use_std_as_weight: True 90 | weight_grid: 0 91 | weight_window: 1 -------------------------------------------------------------------------------- /networks/DeteNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class KeypointDet(nn.Module): 6 | """ 7 | spatical attention header 8 | """ 9 | def __init__(self, in_channels, out_channels=1, prior='SSIM', act='Sigmoid'): 10 | super(KeypointDet, self).__init__() 11 | self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) 12 | self.norm1 = nn.InstanceNorm2d(in_channels) 13 | self.conv2 = nn.Conv2d(in_channels+64, 128, 3, 1, 1) 14 | self.norm2 = nn.InstanceNorm2d(128) 15 | self.conv3 = nn.Conv2d(128, out_channels, 1, 1, 0) 16 | self.norm3 = nn.InstanceNorm2d(out_channels) 17 | self.relu = nn.PReLU() 18 | self.prior = getattr(self, prior) 19 | self.act = getattr(nn, act)() 20 | 21 | self.convimg = nn.Conv2d(3, 64, 3, 1, 1) 22 | self.normimg = nn.InstanceNorm2d(64) 23 | 24 | def SSIM(self, x): 25 | C1 = 0.01 ** 2 26 | C2 = 0.03 ** 2 27 | 28 | x_pad = F.pad(x.abs(), (0,1,0,1), 'reflect') 29 | x_lu = x_pad[:,:,:-1,:-1] 30 | x_rb = x_pad[:,:,1:,1:] 31 | 32 | x_lu = F.pad(x_lu, (1,1,1,1), 'reflect') 33 | x_rb = F.pad(x_rb, (1,1,1,1), 'reflect') 34 | 35 | m_x_lu = F.avg_pool2d(x_lu, 3, 1) 36 | m_x_rb = F.avg_pool2d(x_rb, 3, 1) 37 | 38 | sigma_x_lu = F.avg_pool2d(x_lu**2, 3, 1) - m_x_lu**2 39 | sigma_x_rb = F.avg_pool2d(x_rb**2, 3, 1) - m_x_rb**2 40 | sigma_x_lu_rb = F.avg_pool2d(x_lu*x_rb, 3, 1) - m_x_lu*m_x_rb 41 | 42 | SSIM_n = (2 * m_x_lu * m_x_rb + C1) * (2 * sigma_x_lu_rb + C2) 43 | SSIM_d = (m_x_lu ** 2 + m_x_rb ** 2 + C1) * (sigma_x_lu + sigma_x_rb + C2) 44 | 45 | return torch.clamp((1 - SSIM_n / SSIM_d)/2, 0, 1) 46 | 47 | def D2(self, x): 48 | b,c,h,w = x.shape 49 | window_size = 3 50 | padding_size = window_size//2 51 | 52 | x = F.relu(x) 53 | max_per_sample = torch.max(x.view(b,-1), dim=1)[0] 54 | exp = torch.exp(x/max_per_sample.view(b,1,1,1)) 55 | sum_exp = ( 56 | window_size**2* 57 | F.avg_pool2d( 58 | F.pad(exp, [padding_size]*4, mode='constant', value=1.), 59 | window_size, stride=1 60 | ) 61 | ) 62 | 63 | local_max_score = exp / sum_exp 64 | 65 | depth_wise_max = torch.max(x, dim=1)[0] 66 | depth_wise_max_score = x / depth_wise_max.unsqueeze(1) 67 | 68 | all_scores = local_max_score * depth_wise_max_score 69 | score = torch.max(all_scores, dim=1)[0] 70 | 71 | # score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1) 72 | 73 | return score.unsqueeze(1) 74 | 75 | def ASL_Peak(self, x): 76 | b,c,h,w = x.shape 77 | window_size = 3 78 | padding_size = window_size//2 79 | 80 | # x = F.relu(x) 81 | max_per_sample = torch.max(x.view(b,-1), dim=1)[0] 82 | x = x/max_per_sample.view(b,1,1,1) 83 | 84 | alpha_input = x - F.avg_pool2d( 85 | F.pad(x, [padding_size]*4, mode='reflect'), 86 | window_size, stride=1 87 | ) 88 | alpha = F.softplus(alpha_input) 89 | 90 | beta_input = x - x.mean(1, True) 91 | beta = F.softplus(beta_input) 92 | 93 | all_scores = (alpha*beta).max(1,True)[0] 94 | 95 | return all_scores 96 | 97 | def identity(self, x): 98 | scores = torch.ones_like(x) 99 | return scores.mean(1,True) 100 | 101 | 102 | def forward(self, fine_maps): 103 | fine_map = fine_maps[0] 104 | img_tensor = fine_maps[1] 105 | x_pf = self.prior(fine_map) 106 | x_pi = self.prior(img_tensor) 107 | 108 | x = self.relu(self.norm1(self.conv1(x_pf*fine_map))) 109 | x = F.interpolate(x, img_tensor.shape[2:], align_corners=False, mode='bilinear') 110 | img_tensor = self.normimg(self.convimg(x_pi*img_tensor)) 111 | x = torch.cat([x, img_tensor], dim=1) 112 | x = self.relu(self.norm2(self.conv2(x))) 113 | score = self.act(self.norm3(self.conv3(x))) 114 | 115 | # thr = self.act(self.conv_thr(x)) 116 | # score = self.relu(score-thr) 117 | 118 | score =F.interpolate(x_pf, img_tensor.shape[2:], align_corners=False, mode='bilinear').mean(1,True) * \ 119 | x_pi.mean(1,True) * score 120 | 121 | return score -------------------------------------------------------------------------------- /losses/epipolarloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .preprocess_utils import * 6 | 7 | 8 | class EpipolarLoss_full(nn.Module): 9 | def __init__(self, configs, device=None): 10 | super(EpipolarLoss_full, self).__init__() 11 | self.__lossname__ = 'EpipolarLoss_fullinfo' 12 | self.config = configs 13 | self.w_g = self.config['weight_grid'] 14 | self.w_w = self.config['weight_window'] 15 | 16 | def epipolar_cost(self, coord1, coord2, fmatrix, im_size): 17 | coord1_h = homogenize(coord1).transpose(1, 2) 18 | coord2_h = homogenize(coord2).transpose(1, 2) 19 | epipolar_line = fmatrix.bmm(coord1_h) # Bx3xn 20 | epipolar_line_ = epipolar_line / torch.clamp(torch.norm(epipolar_line[:, :2, :], dim=1, keepdim=True), min=1e-8) 21 | essential_cost = torch.abs(torch.sum(coord2_h * epipolar_line_, dim=1)) # Bxn 22 | return essential_cost 23 | 24 | 25 | def set_weight(self, inverse_std, mask=None, regularizer=0.0): 26 | if self.config['use_std_as_weight']: 27 | # inverse_std = 1. / torch.clamp(std+regularizer, min=1e-10) 28 | weight = inverse_std / torch.mean(inverse_std) 29 | weight = weight.detach() # Bxn 30 | else: 31 | weight = torch.ones_like(std) 32 | 33 | if mask is not None: 34 | weight *= mask.float() 35 | weight /= (torch.mean(weight) + 1e-8) 36 | return weight 37 | 38 | def forward(self, inputs, outputs, processed): 39 | coord1 = processed['coord1'] 40 | coord2 = processed['coord2'] 41 | temperature = processed['temperature'] 42 | 43 | feat1g_corloc = processed['feat1g_corloc'] 44 | feat2g_corloc = processed['feat2g_corloc'] 45 | feat1w_corloc = processed['feat1w_corloc'] 46 | feat2w_corloc = processed['feat2w_corloc'] 47 | 48 | feat1g_std = processed['feat1g_std'] 49 | feat2g_std = processed['feat2g_std'] 50 | feat1w_std = processed['feat1w_std'] 51 | feat2w_std = processed['feat2w_std'] 52 | 53 | Fmat1 = inputs['F1'] 54 | Fmat2 = inputs['F2'] 55 | im_size1 = inputs['im1'].size()[2:] 56 | im_size2 = inputs['im2'].size()[2:] 57 | shorter_edge, longer_edge = min(im_size1), max(im_size1) 58 | 59 | cost_g1 = self.epipolar_cost(coord1, feat1g_corloc, Fmat1, im_size1) 60 | cost_w1 = self.epipolar_cost(coord1, feat1w_corloc, Fmat1, im_size1) 61 | 62 | cost_g2 = self.epipolar_cost(coord2, feat2g_corloc, Fmat2, im_size2) 63 | cost_w2 = self.epipolar_cost(coord2, feat2w_corloc, Fmat2, im_size2) 64 | 65 | # filter out the large values, similar to CAPS 66 | # 去除异常loss,参考CAPS 67 | mask_g1 = cost_g1 < (shorter_edge*self.config['grid_cost_thr']) 68 | mask_w1 = cost_w1 < (shorter_edge*self.config['win_cost_thr']) 69 | mask_g2 = cost_g2 < (shorter_edge*self.config['grid_cost_thr']) 70 | mask_w2 = cost_w2 < (shorter_edge*self.config['win_cost_thr']) 71 | 72 | if 'valid_epi1' in list(processed.keys()): 73 | mask_g1 = mask_g1 & processed['valid_epi1'] 74 | mask_w1 = mask_w1 & processed['valid_epi1'] 75 | mask_g2 = mask_g2 & processed['valid_epi2'] 76 | mask_w2 = mask_w2 & processed['valid_epi2'] 77 | weight_w1 = 1 78 | weight_w2 = 1 79 | 80 | weight_g1 = self.set_weight(1/feat1g_std.clamp(min=1e-10), mask_g1) 81 | weight_w1 = self.set_weight(weight_w1/feat1w_std.clamp(min=1e-10), mask_w1) 82 | weight_g2 = self.set_weight(1/feat2g_std.clamp(min=1e-10), mask_g2) 83 | weight_w2 = self.set_weight(weight_w2/feat2w_std.clamp(min=1e-10), mask_w2) 84 | 85 | loss_g1 = (weight_g1*cost_g1).mean() 86 | loss_w1 = (weight_w1*cost_w1).mean() 87 | loss_g2 = (weight_g2*cost_g2).mean() 88 | loss_w2 = (weight_w2*cost_w2).mean() 89 | 90 | loss = self.w_g*(loss_g1+loss_g2)+self.w_w*(loss_w1+loss_w2) 91 | 92 | percent_g = (mask_g1.sum()/(mask_g1.shape[0]*mask_g1.shape[1]) + mask_g2.sum()/(mask_g2.shape[0]*mask_g2.shape[1]))/2 93 | percent_w = (mask_w1.sum()/(mask_w1.shape[0]*mask_w1.shape[1]) + mask_w2.sum()/(mask_w2.shape[0]*mask_w2.shape[1]))/2 94 | 95 | components = { 96 | 'loss_g1': loss_g1, 'loss_w1':loss_w1, 97 | 'loss_g2':loss_g2, 'loss_w2':loss_w2, 98 | 'percent_g':percent_g, 'percent_w':percent_w 99 | } 100 | 101 | return loss, components -------------------------------------------------------------------------------- /networks/PoSFeat_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | WSFModel without global header 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from abc import ABC, abstractmethod 10 | from path import Path 11 | import os 12 | 13 | import networks 14 | 15 | class PoSFeat(ABC): 16 | def __init__(self, configs, device, no_cuda=None): 17 | self.config = configs 18 | self.device = device 19 | self.no_cuda = no_cuda 20 | self.align_local_grad = self.config['align_local_grad'] 21 | self.local_input_elements = self.config['local_input_elements'] 22 | self.local_with_img = self.config['local_with_img'] 23 | self.parameters = [] 24 | 25 | backbone = getattr(networks, self.config['backbone']) 26 | self.backbone = backbone(**self.config['backbone_config']).to(self.device) 27 | self.parameters += list(self.backbone.parameters()) 28 | # self.backbone.eval() 29 | message = "backbone: {}\n".format(self.config['backbone']) 30 | 31 | if 'localheader' in list(self.config.keys()) and self.config['localheader'] != 'None': 32 | # if self.config['localheader'] is not None: 33 | localheader = getattr(networks, self.config['localheader']) 34 | self.localheader = localheader(**self.config['localheader_config']).to(self.device) 35 | message += "localheader: {}\n".format(self.config['localheader']) 36 | else: 37 | in_channel = self.backbone.out_channels[0] 38 | # if self.config['backbone'] == 'LiteHRNet': 39 | # in_channel = self.config['backbone_config']['extra']['stages_spec']['num_channels'][-1][0] 40 | # else: 41 | # in_channel = 128 42 | self.localheader = networks.KeypointDet(in_channels=in_channel, out_channels=2).to(self.device) 43 | message += "localheader: KeypointDet\n" 44 | self.parameters += list(self.localheader.parameters()) 45 | self.modules = ['localheader', 'backbone'] 46 | print(message) 47 | 48 | def set_parallel(self, local_rank): 49 | self.backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.backbone) 50 | self.backbone = torch.nn.parallel.DistributedDataParallel(self.backbone, 51 | find_unused_parameters=True,device_ids=[local_rank],output_device=local_rank) 52 | 53 | self.localheader = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.localheader) 54 | self.localheader = torch.nn.parallel.DistributedDataParallel(self.localheader, 55 | find_unused_parameters=True,device_ids=[local_rank],output_device=local_rank) 56 | 57 | def load_checkpoint(self, load_path): 58 | load_root = Path(load_path) 59 | model_list = ['backbone', 'localheader'] 60 | for name in model_list: 61 | model_path = load_root/'{}.pth'.format(name) 62 | if os.path.exists(model_path): 63 | print('load {} from checkpoint'.format(name)) 64 | else: 65 | print('{} does not exist, skipping load'.format(name)) 66 | continue 67 | model = getattr(self, name) 68 | model_param = torch.load(model_path) 69 | # print('\n\n {}\n'.format(name)) 70 | # for key, val in model_param.items(): 71 | # print(key) 72 | model.load_state_dict(model_param) 73 | 74 | def save_checkpoint(self, save_path): 75 | save_root = Path(save_path) 76 | model_list = ['backbone', 'localheader'] 77 | for name in model_list: 78 | model_path = save_root/'{}.pth'.format(name) 79 | model = getattr(self, name) 80 | model_param = model.state_dict() 81 | torch.save(model_param, model_path) 82 | 83 | def set_train(self): 84 | self.backbone.train() 85 | self.localheader.train() 86 | 87 | def set_eval(self): 88 | self.backbone.eval() 89 | self.localheader.eval() 90 | 91 | def extract(self, tensor, postfix=""): 92 | feat_maps = self.backbone(tensor) 93 | # g_map = self.globalheader(feat_maps['global_map']) 94 | b, c, h, w = feat_maps['global_map'].shape 95 | g_map = torch.ones(b,1, h, w).type_as(feat_maps['local_map']).to(feat_maps['local_map'].device) 96 | local_list = [] 97 | for name in self.local_input_elements: 98 | local_list.append(feat_maps[name]) 99 | local_input = torch.cat(local_list, dim=1) 100 | if not self.align_local_grad: 101 | # l_map = self.localheader(local_input) 102 | local_input = local_input.detach() 103 | # else: 104 | # l_map = self.localheader(local_input.detach()) 105 | if self.local_with_img: 106 | local_input = [local_input, tensor] 107 | l_map = self.localheader(local_input) 108 | 109 | if l_map.shape[1] == 1: 110 | local_thr = torch.zeros_like(l_map) 111 | elif l_map.shape[1] == 2: 112 | local_thr = l_map[:,1:,:,:] 113 | l_map = l_map[:,:1,:,:] 114 | 115 | g_desc = g_map*feat_maps['global_map'] 116 | # g_desc = g_desc.sum([2,3]) 117 | g_desc = F.normalize(g_desc, p=2, dim=1).mean([2,3]) 118 | 119 | outputs = { 120 | 'local_map': feat_maps['local_map'], 121 | 'global_map': feat_maps['global_map'], 122 | 'global_feat': g_desc, 123 | 'local_point': l_map, 124 | 'local_thr': local_thr, 125 | 'global_point': g_map 126 | } 127 | 128 | # outputs = { 129 | # 'local_feat{}'.format(postfix): feat_maps['fine_map'], 130 | # 'global_feat{}'.format(postfix): g_desc, 131 | # 'local_point{}'.format(postfix): l_map, 132 | # 'global_point{}'.format(postfix): g_map 133 | # } 134 | return outputs 135 | 136 | def forward(self, inputs): 137 | for key, val in inputs.items(): 138 | if key in self.no_cuda: 139 | continue 140 | inputs[key] = val.to(self.device) 141 | 142 | # preds = self.extract(inputs['im1'],1) 143 | # preds.update(self.extract(inputs['im2'],2)) 144 | preds1 = self.extract(inputs['im1'],1) 145 | preds2 = self.extract(inputs['im2'],2) 146 | 147 | return {'preds1':preds1, 'preds2':preds2} 148 | 149 | -------------------------------------------------------------------------------- /losses/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from . import preprocess_utils as putils 5 | from .preprocess_utils import * 6 | 7 | class Preprocess_Line2Window(nn.Module): 8 | ''' 9 | the preprocess class for grid-with-line pipeline 10 | ''' 11 | def __init__(self, configs, device=None, vis=False): 12 | super(Preprocess_Line2Window, self).__init__() 13 | self.__lossname__ = 'Preprocess_Line2Window' 14 | self.config = configs 15 | self.kps_generator = getattr(putils, self.config['kps_generator']) 16 | self.t_base = self.config['temperature_base'] 17 | self.t_max = self.config['temperature_max'] 18 | if device is not None: 19 | self.device = device 20 | 21 | def name(self): 22 | return self.__lossname__ 23 | 24 | def forward(self, inputs, outputs): 25 | preds1 = outputs['preds1'] 26 | preds2 = outputs['preds2'] 27 | 28 | xc1, xf1 = preds1['global_map'], preds1['local_map'] 29 | xc2, xf2 = preds2['global_map'], preds2['local_map'] 30 | h1i, w1i = inputs['im1'].size()[2:] 31 | h2i, w2i = inputs['im2'].size()[2:] 32 | b, _, hf, wf = xf1.shape 33 | temperature = min(self.t_base + outputs['epoch'], self.t_max) 34 | 35 | """ 36 | firstly, we search locate the correspondence with grid points 37 | with keep_spatial==True, coord (score) is with bxhxwx2 (bxhxwx1) 38 | with keep_spatial==False, coord (score) is with bx(h*w)x2 (bx(h*w)x1) 39 | the keep_spatial is defined in self.config['kps_generator_config'] 40 | 首先,我们随所有的抽样点进行匹配搜索 41 | 在配置文件中有一个选项 keep_spatial 可以控制输出的抽样点的shape 42 | 43 | This is a coarse search with grid points matching,which is similar to the coarse search in caps 44 | in fact this coarse matching is just for ablation, and the results are not used in the final loss 45 | you can comment out this search 46 | 这里包含了一部分粗略匹配的代码,类似于CAPS中的粗略匹配 47 | 粗匹配的结果是最开始实验时进行的探索,实际上并没有用于最后的损失函数计算 48 | 可以注释掉粗略匹配的代码 49 | """ 50 | 51 | coord1_n, coord2_n, score1, score2 = self.kps_generator(inputs, outputs, **self.config['kps_generator_config']) 52 | _, hkps, wkps, _ = coord1_n.shape 53 | coord1 = denormalize_coords(coord1_n.reshape(b,-1,2), h1i, w1i) 54 | coord2 = denormalize_coords(coord2_n.reshape(b,-1,2), h2i, w2i) 55 | 56 | feat1_fine = sample_feat_by_coord(xf1, coord1_n.reshape(b,-1,2), self.config['loss_distance']=='cos') 57 | feat2_fine = sample_feat_by_coord(xf2, coord2_n.reshape(b,-1,2), self.config['loss_distance']=='cos') 58 | 59 | cos_sim = feat1_fine @ feat2_fine.transpose(1,2) # bxmxn 60 | feat1g_corloc = (F.softmax(temperature*cos_sim, dim=2)).unsqueeze(-1)*coord2.reshape(b,-1,2).unsqueeze(1) #bxmxnx2 61 | feat1g_corloc = feat1g_corloc.sum(2) #bxmx2 62 | feat2g_corloc = (F.softmax(temperature*cos_sim, dim=1)).unsqueeze(-1)*coord1.reshape(b,-1,2).unsqueeze(2) #bxmxnx2 63 | feat2g_corloc = feat2g_corloc.sum(1) #bxnx2 64 | 65 | 66 | with torch.no_grad(): 67 | if self.config['use_nn_grid']: 68 | _, max_idx1 = cor_mat.max(2) 69 | feat1g_corloc_n = coord2_n.reshape(b,-1,2).gather(dim=1, index=max_idx1[:,:,None].repeat(1,1,2)) 70 | _, max_idx2 = cor_mat.max(1) 71 | feat2g_corloc_n = coord1_n.reshape(b,-1,2).gather(dim=1, index=max_idx2[:,:,None].repeat(1,1,2)) 72 | else: 73 | feat1g_corloc_n = normalize_coords(feat1g_corloc, h2i, w2i) 74 | feat2g_corloc_n = normalize_coords(feat2g_corloc, h1i, w1i) 75 | 76 | feat1g_std = (F.softmax(temperature*cos_sim, dim=2)).unsqueeze(-1)*(coord2_n.reshape(b,1,-1,2)**2) 77 | feat1g_std = feat1g_std.sum(2) - (feat1g_corloc_n**2) 78 | feat1g_std = feat1g_std.clamp(min=1e-6).sqrt().sum(-1) #bxn 79 | feat2g_std = (F.softmax(temperature*cos_sim, dim=1)).unsqueeze(-1)*(coord1_n.reshape(b,-1,1,2)**2) 80 | feat2g_std = feat2g_std.sum(1) - (feat2g_corloc_n**2) 81 | feat2g_std = feat2g_std.clamp(min=1e-6).sqrt().sum(-1) #bxn 82 | 83 | if self.config['use_line_search']: 84 | feat1_c_corloc_n_, feat1_c_corloc_n_org, valid1, epi_std1 = epipolar_line_search(coord1, inputs['F1'], feat1_fine, 85 | temperature*F.normalize(xf2,p=2.0,dim=1), h2i, w2i, window_size=self.config['window_size'], **self.config['line_search_config']) 86 | feat2_c_corloc_n_, feat2_c_corloc_n_org, valid2, epi_std2 = epipolar_line_search(coord2, inputs['F2'], feat2_fine, 87 | temperature*F.normalize(xf1,p=2.0,dim=1), h1i, w1i, window_size=self.config['window_size'], **self.config['line_search_config']) 88 | feat1c_corloc_org = denormalize_coords(feat1_c_corloc_n_org, h2i, w2i) 89 | feat2c_corloc_org = denormalize_coords(feat2_c_corloc_n_org, h1i, w1i) 90 | else: 91 | feat1_c_corloc_n_ = feat1g_corloc_n.detach() 92 | feat2_c_corloc_n_ = feat2g_corloc_n.detach() 93 | feat1c_corloc_org = feat1_c_corloc_n_ 94 | feat2c_corloc_org = feat2_c_corloc_n_ 95 | valid1 = torch.ones_like(feat1g_std).bool() 96 | valid2 = torch.ones_like(feat2g_std).bool() 97 | 98 | feat1w_corloc_n, window_coords_n_1in2, feat1w_std, _ = get_expected_correspondence_within_window( 99 | feat1_fine, temperature*F.normalize(xf2,p=2.0,dim=1), feat1_c_corloc_n_, self.config['window_size'], with_std=True) 100 | feat2w_corloc_n, window_coords_n_2in1, feat2w_std, _ = get_expected_correspondence_within_window( 101 | feat2_fine, temperature*F.normalize(xf1,p=2.0,dim=1), feat2_c_corloc_n_, self.config['window_size'], with_std=True) 102 | 103 | feat1w_corloc = denormalize_coords(feat1w_corloc_n, h2i, w2i) 104 | feat2w_corloc = denormalize_coords(feat2w_corloc_n, h1i, w1i) 105 | 106 | return { 107 | 'coord1':coord1, 'coord2':coord2, 108 | 'feat1g_corloc':feat1g_corloc, 109 | 'feat2g_corloc':feat2g_corloc, 110 | 'feat1w_corloc':feat1w_corloc, 111 | 'feat2w_corloc':feat2w_corloc, 112 | 'feat1c_corloc_org':feat1c_corloc_org, 113 | 'feat2c_corloc_org':feat2_c_corloc_n_org, 114 | 'feat1g_std':feat1g_std, 'feat2g_std':feat2g_std, 115 | 'feat1w_std':feat1w_std, 'feat2w_std':feat2w_std, 116 | 'temperature':temperature, 117 | 'valid_epi1':valid1, 'valid_epi2':valid2 118 | } 119 | 120 | class Preprocess_Skip(nn.Module): 121 | ''' 122 | the preprocess class for keypoint detection net training 123 | ''' 124 | def __init__(self, **kargs): 125 | super(Preprocess_Skip, self).__init__() 126 | self.__lossname__ = 'Preprocess_Skip' 127 | 128 | def forward(self, inputs, outputs): 129 | return None 130 | -------------------------------------------------------------------------------- /networks/DescNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import importlib 5 | 6 | def class_for_name(module_name, class_name): 7 | # load the module, will raise ImportError if module cannot be loaded 8 | m = importlib.import_module(module_name) 9 | return getattr(m, class_name) 10 | 11 | class ResUNet(nn.Module): 12 | def __init__(self, 13 | encoder='resnet50', 14 | pretrained=True, 15 | coarse_out_ch=128, 16 | fine_out_ch=128 17 | ): 18 | 19 | super(ResUNet, self).__init__() 20 | assert encoder in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'wide_resnet50_2'], "Incorrect encoder type" 21 | if encoder in ['resnet18', 'resnet34']: 22 | filters = [64, 128, 256, 512] 23 | else: 24 | filters = [256, 512, 1024, 2048] 25 | resnet = class_for_name("torchvision.models", encoder)(pretrained=pretrained) 26 | 27 | self.firstconv = resnet.conv1 # H/2 28 | self.firstbn = resnet.bn1 29 | self.firstrelu = resnet.relu 30 | self.firstmaxpool = resnet.maxpool # H/4 31 | 32 | # encoder 33 | self.layer1 = resnet.layer1 # H/4 34 | self.layer2 = resnet.layer2 # H/8 35 | self.layer3 = resnet.layer3 # H/16 36 | 37 | # coarse-level conv 38 | self.conv_coarse = conv(filters[2], coarse_out_ch, 1, 1) 39 | 40 | # decoder 41 | self.upconv3 = upconv(filters[2], 512, 3, 2) 42 | self.iconv3 = conv(filters[1] + 512, 512, 3, 1) 43 | self.upconv2 = upconv(512, 256, 3, 2) 44 | self.iconv2 = conv(filters[0] + 256, 256, 3, 1) 45 | 46 | # fine-level conv 47 | self.conv_fine = conv(256, fine_out_ch, 1, 1) 48 | self.out_channels = [fine_out_ch, coarse_out_ch] 49 | 50 | def skipconnect(self, x1, x2): 51 | diffY = x2.size()[2] - x1.size()[2] 52 | diffX = x2.size()[3] - x1.size()[3] 53 | 54 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 55 | diffY // 2, diffY - diffY // 2)) 56 | 57 | # for padding issues, see 58 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 59 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 60 | 61 | x = torch.cat([x2, x1], dim=1) 62 | return x 63 | 64 | def forward(self, x): 65 | x = self.firstrelu(self.firstbn(self.firstconv(x))) 66 | x_first = self.firstmaxpool(x) 67 | 68 | x1 = self.layer1(x_first) 69 | x2 = self.layer2(x1) 70 | x3 = self.layer3(x2) 71 | 72 | x_coarse = self.conv_coarse(x3) #H/16 73 | 74 | x = self.upconv3(x3) 75 | x = self.skipconnect(x2, x) 76 | x = self.iconv3(x) 77 | 78 | x = self.upconv2(x) 79 | x = self.skipconnect(x1, x) 80 | x = self.iconv2(x) 81 | 82 | x_fine = self.conv_fine(x) #H/4 83 | 84 | return {'global_map':x_coarse, 'local_map':x_fine, 'local_map_small':x_first} 85 | 86 | class ResUNetHR(nn.Module): 87 | def __init__(self, 88 | encoder='resnet50', 89 | pretrained=True, 90 | coarse_out_ch=128, 91 | fine_out_ch=128 92 | ): 93 | 94 | super(ResUNetHR, self).__init__() 95 | assert encoder in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "Incorrect encoder type" 96 | if encoder in ['resnet18', 'resnet34']: 97 | filters = [64, 128, 256, 512] 98 | else: 99 | filters = [256, 512, 1024, 2048] 100 | resnet = class_for_name("torchvision.models", encoder)(pretrained=pretrained) 101 | 102 | self.firstconv = resnet.conv1 # H/2 103 | self.firstbn = resnet.bn1 104 | self.firstrelu = resnet.relu 105 | self.firstmaxpool = resnet.maxpool # H/4 106 | 107 | # encoder 108 | self.layer1 = resnet.layer1 # H/4 109 | self.layer2 = resnet.layer2 # H/8 110 | self.layer3 = resnet.layer3 # H/16 111 | 112 | # coarse-level conv 113 | self.conv_coarse = conv(filters[2], coarse_out_ch, 1, 1) 114 | 115 | # decoder 116 | self.upconv3 = upconv(filters[2], 512, 3, 2) 117 | self.iconv3 = conv(filters[1] + 512, 512, 3, 1) 118 | self.upconv2 = upconv(512, 256, 3, 2) 119 | self.iconv2 = conv(filters[0] + 256, 256, 3, 1) 120 | self.upconv1 = upconv(256,192,3,2) 121 | self.iconv1 = conv(64 + 192, 256, 3, 1) 122 | 123 | # fine-level conv 124 | self.conv_fine = conv(256, fine_out_ch, 1, 1) 125 | self.out_channels = [fine_out_ch, coarse_out_ch] 126 | 127 | def skipconnect(self, x1, x2): 128 | diffY = x2.size()[2] - x1.size()[2] 129 | diffX = x2.size()[3] - x1.size()[3] 130 | 131 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 132 | diffY // 2, diffY - diffY // 2)) 133 | 134 | # for padding issues, see 135 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 136 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 137 | 138 | x = torch.cat([x2, x1], dim=1) 139 | return x 140 | 141 | def forward(self, x): 142 | x_first1 = self.firstrelu(self.firstbn(self.firstconv(x))) 143 | x_first = self.firstmaxpool(x_first1) 144 | 145 | x1 = self.layer1(x_first) 146 | x2 = self.layer2(x1) 147 | x3 = self.layer3(x2) 148 | 149 | x_coarse = self.conv_coarse(x3) #H/16 150 | 151 | x = self.upconv3(x3) 152 | x = self.skipconnect(x2, x) 153 | x = self.iconv3(x) 154 | 155 | x = self.upconv2(x) 156 | x = self.skipconnect(x1, x) 157 | x = self.iconv2(x) 158 | 159 | x = self.upconv1(x) 160 | x = self.skipconnect(x_first1, x) 161 | x = self.iconv1(x) 162 | 163 | x_fine = self.conv_fine(x) #H/2 164 | 165 | return {'global_map':x_coarse, 'local_map':x_fine, 'local_map_small':x_first1} 166 | 167 | class conv(nn.Module): 168 | def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): 169 | super(conv, self).__init__() 170 | self.kernel_size = kernel_size 171 | self.conv = nn.Conv2d(num_in_layers, 172 | num_out_layers, 173 | kernel_size=kernel_size, 174 | stride=stride, 175 | padding=(self.kernel_size - 1) // 2) 176 | self.bn = nn.BatchNorm2d(num_out_layers) 177 | 178 | def forward(self, x): 179 | return F.elu(self.bn(self.conv(x)), inplace=True) 180 | 181 | 182 | class upconv(nn.Module): 183 | def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): 184 | super(upconv, self).__init__() 185 | self.scale = scale 186 | self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) 187 | 188 | def forward(self, x): 189 | x = nn.functional.interpolate(x, scale_factor=self.scale, align_corners=True, mode='bilinear') 190 | return self.conv(x) -------------------------------------------------------------------------------- /datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | from matplotlib import cm 6 | from matplotlib.colors import ListedColormap, LinearSegmentedColormap 7 | 8 | def skew(x): 9 | return np.array([[0, -x[2], x[1]], 10 | [x[2], 0, -x[0]], 11 | [-x[1], x[0], 0]]) 12 | 13 | 14 | def rotateImage(image, angle): 15 | h, w = image.shape[:2] 16 | angle_radius = np.abs(angle / 180. * np.pi) 17 | cos = np.cos(angle_radius) 18 | sin = np.sin(angle_radius) 19 | tan = np.tan(angle_radius) 20 | scale_h = (h / cos + (w - h * tan) * sin) / h 21 | scale_w = (h / sin + (w - h / tan) * cos) / w 22 | scale = max(scale_h, scale_w) 23 | image_center = tuple(np.array(image.shape[1::-1]) / 2.) 24 | rot_mat = cv2.getRotationMatrix2D(image_center, angle, scale) 25 | result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LINEAR) 26 | rotation = np.eye(4) 27 | rotation[:2, :2] = rot_mat[:2, :2] 28 | return result, rotation 29 | 30 | 31 | def perspective_transform(img, param=0.001): 32 | h, w = img.shape[:2] 33 | random_state = np.random.RandomState(None) 34 | M = np.array([[1 - param + 2 * param * random_state.rand(), 35 | -param + 2 * param * random_state.rand(), 36 | -param + 2 * param * random_state.rand()], 37 | [-param + 2 * param * random_state.rand(), 38 | 1 - param + 2 * param * random_state.rand(), 39 | -param + 2 * param * random_state.rand()], 40 | [-param + 2 * param * random_state.rand(), 41 | -param + 2 * param * random_state.rand(), 42 | 1 - param + 2 * param * random_state.rand()]]) 43 | 44 | dst = cv2.warpPerspective(img, M, (w, h)) 45 | return dst, M 46 | 47 | 48 | def generate_query_kpts(img, mode, num_pts, h, w): 49 | # generate candidate query points 50 | if mode == 'random': 51 | kp1_x = np.random.rand(num_pts) * (w - 1) 52 | kp1_y = np.random.rand(num_pts) * (h - 1) 53 | coord = np.stack((kp1_x, kp1_y)).T 54 | 55 | elif mode == 'sift': 56 | gray1 = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 57 | sift = cv2.xfeatures2d.SIFT_create(nfeatures=num_pts) 58 | kp1 = sift.detect(gray1) 59 | coord = np.array([[kp.pt[0], kp.pt[1]] for kp in kp1]) 60 | 61 | elif mode == 'mixed': 62 | kp1_x = np.random.rand(1 * int(0.1 * num_pts)) * (w - 1) 63 | kp1_y = np.random.rand(1 * int(0.1 * num_pts)) * (h - 1) 64 | kp1_rand = np.stack((kp1_x, kp1_y)).T 65 | 66 | sift = cv2.xfeatures2d.SIFT_create(nfeatures=int(0.9 * num_pts)) 67 | gray1 = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 68 | kp1_sift = sift.detect(gray1) 69 | kp1_sift = np.array([[kp.pt[0], kp.pt[1]] for kp in kp1_sift]) 70 | if len(kp1_sift) == 0: 71 | coord = kp1_rand 72 | else: 73 | coord = np.concatenate((kp1_rand, kp1_sift), 0) 74 | 75 | else: 76 | raise Exception('unknown type of keypoints') 77 | 78 | return coord 79 | 80 | 81 | def prune_kpts(coord1, F_gt, im2_size, intrinsic1, intrinsic2, pose, d_min, d_max): 82 | # compute the epipolar lines corresponding to coord1 83 | coord1_h = np.concatenate([coord1, np.ones_like(coord1[:, [0]])], axis=1).T # 3xn 84 | epipolar_line = F_gt.dot(coord1_h) # 3xn 85 | epipolar_line /= np.clip(np.linalg.norm(epipolar_line[:2], axis=0), a_min=1e-10, a_max=None) # 3xn 86 | 87 | # determine whether the epipolar lines intersect with the second image 88 | h2, w2 = im2_size 89 | corners = np.array([[0, 0, 1], [0, h2 - 1, 1], [w2 - 1, 0, 1], [w2 - 1, h2 - 1, 1]]) # 4x3 90 | dists = np.abs(corners.dot(epipolar_line)) 91 | # if the epipolar line is far away from any image corners than sqrt(h^2+w^2) 92 | # it doesn't intersect with the image 93 | non_intersect = (dists > np.sqrt(w2 ** 2 + h2 ** 2)).any(axis=0) 94 | 95 | # determine if points in coord1 is likely to have correspondence in the other image by the rough depth range 96 | intrinsic1_4x4 = np.eye(4) 97 | intrinsic1_4x4[:3, :3] = intrinsic1 98 | intrinsic2_4x4 = np.eye(4) 99 | intrinsic2_4x4[:3, :3] = intrinsic2 100 | coord1_h_min = np.concatenate([d_min * coord1, 101 | d_min * np.ones_like(coord1[:, [0]]), 102 | np.ones_like(coord1[:, [0]])], axis=1).T 103 | coord1_h_max = np.concatenate([d_max * coord1, 104 | d_max * np.ones_like(coord1[:, [0]]), 105 | np.ones_like(coord1[:, [0]])], axis=1).T 106 | coord2_h_min = intrinsic2_4x4.dot(pose).dot(np.linalg.inv(intrinsic1_4x4)).dot(coord1_h_min) 107 | coord2_h_max = intrinsic2_4x4.dot(pose).dot(np.linalg.inv(intrinsic1_4x4)).dot(coord1_h_max) 108 | coord2_min = coord2_h_min[:2] / (coord1_h_min[2] + 1e-10) 109 | coord2_max = coord2_h_max[:2] / (coord1_h_max[2] + 1e-10) 110 | out_range = ((coord2_min[0] < 0) & (coord2_max[0] < 0)) | \ 111 | ((coord2_min[1] < 0) & (coord2_max[1] < 0)) | \ 112 | ((coord2_min[0] > w2 - 1) & (coord2_max[0] > w2 - 1)) | \ 113 | ((coord2_min[1] > h2 - 1) & (coord2_max[1] > h2 - 1)) 114 | 115 | ind_intersect = ~(non_intersect | out_range) 116 | return ind_intersect 117 | 118 | def random_choice(array, size): 119 | rand = np.random.RandomState(1234) 120 | num_data = len(array) 121 | if num_data > size: 122 | idx = rand.choice(num_data, size, replace=False) 123 | else: 124 | idx = rand.choice(num_data, size, replace=True) 125 | return array[idx] 126 | 127 | def tensor2array(tensor, max_value=None, colormap='coolwarm'): 128 | def high_res_colormap(low_res_cmap, resolution=1000, max_value=1): 129 | # Construct the list colormap, with interpolated values for higer resolution 130 | # For a linear segmented colormap, you can just specify the number of point in 131 | # cm.get_cmap(name, lutsize) with the parameter lutsize 132 | x = np.linspace(0, 1, low_res_cmap.N) 133 | low_res = low_res_cmap(x) 134 | new_x = np.linspace(0, max_value, resolution) 135 | high_res = np.stack([np.interp(new_x, x, low_res[:, i]) 136 | for i in range(low_res.shape[1])], axis=1) 137 | return ListedColormap(high_res) 138 | 139 | 140 | def opencv_rainbow(resolution=1000): 141 | # Construct the opencv equivalent of Rainbow 142 | opencv_rainbow_data = ( 143 | (0.000, (1.00, 0.00, 0.00)), 144 | (0.400, (1.00, 1.00, 0.00)), 145 | (0.600, (0.00, 1.00, 0.00)), 146 | (0.800, (0.00, 0.00, 1.00)), 147 | (1.000, (0.60, 0.00, 1.00)) 148 | ) 149 | 150 | return LinearSegmentedColormap.from_list('opencv_rainbow', opencv_rainbow_data, resolution) 151 | COLORMAPS = {'rainbow': opencv_rainbow(), 152 | 'magma': high_res_colormap(cm.get_cmap('magma')), 153 | 'bone': cm.get_cmap('bone', 10000), 154 | 'seismic':high_res_colormap(cm.get_cmap('seismic')), 155 | 'coolwarm':high_res_colormap(cm.get_cmap('coolwarm'))} 156 | tensor = tensor.detach().cpu() 157 | if max_value is None: 158 | max_value = tensor.max().item() 159 | if tensor.ndimension() == 2 or tensor.size(0) == 1: 160 | norm_array = tensor.squeeze().numpy()/max_value 161 | if colormap in list(COLORMAPS.keys()): 162 | map_func = COLORMAPS[colormap] 163 | else: 164 | map_func = high_res_colormap(cm.get_cmap(colormap)) 165 | array = map_func(norm_array).astype(np.float32) 166 | array = array.transpose(2, 0, 1) 167 | 168 | elif tensor.ndimension() == 3: 169 | # assert(tensor.size(0) == 3) 170 | array = tensor.numpy() 171 | return array -------------------------------------------------------------------------------- /losses/kploss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .preprocess_utils import * 5 | from torch.distributions import Categorical, Bernoulli 6 | 7 | class DiskLoss(nn.Module): 8 | def __init__(self, configs, device=None): 9 | super(DiskLoss, self).__init__() 10 | self.__lossname__ = 'DiskLoss' 11 | self.config = configs 12 | self.unfold_size = self.config['grid_size'] 13 | self.t_base = self.config['temperature_base'] 14 | self.t_max = self.config['temperature_max'] 15 | self.reward = getattr(self, self.config['epipolar_reward']) 16 | self.good_reward = self.config['good_reward'] 17 | self.bad_reward = self.config['bad_reward'] 18 | self.kp_penalty = self.config['kp_penalty'] 19 | 20 | def point_distribution(self, logits): 21 | proposal_dist = Categorical(logits=logits) # bx1x(h//g)x(w//g)x(g*g) 22 | proposals = proposal_dist.sample() # bx1x(h//g)x(w//g) 23 | proposal_logp = proposal_dist.log_prob(proposals) # bx1x(h//g)x(w//g) 24 | 25 | # accept_logits = select_on_last(logits, proposals).squeeze(-1) 26 | accept_logits = torch.gather(logits, dim=-1, index=proposals[..., None]).squeeze(-1) # bx1x(h//g)x(w//g) 27 | 28 | accept_dist = Bernoulli(logits=accept_logits) 29 | accept_samples = accept_dist.sample() # bx1x(h//g)x(w//g) 30 | accept_logp = accept_dist.log_prob(accept_samples) # for accepted points, equals to sigmoid() then log(); for denied, (1-sigmoid).log 31 | accept_mask = accept_samples == 1. 32 | 33 | logp = proposal_logp + accept_logp 34 | 35 | return proposals, accept_mask, logp 36 | 37 | def point_sample(self, kp_map): 38 | kpmap_unfold = unfold(kp_map, self.unfold_size) 39 | proposals, accept_mask, logp = self.point_distribution(kpmap_unfold) 40 | 41 | b, _, h, w = kp_map.shape 42 | grids_org = gen_grid(h_min=0, h_max=h-1, w_min=0, w_max=w-1, len_h=h, len_w=w) 43 | grids_org = grids_org.reshape(h, w, 2)[None, :, :, :].repeat(b, 1, 1, 1).to(kp_map) 44 | grids_org = grids_org.permute(0,3,1,2) # bx2xhxw 45 | grids_unfold = unfold(grids_org, self.unfold_size) # bx2x(h//g)x(w//g)x(g*g) 46 | 47 | kps = grids_unfold.gather(dim=4, index=proposals.unsqueeze(-1).repeat(1,2,1,1,1)) 48 | return kps.squeeze(4).permute(0,2,3,1), logp, accept_mask 49 | 50 | @ torch.no_grad() 51 | def constant_reward(self, inputs, outputs, coord1, coord2, reward_thr, rescale_thr): 52 | coord1_h = homogenize(coord1).transpose(1, 2) #bx3xm 53 | coord2_h = homogenize(coord2).transpose(1, 2) #bx3xn 54 | fmatrix = inputs['F1'] 55 | fmatrix2 = inputs['F2'] 56 | 57 | # compute the distance of the points in the second image 58 | epipolar_line = fmatrix.bmm(coord1_h) 59 | epipolar_line_ = epipolar_line / torch.clamp( 60 | torch.norm(epipolar_line[:, :2, :], p=2, dim=1, keepdim=True), min=1e-8) 61 | epipolar_dist = torch.abs(epipolar_line_.transpose(1, 2)@coord2_h) #bxmxn 62 | 63 | # compute the distance of the points in the first image 64 | epipolar_line2 = fmatrix2.bmm(coord2_h) 65 | epipolar_line2_ = epipolar_line2 / torch.clamp( 66 | torch.norm(epipolar_line2[:, :2, :], p=2, dim=1, keepdim=True), min=1e-8) 67 | epipolar_dist2 = torch.abs(epipolar_line2_.transpose(1, 2)@coord1_h) #bxnxm 68 | epipolar_dist2 = epipolar_dist2.transpose(1,2) #bxmxn 69 | 70 | if rescale_thr: 71 | b, _, _ = epipolar_dist.shape 72 | dist1 = epipolar_dist.detach().reshape(b, -1).mean(1,True) 73 | dist2 = epipolar_dist2.detach().reshape(b,-1).mean(1,True) 74 | dist_ = torch.cat([dist1, dist2], dim=1) 75 | scale1 = dist1/dist_.min(1,True)[0].clamp(1e-6) 76 | scale2 = dist2/dist_.min(1,True)[0].clamp(1e-6) 77 | thr1 = reward_thr*scale1 78 | thr2 = reward_thr*scale2 79 | thr1 = thr1.reshape(b,1,1) 80 | thr2 = thr2.reshape(b,1,1) 81 | else: 82 | thr1 = reward_thr 83 | thr2 = reward_thr 84 | scale1 = epipolar_dist2.new_tensor(1.) 85 | scale2 = epipolar_dist2.new_tensor(1.) 86 | 87 | good = (epipolar_dist 60000: 53 | keypoints_a = keypoints_a[:60000,:] 54 | descriptors_a = descriptors_a[:60000, :] 55 | n_feats.append(keypoints_a.shape[0]) 56 | 57 | for im_idx in range(2, 7): 58 | keypoints_b, descriptors_b = read_feats(seq_name, im_idx) 59 | if keypoints_b.shape[0] > 60000: 60 | keypoints_b = keypoints_b[:60000,:] 61 | descriptors_b = descriptors_b[:60000, :] 62 | n_feats.append(keypoints_b.shape[0]) 63 | 64 | matches = mnn_matcher( 65 | torch.from_numpy(descriptors_a).to(device=device), 66 | torch.from_numpy(descriptors_b).to(device=device) 67 | ) 68 | 69 | homography = np.loadtxt(dataset_path/"{}/H_1_{}".format(seq_name, im_idx)) 70 | 71 | pos_a = keypoints_a[matches[:, 0], : 2] 72 | pos_a_h = np.concatenate([pos_a, np.ones([matches.shape[0], 1])], axis=1) 73 | pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h))) 74 | pos_b_proj = pos_b_proj_h[:, : 2] / pos_b_proj_h[:, 2 :] 75 | 76 | pos_b = keypoints_b[matches[:, 1], : 2] 77 | 78 | dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1)) 79 | 80 | n_matches.append(matches.shape[0]) 81 | seq_type.append(seq_name[0]) 82 | 83 | if dist.shape[0] == 0: 84 | dist = np.array([float("inf")]) 85 | 86 | for thr in rng: 87 | if seq_name[0] == 'i': 88 | i_err[thr] += np.mean(dist <= thr) 89 | else: 90 | v_err[thr] += np.mean(dist <= thr) 91 | 92 | seq_type = np.array(seq_type) 93 | n_feats = np.array(n_feats) 94 | n_matches = np.array(n_matches) 95 | 96 | return i_err, v_err, [seq_type, n_feats, n_matches] 97 | 98 | def summary(stats): 99 | seq_type, n_feats, n_matches = stats 100 | print('# Features: {:f} - [{:d}, {:d}]'.format(np.mean(n_feats), np.min(n_feats), np.max(n_feats))) 101 | print('# Matches: Overall {:f}, Illumination {:f}, Viewpoint {:f}'.format( 102 | np.sum(n_matches) / ((n_i + n_v) * 5), 103 | np.sum(n_matches[seq_type == 'i']) / (n_i * 5), 104 | np.sum(n_matches[seq_type == 'v']) / (n_v * 5)) 105 | ) 106 | def generate_read_function(method, extension='ppm'): 107 | def read_function(seq_name, im_idx): 108 | aux = np.load(features_path/"{}/{}.{}.{}".format(seq_name, im_idx, extension, method)) 109 | if top_k is None: 110 | return aux['keypoints'], aux['descriptors'] 111 | else: 112 | assert('scores' in aux) 113 | ids = np.argsort(aux['scores'])[-top_k :] 114 | return aux['keypoints'][ids, :], aux['descriptors'][ids, :] 115 | return read_function 116 | 117 | def sift_to_rootsift(descriptors): 118 | return np.sqrt(descriptors / np.expand_dims(np.sum(np.abs(descriptors), axis=1), axis=1) + 1e-16) 119 | def parse_mat(mat): 120 | keypoints = mat['keypoints'][:, : 2] 121 | raw_descriptors = mat['descriptors'] 122 | l2_norm_descriptors = raw_descriptors / np.expand_dims(np.sum(raw_descriptors ** 2, axis=1), axis=1) 123 | descriptors = sift_to_rootsift(l2_norm_descriptors) 124 | if top_k is None: 125 | return keypoints, descriptors 126 | else: 127 | assert('scores' in mat) 128 | ids = np.argsort(mat['scores'][0])[-top_k :] 129 | return keypoints[ids, :], descriptors[ids, :] 130 | 131 | if top_k is None: 132 | cache_dir = 'cache' 133 | else: 134 | cache_dir = 'cache-top' 135 | if not os.path.isdir(cache_dir): 136 | os.mkdir(cache_dir) 137 | 138 | errors = {} 139 | 140 | for method in methods: 141 | output_file = os.path.join(cache_dir, method + '.npy') 142 | print(method) 143 | if method == 'hesaff': 144 | read_function = lambda seq_name, im_idx: parse_mat(loadmat(os.path.join(dataset_path, seq_name, '%d.ppm.hesaff' % im_idx), appendmat=False)) 145 | else: 146 | if method == 'delf' or method == 'delf-new': 147 | read_function = generate_read_function(method, extension='png') 148 | else: 149 | read_function = generate_read_function(method) 150 | if os.path.exists(output_file): 151 | print('Loading precomputed errors...') 152 | errors[method] = np.load(output_file, allow_pickle=True) 153 | else: 154 | errors[method] = benchmark_features(read_function) 155 | saved = np.array(errors[method], dtype=object) 156 | np.save(output_file, saved) 157 | summary(errors[method][-1]) 158 | 159 | # evalute MMA score 160 | MMAscore = {} 161 | for method in methods: 162 | i_err, v_err, _ = errors[method] 163 | tmp_a = [] 164 | tmp_i = [] 165 | tmp_v = [] 166 | for thr in range(1,11): 167 | tmp_a.append((i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5)) 168 | tmp_i.append(i_err[thr] / (n_i * 5)) 169 | tmp_v.append(v_err[thr] / (n_v * 5)) 170 | cur_a = 0 171 | cur_i = 0 172 | cur_v = 0 173 | upper_bound = 0 174 | for idx, (mma_a, mma_i, mma_v) in enumerate(zip(tmp_a, tmp_i, tmp_v)): 175 | cur_a += (2-(idx+1)/10.)*mma_a 176 | cur_i += (2-(idx+1)/10.)*mma_i 177 | cur_v += (2-(idx+1)/10.)*mma_v 178 | upper_bound += (2-(idx+1)/10.)*1 179 | MMAscore[method] = (cur_a/upper_bound, cur_i/upper_bound, cur_v/upper_bound) 180 | 181 | # plot 182 | plt_lim = [1, 10] 183 | plt_rng = np.arange(plt_lim[0], plt_lim[1] + 1) 184 | plt_ylim = [0, 1] 185 | 186 | plt.rc('axes', titlesize=25) 187 | plt.rc('axes', labelsize=25) 188 | 189 | labelsize = 20 190 | plt.figure(figsize=(15, 5)) 191 | 192 | plt.subplot(1, 3, 1) 193 | for method, name, color, ls in zip(methods, names, colors, linestyles): 194 | i_err, v_err, _ = errors[method] 195 | plt.plot(plt_rng, [(i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5) for thr in plt_rng], color=color, ls=ls, linewidth=3, label=name) 196 | plt.title('Overall') 197 | plt.xlim(plt_lim) 198 | plt.xticks(plt_rng) 199 | plt.ylabel('MMA') 200 | plt.ylim(plt_ylim) 201 | plt.grid() 202 | plt.tick_params(axis='both', which='major', labelsize=labelsize) 203 | # plt.legend() 204 | 205 | plt.subplot(1, 3, 2) 206 | for method, name, color, ls in zip(methods, names, colors, linestyles): 207 | i_err, v_err, _ = errors[method] 208 | plt.plot(plt_rng, [i_err[thr] / (n_i * 5) for thr in plt_rng], color=color, ls=ls, linewidth=3, label=name) 209 | plt.title('Illumination') 210 | plt.xlabel('threshold [px]') 211 | plt.xlim(plt_lim) 212 | plt.xticks(plt_rng) 213 | plt.ylim(plt_ylim) 214 | plt.gca().axes.set_yticklabels([]) 215 | plt.grid() 216 | plt.tick_params(axis='both', which='major', labelsize=labelsize) 217 | 218 | plt.subplot(1, 3, 3) 219 | for method, name, color, ls in zip(methods, names, colors, linestyles): 220 | i_err, v_err, _ = errors[method] 221 | plt.plot(plt_rng, [v_err[thr] / (n_v * 5) for thr in plt_rng], color=color, ls=ls, linewidth=3, label=name) 222 | plt.title('Viewpoint') 223 | plt.xlim(plt_lim) 224 | plt.xticks(plt_rng) 225 | plt.ylim(plt_ylim) 226 | plt.gca().axes.set_yticklabels([]) 227 | plt.grid() 228 | plt.tick_params(axis='both', which='major', labelsize=labelsize) 229 | 230 | import datetime 231 | timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") 232 | 233 | if top_k is None: 234 | plt.savefig('hseq{}.pdf'.format(timestamp), bbox_inches='tight', dpi=300) 235 | plt.savefig('hseq{}.eps'.format(timestamp), bbox_inches='tight', dpi=300) 236 | else: 237 | plt.savefig('hseq-top.pdf', bbox_inches='tight', dpi=300) 238 | 239 | plt.legend() 240 | if top_k is None: 241 | plt.savefig('hseq{}_label.pdf'.format(timestamp), bbox_inches='tight', dpi=300) 242 | else: 243 | plt.savefig('hseq-top_label.pdf', bbox_inches='tight', dpi=300) 244 | 245 | with open('hseq{}.txt'.format(timestamp), 'w') as f: 246 | lines = '' 247 | for name, method in zip(names, methods): 248 | name = name.ljust(25, ' ') 249 | tmp_stat = errors[method][-1] 250 | seq_type, n_feats, n_matches = tmp_stat 251 | num_feat = np.mean(n_feats) 252 | num_match = np.sum(n_matches) / ((n_i + n_v) * 5) 253 | mmascore = MMAscore[method] 254 | lines += '{} & {:.1f} & {:.1f} & {:.3f} & {:.3f} & {:.3f}\n'.format( 255 | name, num_feat, num_match, mmascore[0], mmascore[1], mmascore[2]) 256 | 257 | f.write(lines) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Decoupling Makes Weakly Supervised Local Feature Better (PoSFeat) 2 | This is the official implementation of **PoSFeat** (CVPR2022), a weakly supervised local feature training framework. 3 | 4 | **Decoupling Makes Weakly Supervised Local Feature Better**
5 | [Kunhong Li](https://scholar.google.co.uk/citations?user=_kzDdx8AAAAJ&hl=zh-CN&oi=ao), [Longguang Wang](https://longguangwang.github.io/), [Li Liu](http://lilyliliu.com/Default.aspx), [Qing Ran](https://scholar.google.co.uk/citations?user=6ydy5oEAAAAJ&hl=zh-CN&oi=ao), [Kai Xu](http://kevinkaixu.net/index.html), [Yulan Guo* ](http://yulanguo.me/)
6 | **[[Paper](https://openaccess.thecvf.com/content/CVPR2022/html/Li_Decoupling_Makes_Weakly_Supervised_Local_Feature_Better_CVPR_2022_paper.html)] [[Arxiv](https://arxiv.org/abs/2201.02861)] [[Blog](https://zhuanlan.zhihu.com/p/477818450)] [[Bilibili](https://www.bilibili.com/video/BV1xg411R7wD?spm_id_from=333.337.search-card.all.click)] [[Youtube](https://www.youtube.com/watch?v=VnjdkAOIndc)]** 7 | 8 | ## Overview 9 | We decoupled the description net training and detection net training, and postpone the detection net training. This simple but effective framework allows us to detect robust keypoints based on the optimized descriptors. 10 |

11 | 12 | ## Training 13 | **(1) Download training data** 14 | 15 | 下载[CAPS](https://github.com/qianqianwang68/caps)处理好的MegaDepth子集。注意,如果要参加[IMC](https://www.cs.ubc.ca/research/image-matching-challenge/current/),需要手动去掉一些场景(`0008 0021 0024 0063 1589`)。 16 | 17 | Down the preprocessed subset of MegaDepth from [CAPS](https://github.com/qianqianwang68/caps). If you want to test the local feature on [IMC](https://www.cs.ubc.ca/research/image-matching-challenge/current/), please manually remove the banned scenes (`0008 0021 0024 0063 1589`). 18 | 19 | 20 | **(2) Train the description net** 21 | 22 | 开始训练描述子网络之前,首先按照你自己的训练集路径去修改[config/train_desc.yaml](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/configs/train_desc.yaml)中`data_config_train`里的`data_path`。 23 | 24 | 因为某些原因(具体我也不知道为啥QAQ),我们现在使用的多卡训练极其慢,因此要先在终端设置单一GPU可见 25 | 26 | To start the description net training, please mannuly modify the `data_path` of `data_config_train` in [config/train_desc.yaml](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/configs/train_desc.yaml). 27 | 28 | Because of unknown reason, the multi-gpu training is really slow, so we should set single GPU available 29 | ``` 30 | export CUDA_VISIBLE_DEVICES=0 31 | ``` 32 | 33 | 随后开始跑训练 34 | 35 | Then run the following command 36 | ``` 37 | python train.py --config ./configs/train_desc.yaml 38 | ``` 39 | 40 | 描述子网络的训练大概需要24小时(24G显存的3090单卡) 41 | 42 | It takes about 24 hours to finish description net training on a single NVIDIA RTX3090 GPU. 43 | 44 | **(3) Train the detection net** 45 | 46 | 修改关键点配置文件中的`datapath`并且在终端设置单一GPU可见 47 | 48 | Similarly, modify the `datapath` and set single GPU available 49 | ``` 50 | export CUDA_VISIBLE_DEVICES=0 51 | ``` 52 | 53 | 随后开始跑训练 54 | 55 | And run the command 56 | ``` 57 | python train.py --config ./configs/train_kp.yaml 58 | ``` 59 | 60 | **(4) The difference between the results trained with this code repo and in the paper** 61 | 62 | 论文里我们用的`SGD`优化器和`lr=1e-3`,这各仓库中是`Adam`和`lr=1e-4`。注意`Adam`在`lr=1e-3`时可能无法收敛。 63 | 64 | In the paper, we use `SGD` optimizer with `lr=1e-3` to train the model, and here is the `Adam` with `lr=1e-4`. Note that, Adam with lr=1e-3 may not achieve convergence. 65 | 66 | **(5) Multi-GPU training** 67 | 68 | 我们使用pytorch的`DistributedDataParallel` API来实现单机多卡训练,但不知道为啥特别慢,所以都是禁掉了多GPU的。如果你实在需要多GPU训练,可能得自己修改一下代码,使用`DataParallel` API。 69 | 70 | In this code repo, we use the `DistributedDataParallel` API of pytorch to achieve multi-GPU training, which is slow because of unknown reason. If you really need multi-gpu training, please modify the codes to use `DataParallel` API. 71 | 72 | **(6) Visualization during training** 73 | 74 | 我们提供了可视化工具来监控训练进程,尤其是关键点检测训练的过程,损失函数的值是无法作为参考的,因此需要通过关键点得分图的情况来判断是否需要停止训练。可视化的结果包括了关键点得分图,关键点和原始匹配,所有结果都会存在checkpoint的路径中。注意在训练描述子时,可视化用到的关键点时sift(这个时候关键点检测网络还没训呢)。匹配用不同颜色代表了匹配的正确程度(绿色最优),但这个评价是依赖对极几何完成的,不一定完全正确。 75 | 76 | We also provide a visualization tool to give an intuition about the model performance during training. The results (including the heatmap, keypoints and raw matches) will be saved in the checkpoint path. 77 | The visualization results includes the scoremap of keypoints (meaningless for description net training), the keypoints (sift for description net training) and matches (we color the match line with epipolar constraint). 78 | 79 | **(7) Some dependencies** 80 | 81 | 其他的依赖库不做赘述,[path](https://path.readthedocs.io/en/latest/index.html)这个包因为有很多重名的所以单独列出来,请根据[readme on github](https://github.com/jaraco/path)或者[introduction on PyPI](https://pypi.org/project/path/)去安装path包。 82 | 83 | We depend on the [path](https://path.readthedocs.io/en/latest/index.html) package to manage the paths in this repo, please follow the [readme on github](https://github.com/jaraco/path) or [introduction on PyPI](https://pypi.org/project/path/) to install it. Users may be familiar with other dependencies, you can simply use `pip` and `conda` to install dependencies. 84 | 85 | ## Evaluation 86 | **(1) Feature extraction** 87 | 88 | 使用`extract.py`就可以提取PoSFeat特征,这个文件依赖于[managers/extractor.py](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/managers/extractor.py),使用者需要提供一个`.yaml`的配置文件,文件中需要包含datapath和detector config。输出的特征可以用`.npz`或`.h5`两种格式保存。 89 | 90 | 如果配置文件里`use_sift: True`,那么输出的关键点会使用sift而不是学习的关键点。这里的sift使用的是OpenCV的默认设置,提取过程在dataloader里面完成,直接包含在了inputs字典里。 91 | 92 | Using the `extract.py` can extract PoSFeat features. This file works with the [managers/extractor.py](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/managers/extractor.py), and users should provide a config file containing the datapath, detector config. The output can be `.npz` or `.h5`. 93 | 94 | With `use_sift: True` in the config file, the output would be the sift keypoint with PoSFeat descriptor. The SIFT keypoints are detected with the OpenCV default settings in the dataloader. 95 | 96 | 97 | **(2) HPatches** 98 | 99 | HPatches数据集的测试采用了[D2-Net](https://github.com/mihaidusmanu/d2-net/tree/master/hpatches_sequences)提出的方式。首先需要按照D2-Net介绍的方法取下载和处理数据集。为了方便,我们稍微修改了输入部分的代码,评测部分没有修改。评测的结果会以`.npy`文件的形式保存在[evaluations/hpatches/cache](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/tree/main/evaluations/hpatches/cache)中,这个文件夹里有一些现有方法的结果缓存。注意,一定要按D2-Net的要求去除原始数据集中的一些高分辨率的图像。 100 | 101 | We follow the evalutaion protocal proposed by [D2-Net](https://github.com/mihaidusmanu/d2-net/tree/master/hpatches_sequences) (please follow the introduction in D2-Net to download and modify the dataset), and modify the input codes for convenience. The result will be saved in [evaluations/hpatches/cache](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/tree/main/evaluations/hpatches/cache) as a `.npy` file, and we provide the results of several methods in the cache folder. Note that, you should mannuly remove the high resolution scenes in the original dataset. 102 | 103 | Run the command 104 | ``` 105 | export CUDA_VISIBLE_DEVICES=0 106 | python extract.py --config ./configs/extract_hpatches.yaml 107 | ``` 108 | 109 | 提取完特征之后,跑评测,先将终端跳到[evaluations/hpatches](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/evaluations/hpatches),修改一下脚本里包含的待评测方法(不改的话,只会有缓存的PoSFeat_CVPR结果) 110 | 111 | Then turn to the [evaluations/hpatches](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/evaluations/hpatches) folder, modify the path in the evaluation script (if you donnot modify the script, there is only a PoSFeat_CVPR cache result) and run the script 112 | ``` 113 | cd ./evaluations/hpatches 114 | python evaluation.py 115 | ``` 116 | 117 | 完成评测后,[evaluations/hpatches](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/evaluations/hpatches)中会有一个包含曲线的图像和一个包含数值结果的`.txt`文件 118 | 119 | When finishing the evaluation, you will get pictures of curves and a `.txt` file containing the quantitative results in the [evaluations/hpatches](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/evaluations/hpatches) folder. 120 | 121 |

122 | 123 | **(3) Aachen-Day-Night** 124 | 125 | 这部分测试完全按照[The Visual Localization Benchmark](https://www.visuallocalization.net/)中standard Local feature challenge的[pipeline](https://github.com/tsattler/visuallocalizationbenchmark)来进行,因此按照pipeline的介绍,先下载数据集,然后按照以下的结构组织数据集 126 | 127 | We follow the standard Local feature challenge [pipeline](https://github.com/tsattler/visuallocalizationbenchmark) of [The Visual Localization Benchmark](https://www.visuallocalization.net/), please follow the introductions to download the dataset, then manage the data in this way 128 | ``` 129 | data_path_root_aachen 130 | ├── 3D-models/ 131 | │ ├── aachen_v_1/ 132 | │ │ ├── aachen_cvpr2018_db.nvm 133 | │ │ └── database_intrinsics.txt 134 | │ └── aachen_v_1_1/ 135 | │ ├── aachen_v_1_1.nvm 136 | │ ├── cameras.bin 137 | │ ├── database_intrinsics_v1_1.txt 138 | │ ├── images.bin 139 | │ ├── points3D.bin 140 | │ └── project.ini 141 | │ 142 | ├── images # the v1 data and v1.1 data are mixed in this folder 143 | │ └── images_upright/ 144 | │ ├── db/ 145 | │ ├── queries/ 146 | │ └── sequences/ 147 | │ 148 | ├── queries/ 149 | │ ├── day_time_queries_with_intrinsics.txt 150 | │ ├── night_time_queries_with_intrinsics.txt 151 | │ └── night_time_queries_with_intrinsics_v1_1.txt 152 | │ 153 | └── others/ 154 | ├── database.db 155 | ├── database_v1_1.db 156 | ├── image_pairs_to_match.txt 157 | └── image_pairs_to_match_v1_1.txt 158 | ``` 159 | 160 | 如果不想按照上述结构组织数据集,那么你需要手动的修改一下数据路径的设置([evauluations/aachen/reconstruct_pipeline.py](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/evaluations/aachen/reconstruct_pipeline.py) (Line 329-339) and [evauluations/aachen/reconstruct_pipeline_v1_1.py](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/evaluations/aachen/reconstruct_pipeline_v1_1.py) (Line 319-330)) 161 | 162 | If you do not want to manage the data, you should mannuly modify the datapath settings in [evauluations/aachen/reconstruct_pipeline.py](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/evaluations/aachen/reconstruct_pipeline.py) (Line 329-339) and [evauluations/aachen/reconstruct_pipeline_v1_1.py](https://github.com/The-Learning-And-Vision-Atelier-LAVA/PoSFeat/blob/main/evaluations/aachen/reconstruct_pipeline_v1_1.py) (Line 319-330). 163 | 164 | 测试之前,还是先提特征 165 | 166 | Before evaluation, we should extract the features first, 167 | ``` 168 | export CUDA_VISIBLE_DEVICES=0 169 | python extract.py --config ./configs/extract_aachen.yaml 170 | ``` 171 | 172 | 在aachen-v1上测试,用下列代码 173 | 174 | For evaulation on aachen-v1, run the command 175 | ``` 176 | cd ./evaluations/aachen 177 | python reconstruct_pipeline.py --dataset_path [YOUR_data_path_root_aachen] \ 178 | --feature_path ../../ckpts/aachen/PoSFeat_mytrain/desc \ 179 | --colmap_path [YOUR_PATH_TO_COLMAP] \ 180 | --method_name PoSFeat_mytrain \ 181 | --match_list_path image_pairs_to_match.txt 182 | ``` 183 | 184 | 在aachen-v1.1上测试,用下列代码 185 | 186 | For evaulation on aachen-v1.1, run the command 187 | ``` 188 | cd ./evaluations/aachen 189 | python reconstruct_pipeline_v1_1.py --dataset_path [YOUR_data_path_root_aachen] \ 190 | --feature_path ../../ckpts/aachen/PoSFeat_mytrain/desc \ 191 | --colmap_path [YOUR_PATH_TO_COLMAP] \ 192 | --method_name PoSFeat_mytrain \ 193 | --match_list_path image_pairs_to_match_v1_1.txt 194 | ``` 195 | 196 | 测试完成后,数据集路径中会有两个额外的文件夹,`intermedia`里面是一些中间结果(比如database和稀疏点云),`results`里面是可以提交到benchmark的`.txt`文件。 197 | 198 | After evaluation, there will be 2 more folders created, `intermedia` contains intermediate results (such as sparse model and database) and `results` contains the `.txt` files that can be upload to the benchmark. 199 | 200 | 201 | Note that, because the pose estimation (image registration) is based on the results of reconstruction, the results may be different each time. 202 | 203 | **(4) ETH local feature benchmark** 204 | 205 | 按照[ETH local feature benchmark](https://github.com/ahojnnes/local-feature-evaluation) ([download instruction](https://github.com/ahojnnes/local-feature-evaluation/blob/master/INSTRUCTIONS.md))中的介绍下载数据集。数据集需要按照下列方式组织 206 | 207 | Download the dataset following the introduction in [ETH local feature benchmark](https://github.com/ahojnnes/local-feature-evaluation) ([download instruction](https://github.com/ahojnnes/local-feature-evaluation/blob/master/INSTRUCTIONS.md)). Manage the dataset in this way 208 | ``` 209 | data_path_root_ETH_LFB 210 | ├── Alamo/ 211 | │ ├── images/ 212 | │ │ └── ... 213 | │ └── database.db 214 | │ 215 | ├── ArtsQuad_dataset/ 216 | │ ├── images/ 217 | │ │ └── ... 218 | │ └── database.db 219 | │ 220 | ├── Fountain/ 221 | │ ├── images/ 222 | │ │ └── ... 223 | │ └── database.db 224 | │ 225 | └── ... 226 | ``` 227 | 还是先提特征,我们按不同的场景完成特征提取(手动修改配置文件里的`subfolder`) 228 | 229 | Extract features first, we extract features for different scenes individually (mannuly modify the subfolder) 230 | ``` 231 | export CUDA_VISIBLE_DEVICES=0 232 | python extract.py --config ./configs/extract_ETH.yaml 233 | ``` 234 | 然后对不同场景进行测试 235 | 236 | Then run evaluation for the scene 237 | ``` 238 | cd ./evaluations/ETH_local_feature 239 | python reconstruction_pipeline.py --config ../../configs/extract_ETH.yaml 240 | ``` 241 | 242 | ## BibeTeX 243 | 244 | If you use this code in your project, please cite the following paper 245 | ``` 246 | @InProceedings{li2022decoupling, 247 | title={Decoupling Makes Weakly Supervised Local Feature Better}, 248 | author={Li, Kunhong and Wang, Longguang and Liu, Li and Ran, Qing and Xu, Kai and Guo, Yulan}, 249 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 250 | month = {June}, 251 | year = {2022}, 252 | pages = {15838-15848} 253 | } 254 | ``` -------------------------------------------------------------------------------- /evaluations/aachen/reconstruct_pipeline_v1_1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import os 6 | 7 | import shutil 8 | 9 | import subprocess 10 | 11 | import sqlite3 12 | 13 | import torch 14 | 15 | import types 16 | 17 | from tqdm import tqdm 18 | 19 | from matchers import mutual_nn_matcher 20 | 21 | from camera import Camera 22 | 23 | from utils import quaternion_to_rotation_matrix, camera_center_to_translation 24 | 25 | from path import Path 26 | 27 | import sys 28 | IS_PYTHON3 = sys.version_info[0] >= 3 29 | 30 | def array_to_blob(array): 31 | if IS_PYTHON3: 32 | return array.tostring() 33 | else: 34 | return np.getbuffer(array) 35 | 36 | def recover_database_images_and_ids(paths, args): 37 | # Connect to the database. 38 | connection = sqlite3.connect(paths.database_path) 39 | cursor = connection.cursor() 40 | 41 | # Recover database images and ids. 42 | images = {} 43 | cameras = {} 44 | cursor.execute("SELECT name, image_id, camera_id FROM images;") 45 | for row in cursor: 46 | images[row[0]] = row[1] 47 | cameras[row[0]] = row[2] 48 | 49 | # Close the connection to the database. 50 | cursor.close() 51 | connection.close() 52 | 53 | return images, cameras 54 | 55 | 56 | def preprocess_reference_model(paths, args): 57 | print('Preprocessing the reference model...') 58 | 59 | # Recover intrinsics. 60 | with open(os.path.join(paths.reference_model_path, 'database_intrinsics_v1_1.txt')) as f: 61 | raw_intrinsics = f.readlines() 62 | 63 | camera_parameters = {} 64 | 65 | for intrinsics in raw_intrinsics: 66 | intrinsics = intrinsics.strip('\n').split(' ') 67 | 68 | image_name = intrinsics[0] 69 | 70 | camera_model = intrinsics[1] 71 | 72 | intrinsics = [float(param) for param in intrinsics[2 :]] 73 | 74 | camera = Camera() 75 | camera.set_intrinsics(camera_model=camera_model, intrinsics=intrinsics) 76 | 77 | camera_parameters[image_name] = camera 78 | 79 | # Recover poses. 80 | with open(os.path.join(paths.reference_model_path, 'aachen_v_1_1.nvm')) as f: 81 | raw_extrinsics = f.readlines() 82 | 83 | # Skip the header. 84 | n_cameras = int(raw_extrinsics[2]) 85 | raw_extrinsics = raw_extrinsics[3 : 3 + n_cameras] 86 | 87 | for extrinsics in raw_extrinsics: 88 | extrinsics = extrinsics.strip('\n').split(' ') 89 | 90 | image_name = extrinsics[0] 91 | 92 | # Skip the focal length. Skip the distortion and terminal 0. 93 | qw, qx, qy, qz, cx, cy, cz = [float(param) for param in extrinsics[2 : -2]] 94 | 95 | qvec = np.array([qw, qx, qy, qz]) 96 | c = np.array([cx, cy, cz]) 97 | 98 | # NVM -> COLMAP. 99 | t = camera_center_to_translation(c, qvec) 100 | 101 | camera_parameters[image_name].set_pose(qvec=qvec, t=t) 102 | 103 | return camera_parameters 104 | 105 | 106 | def generate_empty_reconstruction(images, cameras, camera_parameters, paths, args): 107 | print('Generating the empty reconstruction...') 108 | 109 | if not os.path.exists(paths.empty_model_path): 110 | os.mkdir(paths.empty_model_path) 111 | 112 | with open(os.path.join(paths.empty_model_path, 'cameras.txt'), 'w') as f: 113 | for image_name in images: 114 | image_id = images[image_name] 115 | camera_id = cameras[image_name] 116 | try: 117 | camera = camera_parameters[image_name] 118 | except: 119 | continue 120 | f.write('%d %s %s\n' % ( 121 | camera_id, 122 | camera.camera_model, 123 | ' '.join(map(str, camera.intrinsics)) 124 | )) 125 | 126 | with open(os.path.join(paths.empty_model_path, 'images.txt'), 'w') as f: 127 | for image_name in images: 128 | image_id = images[image_name] 129 | camera_id = cameras[image_name] 130 | try: 131 | camera = camera_parameters[image_name] 132 | except: 133 | continue 134 | f.write('%d %s %s %d %s\n\n' % ( 135 | image_id, 136 | ' '.join(map(str, camera.qvec)), 137 | ' '.join(map(str, camera.t)), 138 | camera_id, 139 | image_name 140 | )) 141 | 142 | with open(os.path.join(paths.empty_model_path, 'points3D.txt'), 'w') as f: 143 | pass 144 | 145 | 146 | def import_features(images, paths, args): 147 | # Connect to the database. 148 | connection = sqlite3.connect(paths.database_path) 149 | cursor = connection.cursor() 150 | 151 | # Import the features. 152 | print('Importing features...') 153 | 154 | for image_name, image_id in tqdm(images.items(), total=len(images.items())): 155 | features_path = os.path.join(paths.features_path, '%s.%s' % (image_name, args.method_name)) 156 | 157 | keypoints = np.load(features_path)['keypoints'] 158 | n_keypoints = keypoints.shape[0] 159 | 160 | # Keep only x, y coordinates. 161 | keypoints = keypoints[:, : 2] 162 | # Add placeholder scale, orientation. 163 | keypoints = np.concatenate([keypoints, np.ones((n_keypoints, 1)), np.zeros((n_keypoints, 1))], axis=1).astype(np.float32) 164 | 165 | keypoints_str = keypoints.tostring() 166 | cursor.execute("INSERT INTO keypoints(image_id, rows, cols, data) VALUES(?, ?, ?, ?);", 167 | (image_id, keypoints.shape[0], keypoints.shape[1], keypoints_str)) 168 | connection.commit() 169 | 170 | # Close the connection to the database. 171 | cursor.close() 172 | connection.close() 173 | 174 | 175 | def image_ids_to_pair_id(image_id1, image_id2): 176 | if image_id1 > image_id2: 177 | return 2147483647 * image_id2 + image_id1 178 | else: 179 | return 2147483647 * image_id1 + image_id2 180 | 181 | 182 | def match_features(images, paths, args): 183 | # Connect to the database. 184 | connection = sqlite3.connect(paths.database_path) 185 | cursor = connection.cursor() 186 | 187 | # Match the features and insert the matches in the database. 188 | print('Matching...') 189 | 190 | with open(paths.match_list_path, 'r') as f: 191 | raw_pairs = f.readlines() 192 | 193 | image_pair_ids = set() 194 | for raw_pair in tqdm(raw_pairs, total=len(raw_pairs)): 195 | image_name1, image_name2 = raw_pair.strip('\n').split(' ') 196 | 197 | features_path1 = os.path.join(paths.features_path, '%s.%s' % (image_name1, args.method_name)) 198 | features_path2 = os.path.join(paths.features_path, '%s.%s' % (image_name2, args.method_name)) 199 | 200 | descriptors1 = torch.from_numpy(np.load(features_path1)['descriptors']).to(device) 201 | descriptors2 = torch.from_numpy(np.load(features_path2)['descriptors']).to(device) 202 | matches = mutual_nn_matcher(descriptors1, descriptors2).astype(np.uint32) 203 | 204 | image_id1, image_id2 = images[image_name1], images[image_name2] 205 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2) 206 | if image_pair_id in image_pair_ids: 207 | continue 208 | image_pair_ids.add(image_pair_id) 209 | 210 | if image_id1 > image_id2: 211 | matches = matches[:, [1, 0]] 212 | 213 | matches_str = matches.tostring() 214 | cursor.execute("INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?);", 215 | (image_pair_id, matches.shape[0], matches.shape[1], matches_str)) 216 | connection.commit() 217 | 218 | # Close the connection to the database. 219 | cursor.close() 220 | connection.close() 221 | 222 | 223 | def geometric_verification(paths, args): 224 | print('Running geometric verification...') 225 | 226 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'matches_importer', 227 | '--database_path', paths.database_path, 228 | '--match_list_path', paths.match_list_path, 229 | '--match_type', 'pairs']) 230 | 231 | 232 | def reconstruct(paths, args): 233 | if not os.path.isdir(paths.database_model_path): 234 | os.mkdir(paths.database_model_path) 235 | 236 | # Reconstruct the database model. 237 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'point_triangulator', 238 | '--database_path', paths.database_path, 239 | '--image_path', paths.image_path, 240 | '--input_path', paths.empty_model_path, 241 | '--output_path', paths.database_model_path, 242 | '--Mapper.ba_refine_focal_length', '0', 243 | '--Mapper.ba_refine_principal_point', '0', 244 | '--Mapper.ba_refine_extra_params', '0']) 245 | 246 | 247 | def register_queries(paths, args): 248 | if not os.path.isdir(paths.final_model_path): 249 | os.mkdir(paths.final_model_path) 250 | 251 | # Register the query images. 252 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'image_registrator', 253 | '--database_path', paths.database_path, 254 | '--input_path', paths.database_model_path, 255 | '--output_path', paths.final_model_path, 256 | '--Mapper.ba_refine_focal_length', '0', 257 | '--Mapper.ba_refine_principal_point', '0', 258 | '--Mapper.ba_refine_extra_params', '0']) 259 | 260 | 261 | def recover_query_poses(paths, args): 262 | print('Recovering query poses...') 263 | 264 | if not os.path.isdir(paths.final_txt_model_path): 265 | os.mkdir(paths.final_txt_model_path) 266 | 267 | # Convert the model to TXT. 268 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'model_converter', 269 | '--input_path', paths.final_model_path, 270 | '--output_path', paths.final_txt_model_path, 271 | '--output_type', 'TXT']) 272 | 273 | # Recover query names. 274 | query_image_list_path = os.path.join(args.dataset_path, 'queries/night_time_queries_with_intrinsics_v1_1.txt') 275 | 276 | with open(query_image_list_path) as f: 277 | raw_queries = f.readlines() 278 | 279 | query_names = set() 280 | for raw_query in raw_queries: 281 | raw_query = raw_query.strip('\n').split(' ') 282 | query_name = raw_query[0] 283 | query_names.add(query_name) 284 | 285 | with open(os.path.join(paths.final_txt_model_path, 'images.txt')) as f: 286 | raw_extrinsics = f.readlines() 287 | 288 | if not paths.prediction_path.parent.exists(): 289 | paths.prediction_path.parent.makedirs_p() 290 | f = open(paths.prediction_path, 'w') 291 | 292 | # Skip the header. 293 | for extrinsics in raw_extrinsics[4 :: 2]: 294 | extrinsics = extrinsics.strip('\n').split(' ') 295 | 296 | image_name = extrinsics[-1] 297 | 298 | if image_name in query_names: 299 | # Skip the IMAGE_ID ([0]), CAMERA_ID ([-2]), and IMAGE_NAME ([-1]). 300 | f.write('%s %s\n' % (image_name.split('/')[-1], ' '.join(extrinsics[1 : -2]))) 301 | 302 | f.close() 303 | 304 | 305 | if __name__ == "__main__": 306 | parser = argparse.ArgumentParser() 307 | parser.add_argument('--dataset_path', required=True, help='Path to the dataset') 308 | parser.add_argument('--feature_path', required=True, help='Path to the features') 309 | parser.add_argument('--colmap_path', required=True, help='Path to the COLMAP executable folder') 310 | parser.add_argument('--method_name', required=True, help='Name of the method') 311 | parser.add_argument('--match_list_path', type=str, default='image_pairs_to_match.txt', help='config of reconstruct and register') 312 | args = parser.parse_args() 313 | 314 | # Torch settings for the matcher. 315 | use_cuda = torch.cuda.is_available() 316 | device = torch.device("cuda:0" if use_cuda else "cpu") 317 | 318 | # Create the extra paths. 319 | paths = types.SimpleNamespace() 320 | paths.dummy_database_path = Path(args.dataset_path)/'others/database_v1_1.db' 321 | paths.database_path = Path(args.dataset_path)/'intermedia/{}/{}_v1_1.db'.format(args.method_name,args.method_name) 322 | paths.image_path = Path(args.dataset_path)/'images/images_upright' 323 | paths.features_path = Path(args.feature_path) 324 | paths.reference_model_path = Path(args.dataset_path)/'3D-models/aachen_v_1_1' 325 | paths.match_list_path = Path(args.dataset_path)/'others/{}'.format(args.match_list_path) 326 | paths.empty_model_path = Path(args.dataset_path)/'intermedia/{}/sparse-{}-empty-v_1_1'.format(args.method_name, args.method_name) 327 | paths.database_model_path = Path(args.dataset_path)/'intermedia/{}/sparse-{}-database-v_1_1'.format(args.method_name, args.method_name) 328 | paths.final_model_path = Path(args.dataset_path)/'intermedia/{}/sparse-{}-final-v_1_1'.format(args.method_name, args.method_name) 329 | paths.final_txt_model_path = Path(args.dataset_path)/'intermedia/{}/sparse-{}-final-txt-v_1_1'.format(args.method_name, args.method_name) 330 | paths.prediction_path = Path(args.dataset_path)/'results/Aachen_v1_1_eval_[{}].txt'.format(args.method_name) 331 | 332 | # Create a copy of the dummy database. 333 | if os.path.exists(paths.database_path): 334 | raise FileExistsError('The database file already exists for method %s.' % args.method_name) 335 | if not paths.database_path.parent.exists(): 336 | paths.database_path.parent.makedirs_p() 337 | shutil.copyfile(paths.dummy_database_path, paths.database_path) 338 | 339 | # Reconstruction pipeline. 340 | camera_parameters = preprocess_reference_model(paths, args) 341 | images, cameras = recover_database_images_and_ids(paths, args) 342 | generate_empty_reconstruction(images, cameras, camera_parameters, paths, args) 343 | import_features(images, paths, args) 344 | match_features(images, paths, args) 345 | geometric_verification(paths, args) 346 | reconstruct(paths, args) 347 | register_queries(paths, args) 348 | recover_query_poses(paths, args) -------------------------------------------------------------------------------- /evaluations/aachen/reconstruct_pipeline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | from path import Path 6 | 7 | import os 8 | 9 | import shutil 10 | 11 | import subprocess 12 | 13 | import sqlite3 14 | 15 | import torch 16 | 17 | import types 18 | 19 | from tqdm import tqdm 20 | 21 | from matchers import mutual_nn_matcher 22 | 23 | from camera import Camera 24 | 25 | from utils import quaternion_to_rotation_matrix, camera_center_to_translation 26 | 27 | import sys 28 | IS_PYTHON3 = sys.version_info[0] >= 3 29 | 30 | def array_to_blob(array): 31 | if IS_PYTHON3: 32 | return array.tostring() 33 | else: 34 | return np.getbuffer(array) 35 | 36 | def recover_database_images_and_ids(paths, args): 37 | # Connect to the database. 38 | connection = sqlite3.connect(paths.database_path) 39 | cursor = connection.cursor() 40 | 41 | # Recover database images and ids. 42 | images = {} 43 | cameras = {} 44 | cursor.execute("SELECT name, image_id, camera_id FROM images;") 45 | for row in cursor: 46 | images[row[0]] = row[1] 47 | cameras[row[0]] = row[2] 48 | 49 | # Close the connection to the database. 50 | cursor.close() 51 | connection.close() 52 | 53 | return images, cameras 54 | 55 | 56 | def preprocess_reference_model(paths, args): 57 | print('Preprocessing the reference model...') 58 | 59 | # Recover intrinsics. 60 | with open(paths.reference_model_path/'database_intrinsics.txt') as f: 61 | raw_intrinsics = f.readlines() 62 | 63 | camera_parameters = {} 64 | 65 | for intrinsics in raw_intrinsics: 66 | intrinsics = intrinsics.strip('\n').split(' ') 67 | 68 | image_name = intrinsics[0] 69 | 70 | camera_model = intrinsics[1] 71 | 72 | intrinsics = [float(param) for param in intrinsics[2 :]] 73 | 74 | camera = Camera() 75 | camera.set_intrinsics(camera_model=camera_model, intrinsics=intrinsics) 76 | 77 | camera_parameters[image_name] = camera 78 | 79 | # Recover poses. 80 | with open(paths.reference_model_path/'aachen_cvpr2018_db.nvm') as f: 81 | raw_extrinsics = f.readlines() 82 | 83 | # Skip the header. 84 | n_cameras = int(raw_extrinsics[2]) 85 | raw_extrinsics = raw_extrinsics[3 : 3 + n_cameras] 86 | 87 | for extrinsics in raw_extrinsics: 88 | extrinsics = extrinsics.strip('\n').split(' ') 89 | 90 | image_name = extrinsics[0] 91 | 92 | # Skip the focal length. Skip the distortion and terminal 0. 93 | qw, qx, qy, qz, cx, cy, cz = [float(param) for param in extrinsics[2 : -2]] 94 | 95 | qvec = np.array([qw, qx, qy, qz]) 96 | c = np.array([cx, cy, cz]) 97 | 98 | # NVM -> COLMAP. 99 | t = camera_center_to_translation(c, qvec) 100 | 101 | camera_parameters[image_name].set_pose(qvec=qvec, t=t) 102 | 103 | return camera_parameters 104 | 105 | 106 | def generate_empty_reconstruction(images, cameras, camera_parameters, paths, args): 107 | print('Generating the empty reconstruction...') 108 | 109 | if not paths.empty_model_path.exists(): 110 | paths.empty_model_path.makedirs_p() 111 | 112 | with open(paths.empty_model_path/'cameras.txt', 'w') as f: 113 | for image_name in images: 114 | image_id = images[image_name] 115 | camera_id = cameras[image_name] 116 | try: 117 | camera = camera_parameters[image_name] 118 | except: 119 | continue 120 | f.write('%d %s %s\n' % ( 121 | camera_id, 122 | camera.camera_model, 123 | ' '.join(map(str, camera.intrinsics)) 124 | )) 125 | 126 | with open(paths.empty_model_path/'images.txt', 'w') as f: 127 | for image_name in images: 128 | image_id = images[image_name] 129 | camera_id = cameras[image_name] 130 | try: 131 | camera = camera_parameters[image_name] 132 | except: 133 | continue 134 | f.write('%d %s %s %d %s\n\n' % ( 135 | image_id, 136 | ' '.join(map(str, camera.qvec)), 137 | ' '.join(map(str, camera.t)), 138 | camera_id, 139 | image_name 140 | )) 141 | 142 | with open(paths.empty_model_path/'points3D.txt', 'w') as f: 143 | pass 144 | 145 | 146 | def import_features(images, paths, args): 147 | # Connect to the database. 148 | connection = sqlite3.connect(paths.database_path) 149 | cursor = connection.cursor() 150 | 151 | # Import the features. 152 | print('Importing features...') 153 | 154 | for image_name, image_id in tqdm(images.items(), total=len(images.items())): 155 | features_path = paths.features_path/'{}.{}'.format(image_name, args.method_name) 156 | 157 | keypoints = np.load(features_path)['keypoints'] 158 | n_keypoints = keypoints.shape[0] 159 | 160 | # Keep only x, y coordinates. 161 | keypoints = keypoints[:, : 2] 162 | # Add placeholder scale, orientation. 163 | keypoints = np.concatenate([keypoints, np.ones((n_keypoints, 1)), np.zeros((n_keypoints, 1))], axis=1).astype(np.float32) 164 | 165 | keypoints_str = keypoints.tostring() 166 | cursor.execute("INSERT INTO keypoints(image_id, rows, cols, data) VALUES(?, ?, ?, ?);", 167 | (image_id, keypoints.shape[0], keypoints.shape[1], keypoints_str)) 168 | connection.commit() 169 | 170 | # Close the connection to the database. 171 | cursor.close() 172 | connection.close() 173 | 174 | 175 | def image_ids_to_pair_id(image_id1, image_id2): 176 | if image_id1 > image_id2: 177 | return 2147483647 * image_id2 + image_id1 178 | else: 179 | return 2147483647 * image_id1 + image_id2 180 | 181 | 182 | def match_features(images, paths, args): 183 | # Connect to the database. 184 | connection = sqlite3.connect(paths.database_path) 185 | cursor = connection.cursor() 186 | 187 | # Match the features and insert the matches in the database. 188 | print('Matching...') 189 | 190 | with open(paths.match_list_path, 'r') as f: 191 | raw_pairs = f.readlines() 192 | 193 | image_pair_ids = set() 194 | bar = tqdm(raw_pairs, total=len(raw_pairs)) 195 | for raw_pair in bar: 196 | image_name1, image_name2 = raw_pair.strip('\n').split(' ') 197 | 198 | features_path1 = paths.features_path/'{}.{}'.format(image_name1, args.method_name) 199 | features_path2 = paths.features_path/'{}.{}'.format(image_name2, args.method_name) 200 | 201 | descriptors1 = torch.from_numpy(np.load(features_path1)['descriptors']).to(device) 202 | descriptors2 = torch.from_numpy(np.load(features_path2)['descriptors']).to(device) 203 | matches = mutual_nn_matcher(descriptors1, descriptors2).astype(np.uint32) 204 | 205 | image_id1, image_id2 = images[image_name1], images[image_name2] 206 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2) 207 | if image_pair_id in image_pair_ids: 208 | continue 209 | image_pair_ids.add(image_pair_id) 210 | 211 | if image_id1 > image_id2: 212 | matches = matches[:, [1, 0]] 213 | 214 | matches_str = np.int32(matches).tostring() 215 | cursor.execute("INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?);", 216 | (image_pair_id, matches.shape[0], matches.shape[1], matches_str)) 217 | connection.commit() 218 | 219 | # Close the connection to the database. 220 | cursor.close() 221 | connection.close() 222 | 223 | 224 | def geometric_verification(paths, args): 225 | print('Running geometric verification...') 226 | 227 | subprocess.call([Path(args.colmap_path)/'colmap', 'matches_importer', 228 | '--database_path', paths.database_path, 229 | '--match_list_path', paths.match_list_path, 230 | '--match_type', 'pairs']) 231 | 232 | 233 | def reconstruct(paths, args): 234 | if not paths.database_model_path.isdir(): 235 | paths.database_model_path.makedirs_p() 236 | 237 | # Reconstruct the database model. 238 | subprocess.call([Path(args.colmap_path)/'colmap', 'point_triangulator', 239 | '--database_path', paths.database_path, 240 | '--image_path', paths.image_path, 241 | '--input_path', paths.empty_model_path, 242 | '--output_path', paths.database_model_path, 243 | '--Mapper.ba_refine_focal_length', '0', 244 | '--Mapper.ba_refine_principal_point', '0', 245 | '--Mapper.ba_refine_extra_params', '0',]) 246 | 247 | 248 | def register_queries(paths, args): 249 | if not paths.final_model_path.isdir(): 250 | paths.final_model_path.makedirs_p() 251 | 252 | # Register the query images. 253 | subprocess.call([Path(args.colmap_path)/'colmap', 'image_registrator', 254 | '--database_path', paths.database_path, 255 | '--input_path', paths.database_model_path, 256 | '--output_path', paths.final_model_path, 257 | '--Mapper.ba_refine_focal_length', '0', 258 | '--Mapper.ba_refine_principal_point', '0', 259 | '--Mapper.ba_refine_extra_params', '0']) 260 | 261 | 262 | def recover_query_poses(paths, args): 263 | print('Recovering query poses...') 264 | 265 | if not paths.final_txt_model_path.isdir(): 266 | paths.final_txt_model_path.makedirs_p() 267 | 268 | # Convert the model to TXT. 269 | subprocess.call([Path(args.colmap_path)/'colmap', 'model_converter', 270 | '--input_path', paths.final_model_path, 271 | '--output_path', paths.final_txt_model_path, 272 | '--output_type', 'TXT']) 273 | 274 | # Recover query names. 275 | if args.match_list_path == 'image_pairs_to_match.txt': 276 | query_image_list_path = Path(args.dataset_path)/'queries/night_time_queries_with_intrinsics.txt' 277 | with open(query_image_list_path) as f: 278 | raw_queries = f.readlines() 279 | else: 280 | query_image_list_path = Path(args.dataset_path)/'queries/night_time_queries_with_intrinsics.txt' 281 | with open(query_image_list_path) as f: 282 | raw_queries = f.readlines() 283 | query_image_list_path = Path(args.dataset_path)/'queries/day_time_queries_with_intrinsics.txt' 284 | with open(query_image_list_path) as f: 285 | tmp = f.readlines() 286 | raw_queries.extend(tmp) 287 | 288 | query_names = set() 289 | for raw_query in raw_queries: 290 | raw_query = raw_query.strip('\n').split(' ') 291 | query_name = raw_query[0] 292 | query_names.add(query_name) 293 | 294 | with open(os.path.join(paths.final_txt_model_path, 'images.txt')) as f: 295 | raw_extrinsics = f.readlines() 296 | 297 | if not paths.prediction_path.parent.exists(): 298 | paths.prediction_path.parent.makedirs_p() 299 | f = open(paths.prediction_path, 'w') 300 | 301 | # Skip the header. 302 | for extrinsics in raw_extrinsics[4 :: 2]: 303 | extrinsics = extrinsics.strip('\n').split(' ') 304 | 305 | image_name = extrinsics[-1] 306 | 307 | if image_name in query_names: 308 | # Skip the IMAGE_ID ([0]), CAMERA_ID ([-2]), and IMAGE_NAME ([-1]). 309 | f.write('%s %s\n' % (image_name.split('/')[-1], ' '.join(extrinsics[1 : -2]))) 310 | 311 | f.close() 312 | 313 | 314 | if __name__ == "__main__": 315 | parser = argparse.ArgumentParser() 316 | parser.add_argument('--dataset_path', required=True, help='Path to the dataset') 317 | parser.add_argument('--feature_path', required=True, help='Path to the features') 318 | parser.add_argument('--colmap_path', required=True, help='Path to the COLMAP executable folder') 319 | parser.add_argument('--method_name', required=True, help='Name of the method') 320 | parser.add_argument('--match_list_path', type=str, default='image_pairs_to_match.txt', help='config of reconstruct and register') 321 | args = parser.parse_args() 322 | 323 | # Torch settings for the matcher. 324 | use_cuda = torch.cuda.is_available() 325 | device = torch.device("cuda:0" if use_cuda else "cpu") 326 | 327 | # Create the extra paths. 328 | paths = types.SimpleNamespace() 329 | paths.dummy_database_path = Path(args.dataset_path)/'others/database.db' 330 | paths.database_path = Path(args.dataset_path)/'intermedia/{}/{}.db'.format(args.method_name,args.method_name) 331 | paths.image_path = Path(args.dataset_path)/'images/images_upright' 332 | paths.features_path = Path(args.feature_path) 333 | paths.reference_model_path = Path(args.dataset_path)/'3D-models/aachen_v_1' 334 | paths.match_list_path = Path(args.dataset_path)/'others/{}'.format(args.match_list_path) 335 | paths.empty_model_path = Path(args.dataset_path)/'intermedia/{}/sparse-{}-empty'.format(args.method_name, args.method_name) 336 | paths.database_model_path = Path(args.dataset_path)/'intermedia/{}/sparse-{}-database'.format(args.method_name, args.method_name) 337 | paths.final_model_path = Path(args.dataset_path)/'intermedia/{}/sparse-{}-final'.format(args.method_name, args.method_name) 338 | paths.final_txt_model_path = Path(args.dataset_path)/'intermedia/{}/sparse-{}-final-txt'.format(args.method_name, args.method_name) 339 | paths.prediction_path = Path(args.dataset_path)/'results/Aachen_eval_[{}].txt'.format(args.method_name) 340 | 341 | # Create a copy of the dummy database. 342 | if paths.database_path.exists(): 343 | raise FileExistsError('The database file already exists for method %s.' % args.method_name) 344 | if not paths.database_path.parent.exists(): 345 | paths.database_path.parent.makedirs_p() 346 | shutil.copyfile(paths.dummy_database_path, paths.database_path) 347 | 348 | # Reconstruction pipeline. 349 | camera_parameters = preprocess_reference_model(paths, args) 350 | images, cameras = recover_database_images_and_ids(paths, args) 351 | generate_empty_reconstruction(images, cameras, camera_parameters, paths, args) 352 | import_features(images, paths, args) 353 | match_features(images, paths, args) 354 | geometric_verification(paths, args) 355 | reconstruct(paths, args) 356 | register_queries(paths, args) 357 | recover_query_poses(paths, args) -------------------------------------------------------------------------------- /evaluations/ETH_local_feature/reconstruction_pipeline.py: -------------------------------------------------------------------------------- 1 | # Import the features and matches into a COLMAP database. 2 | # 3 | # Copyright 2017: Johannes L. Schoenberger 4 | 5 | from __future__ import print_function, division 6 | 7 | import os 8 | import sys 9 | import glob 10 | import yaml 11 | import types 12 | import torch 13 | import shutil 14 | import sqlite3 15 | import argparse 16 | import subprocess 17 | import multiprocessing 18 | 19 | import numpy as np 20 | from path import Path 21 | from tqdm import tqdm 22 | import custom_matcher as cms 23 | 24 | 25 | IS_PYTHON3 = sys.version_info[0] >= 3 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser() 30 | # parser.add_argument("--dataset_path", required=True, 31 | # help="Path to the dataset, e.g., path/to/Fountain") 32 | # parser.add_argument("--colmap_path", required=True, 33 | # help="Path to the COLMAP executable folder, e.g., " 34 | # "path/to/colmap/build/src/exe") 35 | # parser.add_argument("--features_path", required=True, 36 | # help="Path to the features folder, e.g., " 37 | # "path/to/feature") 38 | # parser.add_argument("--method_postfix", required=True, 39 | # help="the postfix of the method") 40 | # parser.add_argument("--matcher", required=True, 41 | # help="the matcher") 42 | parser.add_argument("--config", required=True, 43 | help="Path to the configs, e.g., path/to/Fountain") 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def image_ids_to_pair_id(image_id1, image_id2): 49 | if image_id1 > image_id2: 50 | return 2147483647 * image_id2 + image_id1 51 | else: 52 | return 2147483647 * image_id1 + image_id2 53 | 54 | 55 | def import_features_and_match(configs, paths): 56 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 57 | 58 | connection = sqlite3.connect(paths.database_path) 59 | cursor = connection.cursor() 60 | 61 | cursor.execute("SELECT name FROM sqlite_master " 62 | "WHERE type='table' AND name='inlier_matches';") 63 | try: 64 | inlier_matches_table_exists = bool(next(cursor)[0]) 65 | except StopIteration: 66 | inlier_matches_table_exists = False 67 | 68 | cursor.execute("DELETE FROM keypoints;") 69 | cursor.execute("DELETE FROM descriptors;") 70 | cursor.execute("DELETE FROM matches;") 71 | if inlier_matches_table_exists: 72 | cursor.execute("DELETE FROM inlier_matches;") 73 | else: 74 | cursor.execute("DELETE FROM two_view_geometries;") 75 | connection.commit() 76 | 77 | images = {} 78 | cursor.execute("SELECT name, image_id FROM images;") 79 | for row in cursor: 80 | images[row[0]] = row[1] 81 | 82 | for image_name, image_id in tqdm(images.items(), total=len(images.items())): 83 | feature_path = paths.features_path/'{}.{}'.format(image_name, configs['method_postfix']) 84 | feature_file = np.load(feature_path) 85 | 86 | keypoints = feature_file['keypoints'][:,:2] 87 | descriptors = feature_file['descriptors'] 88 | assert keypoints.shape[1] == 2 89 | assert keypoints.shape[0] == descriptors.shape[0] 90 | 91 | keypoints_str = keypoints.tobytes() # early python3 use .tostring() 92 | cursor.execute("INSERT INTO keypoints(image_id, rows, cols, data) " 93 | "VALUES(?, ?, ?, ?);", 94 | (image_id, keypoints.shape[0], keypoints.shape[1], 95 | keypoints_str)) 96 | connection.commit() 97 | 98 | # custom match 99 | matcher = getattr(cms, configs['matcher']) 100 | image_names = list(images.keys()) 101 | image_pairs = [] 102 | image_pair_ids = set() 103 | for idx_total, image_name1 in enumerate(tqdm(image_names[:-1])): 104 | feature_path1 = paths.features_path/'{}.{}'.format(image_name1, configs['method_postfix']) 105 | descriptors1 = np.load(feature_path1)['descriptors'] 106 | descriptors1 = torch.from_numpy(descriptors1).to(device) 107 | bar = tqdm(image_names[idx_total+1:]) 108 | for idx_sub, image_name2 in enumerate(bar): 109 | image_pairs.append((image_name1, image_name2)) 110 | image_id1, image_id2 = images[image_name1], images[image_name2] 111 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2) 112 | if image_pair_id in image_pair_ids: 113 | continue 114 | 115 | feature_path2 = paths.features_path/'{}.{}'.format(image_name2, configs['method_postfix']) 116 | descriptors2 = np.load(feature_path2)['descriptors'] 117 | descriptors2 = torch.from_numpy(descriptors2).to(device) 118 | 119 | matches = matcher(descriptors1, descriptors2, **configs['matcher_config']) 120 | assert matches.shape[1] == 2 121 | # bar.write("matches: {}".format(matches.shape[0])) 122 | image_pair_ids.add(image_pair_id) 123 | if image_id1 > image_id2: 124 | matches = matches[:, [1, 0]] 125 | 126 | matches_str = np.int32(matches).tostring() 127 | cursor.execute("INSERT INTO matches(pair_id, rows, cols, data) " 128 | "VALUES(?, ?, ?, ?);", 129 | (image_pair_id, matches.shape[0], matches.shape[1], 130 | matches_str)) 131 | connection.commit() 132 | 133 | torch.cuda.empty_cache() 134 | with open(paths.match_list_path, 'w') as fid: 135 | for image_name1, image_name2 in image_pairs: 136 | fid.write("{} {}\n".format(image_name1, image_name2)) 137 | cursor.close() 138 | connection.close() 139 | 140 | subprocess.call([paths.colmap_path, 141 | "matches_importer", 142 | "--database_path", 143 | paths.database_path, 144 | "--match_list_path", 145 | paths.match_list_path, 146 | "--match_type", "pairs"]) 147 | 148 | # connection = sqlite3.connect(os.path.join(args.dataset_path, "database.db")) 149 | connection = sqlite3.connect(paths.database_path) 150 | cursor = connection.cursor() 151 | 152 | cursor.execute("SELECT count(*) FROM images;") 153 | num_images = next(cursor)[0] 154 | 155 | cursor.execute("SELECT count(*) FROM two_view_geometries WHERE rows > 0;") 156 | num_inlier_pairs = next(cursor)[0] 157 | 158 | cursor.execute("SELECT sum(rows) FROM two_view_geometries WHERE rows > 0;") 159 | num_inlier_matches = next(cursor)[0] 160 | 161 | cursor.close() 162 | connection.close() 163 | 164 | return dict(num_images=num_images, 165 | num_inlier_pairs=num_inlier_pairs, 166 | num_inlier_matches=num_inlier_matches) 167 | 168 | 169 | def reconstruct(configs, paths): 170 | database_path = paths.database_path 171 | image_path = paths.image_path 172 | sparse_path = paths.features_path.parent/"{}_sparse".format(configs['subfolder']) 173 | dense_path = paths.features_path.parent/"{}_dense".format(configs['subfolder']) 174 | if not sparse_path.exists(): 175 | sparse_path.makedirs_p() 176 | if not dense_path.exists(): 177 | dense_path.makedirs_p() 178 | 179 | # Run the sparse reconstruction. 180 | subprocess.call([paths.colmap_path, 181 | "mapper", 182 | "--database_path", database_path, 183 | "--image_path", image_path, 184 | "--output_path", sparse_path, 185 | "--Mapper.num_threads", 186 | str(min(multiprocessing.cpu_count(), 16))]) 187 | 188 | # Find the largest reconstructed sparse model. 189 | models = sparse_path.listdir() 190 | if len(models) == 0: 191 | print("Warning: Could not reconstruct any model") 192 | return 193 | 194 | largest_model = None 195 | largest_model_num_images = 0 196 | for model in models: 197 | subprocess.call([paths.colmap_path, 198 | "model_converter", 199 | "--input_path", model, 200 | "--output_path", model, 201 | "--output_type", "TXT"]) 202 | with open("{}/cameras.txt".format(model), 'r') as fid: 203 | for line in fid: 204 | if line.startswith("# Number of cameras"): 205 | num_images = int(line.split()[-1]) 206 | if num_images > largest_model_num_images: 207 | largest_model = model 208 | largest_model_num_images = num_images 209 | break 210 | assert largest_model_num_images > 0 211 | 212 | # Run the dense reconstruction. 213 | largest_model_path = largest_model 214 | ### the codes for dense reconstruction 215 | # workspace_path = dense_path/largest_model.name 216 | # if not workspace_path.exists(): 217 | # workspace_path.makedirs_p() 218 | 219 | # subprocess.call([paths.colmap_path, 220 | # "image_undistorter", 221 | # "--image_path", image_path, 222 | # "--input_path", largest_model_path, 223 | # "--output_path", workspace_path, 224 | # "--max_image_size", "1200"]) 225 | 226 | # subprocess.call([paths.colmap_path, 227 | # "patch_match_stereo", 228 | # "--workspace_path", workspace_path, 229 | # "--PatchMatchStereo.geom_consistency", "false"]) 230 | 231 | # subprocess.call([paths.colmap_path, 232 | # "stereo_fusion", 233 | # "--workspace_path", workspace_path, 234 | # "--input_type", "photometric", 235 | # "--output_path", os.path.join(workspace_path, "fused.ply"), 236 | # "--StereoFusion.min_num_pixels", "5"]) 237 | 238 | stats = subprocess.check_output( 239 | [paths.colmap_path, "model_analyzer", 240 | "--path", largest_model_path]) 241 | 242 | stats = stats.decode().split("\n") 243 | for stat in stats: 244 | if stat.startswith("Registered images"): 245 | num_reg_images = int(stat.split()[-1]) 246 | elif stat.startswith("Points"): 247 | num_sparse_points = int(stat.split()[-1]) 248 | elif stat.startswith("Observations"): 249 | num_observations = int(stat.split()[-1]) 250 | elif stat.startswith("Mean track length"): 251 | mean_track_length = float(stat.split()[-1]) 252 | elif stat.startswith("Mean observations per image"): 253 | num_observations_per_image = float(stat.split()[-1]) 254 | elif stat.startswith("Mean reprojection error"): 255 | mean_reproj_error = float(stat.split()[-1][:-2]) 256 | 257 | # returns with dense results 258 | # with open(os.path.join(workspace_path, "fused.ply"), "rb") as fid: 259 | # line = fid.readline().decode() 260 | # while line: 261 | # if line.startswith("element vertex"): 262 | # num_dense_points = int(line.split()[-1]) 263 | # break 264 | # line = fid.readline().decode() 265 | 266 | # return dict(num_reg_images=num_reg_images, 267 | # num_sparse_points=num_sparse_points, 268 | # num_observations=num_observations, 269 | # mean_track_length=mean_track_length, 270 | # num_observations_per_image=num_observations_per_image, 271 | # mean_reproj_error=mean_reproj_error, 272 | # num_dense_points=num_dense_points) 273 | 274 | ## returns without dense results 275 | return dict(num_reg_images=num_reg_images, 276 | num_sparse_points=num_sparse_points, 277 | num_observations=num_observations, 278 | mean_track_length=mean_track_length, 279 | num_observations_per_image=num_observations_per_image, 280 | mean_reproj_error=mean_reproj_error) 281 | 282 | 283 | def main(): 284 | args = parse_args() 285 | with open(args.config, 'r') as f: 286 | configs = yaml.load(f, Loader=yaml.FullLoader) 287 | configs['method_postfix'] = configs['postfix'] 288 | configs['features_path'] = '../../ckpts/{}/desc'.format(configs['output_root']) 289 | configs['dataset_path'] = configs['data_config_extract']['data_path'] 290 | configs['subfolder'] = configs['data_config_extract']['subfolder'] 291 | 292 | paths = types.SimpleNamespace() 293 | paths.colmap_path = Path(configs['colmap_path'])/'colmap' 294 | paths.dataset_path = Path(configs['dataset_path'])/'{}'.format( 295 | configs['subfolder']) 296 | paths.image_path = paths.dataset_path/"images" 297 | 298 | paths.features_path = Path(configs['features_path'])/'{}'.format( 299 | configs['subfolder']) 300 | paths.database_path = paths.features_path.parent/'{}_{}.db'.format( 301 | configs['subfolder'], configs['method_postfix']) 302 | paths.match_list_path = paths.features_path/'image_pairs_{}.txt'.format( 303 | configs['method_postfix']) 304 | paths.result_path = Path(configs['features_path'])/'res_{}_{}.txt'.format( 305 | configs['subfolder'], configs['method_postfix']) 306 | 307 | # print(paths.match_list_path) 308 | if paths.database_path.exists(): 309 | raise FileExistsError('The {} database already exists for method \ 310 | {}.'.format(configs['subfolder'], configs['method_postfix'])) 311 | shutil.copyfile(paths.dataset_path/'database.db', paths.database_path) 312 | 313 | matching_stats = import_features_and_match(configs, paths) 314 | reconstruction_stats = reconstruct(configs, paths) 315 | 316 | print() 317 | print(78 * "=") 318 | print("Raw statistics") 319 | print(78 * "=") 320 | print(matching_stats) 321 | print(reconstruction_stats) 322 | 323 | print() 324 | print(78 * "=") 325 | print("Formatted statistics") 326 | print(78 * "=") 327 | 328 | # strings = "| " + " | ".join( 329 | # map(str, [paths.dataset_path.basename(), 330 | # "METHOD", 331 | # matching_stats["num_images"], 332 | # reconstruction_stats["num_reg_images"], 333 | # reconstruction_stats["num_sparse_points"], 334 | # reconstruction_stats["num_observations"], 335 | # reconstruction_stats["mean_track_length"], 336 | # reconstruction_stats["num_observations_per_image"], 337 | # reconstruction_stats["mean_reproj_error"], 338 | # reconstruction_stats["num_dense_points"], 339 | # "", 340 | # "", 341 | # "", 342 | # "", 343 | # matching_stats["num_inlier_pairs"], 344 | # matching_stats["num_inlier_matches"]])) + " |" 345 | 346 | strings_key = '{}|'.format(paths.dataset_path.basename()) 347 | strings_val = '{}|'.format(paths.dataset_path.basename()) 348 | for key, val in reconstruction_stats.items(): 349 | strings_key += '{}|'.format(key) 350 | tmp_str = '{}'.format(val) 351 | tmp_str = tmp_str.rjust(len(key), ' ') 352 | tmp_str = tmp_str +'|' 353 | strings_val += tmp_str 354 | strings_key += '\n' 355 | strings_val += '\n' 356 | 357 | print(strings_key+strings_val) 358 | with open(paths.result_path, 'w') as fid: 359 | fid.write(strings_key+strings_val) 360 | 361 | 362 | if __name__ == "__main__": 363 | main() 364 | -------------------------------------------------------------------------------- /managers/extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import datetime 4 | import shutil 5 | import logging 6 | import yaml 7 | import importlib 8 | import numpy as np 9 | import time 10 | import h5py 11 | from path import Path 12 | from abc import ABC, abstractmethod 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | import torch.distributed as dist 18 | from torch.utils.data.distributed import DistributedSampler 19 | 20 | import networks 21 | import datasets 22 | import losses 23 | import datasets.data_utils as dutils 24 | from losses.preprocess_utils import * 25 | import losses.preprocess_utils as putils 26 | 27 | from tqdm import tqdm 28 | import colorlog 29 | from PIL import Image as Im 30 | 31 | class TqdmHandler(logging.StreamHandler): 32 | def __init__(self): 33 | logging.StreamHandler.__init__(self) 34 | 35 | def emit(self, record): 36 | msg = self.format(record) 37 | tqdm.write(msg) 38 | 39 | 40 | class Extractor(ABC): 41 | def __init__(self, args): 42 | self.args = args 43 | with open(self.args.config, 'r') as f: 44 | self.config = yaml.load(f, Loader=yaml.FullLoader) 45 | # timestamp = datetime.datetime.now().strftime("%m%d-%H%M") 46 | self.save_root = Path('./ckpts/{}'.format(self.config['output_root'])) 47 | self.logfile = self.save_root/'logging_file.txt' 48 | self.desc_root = self.save_root/'desc' 49 | self.img_root = self.save_root/'image' 50 | self.sift_kp = self.config['use_sift'] 51 | 52 | if 'save_npz' in list(self.config.keys()): 53 | self.save_npz = self.config['save_npz'] 54 | else: 55 | self.save_npz = True 56 | 57 | if 'save_h5' in list(self.config.keys()): 58 | self.save_h5 = self.config['save_h5'] 59 | else: 60 | self.save_h5 = False 61 | 62 | ckpt_path = Path(self.config['load_path']) 63 | cfg_path = ckpt_path.dirname()/'config.yaml' 64 | with open(cfg_path, 'r') as f: 65 | pre_conf = yaml.load(f, Loader=yaml.FullLoader) 66 | self.config['model_config'].update(pre_conf['model_config']) 67 | if 'model' in list(pre_conf.keys()): 68 | self.config['model'] = pre_conf['model'] 69 | 70 | self.set_device() 71 | self.set_folder_and_logger() 72 | 73 | ## model 74 | if 'model' in list(self.config.keys()): 75 | tmp_model = getattr(networks, self.config['model']) 76 | self.model = tmp_model(self.config['model_config'], self.device) 77 | else: 78 | self.model = networks.WSFNet(self.config['model_config'], self.device) 79 | if self.multi_gpu: 80 | self.model.set_parallel(self.args.local_rank) 81 | 82 | # self.model.save_checkpoint(self.save_root) 83 | self.model.load_checkpoint(self.config['load_path']) 84 | self.model.set_eval() 85 | 86 | if not self.config['use_sift']: 87 | self.detector = getattr(putils, self.config['detector']) 88 | self.logger.info('use {} to detect keypoints'.format(self.config['detector'])) 89 | else: 90 | self.logger.info('use sift keypoints') 91 | 92 | ## dataloader 93 | dataset = getattr(datasets, self.config['data']) 94 | extract_dataset = dataset(configs=self.config['data_config_extract']) 95 | if self.multi_gpu: 96 | extract_sampler = torch.utils.data.distributed.DistributedSampler(extract_dataset) 97 | else: 98 | extract_sampler = None 99 | self.extract_loader = torch.utils.data.DataLoader(extract_dataset, batch_size=self.config['data_config_extract']['batch_size'], 100 | shuffle=False, num_workers=self.config['data_config_extract']['workers'], 101 | collate_fn=self.my_collate, sampler=extract_sampler) 102 | 103 | 104 | def my_collate(self, batch): 105 | ''' Puts each data field into a tensor with outer dimension batch size ''' 106 | batch = list(filter(lambda b: b is not None, batch)) 107 | return torch.utils.data.dataloader.default_collate(batch) 108 | 109 | def set_device(self): 110 | if torch.cuda.device_count() == 0: 111 | self.device = torch.device("cpu") 112 | self.output_flag=True 113 | self.multi_gpu = False 114 | print('use CPU for extraction') 115 | elif torch.cuda.device_count() == 1: 116 | self.device = torch.device("cuda") 117 | self.output_flag=True 118 | self.multi_gpu = False 119 | print('use a single GPU for extraction') 120 | else: 121 | self.device = torch.device("cuda", self.args.local_rank) 122 | self.multi_gpu = True 123 | dist.init_process_group(backend='nccl') 124 | # torch.autograd.set_detect_anomaly(True) # for debug 125 | if self.args.local_rank == 0: 126 | self.output_flag=True 127 | print('use {} GPUs for extraction'.format(torch.cuda.device_count())) 128 | else: 129 | self.output_flag=False 130 | 131 | def set_folder_and_logger(self): 132 | if self.output_flag: 133 | if not os.path.exists(self.save_root) : 134 | self.save_root.makedirs_p() 135 | else: 136 | # if path exsists, quit to make sure that the previous setting.txt would not be overwritten 137 | if self.config['data'] == 'ETH_LFB' or self.config['data'] == 'IMC_eval': 138 | pass 139 | else: 140 | raise "The save path is already exists, please change the output_root in config" 141 | print('=> will save everything to {}'.format(self.save_root)) 142 | # shutil.copy(self.args.config, self.save_root/'config.yaml') 143 | with open(self.save_root/'config.yaml', 'w') as fout: 144 | yaml.dump(self.config, fout) 145 | self.logfile.touch() 146 | 147 | if not os.path.exists(self.desc_root) : 148 | self.desc_root.makedirs_p() 149 | if not os.path.exists(self.img_root) : 150 | self.img_root.makedirs_p() 151 | 152 | while not os.path.exists(self.logfile): 153 | time.sleep(0.5) 154 | continue 155 | 156 | self.logger = logging.getLogger() 157 | if self.output_flag: 158 | self.logger.setLevel(logging.INFO) 159 | fh = logging.FileHandler(self.logfile, mode='a') 160 | fh.setLevel(logging.DEBUG) 161 | 162 | # ch = logging.StreamHandler() 163 | ch = TqdmHandler() 164 | ch.setLevel(logging.INFO) 165 | 166 | formatter = logging.Formatter("%(asctime)s - gpu {} - %(levelname)s: %(message)s".format(self.args.local_rank)) 167 | fh.setFormatter(formatter) 168 | # ch.setFormatter(formatter) 169 | ch.setFormatter(colorlog.ColoredFormatter( 170 | "%(asctime)s - gpu {} - %(levelname)s: %(message)s".format(self.args.local_rank), 171 | log_colors={ 172 | 'DEBUG': 'cyan', 173 | 'INFO': 'white', 174 | 'SUCCESS:': 'green', 175 | 'WARNING': 'yellow', 176 | 'ERROR': 'red', 177 | 'CRITICAL': 'red,bg_white'},)) 178 | 179 | self.logger.addHandler(fh) 180 | self.logger.addHandler(ch) 181 | else: 182 | self.logger.setLevel(logging.ERROR) 183 | fh = logging.FileHandler(self.logfile, mode='a') 184 | fh.setLevel(logging.ERROR) 185 | 186 | ch = logging.StreamHandler() 187 | ch.setLevel(logging.ERROR) 188 | 189 | formatter = logging.Formatter("%(asctime)s - gpu {} - %(levelname)s: %(message)s".format(self.local_rank)) 190 | fh.setFormatter(formatter) 191 | # ch.setFormatter(formatter) 192 | ch.setFormatter(colorlog.ColoredFormatter( 193 | "%(asctime)s - gpu {} - %(levelname)s: %(message)s".format(self.args.local_rank), 194 | log_colors={ 195 | 'DEBUG': 'cyan', 196 | 'INFO': 'white', 197 | 'SUCCESS:': 'green', 198 | 'WARNING': 'yellow', 199 | 'ERROR': 'red', 200 | 'CRITICAL': 'red,bg_white'},)) 201 | 202 | self.logger.addHandler(fh) 203 | self.logger.addHandler(ch) 204 | # logger.info('test logger') 205 | 206 | def findthr(self, tensor, thr): 207 | tensor_np = tensor.cpu().numpy().reshape(-1,1) 208 | max_val = np.percentile(tensor_np, thr) 209 | return max_val 210 | 211 | def save_imgs(self, inputs, outputs, processed): 212 | local_point = outputs['local_point'] 213 | message = "\nlocal_min:{:.3f} max:{:.3f} global_min:{:.3f} max:{:.3f}".format( 214 | local_point.min(), local_point.max(), global_point.min(), global_point.max()) 215 | 216 | save_path = self.img_root/inputs['name1'][0] 217 | name = save_path.name.split('.')[0] 218 | save_path = save_path.dirname() 219 | if not save_path.exists(): 220 | save_path.makedirs_p() 221 | 222 | bi, ci, hi, wi = inputs['im1'].shape 223 | bo, co, ho, wo = local_point.shape 224 | if hi != ho or wi != wo: 225 | local_point = F.interpolate(local_point, (hi, wi)) 226 | bi, hi, wi, ci = inputs['im1_ori'].shape 227 | 228 | pad = inputs['pad1'] 229 | if pad[3] != 0: 230 | local_point = local_point[:,:,:-pad[3],:] 231 | if pad[1] != 0: 232 | local_point = local_point[:,:,:,:-pad[1]] 233 | local_point1 = local_point[:,0,:,:] 234 | 235 | local_point1 = local_point1/self.findthr(local_point1, 100*self.config['local_thr']) 236 | local_point1 = local_point1.clamp(0,1) 237 | 238 | local_point1 = dutils.tensor2array(local_point1.squeeze())[:3,:,:].transpose(1,2,0) 239 | local_point1 = Im.fromarray((255*local_point1).astype(np.uint8)) 240 | local_point1.save(save_path/'{:>05d}_score_map.jpg'.format(name)) 241 | 242 | imgs_with_kps = inputs['im1_ori'].squeeze().cpu().numpy().astype(np.uint8) 243 | # imgs_with_kps = cv2.cvtColor(imgs_with_kps, cv2.COLOR_RGB2BGR) 244 | color = (0,255,0) 245 | for kp in processed['kpt']: 246 | kp = (kp[0], kp[1]) 247 | cv2.circle(imgs_with_kps, kp, radius=2, color=color, thickness=-1) 248 | imgs_with_kps = cv2.cvtColor(imgs_with_kps, cv2.COLOR_BGR2RGB) 249 | imgs_with_kps = Im.fromarray(imgs_with_kps) 250 | imgs_with_kps.save(save_path/'{:>05d}_image_with_kp.jpg'.format(name)) 251 | 252 | return message 253 | 254 | def save_desc(self, inputs, outputs, processed): 255 | kpt = processed['kpt'] 256 | feat_f = processed['desc'] 257 | kp_score = processed['kp_score'] 258 | 259 | name = inputs['name1'][0]#.replace('ppm','wsf') 260 | save_path = self.desc_root/name 261 | h5_path = self.desc_root+'h5' 262 | if not save_path.dirname().exists(): 263 | save_path.dirname().makedirs_p() 264 | 265 | message = "\nkpts: {}".format(kpt.shape[0]) 266 | 267 | if self.save_npz: 268 | desc = feat_f.squeeze(0).detach().cpu().numpy() 269 | scores = kp_score.squeeze(0).detach().cpu().numpy() 270 | with open(save_path + '.{}'.format(self.config['postfix']), 'wb') as output_file: 271 | np.savez(output_file, keypoints=kpt, scores=scores, descriptors=desc) 272 | 273 | if self.save_h5: 274 | # now it is only for image-matching-benchmark, so the name is seq/name.jpg 275 | desc = feat_f.squeeze(0).detach().cpu().numpy() #save as nxc 276 | scores = kp_score.squeeze(0).detach().cpu().numpy() 277 | scales = np.ones_like(scores) 278 | h5_name = name.split('.')[0] 279 | h5_seq = h5_name.split('/')[:-1] 280 | h5_seq = '/'.join(h5_seq) 281 | h5_name = h5_name.split('/')[-1] 282 | if not os.path.exists(h5_path/h5_seq): 283 | (h5_path/h5_seq).makedirs_p() 284 | with h5py.File(h5_path/h5_seq+'/keypoints.h5', 'a') as fkp, \ 285 | h5py.File(h5_path/h5_seq+'/descriptors.h5', 'a') as fdesc, \ 286 | h5py.File(h5_path/h5_seq+'/scores.h5', 'a') as fsco, \ 287 | h5py.File(h5_path/h5_seq+'/scales.h5', 'a') as fsca: 288 | try: 289 | fkp[h5_name] = kpt 290 | fdesc[h5_name] = desc 291 | fsco[h5_name] = scores 292 | fsca[h5_name] = scales 293 | except OSError as error: 294 | if 'No space left on device' in error.args[0]: 295 | self.logger.error( 296 | 'Out of disk space: storing features on disk can take ' 297 | 'significant space, did you enable the as_half flag?') 298 | del grp, fh5[name] 299 | raise error 300 | # for hloc input 301 | with h5py.File(h5_path/'feat.h5', 'a') as fh5: 302 | try: 303 | grp = fh5.create_group(name) 304 | grp.create_dataset('keypoints', data=kpt) 305 | grp.create_dataset('scores', data=scores) 306 | grp.create_dataset('descriptors', data=desc) 307 | grp.create_dataset('image_size', data=np.array([w,h])) 308 | except OSError as error: 309 | if 'No space left on device' in error.args[0]: 310 | self.logger.error( 311 | 'Out of disk space: storing features on disk can take ' 312 | 'significant space, did you enable the as_half flag?') 313 | del grp, fh5[name] 314 | raise error 315 | 316 | return message 317 | 318 | def process(self, inputs, outputs, remove_pad=False): 319 | desc_f = outputs['local_map'] 320 | name = inputs['name1'][0] 321 | 322 | if remove_pad: 323 | b,c,h,w = inputs['im1_ori'].shape 324 | pad = inputs['pad1'] 325 | desc_f = desc_f[:,:,:-(pad[3]//4),:-(pad[0]//4)] 326 | outputs['local_point'] = outputs['local_point'][:,:,:-(pad[3]//4),:-(pad[0]//4)] 327 | else: 328 | b,c,h,w = inputs['im1'].shape 329 | 330 | if self.sift_kp: 331 | coords = inputs['coord1'] 332 | coord_n = normalize_coords(coords, h, w) 333 | kp_score = torch.ones_like(coord_n)[:,:,:1] 334 | else: 335 | if self.config['data'] == 'Aachen_Day_Night': 336 | cur_name_split = name.split('/') 337 | if cur_name_split[0] == 'query': 338 | coord_n, kp_score = self.detector(outputs['local_point'], **self.config['detector_config_query']) 339 | else: 340 | coord_n, kp_score = self.detector(outputs['local_point'], **self.config['detector_config']) 341 | else: 342 | coord_n, kp_score = self.detector(outputs['local_point'], **self.config['detector_config']) 343 | 344 | coords = denormalize_coords(coord_n, h, w) 345 | 346 | feat_f = sample_feat_by_coord(desc_f, coord_n, self.config['loss_distance']=='cos') 347 | kpt = coords.cpu().numpy().squeeze(0) 348 | 349 | # scale for inloc 350 | if 'scale' in list(inputs.keys()): 351 | kpt = kpt*inputs['scale'].cpu().numpy() 352 | 353 | return {'kpt': kpt, 354 | 'desc': feat_f, 355 | 'kp_score': kp_score} 356 | 357 | @torch.no_grad() 358 | def extract(self): 359 | bar = tqdm(self.extract_loader, total=int(len(self.extract_loader)), ncols=80) 360 | color = np.array(range(256)).astype(np.float)[None,:].repeat(30, axis=0) 361 | color = np.concatenate([np.zeros((30,20)),255*np.ones((30,20)),color], axis=1) 362 | color = dutils.tensor2array(torch.tensor(color))[:3,:,:].transpose(1,2,0) 363 | color = Im.fromarray((255*color).astype(np.uint8)) 364 | color.save(self.img_root/'0_colorbar.jpg') 365 | name_list = '' 366 | for idx, inputs in enumerate(bar): 367 | for key, val in inputs.items(): 368 | if key == 'name1' or key == 'pad1': 369 | continue 370 | inputs[key] = val.to(self.device) 371 | message = inputs['name1'][0] 372 | outputs = self.model.extract(inputs['im1']) 373 | processed = self.process(inputs, outputs) 374 | if self.config['output_desc']: 375 | message += self.save_desc(inputs, outputs, processed) 376 | if self.config['output_img']: 377 | message += self.save_imgs(inputs, outputs, processed, idx) 378 | self.logger.info(message) 379 | name_list += '{} {}\n'.format(idx, inputs['name1'][0]) 380 | torch.cuda.empty_cache() 381 | with open(self.img_root/'name_list.txt', 'w') as f: 382 | f.write(name_list) 383 | -------------------------------------------------------------------------------- /managers/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import shutil 4 | import logging 5 | import yaml 6 | import importlib 7 | import time 8 | from path import Path 9 | from abc import ABC, abstractmethod 10 | from PIL import Image as Im 11 | import numpy as np 12 | import torch.nn.functional as F 13 | 14 | import torch 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | import torch.distributed as dist 18 | from torch.utils.data.distributed import DistributedSampler 19 | 20 | import networks 21 | import datasets 22 | import losses 23 | import datasets.data_utils as dutils 24 | import losses.preprocess_utils as putils 25 | 26 | from tqdm import tqdm 27 | import cv2 28 | import copy 29 | import matplotlib 30 | import matplotlib.pyplot as plt 31 | 32 | class TqdmHandler(logging.StreamHandler): 33 | def __init__(self): 34 | logging.StreamHandler.__init__(self) 35 | 36 | def emit(self, record): 37 | msg = self.format(record) 38 | tqdm.write(msg) 39 | 40 | 41 | class Trainer(ABC): 42 | def __init__(self, args): 43 | ## read the config file 44 | ## 读取配置文件 45 | self.args = args 46 | with open(self.args.config, 'r') as f: 47 | self.config = yaml.load(f, Loader=yaml.FullLoader) 48 | self.save_root = Path('./ckpts/{}'.format(self.config['checkpoint_name'])) 49 | self.logfile = self.save_root/'logging_file.txt' 50 | 51 | ## update the model config if there is a checkpoint 52 | ## 如果存在checkpoint,则根据checkpoint中的模型配置来更新配置文件,确保参数正确载入 53 | ckpt_path = None 54 | if 'load_path' in list(self.config.keys()): 55 | if self.config['load_path'] is not None: 56 | ckpt_path = Path(self.config['load_path']) 57 | cfg_path = ckpt_path.dirname()/'config.yaml' 58 | with open(cfg_path, 'r') as f: 59 | pre_conf = yaml.load(f, Loader=yaml.FullLoader) 60 | self.config['model_config'].update(pre_conf['model_config']) 61 | 62 | if 'model' in list(pre_conf.keys()): 63 | self.config['model'] = pre_conf['model'] 64 | 65 | ## set training device, and now the multi-GPU training is slow (unknown reason) 66 | ## 设置训练设备(CPU/GPU/multi GPU),目前的多GPU训练速度较慢(原因未知) 67 | self.set_device() 68 | ## set logger, create folder and save config file into the folder 69 | ## 设置logger,创建训练文件夹并存储配置文件 70 | self.set_folder_and_logger() 71 | 72 | ## model 73 | if 'model' in list(self.config.keys()): 74 | tmp_model = getattr(networks, self.config['model']) 75 | self.model = tmp_model(self.config['model_config'], self.device, self.config['no_cuda']) 76 | else: 77 | self.model = networks.PoSFeat(self.config['model_config'], self.device, self.config['no_cuda']) 78 | parameters = [] 79 | for module_name, module_lr in zip(self.config['optimal_modules'], self.config['optimal_lrs']): 80 | tmp_module = getattr(self.model, module_name) 81 | parameters.append({'params':tmp_module.parameters(), 'lr':module_lr}) 82 | self.all_optimized_modules = self.config['optimal_modules'] 83 | for module_name in self.model.modules: 84 | if module_name not in self.all_optimized_modules: 85 | tmp_module = getattr(self.model, module_name) 86 | for p in tmp_module.parameters(): 87 | p.requires_grad = False 88 | if ckpt_path is not None: 89 | self.logger.info('load checkpoint from {}'.format(ckpt_path)) 90 | self.model.load_checkpoint(ckpt_path) 91 | if self.multi_gpu: 92 | self.model.set_parallel(self.args.local_rank) 93 | 94 | ## losses 95 | if 'preprocess_train' in list(self.config.keys()): 96 | tmp_model = getattr(losses, self.config['preprocess_train']) 97 | self.preprocess = tmp_model(self.config['preprocess_train_config'], self.device).to(self.device) 98 | self.skip_preprocess = False 99 | else: 100 | self.preprocess = losses.Preprocess_Skip().to(self.device) 101 | self.skip_preprocess = True 102 | 103 | self.losses = [] 104 | self.losses_weight = [] 105 | for loss_name, loss_weight in zip(self.config['losses'], self.config['losses_weight']): 106 | loss_module = getattr(losses, loss_name) 107 | self.losses.append(loss_module(self.config['{}_config'.format(loss_name)], self.device).to(self.device)) 108 | self.losses_weight.append(float(loss_weight)) 109 | if hasattr(self.losses[-1], 'load_checkpoint'): 110 | if ckpt_path is not None: 111 | self.losses[-1].load_checkpoint(ckpt_path) 112 | # parameters += list(self.losses[-1].parameters()) 113 | parameters.append({'params':self.losses[-1].parameters()}) 114 | 115 | ## optimizer 116 | self.logger.info(parameters) 117 | self.logger.info(self.all_optimized_modules) 118 | tmp_optimizer = getattr(torch.optim, self.config['optimizer']) 119 | self.optimizer = tmp_optimizer(parameters) 120 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 121 | step_size=self.config['lr_decay_step'], 122 | gamma=self.config['lr_decay_factor']) 123 | self.logger.info(self.config['optimizer']) 124 | 125 | ## dataloader 126 | dataset = getattr(datasets, self.config['data']) 127 | train_dataset = dataset(configs=self.config['data_config_train'], is_train=True) 128 | if self.multi_gpu: 129 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 130 | else: 131 | train_sampler = None 132 | self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.config['data_config_train']['batch_size'], 133 | shuffle= ~self.multi_gpu, num_workers=self.config['data_config_train']['workers'], 134 | collate_fn=self.my_collate, sampler=train_sampler) 135 | 136 | val_dataset = dataset(configs=self.config['val_config']['data_config_val'], is_train=False) 137 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.config['val_config']['data_config_val']['batch_size'], 138 | shuffle= self.config['val_config']['data_config_val']['shuffle'], 139 | num_workers=self.config['val_config']['data_config_val']['workers'], 140 | collate_fn=self.my_collate) 141 | val_iter = iter(putils.cycle(val_loader)) 142 | self.val_data = next(val_iter) 143 | del val_dataset, val_loader, val_iter 144 | with open(self.save_root/'val_data.npz', 'wb') as out_f: 145 | np.savez(out_f, val_data=self.val_data) 146 | 147 | def my_collate(self, batch): 148 | ''' Puts each data field into a tensor with outer dimension batch size ''' 149 | batch = list(filter(lambda b: b is not None, batch)) 150 | return torch.utils.data.dataloader.default_collate(batch) 151 | 152 | def set_device(self): 153 | if torch.cuda.device_count() == 0: 154 | self.device = torch.device("cpu") 155 | self.output_flag=True 156 | self.multi_gpu = False 157 | print('use CPU for training') 158 | elif torch.cuda.device_count() == 1: 159 | self.device = torch.device("cuda") 160 | self.output_flag=True 161 | self.multi_gpu = False 162 | self.args.local_rank = 0 163 | print('use a single GPU for training') 164 | else: 165 | self.device = torch.device("cuda", self.args.local_rank) 166 | self.multi_gpu = True 167 | dist.init_process_group(backend='nccl') 168 | # torch.autograd.set_detect_anomaly(True) # for debug 169 | if self.args.local_rank == 0: 170 | self.output_flag=True 171 | print('use {} GPUs for training'.format(torch.cuda.device_count())) 172 | else: 173 | self.output_flag=False 174 | 175 | def set_folder_and_logger(self): 176 | if self.output_flag: 177 | if not os.path.exists(self.save_root) : 178 | self.save_root.makedirs_p() 179 | else: 180 | # if path exsists, quit to make sure that the previous setting.txt would not be overwritten 181 | # 如果路径已存在,退出训练保证之间的配置文件不会被覆盖 182 | raise "The save path is already exists, please update the folder name" 183 | print('=> will save everything to {}'.format(self.save_root)) 184 | with open(self.save_root/'config.yaml', 'w') as fout: 185 | yaml.dump(self.config, fout) 186 | self.logfile.touch() 187 | 188 | self.writer = SummaryWriter(self.save_root) 189 | 190 | self.logger = logging.getLogger() 191 | 192 | # color settings 193 | BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) 194 | RESET_SEQ = "\033[0m" 195 | COLOR_SEQ = "\033[1;%dm" 196 | BOLD_SEQ = "\033[1m" 197 | 198 | def formatter_message(message, use_color = True): 199 | if use_color: 200 | message = message.replace("$RESET", RESET_SEQ).replace("$BOLD", BOLD_SEQ) 201 | else: 202 | message = message.replace("$RESET", "").replace("$BOLD", "") 203 | return message 204 | 205 | COLORS = { 206 | 'WARNING': YELLOW, 207 | 'INFO': CYAN, 208 | 'DEBUG': BLUE, 209 | 'CRITICAL': YELLOW, 210 | 'ERROR': RED 211 | } 212 | 213 | class ColoredFormatter(logging.Formatter): 214 | def __init__(self, msg, use_color = True): 215 | logging.Formatter.__init__(self, msg) 216 | self.use_color = use_color 217 | 218 | def format(self, record): 219 | levelname = record.levelname 220 | if self.use_color and levelname in COLORS: 221 | levelname_color = COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ 222 | record.levelname = levelname_color 223 | return logging.Formatter.format(self, record) 224 | 225 | msg = "%(asctime)s-gpu {}-%(levelname)s: %(message)s".format(self.args.local_rank) 226 | formatter = logging.Formatter(msg) 227 | color_formatter = ColoredFormatter(formatter_message(msg, True)) 228 | 229 | if self.output_flag: 230 | self.logger.setLevel(logging.INFO) 231 | fh = logging.FileHandler(self.logfile, mode='a') 232 | fh.setLevel(logging.DEBUG) 233 | fh.setFormatter(formatter) 234 | 235 | ch = TqdmHandler() 236 | ch.setLevel(logging.INFO) 237 | ch.setFormatter(color_formatter) 238 | else: 239 | self.logger.setLevel(logging.ERROR) 240 | fh = logging.FileHandler(self.logfile, mode='a') 241 | fh.setLevel(logging.ERROR) 242 | fh.setFormatter(formatter) 243 | 244 | ch = TqdmHandler() 245 | ch.setLevel(logging.ERROR) 246 | ch.setFormatter(color_formatter) 247 | 248 | self.logger.addHandler(fh) 249 | self.logger.addHandler(ch) 250 | 251 | def save_errors(self, inputs, outputs, losses, loss_items): 252 | if not os.path.exists(self.save_root/"error.pt"): 253 | save_dict = {"inputs":inputs, "outputs": outputs, 254 | "losses":losses, "loss_items":loss_items} 255 | torch.save(save_dict, self.save_root/"error.pt") 256 | 257 | def save_loss(self, save_path): 258 | save_path = Path(save_path) 259 | for idx in range(len(self.config['losses'])): 260 | if hasattr(self.losses[idx], 'save_checkpoint'): 261 | self.losses[idx].save_checkpoint(save_path) 262 | 263 | def train(self): 264 | batch_size_val = self.val_data['im1'].shape[0] 265 | epoch_path = self.save_root/'{:>03d}'.format(0) 266 | epoch_path.makedirs_p() 267 | self.model.save_checkpoint(epoch_path) 268 | self.save_loss(epoch_path) 269 | 270 | for epoch in range(self.config['epoch']): 271 | epoch += 1 272 | epoch_path = self.save_root/'{:>03d}'.format(epoch) 273 | epoch_path.makedirs_p() 274 | batch_path_list = [] 275 | for i in range(batch_size_val): 276 | batch_path = epoch_path/'{}'.format(i) 277 | batch_path.makedirs_p() 278 | batch_path_list.append(batch_path) 279 | if self.config['epoch_step'] > 0: 280 | total_steps = self.config['epoch_step'] 281 | else: 282 | total_steps = len(self.train_loader) 283 | bar = tqdm(self.train_loader, total=int(total_steps), ncols=80) 284 | bar.set_description('{}/{} {}/{}'.format(self.config['checkpoint_name'], self.save_root.name, epoch, self.config['epoch'])) 285 | self.model.set_train() 286 | for idx, inputs in enumerate(bar): 287 | # val and vis 288 | self.model.set_eval() 289 | if self.output_flag and idx % self.config['log_freq'] == 0: 290 | self.val_and_vis(batch_path_list, idx) 291 | torch.cuda.empty_cache() 292 | # train 293 | self.model.set_eval() 294 | for module in self.config['optimal_modules']: 295 | tmp_module = getattr(self.model, module) 296 | tmp_module.train() 297 | outputs = self.model.forward(inputs) 298 | outputs['epoch'] = epoch 299 | outputs['iterations'] = int((epoch-1)*total_steps+idx) 300 | processed = self.preprocess(inputs, outputs) 301 | if self.skip_preprocess: 302 | message = "epoch {} batch {}".format(epoch, idx) 303 | else: 304 | message = "epoch {} batch {} temperature {}".format(epoch, idx, processed['temperature']) 305 | total_loss = 0 306 | loss_items = [] 307 | temp_log = {} 308 | for loss_name, loss_module, loss_weight in zip(self.config['losses'], self.losses, self.losses_weight): 309 | tmp_loss, tmp_items = loss_module(inputs, outputs, processed) 310 | total_loss += loss_weight*tmp_loss.mean() 311 | temp_log[loss_name] = tmp_loss.detach().mean().item() 312 | message += "\n {}:{:.5f}[{:.2f}] (total: {:.5f} ".format(loss_name, loss_weight*tmp_loss.detach().mean().item(), loss_weight, 313 | tmp_loss.detach().mean().item()) 314 | for key, val in tmp_items.items(): 315 | message += "{}[{:.5f}] ".format(key, val.detach().mean().item()) 316 | message += ")" 317 | loss_items.append(tmp_items) 318 | message += '\n' 319 | 320 | # if the loss is nan, skip this batch 321 | # 如果loss是nan,则跳过当前batch 322 | if total_loss.isnan(): 323 | self.logger.info(message) 324 | self.logger.error("loss is nan in {}, check the error.pt".format(idx)) 325 | self.save_errors(inputs, outputs, total_loss, loss_items) 326 | total_loss.backward() 327 | self.optimizer.zero_grad() 328 | continue 329 | 330 | self.optimizer.zero_grad() 331 | total_loss.backward() 332 | 333 | if 'localheader' in self.all_optimized_modules: 334 | grad_message = 'grad localheader conv1 mean {:.6f} max{:.6f}'.format(self.model.localheader.conv1.weight.grad.mean().item(), 335 | self.model.localheader.conv1.weight.grad.max().item()) 336 | self.logger.info(grad_message) 337 | if 'backbone' in self.all_optimized_modules: 338 | grad_message = 'grad backbone conv_fine mean {:.6f} max{:.6f}'.format(self.model.backbone.conv_fine.conv.weight.grad.mean().item(), 339 | self.model.backbone.conv_fine.conv.weight.grad.max().item()) 340 | self.logger.info(grad_message) 341 | grad_message = 'grad backbone firstconv mean {:.6f} max{:.6f}'.format(self.model.backbone.firstconv.weight.grad.mean().item(), 342 | self.model.backbone.firstconv.weight.grad.max().item()) 343 | self.logger.info(grad_message) 344 | if self.config['grad_clip']: 345 | for module_name in self.all_optimized_modules: 346 | tmp_module = getattr(self.model, module_name) 347 | torch.nn.utils.clip_grad_norm_(tmp_module.parameters(), self.config['clip_norm']) 348 | if 'localheader' in self.all_optimized_modules: 349 | grad_message = 'grad clipped localheader conv1 mean {:.6f} max{:.6f}'.format(self.model.localheader.conv1.weight.grad.mean().item(), 350 | self.model.localheader.conv1.weight.grad.max().item()) 351 | self.logger.info(grad_message) 352 | if 'backbone' in self.all_optimized_modules: 353 | grad_message = 'grad clipped backbone firstconv mean {:.6f} max{:.6f}'.format(self.model.backbone.firstconv.weight.grad.mean().item(), 354 | self.model.backbone.firstconv.weight.grad.max().item()) 355 | self.logger.info(grad_message) 356 | self.optimizer.step() 357 | 358 | self.logger.info(message) 359 | if self.output_flag and idx%self.config['log_freq'] == 0: 360 | self.writer.add_scalar('losses', total_loss.item(), int((epoch-1)*total_steps+idx)) 361 | for loss_name in self.config['losses']: 362 | self.writer.add_scalar(loss_name, temp_log[loss_name], int((epoch-1)*total_steps+idx)) 363 | for components in loss_items: 364 | for component_name in list(components.keys()): 365 | if component_name in self.config['tb_component']: 366 | self.writer.add_scalar(component_name, components[component_name], int((epoch-1)*total_steps+idx)) 367 | if self.output_flag and idx%100 == 0: 368 | self.model.save_checkpoint(epoch_path) 369 | 370 | torch.cuda.empty_cache() 371 | 372 | if idx>=self.config['epoch_step']: 373 | break 374 | 375 | 376 | self.model.save_checkpoint(epoch_path) 377 | self.save_loss(epoch_path) 378 | self.scheduler.step() 379 | 380 | @torch.no_grad() 381 | def val_and_vis(self, batch_path_list, idx): 382 | val_config = self.config['val_config'] 383 | self.model.set_eval() 384 | outputs = self.model.forward(self.val_data) 385 | mid_pad = 20 386 | 387 | preds1 = outputs['preds1'] 388 | preds2 = outputs['preds2'] 389 | 390 | b,c,h,w = self.val_data['im1'].shape 391 | 392 | all_images = ['0_original_images', '1_score_maps', '2_all_keypoints', 393 | '3_matched_keypoints', '4_matches_less', '5_matches_all'] 394 | 395 | # if val_config['detector'] == 'sift': 396 | # coord1 = self.val_data['coord1'] 397 | # coord2 = self.val_data['coord2'] 398 | 399 | # coord1_n = putils.normalize_coords(coord1, h, w) 400 | # coord2_n = putils.normalize_coords(coord2, h, w) 401 | # else: 402 | # detector = getattr(putils, val_config['detector']) 403 | # coord1_n = detector(preds1['local_point'], **val_config['detector_config']) 404 | # coord2_n = detector(preds2['local_point'], **val_config['detector_config']) 405 | 406 | # coord1 = putils.denormalize_coords(coord1_n, h, w) 407 | # coord2 = putils.denormalize_coords(coord2_n, h, w) 408 | 409 | # desc1 = putils.sample_feat_by_coord(preds1['local_map'], coord1_n, val_config['loss_distance']=='cos') 410 | # desc2 = putils.sample_feat_by_coord(preds2['local_map'], coord1_n, val_config['loss_distance']=='cos') 411 | 412 | for i, cur_path in enumerate(batch_path_list): 413 | for image_folder in all_images: 414 | tmp_path = cur_path/image_folder 415 | if not tmp_path.exists(): 416 | tmp_path.makedirs_p() 417 | cur_img1 = self.val_data['im1_ori'][i,...] 418 | cur_img2 = self.val_data['im2_ori'][i,...] 419 | cur_F12 = self.val_data['F1'][i,...] 420 | score_map1 = preds1['local_point'][i,...] 421 | score_map2 = preds2['local_point'][i,...] 422 | comb_img = torch.cat((cur_img1, torch.zeros_like(cur_img1)[:,:mid_pad,:], cur_img2), dim=1) 423 | comb_score = torch.cat((score_map1, torch.zeros_like(score_map1)[:,:,:mid_pad], score_map2), dim=2) 424 | 425 | if val_config['detector'] == 'sift': 426 | cur_kps1 = self.val_data['coord1'][i,:,:2] 427 | cur_kps2 = self.val_data['coord2'][i,:,:2] 428 | cur_score1 = torch.ones_like(cur_kps1)[...,0:1] 429 | cur_score2 = torch.ones_like(cur_kps2)[...,0:1] 430 | 431 | cur_kps1_n = putils.normalize_coords(cur_kps1, h, w).unsqueeze(0) 432 | cur_kps2_n = putils.normalize_coords(cur_kps2, h, w).unsqueeze(0) 433 | else: 434 | detector = getattr(putils, val_config['detector']) 435 | cur_kps1_n, cur_score1 = detector(preds1['local_point'][i:i+1,...], 436 | **val_config['detector_config']) 437 | cur_kps2_n, cur_score2 = detector(preds2['local_point'][i:i+1,...], 438 | **val_config['detector_config']) 439 | 440 | cur_kps1 = putils.denormalize_coords(cur_kps1_n, h, w).squeeze(0) 441 | cur_kps2 = putils.denormalize_coords(cur_kps2_n, h, w).squeeze(0) 442 | cur_score1 = cur_score1.squeeze(0) 443 | cur_score2 = cur_score2.squeeze(0) 444 | 445 | 446 | cur_desc1 = putils.sample_feat_by_coord(preds1['local_map'][i:i+1,...], 447 | cur_kps1_n, val_config['loss_distance']=='cos').squeeze(0) 448 | cur_desc2 = putils.sample_feat_by_coord(preds2['local_map'][i:i+1,...], 449 | cur_kps2_n, val_config['loss_distance']=='cos').squeeze(0) 450 | 451 | cur_matches = putils.mnn_matcher(cur_desc1, cur_desc2) 452 | cur_matchkp1 = cur_kps1[cur_matches[:,0],:2] 453 | cur_matchkp2 = cur_kps2[cur_matches[:,1],:2] 454 | cur_kpscore_m1 = cur_score1[cur_matches[:,0],:1] 455 | cur_kpscore_m2 = cur_score2[cur_matches[:,1],:1] 456 | cur_kpscore = cur_kpscore_m1 + cur_kpscore_m2.to(cur_score1) 457 | # cur_kpscore = cur_kpscore_m1 + cur_kpscore_m2 458 | _, topk_idx = cur_kpscore.topk(min(val_config['vis_topk'], cur_kpscore.shape[0]), dim=0) 459 | 460 | cur_matchkp1_h = putils.homogenize(cur_matchkp1).transpose(0, 1) 461 | cur_matchkp2_h = putils.homogenize(cur_matchkp2).transpose(0, 1) 462 | cur_epi_line1 = cur_F12@cur_matchkp1_h 463 | cur_epi_line1 = cur_epi_line1 / torch.clamp( 464 | torch.norm(cur_epi_line1[:2, :], dim=0, keepdim=True), min=1e-8) 465 | epi_dist = torch.abs(torch.sum(cur_matchkp2_h * cur_epi_line1, dim=0)).unsqueeze(1) 466 | epi_dist = epi_dist.clamp(min=0, max=val_config['vis_err_thr']).repeat(1,2) 467 | 468 | match_color = dutils.tensor2array(val_config['vis_err_thr'] - epi_dist, 469 | max_value=val_config['vis_err_thr'], colormap='RdYlGn')[:3,:,:1].transpose(1,2,0) 470 | match_color = (255*match_color).astype(np.uint8) 471 | match_color = cv2.cvtColor(match_color, cv2.COLOR_RGB2BGR).squeeze(1) 472 | 473 | cur_matchkp1_less = cur_matchkp1[topk_idx, :2] 474 | cur_matchkp2_less = cur_matchkp2[topk_idx, :2] 475 | match_color_less = match_color[topk_idx.cpu().numpy()[:,0], :3] 476 | 477 | cur_kps1 = list(map(tuple,cur_kps1.reshape(-1, 2).cpu().numpy())) 478 | cur_kps2 = list(map(tuple,cur_kps2.reshape(-1, 2).cpu().numpy())) 479 | cur_matchkp1 = list(map(tuple,cur_matchkp1.reshape(-1, 2).cpu().numpy())) 480 | cur_matchkp2 = list(map(tuple,cur_matchkp2.reshape(-1, 2).cpu().numpy())) 481 | match_color = list(map(tuple,match_color)) 482 | cur_matchkp1_less = list(map(tuple,cur_matchkp1_less.reshape(-1, 2).cpu().numpy())) 483 | cur_matchkp2_less = list(map(tuple,cur_matchkp2_less.reshape(-1, 2).cpu().numpy())) 484 | match_color_less = list(map(tuple,match_color_less)) 485 | 486 | comb_img = comb_img.cpu().numpy() 487 | save_img = comb_img 488 | save_img = Im.fromarray(save_img.astype(np.uint8)) 489 | save_img.save(cur_path/'0_original_images/{}.jpg'.format(idx)) 490 | 491 | comb_score = dutils.tensor2array(comb_score.squeeze())[:3,:,:].transpose(1,2,0) 492 | save_img = 255*comb_score 493 | save_img = Im.fromarray(save_img.astype(np.uint8)) 494 | save_img.save(cur_path/'1_score_maps/{}.jpg'.format(idx)) 495 | 496 | comb_img_kps = cv2.cvtColor(comb_img,cv2.COLOR_RGB2BGR) 497 | color = (0,255,0) 498 | for kp1 in cur_kps1: 499 | cv2.circle(comb_img_kps, kp1, radius=2, color=color, thickness=-1) 500 | for kp2 in cur_kps2: 501 | kp2_comb = (int(kp2[0]+w+mid_pad), int(kp2[1])) 502 | cv2.circle(comb_img_kps, kp2_comb, radius=2, color=color, thickness=-1) 503 | comb_img_kps = cv2.cvtColor(comb_img_kps, cv2.COLOR_BGR2RGB) 504 | save_img = comb_img_kps 505 | save_img = Im.fromarray(save_img.astype(np.uint8)) 506 | save_img.save(cur_path/'2_all_keypoints/{}.jpg'.format(idx)) 507 | 508 | comb_img_kps_m = cv2.cvtColor(comb_img,cv2.COLOR_RGB2BGR) 509 | color = (0,255,0) 510 | for kp1, kp2 in zip(cur_matchkp1, cur_matchkp2): 511 | cv2.circle(comb_img_kps_m, kp1, radius=2, color=color, thickness=-1) 512 | # kp2_comb = kp2 + torch.tensor([w, 0]).reshape(1,2).to(kp2) 513 | kp2_comb = (int(kp2[0]+w+mid_pad), int(kp2[1])) 514 | cv2.circle(comb_img_kps_m, kp2_comb, radius=2, color=color, thickness=-1) 515 | comb_img_kps_m = cv2.cvtColor(comb_img_kps_m, cv2.COLOR_BGR2RGB) 516 | save_img = comb_img_kps_m 517 | save_img = Im.fromarray(save_img.astype(np.uint8)) 518 | save_img.save(cur_path/'3_matched_keypoints/{}.jpg'.format(idx)) 519 | 520 | comb_img_m_less = cv2.cvtColor(comb_img,cv2.COLOR_RGB2BGR) 521 | for kp1, kp2, color in zip(cur_matchkp1_less, cur_matchkp2_less, match_color_less): 522 | # kp2_comb = kp2 + torch.tensor([w, 0]).reshape(1,2).to(kp2) 523 | kp2_comb = (int(kp2[0]+w+mid_pad), int(kp2[1])) 524 | color = (int(color[0]), int(color[1]), int(color[2])) 525 | cv2.line(comb_img_m_less, kp1, kp2_comb, color, thickness=2) 526 | cv2.circle(comb_img_m_less, kp1, radius=2, color=(0,255,0), thickness=-1) 527 | cv2.circle(comb_img_m_less, kp2_comb, radius=2, color=(0,255,0), thickness=-1) 528 | comb_img_m_less = cv2.cvtColor(comb_img_m_less, cv2.COLOR_BGR2RGB) 529 | save_img = comb_img_m_less 530 | save_img = Im.fromarray(save_img.astype(np.uint8)) 531 | save_img.save(cur_path/'4_matches_less/{}.jpg'.format(idx)) 532 | 533 | comb_img_m = cv2.cvtColor(comb_img,cv2.COLOR_RGB2BGR) 534 | for kp1, kp2, color in zip(cur_matchkp1, cur_matchkp2, match_color): 535 | # kp2_comb = kp2 + torch.tensor([w, 0]).reshape(1,2).to(kp2) 536 | kp2_comb = (int(kp2[0]+w+mid_pad), int(kp2[1])) 537 | color = (int(color[0]), int(color[1]), int(color[2])) 538 | cv2.line(comb_img_m, kp1, kp2_comb, color, thickness=2) 539 | cv2.circle(comb_img_m, kp1, radius=2, color=(0,255,0), thickness=-1) 540 | cv2.circle(comb_img_m, kp2_comb, radius=2, color=(0,255,0), thickness=-1) 541 | comb_img_m = cv2.cvtColor(comb_img_m, cv2.COLOR_BGR2RGB) 542 | save_img = comb_img_m 543 | save_img = Im.fromarray(save_img.astype(np.uint8)) 544 | save_img.save(cur_path/'5_matches_all/{}.jpg'.format(idx)) --------------------------------------------------------------------------------